# 快速上手 TVM 自动量化

In [1]:
import logging
import set_env
from d2py.utils.log_config import config_logging
from d2py.utils.file import mkdir
# 配置日志信息
temp_dir = ".temp"
logger_name = "test"
mkdir(temp_dir)
config_logging(
    f"{temp_dir}/{logger_name}.log", logger_name, 
    filemode="w", filter_mod_names={"te_compiler"}
)
logger = logging.getLogger(logger_name)

定义简单网络：

In [2]:
import torch
from torch import nn

class TestModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.planes = 64
        self.conv = nn.Conv2d(3, self.planes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn = nn.BatchNorm2d(self.planes, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, dtype=torch.float32)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

转换前端模型为 relay 模型：

In [6]:
import tvm
from tvm import relay
from tvm.relay import transform as _transform
from tvm.relay import expr as _expr
from tvm.relay import Call, Constant, Function
from tvm.ir.op import Op

def _bind_params(func, params):
    """将 params 绑定到 func"""
    name_dict = {}
    for arg in func.params:
        name = arg.name_hint
        if name in name_dict:
            name_dict[name] = None
        else:
            name_dict[name] = arg
    bind_dict = {}
    for k, v in params.items():
        if k not in name_dict:
            continue
        arg = name_dict[k]
        if arg is None:
            raise ValueError(f"Multiple args in the function have name {k}")
        bind_dict[arg] = _expr.const(v)
    return _expr.bind(func, bind_dict)

input_name = "data"
input_shape = (1, 3, 32, 32)
frontend_mod = torch.jit.trace(TestModel().eval(), torch.randn(*input_shape))
# 将前端模型翻译为 relay 模型
origin_mod, params = relay.frontend.from_pytorch(frontend_mod, [(input_name, input_shape)])
logger.info(f'原始模型：{origin_mod["main"]}')
# 将 params 绑定到 origin_mod
if params:
    origin_mod["main"] = _bind_params(origin_mod["main"], params)
logger.info(f'原始模型(绑定参数)：{origin_mod["main"]}')
# 化简并折叠常量
optimize = tvm.transform.Sequential([
        _transform.SimplifyInference(),
        _transform.FoldConstant(),
        _transform.FoldScaleAxis(),
        _transform.CanonicalizeOps(),
        _transform.FoldConstant(),
])
with tvm.transform.PassContext(opt_level=3):
    run_mod = optimize(origin_mod)
logger.info(f'原始模型(化简后)：{run_mod["main"]}')

INFO|2024-01-09 12:50:27,877|test| -> 原始模型：fn (%data: Tensor[(1, 3, 32, 32), float32] /* span=aten::_convolution_0.data:0:0 */, %aten::_convolution_0.weight: Tensor[(64, 3, 7, 7), float32] /* span=aten::_convolution_0.weight:0:0 */, %aten::batch_norm_0.weight: Tensor[(64), float32] /* span=aten::batch_norm_0.weight:0:0 */, %aten::batch_norm_0.bias: Tensor[(64), float32] /* span=aten::batch_norm_0.bias:0:0 */, %aten::batch_norm_0.running_mean: Tensor[(64), float32] /* span=aten::batch_norm_0.running_mean:0:0 */, %aten::batch_norm_0.running_var: Tensor[(64), float32] /* span=aten::batch_norm_0.running_var:0:0 */) {
  %0 = nn.conv2d(%data, %aten::_convolution_0.weight, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* span=aten::_convolution_0:0:0 */;
  %1 = nn.batch_norm(%0, %aten::batch_norm_0.weight, %aten::batch_norm_0.bias, %aten::batch_norm_0.running_mean, %aten::batch_norm_0.running_var) /* span=aten::batch_norm_0:0:0 */;
  %2 = %1.0 /* span=aten::batch_norm

In [None]:
import tvm
class _Transform(tvm.relay.ExprMutator):
    def __init__(self):
        super().__init__()
        self.binds = {}
        self.func_id = 0
    def visit_call(self, call):
        new_fn = self.visit(call.op)
        new_args = [self.visit(arg) for arg in call.args]
        call = Call(new_fn, new_args, call.attrs, call.type_args, call.span)
        if isinstance(new_fn, Op):
            if new_fn.name == "nn.batch_norm":
                self.binds[f"{new_fn.name}_{self.func_id}"] = call
                self.func_id += 1
        return call

In [8]:
relay.analysis.extract_intermdeiate_expr(run_mod, 1)

def @main(%data: Tensor[(1, 3, 32, 32), float32] /* ty=Tensor[(1, 3, 32, 32), float32] span=aten::_convolution_0.data:0:0 */) {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(64, 3, 7, 7), float32] */, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 16, 16), float32] */;
  add(%0, meta[relay.Constant][1] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 16, 16), float32] */
}
