In [1]:
import torch

# define a floating point model where some layers could benefit from QAT
class M(torch.nn.Module):
    def __init__(self):
        super(M, self).__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.bn = torch.nn.BatchNorm2d(1)
        self.relu = torch.nn.ReLU()
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.dequant(x)
        return x

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# create a model instance
model_fp32 = M()

# model must be set to train mode for QAT logic to work
model_fp32.train()

M(
  (quant): QuantStub()
  (conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
  (bn): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (dequant): DeQuantStub()
)

In [3]:
# quantization config 설정
model_fp32.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')

In [4]:
# 앞단의 레이어와 activation을 fusing한다.
# 아키텍처에 따라서 수동을 해야할수도 있음
model_fp32_fused = torch.quantization.fuse_modules(model_fp32,
    [['conv', 'bn', 'relu']])

In [6]:
model_fp32

M(
  (quant): QuantStub()
  (conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
  (bn): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (dequant): DeQuantStub()
)

In [5]:
model_fp32_fused

M(
  (quant): QuantStub()
  (conv): ConvBnReLU2d(
    (0): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (bn): Identity()
  (relu): Identity()
  (dequant): DeQuantStub()
)

In [None]:
# Prepare the model for QAT. This inserts observers and fake_quants in
# the model that will observe weight and activation tensors during calibration.
model_fp32_prepared = torch.quantization.prepare_qat(model_fp32_fused)

In [None]:

# run the training loop (not shown)
training_loop(model_fp32_prepared)

# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, fuses modules where appropriate,
# and replaces key operators with quantized implementations.
model_fp32_prepared.eval()
model_int8 = torch.quantization.convert(model_fp32_prepared)

# run the model, relevant calculations will happen in int8
res = model_int8(input_fp32)