In [1]:
%matplotlib inline
from copy import deepcopy
from collections import OrderedDict
import gc
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD,Adam,lr_scheduler
from torch.utils.data import random_split
import torchvision
from torchvision import transforms,models


In [2]:
from google.colab import files

In [3]:
train_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(p=.40),
    transforms.RandomRotation(30),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

test_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

traindata = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=train_transform)
trainset,valset = random_split(traindata,[42000,8000])
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,shuffle=True)
valloader = torch.utils.data.DataLoader(valset, batch_size=64,shuffle=False)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,shuffle=False)

classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [4]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        base = models.resnet18(pretrained=True)
        self.base = nn.Sequential(*list(base.children())[:-1])
        in_features = base.fc.in_features
        self.drop = nn.Dropout()
        self.final = nn.Linear(in_features,10)
    
    def forward(self,x):
        x = self.base(x)
        x = self.drop(x.view(-1,self.final.in_features))
        return self.final(x)
    
model = Model().cuda()
[x for x,y in model.named_children()]

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

['base', 'drop', 'final']

In [5]:
criterion = nn.CrossEntropyLoss()
param_groups = [
    {'params':model.base.parameters(),'lr':.0001},
    {'params':model.final.parameters(),'lr':.001}
]
optimizer = Adam(param_groups)
lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
states = {}

In [None]:
%%time
best_val_acc = -1000
best_val_model = None
for epoch in range(10):  
    model.train(True)
    running_loss = 0.0
    running_acc = 0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.cuda(),labels.cuda()

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item() * inputs.size(0)
        out = torch.argmax(outputs.detach(),dim=1)
        assert out.shape==labels.shape
        running_acc += (labels==out).sum().item()
    print(f"Train loss {epoch+1}: {running_loss/len(trainset)},Train Acc:{running_acc*100/len(trainset)}%")
    
    correct = 0
    model.train(False)
    with torch.no_grad():
        for inputs,labels in valloader:
            out = model(inputs.cuda()).cpu()
            out = torch.argmax(out,dim=1)
            acc = (out==labels).sum().item()
            correct += acc
    print(f"Val accuracy:{correct*100/len(valset)}%")
    if correct>best_val_acc:
        best_val_acc = correct
        best_val_model = deepcopy(model.state_dict())
    lr_scheduler.step()
    
print('Finished Training') 

In [None]:
torch.save(model.state_dict(), 'checkpoint.pth')

# download checkpoint file
files.download('checkpoint.pth')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [6]:
state_dict = torch.load('checkpoint.pth')
print(state_dict.keys())

odict_keys(['base.0.weight', 'base.1.weight', 'base.1.bias', 'base.1.running_mean', 'base.1.running_var', 'base.1.num_batches_tracked', 'base.4.0.conv1.weight', 'base.4.0.bn1.weight', 'base.4.0.bn1.bias', 'base.4.0.bn1.running_mean', 'base.4.0.bn1.running_var', 'base.4.0.bn1.num_batches_tracked', 'base.4.0.conv2.weight', 'base.4.0.bn2.weight', 'base.4.0.bn2.bias', 'base.4.0.bn2.running_mean', 'base.4.0.bn2.running_var', 'base.4.0.bn2.num_batches_tracked', 'base.4.1.conv1.weight', 'base.4.1.bn1.weight', 'base.4.1.bn1.bias', 'base.4.1.bn1.running_mean', 'base.4.1.bn1.running_var', 'base.4.1.bn1.num_batches_tracked', 'base.4.1.conv2.weight', 'base.4.1.bn2.weight', 'base.4.1.bn2.bias', 'base.4.1.bn2.running_mean', 'base.4.1.bn2.running_var', 'base.4.1.bn2.num_batches_tracked', 'base.5.0.conv1.weight', 'base.5.0.bn1.weight', 'base.5.0.bn1.bias', 'base.5.0.bn1.running_mean', 'base.5.0.bn1.running_var', 'base.5.0.bn1.num_batches_tracked', 'base.5.0.conv2.weight', 'base.5.0.bn2.weight', 'base.

In [7]:
model.load_state_dict(state_dict)

<All keys matched successfully>

In [16]:
#check accuracy for original model
correct = 0
model.train(False)
with torch.no_grad():
    for inputs,labels in valloader:
        out = model(inputs.cuda()).cpu()
        out = torch.argmax(out,dim=1)
        acc = (out==labels).sum().item()
        correct += acc
print(f"Val accuracy:{correct*100/len(valset)}%")

Val accuracy:9.525%


In [8]:
import torch.nn.utils.prune as prune

In [9]:
parameters_to_prune =[]

In [10]:
parameters_to_prune.append((model.base[0],'weight'))
parameters_to_prune.append((model.base[1],'weight'))
parameters_to_prune.append((model.base[1],'bias'))

In [11]:

for i in range(4,8):
  for j in range(0,2):
    parameters_to_prune.append((model.base[i][j].conv1,"weight"))
    parameters_to_prune.append((model.base[i][j].bn1,"weight"))
    parameters_to_prune.append((model.base[i][j].bn1,"bias"))
    parameters_to_prune.append((model.base[i][j].conv2,"weight"))
    parameters_to_prune.append((model.base[i][j].bn2,"weight"))
    parameters_to_prune.append((model.base[i][j].bn2,"bias"))
    


In [12]:
for i in range(5,8):
  parameters_to_prune.append((model.base[i][0].downsample[0],"weight"))
  parameters_to_prune.append((model.base[i][0].downsample[1],"weight"))
  parameters_to_prune.append((model.base[i][0].downsample[1],"bias"))

In [13]:
parameters_to_prune = tuple(parameters_to_prune)


In [15]:
prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.RandomUnstructured,
        amount=0.5,
    )

In [None]:
#check accuracy for pruned model
correct = 0
model.train(False)
with torch.no_grad():
    for inputs,labels in valloader:
        out = model(inputs.cuda()).cpu()
        out = torch.argmax(out,dim=1)
        acc = (out==labels).sum().item()
        correct += acc
print(f"Val accuracy:{correct*100/len(valset)}%")