## Baseline

In [None]:
# For Colab
# from google.colab import drive
# drive.mount('/content/drive/')

Mounted at /content/drive/
/content/drive/MyDrive/Course/cinnamon/pytorch-compression-cifar100


In [None]:
# !gdown 1zFsUJWH86xxB5-aQ66GJ9WyY8bkwrG0p

In [1]:
from copy import deepcopy
import torch
import torch.nn as nn

from conf import settings
from utils import get_test_dataloader
from benchmark import evaluate, create_torch_profile, get_model_size


def create_report(model, test_loader, device, dummy_input=None):
    if dummy_input is None:
        dummy_input = torch.randn(1, 3, 32, 32, device=device)

    # Error
    top1_err, top5_err, t = evaluate(model, test_loader, device)
    print('Top 1 err:', top1_err)
    print('Top 5 err:', top5_err)
    print(f'Time per image: {t} (ms)')

    # Size
    model_size = get_model_size(model)
    print(f"Model size: {model_size/1e3} (MB)")


    # Torch profile
    create_torch_profile(model, dummy_input, device)

In [2]:
device = 'cpu'

test_loader = get_test_dataloader(
    settings.CIFAR100_TRAIN_MEAN,
    settings.CIFAR100_TRAIN_STD,
    num_workers=4,
    batch_size=16,
)

Files already downloaded and verified


In [6]:
from models.vgg import vgg19_bn

baseline = vgg19_bn()
baseline.load_state_dict(torch.load('vgg19-61-best.pth', map_location='cpu'))

<All keys matched successfully>

In [7]:
create_report(baseline, test_loader, device)

Top 1 err: 0.3230999708175659
Top 5 err: 0.11979997158050537
Time per image: 21.051048493385313 (ms)
Model size: 157.384365 (MB)
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                      aten::empty         0.71%     280.000us         0.71%     280.000us       2.188us       5.47 Mb       5.47 Mb           128  
                     aten::conv2d         0.65%     256.000us        68.94%      27.241ms       1.703ms       1.16 Mb           0 b            16  
                aten::convolution         0.71%     282.000us        68.29%      26.985ms       1.687ms       1.16 Mb           0 b

## Quantize

### Post-Training Dynamic/Weight-only Quantization

In [4]:
from torch.quantization import quantize_dynamic

In [10]:
model_quantized_dynamic_float16 = quantize_dynamic(
    model=baseline, qconfig_spec={torch.nn.Linear}, dtype=torch.float16,
)
create_report(model_quantized_dynamic_float16, test_loader, device)



Top 1 err: 0.3230999708175659
Top 5 err: 0.11979997158050537
Time per image: 20.80197432041168 (ms)
Model size: 157.38644699999998 (MB)
----------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                              Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
----------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                       aten::empty         1.16%     485.000us         1.16%     485.000us       3.702us       5.48 Mb       5.48 Mb           131  
                      aten::conv2d         0.34%     141.000us        82.03%      34.256ms       2.141ms       1.16 Mb           0 b            16  
                 aten::convolution         1.47%     612.000us        81.69%      34.115ms       2.132ms       1.16 Mb 

In [8]:
model_quantized_dynamic_int8 = quantize_dynamic(
    model=baseline, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8,
)
create_report(model_quantized_dynamic_int8, test_loader, device)

Top 1 err: 0.32260000705718994
Top 5 err: 0.11970001459121704
Time per image: 10.717263984680175 (ms)
Model size: 99.534671 (MB)
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                      aten::empty         4.12%     640.000us         4.12%     640.000us       4.776us       5.53 Mb       5.53 Mb           134  
                 aten::empty_like         0.30%      47.000us         0.68%     105.000us       5.526us       1.19 Mb           0 b            19  
                     aten::conv2d         2.69%     419.000us        76.13%      11.838ms     739.875us       1.16 Mb           0 b

### Post-Training Static Quantization

In [9]:
from torch.nn.utils.fusion import fuse_conv_bn_eval


def fuse_all_conv_bn(model):
    """
    Fuses all consecutive Conv2d and BatchNorm2d layers.
    License: Copyright Zeeshan Khan Suri, CC BY-NC 4.0
    """
    stack = []
    for name, module in model.named_children(): # immediate children
        if list(module.named_children()): # is not empty (not a leaf)
            fuse_all_conv_bn(module)

        if isinstance(module, nn.BatchNorm2d):
            if isinstance(stack[-1][1], nn.Conv2d):
                setattr(model, stack[-1][0], fuse_conv_bn_eval(stack[-1][1], module))
                setattr(model, name, nn.Identity())
        else:
            stack.append((name, module))

def ptq(model, sample_loader, device='cpu', backend='fbgemm', fuse_bn=True):
    # running on a x86 CPU. Use backend="qnnpack" if running on ARM.
    m = deepcopy(model)
    m.eval()

    # Fuse
    if fuse_bn:
        fuse_all_conv_bn(m)

    # Insert stubs
    m = nn.Sequential(
        torch.quantization.QuantStub(),
        m,
        torch.quantization.DeQuantStub()
    )

    # Prepare
    m.qconfig = torch.quantization.get_default_qconfig(backend)
    torch.quantization.prepare(m, inplace=True)

    # Calibrate
    m.to(device)
    m.eval()
    with torch.no_grad():
        for data, target in sample_loader:
            data = data.to(device)
            m(data)

    # Convert
    torch.quantization.convert(m, inplace=True)

    return m

In [14]:
model_quantized_static_int8 = ptq(baseline, sample_loader=test_loader, device=device, backend='fbgemm', fuse_bn=False)
create_report(model_quantized_static_int8, test_loader, device)



Top 1 err: 0.3264999985694885
Top 5 err: 0.1226000189781189
Time per image: 10.191714763641357 (ms)
Model size: 39.726023 (MB)
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                      aten::empty         0.94%     118.000us         0.94%     118.000us       2.269us       1.13 Mb       1.13 Mb            52  
    aten::_empty_affine_quantized         1.43%     179.000us         1.43%     179.000us       4.366us     625.50 Kb     625.50 Kb            41  
          quantized::batch_norm2d         2.93%     367.000us         5.04%     632.000us      39.500us     297.50 Kb     -41.50 Kb  

In [10]:
model_quantized_static_fuse_int8 = ptq(baseline, sample_loader=test_loader, device=device, backend='fbgemm', fuse_bn=True)
create_report(model_quantized_static_fuse_int8, test_loader, device)
# torch.save(model_quantized_static_fuse_int8.state_dict(), 'vgg19_quantized_static_fuse_int8.pth')



Top 1 err: 0.32440000772476196
Top 5 err: 0.12139999866485596
Time per image: 4.995114207267761 (ms)
Model size: 39.609398999999996 (MB)
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                      aten::empty         0.86%      67.000us         0.86%      67.000us       3.350us       1.19 Mb       1.19 Mb            20  
    aten::_empty_affine_quantized         2.15%     168.000us         2.15%     168.000us       6.720us     329.50 Kb     329.50 Kb            25  
                quantized::conv2d        61.74%       4.824ms        73.23%       5.722ms     357.625us     296.00 Kb      

## Prune

### Pytorch prunning
- Pruning is strictly in research phase and not actually providing any benefits yet.

In [16]:
def measure_module_sparsity(module, weight=True, bias=False, use_mask=False):

    num_zeros = 0
    num_elements = 0

    if use_mask == True:
        for buffer_name, buffer in module.named_buffers():
            if "weight_mask" in buffer_name and weight == True:
                num_zeros += torch.sum(buffer == 0).item()
                num_elements += buffer.nelement()
            if "bias_mask" in buffer_name and bias == True:
                num_zeros += torch.sum(buffer == 0).item()
                num_elements += buffer.nelement()
    else:
        for param_name, param in module.named_parameters():
            if "weight" in param_name and weight == True:
                num_zeros += torch.sum(param == 0).item()
                num_elements += param.nelement()
            if "bias" in param_name and bias == True:
                num_zeros += torch.sum(param == 0).item()
                num_elements += param.nelement()

    sparsity = num_zeros / num_elements

    return num_zeros, num_elements, sparsity


def measure_global_sparsity(model,
                            weight=True,
                            bias=False,
                            conv2d_use_mask=False,
                            linear_use_mask=False):

    num_zeros = 0
    num_elements = 0

    for module_name, module in model.named_modules():

        if isinstance(module, torch.nn.Conv2d):

            module_num_zeros, module_num_elements, _ = measure_module_sparsity(
                module, weight=weight, bias=bias, use_mask=conv2d_use_mask)
            num_zeros += module_num_zeros
            num_elements += module_num_elements

        elif isinstance(module, torch.nn.Linear):

            module_num_zeros, module_num_elements, _ = measure_module_sparsity(
                module, weight=weight, bias=bias, use_mask=linear_use_mask)
            num_zeros += module_num_zeros
            num_elements += module_num_elements

    sparsity = num_zeros / num_elements

    return sparsity

In [17]:
import torch.nn.utils.prune as prune


def remove_parameters(model, bias=False):
    for module_name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d) and prune.is_pruned(module):
            prune.remove(module, "weight")
            if bias:
                prune.remove(module, "bias")
        elif isinstance(module, torch.nn.Linear) and prune.is_pruned(module):
            prune.remove(module, "weight")
            if bias:
                prune.remove(module, "bias")


def prune_model(model, grouped_pruning, conv2d_prune_amount=0, linear_prune_amount=0):
    m = deepcopy(model)

    if grouped_pruning:
        parameters_to_prune = []
        for module_name, module in m.named_modules():
            if isinstance(module, torch.nn.Conv2d):
                parameters_to_prune.append((module, "weight"))
        prune.global_unstructured(
            parameters_to_prune,
            pruning_method=prune.L1Unstructured,
            amount=conv2d_prune_amount,
        )
    else:
        for module_name, module in m.named_modules():
            if isinstance(module, torch.nn.Conv2d):
                prune.l1_unstructured(module,
                                        name="weight",
                                        amount=conv2d_prune_amount)
            elif isinstance(module, torch.nn.Linear):
                prune.l1_unstructured(module,
                                    name="weight",
                                    amount=linear_prune_amount)
    remove_parameters(m)
    print("Sparsity:", measure_global_sparsity(m))
    return m

In [18]:
model_pruned_group = prune_model(baseline, grouped_pruning=True, conv2d_prune_amount=0.6)
create_report(model_pruned_group, test_loader, device)

Sparsity: 0.3056096087489639




Top 1 err: 0.3636000156402588
Top 5 err: 0.14230000972747803
Time per image: 22.443213438987733 (ms)
Model size: 157.384365 (MB)
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                      aten::empty         1.24%     591.000us         1.24%     591.000us       4.617us       5.48 Mb       5.48 Mb           128  
                     aten::conv2d         0.36%     174.000us        74.92%      35.719ms       2.232ms       1.16 Mb           0 b            16  
                aten::convolution         1.07%     511.000us        74.55%      35.545ms       2.222ms       1.16 Mb           0 b

In [19]:
model_pruned = prune_model(baseline, grouped_pruning=False, conv2d_prune_amount=0.6, linear_prune_amount=0)
create_report(model_pruned, test_loader, device)

Sparsity: 0.3056096596358615
Top 1 err: 0.4927999973297119
Top 5 err: 0.23309999704360962
Time per image: 21.303886651992798 (ms)
Model size: 157.384365 (MB)
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                      aten::empty         5.48%       3.147ms         5.48%       3.147ms      24.586us       5.47 Mb       5.47 Mb           128  
                     aten::conv2d         0.25%     144.000us        79.73%      45.805ms       2.863ms       1.16 Mb           0 b            16  
                aten::convolution         0.80%     458.000us        79.48%      45.661ms       2.854m

## ONNX

In [21]:
!pip install onnx
!pip install onnxruntime

Collecting onnx
  Downloading onnx-1.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.6/14.6 MB[0m [31m32.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: onnx
Successfully installed onnx-1.14.0
Collecting onnxruntime
  Downloading onnxruntime-1.15.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.9/5.9 MB[0m [31m14.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m10.5 MB/s[0m 

In [22]:
import onnx
import onnxruntime

from benchmark import evaluate_onnx

In [28]:
def export_onnx(model, save_path, device='cpu', verbose=False):
    model.to(device)
    model.eval()
    x = torch.randn(1, 3, 32, 32, requires_grad=True, device=device)

    torch.onnx.export(
        model,             # model being run
        x,                 # model input (or a tuple for multiple inputs)
        save_path,   # where to save the model (can be a file or file-like object)
        opset_version=15,  # the ONNX version to export the model to
        do_constant_folding=True,  # whether to execute constant folding for optimization
        input_names = ['input'],   # the model's input names
        output_names = ['output'], # the model's output names
        dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                    'output' : {0 : 'batch_size'}},
        verbose=verbose,
    )

    # Check
    onnx_model = onnx.load(save_path)
    onnx.checker.check_model(onnx_model)

In [34]:
export_onnx(baseline, save_path='vgg19_bn.onnx', device=device, verbose=True)

verbose: False, log level: Level.ERROR



In [23]:
ort_session = onnxruntime.InferenceSession("vgg19_bn.onnx")
top1_err, top5_err, t = evaluate_onnx(ort_session, test_loader)
print('Top 1 err:', top1_err)
print('Top 5 err:', top5_err)
print(f'Time per image: {t} (ms)')



Top 1 err: 0.3230999708175659
Top 5 err: 0.11979997158050537
Time per image: 17.320813941955567 (ms)
