# 自定义 VTA Graph Pack

In [1]:
from copy import deepcopy
import tvm
from tvm import relay
from tvm_book.vta_utils.pack_tool import graph_pack, WithVTAFunctionTransform

## VTA 模型样例

In [2]:
from torch import nn
import torch

class Model(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv = nn.Conv2d(3, 36, 3, 1, 1, bias=True)
        self.bn = nn.BatchNorm2d(36)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

pt_model = Model().eval().float()
ishape = (1, 3, 4, 4)
input_name = "data"
input_shapes = [(input_name, ishape)]
# script_module = torch.jit.script(pt_model)
# mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
idata = torch.randn(ishape).type(torch.float32)
traced_model = torch.jit.trace(pt_model, idata)
# traced_model 翻译为 TVM 前端模型
mod, params = relay.frontend.from_pytorch(traced_model, input_shapes)
# 量化
with tvm.transform.PassContext(opt_level=3):
    with relay.quantize.qconfig(skip_conv_layers=[], weight_scale="max",):
        mod = relay.quantize.quantize(mod, params)
mod.show()

## VTA Graph Pack

In [3]:
from tvm.relay.function import Function
from tvm.relay.testing import run_opt_pass
import vta

env = vta.get_env()
bfactor = env.BATCH
cfactor = env.BLOCK_OUT
weight_bits = env.WGT_WIDTH

run_mod = deepcopy(mod)
new_fn = graph_pack(run_mod["main"], bfactor, cfactor, weight_bits)
tvm.IRModule.from_expr(new_fn).show()

## VTA 模型的算子融合

创建融合策略：

In [4]:
from tvm.relay.dataflow_pattern import (
    # TuplePattern, TupleGetItemPattern, 
    is_op, wildcard, is_constant
)
def preprocessing_pattern():
    r = is_op("multiply")(wildcard(), is_constant())
    r = is_op("round")(r)
    r = is_op("clip")(r)
    r = is_op("cast")(r)
    return r

def output_pattern():
    r = is_op("cast")(wildcard())
    r = is_op("multiply")(r, is_constant())
    return r

def pad_reshape_transpose_pattern():
    r = is_op("nn.pad")(wildcard(), wildcard())
    r = is_op("reshape")(r) | r
    r = is_op("transpose")(r)
    r = is_op("broadcast_to")(r) | r
    return r

def conv_add_activate_pattern():
    r"""Create a pattern to match the following graph.

    conv2d
        |
        (add)
        |
        (add)
        |
    (relu|relu6|prelu|sigmoid|relux)
    """
    x = wildcard()
    w = wildcard()
    bias = wildcard()
    bias2 = wildcard()
    alpha = wildcard()
    
    bias_ = is_op("relay.op.annotation.simulated_quantize")(bias, is_constant(), is_constant(), is_constant()) | bias
    bias2_ = is_op("relay.op.annotation.simulated_quantize")(bias2, is_constant(), is_constant(), is_constant()) | bias2
    alpha_ = is_op("relay.op.annotation.simulated_quantize")(alpha, is_constant(), is_constant(), is_constant()) | alpha

    conv_node = is_op("nn.conv2d")(x, w)
    conv_node = is_op("add")(conv_node, bias2_) | conv_node
    
    fixed_point_multiply = is_op("fixed_point_multiply")(conv_node)
    fixed_point_multiply = is_op("cast")(fixed_point_multiply)
    conv_node = fixed_point_multiply | conv_node
    r = is_op("add")(conv_node, bias_) | conv_node
    
    # 激活函数
    r1 = r.optional(lambda x: is_op("nn.relu")(x))
    r2 = r.optional(lambda x: is_op("clip")(x)) # relu6
    r3 = r.optional(lambda x: is_op("nn.prelu")(x, alpha)) # prelu
    r4 = r.optional(lambda x: is_op("sigmoid")(x)) # sigmoid
    r = r1 | r2 | r3 | r4

    r_s = is_op("relay.op.annotation.simulated_quantize")(r, is_constant(), is_constant(), is_constant()) | r
    r_s = is_op("annotation.cast_hint")(r_s) | r_s

    r_q = is_op("cast")(r)
    r_q = is_op("fixed_point_multiply")(r_q)
    r_q = is_op("clip")(r_q)
    r_q = is_op("cast")(r_q)
    r_q = is_op("cast")(r_q)
    r = r_s | r_q
    r = is_op("annotation.stop_fusion")(r) | r
    return r

pattern_table = [
    ("vta_preprocessing", preprocessing_pattern()),
    ("vta_reshape_transpose", pad_reshape_transpose_pattern()),
    ("vta_conv2d", conv_add_activate_pattern()),
    ("vta_output", output_pattern()),
]

实现算子融合：

In [8]:
import vta

env = vta.get_env()
bfactor = env.BATCH
cfactor = env.BLOCK_OUT
weight_bits = env.WGT_WIDTH

run_mod = deepcopy(mod)
new_fn = graph_pack(run_mod["main"], bfactor, cfactor, weight_bits)
run_mod = tvm.IRModule.from_expr(new_fn)

prepare_transform = tvm.transform.Sequential([
    relay.transform.InferType(),
    relay.transform.FoldConstant(), # 折叠常量参数
    relay.transform.MergeComposite(pattern_table), # 算子融合
    WithVTAFunctionTransform(), # 为融合函数 vta_conv2d 添加 ConvAttrs 属性
    relay.transform.InferType(),
])
# run_mod = deepcopy(mod)
with tvm.transform.PassContext(opt_level=3):
    run_mod = prepare_transform(run_mod)
run_mod.show()