## Using FashionMNIST to train, test and prune a CNN

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import math

from collections import OrderedDict

from run_classes import RunBuilder, RunManager

import pandas as pd
import numpy as np

#### Import Train and Test Set

In [2]:
train_set = torchvision.datasets.FashionMNIST(
    root = 'data/FashionMNIST'
    , train = True
    , download = True
    , transform = transforms.Compose([
        transforms.ToTensor()
        #, transforms.Normalize(mean = [], std = [])
    ]))


test_set = torchvision.datasets.FashionMNIST(
    root ='data/FashionMNIST'
    , train = False
    , download = True
    , transform = transforms.Compose([
        transforms.ToTensor()
        #, transforms.Normalize(mean = [], std = [])
    ]))

#### Calculate Mean and Std for Train and Test Set for normalization

In [3]:
loader = DataLoader(train_set, batch_size = len(train_set), num_workers = 1)
data = next(iter(loader))
data[0].mean(), data[0].std()

(tensor(0.2859), tensor(0.3530))

In [4]:
loader = DataLoader(test_set, batch_size = len(test_set), num_workers = 1)
data = next(iter(loader))
data[0].mean(), data[0].std()

(tensor(0.2869), tensor(0.3524))

#### Create normalized Train and Test Set

In [5]:
train_set_norm = torchvision.datasets.FashionMNIST(
    root = 'data/FashionMNIST'
    , train = True
    , download = True
    , transform = transforms.Compose([
        transforms.ToTensor()
        , transforms.Normalize(mean = [.2859], std = [.3524])
    ]))


test_set_norm = torchvision.datasets.FashionMNIST(
    root ='data/FashionMNIST'
    , train = False
    , download = True
    , transform = transforms.Compose([
        transforms.ToTensor()
        , transforms.Normalize(mean = [.2869], std = [.3524])
    ]))

In [6]:
train_set.classes

['T-shirt/top',
 'Trouser',
 'Pullover',
 'Dress',
 'Coat',
 'Sandal',
 'Shirt',
 'Sneaker',
 'Bag',
 'Ankle boot']

#### Define our Network

In [7]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = 1, out_channels = 6, kernel_size = 5)
        self.conv2 = nn.Conv2d(in_channels = 6, out_channels = 12, kernel_size = 5)
        
        #table and formula to calculate the changes of img sizes:
        # https://deeplizard.com/learn/video/cin4YcGBh3Q
        self.fc1 = nn.Linear(in_features = 12*4*4, out_features = 120) #needed, because the img has the shape
                                                                        #(1, 12, 4, 4) when it arrives at the fc
                                                                        #because it is flattened, the input is 12*4*4
        self.fc2 = nn.Linear(in_features = 120, out_features = 60)
        self.out = nn.Linear(in_features = 60, out_features = 10)
        
    def forward(self, t):
        t = F.relu(self.conv1(t))
        t = F.max_pool2d(t, kernel_size = 2, stride = 2)
        
        t = F.relu(self.conv2(t))
        t = F.max_pool2d(t, kernel_size = 2, stride = 2)
        
        t = F.relu(self.fc1(t.reshape(-1, 12 * 4 * 4)))
        t = F.relu(self.fc2(t))
        
        t = self.out(t)
        #normally softmax, but is implicitly included in the cross entropy 
        return t

#### Defining Testing Dics

In [8]:
trainsets = {
    #'not_norm' : train_set,
    'norm' : train_set_norm
}

params = OrderedDict(
    lr = [.01]
    , batch_size = [1000]
    , num_workers = [1]
    , epochs = [5]
    , trainset = list(trainsets.keys())
)

#### Trainingsloop

In [9]:
m = RunManager()
for run in RunBuilder.get_runs(params):
    network = CNN()
    loader = DataLoader(trainsets[run.trainset]
                        , batch_size = run.batch_size
                        , num_workers = run.num_workers)
    optimizer = torch.optim.Adam(network.parameters(), lr = run.lr)
    
    m.begin_run(run, network, loader)
    for epoch in range(run.epochs):
        m.begin_epoch()
        
        for batch in loader:
            images, labels = batch
            preds = network(images)
            loss = F.cross_entropy(preds, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            m.track_loss(loss)
            m.track_num_correct(preds, labels)
        m.end_epoch()
    m.end_run()

Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,num_workers,epochs,trainset
0,1,1,0.834753,0.683817,11.147281,14.355769,0.01,1000,1,5,norm
1,1,2,0.45602,0.826767,12.111821,26.559419,0.01,1000,1,5,norm
2,1,3,0.378957,0.860817,12.516968,39.162835,0.01,1000,1,5,norm
3,1,4,0.335699,0.876917,12.605595,51.855749,0.01,1000,1,5,norm
4,1,5,0.31942,0.881417,12.678778,64.622526,0.01,1000,1,5,norm


### Applying the Test Set

In [10]:
testloader = DataLoader(test_set_norm
                        , batch_size = 1000
                        , num_workers = 1)

In [11]:
correct = 0
total = 0

network.eval()
with torch.no_grad():
    for batch in testloader:
        images, labels = batch
        preds = network(images)
        predicted = torch.max(preds, 1)[1]
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
    print('Accuracy on test set in %:', (correct / total * 100))

Accuracy on test set in %: 86.69


### Measuring Accuracy of different classes

In [12]:
class_correct = list(0. for i in range(len(train_set.classes)))
class_total = list(0. for i in range(len(train_set.classes)))

with torch.no_grad():
    for batch in testloader:
        images, labels = batch
        preds = network(images)
        predicted = torch.max(preds, 1)[1]
        c = (predicted == labels)
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1
            
for i in range(len(train_set.classes)):
    print('Accuracy of', train_set.classes[i], class_correct[i] / class_total[i] * 100, '%')

Accuracy of T-shirt/top 75.0 %
Accuracy of Trouser 100.0 %
Accuracy of Pullover 75.0 %
Accuracy of Dress 80.0 %
Accuracy of Coat 66.66666666666666 %
Accuracy of Sandal 100.0 %
Accuracy of Shirt 100.0 %
Accuracy of Sneaker 100.0 %
Accuracy of Bag 100.0 %
Accuracy of Ankle boot 100.0 %


### Pruning the Model

In [13]:
import torch.nn.utils.prune as prune

In [14]:
module = network.conv1
print(list(module.named_parameters()))

[('weight', Parameter containing:
tensor([[[[ 0.1406,  0.3241, -0.3158, -0.1282, -0.3015],
          [ 0.1322,  0.1489,  0.1558, -0.1870, -0.0869],
          [-0.0017, -0.0868,  0.2813, -0.1761,  0.1028],
          [ 0.2488, -0.0814, -0.0869, -0.1579,  0.0719],
          [ 0.1737, -0.2295,  0.1568,  0.0553, -0.1848]]],


        [[[-0.1738,  0.1401,  0.3369,  0.1556,  0.1386],
          [-0.2808,  0.3735, -0.0281, -0.2210,  0.0766],
          [-0.1657,  0.2196, -0.0553, -0.0506, -0.1463],
          [-0.0950,  0.0341, -0.1670,  0.0759,  0.1428],
          [-0.0037,  0.2186, -0.0685, -0.0528, -0.0574]]],


        [[[ 0.0083,  0.1842, -0.0096,  0.1312, -0.2054],
          [-0.0536,  0.0767,  0.2360,  0.0437, -0.0317],
          [ 0.2464, -0.2670,  0.1162,  0.0054,  0.0795],
          [ 0.0985, -0.1735,  0.1174,  0.0308,  0.0866],
          [ 0.3539, -0.1757, -0.1554,  0.2130,  0.0505]]],


        [[[-0.1154, -0.0993,  0.2199,  0.1610, -0.0416],
          [-0.4733, -0.1686,  0.2232,  0.2

In [15]:
#The layer does not contain any buffers at the moment
print(list(module.named_buffers()))

[]


1) First, select a pruning method among those available in torch.nn.utils.prune or implement your own by using BasePruningMethod  
2) Then, specify module and name of parameter to prune.  
3) Use keyword arguments for the specific pruning techniques, specifying pruning parameters.  
  
Let's prune 30% of our network randomly.

In [16]:
#Prune the weights of the conv1 layer by 30%
# when amount is integer, results in number of connections pruned.
prune.random_unstructured(module, name = 'weight', amount = 0.3)

Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

Pruning removes the `weight` from the parameters and gives a new parameter calles `weight_orig` which stores the unpruned version of the tensor. Bias was not pruned, so it stays.

In [17]:
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([-0.0450,  0.2849, -0.0798,  0.0157, -0.0742, -0.4180],
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.1406,  0.3241, -0.3158, -0.1282, -0.3015],
          [ 0.1322,  0.1489,  0.1558, -0.1870, -0.0869],
          [-0.0017, -0.0868,  0.2813, -0.1761,  0.1028],
          [ 0.2488, -0.0814, -0.0869, -0.1579,  0.0719],
          [ 0.1737, -0.2295,  0.1568,  0.0553, -0.1848]]],


        [[[-0.1738,  0.1401,  0.3369,  0.1556,  0.1386],
          [-0.2808,  0.3735, -0.0281, -0.2210,  0.0766],
          [-0.1657,  0.2196, -0.0553, -0.0506, -0.1463],
          [-0.0950,  0.0341, -0.1670,  0.0759,  0.1428],
          [-0.0037,  0.2186, -0.0685, -0.0528, -0.0574]]],


        [[[ 0.0083,  0.1842, -0.0096,  0.1312, -0.2054],
          [-0.0536,  0.0767,  0.2360,  0.0437, -0.0317],
          [ 0.2464, -0.2670,  0.1162,  0.0054,  0.0795],
          [ 0.0985, -0.1735,  0.1174,  0.0308,  0.0866],
          [ 0.3539, -0.1757, -0.

The mask, that prunes the parameter is stored as a module buffer named `weight_mask`.

In [18]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[1., 1., 1., 0., 0.],
          [1., 1., 1., 1., 1.],
          [1., 0., 1., 0., 1.],
          [0., 1., 1., 1., 1.],
          [1., 1., 0., 1., 0.]]],


        [[[1., 1., 0., 1., 0.],
          [1., 0., 1., 0., 1.],
          [1., 0., 1., 0., 1.],
          [1., 1., 1., 1., 0.],
          [1., 1., 1., 0., 1.]]],


        [[[1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 0., 1.],
          [1., 1., 1., 1., 0.]]],


        [[[1., 1., 1., 0., 0.],
          [0., 1., 1., 0., 1.],
          [1., 1., 0., 1., 0.],
          [1., 0., 1., 1., 1.],
          [1., 0., 1., 1., 1.]]],


        [[[0., 1., 1., 1., 1.],
          [1., 1., 0., 1., 0.],
          [0., 0., 0., 1., 1.],
          [0., 0., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]],


        [[[1., 1., 0., 1., 1.],
          [1., 0., 1., 0., 1.],
          [0., 1., 0., 0., 1.],
          [1., 1., 0., 1., 0.],
          [0., 0., 0., 1., 1.]]]]))

But where is the updated weight stored now? As it is required for the forward-pass to function properly, the pruned version of the weights is stored in `weight`. It can be accessed as an attribute and is NOT stored in the parameter of the module anymore.

In [19]:
print(module.weight)

tensor([[[[ 0.1406,  0.3241, -0.3158, -0.0000, -0.0000],
          [ 0.1322,  0.1489,  0.1558, -0.1870, -0.0869],
          [-0.0017, -0.0000,  0.2813, -0.0000,  0.1028],
          [ 0.0000, -0.0814, -0.0869, -0.1579,  0.0719],
          [ 0.1737, -0.2295,  0.0000,  0.0553, -0.0000]]],


        [[[-0.1738,  0.1401,  0.0000,  0.1556,  0.0000],
          [-0.2808,  0.0000, -0.0281, -0.0000,  0.0766],
          [-0.1657,  0.0000, -0.0553, -0.0000, -0.1463],
          [-0.0950,  0.0341, -0.1670,  0.0759,  0.0000],
          [-0.0037,  0.2186, -0.0685, -0.0000, -0.0574]]],


        [[[ 0.0083,  0.1842, -0.0096,  0.1312, -0.0000],
          [-0.0536,  0.0767,  0.2360,  0.0437, -0.0317],
          [ 0.2464, -0.2670,  0.1162,  0.0054,  0.0795],
          [ 0.0985, -0.1735,  0.1174,  0.0000,  0.0866],
          [ 0.3539, -0.1757, -0.1554,  0.2130,  0.0000]]],


        [[[-0.1154, -0.0993,  0.2199,  0.0000, -0.0000],
          [-0.0000, -0.1686,  0.2232,  0.0000,  0.0517],
          [-0.2512,

Summarized:
* `name + _mask` is the binary mask which indicates pruning
* `name` usually containing original parameters is replaced with the pruned params
* `name + _orig` is a new parameter that stores the unpruned (original) parameters.

Pruning is applied prior to each forward pass with the `forward_pre_hooks`. It is acquired when the module is pruned. Only one hook is present, as only one parameter was pruned so far.

In [20]:
print(module._forward_pre_hooks)

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7fa91e49cad0>)])


Let's prune the bias as well!

In [21]:
#pruning three smallest entries of the bias with L1 Norm.
prune.l1_unstructured(module, name = 'bias', amount = 3)

Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

In [22]:
print(module._forward_pre_hooks) #we have 2 hooks now

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7fa91e49cad0>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7fa933226510>)])


Now we have a `bias_orig` as well, `bias_mask` and two forward hooks.

#### Iterative Pruning

The same parameter can be tuned multiple times - the created masks are then applied in series. This mask-combination is handled by the `compute_mask` method.

In [23]:
#Using structured pruning along the output channels of conv1, based on L2 norm (?).
prune.ln_structured(module, name = "weight", amount = 0.5, n = 2, dim = 0) #half of the channels is zerod out

Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

In [24]:
print(module.weight)

tensor([[[[ 0.0000,  0.0000, -0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000,  0.0000, -0.0000]]],


        [[[-0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000, -0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000, -0.0000, -0.0000]]],


        [[[ 0.0083,  0.1842, -0.0096,  0.1312, -0.0000],
          [-0.0536,  0.0767,  0.2360,  0.0437, -0.0317],
          [ 0.2464, -0.2670,  0.1162,  0.0054,  0.0795],
          [ 0.0985, -0.1735,  0.1174,  0.0000,  0.0866],
          [ 0.3539, -0.1757, -0.1554,  0.2130,  0.0000]]],


        [[[-0.1154, -0.0993,  0.2199,  0.0000, -0.0000],
          [-0.0000, -0.1686,  0.2232,  0.0000,  0.0517],
          [-0.2512,

In [25]:
#History of pruning is further written:
for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == 'weight':
        break
        
print(list(hook))

[<torch.nn.utils.prune.RandomUnstructured object at 0x7fa91e49cad0>, <torch.nn.utils.prune.LnStructured object at 0x7fa9332265d0>]


In [26]:
#Getting all relevant tensors of our network:
network.state_dict().keys()

odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'out.weight', 'out.bias'])

#### Remove pruning re-parametrization

= making pruning permanent, removing forward_pre_hook

In [27]:
print(list(module.named_parameters()))

[('weight_orig', Parameter containing:
tensor([[[[ 0.1406,  0.3241, -0.3158, -0.1282, -0.3015],
          [ 0.1322,  0.1489,  0.1558, -0.1870, -0.0869],
          [-0.0017, -0.0868,  0.2813, -0.1761,  0.1028],
          [ 0.2488, -0.0814, -0.0869, -0.1579,  0.0719],
          [ 0.1737, -0.2295,  0.1568,  0.0553, -0.1848]]],


        [[[-0.1738,  0.1401,  0.3369,  0.1556,  0.1386],
          [-0.2808,  0.3735, -0.0281, -0.2210,  0.0766],
          [-0.1657,  0.2196, -0.0553, -0.0506, -0.1463],
          [-0.0950,  0.0341, -0.1670,  0.0759,  0.1428],
          [-0.0037,  0.2186, -0.0685, -0.0528, -0.0574]]],


        [[[ 0.0083,  0.1842, -0.0096,  0.1312, -0.2054],
          [-0.0536,  0.0767,  0.2360,  0.0437, -0.0317],
          [ 0.2464, -0.2670,  0.1162,  0.0054,  0.0795],
          [ 0.0985, -0.1735,  0.1174,  0.0308,  0.0866],
          [ 0.3539, -0.1757, -0.1554,  0.2130,  0.0505]]],


        [[[-0.1154, -0.0993,  0.2199,  0.1610, -0.0416],
          [-0.4733, -0.1686,  0.2232,

In [28]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 0., 1.],
          [1., 1., 1., 1., 0.]]],


        [[[1., 1., 1., 0., 0.],
          [0., 1., 1., 0., 1.],
          [1., 1., 0., 1., 0.],
          [1., 0., 1., 1., 1.],
          [1., 0., 1., 1., 1.]]],


        [[[0., 1., 1., 1., 1.],
          [1., 1., 0., 1., 0.],
          [0., 0., 0., 1., 1.],
          [0., 0., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]]]))

In [29]:
print(module.weight)

tensor([[[[ 0.0000,  0.0000, -0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000,  0.0000, -0.0000]]],


        [[[-0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000, -0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000, -0.0000, -0.0000]]],


        [[[ 0.0083,  0.1842, -0.0096,  0.1312, -0.0000],
          [-0.0536,  0.0767,  0.2360,  0.0437, -0.0317],
          [ 0.2464, -0.2670,  0.1162,  0.0054,  0.0795],
          [ 0.0985, -0.1735,  0.1174,  0.0000,  0.0866],
          [ 0.3539, -0.1757, -0.1554,  0.2130,  0.0000]]],


        [[[-0.1154, -0.0993,  0.2199,  0.0000, -0.0000],
          [-0.0000, -0.1686,  0.2232,  0.0000,  0.0517],
          [-0.2512,

In [31]:
#Removing reparametrization
prune.remove(module, 'weight')
print(list(module.named_parameters())) #under named parameters, 'weight' is now mentioned
                                        #, instead of weight_orig

[('bias_orig', Parameter containing:
tensor([-0.0450,  0.2849, -0.0798,  0.0157, -0.0742, -0.4180],
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.0000,  0.0000, -0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000,  0.0000, -0.0000]]],


        [[[-0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000, -0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000, -0.0000, -0.0000]]],


        [[[ 0.0083,  0.1842, -0.0096,  0.1312, -0.0000],
          [-0.0536,  0.0767,  0.2360,  0.0437, -0.0317],
          [ 0.2464, -0.2670,  0.1162,  0.0054,  0.0795],
          [ 0.0985, -0.1735,  0.1174,  0.0000,  0.0866],
          [ 0.3539, -0.1757, -0.

In [42]:
print(list(module.named_buffers()))

[('bias_mask', tensor([0., 1., 1., 0., 0., 1.]))]


#### Pruning multiple parameters

In [None]:
#prune modules according to their type
for name, module in network.named_modules():
    
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name = 'weight', amount = .2)
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name = 'weight', amount = .1)
        
    #...

#### Global Pruning

local = pruning tensors one by one  
global = removing a certain percentage of connections across the whole model. This means that sparsity may be higher or lower in certain layers, but overall being a specific defined percentage. This is done with `global_unstructured`.

In [None]:
parameters_to_prune = (
    (network.conv1, 'weight')
    #, ... 
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method = prune.L1Unstructured,
    amount = 0.1
)

For custom pruning methods, have a look here: https://pytorch.org/tutorials/intermediate/pruning_tutorial.html (at the bottom)