<a href="https://colab.research.google.com/github/xuwangfmc/dlbook/blob/main/modelcompression/WeightQuantization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 权重量化
   
权重量化（Weight Quantization）是指以低于浮点精度的位宽执行计算和存储张量的技术。量化模型使用整数而不是浮点值对张量执行部分或全部操作。这允许更紧凑的模型表示和在许多硬件平台上使用高性能矢量化操作。与典型的FP32模型相比，PyTorch支持INT8量化，允许模型大小减少4倍，内存带宽要求减少4倍。与FP32计算相比，对INT8计算的硬件支持通常快2到4倍。量化主要是一种加速推理的技术，量化运算符仅支持前向传递。对于目前PyTorch自带的量化工具，除了量化感知训练，其余只支持支持AVX2或更高版本的x86 CPU和ARM CPU。 

PyTorch提供两种不同的量化模式：Eager Mode Quantization和FX Graph Mode Quantization。  

- Eager Mode Quantization是测试版功能，用户需要手动融合和指定量化和反量化发生的位置，并且它只支模块而不支持函数。  

- FX Graph Quantization是PyTorch新的自动量化框架，目前是原型功能。它通过添加对函数的支持和量化过程的自动化来改进Eager Mode Quantization，它不适用于任意模型，用户可能需要重构模型以使模型于FX Graph Quantization兼容。

该教程主要介绍了如何用Pytorch自带的库进行模型量化，以及在实际运用中如何将模型从32位转换为16位或者8位。
   

## Eager Mode Quantization
Eager Mode Quantization支持三种类型的量化：
- Dynamic Quantiztion
- Static Quantization
- Quantization Aware Training.

 **Dynamic Quantization**
 
这是最简单的量化形式，其中weight提前量化，activation则在推理期间动态量化。这种方法用于模型执行时间主要是从内存加载权重而不是计算矩阵乘法的情况，适用于小批量的LSTM和Transformer模型。

In [1]:
import torch

# define a floating point model
class M(torch.nn.Module):
    def __init__(self):
        super(M, self).__init__()
        self.fc = torch.nn.Linear(4, 4)

    def forward(self, x):
        x = self.fc(x)
        return x

# create a model instance
model_fp32 = M()
# create a quantized model instance
model_int8 = torch.quantization.quantize_dynamic(
    model_fp32,  # the original model
    {torch.nn.Linear},  # a set of layers to dynamically quantize
    dtype=torch.qint8)  # the target dtype for quantized weights

# run the model
input_fp32 = torch.randn(4, 4, 4, 4)
res = model_int8(input_fp32)

**Static Quantization** 

Static Quantization量化模型的weight和activation。它在可能的情况下将激活融合到前面的层中。它需要使用代表性数据集进行校准，以确定激活的最佳量化参数。Static Quantization也称为Post Training Quantization（PTQ），通常用于节省内存带宽和计算都很重要且CNN被经常使用的情况。

In [2]:
import torch

# define a floating point model where some layers could be statically quantized
class M(torch.nn.Module):
    def __init__(self):
        super(M, self).__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.relu = torch.nn.ReLU()
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        x = self.quant(x)
        x = self.conv(x)
        x = self.relu(x)
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        x = self.dequant(x)
        return x

# create a model instance
model_fp32 = M()

# model must be set to eval mode for static quantization logic to work
model_fp32.eval()

# attach a global qconfig, which contains information about what kind
# of observers to attach. Use 'fbgemm' for server inference and
# 'qnnpack' for mobile inference. Other quantization configurations such
# as selecting symmetric or assymetric quantization and MinMax or L2Norm
# calibration techniques can be specified here.
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')

# Fuse the activations to preceding layers, where applicable.
# This needs to be done manually depending on the model architecture.
# Common fusions include `conv + relu` and `conv + batchnorm + relu`
model_fp32_fused = torch.quantization.fuse_modules(model_fp32, [['conv', 'relu']])

# Prepare the model for static quantization. This inserts observers in
# the model that will observe activation tensors during calibration.
model_fp32_prepared = torch.quantization.prepare(model_fp32_fused)

# calibrate the prepared model to determine quantization parameters for activations
# in a real world setting, the calibration would be done with a representative dataset
input_fp32 = torch.randn(4, 1, 4, 4)
model_fp32_prepared(input_fp32)

# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, and replaces key operators with quantized
# implementations.
model_int8 = torch.quantization.convert(model_fp32_prepared)

# run the model, relevant calculations will happen in int8
res = model_int8(input_fp32)

  reduce_range will be deprecated in a future release of PyTorch."
  src_bin_begin // dst_bin_width, 0, self.dst_nbins - 1
  src_bin_end // dst_bin_width, 0, self.dst_nbins - 1


**Quantization Aware Training** 

它对训练期间的量化效果进行建模，与其他量化方法相比具有更高的准确度。在训练期间，所有计算都是在浮点数中完成的，fake_quant模块通过clamp与round来模拟INT8的量化效果。模型转换后，权重和激活被量化，激活被融合到前一层。它通常与CNN一起使用，与静态量化相比具有更高的准确度，也被称为QAT。

In [3]:
import torch
import torchvision
import torch.utils.data.dataloader as Data
from torch.autograd import Variable
# define a floating point model where some layers could benefit from QAT
class M(torch.nn.Module):
    def __init__(self):
        super(M, self).__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.quantization.QuantStub()
        self.conv1 = torch.nn.Conv2d(1, 32, 3, 1, 1)
        self.relu = torch.nn.ReLU()
        self.maxpool = torch.nn.MaxPool2d(2)
        
        self.conv2 = torch.nn.Conv2d(32, 64, 3, 1, 1)
        self.relu = torch.nn.ReLU()
        self.maxpool = torch.nn.MaxPool2d(2)
        
        self.conv3 = torch.nn.Conv2d(64, 64, 3, 1, 1)
        self.relu = torch.nn.ReLU()
        self.maxpool = torch.nn.MaxPool2d(2)
        
        self.dense1 = torch.nn.Linear(64 * 3 * 3, 128)
        self.relu = torch.nn.ReLU()
        self.dense2 = torch.nn.Linear(128, 10)
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.quantization.DeQuantStub()
    def forward(self, x):
        x = self.quant(x)
        x = self.maxpool(self.relu(self.conv1(x)))
        x = self.maxpool(self.relu(self.conv2(x)))
        x = self.maxpool(self.relu(self.conv3(x)))
        res = x.contiguous().view(x.size(0), -1)
        out = self.dense2(self.relu(self.dense1(res)))
        out = self.dequant(out)
        return out

        return x
train_data = torchvision.datasets.MNIST(
    './mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True
)
train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)

def training_loop(model):
    optimizer = torch.optim.Adam(model.parameters())
    loss_func = torch.nn.CrossEntropyLoss()
    for epoch in range(1):
        # training
        train_loss = 0.
        train_acc = 0.
        for batch_x, batch_y in train_loader:
            batch_x, batch_y = Variable(batch_x), Variable(batch_y)
            out = model(batch_x)
            loss = loss_func(out, batch_y)
            pred = torch.max(out, 1)[1]
            train_correct = (pred == batch_y).sum()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        

# create a model instance
model_fp32 = M()

# model must be set to train mode for QAT logic to work
model_fp32.train()

# attach a global qconfig, which contains information about what kind
# of observers to attach. Use 'fbgemm' for server inference and
# 'qnnpack' for mobile inference. Other quantization configurations such
# as selecting symmetric or assymetric quantization and MinMax or L2Norm
# calibration techniques can be specified here.
model_fp32.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')

# fuse the activations to preceding layers, where applicable
# this needs to be done manually depending on the model architecture
model_fp32_fused = torch.quantization.fuse_modules(model_fp32,
    [['conv1', 'relu']])

# Prepare the model for QAT. This inserts observers and fake_quants in
# the model that will observe weight and activation tensors during calibration.
model_fp32_prepared = torch.quantization.prepare_qat(model_fp32_fused)

# run the training loop (not shown)
training_loop(model_fp32_prepared)

# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, fuses modules where appropriate,
# and replaces key operators with quantized implementations.
model_fp32_prepared.eval()
model_int8 = torch.quantization.convert(model_fp32_prepared)
input_fp32 = torch.randn(100,1,28,28)
# run the model, relevant calculations will happen in int8
res = model_int8(input_fp32)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./mnist/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./mnist/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist/MNIST/raw



  reduce_range will be deprecated in a future release of PyTorch."


## FX Graph Mode Quantization
FX Graph Mode 支持的量化类型可以分为两种：Post Training Quantization与Quantization Aware Training。  
FX Graph Mode Quantization 支持的量化类型有：  
Post Training Quantization: Weight Only Quantization、Dynamic Quantization、Static Quantization  
Quantization Aware Training: Static Quantization  
注：以下代码为FX Graph Mode的各种设置方式，无法运行。

```python
import torch.quantization.quantize_fx as quantize_fx
import copy

model_fp = UserModel(...)

#
# post training dynamic/weight_only quantization
#

# we need to deepcopy if we still want to keep model_fp unchanged after quantization since quantization apis change the input model
model_to_quantize = copy.deepcopy(model_fp)
model_to_quantize.eval()
qconfig_dict = {"": torch.quantization.default_dynamic_qconfig}
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict)
# no calibration needed when we only have dynamici/weight_only quantization
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)

#
# post training static quantization
#

model_to_quantize = copy.deepcopy(model_fp)
qconfig_dict = {"": torch.quantization.get_default_qconfig('qnnpack')}
model_to_quantize.eval()
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict)
# calibrate (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)

#
# quantization aware training for static quantization
#

model_to_quantize = copy.deepcopy(model_fp)
qconfig_dict = {"": torch.quantization.get_default_qat_qconfig('qnnpack')}
model_to_quantize.train()
# prepare
model_prepared = quantize_fx.prepare_qat_fx(model_to_qunatize, qconfig_dict)
# training loop (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)

#
# fusion
#
model_to_quantize = copy.deepcopy(model_fp)
model_fused = quantize_fx.fuse_fx(model_to_quantize)
```

## 实战案例
模型分别实现32bit到16bit与8bit的量化，并评估模型大小。

In [4]:
import os
import torch
import pickle
import numpy as np

# 加载训练好的StudentNet参数
!gdown --id '12wtIa0WVRcpboQzhgRUJOpcXe23tgWUL' --output student_custom_small.bin
def encode16(params, save_path):
    """将params压缩到16bit，并保存到save_path"""
    custom_dict = {}
    for name, param in params.items():
        param = np.float64(param.numpy())
        # 有些变量不是ndarray而只是一个数字，这种变量不用压缩
        if type(param) == np.ndarray:
            custom_dict[name] = np.float16(param)
        else:
            custom_dict[name] = param

    pickle.dump(custom_dict, open(save_path, 'wb'))

def decode16(fname):
    '''读取16bit的权重，还原到torch.tensor后以state_dict形式存储'''
    params = pickle.load(open(fname, 'rb'))
    custom_dict = {}
    for (name, param) in params.items():
        param = torch.tensor(param)
        custom_dict[name] = param
    return custom_dict

def encode8(params, save_path):
    """将params压缩到8bit，并保存到save_path"""
    custom_dict = {}
    for (name, param) in params.items():
        param = np.float64(param.numpy())
        if type(param) == np.ndarray:
            min_val = np.min(param)
            max_val = np.max(param)
            param = np.round((param - min_val) / (max_val - min_val) * 255)
            param = np.uint8(param)
            custom_dict[name] = (min_val, max_val, param)
        else:
            custom_dict[name] = param

    pickle.dump(custom_dict, open(save_path, 'wb'))

def decode8(fname):
    '''读取8bit的权重，还原到torch.tensor后以state_dict形式存储'''
    params = pickle.load(open(fname, 'rb'))
    custom_dict = {}
    for (name, param) in params.items():
        if type(param) == tuple:
            min_val, max_val, param = param
            param = np.float64(param)
            param = (param / 255 * (max_val - min_val)) + min_val
            param = torch.tensor(param)
        else:
            param = torch.tensor(param)

        custom_dict[name] = param
    return custom_dict


if __name__ == '__main__':
    print(f"Original Cost: {os.stat('./student_custom_small.bin').st_size} Bytes.")
    old_params = torch.load('./student_custom_small.bin', map_location='cpu')
    encode16(old_params, './student_model_16bit.bin')
    print(f"16-bit Cost: {os.stat('./student_model_16bit.bin').st_size} Bytes.")
    encode8(old_params, './student_model_8bit.bin')
    print(f"8-bit Cost: {os.stat('./student_model_8bit.bin').st_size} Bytes.")

Downloading...
From: https://drive.google.com/uc?id=12wtIa0WVRcpboQzhgRUJOpcXe23tgWUL
To: /content/student_custom_small.bin
  0% 0.00/1.05M [00:00<?, ?B/s]100% 1.05M/1.05M [00:00<00:00, 66.7MB/s]
Original Cost: 1047430 Bytes.
16-bit Cost: 522954 Bytes.
8-bit Cost: 268467 Bytes.
