Prepare the data:

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import torch.optim as optim
import torchvision.datasets as datasets
from torch.autograd import Variable
from torchsummary import summary

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
torch.cuda.is_available()

cuda:0


True

In [3]:
class CIFAR10RandomLabels(datasets.CIFAR10):
    """CIFAR10 dataset, with support for randomly corrupt labels.
    Params
    ------  
    corrupt_prob: float
    Default 0.0. The probability of a label being replaced with
    random label.
    num_classes: int
    Default 10. The number of classes in the dataset.
    """
    def __init__(self, corrupt_prob=0.0, num_classes=10, **kwargs):
        super(CIFAR10RandomLabels, self).__init__(**kwargs)
        self.n_classes = num_classes
        if corrupt_prob > 0:
            self.corrupt_labels(corrupt_prob)

    def corrupt_labels(self, corrupt_prob):
        labels = np.array(self.targets)
        np.random.seed(12345)
        mask = np.random.rand(len(labels)) <= corrupt_prob
        rnd_labels = np.random.choice(self.n_classes, mask.sum())
        labels[mask] = rnd_labels
        # we need to explicitly cast the labels from npy.int64 to
        # builtin int type, otherwise pytorch will fail...
        labels = [int(x) for x in labels]
        self.targets = torch.tensor(labels)

In [4]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainloader = torch.utils.data.DataLoader(
        CIFAR10RandomLabels(root='./data', train=True, download=True, 
                            transform=transform, corrupt_prob=0.1),
                            batch_size=64, shuffle=True, num_workers=4)
testloader = torch.utils.data.DataLoader(
        CIFAR10RandomLabels(root='./data', train=False, download=True,
                            transform=transform, corrupt_prob=0.1),
                            batch_size=100, shuffle=False, num_workers=4)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [5]:
class MaskedLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super(MaskedLinear, self).__init__(in_features, out_features, bias)
        self.mask_flag = False
    
    def set_mask(self, mask):
        self.mask = Variable(mask, requires_grad=False)
        self.weight.data = self.weight.data*self.mask.data
        self.mask_flag = True
    
    def get_mask(self):
        print(self.mask_flag)
        return self.mask
    
    def prune(self, threshold):
        # generate mask
        for params in self.parameters():
            if len(params.data.size()) != 1: # Not bias
                mask = params.data.abs() > threshold
                self.set_mask(mask.float()) 
    
    def forward(self, x):
        if self.mask_flag == True:
            weight = self.weight*self.mask
            return F.linear(x, weight, self.bias)
        else:
            return F.linear(x, self.weight, self.bias)
        

        
class MaskedConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super(MaskedConv2d, self).__init__(in_channels, out_channels, 
            kernel_size, stride, padding, dilation, groups, bias)
        self.mask_flag = False
    
    def set_mask(self, mask):
        self.mask = Variable(mask, requires_grad=False)
        self.weight.data = self.weight.data*self.mask.data
        self.mask_flag = True
    
    def get_mask(self):
        print(self.mask_flag)
        return self.mask
    
    def prune(self, threshold):
        # generate mask
        for params in self.parameters():
            if len(params.data.size()) != 1: # Not bias
                mask = params.data.abs() > threshold
                self.set_mask(mask.float()) 
    
    def forward(self, x):
        if self.mask_flag == True:
            weight = self.weight*self.mask
            return F.conv2d(x, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)
        else:
            return F.conv2d(x, self.weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

Define the network:

In [6]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        linear = MaskedLinear
        conv2d = MaskedConv2d
        self.conv1 = conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = conv2d(6, 16, 5)
        self.fc1 = linear(16 * 5 * 5, 120)
        self.fc2 = linear(120, 84)
        self.fc3 = linear(84, 10)


    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def weight_prune(self, pruning_perc):
        all_weights = []
        for params in self.parameters():
            if len(params.data.size()) != 1:
                all_weights += list(params.cpu().data.abs().numpy().flatten())
        threshold = np.percentile(np.array(all_weights), pruning_perc)
        print(f'Pruning with threshold : %.4f' % threshold)
        
        # Module here refers to layer
        for name, module in self.named_modules():
            if name in ['conv1','conv2','fc1','fc2','fc3']:
                module.prune(threshold)                

net = Net()
net.to(device)
summary(net, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
      MaskedConv2d-1            [-1, 6, 28, 28]             456
         MaxPool2d-2            [-1, 6, 14, 14]               0
      MaskedConv2d-3           [-1, 16, 10, 10]           2,416
         MaxPool2d-4             [-1, 16, 5, 5]               0
      MaskedLinear-5                  [-1, 120]          48,120
      MaskedLinear-6                   [-1, 84]          10,164
      MaskedLinear-7                   [-1, 10]             850
Total params: 62,006
Trainable params: 62,006
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 0.06
Params size (MB): 0.24
Estimated Total Size (MB): 0.31
----------------------------------------------------------------


In [7]:
#params = list(net.parameters())
#print(len(params))
#print(params[4])

In [8]:
criterion = nn.CrossEntropyLoss() # Softmax is built in it so you do not need add that on the last layer
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
#initial_optimizer_state_dict = optimizer.state_dict()

In [None]:
def train(model, epochs):
    for epoch in range(epochs):  # loop over the dataset multiple times
        
        train_correct = 0
        train_total = 0
        running_loss = 0.0
        
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data[0].to(device), data[1].to(device)
    
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 200 == 199:    # print every 200 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 200))
                running_loss = 0.0
                
            # training accuracy
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            
        print('Train Accuracy: %.3f %%' % (100 * train_correct / train_total))
        
    print('Finished Training')

In [None]:
def test(model):
    correct = 0
    total = 0
    test_loss = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            test_loss += criterion(outputs, labels).item()
    
    print('Accuracy of the network on the 10000 test images: %.2f %%' % (
        100 * correct / total))
    print('Test loss: %.4f ' % (test_loss / 100))

In [None]:
def calculate_size(model):
    non_zeros = 0
    for param in model.parameters():
        if param.requires_grad==True:
            non_zeros += param.nonzero().size(0)
    params_size = non_zeros* 4. / (1024 ** 2)
    print("Params size (MB): %0.3f" % params_size)

In [None]:
train(net, 600)
torch.save(net.state_dict(), "models/LeNet.pt")
test(net)
calculate_size(net)

[1,   200] loss: 2.302
[1,   400] loss: 2.301
[1,   600] loss: 2.296
Train Accuracy: 12.658 %
[2,   200] loss: 2.263
[2,   400] loss: 2.230
[2,   600] loss: 2.190
Train Accuracy: 18.688 %
[3,   200] loss: 2.140
[3,   400] loss: 2.102
[3,   600] loss: 2.078
Train Accuracy: 24.624 %
[4,   200] loss: 2.032
[4,   400] loss: 1.997
[4,   600] loss: 1.972
Train Accuracy: 28.476 %
[5,   200] loss: 1.903
[5,   400] loss: 1.875
[5,   600] loss: 1.849
Train Accuracy: 33.500 %
[6,   200] loss: 1.802
[6,   400] loss: 1.788
[6,   600] loss: 1.784
Train Accuracy: 37.190 %
[7,   200] loss: 1.750
[7,   400] loss: 1.748
[7,   600] loss: 1.730
Train Accuracy: 39.482 %
[8,   200] loss: 1.706
[8,   400] loss: 1.700
[8,   600] loss: 1.697
Train Accuracy: 41.306 %
[9,   200] loss: 1.671
[9,   400] loss: 1.660
[9,   600] loss: 1.665
Train Accuracy: 42.950 %
[10,   200] loss: 1.641
[10,   400] loss: 1.653
[10,   600] loss: 1.626
Train Accuracy: 44.092 %
[11,   200] loss: 1.595
[11,   400] loss: 1.623
[11,   60

Train Accuracy: 76.382 %
[86,   200] loss: 0.691
[86,   400] loss: 0.722
[86,   600] loss: 0.733
Train Accuracy: 76.840 %
[87,   200] loss: 0.683
[87,   400] loss: 0.735
[87,   600] loss: 0.730
Train Accuracy: 77.118 %
[88,   200] loss: 0.694
[88,   400] loss: 0.703
[88,   600] loss: 0.714
Train Accuracy: 77.168 %
[89,   200] loss: 0.671
[89,   400] loss: 0.718
[89,   600] loss: 0.728
Train Accuracy: 77.508 %
[90,   200] loss: 0.663
[90,   400] loss: 0.693
[90,   600] loss: 0.707
Train Accuracy: 77.722 %
[91,   200] loss: 0.650
[91,   400] loss: 0.683
[91,   600] loss: 0.694
Train Accuracy: 77.798 %
[92,   200] loss: 0.665
[92,   400] loss: 0.663
[92,   600] loss: 0.696
Train Accuracy: 78.100 %
[93,   200] loss: 0.632
[93,   400] loss: 0.650
[93,   600] loss: 0.695
Train Accuracy: 78.354 %
[94,   200] loss: 0.629
[94,   400] loss: 0.665
[94,   600] loss: 0.678
Train Accuracy: 78.332 %
[95,   200] loss: 0.645
[95,   400] loss: 0.658
[95,   600] loss: 0.678
Train Accuracy: 78.384 %
[96, 

[168,   400] loss: 0.291
[168,   600] loss: 0.332
Train Accuracy: 89.088 %
[169,   200] loss: 0.263
[169,   400] loss: 0.280
[169,   600] loss: 0.322
Train Accuracy: 89.550 %
[170,   200] loss: 0.260
[170,   400] loss: 0.300
[170,   600] loss: 0.334
Train Accuracy: 89.098 %
[171,   200] loss: 0.277
[171,   400] loss: 0.281
[171,   600] loss: 0.319
Train Accuracy: 89.360 %
[172,   200] loss: 0.277
[172,   400] loss: 0.297
[172,   600] loss: 0.343
Train Accuracy: 88.668 %
[173,   200] loss: 0.280
[173,   400] loss: 0.290
[173,   600] loss: 0.315
Train Accuracy: 89.106 %
[174,   200] loss: 0.270
[174,   400] loss: 0.306
[174,   600] loss: 0.309
Train Accuracy: 89.252 %
[175,   200] loss: 0.267
[175,   400] loss: 0.294
[175,   600] loss: 0.303
Train Accuracy: 89.512 %
[176,   200] loss: 0.259
[176,   400] loss: 0.300
[176,   600] loss: 0.319
Train Accuracy: 89.476 %
[177,   200] loss: 0.253
[177,   400] loss: 0.291
[177,   600] loss: 0.312
Train Accuracy: 89.712 %
[178,   200] loss: 0.250


[250,   400] loss: 0.190
[250,   600] loss: 0.212
Train Accuracy: 93.024 %
[251,   200] loss: 0.149
[251,   400] loss: 0.148
[251,   600] loss: 0.202
Train Accuracy: 93.668 %
[252,   200] loss: 0.191
[252,   400] loss: 0.196
[252,   600] loss: 0.202
Train Accuracy: 92.474 %
[253,   200] loss: 0.196
[253,   400] loss: 0.187
[253,   600] loss: 0.212
Train Accuracy: 92.742 %
[254,   200] loss: 0.150
[254,   400] loss: 0.167
[254,   600] loss: 0.205
Train Accuracy: 93.410 %
[255,   200] loss: 0.171
[255,   400] loss: 0.150
[255,   600] loss: 0.171
Train Accuracy: 93.930 %
[256,   200] loss: 0.146
[256,   400] loss: 0.142
[256,   600] loss: 0.151
Train Accuracy: 94.292 %
[257,   200] loss: 0.139
[257,   400] loss: 0.169
[257,   600] loss: 0.206
Train Accuracy: 93.420 %
[258,   200] loss: 0.202
[258,   400] loss: 0.212
[258,   600] loss: 0.236
Train Accuracy: 92.282 %
[259,   200] loss: 0.189
[259,   400] loss: 0.201
[259,   600] loss: 0.207
Train Accuracy: 92.586 %
[260,   200] loss: 0.162


[332,   400] loss: 0.111
[332,   600] loss: 0.115
Train Accuracy: 95.900 %
[333,   200] loss: 0.103
[333,   400] loss: 0.101
[333,   600] loss: 0.110
Train Accuracy: 95.924 %
[334,   200] loss: 0.140
[334,   400] loss: 0.137
[334,   600] loss: 0.181
Train Accuracy: 94.268 %
[335,   200] loss: 0.140
[335,   400] loss: 0.170
[335,   600] loss: 0.211
Train Accuracy: 93.674 %
[336,   200] loss: 0.157
[336,   400] loss: 0.135
[336,   600] loss: 0.199
Train Accuracy: 93.784 %
[337,   200] loss: 0.164
[337,   400] loss: 0.174
[337,   600] loss: 0.171
Train Accuracy: 94.176 %
[338,   200] loss: 0.164
[338,   400] loss: 0.130
[338,   600] loss: 0.175
Train Accuracy: 94.586 %
[339,   200] loss: 0.100
[339,   400] loss: 0.089
[339,   600] loss: 0.103
Train Accuracy: 96.076 %
[340,   200] loss: 0.127
[340,   400] loss: 0.120
[340,   600] loss: 0.157
Train Accuracy: 95.114 %
[341,   200] loss: 0.115
[341,   400] loss: 0.135
[341,   600] loss: 0.132
Train Accuracy: 95.352 %
[342,   200] loss: 0.102


Train Accuracy: 100.000 %
[414,   200] loss: 0.001
[414,   400] loss: 0.001
[414,   600] loss: 0.001
Train Accuracy: 100.000 %
[415,   200] loss: 0.001
[415,   400] loss: 0.001
[415,   600] loss: 0.001
Train Accuracy: 100.000 %
[416,   200] loss: 0.001
[416,   400] loss: 0.001
[416,   600] loss: 0.001
Train Accuracy: 100.000 %
[417,   200] loss: 0.001
[417,   400] loss: 0.001
[417,   600] loss: 0.001
Train Accuracy: 100.000 %
[418,   200] loss: 0.001
[418,   400] loss: 0.001
[418,   600] loss: 0.001
Train Accuracy: 100.000 %
[419,   200] loss: 0.001
[419,   400] loss: 0.001
[419,   600] loss: 0.001
Train Accuracy: 100.000 %
[420,   200] loss: 0.001
[420,   400] loss: 0.001
[420,   600] loss: 0.001
Train Accuracy: 100.000 %
[421,   200] loss: 0.001
[421,   400] loss: 0.001
[421,   600] loss: 0.001
Train Accuracy: 100.000 %
[422,   200] loss: 0.001
[422,   400] loss: 0.001
[422,   600] loss: 0.001
Train Accuracy: 100.000 %
[423,   200] loss: 0.001
[423,   400] loss: 0.001
[423,   600] lo

[495,   200] loss: 0.000
[495,   400] loss: 0.000
[495,   600] loss: 0.000
Train Accuracy: 100.000 %
[496,   200] loss: 0.000
[496,   400] loss: 0.000
[496,   600] loss: 0.000
Train Accuracy: 100.000 %
[497,   200] loss: 0.000
[497,   400] loss: 0.000
[497,   600] loss: 0.000
Train Accuracy: 100.000 %
[498,   200] loss: 0.000
[498,   400] loss: 0.000
[498,   600] loss: 0.000
Train Accuracy: 100.000 %
[499,   200] loss: 0.000
[499,   400] loss: 0.000
[499,   600] loss: 0.000
Train Accuracy: 100.000 %
[500,   200] loss: 0.000
[500,   400] loss: 0.000
[500,   600] loss: 0.000
Train Accuracy: 100.000 %
[501,   200] loss: 0.000
[501,   400] loss: 0.000
[501,   600] loss: 0.000
Train Accuracy: 100.000 %
[502,   200] loss: 0.000
[502,   400] loss: 0.000
[502,   600] loss: 0.000
Train Accuracy: 100.000 %
[503,   200] loss: 0.000
[503,   400] loss: 0.000
[503,   600] loss: 0.000
Train Accuracy: 100.000 %
[504,   200] loss: 0.000
[504,   400] loss: 0.000
[504,   600] loss: 0.000
Train Accuracy: 

[576,   400] loss: 0.000
[576,   600] loss: 0.000
Train Accuracy: 100.000 %
[577,   200] loss: 0.000
[577,   400] loss: 0.000
[577,   600] loss: 0.000
Train Accuracy: 100.000 %
[578,   200] loss: 0.000
[578,   400] loss: 0.000
[578,   600] loss: 0.000
Train Accuracy: 100.000 %
[579,   200] loss: 0.000
[579,   400] loss: 0.000
[579,   600] loss: 0.000
Train Accuracy: 100.000 %
[580,   200] loss: 0.000
[580,   400] loss: 0.000
[580,   600] loss: 0.000
Train Accuracy: 100.000 %
[581,   200] loss: 0.000
[581,   400] loss: 0.000
[581,   600] loss: 0.000
Train Accuracy: 100.000 %
[582,   200] loss: 0.000
[582,   400] loss: 0.000
[582,   600] loss: 0.000
Train Accuracy: 100.000 %
[583,   200] loss: 0.000
[583,   400] loss: 0.000
[583,   600] loss: 0.000
Train Accuracy: 100.000 %
[584,   200] loss: 0.000
[584,   400] loss: 0.000
[584,   600] loss: 0.000
Train Accuracy: 100.000 %
[585,   200] loss: 0.000
[585,   400] loss: 0.000
[585,   600] loss: 0.000
Train Accuracy: 100.000 %
[586,   200] lo

In [None]:
net_p = Net()
net_p.load_state_dict(torch.load("models/LeNet.pt"))
net_p.to(device)
net_p.eval()
net_p.weight_prune(50)
test(net_p)
#summary(net, (3, 32, 32))
calculate_size(net_p)

Pruning with threshold : 0.1145
Accuracy of the network on the 10000 test images: 36.30 %
Test loss: 15.6036 
Params size (MB): 0.119


In [None]:
#optimizer.load_state_dict(initial_optimizer_state_dict) 
optimizer = optim.SGD(net_p.parameters(), lr=0.001, momentum=0.9)
train(net_p, 500)
test(net_p)

[1,   200] loss: 3.341
[1,   400] loss: 1.934
[1,   600] loss: 1.550
Train Accuracy: 64.186 %
[2,   200] loss: 0.985
[2,   400] loss: 0.932
[2,   600] loss: 0.953
Train Accuracy: 72.400 %
[3,   200] loss: 0.732
[3,   400] loss: 0.725
[3,   600] loss: 0.731
Train Accuracy: 76.090 %
[4,   200] loss: 0.611
[4,   400] loss: 0.619
[4,   600] loss: 0.626
Train Accuracy: 78.372 %
[5,   200] loss: 0.552
[5,   400] loss: 0.566
[5,   600] loss: 0.574
Train Accuracy: 79.782 %
[6,   200] loss: 0.506
[6,   400] loss: 0.524
[6,   600] loss: 0.560
Train Accuracy: 81.014 %
[7,   200] loss: 0.486
[7,   400] loss: 0.506
[7,   600] loss: 0.499
Train Accuracy: 81.672 %
[8,   200] loss: 0.452
[8,   400] loss: 0.490
[8,   600] loss: 0.501
Train Accuracy: 82.332 %
[9,   200] loss: 0.455
[9,   400] loss: 0.451
[9,   600] loss: 0.471
Train Accuracy: 82.890 %
[10,   200] loss: 0.431
[10,   400] loss: 0.440
[10,   600] loss: 0.463
Train Accuracy: 83.464 %
[11,   200] loss: 0.409
[11,   400] loss: 0.441
[11,   60

Train Accuracy: 89.322 %
[86,   200] loss: 0.238
[86,   400] loss: 0.246
[86,   600] loss: 0.292
Train Accuracy: 89.830 %
[87,   200] loss: 0.239
[87,   400] loss: 0.254
[87,   600] loss: 0.301
Train Accuracy: 89.646 %
[88,   200] loss: 0.219
[88,   400] loss: 0.252
[88,   600] loss: 0.302
Train Accuracy: 89.858 %
[89,   200] loss: 0.256
[89,   400] loss: 0.275
[89,   600] loss: 0.283
Train Accuracy: 89.604 %
[90,   200] loss: 0.237
[90,   400] loss: 0.251
[90,   600] loss: 0.301
Train Accuracy: 89.630 %
[91,   200] loss: 0.260
[91,   400] loss: 0.259
[91,   600] loss: 0.276
Train Accuracy: 89.514 %
[92,   200] loss: 0.254
[92,   400] loss: 0.255
[92,   600] loss: 0.293
Train Accuracy: 89.598 %
[93,   200] loss: 0.240
[93,   400] loss: 0.253
[93,   600] loss: 0.275
Train Accuracy: 90.124 %
[94,   200] loss: 0.243
[94,   400] loss: 0.253
[94,   600] loss: 0.266
Train Accuracy: 90.132 %
[95,   200] loss: 0.235
[95,   400] loss: 0.246
[95,   600] loss: 0.273
Train Accuracy: 89.992 %
[96, 