In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.models.quantization as models
import torch.nn.utils.prune as prune
import os
import copy
import torchsummary

In [2]:
from timeit import default_timer as timer
def print_train_time(start:float, end : float, device : torch.device = None):
    total_time = end - start
    print(f' Train time :{total_time:.3f}')
    return total_time


In [3]:
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [4]:
trainset = datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=16, pin_memory=True)

testset = datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32,
                                          shuffle=True, num_workers=16, pin_memory= True)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [6]:
import torch
import torch.nn as nn

class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, identity_downsample= None, stride= 1):
        super().__init__()

        self.expansion = 1
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size= 3, stride= stride, padding= 1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size= 3,padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace= True)
        self.identity_downsample = identity_downsample
        self.skip_add = nn.quantized.FloatFunctional()

    def forward(self, x):
        identity = x

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        # x = self.relu(x)

        if self.identity_downsample is not None:
            identity = self.identity_downsample(identity)

        # x += identity
        x = self.skip_add.add(x, identity)
        x = self.relu(x)
        return x


class ResNet(nn.Module):
    def __init__(self, BasicBlock, layers, image_channels, num_classes):
        super().__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(image_channels, 64, kernel_size= 7, stride= 2, padding= 3)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace= True)
        self.maxpool = nn.MaxPool2d(kernel_size= 3, stride= 2, padding= 1)

        self.layer1 = self._make_layer(BasicBlock, layers[0], out_channels= 64, stride= 1)
        self.layer2 = self._make_layer(BasicBlock, layers[1], out_channels= 128, stride= 2)
        self.layer3 = self._make_layer(BasicBlock, layers[2], out_channels= 256, stride= 2)
        self.layer4 = self._make_layer(BasicBlock, layers[3], out_channels= 512, stride= 2)

        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(512, num_classes)
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
        x = self.dequant(x)

        return x


    def _make_layer(self, BasicBlock, num_residual_blocks, out_channels, stride):
        identity_downsample = None
        layers = []
        if stride != 1 or self.in_channels != out_channels:
            identity_downsample = nn.Sequential(nn.Conv2d(self.in_channels, out_channels , kernel_size= 1, stride= stride),
                                               nn.BatchNorm2d(out_channels))
        layers.append(BasicBlock(self.in_channels, out_channels, identity_downsample, stride))
        self.in_channels = out_channels

        for i in range(num_residual_blocks - 1):
            layers.append(BasicBlock(self.in_channels, out_channels))

        return nn.Sequential(*layers)

def ResNet18(img_channels= 3, num_classes= 100):
    return ResNet(BasicBlock, [2, 2, 2, 2], img_channels, num_classes)


In [7]:
model = ResNet18()
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (skip_add): FloatFunctional(
        (activation_post_process): Identity()
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=

In [8]:
torchsummary.summary(model, input_size= (3,224,224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         QuantStub-1          [-1, 3, 224, 224]               0
            Conv2d-2         [-1, 64, 112, 112]           9,472
       BatchNorm2d-3         [-1, 64, 112, 112]             128
              ReLU-4         [-1, 64, 112, 112]               0
         MaxPool2d-5           [-1, 64, 56, 56]               0
            Conv2d-6           [-1, 64, 56, 56]          36,928
       BatchNorm2d-7           [-1, 64, 56, 56]             128
              ReLU-8           [-1, 64, 56, 56]               0
            Conv2d-9           [-1, 64, 56, 56]          36,928
      BatchNorm2d-10           [-1, 64, 56, 56]             128
         Identity-11           [-1, 64, 56, 56]               0
             ReLU-12           [-1, 64, 56, 56]               0
       BasicBlock-13           [-1, 64, 56, 56]               0
           Conv2d-14           [-1, 64,

In [9]:
for layers, _ in model.named_modules():
  print(layers)


conv1
bn1
relu
maxpool
layer1
layer1.0
layer1.0.conv1
layer1.0.bn1
layer1.0.conv2
layer1.0.bn2
layer1.0.relu
layer1.0.skip_add
layer1.0.skip_add.activation_post_process
layer1.1
layer1.1.conv1
layer1.1.bn1
layer1.1.conv2
layer1.1.bn2
layer1.1.relu
layer1.1.skip_add
layer1.1.skip_add.activation_post_process
layer2
layer2.0
layer2.0.conv1
layer2.0.bn1
layer2.0.conv2
layer2.0.bn2
layer2.0.relu
layer2.0.identity_downsample
layer2.0.identity_downsample.0
layer2.0.identity_downsample.1
layer2.0.skip_add
layer2.0.skip_add.activation_post_process
layer2.1
layer2.1.conv1
layer2.1.bn1
layer2.1.conv2
layer2.1.bn2
layer2.1.relu
layer2.1.skip_add
layer2.1.skip_add.activation_post_process
layer3
layer3.0
layer3.0.conv1
layer3.0.bn1
layer3.0.conv2
layer3.0.bn2
layer3.0.relu
layer3.0.identity_downsample
layer3.0.identity_downsample.0
layer3.0.identity_downsample.1
layer3.0.skip_add
layer3.0.skip_add.activation_post_process
layer3.1
layer3.1.conv1
layer3.1.bn1
layer3.1.conv2
layer3.1.bn2
layer3.1.relu

In [10]:
def calculate_time(start:float, end : float, device : torch.device = None):
    total_time = end - start
    print(f' Evaluation time :{total_time:.3f}')

In [11]:
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('size (KB) :',os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')

In [12]:
def prune_model(model, pruning_rate=0.3):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):

            # Applying unstructured L1 norm pruning
            prune.l1_unstructured(module, name='weight', amount=pruning_rate)

            prune.remove(module, 'weight')

In [13]:
def count_nonzero_params(model):
    # non_zero_count = 0
    # for param in model.parameters():
    #     non_zero_count += torch.count_nonzero(param).item()
    print("The number of non-zero parameters :", sum(p.numel() for p in model.parameters()))
    zero_count = 0
    for param in model.parameters():
        zero_count += torch.sum(param == 0).item()
    
    print("The number of zero parameters :", zero_count)

In [14]:
# count_nonzero_params(model)

In [15]:
def slice_dataloader(dataloader, start, end):
    sliced_data = []
    current_index = 0
    for inputs, labels in dataloader:
        batch_size = inputs.size(0)
        if current_index + batch_size > start:
            # Find the start index within the current batch
            start_idx = max(start - current_index, 0)
            # Find the end index within the current batch
            end_idx = min(end - current_index, batch_size)
            sliced_inputs = inputs[start_idx:end_idx]
            sliced_labels = labels[start_idx:end_idx]
            sliced_data.append((sliced_inputs, sliced_labels))
            if current_index + batch_size >= end:
                break
        current_index += batch_size
    return sliced_data

In [16]:
def train(model,dataloader, device):
  criterion = nn.CrossEntropyLoss()
  optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
  epochs = 4
  for epoch in range(epochs):
      running_loss = 0.0
      model.train()
      for i, data in enumerate(dataloader, 0):
          inputs, labels = data
          inputs, labels = inputs.to(device), labels.to(device)
          optimizer.zero_grad()
          model = model.to(device)
          outputs = model(inputs)
          loss = criterion(outputs, labels)
          loss.backward()
          optimizer.step()
          running_loss += loss.item()

      print(f'Epoch {epoch+1}, Loss: {running_loss/len(trainloader)}')

      if (epoch + 1) % 2 == 0:
          print(f'Pruning after epoch {epoch + 1}')
          prune_model(model, pruning_rate=0.1)
          print('Pruning done.')

In [17]:
def test(model, dataloader, device):
  start = timer()
  correct = 0
  total = 0
  model.eval()
  with torch.no_grad():
      for data in dataloader:
          inputs, labels = data
          inputs, labels = inputs.to(device), labels.to(device)
          model = model.to(device)
          outputs = model(inputs)
          _, predicted = torch.max(outputs.data, 1)
          total += labels.size(0)
          correct += (predicted == labels).sum().item()
      print(f'Accuracy of the network on the {total} test images: %d %%' % (100 * correct / total))
  end = timer()
  calculate_time(start, end, device)

In [18]:
def count_parameters(model: nn.Module) -> int:
     print(sum(p.numel() for p in model.parameters()))

In [19]:
count_parameters(model)

11232612


In [None]:
train(model, trainloader, device)

Epoch 1, Loss: 1.4423082346654


In [None]:
sliced_data_1 = slice_dataloader(testloader, start=0, end=2000)

In [None]:
test(model, sliced_data_1, device)

In [None]:
print_size_of_model(model)

In [None]:
# dynamic_quant_model = copy.deepcopy(model)

# state_dict = model.state_dict()
# dynamic_quant_model.load_state_dict(state_dict)
count_nonzero_params(model)

In [None]:
device = 'cpu'
model = model.to(device)

In [None]:
dynamic_quant_model = torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8 )

In [None]:
sliced_data_2 = slice_dataloader(testloader, start=1000, end=2000)

In [None]:
test(dynamic_quant_model, sliced_data_2, device)

In [None]:
print_size_of_model(dynamic_quant_model)

In [None]:
device = 'cpu'

In [None]:
# static_quant_model = copy.deepcopy(dynamic_quant_model)

# state_dict = dynamic_quant_model.state_dict()
# static_quant_model.load_state_dict(state_dict).to('cpu')
dynamic_quant_model.eval()
modules_to_fuse = modules_to_fuse = [
    ['conv1','bn1','relu'],
    ['layer1.0.conv1','layer1.0.bn1'],
    ['layer1.0.conv2','layer1.0.bn2','layer1.0.relu'],
    ['layer1.1.conv1','layer1.1.bn1'],
    ['layer1.1.conv2','layer1.1.bn2','layer1.1.relu'],
    ['layer2.0.conv1','layer2.0.bn1'],
    ['layer2.0.conv2','layer2.0.bn2','layer2.0.relu'],
    ['layer2.0.identity_downsample.0','layer2.0.identity_downsample.1'],
    ['layer2.1.conv1','layer2.1.bn1'],
    ['layer2.1.conv2','layer2.1.bn2','layer2.1.relu'],
    ['layer3.0.conv1','layer3.0.bn1'],
    ['layer3.0.conv2','layer3.0.bn2','layer3.0.relu'],
    ['layer3.0.identity_downsample.0','layer3.0.identity_downsample.1'],
    ['layer3.1.conv1','layer3.1.bn1'],
    ['layer3.1.conv2','layer3.1.bn2','layer3.1.relu'],
    ['layer4.0.conv1','layer4.0.bn1'],
    ['layer4.0.conv2','layer4.0.bn2','layer4.0.relu'],
    ['layer4.0.identity_downsample.0','layer4.0.identity_downsample.1'],
    ['layer4.1.conv1','layer4.1.bn1'],
    ['layer4.1.conv2','layer4.1.bn2','layer4.1.relu'],
]

dynamic_quant_model = torch.quantization.fuse_modules(dynamic_quant_model, modules_to_fuse)

In [None]:
[['conv1','bn1','relu'],
['layer1','layer1.0'],
['layer1.0.conv1','layer1.0.bn1'],
['layer1.0.conv2','layer1.0.bn2','layer1.0.relu'],
['layer1.0.skip_add','layer1.0.skip_add.activation_post_process'],
['layer1.1'],
['layer1.1.conv1','layer1.1.bn1'],
['layer1.1.conv2','layer1.1.bn2','layer1.1.relu'],
['layer1.1.skip_add','layer1.1.skip_add.activation_post_process'],
['layer2','layer2.0'],
['layer2.0.conv1','layer2.0.bn1'],
['layer2.0.conv2','layer2.0.bn2','layer2.0.relu'],
['layer2.0.identity_downsample','layer2.0.identity_downsample.0','layer2.0.identity_downsample.1'],
['layer2.0.skip_add','layer2.0.skip_add.activation_post_process'],
['layer2.1'],
['layer2.1.conv1','layer2.1.bn1'],
['layer2.1.conv2','layer2.1.bn2','layer2.1.relu'],
['layer2.1.skip_add','layer2.1.skip_add.activation_post_process'],
['layer3','layer3.0'],
['layer3.0.conv1','layer3.0.bn1'],
['layer3.0.conv2','layer3.0.bn2','layer3.0.relu'],
['layer3.0.identity_downsample','layer3.0.identity_downsample.0','layer3.0.identity_downsample.1'],
['layer3.0.skip_add','layer3.0.skip_add.activation_post_process'],
['layer3.1'],
['layer3.1.conv1','layer3.1.bn1'],
['layer3.1.conv2','layer3.1.bn2','layer3.1.relu'],
['layer3.1.skip_add','layer3.1.skip_add.activation_post_process'],
['layer4','layer4.0'],
['layer4.0.conv1','layer4.0.bn1'],
['layer4.0.conv2','layer4.0.bn2','layer4.0.relu'],
['layer4.0.identity_downsample','layer4.0.identity_downsample.0','layer4.0.identity_downsample.1'],
['layer4.0.skip_add','layer4.0.skip_add.activation_post_process'],
['layer4.1'],
['layer4.1.conv1','layer4.1.bn1'],
['layer4.1.conv2','layer4.1.bn2','layer4.1.relu'],
['layer4.1.skip_add','layer4.1.skip_add.activation_post_process'],
['avgpool','fc'],
['quant','dequant']]

In [None]:
dynamic_quant_model.qconfig = torch.quantization.default_qconfig
static_quant_model = torch.quantization.prepare(dynamic_quant_model, inplace=True)

In [None]:
sliced_data_3 = slice_dataloader(testloader, start=2000, end=2000)

In [None]:
test(static_quant_model, sliced_data_3, device)

In [None]:
torch.quantization.convert(static_quant_model, inplace=True)

In [None]:
print_size_of_model(static_quant_model)

In [None]:
sliced_data_4 = slice_dataloader(testloader, start=4000, end=5000)

In [None]:
test(static_quant_model, sliced_data_4, device)

In [None]:
print_size_of_model(static_quant_model)