In [22]:
import torch
import torch.quantization

In [23]:
# 定义一个简单的模型
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv = torch.nn.Conv2d(1, 20, 5)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x

In [24]:
# 创建模型实例
model = MyModel()

# 切换到评估模式
model.eval()

MyModel(
  (conv): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (relu): ReLU()
)

In [25]:
# 设置量化配置，使用fbgemm后端
model.qconfig = torch.quantization.get_default_qconfig('x86')

# 准备模型进行量化（自动插入量化相关层）
torch.quantization.prepare(model, inplace=True)

MyModel(
  (conv): Conv2d(
    1, 20, kernel_size=(5, 5), stride=(1, 1)
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
)

In [26]:
# 模拟校准数据运行一次前向传播，校准模型的量化参数
calibration_data = torch.randn(10, 1, 28, 28)  # 假设有10个校准数据
with torch.no_grad():
    for data in calibration_data:
        model(data)

In [27]:
# 执行量化转换
quantized_model = torch.quantization.convert(model, inplace=False)

In [28]:
quantized_model

MyModel(
  (conv): QuantizedConv2d(1, 20, kernel_size=(5, 5), stride=(1, 1), scale=0.038826461881399155, zero_point=64)
  (relu): ReLU()
)

In [31]:
print(quantized_model)

MyModel(
  (conv): QuantizedConv2d(1, 20, kernel_size=(5, 5), stride=(1, 1), scale=0.038826461881399155, zero_point=64)
  (relu): ReLU()
)
