In [1]:
import datetime
import os
import sys
import time
import collections

import torch
import torch.utils.data
from torch import nn

from tqdm import tqdm

import torchvision
from torchvision import transforms

from pytorch_quantization import nn as quant_nn
from pytorch_quantization import calib
from pytorch_quantization.tensor_quant import QuantDescriptor

from absl import logging
logging.set_verbosity(logging.FATAL)  # Disable logging as they are too noisy in notebook



In [2]:
# For simplicity, import train and eval functions from the train script from torchvision instead of copything them here
# Download torchvision from https://github.com/pytorch/vision
sys.path.append("/objdet/vision/references/classification/")
from train import evaluate, train_one_epoch, load_data

## Set default QuantDescriptor to use histogram based calibration for activation

In [3]:
quant_desc_input = QuantDescriptor(calib_method='histogram')
quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)

## Initialize quantized modules

In [4]:
from pytorch_quantization import quant_modules
quant_modules.initialize()

## Create model with pretrained weight

In [7]:
from torchvision_cust.models.classification import resnet50 as r50

model = r50(pretrained=True, quantize=True,  progress=False)
model.cuda()

ResNet(
  (conv1): QuantConv2d(
    3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
    (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
  )
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): QuantConv2d(
        64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False
        (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
        (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
      )
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=Tru

## Create data loader

In [8]:
data_path = "/objdet/imagenet/"
batch_size = 512

traindir = os.path.join(data_path, 'train')
valdir = os.path.join(data_path, 'val')
_args = collections.namedtuple('mock_args', ['model', 'distributed', 'cache_dataset', 'val_resize_size', 'val_crop_size', 'train_crop_size', 'interpolation', 'prototype'])
dataset, dataset_test, train_sampler, test_sampler = load_data(traindir, valdir, _args(model='resnet50', distributed=False, cache_dataset=False, val_resize_size=256, val_crop_size=224, train_crop_size=224, interpolation='bilinear', prototype=None))

data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size,
    sampler=train_sampler, num_workers=4, pin_memory=True)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=batch_size,
    sampler=test_sampler, num_workers=4, pin_memory=True)


Loading data
Loading training data
Took 7.04069447517395
Loading validation data
Creating data loaders


## Calibrate the model

In [9]:
def collect_stats(model, data_loader, num_batches):
    """Feed data to the network and collect statistic"""

    # Enable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.disable_quant()
                module.enable_calib()
            else:
                module.disable()

    for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches):
        model(image.cuda())
        if i >= num_batches:
            break

    # Disable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.enable_quant()
                module.disable_calib()
            else:
                module.enable()
            
def compute_amax(model, **kwargs):
    # Load calib result
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                if isinstance(module._calibrator, calib.MaxCalibrator):
                    module.load_calib_amax()
                else:
                    module.load_calib_amax(**kwargs)
#             print(F"{name:40}: {module}")
    model.cuda()

In [10]:
# It is a bit slow since we collect histograms on CPU
with torch.no_grad():
    collect_stats(model, data_loader, num_batches=2)
    compute_amax(model, method="percentile", percentile=99.99)

100%|██████████| 2/2 [14:45<00:00, 442.86s/it]


## Now evaluate the calibrated model

In [11]:
criterion = nn.CrossEntropyLoss()
with torch.no_grad():
    evaluate(model, criterion, data_loader_test, device="cuda", print_freq=20)
    
# Save the model
torch.save(model.state_dict(), "./quant_resnet50-calibrated-opt.pth")

Test:   [ 0/98]  eta: 0:11:42  loss: 0.5699 (0.5699)  acc1: 85.5469 (85.5469)  acc5: 96.0938 (96.0938)  time: 7.1673  data: 5.3114  max mem: 7157
Test:   [20/98]  eta: 0:02:44  loss: 0.6679 (0.6834)  acc1: 82.8125 (82.4870)  acc5: 95.8984 (95.7589)  time: 1.8545  data: 0.0003  max mem: 7159
Test:   [40/98]  eta: 0:01:55  loss: 0.7024 (0.7166)  acc1: 79.8828 (81.3786)  acc5: 96.0938 (95.7174)  time: 1.8716  data: 0.0003  max mem: 7159
Test:   [60/98]  eta: 0:01:14  loss: 1.1042 (0.8603)  acc1: 71.8750 (78.2691)  acc5: 90.6250 (94.1182)  time: 1.8724  data: 0.0003  max mem: 7159
Test:   [80/98]  eta: 0:00:34  loss: 1.1309 (0.9388)  acc1: 72.2656 (76.7024)  acc5: 90.0391 (93.1472)  time: 1.8541  data: 0.0002  max mem: 7159
Test:  Total time: 0:03:05
Test:  Acc@1 76.120 Acc@5 92.930


## We can also try different calibrations and see which one works the best

In [None]:
with torch.no_grad():
    compute_amax(model, method="percentile", percentile=99.9)
    evaluate(model, criterion, data_loader_test, device="cuda", print_freq=20)

In [None]:
with torch.no_grad():
    for method in ["mse", "entropy"]:
        print(F"{method} calibration")
        compute_amax(model, method=method)
        evaluate(model, criterion, data_loader_test, device="cuda", print_freq=20)