In [1]:
from torchvision import models
from torchsummary import summary
import torch.nn as nn
import torch
import os

# Load models

In [2]:
resnet_model = models.resnet101(pretrained=True)
vis_trans = models.vit_l_32(pretrained=True)

Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth
100%|██████████| 171M/171M [00:01<00:00, 140MB/s]
Downloading: "https://download.pytorch.org/models/vit_l_32-c7638314.pth" to /root/.cache/torch/hub/checkpoints/vit_l_32-c7638314.pth
100%|██████████| 1.14G/1.14G [00:31<00:00, 38.9MB/s]


In [None]:
def get_model_size(model):
  dummy_input = torch.randn(1, 3, 224, 224)
  size = sum(torch.nn.utils.parameters_to_vector(model.parameters()).size() * 4) / (1024 * 1024)
  return size

In [None]:
get_model_size(resnet_model)

169.94155883789062

In [3]:
def print_size_of_model(model, label=""):
    torch.save(model.state_dict(), "temp.p")
    size=os.path.getsize("temp.p")
    print("model: ",label,' \t','Size (MB):', round(size/1024/1024, 2))
    os.remove('temp.p')

In [4]:
print_size_of_model(resnet_model, 'resnet')
print_size_of_model(vis_trans, 'visual_transformer')

model:  resnet  	 Size (MB): 170.5
model:  visual_transformer  	 Size (MB): 1169.43


# Quantization

## Dynamic

In [None]:
resnet_int8 = torch.ao.quantization.quantize_dynamic(resnet_model)
print_size_of_model(resnet_int8, 'resnet_int8')

model:  resnet_int8  	 Size (MB): 164.64


In [None]:
vis_trans_int8 = torch.ao.quantization.quantize_dynamic(vis_trans)
print_size_of_model(vis_trans_int8, 'vis_trans_int8')

model:  vis_trans_int8  	 Size (MB): 590.54


## Static

In [None]:
resnet_model.eval()
vis_trans.eval()

In [None]:
resnet_model.qconfig = torch.ao.quantization.get_default_qconfig('x86')
vis_trans.qconfig = torch.ao.quantization.get_default_qconfig('x86')

In [None]:
resnet_model_prepared = torch.ao.quantization.prepare(resnet_model)
vis_trans_prepared = torch.ao.quantization.prepare(vis_trans)



In [None]:
input_fp32 = torch.randn(4, 3, 224, 224)
resnet_model_prepared(input_fp32)
vis_trans_prepared(input_fp32)

tensor([[-0.4991, -0.4257, -0.1935,  ..., -0.3780,  0.0137, -0.5475],
        [-0.5217, -0.4660, -0.1405,  ..., -0.4145,  0.3004, -0.3279],
        [-0.5042, -0.3319, -0.0387,  ..., -0.3302,  0.0510, -0.4187],
        [-0.6160, -0.3664, -0.0742,  ..., -0.4268,  0.1430, -0.4532]],
       grad_fn=<AddmmBackward0>)

In [None]:
resnet_int8 = torch.ao.quantization.convert(resnet_model_prepared)
vis_trans_int8 = torch.ao.quantization.convert(vis_trans_prepared)



In [None]:
print_size_of_model(resnet_int8, 'resnet_int8')
print_size_of_model(vis_trans_int8, 'vis_trans_int8')

model:  resnet_int8  	 Size (MB): 44.32
model:  vis_trans_int8  	 Size (MB): 585.48


# Prunning