In [None]:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

import torch
device = torch.device("cuda")
import torch.nn.functional as F
from torchvision.models import resnet34, resnet50
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
from torch.utils import data
from torchvision import datasets
from torchvision import transforms
import torch.nn as nn
from torch.optim import Adam, SGD, RMSprop
from torchvision.models import resnet34, resnet50
import numpy as np
from large_margin import LargeMarginLoss

torch.manual_seed(0)

def _init_resnet_50(output_size, pretrained = False, features_hook = None):
    model = resnet50(pretrained=pretrained)
    model.fc = torch.nn.Linear(2048, output_size)
    if features_hook is not None:
        for name, module in model.named_modules():
            if name in ['layer1', 'layer2', 'layer3', 'layer4']:
                module.register_forward_hook(features_hook)

    return model

def create_pretrained_model(architecture, n_classes, features_hook = None):
    pretrained = True
    if 'resnet50' in architecture:
        net = _init_resnet_50(n_classes, pretrained, features_hook)
    else:
        raise NotImplementedError()

    return net
    
def create_model(architecture, n_classes, features_hook = None):
    pretrained = False
    if 'resnet50' in architecture:
        net = _init_resnet_50(n_classes, pretrained, features_hook)
    else:
        raise NotImplementedError()

    return net

class FeatureExtractor(torch.nn.Module):
    def __init__(self, architecture, n_classes = None):
        super().__init__()
        self._features = []
        if 'pretrained' in architecture:
            self.model = create_pretrained_model(
                architecture, 
                n_classes, 
                features_hook=self.feature_hook)
        else:
            self.model = create_model(
                architecture, 
                n_classes, 
                features_hook=self.feature_hook)

    def feature_hook(self, module, input, output):
        self._features.append(output[0])

    def forward(self, x):
        logits = self.model(x)
        return logits, self._features

my_module = FeatureExtractor('resnet50', 100)

train_loader = data.DataLoader(
        datasets.CIFAR100('./data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229, 0.224, 0.225])
                       ])),
        batch_size=256, shuffle=True, drop_last=True)

lm = LargeMarginLoss(
    gamma=10000,
    alpha_factor=4,
    top_k=1,
    dist_norm=np.inf
)

net = FeatureExtractor('resnet50', 100)
net.to(device)

layer_1_output = None
def forward_hook_1(module, input, output):
    global layer_1_output
    layer_1_output = output

net.model.layer1.register_forward_hook(forward_hook_1)

layer_2_output = None
def forward_hook_2(module, input, output):
    global layer_2_output
    layer_2_output = output

net.model.layer2.register_forward_hook(forward_hook_2)

layer_3_output = None
def forward_hook_3(module, input, output):
    global layer_3_output
    layer_3_output = output

net.model.layer3.register_forward_hook(forward_hook_3)

layer_4_output = None
def forward_hook_4(module, input, output):
    global layer_4_output
    layer_4_output = output

net.model.layer4.register_forward_hook(forward_hook_4)

X = torch.randn(256, 3, 32, 32)
out, features = net(X.to(device))
one_hot = F.one_hot(torch.arange(0,256) % 100)

one_hot = one_hot.to(device)

top_k = 1

prob = F.softmax(out, dim=1)
correct_prob = prob * one_hot

correct_prob = torch.sum(correct_prob, dim=1, keepdim=True)
other_prob = prob * (1.0 - one_hot)

if top_k > 1:
    topk_prob, _ = other_prob.topk(top_k, dim=1)
else:
    topk_prob, _ = other_prob.max(dim=1, keepdim=True)

diff_prob = correct_prob - topk_prob

loss = torch.empty(0, device=out.device)

for feature_map in [layer_1_output, layer_2_output, layer_3_output, layer_4_output]:
    for i in range(top_k):
        torch.autograd.grad(diff_prob[:,i], feature_map,
            grad_outputs=torch.ones_like(diff_prob[:,i], dtype=torch.float32),
            retain_graph=True)
            
    #diff_grad = torch.stack([_get_grad(diff_prob[:, i], feature_map) for i in range(top_k)],
    #                        dim=1)

# logits, onehot_labels, feature_maps
# loss = lm(out, one_hot.to(device), features)

# grad1 = torch.autograd.grad(out, layer_1_output, grad_outputs=torch.ones_like(out), retain_graph=True)
# grad2 = torch.autograd.grad(out, layer_2_output, grad_outputs=torch.ones_like(out), retain_graph=True)
# grad3 = torch.autograd.grad(out, layer_3_output, grad_outputs=torch.ones_like(out), retain_graph=True)
# grad4 = torch.autograd.grad(out, layer_4_output, grad_outputs=torch.ones_like(out), retain_graph=True)

# print(grad4)