In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.autograd import Variable
import numpy as np
import resnet
from dataset import get_dataset, get_handler
from torchvision import transforms
from operator import add

outputs = (torch.rand(100,4, requires_grad=True ))
net = resnet.ResNet18(num_classes=4)
parameters = tuple(net.parameters())
probs = F.softmax(outputs, dim=1).to('cpu')
log_probs = F.log_softmax(outputs, dim=1)
N, C = log_probs.shape
sq_grads_expect_orig = {i: np.zeros(p.shape) for i, p in enumerate(parameters)}
X_tr, Y_tr, X_te, Y_te = get_dataset('CIFAR10', 'data')
dim = np.shape(X_tr)[1:]
handler = get_handler('CIFAR10')
args = {'MNIST':
                {'n_epoch': 10, 'transform': transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
                'loader_tr_args':{'batch_size': 64, 'num_workers': 1},
                'loader_te_args':{'batch_size': 1000, 'num_workers': 1},
                'optimizer_args':{'lr': 0.01, 'momentum': 0.5}},
            'FashionMNIST':
                {'n_epoch': 10, 'transform': transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
                'loader_tr_args':{'batch_size': 64, 'num_workers': 1},
                'loader_te_args':{'batch_size': 1000, 'num_workers': 1},
                'optimizer_args':{'lr': 0.01, 'momentum': 0.5}},
            'SVHN':
                {'n_epoch': 20, 'transform': transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))]),
                'loader_tr_args':{'batch_size': 64, 'num_workers': 1},
                'loader_te_args':{'batch_size': 1000, 'num_workers': 1},
                'optimizer_args':{'lr': 0.01, 'momentum': 0.5}},
            'CIFAR10':
                {'n_epoch': 3, 'transform': transforms.Compose([ 
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
                ]),
                'loader_tr_args':{'batch_size': 128, 'num_workers': 1},
                'loader_te_args':{'batch_size': 100, 'num_workers': 1}, # change back to 1000
                'optimizer_args':{'lr': 0.05, 'momentum': 0.3},
                'transformTest': transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])}
                }


Files already downloaded and verified
Files already downloaded and verified


In [2]:
#Modified code
test_loader = DataLoader(handler(X_tr, Y_tr, transform=args['CIFAR10']['transform']), shuffle=False, **args['CIFAR10']['loader_te_args'])
num_samples = 3 # used by FISH mask paper
idx = 0
sq_grads_expect = []

for test_batch, test_labels, idxs in test_loader:
    test_batch, test_labels = test_batch, test_labels
    outputs, e1 = net(test_batch)
    print('Outputs: ',outputs.shape)
    _, preds = torch.max(outputs, 1)         
    probs = F.softmax(outputs, dim=1).to('cpu')
    log_probs = F.log_softmax(outputs, dim=1)
    N, C = log_probs.shape
    grad_outs = torch.zeros_like(log_probs)
    for c in range(len(preds)):
        grad_outs[c, preds[c]] = 1
    grad_list = torch.autograd.grad(log_probs*torch.sqrt(probs), parameters, grad_outputs=grad_outs, retain_graph=True)
    grad_list = [np.array(torch.square(i)) for i in grad_list]
    if idx == 0:
        sq_grads_expect = grad_list
    else:
        sq_grads_expect = list( map(add, sq_grads_expect, grad_list) )
    net.zero_grad()

    print('Outputs: ',outputs.shape)
    for n in range(N):
        for c in range(C):
            grad_list_orig = torch.autograd.grad(log_probs[n][c], parameters, retain_graph=True)
            for i, grad in enumerate(grad_list_orig):    # different layers
                gsq = torch.square(grad).to('cpu') * probs[n][c] / N
                sq_grads_expect_orig[i] += gsq.detach().numpy() # sq_grads_expect[i] + gsq
                del gsq
            net.zero_grad()

    idx += 1
    if idx >= num_samples:
        break
    

Outputs:  torch.Size([100, 4])
Outputs:  torch.Size([100, 4])
Outputs:  torch.Size([100, 4])
Outputs:  torch.Size([100, 4])
Outputs:  torch.Size([100, 4])
Outputs:  torch.Size([100, 4])


Computing gradients for the vector valued function

In [5]:
import torch
n = 4
x = torch.rand(n, requires_grad = True)
y = torch.rand(n, requires_grad = True)
z = torch.rand(n, requires_grad = True)
f = torch.column_stack([2 * x + 3 * torch.square(z), torch.square(y) + 4 * z * torch.square(x)])
f1 = 2 * x + 3 * torch.square(z)
f2 = torch.square(y) + 4 * z * torch.square(x)

for c in range(f.shape[1]):
    grad_outs = torch.zeros_like(f)
    grad_outs[:, c] = 1
    print(torch.autograd.grad(f, (x, y, z), grad_outputs=grad_outs, retain_graph=True))

# verifying calculated gradients are correct
#print(x, y, z)
print(torch.autograd.grad(f1, (x, y, z), grad_outputs=torch.ones_like(f1), allow_unused=True))
print(torch.autograd.grad(f2, (x, y, z), grad_outputs=torch.ones_like(f2), allow_unused=True))


(tensor([2., 2., 2., 2.]), tensor([0., 0., 0., 0.]), tensor([3.2179, 5.4168, 3.4557, 5.7125]))
(tensor([3.3963, 1.6956, 2.2514, 2.3443]), tensor([1.4453, 1.6891, 0.4078, 0.1368]), tensor([2.5065, 0.2205, 0.9550, 0.3789]))
(tensor([2., 2., 2., 2.]), None, tensor([3.2179, 5.4168, 3.4557, 5.7125]))
(tensor([3.3963, 1.6956, 2.2514, 2.3443]), tensor([1.4453, 1.6891, 0.4078, 0.1368]), tensor([2.5065, 0.2205, 0.9550, 0.3789]))
