In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torch.utils.data import DataLoader
import numpy as np
import resnet
from dataset import get_dataset, get_handler
from torchvision import transforms
from operator import add

In [None]:
add5 = lambda x : x + 5
add1 = lambda x : x + 1
add10 = lambda x : x + 10

net = resnet.ResNet18(num_classes=4)

args = {'MNIST':
                {'n_epoch': 10, 'transform': transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
                'loader_tr_args':{'batch_size': 64, 'num_workers': 1},
                'loader_te_args':{'batch_size': 1000, 'num_workers': 1},
                'optimizer_args':{'lr': 0.01, 'momentum': 0.5}},
            'FashionMNIST':
                {'n_epoch': 10, 'transform': transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
                'loader_tr_args':{'batch_size': 64, 'num_workers': 1},
                'loader_te_args':{'batch_size': 1000, 'num_workers': 1},
                'optimizer_args':{'lr': 0.01, 'momentum': 0.5}},
            'SVHN':
                {'n_epoch': 20, 'transform': transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))]),
                'loader_tr_args':{'batch_size': 64, 'num_workers': 1},
                'loader_te_args':{'batch_size': 1000, 'num_workers': 1},
                'optimizer_args':{'lr': 0.01, 'momentum': 0.5}},
            'CIFAR10':
                {'n_epoch': 3, 'transform': transforms.Compose([ 
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
                ]),
                'loader_tr_args':{'batch_size': 128, 'num_workers': 1},
                'loader_te_args':{'batch_size': 10, 'num_workers': 1}, # change back to 1000
                'optimizer_args':{'lr': 0.05, 'momentum': 0.3},
                'transformTest': transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])}
                }
parameters = tuple(net.parameters())
X_tr, Y_tr, X_te, Y_te = get_dataset('CIFAR10', 'data')
dim = np.shape(X_tr)[1:]
handler = get_handler('CIFAR10')
test_loader = DataLoader(handler(X_tr, Y_tr, transform=args['CIFAR10']['transform']), shuffle=False, **args['CIFAR10']['loader_te_args'])
num_samples = 1 # used by FISH mask paper
idx = 0
sq_grads_expect = []
sq_grads_expect_orig = {i: np.zeros(p.shape,dtype=np.float32) for i, p in enumerate(parameters)}
parameters = tuple(net.parameters())

for test_batch, test_labels, idxs in test_loader:
 
    #Method 1
    # test_batch, test_labels = test_batch, test_labels
    outputs, e1 = net(test_batch)
    print('Outputs: ',outputs.shape)
    # _, preds = torch.max(outputs, 1)         
    probs = F.softmax(outputs, dim=1).to('cpu')
    log_probs = F.log_softmax(outputs, dim=1)
    N, C = log_probs.shape
    ps = probs.detach()
    f = log_probs #torch.sqrt(ps)*log_probs
    grad_list = []
    for c in range(C):
        grad_outs = torch.zeros_like(f)
        grad_outs[:, c] = 1
        grad = torch.autograd.grad(f, parameters, grad_outputs=grad_outs, retain_graph=True) #f
        print(grad[0].shape)
        #Below if-else used to add tuples element-wise
        # grad = torch.square(grad)
        
        if c == 0:
            grad_list = [(add10(g)/N).detach().numpy() for g in grad] #torch.square(g)/N
        else:
            grad_list = list( map(add, grad_list,  [(add10(g)/N).detach().numpy() for g in grad]) ) #torch.square(g)/N
            
        net.zero_grad()
    
    # grad_list = [np.array(i/N) for i in grad_list]
    #Below if-else used to add tuples element-wise
    if idx == 0:
        sq_grads_expect = grad_list
    else:
        sq_grads_expect = list( map(add, sq_grads_expect, grad_list) )
        
    #Method 2
    print("original way next")
    for n in range(N):
        for c in range(C):
            grad_list_orig = torch.autograd.grad(f[n][c], parameters, retain_graph=True) #log_probs
            for i, grad in enumerate(grad_list_orig):   # different layers
                if (n ==0) & (c==0) & (i ==0):
                    print(grad.shape)
                # gsq = torch.square(grad).to('cpu') * probs[n][c] / N
                gsq = add1(grad).to('cpu') / N # torch.square(grad)
                sq_grads_expect_orig[i] += gsq.detach().numpy() # sq_grads_expect[i] + gsq
                # del grad
            net.zero_grad()
    
    idx += 1
    if idx >= num_samples:
        break

print(sq_grads_expect[0])
print(sq_grads_expect_orig[0])

In [2]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import resnet
import torchvision
from torchvision import transforms
from torch.func import functional_call, vmap, grad

net = resnet.ResNet18(num_classes=10)
args = {
            'CIFAR10':
                {
                    'n_epoch': 3, 
                    'transform': transforms.Compose([ 
                        transforms.RandomCrop(32, padding=4),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))]),
                    'loader_tr_args':{'batch_size': 128, 'num_workers': 1},
                    'loader_te_args':{'batch_size': 10, 'num_workers': 1}, # change back to 1000
                    'optimizer_args':{'lr': 0.05, 'momentum': 0.3},
                    'transformTest': transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])
                }
        }
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=args['CIFAR10']['transform'])
train_loader = torch.utils.data.DataLoader(trainset, batch_size=args['CIFAR10']['loader_tr_args']['batch_size'], shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=args['CIFAR10']['transform'])
testloader = torch.utils.data.DataLoader(testset, batch_size=args['CIFAR10']['loader_te_args']['batch_size'], shuffle=False, num_workers=2)


def compute_loss(params, buffers, sample, target):
    batch = sample.unsqueeze(0)
    #targets = target.unsqueeze(0)

    predictions = functional_call(net, (params, buffers), (batch,))
    loss = loss_fn(predictions[0], target)
    return loss
            

def loss_fn(predictions, targets):
    return torch.from_numpy(predictions[:, targets]) #F.nll_loss(predictions, targets)


parameters = tuple(net.parameters())
sq_grads_expect_orig = {i: np.zeros(p.shape,dtype=np.float32) for i, p in enumerate(parameters)}

num_samples=1
idx=0

#params = {k: v.detach() for k, v in net.named_parameters()}
#buffers = {k: v.detach() for k, v in net.named_buffers()}
#ft_compute_grad = grad(compute_loss)
#ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))
#ft_per_sample_grads = ft_compute_sample_grad(params, buffers, torch.permute(torch.from_numpy(trainset.data),(0,3,1,2)).float(), torch.tensor(trainset.targets))

for train_batch, train_labels in train_loader:
    outputs, e1 = net(train_batch)      
    probs = F.softmax(outputs, dim=1)#.to('cpu')
    log_probs = F.log_softmax(outputs, dim=1)
    N, C = log_probs.shape
    print("original way next")
    for n in range(N):
        for c in range(C):
            grad_list_orig = torch.autograd.grad(log_probs[n][c], parameters, retain_graph=True)
            for i, grad in enumerate(grad_list_orig):   # different layers
                # gsq = torch.square(grad).to('cpu') * probs[n][c] / N
                gsq = grad#.to('cpu') # torch.square(grad)
                sq_grads_expect_orig[i] += gsq.detach().numpy()
                # del grad
            net.zero_grad()

    params = {k: v.detach() for k, v in net.named_parameters()}
    buffers = {k: v.detach() for k, v in net.named_buffers()}
    grads_arr = []
    for c in range(C):
        ft_compute_grad = grad(compute_loss)
        ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))
        ft_per_sample_grads = sum(ft_compute_sample_grad(params, buffers, train_batch, c))#train_labels))
        grads_arr.append(ft_per_sample_grads)
        net.zero_grad()
    
    grads_arr = sum(grads_arr)
    idx += 1
    if idx >= num_samples:
        break

Files already downloaded and verified
Files already downloaded and verified
original way next


TypeError: 'Tensor' object is not callable

In [6]:
sq_grads_expect_orig[0][0]

array([[[12.704348  , 12.60496   ,  7.723358  ],
        [ 9.278099  ,  9.066983  ,  3.3033621 ],
        [ 2.2450042 ,  2.668539  , -3.1830425 ]],

       [[10.8174    , 11.973187  ,  8.793834  ],
        [ 6.7906795 ,  6.8204546 ,  3.349778  ],
        [ 0.9677659 ,  2.1111672 , -2.3996966 ]],

       [[12.784865  , 14.335051  , 11.729529  ],
        [ 9.406301  ,  9.912758  ,  7.700953  ],
        [ 2.698927  ,  4.172598  ,  0.61233765]]], dtype=float32)

In [5]:
sum(ft_per_sample_grads['conv1.weight'])[0]

tensor([[[ 3.3205,  3.0951,  6.5645],
         [ 2.6982,  1.3491,  0.6842],
         [ 1.7819,  4.7311,  1.5717]],

        [[ 2.2331,  2.5682,  6.2995],
         [ 1.3808,  1.3336,  0.0521],
         [ 0.5188,  4.1565, -0.4611]],

        [[ 2.0457,  2.7036,  6.1600],
         [ 0.5252,  0.5636, -0.4522],
         [ 0.2869,  3.7947, -0.5545]]])

In [None]:
print([i for (i, j) in net.named_parameters()])
print(sq_grads_expect_orig[0].shape)
print(sum(ft_per_sample_grads["conv1.weight"]).shape)

In [None]:
print(sq_grads_expect_orig[0][0])
print(sum(ft_per_sample_grads["conv1.weight"])[0])

In [None]:
# all([torch.allclose(sq_grads_expect[i], torch.from_numpy(sq_grads_expect_orig[i]).float(), atol=1e-4) for i in range(61)]) 
all([np.allclose(sq_grads_expect[i], sq_grads_expect_orig[i], atol=1e-4) for i in range(61)]) 

In [None]:
sq_grads_expect = []
sq_grads = {i: np.zeros(4,dtype=np.float32) for i in range(3)} # 4 is the n and 3 is the number of parameters

for idx in range(2):
    n = 4
    x = torch.rand(n, requires_grad = True)
    y = torch.rand(n, requires_grad = True)
    z = torch.rand(n, requires_grad = True)
    # sq_grads = {i: np.zeros(p.shape) for i, p in enumerate((x,y,z))}
    f = torch.column_stack([2 * x + 3 * torch.square(z), torch.square(y) + 4 * z * torch.square(x)])
    probs = torch.abs(torch.rand(f.shape, requires_grad=False))
    pf = torch.sqrt(probs) * f
    N, C = f.shape
    grad_list = []
    for c in range(C):
        grad_outs = torch.zeros_like(f)
        grad_outs[:, c] = 1
        n_grad = torch.autograd.grad(pf, (x, y, z), grad_outputs=grad_outs, retain_graph=True)
        if c == 0:
            grad_list = [(torch.square(g)/N).detach().numpy() for g in n_grad] # #n_grad
        else:
            grad_list = list( map(add, grad_list,  [(torch.square(g)/N).detach().numpy() for g in n_grad]) )
        
    if idx == 0:
        sq_grads_expect = grad_list
    else:
        sq_grads_expect = list( map(add, sq_grads_expect, grad_list) )        

    # print(grad_list)
    # print(f.grad)
    # verifying calculated gradients are correct
    # print(x, y, z)

    for idx in range(N):
        for c in range(C):
            new_grad = torch.autograd.grad(f[idx][c], (x, y, z),  retain_graph=True)
            # print(new_grad)
            for i, grad in enumerate(new_grad):
                # print(type(grad), grad.shape)
                gsq = torch.square(grad).to('cpu') * probs[idx][c]  / N
                sq_grads[i] += gsq.detach().numpy()

print(sq_grads_expect)
print(sq_grads)

In [None]:
type(sq_grads_expect[0][0])

In [None]:
type(sq_grads[0][0])

In [None]:
net = resnet.ResNet18(num_classes=4)
args = {'MNIST':
                {'n_epoch': 10, 'transform': transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
                'loader_tr_args':{'batch_size': 64, 'num_workers': 1},
                'loader_te_args':{'batch_size': 1000, 'num_workers': 1},
                'optimizer_args':{'lr': 0.01, 'momentum': 0.5}},
            'FashionMNIST':
                {'n_epoch': 10, 'transform': transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
                'loader_tr_args':{'batch_size': 64, 'num_workers': 1},
                'loader_te_args':{'batch_size': 1000, 'num_workers': 1},
                'optimizer_args':{'lr': 0.01, 'momentum': 0.5}},
            'SVHN':
                {'n_epoch': 20, 'transform': transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))]),
                'loader_tr_args':{'batch_size': 64, 'num_workers': 1},
                'loader_te_args':{'batch_size': 1000, 'num_workers': 1},
                'optimizer_args':{'lr': 0.01, 'momentum': 0.5}},
            'CIFAR10':
                {'n_epoch': 3, 'transform': transforms.Compose([ 
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
                ]),
                'loader_tr_args':{'batch_size': 128, 'num_workers': 1},
                'loader_te_args':{'batch_size': 10, 'num_workers': 1}, # change back to 1000
                'optimizer_args':{'lr': 0.05, 'momentum': 0.3},
                'transformTest': transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])}
                }
parameters = tuple(net.parameters())
X_tr, Y_tr, X_te, Y_te = get_dataset('CIFAR10', 'data')
dim = np.shape(X_tr)[1:]
handler = get_handler('CIFAR10')
test_loader = DataLoader(handler(X_tr, Y_tr, transform=args['CIFAR10']['transform']), shuffle=False, **args['CIFAR10']['loader_te_args'])
num_samples = 2 # used by FISH mask paper
idx = 0
sq_grads_expect = []
sq_grads_expect_orig = {i: np.zeros(p.shape) for i, p in enumerate(parameters)}
parameters = tuple(net.parameters())

for test_batch, test_labels, idxs in test_loader:
 
    
    test_batch, test_labels = test_batch, test_labels
    outputs, e1 = net(test_batch)
    print('Outputs: ',outputs.shape)
    _, preds = torch.max(outputs, 1)         
    probs = F.softmax(outputs, dim=1).to('cpu')
    log_probs = F.log_softmax(outputs, dim=1)
    N, C = log_probs.shape
    ps = probs.detach()
    f = log_probs*torch.sqrt(ps)

    for c in range(f.shape[1]):
        grad_outs = torch.zeros_like(f)
        grad_outs[:, c] = 1
        grad = torch.autograd.grad(f, parameters, grad_outputs=grad_outs, retain_graph=True)
        #Below if-else used to add tuples element-wise
        if c == 0:
            grad_list = [torch.square(g) for g in grad]
        else:
            grad_list = list( map(add, grad_list, [torch.square(g) for g in grad]) )
        
        net.zero_grad()
    
    grad_list = [np.array(i/N) for i in grad_list]
    #Below if-else used to add tuples element-wise
    if idx == 0:
        sq_grads_expect = grad_list
    else:
        sq_grads_expect = list( map(add, sq_grads_expect, grad_list) )

    print("original way next")
    for n in range(N):
        for c in range(C):
            grad_list_orig = torch.autograd.grad(log_probs[n][c], parameters, retain_graph=True)
            for i, grad in enumerate(grad_list_orig):    # different layers
                gsq = torch.square(grad).to('cpu') * probs[n][c] / N
                sq_grads_expect_orig[i] += gsq.detach().numpy() # sq_grads_expect[i] + gsq
                del gsq
            net.zero_grad()
    
    idx += 1
    if idx >= num_samples:
        break

print()

In [None]:
from functorch import make_functional_with_buffers, vmap, grad

In [None]:
net = resnet.ResNet18(num_classes=4)
parameters = tuple(net.parameters())
X_tr, Y_tr, X_te, Y_te = get_dataset('CIFAR10', 'data')

In [None]:
fmodel, params, buffers = make_functional_with_buffers(net)

In [None]:
fmodel

In [None]:
def compute_loss_stateless_model (params, buffers, sample, target):
    batch = sample.unsqueeze(0)
    targets = target.unsqueeze(0)

    predictions = fmodel(params, buffers, batch) 
    loss = loss_fn(predictions, targets)
    return loss

## PER-SAMPLE-GRADIENTS

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial

torch.manual_seed(0);

In [None]:
# Here's a simple CNN and loss function:

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        output = x
        return output

def loss_fn(predictions, targets):
    return F.nll_loss(predictions, targets)

In [None]:
device = 'cuda'

num_models = 10
batch_size = 64
data = torch.randn(batch_size, 1, 28, 28, device=device)

targets = torch.randint(10, (64,), device=device)

In [None]:
model = SimpleCNN().to(device=device)
predictions = model(data) # move the entire mini-batch through the model

loss = loss_fn(predictions, targets)
loss.backward() # back propogate the 'average' gradient of this mini-batch

In [None]:
def compute_grad(sample, target):
    
    sample = sample.unsqueeze(0)  # prepend batch dimension for processing
    target = target.unsqueeze(0)

    prediction = model(sample)
    loss = loss_fn(prediction, target)

    return torch.autograd.grad(loss, list(model.parameters()))


def compute_sample_grads(data, targets):
    """ manually process each sample with per sample gradient """
    sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]
    sample_grads = zip(*sample_grads)
    sample_grads = [torch.stack(shards) for shards in sample_grads]
    return sample_grads

per_sample_grads = compute_sample_grads(data, targets)

In [None]:
print(per_sample_grads[0].shape)

In [None]:
from functorch import make_functional_with_buffers, vmap, grad

fmodel, params, buffers = make_functional_with_buffers(model)

In [None]:
fmodel

In [None]:
def compute_loss_stateless_model (params, buffers, sample, target):
    batch = sample.unsqueeze(0)
    targets = target.unsqueeze(0)

    predictions = fmodel(params, buffers, batch) 
    loss = loss_fn(predictions, targets)
    return loss

In [None]:
ft_compute_grad = grad(compute_loss_stateless_model)

In [None]:
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))

In [None]:
ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets)

# we can double check that the results using functorch grad and vmap match the results of hand processing each one individually:
for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads):
    assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)