# Combining gone fishing (BAIT) with Fisher mask (FISH)

Fisher mask code in FISH folder

In [None]:
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torchvision.models import resnet18

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

model = resnet18(pretrained=True)

In [None]:
w = model.conv1.weight.grad
type(w)

Next steps:

* download the dataset cifar10, and create a dataloader  with appropriate preprocessing
* iterate over minibatches from dataloader, calculate crossentropy loss, 
* in typical sgd training, we zero out the gradients before processing the next minibatch. this is done by calling model.zero_grad(), and then calling model.backward() to get the new gradients. Since we are not interested in training right now, we want the sum of gradients over all minibatches. so we should not call model.zero_grad()
* identify the top 2% of all the weights in the model and make a new pk x k array, where p = 2% of the total number of parameters.

In [None]:
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim
import resnet

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

#using the paper's model definition from resnet.py
model = resnet.ResNet18()

#arguments from run.py
cifar_args = {'n_epoch': 3, 'transform': transforms.Compose([ 
                     transforms.RandomCrop(32, padding=4),
                     transforms.RandomHorizontalFlip(),
                     transforms.ToTensor(),
                     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
                 ]),
                 'loader_tr_args':{'batch_size': 128, 'num_workers': 1},
                 'loader_te_args':{'batch_size': 1000, 'num_workers': 1},
                 'optimizer_args':{'lr': 0.05, 'momentum': 0.3},
                 'transformTest': transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])
            }
cifar_trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=cifar_args['transform'])
cifar_trainloader = torch.utils.data.DataLoader(cifar_trainset, batch_size=cifar_args['loader_tr_args']['batch_size'], num_workers=cifar_args['loader_tr_args']['num_workers'], shuffle=True)
cifar_testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=cifar_args['transform'])
cifar_trainloader = torch.utils.data.DataLoader(cifar_trainset, batch_size=cifar_args['loader_te_args']['batch_size'], num_workers=cifar_args['loader_te_args']['num_workers'], shuffle=False)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=cifar_args['optimizer_args']['lr'], momentum=cifar_args['optimizer_args']['momentum'])

In [None]:
for name, param in model.named_parameters():
    if name.__contains__("weight"):
        print(name, param.grad.size())

In [None]:
for epoch in range(1):#cifar_args['n_epoch']):
    running_loss = 0.0
    for i, data in enumerate(cifar_trainloader):
        inputs, labels = data
        #we omit the step of zeroing gradients here
        outputs, emb = model.forward(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        #print(model.layer1[0].conv1.weight.grad) #how to access gradients of layers
        optimizer.step()
        running_loss += loss.item()
        for name, param in model.named_parameters():
            if name.__contains__("weight"):
                print(name, param.grad)
print('Done Training.')