# TVM 自动量化过程剖析

以 PyTorch 的 resnet18 模型为例剖析 TVM 自动量化过程。

## PyTorch 模型翻译为 relay 模型

加载 PyTorch 模型：

In [1]:
def load_model(input_shape):
    from torchvision.models import resnet18, ResNet18_Weights
    import torch
    model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    data = torch.randn(*input_shape)
    model = torch.jit.trace(model.eval(), data)
    return model.eval()

PyTorch 模型翻译为 relay 模型：

In [2]:
import tvm
from tvm import relay

input_shape = 1, 3, 224, 224
input_name = "data"
traced_model = load_model(input_shape)
mod, params = relay.frontend.from_pytorch(
    traced_model, 
    [(input_name, input_shape)], 
    # use_parser_friendly_name=True
)
with tvm.transform.PassContext(opt_level=3): # 预处理
    opt_mod = relay.quantize.prerequisite_optimize(mod, params)

```{rubric} 加载数据
```

In [3]:
import tvm.testing
from tvm import relay
from tvm.relay import transform, build_module
from tvm.relay.testing import run_opt_pass

In [4]:
from dataclasses import dataclass
from pathlib import Path
from PIL import Image
from tqdm import tqdm
import numpy as np
from tvm_book.data.classification import ImageFolderDataset

def preprocess_image(
        image: np.ndarray,
        size: tuple[int] = (224, 224),
        mean: tuple[float] = (0.485, 0.456, 0.406),
        std: tuple[float] = (0.229, 0.224, 0.225)
    ):
    im = Image.fromarray(image)
    im = im.resize((256, 256), Image.Resampling.BILINEAR)
    ori_H, ori_W = im.size
    H, W = size
    space_W, space_H = (ori_W - W)//2, (ori_H - H)//2
    im = im.crop((space_W, space_H, ori_W-space_W, ori_H-space_H))
    image = np.array(im, dtype="float32")
    im.close()
    image = image/256
    image -= mean
    image /= std
    return image.astype(np.float32)

@dataclass
class ImageNet:
    root: str
    size: tuple[int] = (224, 224)
    mean: tuple[float] = (0.485, 0.456, 0.406)
    std: tuple[float] = (0.229, 0.224, 0.225)

    def __post_init__(self):
        self.root = Path(self.root) # 数据根目录
        self.valset = ImageFolderDataset(f"{self.root}/val")
        self.trainset = ImageFolderDataset(f"{self.root}/train")

    def calibrateset(self, calibrate_num: int = 200):
        """用于 TVM 量化的校准数据集
        """
        for k, (data, label) in tqdm(enumerate(self.trainset)):
            if k >= calibrate_num:
                break
            image = preprocess_image(data, self.size, self.mean, self.std)
            images = np.expand_dims(image, 0)
            images = images.transpose((0, 3, 1, 2))
            yield {"data": images}


In [5]:
dataset = ImageNet("/media/pc/data/lxw/home/data/datasets/ILSVRC/")

## resnet18 算子融合

In [6]:
from tvm.relay import Call
from tvm.relay.function import Function, FunctionWithFields
from tvm.relay.quantize._partition import QPartitionExpr

@tvm.relay.transform.function_pass(opt_level=1)
class QPartitionTransform:
    """为融合的函数添加 QPartitionExpr
    """
    def transform_function(self, func, mod, ctx):
        class Replace(tvm.relay.ExprMutator):
            def visit_function(self, fn):
                new_params = [self.visit(x) for x in fn.params]
                new_body = self.visit(fn.body)
                if not isinstance(new_body.op, Function): # 防止循环添加 QPartitionExpr
                    new_body = QPartitionExpr(new_body).realize()
                if new_params == list(fn.params) and new_body == fn.body:
                    new_fn =  fn
                else:
                    new_fn = FunctionWithFields(fn, list(new_params), new_body)
                return new_fn
        return Replace().visit(func)
    
@tvm.relay.transform.function_pass(opt_level=1)
class SplitGraphTransform:
    """保存子图到不同是子函数
    """
    def __init__(self):
        self.reset()
        
    def reset(self):
        self._func_index = 0

    def transform_function(self, func, mod, ctx):
        obj = self
        class Replace(tvm.relay.ExprMutator):
            def visit_call(self, call):
                new_fn = self.visit(call.op)
                new_args = [self.visit(arg) for arg in call.args]
                if isinstance(new_fn, Function):
                    func_name = f"f_{obj._func_index:04d}"
                    new_fn = run_opt_pass(new_fn, relay.transform.FoldConstant())
                    # print(new_fn)
                    mod[func_name] = new_fn
                    new_fn = mod.get_global_var(func_name)
                    obj._func_index += 1
                if new_fn == call.op and new_args == list(call.args):
                    new_call = call
                else:
                    new_call = Call(new_fn, new_args, call.attrs, call.type_args, call.span)
                return new_call
        return Replace().visit(func)


In [7]:
from tvm_book.tvm_utils.relay_pattern import *
# 配置融合规则
compiler_name = "ccompiler"
pattern_table = [
    (f"{compiler_name}.conv_add_relu_max_pool2d", make_conv_add_relu_max_pool2d_pattern()),
    (f"{compiler_name}.conv2d_transpose_add_activate", make_conv2d_transpose_add_activate_pattern()),
    (f"{compiler_name}.conv_add_activate", make_conv_add_activate_pattern()),
    (f"{compiler_name}.max_pool2d", make_max_pool2d_pattern()),
    (f"{compiler_name}.dense_add", make_dense_add_pattern()),
    (f"{compiler_name}.adaptive_avg_pool2d", make_adaptive_avg_pool2d_pattern()),
    (f"{compiler_name}.avg_pool2dd", make_avg_pool2d_pattern()),
    (f"{compiler_name}.add_multiply_add", make_add_multiply_add_pattern()),
    (f"{compiler_name}.add", make_add_pattern()),
    (f"{compiler_name}.multiply", make_multiply_pattern()),
    # (f"{compiler_name}.strided_slice", make_strided_slice_pattern()),
]
merge_passes = tvm.transform.Sequential([
    relay.transform.InferType(),
    relay.transform.MergeComposite(pattern_table),
    QPartitionTransform(), # 为融合函数添加 `QPartitionExpr` 算子
    # relay.transform.DefuseOps(),
    # relay.transform.MergeComposite(pattern_table),
    SplitGraphTransform(),
])

In [8]:
with tvm.transform.PassContext(opt_level=3):
    run_mod = merge_passes(opt_mod)

In [9]:
print(run_mod["f_0000"])

fn (%FunctionVar_19_0: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] */, %FunctionVar_19_1: Tensor[(64, 3, 7, 7), float32] /* ty=Tensor[(64, 3, 7, 7), float32] */, %FunctionVar_19_2: Tensor[(64, 1, 1), float32] /* ty=Tensor[(64, 1, 1), float32] */, PartitionedFromPattern="nn.conv2d_add_nn.relu_", Composite="ccompiler.conv_add_activate") -> Tensor[(1, 64, 112, 112), float32] {
  %0 = nn.conv2d(%FunctionVar_19_0, %FunctionVar_19_1, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %1 = add(%0, %FunctionVar_19_2) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %2 = nn.relu(%1) /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::relu__0:0:0 */;
  %3 = annotation.cast_hint(%2, dtype="int8") /* ty=Tensor[(1, 64, 112, 112), float32] */;
  annotation.stop_fusion(%3) /* ty=Tensor[(1, 64, 112, 112), float32] */
} /* ty=fn (Tensor[(1, 3, 224, 224), float32], Tensor[(64, 3, 7, 7), float32], Tensor[(64, 1

In [10]:
print(run_mod["main"])

fn (%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 1000), float32] {
  %0 = @f_0000(%data, meta[relay.Constant][0] /* ty=Tensor[(64, 3, 7, 7), float32] */, meta[relay.Constant][1] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %1 = @f_0001(%0) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %2 = @f_0002(%1, meta[relay.Constant][2] /* ty=Tensor[(64, 64, 3, 3), float32] */, meta[relay.Constant][3] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %3 = @f_0003(%2, meta[relay.Constant][4] /* ty=Tensor[(64, 64, 3, 3), float32] */, meta[relay.Constant][5] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %4 = @f_0004(%3, %1) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %5 = @f_0005(%4, meta[relay.Constant][6] /* ty=Tensor[(64, 64, 3, 3), float32] */, meta[relay.Constant][7] /* ty=Tensor[(64, 1, 1), float32]

## 注解 resnet18 模型

In [None]:
with tvm.transform.PassContext(opt_level=3):
    run_mod = merge_passes(opt_mod)

    # run_mod = relay.quantize.annotate()(run_mod)

In [None]:
print(run_mod["f_0000"])

In [None]:
# from tvm.ir import IRModule, structural_equal
import tvm

with tvm.transform.PassContext(opt_level=3):
    with relay.quantize.qconfig(
        skip_conv_layers=[],
        calibrate_mode="kl_divergence", 
        weight_scale="max",
        round_for_shift=True,
        # rounding="TONEAREST", # "UPWARD" or "TONEAREST"
        calibrate_skip_layers=[],
        skip_dense_layer=False,
    ):
        qmod = relay.quantize.quantize(mod, params, dataset.calibrateset(calibrate_num=200))

```{rubric} 度量 resnet18 结果
```

In [None]:
from tqdm import tqdm
from tvm.runtime.vm import VirtualMachine
from tvm_book.metric.classification import Accuracy, TopKAccuracy

with tvm.transform.PassContext(opt_level=3):
    vm_exec = relay.vm.compile(run_mod, target="llvm", params=params)
vm = VirtualMachine(vm_exec, tvm.cpu())
with tvm.transform.PassContext(opt_level=3):
    qvm_exec = relay.vm.compile(qmod, target="llvm", params=params)
qvm = VirtualMachine(qvm_exec, tvm.cpu())

metric_top1 = Accuracy("浮点")
metric_top5 = TopKAccuracy(top_k=5)
qmetric_top1 = Accuracy("量化")
qmetric_top5 = TopKAccuracy(top_k=5)
for k, (data, label) in tqdm(enumerate(dataset.valset)):
    image = preprocess_image(data, dataset.size, dataset.mean, dataset.std)
    images = np.expand_dims(image, 0)
    images = images.transpose((0, 3, 1, 2))
    input_dict = {"data": images}
    output = vm.run(**input_dict).asnumpy()
    quant_output = qvm.run(**input_dict).asnumpy()
    label = np.array([label])
    # 精度度量
    metric_top1.update(preds = output, labels = label)
    metric_top5.update(preds = output, labels = label)
    qmetric_top1.update(preds = quant_output, labels = label)
    qmetric_top5.update(preds = quant_output, labels = label)
    if k % 1000 == 0:
        print(f"浮点: {metric_top1.get(), metric_top5.get()}||量化: {qmetric_top1.get(), qmetric_top5.get()}")
    # break