Prepare the data:

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import torch.optim as optim
from torch.autograd import Variable
from torchsummary import summary

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
torch.cuda.is_available()

cuda:0


True

In [3]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                         shuffle=False, num_workers=4)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [4]:
class MaskedLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super(MaskedLinear, self).__init__(in_features, out_features, bias)
        self.mask_flag = False
    
    def set_mask(self, mask):
        self.mask = Variable(mask, requires_grad=False)
        self.weight.data = self.weight.data*self.mask.data
        self.mask_flag = True
    
    def get_mask(self):
        print(self.mask_flag)
        return self.mask
    
    def prune(self, threshold):
        # generate mask
        for params in self.parameters():
            if len(params.data.size()) != 1: # Not bias
                mask = params.data.abs() > threshold
                self.set_mask(mask.float()) 
    
    def forward(self, x):
        if self.mask_flag == True:
            weight = self.weight*self.mask
            return F.linear(x, weight, self.bias)
        else:
            return F.linear(x, self.weight, self.bias)
        

        
class MaskedConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super(MaskedConv2d, self).__init__(in_channels, out_channels, 
            kernel_size, stride, padding, dilation, groups, bias)
        self.mask_flag = False
    
    def set_mask(self, mask):
        self.mask = Variable(mask, requires_grad=False)
        self.weight.data = self.weight.data*self.mask.data
        self.mask_flag = True
    
    def get_mask(self):
        print(self.mask_flag)
        return self.mask
    
    def prune(self, threshold):
        # generate mask
        for params in self.parameters():
            if len(params.data.size()) != 1: # Not bias
                mask = params.data.abs() > threshold
                self.set_mask(mask.float()) 
    
    def forward(self, x):
        if self.mask_flag == True:
            weight = self.weight*self.mask
            return F.conv2d(x, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)
        else:
            return F.conv2d(x, self.weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

Define the network:

In [5]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        linear = MaskedLinear
        conv2d = MaskedConv2d
        self.conv1 = conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = conv2d(6, 16, 5)
        self.fc1 = linear(16 * 5 * 5, 512)
        self.fc2 = linear(512, 84)
        self.fc3 = linear(84, 10)


    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def weight_prune(self, pruning_perc):
        all_weights = []
        for params in self.parameters():
            if len(params.data.size()) != 1:
                all_weights += list(params.cpu().data.abs().numpy().flatten())
        threshold = np.percentile(np.array(all_weights), pruning_perc)
        print(f'Pruning with threshold : %.4f' % threshold)
        
        # Module here refers to layer
        for name, module in self.named_modules():
            if name in ['conv1','conv2','fc1','fc2','fc3']:
                module.prune(threshold)                

net = Net()
net.to(device)
summary(net, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
      MaskedConv2d-1            [-1, 6, 28, 28]             456
         MaxPool2d-2            [-1, 6, 14, 14]               0
      MaskedConv2d-3           [-1, 16, 10, 10]           2,416
         MaxPool2d-4             [-1, 16, 5, 5]               0
      MaskedLinear-5                  [-1, 512]         205,312
      MaskedLinear-6                   [-1, 84]          43,092
      MaskedLinear-7                   [-1, 10]             850
Total params: 252,126
Trainable params: 252,126
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 0.06
Params size (MB): 0.96
Estimated Total Size (MB): 1.04
----------------------------------------------------------------


In [6]:
#params = list(net.parameters())
#print(len(params))
#print(params[4])

In [7]:
criterion = nn.CrossEntropyLoss() # Softmax is built in it so you do not need add that on the last layer
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
#initial_optimizer_state_dict = optimizer.state_dict()

In [8]:
def train(model, epochs):
    for epoch in range(epochs):  # loop over the dataset multiple times
        
        train_correct = 0
        train_total = 0
        running_loss = 0.0
        
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data[0].to(device), data[1].to(device)
    
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 200 == 199:    # print every 200 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 200))
                running_loss = 0.0
                
            # training accuracy
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            
        print('Train Accuracy: %.3f %%' % (100 * train_correct / train_total))
        
    print('Finished Training')

In [None]:
def test(model):
    correct = 0
    total = 0
    test_loss = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            test_loss += criterion(outputs, labels).item()
    
    print('Accuracy of the network on the 10000 test images: %.2f %%' % (
        100 * correct / total))
    print('Test loss: %.4f ' % (test_loss / 100))

In [None]:
def calculate_size(model):
    non_zeros = 0
    for param in model.parameters():
        if param.requires_grad==True:
            non_zeros += param.nonzero().size(0)
    params_size = non_zeros* 4. / (1024 ** 2)
    print("Params size (MB): %0.3f" % params_size)

In [None]:
train(net, 200)
torch.save(net.state_dict(), "models/LeNet.pt")
test(net)
calculate_size(net)

[1,   200] loss: 2.302
[1,   400] loss: 2.301
[1,   600] loss: 2.297
Train Accuracy: 14.380 %
[2,   200] loss: 2.258
[2,   400] loss: 2.164
[2,   600] loss: 2.090
Train Accuracy: 23.658 %
[3,   200] loss: 1.932
[3,   400] loss: 1.879
[3,   600] loss: 1.814
Train Accuracy: 33.184 %
[4,   200] loss: 1.743
[4,   400] loss: 1.687
[4,   600] loss: 1.662
Train Accuracy: 38.952 %
[5,   200] loss: 1.613
[5,   400] loss: 1.592
[5,   600] loss: 1.588
Train Accuracy: 42.526 %
[6,   200] loss: 1.541
[6,   400] loss: 1.512
[6,   600] loss: 1.506
Train Accuracy: 45.642 %
[7,   200] loss: 1.466
[7,   400] loss: 1.460
[7,   600] loss: 1.440
Train Accuracy: 47.824 %
[8,   200] loss: 1.407
[8,   400] loss: 1.398
[8,   600] loss: 1.411
Train Accuracy: 50.146 %
[9,   200] loss: 1.355
[9,   400] loss: 1.343
[9,   600] loss: 1.346
Train Accuracy: 51.814 %
[10,   200] loss: 1.321
[10,   400] loss: 1.290
[10,   600] loss: 1.292
Train Accuracy: 53.580 %
[11,   200] loss: 1.258
[11,   400] loss: 1.264
[11,   60

Train Accuracy: 100.000 %
[86,   200] loss: 0.001
[86,   400] loss: 0.001
[86,   600] loss: 0.001
Train Accuracy: 100.000 %
[87,   200] loss: 0.001
[87,   400] loss: 0.001
[87,   600] loss: 0.001
Train Accuracy: 100.000 %
[88,   200] loss: 0.001
[88,   400] loss: 0.001
[88,   600] loss: 0.001
Train Accuracy: 100.000 %
[89,   200] loss: 0.001
[89,   400] loss: 0.001
[89,   600] loss: 0.001
Train Accuracy: 100.000 %
[90,   200] loss: 0.001
[90,   400] loss: 0.001
[90,   600] loss: 0.001
Train Accuracy: 100.000 %
[91,   200] loss: 0.001
[91,   400] loss: 0.001
[91,   600] loss: 0.001
Train Accuracy: 100.000 %
[92,   200] loss: 0.001
[92,   400] loss: 0.001
[92,   600] loss: 0.001
Train Accuracy: 100.000 %
[93,   200] loss: 0.001
[93,   400] loss: 0.001
[93,   600] loss: 0.001
Train Accuracy: 100.000 %
[94,   200] loss: 0.001
[94,   400] loss: 0.001
[94,   600] loss: 0.001
Train Accuracy: 100.000 %
[95,   200] loss: 0.001
[95,   400] loss: 0.001
[95,   600] loss: 0.001
Train Accuracy: 100.

[167,   600] loss: 0.000
Train Accuracy: 100.000 %
[168,   200] loss: 0.000
[168,   400] loss: 0.000
[168,   600] loss: 0.000
Train Accuracy: 100.000 %
[169,   200] loss: 0.000
[169,   400] loss: 0.000
[169,   600] loss: 0.000
Train Accuracy: 100.000 %
[170,   200] loss: 0.000
[170,   400] loss: 0.000
[170,   600] loss: 0.000
Train Accuracy: 100.000 %
[171,   200] loss: 0.000
[171,   400] loss: 0.000
[171,   600] loss: 0.000
Train Accuracy: 100.000 %
[172,   200] loss: 0.000
[172,   400] loss: 0.000
[172,   600] loss: 0.000
Train Accuracy: 100.000 %
[173,   200] loss: 0.000
[173,   400] loss: 0.000
[173,   600] loss: 0.000
Train Accuracy: 100.000 %
[174,   200] loss: 0.000
[174,   400] loss: 0.000
[174,   600] loss: 0.000
Train Accuracy: 100.000 %
[175,   200] loss: 0.000
[175,   400] loss: 0.000
[175,   600] loss: 0.000
Train Accuracy: 100.000 %
[176,   200] loss: 0.000
[176,   400] loss: 0.000
[176,   600] loss: 0.000
Train Accuracy: 100.000 %
[177,   200] loss: 0.000
[177,   400] lo

In [43]:
net_p = Net()
net_p.load_state_dict(torch.load("models/LeNet.pt"))
net_p.to(device)
net_p.eval()
net_p.weight_prune(97.5)
test(net_p)
#summary(net, (3, 32, 32))
calculate_size(net_p)

Pruning with threshold : 0.1084
Accuracy of the network on the 10000 test images: 14.99 %
Test loss: 4.3909 
Params size (MB): 0.026


In [44]:
#optimizer.load_state_dict(initial_optimizer_state_dict) 
optimizer = optim.SGD(net_p.parameters(), lr=0.001, momentum=0.9)
train(net_p, 500)
test(net_p)

[1,   200] loss: 2.042
[1,   400] loss: 1.648
[1,   600] loss: 1.526
Train Accuracy: 40.438 %
[2,   200] loss: 1.427
[2,   400] loss: 1.412
[2,   600] loss: 1.389
Train Accuracy: 50.190 %
[3,   200] loss: 1.337
[3,   400] loss: 1.338
[3,   600] loss: 1.304
Train Accuracy: 52.870 %
[4,   200] loss: 1.291
[4,   400] loss: 1.294
[4,   600] loss: 1.280
Train Accuracy: 54.622 %
[5,   200] loss: 1.263
[5,   400] loss: 1.249
[5,   600] loss: 1.249
Train Accuracy: 55.782 %
[6,   200] loss: 1.226
[6,   400] loss: 1.234
[6,   600] loss: 1.218
Train Accuracy: 56.588 %
[7,   200] loss: 1.215
[7,   400] loss: 1.207
[7,   600] loss: 1.212
Train Accuracy: 57.458 %
[8,   200] loss: 1.187
[8,   400] loss: 1.188
[8,   600] loss: 1.210
Train Accuracy: 57.942 %
[9,   200] loss: 1.178
[9,   400] loss: 1.178
[9,   600] loss: 1.196
Train Accuracy: 58.388 %
[10,   200] loss: 1.167
[10,   400] loss: 1.173
[10,   600] loss: 1.171
Train Accuracy: 58.862 %
[11,   200] loss: 1.156
[11,   400] loss: 1.147
[11,   60

Train Accuracy: 66.080 %
[86,   200] loss: 0.966
[86,   400] loss: 0.968
[86,   600] loss: 0.966
Train Accuracy: 66.090 %
[87,   200] loss: 0.959
[87,   400] loss: 0.952
[87,   600] loss: 0.986
Train Accuracy: 66.166 %
[88,   200] loss: 0.955
[88,   400] loss: 0.962
[88,   600] loss: 0.962
Train Accuracy: 66.286 %
[89,   200] loss: 0.969
[89,   400] loss: 0.967
[89,   600] loss: 0.938
Train Accuracy: 66.168 %
[90,   200] loss: 0.947
[90,   400] loss: 0.970
[90,   600] loss: 0.978
Train Accuracy: 66.104 %
[91,   200] loss: 0.944
[91,   400] loss: 0.960
[91,   600] loss: 0.969
Train Accuracy: 66.152 %
[92,   200] loss: 0.952
[92,   400] loss: 0.951
[92,   600] loss: 0.961
Train Accuracy: 66.360 %
[93,   200] loss: 0.960
[93,   400] loss: 0.948
[93,   600] loss: 0.967
Train Accuracy: 66.308 %
[94,   200] loss: 0.956
[94,   400] loss: 0.951
[94,   600] loss: 0.969
Train Accuracy: 66.306 %
[95,   200] loss: 0.960
[95,   400] loss: 0.960
[95,   600] loss: 0.960
Train Accuracy: 66.414 %
[96, 

[168,   400] loss: 0.921
[168,   600] loss: 0.900
Train Accuracy: 68.236 %
[169,   200] loss: 0.907
[169,   400] loss: 0.898
[169,   600] loss: 0.923
Train Accuracy: 68.334 %
[170,   200] loss: 0.906
[170,   400] loss: 0.908
[170,   600] loss: 0.904
Train Accuracy: 68.120 %
[171,   200] loss: 0.915
[171,   400] loss: 0.898
[171,   600] loss: 0.912
Train Accuracy: 68.136 %
[172,   200] loss: 0.924
[172,   400] loss: 0.899
[172,   600] loss: 0.908
Train Accuracy: 68.370 %
[173,   200] loss: 0.911
[173,   400] loss: 0.899
[173,   600] loss: 0.906
Train Accuracy: 68.362 %
[174,   200] loss: 0.909
[174,   400] loss: 0.911
[174,   600] loss: 0.911
Train Accuracy: 68.242 %
[175,   200] loss: 0.900
[175,   400] loss: 0.924
[175,   600] loss: 0.906
Train Accuracy: 68.234 %
[176,   200] loss: 0.891
[176,   400] loss: 0.905
[176,   600] loss: 0.927
Train Accuracy: 68.406 %
[177,   200] loss: 0.902
[177,   400] loss: 0.902
[177,   600] loss: 0.919
Train Accuracy: 68.316 %
[178,   200] loss: 0.904


[250,   400] loss: 0.869
[250,   600] loss: 0.872
Train Accuracy: 69.244 %
[251,   200] loss: 0.876
[251,   400] loss: 0.884
[251,   600] loss: 0.893
Train Accuracy: 69.136 %
[252,   200] loss: 0.864
[252,   400] loss: 0.882
[252,   600] loss: 0.884
Train Accuracy: 69.272 %
[253,   200] loss: 0.859
[253,   400] loss: 0.885
[253,   600] loss: 0.878
Train Accuracy: 69.402 %
[254,   200] loss: 0.888
[254,   400] loss: 0.869
[254,   600] loss: 0.873
Train Accuracy: 69.384 %
[255,   200] loss: 0.861
[255,   400] loss: 0.878
[255,   600] loss: 0.883
Train Accuracy: 69.418 %
[256,   200] loss: 0.871
[256,   400] loss: 0.890
[256,   600] loss: 0.872
Train Accuracy: 69.382 %
[257,   200] loss: 0.867
[257,   400] loss: 0.879
[257,   600] loss: 0.877
Train Accuracy: 69.328 %
[258,   200] loss: 0.876
[258,   400] loss: 0.878
[258,   600] loss: 0.883
Train Accuracy: 69.248 %
[259,   200] loss: 0.864
[259,   400] loss: 0.874
[259,   600] loss: 0.885
Train Accuracy: 69.472 %
[260,   200] loss: 0.877


[332,   400] loss: 0.863
[332,   600] loss: 0.854
Train Accuracy: 70.094 %
[333,   200] loss: 0.856
[333,   400] loss: 0.845
[333,   600] loss: 0.868
Train Accuracy: 70.342 %
[334,   200] loss: 0.847
[334,   400] loss: 0.857
[334,   600] loss: 0.862
Train Accuracy: 70.228 %
[335,   200] loss: 0.845
[335,   400] loss: 0.860
[335,   600] loss: 0.857
Train Accuracy: 70.128 %
[336,   200] loss: 0.859
[336,   400] loss: 0.856
[336,   600] loss: 0.862
Train Accuracy: 70.216 %
[337,   200] loss: 0.843
[337,   400] loss: 0.859
[337,   600] loss: 0.857
Train Accuracy: 70.232 %
[338,   200] loss: 0.852
[338,   400] loss: 0.867
[338,   600] loss: 0.840
Train Accuracy: 70.142 %
[339,   200] loss: 0.847
[339,   400] loss: 0.860
[339,   600] loss: 0.850
Train Accuracy: 70.298 %
[340,   200] loss: 0.853
[340,   400] loss: 0.866
[340,   600] loss: 0.853
Train Accuracy: 70.252 %
[341,   200] loss: 0.842
[341,   400] loss: 0.860
[341,   600] loss: 0.845
Train Accuracy: 70.182 %
[342,   200] loss: 0.858


Train Accuracy: 70.952 %
[439,   200] loss: 0.829
[439,   400] loss: 0.850
[439,   600] loss: 0.844
Train Accuracy: 70.758 %
[440,   200] loss: 0.832
[440,   400] loss: 0.826
[440,   600] loss: 0.844
Train Accuracy: 70.838 %
[441,   200] loss: 0.821
[441,   400] loss: 0.840
[441,   600] loss: 0.849
Train Accuracy: 70.886 %
[442,   200] loss: 0.846
[442,   400] loss: 0.824
[442,   600] loss: 0.832
Train Accuracy: 70.960 %
[443,   200] loss: 0.815
[443,   400] loss: 0.845
[443,   600] loss: 0.838
Train Accuracy: 70.856 %
[444,   200] loss: 0.835
[444,   400] loss: 0.825
[444,   600] loss: 0.841
Train Accuracy: 70.794 %
[445,   200] loss: 0.832
[445,   400] loss: 0.829
[445,   600] loss: 0.840
Train Accuracy: 70.858 %
[446,   200] loss: 0.833
[446,   400] loss: 0.832
[446,   600] loss: 0.836
Train Accuracy: 70.926 %
[447,   200] loss: 0.826
[447,   400] loss: 0.841
[447,   600] loss: 0.841
Train Accuracy: 70.938 %
[448,   200] loss: 0.836
[448,   400] loss: 0.834
[448,   600] loss: 0.826
