In [1]:
import datetime
import os
import sys
import time
import collections

import torch
import torch.utils.data
from torch import nn

from tqdm import tqdm

import torchvision
from torchvision import transforms

from pytorch_quantization import nn as quant_nn
from pytorch_quantization import calib
from pytorch_quantization.tensor_quant import QuantDescriptor

from absl import logging
logging.set_verbosity(logging.FATAL)  # Disable logging as they are too noisy in notebook



In [2]:
# For simplicity, import train and eval functions from the train script from torchvision instead of copything them here
# Download torchvision from https://github.com/pytorch/vision
sys.path.append("/objdet/vision/references/classification/")
from train import evaluate, train_one_epoch, load_data

## Set default QuantDescriptor to use histogram based calibration for activation

In [3]:
quant_desc_input = QuantDescriptor(calib_method='histogram')
quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)

## Initialize quantized modules

In [4]:
from pytorch_quantization import quant_modules
quant_modules.initialize()

## Create model with pretrained weight

In [5]:
model = torchvision.models.resnet50(pretrained=True, progress=False)
model.cuda()

ResNet(
  (conv1): QuantConv2d(
    3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
    (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
  )
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): QuantConv2d(
        64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False
        (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
        (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
      )
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=Tru

## Create data loader

In [6]:
data_path = "/objdet/imagenet/"
batch_size = 512

traindir = os.path.join(data_path, 'train')
valdir = os.path.join(data_path, 'val')
_args = collections.namedtuple('mock_args', ['model', 'distributed', 'cache_dataset', 'val_resize_size', 'val_crop_size', 'train_crop_size', 'interpolation', 'prototype'])
dataset, dataset_test, train_sampler, test_sampler = load_data(traindir, valdir, _args(model='resnet50', distributed=False, cache_dataset=False, val_resize_size=256, val_crop_size=224, train_crop_size=224, interpolation='bilinear', prototype=None))

data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size,
    sampler=train_sampler, num_workers=4, pin_memory=True)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=batch_size,
    sampler=test_sampler, num_workers=4, pin_memory=True)


Loading data
Loading training data
Took 5.052899122238159
Loading validation data
Creating data loaders


## Calibrate the model

In [7]:
def collect_stats(model, data_loader, num_batches):
    """Feed data to the network and collect statistic"""

    # Enable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.disable_quant()
                module.enable_calib()
            else:
                module.disable()

    for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches):
        model(image.cuda())
        if i >= num_batches:
            break

    # Disable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.enable_quant()
                module.disable_calib()
            else:
                module.enable()
            
def compute_amax(model, **kwargs):
    # Load calib result
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                if isinstance(module._calibrator, calib.MaxCalibrator):
                    module.load_calib_amax()
                else:
                    module.load_calib_amax(**kwargs)
#             print(F"{name:40}: {module}")
    model.cuda()

In [8]:
# It is a bit slow since we collect histograms on CPU
with torch.no_grad():
    collect_stats(model, data_loader, num_batches=2)
    compute_amax(model, method="percentile", percentile=99.99)

  img = torch.as_tensor(np.asarray(pic))
  img = torch.as_tensor(np.asarray(pic))
  img = torch.as_tensor(np.asarray(pic))
  img = torch.as_tensor(np.asarray(pic))
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
100%|██████████| 2/2 [09:21<00:00, 280.84s/it]


## Now evaluate the calibrated model

In [9]:
criterion = nn.CrossEntropyLoss()
with torch.no_grad():
    evaluate(model, criterion, data_loader_test, device="cuda", print_freq=20)
    
# Save the model
torch.save(model.state_dict(), "/tmp/quant_resnet50-calibrated.pth")

  img = torch.as_tensor(np.asarray(pic))
  img = torch.as_tensor(np.asarray(pic))
  img = torch.as_tensor(np.asarray(pic))
  img = torch.as_tensor(np.asarray(pic))


Test:   [ 0/98]  eta: 0:11:33  loss: 0.5672 (0.5672)  acc1: 85.7422 (85.7422)  acc5: 96.2891 (96.2891)  time: 7.0780  data: 5.5019  max mem: 5880
Test:   [20/98]  eta: 0:02:21  loss: 0.6620 (0.6810)  acc1: 83.2031 (82.6544)  acc5: 95.8984 (95.8705)  time: 1.5538  data: 0.0003  max mem: 5882
Test:   [40/98]  eta: 0:01:37  loss: 0.7045 (0.7154)  acc1: 80.2734 (81.4977)  acc5: 95.8984 (95.7746)  time: 1.5526  data: 0.0003  max mem: 5882
Test:   [60/98]  eta: 0:01:02  loss: 1.1120 (0.8594)  acc1: 71.6797 (78.3523)  acc5: 90.2344 (94.1310)  time: 1.5172  data: 0.0003  max mem: 5882
Test:   [80/98]  eta: 0:00:28  loss: 1.1234 (0.9377)  acc1: 72.6562 (76.7626)  acc5: 89.8438 (93.1448)  time: 1.4737  data: 0.0003  max mem: 5882
Test:  Total time: 0:02:33
Test:  Acc@1 76.194 Acc@5 92.916


## We can also try different calibrations and see which one works the best

In [10]:
with torch.no_grad():
    compute_amax(model, method="percentile", percentile=99.9)
    evaluate(model, criterion, data_loader_test, device="cuda", print_freq=20)

  img = torch.as_tensor(np.asarray(pic))
  img = torch.as_tensor(np.asarray(pic))
  img = torch.as_tensor(np.asarray(pic))
  img = torch.as_tensor(np.asarray(pic))


Test:   [ 0/98]  eta: 0:10:38  loss: 0.6082 (0.6082)  acc1: 85.1562 (85.1562)  acc5: 95.3125 (95.3125)  time: 6.5149  data: 5.0029  max mem: 5882
Test:   [20/98]  eta: 0:02:14  loss: 0.6760 (0.7042)  acc1: 81.8359 (81.8824)  acc5: 95.7031 (95.5171)  time: 1.4811  data: 0.0003  max mem: 5882
Test:   [40/98]  eta: 0:01:33  loss: 0.7239 (0.7349)  acc1: 79.1016 (80.8117)  acc5: 95.8984 (95.5316)  time: 1.4828  data: 0.0003  max mem: 5882
Test:   [60/98]  eta: 0:00:59  loss: 1.1130 (0.8802)  acc1: 71.4844 (77.6447)  acc5: 91.0156 (93.8653)  time: 1.4832  data: 0.0002  max mem: 5882
Test:   [80/98]  eta: 0:00:27  loss: 1.1479 (0.9612)  acc1: 71.2891 (76.0272)  acc5: 89.6484 (92.7686)  time: 1.4967  data: 0.0002  max mem: 5882
Test:  Total time: 0:02:30
Test:  Acc@1 75.452 Acc@5 92.530


In [11]:
with torch.no_grad():
    for method in ["mse", "entropy"]:
        print(F"{method} calibration")
        compute_amax(model, method=method)
        evaluate(model, criterion, data_loader_test, device="cuda", print_freq=20)

mse calibration


  img = torch.as_tensor(np.asarray(pic))
  img = torch.as_tensor(np.asarray(pic))
  img = torch.as_tensor(np.asarray(pic))
  img = torch.as_tensor(np.asarray(pic))


Test:   [ 0/98]  eta: 0:09:57  loss: 0.5587 (0.5587)  acc1: 85.3516 (85.3516)  acc5: 96.0938 (96.0938)  time: 6.0965  data: 4.7183  max mem: 5882
Test:   [20/98]  eta: 0:02:05  loss: 0.6701 (0.6823)  acc1: 82.2266 (82.4870)  acc5: 96.0938 (95.8333)  time: 1.3899  data: 0.0003  max mem: 5882
Test:   [40/98]  eta: 0:01:27  loss: 0.7019 (0.7156)  acc1: 79.8828 (81.3881)  acc5: 96.0938 (95.7555)  time: 1.3903  data: 0.0003  max mem: 5882
Test:   [60/98]  eta: 0:00:55  loss: 1.1151 (0.8587)  acc1: 72.2656 (78.2947)  acc5: 90.6250 (94.0926)  time: 1.3931  data: 0.0003  max mem: 5882
Test:   [80/98]  eta: 0:00:26  loss: 1.1234 (0.9370)  acc1: 72.2656 (76.7048)  acc5: 89.6484 (93.1496)  time: 1.3967  data: 0.0003  max mem: 5882
Test:  Total time: 0:02:20
Test:  Acc@1 76.150 Acc@5 92.926
entropy calibration


  img = torch.as_tensor(np.asarray(pic))
  img = torch.as_tensor(np.asarray(pic))
  img = torch.as_tensor(np.asarray(pic))
  img = torch.as_tensor(np.asarray(pic))


Test:   [ 0/98]  eta: 0:09:58  loss: 0.5608 (0.5608)  acc1: 84.9609 (84.9609)  acc5: 96.0938 (96.0938)  time: 6.1028  data: 4.7199  max mem: 5882
Test:   [20/98]  eta: 0:02:05  loss: 0.6716 (0.6821)  acc1: 82.6172 (82.5056)  acc5: 96.0938 (95.8426)  time: 1.3842  data: 0.0002  max mem: 5882
Test:   [40/98]  eta: 0:01:26  loss: 0.7012 (0.7154)  acc1: 80.0781 (81.4167)  acc5: 95.8984 (95.7317)  time: 1.3851  data: 0.0002  max mem: 5882
Test:   [60/98]  eta: 0:00:55  loss: 1.1127 (0.8588)  acc1: 72.8516 (78.3715)  acc5: 90.8203 (94.0798)  time: 1.3931  data: 0.0003  max mem: 5882
Test:   [80/98]  eta: 0:00:26  loss: 1.1266 (0.9372)  acc1: 72.2656 (76.7795)  acc5: 89.6484 (93.1544)  time: 1.3924  data: 0.0003  max mem: 5882
Test:  Total time: 0:02:20
Test:  Acc@1 76.204 Acc@5 92.942
