## Dynamic weight pruning on Resnet18

In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.models.quantization as models
import torchsummary
print(torch.__version__ , torchvision.__version__)

2.3.1+cpu 0.18.1+cpu


In [4]:
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [5]:
trainset = datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=16, pin_memory=True)

testset = datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32,
                                          shuffle=True, num_workers=16, pin_memory= True)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
testset.classes

['airplane',
 'automobile',
 'bird',
 'cat',
 'deer',
 'dog',
 'frog',
 'horse',
 'ship',
 'truck']

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [6]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, identity_downsample= None, stride= 1):
        super().__init__()


        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size= 3, stride= stride, padding= 1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size= 3,padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace= True)
        self.identity_downsample = identity_downsample
        self.skip_add = nn.quantized.FloatFunctional()

    def forward(self, x):
        identity = x

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        # x = self.relu(x)

        if self.identity_downsample is not None:
            identity = self.identity_downsample(identity)

        # x += identity
        x = self.skip_add.add(x, identity)
        x = self.relu(x)
        return x


class ResNet(nn.Module):
    def __init__(self, BasicBlock, layers, image_channels, num_classes):
        super().__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(image_channels, 64, kernel_size= 7, stride= 2, padding= 3)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace= True)
        self.maxpool = nn.MaxPool2d(kernel_size= 3, stride= 2, padding= 1)

        self.layer1 = self._make_layer(BasicBlock, layers[0], out_channels= 64, stride= 1)
        self.layer2 = self._make_layer(BasicBlock, layers[1], out_channels= 128, stride= 2)
        self.layer3 = self._make_layer(BasicBlock, layers[2], out_channels= 256, stride= 2)
        self.layer4 = self._make_layer(BasicBlock, layers[3], out_channels= 512, stride= 2)

        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(512, num_classes)
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        # print(x.shape)
        x = self.quant(x)
        x = self.conv1(x)
        # print(x.shape)
        x = self.bn1(x)
        # print(x.shape)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = self.dequant(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)


        return x


    def _make_layer(self, BasicBlock, num_residual_blocks, out_channels, stride):
        identity_downsample = None
        layers = []
        if stride != 1 or self.in_channels != out_channels:
            identity_downsample = nn.Sequential(nn.Conv2d(self.in_channels, out_channels , kernel_size= 1, stride= stride),
                                               nn.BatchNorm2d(out_channels))
        layers.append(BasicBlock(self.in_channels, out_channels, identity_downsample, stride))
        self.in_channels = out_channels

        for i in range(num_residual_blocks - 1):
            layers.append(BasicBlock(self.in_channels, out_channels))

        return nn.Sequential(*layers)

def ResNet18(img_channels= 3, num_classes= 10):
    return ResNet(BasicBlock, [2, 2, 2, 2], img_channels, num_classes)


In [7]:
model = ResNet18().to(device)

In [8]:
torchsummary.summary(model, input_size= (3,224,224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         QuantStub-1          [-1, 3, 224, 224]               0
            Conv2d-2         [-1, 64, 112, 112]           9,472
       BatchNorm2d-3         [-1, 64, 112, 112]             128
              ReLU-4         [-1, 64, 112, 112]               0
         MaxPool2d-5           [-1, 64, 56, 56]               0
            Conv2d-6           [-1, 64, 56, 56]          36,928
       BatchNorm2d-7           [-1, 64, 56, 56]             128
              ReLU-8           [-1, 64, 56, 56]               0
            Conv2d-9           [-1, 64, 56, 56]          36,928
      BatchNorm2d-10           [-1, 64, 56, 56]             128
         Identity-11           [-1, 64, 56, 56]               0
             ReLU-12           [-1, 64, 56, 56]               0
       BasicBlock-13           [-1, 64, 56, 56]               0
           Conv2d-14           [-1, 64,

In [9]:
helper_functions.count_nonzero_params(model)

The number of non-zero parameters : 11181641
The number of zero parameters : 4801


In [10]:
helper_functions.train_with_pruning(model, trainloader, 20, device)

Epoch 1, Loss: 1.4067058670703712
Epoch 2, Loss: 0.956815745214672
Epoch 3, Loss: 0.7384919034855445
Epoch 4, Loss: 0.5778948737837165
Epoch 5, Loss: 0.44662247857321863
Pruning after epoch 5
Pruning done.
Epoch 6, Loss: 0.3388059772646336
Epoch 7, Loss: 0.23517573008413814
Epoch 8, Loss: 0.14617571805406104
Epoch 9, Loss: 0.08483502784472369
Epoch 10, Loss: 0.056397375510171856
Pruning after epoch 10
Pruning done.
Epoch 11, Loss: 0.04288635634617104
Epoch 12, Loss: 0.022173141134376434
Epoch 13, Loss: 0.008181728106593508
Epoch 14, Loss: 0.004467464989715773
Epoch 15, Loss: 0.0023382750003065086
Pruning after epoch 15
Pruning done.
Epoch 16, Loss: 0.001717682702906306
Epoch 17, Loss: 0.0013493761868940194
Epoch 18, Loss: 0.0010571652148237518
Epoch 19, Loss: 0.0009725533975952226
Epoch 20, Loss: 0.0011031236185993264
Pruning after epoch 20
Pruning done.


In [11]:
helper_functions.test(model, testloader, device)

Accuracy of the network on the 10000 test images: 83 %
 Evaluation time :21.592


In [12]:
helper_functions.count_nonzero_params(model)

The number of non-zero parameters : 10069239
The number of zero parameters : 1117203


In [13]:
torch.save(model.state_dict(), 'resnet_with_10_class.pth')