In [1]:
# Script demonstrating pruning of a conv layer together with a subsequent fully connected layer.
# The Fully connected has no bias.

In [2]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import pdb
import numpy as np
from copy import deepcopy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
dataset1 = datasets.MNIST('../data', train=True, download=True,
                       transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
                       transform=transform)

train_kwargs = {'batch_size': 64}
test_kwargs = {'batch_size': 1000}
if torch.cuda.is_available():
    cuda_kwargs = {'num_workers': 1,
                       'pin_memory': True,
                       'shuffle': True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

In [4]:
def train(log_interval, model, device, train_loader, optimizer, epoch, alpha =1):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data, alpha)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {}, alpha: {}, [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, alpha, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))



def test(model, device, test_loader, alpha=1):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data, alpha)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Alpha {}, Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(alpha,
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [5]:
# Original Pre-trained model (Trained using mnist.py)
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 1, 3, 1, bias = False)
        self.conv2 = nn.Conv2d(1, 2, 3, 1, bias = False)
        self.conv3 = nn.Conv2d(2, 256, 3, 1, bias = False)
        
        self.bn1 = nn.BatchNorm2d(1)
        self.bn2 = nn.BatchNorm2d(2)
        self.bn3 = nn.BatchNorm2d(256)

        self.fc1 = nn.Linear(256 * 1 * 1, 128, bias =False)
        self.fc2 = nn.Linear(128, 50, bias=False)
        self.fc3 = nn.Linear(50, 10, bias = False)
        

    def forward(self, x, alpha = 1):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.conv3(x)
        x = self.bn3(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x= F.relu(x)
        x = self.fc3(x)
        output = F.log_softmax(x, dim=1)
        return output
    
model = Net().to(device)

In [6]:
# Load weights
# Note that the model loaded was trained with FC layers having no bias
PATH = 'mnist_cnn_nobias.pt'
model_dict= torch.load(PATH, map_location = device)
model.load_state_dict(model_dict)

<All keys matched successfully>

In [7]:
# Test to see if the weights of the pretrained model have been correctly loaded
test(model, device, test_loader)


Test set: Alpha 1, Average loss: 0.1174, Accuracy: 9660/10000 (97%)



In [8]:
class PrunedNet(nn.Module):
    def __init__(self,pretrainedmodel):
        super(PrunedNet, self).__init__()
        self.pretrainedNet = pretrainedmodel
        
        # No need to update weights of the pretrained network        
        for param in self.pretrainedNet.parameters():
            param.requires_grad=False
        # In case the weights of the pretrained network are also desired to be updated then 
        # comment out the above 2 lines of code            
            
        # Branch parallel to conv3 in the original pre-trained network
        # Note that the since, the output of conv3 is connected to a FC layer
        # Therefore fca is also a FC layer. Otherwise if conv3 output was connected to another conv layer
        # Than fca would have been a nn.Conv2d instead of nn.Linear
        self.conva = nn.Conv2d(256, 40, 1, 1, bias = False)
        self.fca = nn.Linear(40, 256, bias = False)
        
        self.bna = nn.BatchNorm2d(40)
        


        # Branch parallel to the fc1 layer in the original pre-trained network
        self.fcc = nn.Linear(128, 64 , bias = False)
        self.fcd = nn.Linear(64, 128, bias = False)
        
    def forward(self, x, alpha=1):
        x = self.pretrainedNet.conv1(x)
        x = self.pretrainedNet.bn1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.pretrainedNet.conv2(x)
        x = self.pretrainedNet.bn2(x)
        x = F.relu(x)
        xorg = F.max_pool2d(x, 2)
        
        
        x1 = self.pretrainedNet.conv3(xorg)

        x2 = self.fca((F.max_pool2d(F.relu(self.bna(self.conva(x1))),2)).squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(2)
        
        x = alpha * F.max_pool2d(F.relu(self.pretrainedNet.bn3(x1)),2) + (1-alpha) * x2
        
        
        xorg = torch.flatten(x, 1)
                     
        
        x1 = self.pretrainedNet.fc1(xorg)
        
        x2 = self.fcd((F.relu(self.fcc(x1))))

        x = alpha * F.relu(x1) + (1-alpha) * x2

        
        x = self.pretrainedNet.fc2(x)

        x= F.relu(x)
        
        x = self.pretrainedNet.fc3(x)
        output = F.log_softmax(x, dim=1)
        return output
pruned_model = PrunedNet(model).to(device)

In [11]:
# Check to see if the original branch(alpha = 1 ) gives the same accuracy.
# If not, then something wrong with the implementation
print('Original accuracy:')
test(model, device, test_loader)
print('Accuracy of pruned_model at alpha = 1:')
print(test(pruned_model, device, test_loader, alpha = 1))

Original accuracy:

Test set: Alpha 1, Average loss: 0.1178, Accuracy: 9658/10000 (97%)

Accuracy of pruned_model at alpha = 1:

Test set: Alpha 1, Average loss: 0.1178, Accuracy: 9658/10000 (97%)

None


In [10]:
# 2 For loops, The outer is for varying the alpha values. The inner loop are the training epochs
# The schedule of how alpha is decayed is something to explore.
# In this case, a simple, linear decay has been used.
optimizer = optim.Adadelta(filter(lambda p: p.requires_grad, pruned_model.parameters()), lr=1.0)
for alpha in np.arange(0.9,-0.05,-0.1):
    scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
    for epoch in range(1, 10 + 1):
        train(10,pruned_model, device, train_loader, optimizer, epoch, alpha)
        test(pruned_model, device, test_loader, alpha)
        scheduler.step()


Test set: Alpha 0.9, Average loss: 0.1156, Accuracy: 9653/10000 (97%)




Test set: Alpha 0.9, Average loss: 0.1136, Accuracy: 9659/10000 (97%)




Test set: Alpha 0.9, Average loss: 0.1144, Accuracy: 9652/10000 (97%)


Test set: Alpha 0.9, Average loss: 0.1142, Accuracy: 9657/10000 (97%)




Test set: Alpha 0.9, Average loss: 0.1136, Accuracy: 9662/10000 (97%)




Test set: Alpha 0.9, Average loss: 0.1145, Accuracy: 9658/10000 (97%)




Test set: Alpha 0.9, Average loss: 0.1141, Accuracy: 9650/10000 (96%)


Test set: Alpha 0.9, Average loss: 0.1139, Accuracy: 9657/10000 (97%)




Test set: Alpha 0.9, Average loss: 0.1146, Accuracy: 9651/10000 (97%)




Test set: Alpha 0.9, Average loss: 0.1144, Accuracy: 9647/10000 (96%)


Test set: Alpha 0.8, Average loss: 0.1152, Accuracy: 9641/10000 (96%)




Test set: Alpha 0.8, Average loss: 0.1150, Accuracy: 9647/10000 (96%)




Test set: Alpha 0.8, Average loss: 0.1174, Accuracy: 9641/10000 (96%)




Test set: Alpha 0.8, Average loss: 0.1131, Accuracy: 9653/10000 (97%)


Test set: Alpha 0.8, Average loss: 0.1136, Accuracy: 9660/10000 (97%)




Test set: Alpha 0.8, Average loss: 0.1123, Accuracy: 9669/10000 (97%)




Test set: Alpha 0.8, Average loss: 0.1126, Accuracy: 9662/10000 (97%)




Test set: Alpha 0.8, Average loss: 0.1133, Accuracy: 9663/10000 (97%)


Test set: Alpha 0.8, Average loss: 0.1135, Accuracy: 9657/10000 (97%)




Test set: Alpha 0.8, Average loss: 0.1122, Accuracy: 9667/10000 (97%)




Test set: Alpha 0.7000000000000001, Average loss: 0.1230, Accuracy: 9626/10000 (96%)




Test set: Alpha 0.7000000000000001, Average loss: 0.1193, Accuracy: 9634/10000 (96%)




Test set: Alpha 0.7000000000000001, Average loss: 0.1167, Accuracy: 9652/10000 (97%)




Test set: Alpha 0.7000000000000001, Average loss: 0.1182, Accuracy: 9642/10000 (96%)


Test set: Alpha 0.7000000000000001, Average loss: 0.1155, Accuracy: 9651/10000 (97%)




Test set: Alpha 0.7000000000000001, Average loss: 0.1149, Accuracy: 9662/10000 (97%)




Test set: Alpha 0.7000000000000001, Average loss: 0.1160, Accuracy: 9662/10000 (97%)




Test set: Alpha 0.7000000000000001, Average loss: 0.1150, Accuracy: 9662/10000 (97%)




Test set: Alpha 0.7000000000000001, Average loss: 0.1146, Accuracy: 9667/10000 (97%)




Test set: Alpha 0.7000000000000001, Average loss: 0.1155, Accuracy: 9654/10000 (97%)




Test set: Alpha 0.6000000000000001, Average loss: 0.1283, Accuracy: 9619/10000 (96%)




Test set: Alpha 0.6000000000000001, Average loss: 0.1274, Accuracy: 9627/10000 (96%)




Test set: Alpha 0.6000000000000001, Average loss: 0.1295, Accuracy: 9606/10000 (96%)


Test set: Alpha 0.6000000000000001, Average loss: 0.1198, Accuracy: 9661/10000 (97%)




Test set: Alpha 0.6000000000000001, Average loss: 0.1208, Accuracy: 9666/10000 (97%)




Test set: Alpha 0.6000000000000001, Average loss: 0.1184, Accuracy: 9667/10000 (97%)




Test set: Alpha 0.6000000000000001, Average loss: 0.1183, Accuracy: 9672/10000 (97%)




Test set: Alpha 0.6000000000000001, Average loss: 0.1180, Accuracy: 9679/10000 (97%)




Test set: Alpha 0.6000000000000001, Average loss: 0.1179, Accuracy: 9669/10000 (97%)




Test set: Alpha 0.6000000000000001, Average loss: 0.1194, Accuracy: 9670/10000 (97%)




Test set: Alpha 0.5000000000000001, Average loss: 0.1506, Accuracy: 9589/10000 (96%)




Test set: Alpha 0.5000000000000001, Average loss: 0.1403, Accuracy: 9609/10000 (96%)




Test set: Alpha 0.5000000000000001, Average loss: 0.1300, Accuracy: 9630/10000 (96%)


Test set: Alpha 0.5000000000000001, Average loss: 0.1246, Accuracy: 9647/10000 (96%)




Test set: Alpha 0.5000000000000001, Average loss: 0.1219, Accuracy: 9663/10000 (97%)




Test set: Alpha 0.5000000000000001, Average loss: 0.1218, Accuracy: 9661/10000 (97%)




Test set: Alpha 0.5000000000000001, Average loss: 0.1218, Accuracy: 9668/10000 (97%)




Test set: Alpha 0.5000000000000001, Average loss: 0.1218, Accuracy: 9661/10000 (97%)




Test set: Alpha 0.5000000000000001, Average loss: 0.1218, Accuracy: 9671/10000 (97%)




Test set: Alpha 0.5000000000000001, Average loss: 0.1227, Accuracy: 9669/10000 (97%)




Test set: Alpha 0.40000000000000013, Average loss: 0.1429, Accuracy: 9608/10000 (96%)




Test set: Alpha 0.40000000000000013, Average loss: 0.1385, Accuracy: 9600/10000 (96%)




Test set: Alpha 0.40000000000000013, Average loss: 0.1288, Accuracy: 9647/10000 (96%)


Test set: Alpha 0.40000000000000013, Average loss: 0.1348, Accuracy: 9626/10000 (96%)




Test set: Alpha 0.40000000000000013, Average loss: 0.1286, Accuracy: 9658/10000 (97%)




Test set: Alpha 0.40000000000000013, Average loss: 0.1297, Accuracy: 9643/10000 (96%)




Test set: Alpha 0.40000000000000013, Average loss: 0.1257, Accuracy: 9659/10000 (97%)




Test set: Alpha 0.40000000000000013, Average loss: 0.1256, Accuracy: 9660/10000 (97%)




Test set: Alpha 0.40000000000000013, Average loss: 0.1301, Accuracy: 9653/10000 (97%)




Test set: Alpha 0.40000000000000013, Average loss: 0.1278, Accuracy: 9659/10000 (97%)




Test set: Alpha 0.30000000000000016, Average loss: 0.1415, Accuracy: 9612/10000 (96%)




Test set: Alpha 0.30000000000000016, Average loss: 0.1285, Accuracy: 9655/10000 (97%)




Test set: Alpha 0.30000000000000016, Average loss: 0.1282, Accuracy: 9624/10000 (96%)


Test set: Alpha 0.30000000000000016, Average loss: 0.1331, Accuracy: 9643/10000 (96%)




Test set: Alpha 0.30000000000000016, Average loss: 0.1302, Accuracy: 9643/10000 (96%)




Test set: Alpha 0.30000000000000016, Average loss: 0.1278, Accuracy: 9653/10000 (97%)




Test set: Alpha 0.30000000000000016, Average loss: 0.1295, Accuracy: 9663/10000 (97%)




Test set: Alpha 0.30000000000000016, Average loss: 0.1298, Accuracy: 9653/10000 (97%)




Test set: Alpha 0.30000000000000016, Average loss: 0.1309, Accuracy: 9657/10000 (97%)




Test set: Alpha 0.30000000000000016, Average loss: 0.1307, Accuracy: 9663/10000 (97%)




Test set: Alpha 0.20000000000000018, Average loss: 0.1370, Accuracy: 9615/10000 (96%)




Test set: Alpha 0.20000000000000018, Average loss: 0.1394, Accuracy: 9618/10000 (96%)




Test set: Alpha 0.20000000000000018, Average loss: 0.1349, Accuracy: 9635/10000 (96%)




Test set: Alpha 0.20000000000000018, Average loss: 0.1310, Accuracy: 9647/10000 (96%)


Test set: Alpha 0.20000000000000018, Average loss: 0.1320, Accuracy: 9642/10000 (96%)




Test set: Alpha 0.20000000000000018, Average loss: 0.1292, Accuracy: 9652/10000 (97%)




Test set: Alpha 0.20000000000000018, Average loss: 0.1296, Accuracy: 9655/10000 (97%)




Test set: Alpha 0.20000000000000018, Average loss: 0.1307, Accuracy: 9651/10000 (97%)




Test set: Alpha 0.20000000000000018, Average loss: 0.1303, Accuracy: 9659/10000 (97%)




Test set: Alpha 0.20000000000000018, Average loss: 0.1302, Accuracy: 9654/10000 (97%)




Test set: Alpha 0.1000000000000002, Average loss: 0.1277, Accuracy: 9612/10000 (96%)




Test set: Alpha 0.1000000000000002, Average loss: 0.1246, Accuracy: 9643/10000 (96%)




Test set: Alpha 0.1000000000000002, Average loss: 0.1302, Accuracy: 9636/10000 (96%)




Test set: Alpha 0.1000000000000002, Average loss: 0.1272, Accuracy: 9638/10000 (96%)


Test set: Alpha 0.1000000000000002, Average loss: 0.1277, Accuracy: 9645/10000 (96%)




Test set: Alpha 0.1000000000000002, Average loss: 0.1264, Accuracy: 9642/10000 (96%)




Test set: Alpha 0.1000000000000002, Average loss: 0.1287, Accuracy: 9645/10000 (96%)




Test set: Alpha 0.1000000000000002, Average loss: 0.1277, Accuracy: 9641/10000 (96%)




Test set: Alpha 0.1000000000000002, Average loss: 0.1265, Accuracy: 9644/10000 (96%)




Test set: Alpha 0.1000000000000002, Average loss: 0.1273, Accuracy: 9641/10000 (96%)




Test set: Alpha 2.220446049250313e-16, Average loss: 0.3079, Accuracy: 9100/10000 (91%)




Test set: Alpha 2.220446049250313e-16, Average loss: 0.2451, Accuracy: 9257/10000 (93%)




Test set: Alpha 2.220446049250313e-16, Average loss: 0.2087, Accuracy: 9389/10000 (94%)




Test set: Alpha 2.220446049250313e-16, Average loss: 0.2001, Accuracy: 9414/10000 (94%)




Test set: Alpha 2.220446049250313e-16, Average loss: 0.1922, Accuracy: 9442/10000 (94%)




Test set: Alpha 2.220446049250313e-16, Average loss: 0.1881, Accuracy: 9472/10000 (95%)


Test set: Alpha 2.220446049250313e-16, Average loss: 0.1799, Accuracy: 9465/10000 (95%)




Test set: Alpha 2.220446049250313e-16, Average loss: 0.1797, Accuracy: 9465/10000 (95%)




Test set: Alpha 2.220446049250313e-16, Average loss: 0.1782, Accuracy: 9475/10000 (95%)




Test set: Alpha 2.220446049250313e-16, Average loss: 0.1769, Accuracy: 9482/10000 (95%)



In [12]:
print(test(pruned_model, device, test_loader, 0 ))
print(test(pruned_model, device, test_loader, 1 ))


Test set: Alpha 0, Average loss: 0.1769, Accuracy: 9482/10000 (95%)

None

Test set: Alpha 1, Average loss: 0.1178, Accuracy: 9658/10000 (97%)

None


In [15]:
# Final pruned model after doing the appropriate matrix multiplications (for FC layers)
# or tensor product for (convolution layers)
# for tensor products see: https://pytorch.org/docs/stable/generated/torch.einsum.html

# After pruning the conv3 has only 40 filters instead of the original 256
# Also the FC1 output has now only 64 neurons instead of 128

# Note that at the point where 2 consecutive layers are being pruned, 3 matrix multiplications need to be done.
# This can also be thought as 2 matrix multiplications done 2 times.
# TODO: Implement bias when there are consecutive layers being pruned
class SmallerNet(nn.Module):
    def __init__(self):
        super(SmallerNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 1, 3, 1, bias = False)
        self.conv2 = nn.Conv2d(1, 2, 3, 1, bias = False)
        self.conv3 = nn.Conv2d(2, 40, 3, 1, bias = False)
        
        self.bn1 = nn.BatchNorm2d(1)
        self.bn2 = nn.BatchNorm2d(2)
        self.bn3 = nn.BatchNorm2d(40)

        self.fc1 = nn.Linear(40 * 1 * 1, 64, bias =False)
        self.fc2 = nn.Linear(64, 50, bias=False)
        self.fc3 = nn.Linear(50, 10, bias = False)
        

    def forward(self, x, alpha = 1):
        # Note that alpha is not being used here
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.conv3(x)
        x = self.bn3(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x= F.relu(x)
        x = self.fc3(x)
        output = F.log_softmax(x, dim=1)
        return output
    
smaller_model = SmallerNet().to(device)

In [16]:
# Since the model is randomly initilaized, the accuracy is low
test(smaller_model, device, test_loader, 0 )


Test set: Alpha 0, Average loss: 2.3036, Accuracy: 967/10000 (10%)



In [19]:
# The pretrained weights are loaded after the appropriate matrix/tensor multiplications to get the new weight matrix 

smaller_model_dict= deepcopy(model.state_dict())

pmd=pruned_model.state_dict()
print(pmd["conva.weight"].shape)
# print(pmd["convb.weight"].shape)
print(pmd['pretrainedNet.conv3.weight'].shape)
x = pmd["conva.weight"]
y = pmd['pretrainedNet.conv3.weight']
# 4 dimensional Tensor multplications in the next steps
# 2 alternate ways to do this. Both should yield the same result
out1 = torch.einsum("abcd, befg -> aefg",x,y)
out2 = torch.einsum("abcd, eafg -> ebcd",y,x)

smaller_model_dict["conv3.weight"] = out1
smaller_model_dict["bn3.weight"] = pmd["bna.weight"]
smaller_model_dict["bn3.bias"] = pmd["bna.bias"]
smaller_model_dict["bn3.running_mean"] = pmd["bna.running_mean"]
smaller_model_dict["bn3.running_var"]  = pmd["bna.running_var"]
smaller_model_dict["bn3.num_batches_tracked"] = pmd["bna.num_batches_tracked"]
smaller_model_dict["bn3.weight"] = pmd["bna.weight"]
# Note the 3 matrix multiplications in the next step
smaller_model_dict["fc1.weight"] = pmd["fcc.weight"] @ pmd["pretrainedNet.fc1.weight"] @ pmd["fca.weight"]
smaller_model_dict["fc2.weight"] = pmd["pretrainedNet.fc2.weight"] @ pmd["fcd.weight"]
smaller_model.load_state_dict(smaller_model_dict)

torch.Size([40, 256, 1, 1])
torch.Size([256, 2, 3, 3])


<All keys matched successfully>

In [20]:
test(smaller_model, device, test_loader, 0 )


Test set: Alpha 0, Average loss: 0.1769, Accuracy: 9482/10000 (95%)

