<a href="https://colab.research.google.com/github/wileyw/DeepLearningDemos/blob/master/Quantization/Quantization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import glob
from itertools import chain
import os
import random
import zipfile

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm


print(f"Torch: {torch.__version__}")

# Training settings
batch_size = 64
epochs = 20
lr = 3e-5
gamma = 0.7
seed = 42

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

device = torch.device('cuda')

# # Necessary to setup quantization
# qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
# torch.backends.quantized.engine = 'qnnpack'

def print_size_of_model(model, label=""):
    torch.save(model.state_dict(), "temp.p")
    size=os.path.getsize("temp.p")
    print("model: ",label,' \t','Size (KB):', size/1e3)
    os.remove('temp.p')
    return size

Torch: 1.10.0+cu111


In [2]:
!pip install tqdm requests regex sentencepiece

Collecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 5.1 MB/s 
Installing collected packages: sentencepiece
Successfully installed sentencepiece-0.1.96


# Load Data & Augmentations

In [3]:
tsfm = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1) )  # Grayscale to Color
])

fmnist_train_data = torchvision.datasets.FashionMNIST(root='/data', train=True, download=True, transform=tsfm)
fmnist_test_data = torchvision.datasets.FashionMNIST(root='/data', train=False, download=True, transform=tsfm)

train_loader = DataLoader(dataset = fmnist_train_data, batch_size = batch_size, shuffle = True)
test_loader = DataLoader(dataset = fmnist_test_data, batch_size = batch_size, shuffle = True)

print(len(fmnist_train_data), len(train_loader))
print(len(fmnist_test_data), len(test_loader))

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /data/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting /data/FashionMNIST/raw/train-images-idx3-ubyte.gz to /data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to /data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting /data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to /data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting /data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to /data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting /data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /data/FashionMNIST/raw

60000 938
10000 157


# Load Model

In [None]:
"""
model = torchvision.models.resnet18(pretrained=True)
# or any of these variants
model = torchvision.models.resnet34(pretrained=True)
model = torchvision.models.resnet50(pretrained=True)
model = torchvision.models.resnet101(pretrained=True)
model = torchvision.models.resnet152(pretrained=True)
"""
model = torchvision.models.resnet18(pretrained=True)

model = model.train()  # Set model to training mode.
# model = model.eval()

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

# Setup Quantization

Quantization Implementations:

*   Fine-tune as-is & post-training static quantization
*   Quantization aware training
*   [Dynamic quantization
](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html)


## 1. Dynamic Quantization

In [None]:
float_model = model

dq_model_fc = torch.quantization.quantize_dynamic(float_model, {torch.nn.Linear}, dtype=torch.qint8)
dq_model_conv2d = torch.quantization.quantize_dynamic(float_model, {torch.nn.Conv2d}, dtype=torch.qint8)
dq_model_bn = torch.quantization.quantize_dynamic(float_model, {torch.nn.BatchNorm2d}, dtype=torch.qint8)
dq_model = torch.quantization.quantize_dynamic(float_model, dtype=torch.qint8)

In [None]:
def print_size_of_model(model, label=""):
    torch.save(model.state_dict(), "temp.p")
    size=os.path.getsize("temp.p")
    print("model: ",label,' \t','Size (KB):', size/1e3)
    os.remove('temp.p')
    return size

f=print_size_of_model(float_model,"fp32")
print()

print('Quantize fc layer')
q1=print_size_of_model(dq_model_fc,"int8")
print("{0:.2f} times smaller".format(f/q1))
print()

print('Quantize convolution layer')
q2=print_size_of_model(dq_model_conv2d,"int8")
print("{0:.2f} times smaller".format(f/q2))
print()

print('Quantize batch norm layer')
q3=print_size_of_model(dq_model_bn,"int8")
print("{0:.2f} times smaller".format(f/q3))
print()

print('Quantize all layers')
q4=print_size_of_model(dq_model,"int8")
print("{0:.2f} times smaller".format(f/q4))
print()

model:  fp32  	 Size (KB): 46834.317

Quantize fc layer
model:  int8  	 Size (KB): 45299.145
1.03 times smaller

Quantize convolution layer
model:  int8  	 Size (KB): 46834.317
1.00 times smaller

Quantize batch norm layer
model:  int8  	 Size (KB): 46834.317
1.00 times smaller

Quantize all layers
model:  int8  	 Size (KB): 45299.145
1.03 times smaller



https://pytorch.org/docs/stable/quantization.html#torch.quantization.quantize_dynamic

"This is the simplest to apply form of quantization where the weights are quantized ahead of time but the activations are dynamically quantized during inference. This is used for situations where the model execution time is dominated by loading weights from memory rather than computing the matrix multiplications. This is true for for LSTM and Transformer type models with small batch size."

In [None]:
# Let's try with transformers
!pip install transformers

Collecting transformers
  Downloading transformers-4.15.0-py3-none-any.whl (3.4 MB)
[K     |████████████████████████████████| 3.4 MB 5.2 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.4.0-py3-none-any.whl (67 kB)
[K     |████████████████████████████████| 67 kB 4.7 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.47-py2.py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 35.6 MB/s 
Collecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 47.7 MB/s 
[?25hCollecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 47.6 MB/s 
Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, transformers
  A

In [None]:
from transformers import BertModel

bert_model_float = BertModel.from_pretrained('bert-base-uncased')
bert_model_dq = torch.quantization.quantize_dynamic(bert_model_float, dtype=torch.qint8)

f=print_size_of_model(bert_model_float, "fp32")
q=print_size_of_model(bert_model_dq, "int8")
print("{0:.2f} times smaller".format(f/q))

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


model:  fp32  	 Size (KB): 438007.537
model:  int8  	 Size (KB): 181490.125
2.41 times smaller


**TODO**
1. Finetune model to cat v. dog problem
2. Use training data as representative dataset for post-training static quantization on finetuned model & evaluate against finetuned model
3. Use training data to finetune Quantization Aware Training model

Create model for cat v. dog -> finetuning -> Post-training static quantization


 -> quantization aware training

 Note: https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html

# Pre 2 and 3: Create model for post-training static quantization and quantization aware training

## 2. Post-training Static Quantization


In [42]:
# Set up model to train.
# model_fp32 = torchvision.models.quantization.resnet18(pretrained=True)

# Copied from https://leimao.github.io/blog/PyTorch-Static-Quantization/
class QuantizedResNet18(nn.Module):
    def __init__(self, model_fp32):
        super(QuantizedResNet18, self).__init__()
        # QuantStub converts tensors from floating point to quantized.
        # This will only be used for inputs.
        self.quant = torch.quantization.QuantStub()
        # DeQuantStub converts tensors from quantized to floating point.
        # This will only be used for outputs.
        self.dequant = torch.quantization.DeQuantStub()
        # FP32 model
        self.model_fp32 = model_fp32

    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        x = self.quant(x)
        x = self.model_fp32(x)
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        x = self.dequant(x)
        return x

model_fp32 = torchvision.models.resnet18(pretrained=True)
model_fp32.fc = nn.Linear(in_features=512, out_features=10, bias = True)
model_fp32 = QuantizedResNet18(model_fp32)
model_fp32.train()
params_to_update = model_fp32.parameters()

In [43]:
# Training settings
batch_size = 64
epochs = 1 # 10
lr = 1e-3
gamma = 0.7

In [44]:
# loss function
criterion = nn.BCEWithLogitsLoss()
# optimizer
optimizer = optim.Adam(params_to_update, lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

model_fp32 = model_fp32.to(device)

# Training loop.
for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        label = torch.nn.functional.one_hot(label, 10).float()
        data = data.to(device)
        label = label.to(device)

        output = model_fp32(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label.argmax(dim=1)).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in test_loader:
            data = data.to(device)
            label = torch.nn.functional.one_hot(label, 10).float()
            label = label.to(device)

            val_output = model_fp32(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label.argmax(dim=1)).float().mean()
            epoch_val_accuracy += acc / len(test_loader)
            epoch_val_loss += val_loss / len(test_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

  0%|          | 0/938 [00:00<?, ?it/s]

Epoch : 1 - loss : 0.0701 - acc: 0.8629 - val_loss : 0.0597 - val_acc: 0.8788



In [45]:
# set the qconfig for PTQ
ptq_qconfig = torch.quantization.get_default_qconfig('fbgemm')
# or, set the qconfig for QAT
qat_qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
# set the qengine to control weight packing
torch.backends.quantized.engine = 'fbgemm'

In [None]:
# Return model to cpu and setup quantization.
model_fp32 = model_fp32.to(torch.device('cpu'))
model_fp32.eval()
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')
# model_fp32_quantizable = QuantizedResNet18(model_fp32)
# model_fp32_fused = torch.quantization.fuse_modules(model_fp32, [['conv1', 'bn1', 'relu']])
model_fp32_prepared = torch.quantization.prepare(model_fp32)

# Run representative dataset here.
input_fp32 = test_loader
for data, label in tqdm(input_fp32):
    model_fp32_prepared(data)

# Quantize.
model_int8 = torch.quantization.convert(model_fp32_prepared, inplace=False)

  reduce_range will be deprecated in a future release of PyTorch."


  0%|          | 0/157 [00:00<?, ?it/s]

In [47]:
# Calculate quantization space savings.
f=print_size_of_model(model_fp32, "fp32")
q=print_size_of_model(model_int8, "int8")
print("{0:.2f} times smaller".format(f/q))

model:  fp32  	 Size (KB): 44805.005
model:  int8  	 Size (KB): 11393.889
3.93 times smaller


In [49]:
# https://pytorch.org/tutorials/recipes/quantization.html

model_fp32 = model_fp32.to(torch.device('cpu'))
# quant = torch.quantization.QuantStub()
# dequant = torch.quantization.DeQuantStub()

with torch.no_grad():
    fp32_val_accuracy = 0
    fp32_val_loss = 0
    int8_val_accuracy = 0
    int8_val_loss = 0
    # for data, label in test_loader:
    for data, label in tqdm(input_fp32):
        label = torch.nn.functional.one_hot(label, 10).float()

        fp32_val_output = model_fp32(data)
        int8_val_output = model_int8(data)

        fp32_val_loss = criterion(fp32_val_output, label)
        int8_val_loss = criterion(int8_val_output, label)
        fp32_acc = (fp32_val_output.argmax(dim=1) == label.argmax(dim=1)).float().mean()
        int8_acc = (int8_val_output.argmax(dim=1) == label.argmax(dim=1)).float().mean()

        fp32_val_accuracy += fp32_acc / len(test_loader)
        fp32_val_loss += fp32_val_loss / len(test_loader)
        int8_val_accuracy += int8_acc / len(test_loader)
        int8_val_loss += int8_val_loss / len(test_loader)
print(f"FP32: {fp32_val_loss:.4f} - val_acc: {fp32_val_accuracy:.4f}\n")
print(f"INT8: {int8_val_loss:.4f} - val_acc: {int8_val_accuracy:.4f}\n")

  0%|          | 0/157 [00:00<?, ?it/s]

NotImplementedError: ignored

## 3. Quantization Aware Training


In [14]:
float_model = torchvision.models.quantization.resnet18(pretrained=True)
float_model.fc = nn.Linear(in_features=512, out_features=10, bias = True)
float_model.train()
params_to_update = float_model.parameters()

float_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') 
torch.backends.quantized.engine = 'fbgemm'
# TODO: When writing post, explore other qconfigs (fbgemm == server inference, qnnpack == mobile, what else?)
# Also talk about symmetic vs assymetric quantization, etc.
float_model_fused = torch.quantization.fuse_modules(float_model,
    [['conv1', 'bn1', 'relu']])
float_model_prepared = torch.quantization.prepare_qat(float_model_fused)
params_to_update = float_model_prepared.parameters()

  reduce_range will be deprecated in a future release of PyTorch."


In [15]:
# loss function
criterion = nn.BCEWithLogitsLoss()
# optimizer
optimizer = optim.Adam(params_to_update, lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

float_model_prepared = float_model_prepared.to(device)

for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        # print(data)
        # label = label.reshape((-1, 1)).float()
        label = torch.nn.functional.one_hot(label, 10).float()
        data = data.to(device)
        label = label.to(device)

        output = float_model_prepared(data)
        # print(output.type())
        # print(label.type())
        # print(label)
        # print(data.shape, output.shape, label.shape)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label.argmax(dim=1)).float().mean()
        # acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        # for data, label in valid_loader:
        for data, label in test_loader:
            data = data.to(device)
            label = torch.nn.functional.one_hot(label, 10).float()
            label = label.to(device)

            val_output = float_model_prepared(data)
            val_loss = criterion(val_output, label)
            # print(data.shape, val_output.shape, label.shape)

            acc = (val_output.argmax(dim=1) == label.argmax(dim=1)).float().mean()
            # acc = (val_output.argmax(dim=1) == label).float().mean()
            # epoch_val_accuracy += acc / len(valid_loader)
            # epoch_val_loss += val_loss / len(valid_loader)
            epoch_val_accuracy += acc / len(test_loader)
            epoch_val_loss += val_loss / len(test_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

  0%|          | 0/938 [00:00<?, ?it/s]

Epoch : 1 - loss : 0.0718 - acc: 0.8609 - val_loss : 0.0594 - val_acc: 0.8819



In [16]:
float_model_prepared = float_model_prepared.to('cpu')
float_model_prepared.eval()
model_int8 = torch.quantization.convert(float_model_prepared)

f=print_size_of_model(float_model_prepared,"fp32")
q=print_size_of_model(model_int8,"int8")
print("{0:.2f} times smaller".format(f/q))

model:  fp32  	 Size (KB): 45050.361
model:  int8  	 Size (KB): 11391.881
3.95 times smaller


In [17]:


with torch.no_grad():
    fp32_val_accuracy = 0
    fp32_val_loss = 0
    int8_val_accuracy = 0
    int8_val_loss = 0
    for data, label in test_loader:
        # data = data.to(device)
        label = torch.nn.functional.one_hot(label, 10).float()
        # label = label.to(device)
        # print(data.device)
        # print(label.device)
        # print(model_fp32)
        # print(model_int8)

        fp32_val_output = model_fp32(data)
        int8_val_output = model_int8(data)

        fp32_val_loss = criterion(fp32_val_output, label)
        int8_val_loss = criterion(int8_val_output, label)
        fp32_acc = (fp32_val_output.argmax(dim=1) == label.argmax(dim=1)).float().mean()
        int8_acc = (int8_val_output.argmax(dim=1) == label.argmax(dim=1)).float().mean()

        fp32_val_accuracy += fp32_acc / len(test_loader)
        fp32_val_loss += fp32_val_loss / len(test_loader)
        int8_val_accuracy += int8_acc / len(test_loader)
        int8_val_loss += int8_val_loss / len(test_loader)
print(f"FP32: {fp32_val_loss:.4f} - val_acc: {fp32_val_accuracy:.4f}\n")
print(f"INT8: {int8_val_loss:.4f} - val_acc: {int8_val_accuracy:.4f}\n")

FP32: 0.0492 - val_acc: 0.8896

INT8: 0.0594 - val_acc: 0.8898

