# PyTorch 2 export 训练后量化

参考：[pt2e-PTQ](https://docs.pytorch.org/ao/main/tutorials_source/pt2e_quant_ptq.html#)

准备模型和数据集：

In [1]:
import torch
import torch.nn as nn
from torchvision.models.resnet import resnet18, ResNet18_Weights
from imagenet import ImageNet

train_batch_size = 30
eval_batch_size = 50
# data_path = 'data/imagenet'
data_path = "/media/pc/data/lxw/home/data/datasets/ILSVRC"
dataset = ImageNet(data_path)
data_loader = dataset.train_loader(train_batch_size)
data_loader_test = dataset.test_loader(eval_batch_size)
example_inputs = (next(iter(data_loader))[0])
criterion = nn.CrossEntropyLoss()
float_model = resnet18(weights=ResNet18_Weights.DEFAULT)
float_model = float_model.to("cpu")

训练后量化(Post Training Quantization，简称 PTQ)，需要将模型设置为评估模式。

In [2]:
model_to_quantize = float_model.eval()

## 使用 {func}`torch.export.export` 导出模型

In [3]:
# 创建示例输入：形状为(2, 3, 224, 224)的随机张量
example_inputs = (torch.rand(2, 3, 224, 224),)

# 适用于PyTorch 2.6及以上版本
# 导出模型，捕获计算图并获取模块
exported_model = torch.export.export(model_to_quantize, example_inputs).module()

# 适用于PyTorch 2.5及以前版本
# from torch._export import capture_pre_autograd_graph
# exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs)

# 或者使用动态维度进行捕获
# 适用于PyTorch 2.6及以上版本
# 为第一个输入张量的第0维设置动态维度
dynamic_shapes = tuple(
  {0: torch.export.Dim("dim")} if i == 0 else None
  for i in range(len(example_inputs))
)
# 使用动态维度导出模型
exported_model = torch.export.export(model_to_quantize, example_inputs, dynamic_shapes=dynamic_shapes).module()

# 适用于PyTorch 2.5及以前版本
# 动态维度API可能有所不同
# from torch._export import dynamic_dim
# exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs, constraints=[dynamic_dim(example_inputs[0], 0)])

## 导入后端特定量化器并配置如何量化模型

以下代码片段描述了如何量化模型：

In [4]:
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
  get_symmetric_quantization_config,
  XNNPACKQuantizer,
)
quantizer = XNNPACKQuantizer()
quantizer.set_global(get_symmetric_quantization_config())

<executorch.backends.xnnpack.quantizer.xnnpack_quantizer.XNNPACKQuantizer at 0x7fb9a82f80b0>

`Quantizer` 是后端特定的，每个 `Quantizer` 都会提供自己的方式来允许用户配置他们的模型。例如，这里支持 `XNNPackQuantizer` 的不同配置 API：

```python
# 设置全局量化配置
# qconfig_opt 是一个可选的量化配置对象
quantizer.set_global(qconfig_opt) 
    # 为 Conv2d 类型的模块设置量化配置
    # 可以针对整个模块类型进行设置
    .set_object_type(torch.nn.Conv2d, qconfig_opt) 
    # 为线性函数操作设置量化配置
    # 也可以针对 PyTorch 函数式操作进行设置
    .set_object_type(torch.nn.functional.linear, qconfig_opt)
    # 为特定名称的模块设置量化配置
    # 这里 "foo.bar" 表示模块的路径名称
    .set_module_name("foo.bar", qconfig_opt)
```

```{seealso}
了解[如何编写新的 Quantizer](https://pytorch.org/tutorials/prototype/pt2e_quantizer.html) 。
```

## 准备模型进行训练后量化

`prepare_pt2e` 将 `BatchNorm` 个算子合并到前 `Conv2d` 个算子中，并在模型中适当位置插入观测者。

In [5]:
# Set up warnings
import warnings
warnings.filterwarnings(
    action='ignore',
    category=DeprecationWarning,
    module=r'.*'
)
warnings.filterwarnings(
    action='ignore',
    module=r'torch.fx.graph'
)
warnings.filterwarnings(
    action='default',
    module=r'torchao.quantization.pt2e'
)
from torchao.quantization.pt2e.quantize_pt2e import (
  prepare_pt2e,
  convert_pt2e,
)
prepared_model = prepare_pt2e(exported_model, quantizer)
print(prepared_model.graph)

graph():
    %conv1_weight : [num_users=1] = get_attr[target=conv1.weight]
    %activation_post_process_1 : [num_users=1] = call_module[target=activation_post_process_1](args = (%conv1_weight,), kwargs = {})
    %layer1_0_conv1_weight : [num_users=1] = get_attr[target=layer1.0.conv1.weight]
    %activation_post_process_4 : [num_users=1] = call_module[target=activation_post_process_4](args = (%layer1_0_conv1_weight,), kwargs = {})
    %layer1_0_conv2_weight : [num_users=1] = get_attr[target=layer1.0.conv2.weight]
    %activation_post_process_6 : [num_users=1] = call_module[target=activation_post_process_6](args = (%layer1_0_conv2_weight,), kwargs = {})
    %layer1_1_conv1_weight : [num_users=1] = get_attr[target=layer1.1.conv1.weight]
    %activation_post_process_9 : [num_users=1] = call_module[target=activation_post_process_9](args = (%layer1_1_conv1_weight,), kwargs = {})
    %layer1_1_conv2_weight : [num_users=1] = get_attr[target=layer1.1.conv2.weight]
    %activation_post_process_1

## 校准

在模型中插入观测者后运行校准函数。校准的目的是运行一些具有代表性的样本示例（例如训练数据集的样本），以便模型中的观测者能够观测张量的统计数据，稍后可以使用这些信息来计算量化参数。

In [6]:
def calibrate(model, data_loader):
    # model.eval()
    with torch.no_grad():
        for image, target in data_loader:
            model(image)
calibrate(prepared_model, data_loader_test)  # run calibration on sample data

## 将校准模型转换为量化模型

`convert_pt2e` 接收校准后的模型，并生成量化后的模型。

In [7]:
quantized_model = convert_pt2e(prepared_model)
print(quantized_model)

GraphModule(
  (conv1): Module()
  (layer1): Module(
    (0): Module(
      (conv1): Module()
      (conv2): Module()
    )
    (1): Module(
      (conv1): Module()
      (conv2): Module()
    )
  )
  (layer2): Module(
    (0): Module(
      (conv1): Module()
      (conv2): Module()
      (downsample): Module(
        (0): Module()
      )
    )
    (1): Module(
      (conv1): Module()
      (conv2): Module()
    )
  )
  (layer3): Module(
    (0): Module(
      (conv1): Module()
      (conv2): Module()
      (downsample): Module(
        (0): Module()
      )
    )
    (1): Module(
      (conv1): Module()
      (conv2): Module()
    )
  )
  (layer4): Module(
    (0): Module(
      (conv1): Module()
      (conv2): Module()
      (downsample): Module(
        (0): Module()
      )
    )
    (1): Module(
      (conv1): Module()
      (conv2): Module()
    )
  )
  (fc): Module()
)



def forward(self, x):
    x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
    quantize_per_tenso

## 量化表示

### Q/DQ 表示

在当前阶段，提供了两种表示形式供您选择，但长期提供的具体表示形式可能会根据 PyTorch 用户的反馈进行调整。
- Q/DQ 表示（默认）
- 之前的文档中，所有量化算子都用 `dequantize -> fp32_op -> qauntize` [表示](https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md)。

```python
def quantized_linear(x_int8, x_scale, x_zero_point, weight_int8, weight_scale, weight_zero_point, bias_fp32, output_scale, output_zero_point):
    x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
             x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
    weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
             weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
    weight_permuted = torch.ops.aten.permute_copy.default(weight_fp32, [1, 0]);
    out_fp32 = torch.ops.aten.addmm.default(bias_fp32, x_fp32, weight_permuted)
    out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
    out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8)
    return out_i8
```

### 参考量化模型表示

为选定的算子提供特殊表示，例如量化线性。其他算子表示为 `dq -> float32_op -> q` 和 `q/dq` 并被分解为更基本的算子。您可以通过使用 `convert_pt2e(..., use_reference_representation=True)` 获取这种表示。

```python
# Reference Quantized Pattern for quantized linear
def quantized_linear(x_int8, x_scale, x_zero_point, weight_int8, weight_scale, weight_zero_point, bias_fp32, output_scale, output_zero_point):
    x_int16 = x_int8.to(torch.int16)
    weight_int16 = weight_int8.to(torch.int16)
    acc_int32 = torch.ops.out_dtype(torch.mm, torch.int32, (x_int16 - x_zero_point), (weight_int16 - weight_zero_point))
    bias_scale = x_scale * weight_scale
    bias_int32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
    acc_int32 = acc_int32 + bias_int32
    acc_int32 = torch.ops.out_dtype(torch.ops.aten.mul.Scalar, torch.int32, acc_int32, x_scale * weight_scale / output_scale) + output_zero_point
    out_int8 = torch.ops.aten.clamp(acc_int32, qmin, qmax).to(torch.int8)
    return out_int8
```

## 检查模型大小和准确度评估

现在将模型大小和模型精度与基线模型进行比较。

In [8]:
from utils import print_size_of_model, evaluate
# Baseline model size and accuracy
print("Size of baseline model")
print_size_of_model(float_model)

top1, top5 = evaluate(float_model, criterion, data_loader_test)
print("Baseline Float Model Evaluation accuracy: %2.2f, %2.2f"%(top1.avg, top5.avg))

# Quantized model size and accuracy
print("Size of model after quantization")
# export again to remove unused weights
quantized_model = torch.export.export(quantized_model, example_inputs).module()
print_size_of_model(quantized_model)

top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
print("[before serilaization] Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))

Size of baseline model
Size (MB): 46.828683

Baseline Float Model Evaluation accuracy: 69.23, 88.81
Size of model after quantization
Size (MB): 11.713877


RuntimeError: Expected input at *args[0].shape[0] to be equal to 2, but got 50

```{tip}
1. 现在无法进行性能评估，因为模型尚未下放到目标设备上，它只是 ATen 运算中量化计算的表示。
2. 目前的权重仍然是 `fp32` 格式，未来可能会对量化算子进行常量传播，以获得整数权重。
```

如果你想要获得更好的准确率或性能，可以尝试以不同的方式配置 `quantizer` ，而每个 `quantizer` 都会有其自己的配置方式，因此请查阅你所使用的量化器的文档，以了解更多关于如何更好地控制模型量化方法的信息。

## 保存和加载量化模型

展示如何保存和加载量化模型：

In [None]:
# 0. Store reference output, for example, inputs, and check evaluation accuracy:
example_inputs = (next(iter(data_loader))[0],)
ref = quantized_model(*example_inputs)
top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
print("[before serialization] Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))

# 1. Export the model and Save ExportedProgram
pt2e_quantized_model_file_path = saved_model_dir + "resnet18_pt2e_quantized.pth"
# capture the model to get an ExportedProgram
quantized_ep = torch.export.export(quantized_model, example_inputs)
# use torch.export.save to save an ExportedProgram
torch.export.save(quantized_ep, pt2e_quantized_model_file_path)


# 2. Load the saved ExportedProgram
loaded_quantized_ep = torch.export.load(pt2e_quantized_model_file_path)
loaded_quantized_model = loaded_quantized_ep.module()

# 3. Check results for example inputs and check evaluation accuracy again:
res = loaded_quantized_model(*example_inputs)
print("diff:", ref - res)

top1, top5 = evaluate(loaded_quantized_model, criterion, data_loader_test)
print("[after serialization/deserialization] Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))