In [1]:
# Script demonstrating pruning of a Fully Connected layer with no bias.

In [2]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import pdb
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [17]:
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
dataset1 = datasets.MNIST('../data', train=True, download=True,
                       transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
                       transform=transform)

train_kwargs = {'batch_size': 64}
test_kwargs = {'batch_size': 1000}
if torch.cuda.is_available():
    cuda_kwargs = {'num_workers': 1,
                       'pin_memory': True,
                       'shuffle': True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

In [4]:
def train(log_interval, model, device, train_loader, optimizer, epoch, alpha =1):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data, alpha)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {}, alpha: {}, [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, alpha, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))



def test(model, device, test_loader, alpha=1):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data, alpha)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Alpha {}, Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(alpha,
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

32768it [00:02, 12643.13it/s]            
1654784it [00:02, 748890.60it/s]                            
8192it [00:01, 7217.05it/s]             


In [5]:
# Original Pre-trained model (Trained using mnist.py)
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 1, 3, 1, bias = False)
        self.conv2 = nn.Conv2d(1, 2, 3, 1, bias = False)
        self.conv3 = nn.Conv2d(2, 256, 3, 1, bias = False)
        
        self.bn1 = nn.BatchNorm2d(1)
        self.bn2 = nn.BatchNorm2d(2)
        self.bn3 = nn.BatchNorm2d(256)
        #self.dropout1 = nn.Dropout(0.25)
        #self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(256 * 1 * 1, 128, bias =False)
        self.fc2 = nn.Linear(128, 50, bias=False)
        self.fc3 = nn.Linear(50, 10, bias = False)
        

    def forward(self, x, alpha = 1):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.conv3(x)
        x = self.bn3(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        
        #x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        #x = self.dropout2(x)
        x = self.fc2(x)
        x= F.relu(x)
        x = self.fc3(x)
        output = F.log_softmax(x, dim=1)
        return output
    
model = Net().to(device)

In [6]:
# Load weights
PATH = 'mnist_cnn_nobias.pt'
model_dict= torch.load(PATH, map_location = device)
model.load_state_dict(model_dict)

<All keys matched successfully>

In [7]:
# Test to see if the weights of the pretrained model have been correctly loaded
test(model, device, test_loader)

9920512it [00:11, 864693.11it/s]                              



Test set: Alpha 1, Average loss: 0.1174, Accuracy: 9660/10000 (97%)



In [8]:
# A pretrinaed Network with a branch parrallel to FC1 for the purpose of redcucing from 128 to 64 neurons
# See also image 'highlevel.png'
class PrunedNet(nn.Module):
    def __init__(self,pretrainedmodel):
        super(PrunedNet, self).__init__()
        self.pretrainedNet = pretrainedmodel
        
        # No need to update weights of the pretrained network        
        for param in self.pretrainedNet.parameters():
            param.requires_grad=False
            
        # In case the weights of the pretrained network are also desired to be updated then 
        # comment out the above 2 lines of code
        
        # Branch parallel to the fc1 layer in the original pre-trained network
        self.fcc = nn.Linear(128, 64 , bias = False)
        self.fcd = nn.Linear(64, 128, bias = False)
        
    def forward(self, x, alpha=0.1):
        x = self.pretrainedNet.conv1(x)
        x = self.pretrainedNet.bn1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.pretrainedNet.conv2(x)
        x = self.pretrainedNet.bn2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.pretrainedNet.conv3(x)
        x = self.pretrainedNet.bn3(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        
        xorg = torch.flatten(x, 1)
        x1 = self.pretrainedNet.fc1(xorg)
        
        x2 = self.fcd((F.relu(self.fcc(x1))))

        x = alpha * F.relu(x1) + (1-alpha) * x2

        
        x = self.pretrainedNet.fc2(x)

        x= F.relu(x)
        
        x = self.pretrainedNet.fc3(x)
        output = F.log_softmax(x, dim=1)
        return output
pruned_model = PrunedNet(model).to(device)

In [18]:
# Check to see if the original branch(alpha = 1 ) gives the same accuracy.
# If not, then something wrong with the implementation
print('Original accuracy:')
test(model, device, test_loader)
print('Accuracy of pruned_model at alpha = 1:')
print(test(pruned_model, device, test_loader, alpha = 1))

Original accuracy:


9920512it [47:24, 3487.12it/s]   



Test set: Alpha 1, Average loss: 0.1178, Accuracy: 9660/10000 (97%)

Accuracy of pruned_model at alpha = 1:


9920512it [47:26, 3485.51it/s]   



Test set: Alpha 1, Average loss: 0.1178, Accuracy: 9660/10000 (97%)

None


In [10]:
# 2 For loops, The outer is for varying the alpha values. The inner loop are the training epochs
# The schedule of how alpha is decayed is something to explore.
# In this case, a simple, linear decay has been used.
optimizer = optim.Adadelta(filter(lambda p: p.requires_grad, pruned_model.parameters()), lr=1.0)
for alpha in np.arange(0.9,-0.05,-0.1):
    scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
    for epoch in range(1, 10 + 1):
        train(10,pruned_model, device, train_loader, optimizer, epoch, alpha)
        test(pruned_model, device, test_loader, alpha)
        scheduler.step()



9920512it [01:46, 92760.98it/s]  



Test set: Alpha 0.9, Average loss: 0.1149, Accuracy: 9656/10000 (97%)



9920512it [01:57, 84563.79it/s]  



Test set: Alpha 0.9, Average loss: 0.1150, Accuracy: 9654/10000 (97%)



9920512it [02:07, 77600.14it/s]  



Test set: Alpha 0.9, Average loss: 0.1149, Accuracy: 9656/10000 (97%)



9920512it [02:18, 71831.35it/s]  



Test set: Alpha 0.9, Average loss: 0.1154, Accuracy: 9649/10000 (96%)



9920512it [02:28, 66793.26it/s]  



Test set: Alpha 0.9, Average loss: 0.1147, Accuracy: 9660/10000 (97%)



9920512it [02:38, 62432.13it/s]  



Test set: Alpha 0.9, Average loss: 0.1148, Accuracy: 9660/10000 (97%)



9920512it [02:49, 58525.41it/s]  



Test set: Alpha 0.9, Average loss: 0.1151, Accuracy: 9663/10000 (97%)



9920512it [03:00, 55058.04it/s]  



Test set: Alpha 0.9, Average loss: 0.1145, Accuracy: 9662/10000 (97%)



9920512it [03:10, 52011.71it/s]  



Test set: Alpha 0.9, Average loss: 0.1146, Accuracy: 9661/10000 (97%)



9920512it [03:21, 49343.49it/s]  



Test set: Alpha 0.9, Average loss: 0.1151, Accuracy: 9656/10000 (97%)



9920512it [03:31, 46871.49it/s]  



Test set: Alpha 0.8, Average loss: 0.1131, Accuracy: 9667/10000 (97%)



9920512it [03:42, 44615.78it/s]  



Test set: Alpha 0.8, Average loss: 0.1145, Accuracy: 9645/10000 (96%)



9920512it [03:53, 42560.50it/s]  



Test set: Alpha 0.8, Average loss: 0.1121, Accuracy: 9656/10000 (97%)



9920512it [04:04, 40577.28it/s]  



Test set: Alpha 0.8, Average loss: 0.1131, Accuracy: 9663/10000 (97%)



9920512it [04:16, 38655.89it/s]  



Test set: Alpha 0.8, Average loss: 0.1130, Accuracy: 9651/10000 (97%)



9920512it [04:28, 36881.58it/s]  



Test set: Alpha 0.8, Average loss: 0.1139, Accuracy: 9648/10000 (96%)



9920512it [04:41, 35251.89it/s]  



Test set: Alpha 0.8, Average loss: 0.1135, Accuracy: 9659/10000 (97%)



9920512it [04:54, 33721.32it/s]  



Test set: Alpha 0.8, Average loss: 0.1137, Accuracy: 9660/10000 (97%)



9920512it [05:06, 32352.90it/s]  



Test set: Alpha 0.8, Average loss: 0.1137, Accuracy: 9656/10000 (97%)



9920512it [05:18, 31118.73it/s]  



Test set: Alpha 0.8, Average loss: 0.1135, Accuracy: 9661/10000 (97%)



9920512it [05:31, 29943.14it/s]  



Test set: Alpha 0.7000000000000001, Average loss: 0.1142, Accuracy: 9658/10000 (97%)



9920512it [05:43, 28865.52it/s]  



Test set: Alpha 0.7000000000000001, Average loss: 0.1160, Accuracy: 9636/10000 (96%)



9920512it [05:56, 27857.22it/s]  



Test set: Alpha 0.7000000000000001, Average loss: 0.1186, Accuracy: 9640/10000 (96%)



9920512it [06:08, 26928.93it/s]  



Test set: Alpha 0.7000000000000001, Average loss: 0.1160, Accuracy: 9662/10000 (97%)



9920512it [06:20, 26055.50it/s]  



Test set: Alpha 0.7000000000000001, Average loss: 0.1146, Accuracy: 9650/10000 (96%)



9920512it [06:33, 25220.41it/s]  



Test set: Alpha 0.7000000000000001, Average loss: 0.1158, Accuracy: 9646/10000 (96%)



9920512it [06:45, 24442.28it/s]  



Test set: Alpha 0.7000000000000001, Average loss: 0.1150, Accuracy: 9658/10000 (97%)



9920512it [06:58, 23710.50it/s]  



Test set: Alpha 0.7000000000000001, Average loss: 0.1149, Accuracy: 9657/10000 (97%)



9920512it [07:10, 23039.15it/s]  



Test set: Alpha 0.7000000000000001, Average loss: 0.1155, Accuracy: 9651/10000 (97%)



9920512it [07:23, 22389.13it/s]  



Test set: Alpha 0.7000000000000001, Average loss: 0.1158, Accuracy: 9658/10000 (97%)



9920512it [07:35, 21771.82it/s]  



Test set: Alpha 0.6000000000000001, Average loss: 0.1267, Accuracy: 9637/10000 (96%)



9920512it [07:48, 21184.39it/s]  



Test set: Alpha 0.6000000000000001, Average loss: 0.1240, Accuracy: 9631/10000 (96%)



9920512it [08:00, 20645.23it/s]  



Test set: Alpha 0.6000000000000001, Average loss: 0.1222, Accuracy: 9655/10000 (97%)



9920512it [08:12, 20133.14it/s]  



Test set: Alpha 0.6000000000000001, Average loss: 0.1221, Accuracy: 9647/10000 (96%)



9920512it [08:25, 19643.51it/s]  



Test set: Alpha 0.6000000000000001, Average loss: 0.1202, Accuracy: 9652/10000 (97%)



9920512it [08:37, 19173.83it/s]  



Test set: Alpha 0.6000000000000001, Average loss: 0.1197, Accuracy: 9655/10000 (97%)



9920512it [08:49, 18722.10it/s]  



Test set: Alpha 0.6000000000000001, Average loss: 0.1193, Accuracy: 9651/10000 (97%)



9920512it [09:02, 18299.64it/s]  



Test set: Alpha 0.6000000000000001, Average loss: 0.1206, Accuracy: 9658/10000 (97%)



9920512it [09:14, 17893.15it/s]  



Test set: Alpha 0.6000000000000001, Average loss: 0.1216, Accuracy: 9654/10000 (97%)



9920512it [09:26, 17502.34it/s]  



Test set: Alpha 0.6000000000000001, Average loss: 0.1207, Accuracy: 9660/10000 (97%)



9920512it [09:39, 17122.32it/s]  



Test set: Alpha 0.5000000000000001, Average loss: 0.1408, Accuracy: 9605/10000 (96%)



9920512it [09:52, 16753.86it/s]  



Test set: Alpha 0.5000000000000001, Average loss: 0.1287, Accuracy: 9638/10000 (96%)



9920512it [10:04, 16408.09it/s]  



Test set: Alpha 0.5000000000000001, Average loss: 0.1292, Accuracy: 9640/10000 (96%)



9920512it [10:17, 16077.19it/s]  



Test set: Alpha 0.5000000000000001, Average loss: 0.1290, Accuracy: 9640/10000 (96%)



9920512it [10:29, 15763.60it/s]  



Test set: Alpha 0.5000000000000001, Average loss: 0.1309, Accuracy: 9632/10000 (96%)



9920512it [10:41, 15461.92it/s]  



Test set: Alpha 0.5000000000000001, Average loss: 0.1259, Accuracy: 9640/10000 (96%)



9920512it [10:54, 15161.43it/s]  



Test set: Alpha 0.5000000000000001, Average loss: 0.1259, Accuracy: 9654/10000 (97%)



9920512it [11:07, 14872.26it/s]  



Test set: Alpha 0.5000000000000001, Average loss: 0.1253, Accuracy: 9650/10000 (96%)



9920512it [11:19, 14595.05it/s]  



Test set: Alpha 0.5000000000000001, Average loss: 0.1259, Accuracy: 9651/10000 (97%)



9920512it [11:32, 14330.31it/s]  



Test set: Alpha 0.5000000000000001, Average loss: 0.1265, Accuracy: 9648/10000 (96%)



9920512it [11:44, 14074.20it/s]  



Test set: Alpha 0.40000000000000013, Average loss: 0.1480, Accuracy: 9590/10000 (96%)



9920512it [11:57, 13827.91it/s]  



Test set: Alpha 0.40000000000000013, Average loss: 0.1364, Accuracy: 9619/10000 (96%)



9920512it [12:09, 13590.58it/s]  



Test set: Alpha 0.40000000000000013, Average loss: 0.1386, Accuracy: 9631/10000 (96%)



9920512it [12:22, 13359.62it/s]  



Test set: Alpha 0.40000000000000013, Average loss: 0.1328, Accuracy: 9654/10000 (97%)



9920512it [12:35, 13136.02it/s]  



Test set: Alpha 0.40000000000000013, Average loss: 0.1364, Accuracy: 9633/10000 (96%)



9920512it [12:47, 12926.95it/s]  



Test set: Alpha 0.40000000000000013, Average loss: 0.1325, Accuracy: 9648/10000 (96%)



9920512it [12:59, 12721.03it/s]  



Test set: Alpha 0.40000000000000013, Average loss: 0.1329, Accuracy: 9655/10000 (97%)



9920512it [13:12, 12517.50it/s]  



Test set: Alpha 0.40000000000000013, Average loss: 0.1342, Accuracy: 9638/10000 (96%)



9920512it [13:24, 12325.20it/s]  



Test set: Alpha 0.40000000000000013, Average loss: 0.1328, Accuracy: 9646/10000 (96%)



9920512it [13:37, 12138.70it/s]  



Test set: Alpha 0.40000000000000013, Average loss: 0.1328, Accuracy: 9645/10000 (96%)



9920512it [13:49, 11958.06it/s]  



Test set: Alpha 0.30000000000000016, Average loss: 0.1663, Accuracy: 9578/10000 (96%)



9920512it [14:02, 11777.41it/s]  



Test set: Alpha 0.30000000000000016, Average loss: 0.1566, Accuracy: 9598/10000 (96%)



9920512it [14:14, 11608.71it/s]  



Test set: Alpha 0.30000000000000016, Average loss: 0.1457, Accuracy: 9627/10000 (96%)



9920512it [14:26, 11444.58it/s]  



Test set: Alpha 0.30000000000000016, Average loss: 0.1439, Accuracy: 9626/10000 (96%)



9920512it [14:39, 11285.31it/s]  



Test set: Alpha 0.30000000000000016, Average loss: 0.1414, Accuracy: 9629/10000 (96%)



9920512it [14:51, 11122.46it/s]  



Test set: Alpha 0.30000000000000016, Average loss: 0.1414, Accuracy: 9634/10000 (96%)



9920512it [15:04, 10967.22it/s]  



Test set: Alpha 0.30000000000000016, Average loss: 0.1418, Accuracy: 9638/10000 (96%)



9920512it [15:17, 10814.66it/s]  



Test set: Alpha 0.30000000000000016, Average loss: 0.1420, Accuracy: 9640/10000 (96%)



9920512it [15:29, 10671.08it/s]  



Test set: Alpha 0.30000000000000016, Average loss: 0.1421, Accuracy: 9642/10000 (96%)



9920512it [15:42, 10528.56it/s]  



Test set: Alpha 0.30000000000000016, Average loss: 0.1422, Accuracy: 9632/10000 (96%)



9920512it [15:54, 10393.29it/s]  



Test set: Alpha 0.20000000000000018, Average loss: 0.1765, Accuracy: 9569/10000 (96%)



9920512it [16:07, 10256.44it/s]  



Test set: Alpha 0.20000000000000018, Average loss: 0.1612, Accuracy: 9586/10000 (96%)



9920512it [16:19, 10125.22it/s]  



Test set: Alpha 0.20000000000000018, Average loss: 0.1553, Accuracy: 9611/10000 (96%)



9920512it [16:32, 9996.91it/s]   



Test set: Alpha 0.20000000000000018, Average loss: 0.1535, Accuracy: 9621/10000 (96%)



9920512it [16:44, 9873.27it/s]   



Test set: Alpha 0.20000000000000018, Average loss: 0.1507, Accuracy: 9630/10000 (96%)



9920512it [16:57, 9753.32it/s]   



Test set: Alpha 0.20000000000000018, Average loss: 0.1498, Accuracy: 9621/10000 (96%)



9920512it [17:09, 9635.26it/s]   



Test set: Alpha 0.20000000000000018, Average loss: 0.1504, Accuracy: 9630/10000 (96%)



9920512it [17:22, 9519.49it/s]   



Test set: Alpha 0.20000000000000018, Average loss: 0.1483, Accuracy: 9635/10000 (96%)



9920512it [17:34, 9406.21it/s]   



Test set: Alpha 0.20000000000000018, Average loss: 0.1487, Accuracy: 9633/10000 (96%)



9920512it [17:46, 9299.08it/s]   



Test set: Alpha 0.20000000000000018, Average loss: 0.1500, Accuracy: 9633/10000 (96%)



9920512it [17:59, 9193.31it/s]   



Test set: Alpha 0.1000000000000002, Average loss: 0.1851, Accuracy: 9566/10000 (96%)



9920512it [18:11, 9090.00it/s]   



Test set: Alpha 0.1000000000000002, Average loss: 0.1751, Accuracy: 9580/10000 (96%)



9920512it [18:24, 8985.12it/s]   



Test set: Alpha 0.1000000000000002, Average loss: 0.1668, Accuracy: 9602/10000 (96%)



9920512it [18:36, 8884.52it/s]   



Test set: Alpha 0.1000000000000002, Average loss: 0.1560, Accuracy: 9623/10000 (96%)



9920512it [18:49, 8785.01it/s]   



Test set: Alpha 0.1000000000000002, Average loss: 0.1583, Accuracy: 9618/10000 (96%)



9920512it [19:01, 8691.69it/s]   



Test set: Alpha 0.1000000000000002, Average loss: 0.1586, Accuracy: 9626/10000 (96%)



9920512it [19:13, 8597.89it/s]   



Test set: Alpha 0.1000000000000002, Average loss: 0.1587, Accuracy: 9630/10000 (96%)



9920512it [19:26, 8506.30it/s]   



Test set: Alpha 0.1000000000000002, Average loss: 0.1567, Accuracy: 9638/10000 (96%)



9920512it [19:38, 8416.38it/s]   



Test set: Alpha 0.1000000000000002, Average loss: 0.1565, Accuracy: 9633/10000 (96%)



9920512it [19:51, 8327.94it/s]   



Test set: Alpha 0.1000000000000002, Average loss: 0.1579, Accuracy: 9635/10000 (96%)



9920512it [20:03, 8240.05it/s]   



Test set: Alpha 2.220446049250313e-16, Average loss: 0.1956, Accuracy: 9551/10000 (96%)



9920512it [20:16, 8156.38it/s]   



Test set: Alpha 2.220446049250313e-16, Average loss: 0.1677, Accuracy: 9608/10000 (96%)



9920512it [20:28, 8074.73it/s]   



Test set: Alpha 2.220446049250313e-16, Average loss: 0.1694, Accuracy: 9612/10000 (96%)



9920512it [20:40, 7996.90it/s]   



Test set: Alpha 2.220446049250313e-16, Average loss: 0.1713, Accuracy: 9612/10000 (96%)



9920512it [20:53, 7917.39it/s]   



Test set: Alpha 2.220446049250313e-16, Average loss: 0.1643, Accuracy: 9613/10000 (96%)



9920512it [21:05, 7839.13it/s]   



Test set: Alpha 2.220446049250313e-16, Average loss: 0.1617, Accuracy: 9633/10000 (96%)



9920512it [21:17, 7767.83it/s]   



Test set: Alpha 2.220446049250313e-16, Average loss: 0.1629, Accuracy: 9626/10000 (96%)



9920512it [21:27, 7704.09it/s]   



Test set: Alpha 2.220446049250313e-16, Average loss: 0.1636, Accuracy: 9627/10000 (96%)



9920512it [21:38, 7641.87it/s]   



Test set: Alpha 2.220446049250313e-16, Average loss: 0.1653, Accuracy: 9625/10000 (96%)



9920512it [21:48, 7580.42it/s]   



Test set: Alpha 2.220446049250313e-16, Average loss: 0.1640, Accuracy: 9619/10000 (96%)



In [16]:
# Test to see the performance of the parallel branch by setting alpha = 0
# If accuracy is still high(approx. equal to when alpha =1),
# then it hints that the parrallel branch by itself is enough.
# Hence, the original pretrained model can then be pruned to a smaller size
test(pruned_model, device, test_loader, 0 )
test(pruned_model, device, test_loader, 1 )

9920512it [34:18, 4819.88it/s]   



Test set: Alpha 0, Average loss: 0.1640, Accuracy: 9619/10000 (96%)



9920512it [34:19, 4816.72it/s]   



Test set: Alpha 1, Average loss: 0.1178, Accuracy: 9660/10000 (97%)



In [12]:
# Final pruned model after doing the appropriate matrix(or tensor) multiplications
# Note that FC1 output has now only 64 neurons instead of 128
# Likewise, the input of FC2 has 64 instead of 128 neurons
class SmallerNet(nn.Module):
    def __init__(self):
        super(SmallerNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 1, 3, 1, bias = False)
        self.conv2 = nn.Conv2d(1, 2, 3, 1, bias = False)
        self.conv3 = nn.Conv2d(2, 256, 3, 1, bias = False)
        
        self.bn1 = nn.BatchNorm2d(1)
        self.bn2 = nn.BatchNorm2d(2)
        self.bn3 = nn.BatchNorm2d(256)

        self.fc1 = nn.Linear(256 * 1 * 1, 64, bias =False)
        self.fc2 = nn.Linear(64, 50, bias=False)
        self.fc3 = nn.Linear(50, 10, bias = False)
        

    def forward(self, x, alpha = 1):
        # Note that alpha is not being used here
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.conv3(x)
        x = self.bn3(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x= F.relu(x)
        x = self.fc3(x)
        output = F.log_softmax(x, dim=1)
        return output
    
smaller_model = SmallerNet().to(device)

In [13]:
# Since the model is randomly initilaized, the accuracy is low
test(smaller_model, device, test_loader, 0 )

9920512it [23:50, 6936.85it/s]   



Test set: Alpha 0, Average loss: 2.3021, Accuracy: 928/10000 (9%)



In [14]:
# The pretrained weights are loaded after the appropriate matrix multiplications
from copy import deepcopy
smaller_model_dict= deepcopy(model.state_dict())
print(smaller_model_dict["fc1.weight"].shape)
print(smaller_model_dict["fc2.weight"].shape)
smaller_model_dict["fc1.weight"] = pruned_model.state_dict()["fcc.weight"] @ pruned_model.state_dict()["pretrainedNet.fc1.weight"]
smaller_model_dict["fc2.weight"] = pruned_model.state_dict()["pretrainedNet.fc2.weight"] @ pruned_model.state_dict()["fcd.weight"] 
print(smaller_model_dict["fc1.weight"].shape)
print(smaller_model_dict["fc2.weight"].shape)
smaller_model_dict.keys()
smaller_model.load_state_dict(smaller_model_dict)

torch.Size([128, 256])
torch.Size([50, 128])
torch.Size([64, 256])
torch.Size([50, 64])


<All keys matched successfully>

In [15]:
# Test accuracy of the smaller_model is the same as pruned_model when alpha=0
# Note that only pruned_model needs alpha values not the smaller_model(Check the forward functions for both)
# However, for the sake of simplicity, test() and train() funcitons were defined assuming that
# alpha is needed. But this alpha when passed to the smaller_model is irrelevant.
# TODO: Reimplement the train() and test() functions such that alpha is only required for pruned_model
test(smaller_model, device, test_loader, 0 )

9920512it [23:55, 6909.00it/s]   



Test set: Alpha 0, Average loss: 0.1640, Accuracy: 9619/10000 (96%)



In [102]:
print(smaller_model.state_dict()["fc1.weight"].shape)
print(smaller_model.state_dict()["fc2.weight"].shape)
print("--"*10)
print(model.state_dict()["fc1.weight"].shape)
print(model.state_dict()["fc2.weight"].shape)
print("--"*10)
print(pruned_model.state_dict()["pretrainedNet.fc1.weight"].shape)
print(pruned_model.state_dict()["fcc.weight"].shape)
print((pruned_model.state_dict()["fcc.weight"] @ pruned_model.state_dict()["pretrainedNet.fc1.weight"]).shape)
print("--"*10)
print(pruned_model.state_dict()["pretrainedNet.fc2.weight"].shape)
print(pruned_model.state_dict()["fcd.weight"].shape)

torch.Size([64, 256])
torch.Size([50, 64])
--------------------
torch.Size([128, 256])
torch.Size([50, 128])
--------------------
torch.Size([128, 256])
torch.Size([64, 128])
torch.Size([64, 256])
--------------------
torch.Size([50, 128])
torch.Size([128, 64])


In [71]:
pruned_model.state_dict()["fcc.weight"].shape
pruned_model.state_dict()["pretrainedNet.fc1.weight"]

torch.Size([64, 128])

In [32]:
class PrunedNet(nn.Module):
    def __init__(self,pretrainedmodel):
        super(PrunedNet, self).__init__()
        self.pretrainedNet = pretrainedmodel
        
        # No need to update weights of the pretrained network        
        for param in self.pretrainedNet.parameters():
            param.requires_grad=False
        
        self.fca = nn.Linear(256, 30, bias = False)
        self.fcb = nn.Linear(30, 128, bias = False)
        
        self.fcc = nn.Linear(128, 10, bias = False)
        self.fcd = nn.Linear(10, 50, bias = False)
        
    def forward(self, x, alpha=0.9):
        x = self.pretrainedNet.conv1(x)
        x = self.pretrainedNet.bn1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.pretrainedNet.conv2(x)
        x = self.pretrainedNet.bn2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.pretrainedNet.conv3(x)
        x = self.pretrainedNet.bn3(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        
        xorg = torch.flatten(x, 1)
        x1 = self.pretrainedNet.fc1(xorg)
        x2 = self.fca(xorg)
        x2 = self.fcb(x2)
        #x = alpha * x1 + (1-alpha) * x2
        xorg = F.relu(x2)
        
        #x1 = self.pretrainedNet.fc2(xorg)
        x2 = self.fcc(xorg)
        x2 = self.fcd(x2)
        #x = alpha * x1 + (1-alpha) * x2
        x= F.relu(x2)
        
        x = self.pretrainedNet.fc3(x)
        output = F.log_softmax(x, dim=1)
        return output
pruned_model = PrunedNet(model).to(device)

8

In [26]:
for p in pruned_model.pretrainedNet.parameters():
    pass
p.requires_grad

True

In [None]:
class smallerNet(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 1, 3, 1, bias = False)
        self.conv2 = nn.Conv2d(1, 2, 3, 1, bias = False)
        self.conv3 = nn.Conv2d(2, 256, 3, 1, bias = False)
        
        self.bn1 = nn.BatchNorm2d(1)
        self.bn2 = nn.BatchNorm2d(2)
        self.bn3 = nn.BatchNorm2d(256)
        #self.dropout1 = nn.Dropout(0.25)
        #self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(256 * 1 * 1, 64, bias =False)
        self.fc2 = nn.Linear(64, 50, bias=False)
        self.fc3 = nn.Linear(50, 10, bias = False)
        

    def forward(self, x, alpha = 1):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.conv3(x)
        x = self.bn3(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        
        #x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        #x = self.dropout2(x)
        x = self.fc2(x)
        x= F.relu(x)
        x = self.fc3(x)
        output = F.log_softmax(x, dim=1)
        return output
    
smaller_model = smallerNet().to(device)

In [11]:
layer3 = model_dict['conv3.weight'].cpu().numpy()
num_filters = layer3.shape[0]

In [12]:
layer3.shape

(256, 2, 3, 3)

In [13]:
# class PrunedNet(nn.Module):
#     def __init__(self,pretrainedmodel):
#         super(PrunedNet, self).__init__()
#         self.pretrainedNet = pretrainedmodel
        
#         # No need to update weights of the pretrained network        
#         for param in self.pretrainedNet.parameters():
#             param.requires_grad=False
        
#         self.fca = nn.Linear(256, 30, bias = False)
#         self.fcb = nn.Linear(30, 128, bias = False)
        
#         self.fcc = nn.Linear(128, 10, bias = False)
#         self.fcd = nn.Linear(10, 50, bias = False)
        
#     def forward(self, x, alpha=0.9):
#         x = self.pretrainedNet.conv1(x)
#         x = self.pretrainedNet.bn1(x)
#         x = F.relu(x)
#         x = F.max_pool2d(x, 2)
#         x = self.pretrainedNet.conv2(x)
#         x = self.pretrainedNet.bn2(x)
#         x = F.relu(x)
#         x = F.max_pool2d(x, 2)
#         x = self.pretrainedNet.conv3(x)
#         x = self.pretrainedNet.bn3(x)
#         x = F.relu(x)
#         x = F.max_pool2d(x, 2)
        
#         xorg = torch.flatten(x, 1)
#         x1 = self.pretrainedNet.fc1(xorg)
#         x2 = self.fca(xorg)
#         x2 = self.fcb(x2)
#         x = alpha * x1 + (1-alpha) * x2
#         xorg = F.relu(x)
        
#         x1 = self.pretrainedNet.fc2(xorg)
#         x2 = self.fcc(xorg)
#         x2 = self.fcd(x2)
#         x = alpha * x1 + (1-alpha) * x2
#         x= F.relu(x)
        
#         x = self.pretrainedNet.fc3(x)
#         output = F.log_softmax(x, dim=1)
#         return output
# pruned_model = PrunedNet(model).to(device)

256

In [14]:
scores = np.sum(np.sum(np.sum(np.abs(layer3),3),2),1)
print(scores)

[2.8458521 2.4906092 2.182644  2.4704623 2.604546  2.7074933 2.4273744
 2.4395137 2.7046404 2.2236407 2.724819  2.346136  2.4736838 2.704204
 2.352558  1.8051548 2.0248778 2.1634579 2.0994892 2.1173074 2.4371142
 2.56626   2.1074808 2.8850985 2.6320398 2.3986905 2.6123188 2.4185576
 2.3762336 2.3464718 2.3099728 2.3896155 2.3797746 2.571484  2.682601
 2.457549  2.5476065 2.703555  2.171604  2.8341658 2.7164273 2.3056388
 2.2262316 2.4162507 2.162299  2.1605859 2.3533275 2.154541  2.29719
 2.3968291 2.1048923 2.3813574 2.5073893 2.5824394 2.5691903 2.173055
 2.6351976 2.549543  2.6614966 2.5639572 2.6166143 2.8284776 2.1105638
 2.4044895 2.4763665 2.1381226 2.629759  2.30014   2.2500648 2.2014756
 1.8681186 1.7940662 2.735925  2.3465033 2.6203418 2.2923179 2.656416
 2.6224637 2.6273053 2.1582756 2.4212275 1.8324049 2.1633458 2.0751681
 2.4310641 2.3555272 2.2806036 2.3307126 2.0461018 2.3769479 2.0755684
 2.1711507 2.357153  2.3072407 2.2286    2.5354738 2.273097  2.405572
 2.4387927 2.

In [15]:
ordered_indices = np.argsort(scores)
ordered_indices

array([222,  71, 124,  15,  81, 179,  70, 229, 173, 123, 217, 119, 180,
       129, 161, 238, 213,  16, 206, 187, 105,  88, 139, 197, 209,  83,
        90, 131, 214, 177,  18, 166,  50, 143,  22,  62, 137, 127, 210,
        19, 133,  65, 171, 220, 208, 245, 144, 184,  47, 215,  79,  45,
        44,  82,  17,  91,  38, 150,  55, 128,   2, 120, 151, 191,  69,
       130, 159, 252, 181,   9,  42,  94, 227, 196, 231, 236, 247,  68,
        99, 138, 182, 116, 218,  96, 167,  86, 234, 235,  75, 233,  48,
        67,  41,  93,  30, 100, 107, 110, 135, 189,  87, 140, 154, 163,
        11,  29,  73, 205, 242,  14,  46, 125, 237,  85, 115,  92, 240,
       230, 228, 103, 142, 108,  28,  89, 145,  32,  51, 147, 255, 165,
       226, 212,  31,  49,  25, 117, 188,  63,  97, 168, 169, 104, 241,
       194,  43,  27,  80, 244,   6,  84, 109,  20, 232,  98,   7, 162,
       134, 198,  35, 126, 155, 132,   3, 114,  12,  64, 192, 246, 136,
       219, 221, 249,   1, 195,  52, 160, 172, 253, 122, 248, 20

In [16]:
num_filters_to_retain = int(num_filters * (1 - remove_ratio))
filters_idxs_to_retain = ordered_indices[:num_filters_to_retain]

In [17]:
model_dict['conv3.weight'][filters_idxs_to_retain].shape

torch.Size([192, 2, 3, 3])

In [18]:
model_dict['conv3.weight'] = model_dict['conv3.weight'][filters_idxs_to_retain]
model_dict["bn3.weight"] = model_dict['bn3.weight'][filters_idxs_to_retain]
model_dict["bn3.running_mean"] = model_dict['bn3.running_mean'][filters_idxs_to_retain]
model_dict["bn3.running_var"] = model_dict['bn3.running_var'][filters_idxs_to_retain]
model_dict["bn3.bias"] = model_dict['bn3.bias'][filters_idxs_to_retain]


In [19]:
model_dict["fc1.weight"] = model_dict["fc1.weight"][:,filters_idxs_to_retain]

In [20]:
#print(model_dict["fc1.weight"].shape)
#print(model_dict["fc1.weight"][:,filters_idxs_to_retain].shape)


In [26]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 1, 3, 1, bias = False)
        self.conv2 = nn.Conv2d(1, 2, 3, 1, bias = False)
        self.conv3 = nn.Conv2d(2, 192, 3, 1, bias = False)
        
        self.bn1 = nn.BatchNorm2d(1)
        self.bn2 = nn.BatchNorm2d(2)
        self.bn3 = nn.BatchNorm2d(192)
        #self.dropout1 = nn.Dropout(0.25)
        #self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(192 * 1 * 1, 128)
        self.fc2 = nn.Linear(128, 50)
        self.fc3 = nn.Linear(50, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.conv3(x)
        x = self.bn3(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        
        #x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        #x = self.dropout2(x)
        x = self.fc2(x)
        x= F.relu(x)
        x = self.fc3(x)
        output = F.log_softmax(x, dim=1)
        return output
    
model = Net().to(device)

In [27]:
model.load_state_dict(model_dict)


<All keys matched successfully>

In [28]:
test(model, device, test_loader)


Test set: Average loss: 0.3935, Accuracy: 8718/10000 (87%)



In [29]:
optimizer = optim.Adadelta(model.parameters(), lr=1.0)

scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
for epoch in range(1, 10 + 1):
    train(10,model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)
    scheduler.step()


Test set: Average loss: 0.1543, Accuracy: 9537/10000 (95%)




Test set: Average loss: 0.1545, Accuracy: 9533/10000 (95%)


Test set: Average loss: 0.1722, Accuracy: 9477/10000 (95%)




Test set: Average loss: 0.1266, Accuracy: 9613/10000 (96%)


Test set: Average loss: 0.1176, Accuracy: 9644/10000 (96%)




Test set: Average loss: 0.1124, Accuracy: 9656/10000 (97%)




Test set: Average loss: 0.1118, Accuracy: 9671/10000 (97%)


Test set: Average loss: 0.1110, Accuracy: 9683/10000 (97%)




Test set: Average loss: 0.1129, Accuracy: 9676/10000 (97%)


Test set: Average loss: 0.1100, Accuracy: 9695/10000 (97%)



In [23]:
torch.save(model, "pruned_model.pth")

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


In [24]:
new_model = torch.load("pruned_model.pth")

In [25]:
test(new_model, device, test_loader)


Test set: Average loss: 0.1149, Accuracy: 9673/10000 (97%)



In [26]:
import numpy as np
import torch
from torch.autograd import Variable

x = Variable(torch.Tensor([5]),requires_grad=True)
w0 = Variable(torch.Tensor([2]),requires_grad=False)

w1 = Variable(torch.Tensor([3]),requires_grad=False)


y1 = x**2*w0
y2 = y1*w1


gradients = []
y1.register_hook(lambda x: gradients.append(x))
y2.register_hook(lambda x: gradients.append(x))
x.register_hook(lambda x: gradients.append(x))

y2.backward()

gradients

[tensor([1.]), tensor([3.]), tensor([60.])]

In [27]:
def print_value(value):
    print("The value is {}".format(value.item()))

x = Variable(torch.Tensor([5]),requires_grad=True)
w0 = Variable(torch.Tensor([2]),requires_grad=False)

w1 = Variable(torch.Tensor([3]),requires_grad=False)


y1 = x**2*w0
y2 = y1*w1


gradients = []
y1.register_hook(lambda x: print_value(x))
y2.register_hook(lambda x: print_value(x))
x.register_hook(lambda x: print_value(x))



<torch.utils.hooks.RemovableHandle at 0x7f8133865a10>

In [28]:
y2.backward()

The value is 1.0
The value is 3.0
The value is 60.0
