# 自定义 VTA Graph Pack

In [1]:
from copy import deepcopy
import tvm
from tvm import relay
from vta_utils.pack_tool import ExprGraphPack # WithVTAFunctionTransform

## VTA 模型样例

In [2]:
from torch import nn
import torch

class Model(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv = nn.Conv2d(3, 32, 3, 1, 1, bias=False)
        self.bn = nn.BatchNorm2d(32)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        # x = self.bn(x)
        # x = self.relu(x)
        return x

pt_model = Model().eval().float()
ishape = (1, 3, 4, 4)
input_name = "data"
input_shapes = [(input_name, ishape)]
# script_module = torch.jit.script(pt_model)
# mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
idata = torch.randn(ishape).type(torch.float32)
traced_model = torch.jit.trace(pt_model, idata)
# traced_model 翻译为 TVM 前端模型
mod, params = relay.frontend.from_pytorch(traced_model, input_shapes)
# 量化
with tvm.transform.PassContext(opt_level=3):
    with relay.quantize.qconfig(skip_conv_layers=[], weight_scale="max",):
        mod = relay.quantize.quantize(mod, params)
mod.show()

## VTA Graph Pack

In [3]:
# from tvm.relay import op
# from tvm.relay.op import op as _op
# from tvm.relay import ExprMutator
# from tvm.relay.expr import Call
# from tvm.ir.op import Op
from tvm.relay.function import Function
# from vta.top.graphpack import (
#     _channel_const_match,
#     _to_shape,
#     _get_tensor_type,
#     _pack_weight,
#     _weight_shape_match,
#     _pack_weight_conv2d_transpose,
#     _weight_shape_match_transpose,
#     _pack_const,
#     _const_shape_match,
# )
from tvm.relay.testing import run_opt_pass

In [4]:
import vta

env = vta.get_env()
bfactor = env.BATCH
cfactor = env.BLOCK_OUT
weight_bits = env.WGT_WIDTH

run_mod = deepcopy(mod)
transform = ExprGraphPack(bfactor, cfactor, weight_bits)
new_fn = run_mod["main"]
new_fn = transform.visit(new_fn)
new_body = new_fn.body
new_fn = Function(
    list(new_fn.params), new_body,
    ret_type=new_body.checked_type,
    type_params=new_fn.type_params,
    attrs=new_fn.attrs,
    span=new_fn.span
)
new_fn = run_opt_pass(new_fn, relay.transform.InferType())
tvm.IRModule.from_expr(new_fn).show()

AssertionError: 

In [None]:
www

## VTA 模型的算子融合

创建融合策略：

In [None]:
from tvm.relay.dataflow_pattern import (
    # TuplePattern, TupleGetItemPattern, 
    is_op, wildcard, is_constant
)
def preprocessing_pattern():
    r = is_op("multiply")(wildcard(), is_constant())
    r = is_op("round")(r)
    r = is_op("clip")(r)
    r = is_op("cast")(r)
    return r

def output_pattern():
    r = is_op("cast")(wildcard())
    r = is_op("multiply")(r, is_constant())
    return r

def conv_add_activate_pattern():
    r"""Create a pattern to match the following graph.

    conv2d
        |
        (add)
        |
        (add)
        |
    (relu|relu6|prelu|sigmoid|relux)
    """
    x = wildcard()
    w = wildcard()
    bias = wildcard()
    bias2 = wildcard()
    alpha = wildcard()
    
    bias_ = is_op("relay.op.annotation.simulated_quantize")(bias, is_constant(), is_constant(), is_constant()) | bias
    bias2_ = is_op("relay.op.annotation.simulated_quantize")(bias2, is_constant(), is_constant(), is_constant()) | bias2
    alpha_ = is_op("relay.op.annotation.simulated_quantize")(alpha, is_constant(), is_constant(), is_constant()) | alpha

    conv_node = is_op("nn.conv2d")(x, w)
    conv_node = is_op("add")(conv_node, bias2_) | conv_node
    
    fixed_point_multiply = is_op("fixed_point_multiply")(conv_node)
    fixed_point_multiply = is_op("cast")(fixed_point_multiply)
    conv_node = fixed_point_multiply | conv_node
    r = is_op("add")(conv_node, bias_) | conv_node
    
    # 激活函数
    r1 = r.optional(lambda x: is_op("nn.relu")(x))
    r2 = r.optional(lambda x: is_op("clip")(x)) # relu6
    r3 = r.optional(lambda x: is_op("nn.prelu")(x, alpha)) # prelu
    r4 = r.optional(lambda x: is_op("sigmoid")(x)) # sigmoid
    r = r1 | r2 | r3 | r4

    r_s = is_op("relay.op.annotation.simulated_quantize")(r, is_constant(), is_constant(), is_constant()) | r
    r_s = is_op("annotation.cast_hint")(r_s) | r_s

    r_q = is_op("cast")(r)
    r_q = is_op("fixed_point_multiply")(r_q)
    r_q = is_op("clip")(r_q)
    r_q = is_op("cast")(r_q)
    r_q = is_op("cast")(r_q)
    r = r_s | r_q
    r = is_op("annotation.stop_fusion")(r) | r
    return r

pattern_table = [
    ("vta_preprocessing", preprocessing_pattern()),
    ("vta_conv2d", conv_add_activate_pattern()),
    ("vta_output", output_pattern()),
]

实现算子融合：

In [None]:
# from vta_utils.pack_tool import VTAGraphPackTransform

In [None]:
import vta

env = vta.get_env()
bfactor = env.BATCH
cfactor = env.BLOCK_OUT
weight_bits = env.WGT_WIDTH

prepare_transform = tvm.transform.Sequential([
    relay.transform.InferType(),
    relay.transform.MergeComposite(pattern_table), # 算子融合
    WithVTAFunctionTransform(), # 为融合函数 vta_conv2d 添加 ConvAttrs 属性
    # VTAGraphPackTransform(bfactor, cfactor, weight_bits),
    relay.transform.InferType(),

])
run_mod = deepcopy(mod)
with tvm.transform.PassContext(opt_level=3):
    run_mod = prepare_transform(run_mod)
run_mod.show()

In [None]:
import tvm
from tvm import relay
from tvm.relay.expr import Call, Let, Var
from tvm.relay.function import Function, FunctionWithFields
from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.relay import op
from tvm.relay.op import op as _op
from tvm.relay.testing import run_opt_pass
from tvm.relay.dataflow_pattern import (
    # TuplePattern, TupleGetItemPattern, 
    is_op, wildcard, is_constant
)
from tvm.relay import ExprMutator #, ExprVisitor
from tvm.ir.op import Op
from vta.top.graphpack import (
    _to_shape,
    # _unpack_batch_channel,
    _channel_const_match,
    _const_shape_match,
    _weight_shape_match,
    # # _weight_shape_match_transpose, # 新增
    _pack_weight,
    # _pack_weight_conv2d_transpose,
    # # _pack_const, # 被修改
    _get_tensor_shape,
    _get_tensor_type,
)
from tvm.relay.expr import GlobalVar, Let
from tvm.relay.function import Function, FunctionWithFields
from vta_utils.utils import _pack_batch_channel
from tvm.relay.testing import run_opt_pass

In [None]:
prepare_transform = tvm.transform.Sequential([
    relay.transform.InferType(),
    relay.transform.MergeComposite(pattern_table), # 算子融合
    WithVTAFunctionTransform(), # 为融合函数 vta_conv2d 添加 ConvAttrs 属性
    # VTAGraphPackTransform(bfactor, cfactor, weight_bits),
    relay.transform.InferType(),

])

run_mod = deepcopy(mod)
run_mod = prepare_transform(run_mod)
new_fn = run_mod["vta_conv2d__1"]

# new_fn = run_opt_pass(new_fn, relay.transform.InferType())

In [None]:
prepare_transform = tvm.transform.Sequential([
    relay.transform.InferType(),
    relay.transform.MergeComposite(pattern_table), # 算子融合
    WithVTAFunctionTransform(), # 为融合函数 vta_conv2d 添加 ConvAttrs 属性
    # VTAGraphPackTransform(bfactor, cfactor, weight_bits),
    relay.transform.InferType(),

])

run_mod = deepcopy(mod)
run_mod = prepare_transform(run_mod)
global_vars = [vv for vv in run_mod.get_global_vars() if vv.name_hint!="main"]
for op_var in global_vars:
    new_fn = run_mod[op_var]
    oshape = list(_to_shape(new_fn.checked_type.ret_type.shape))
    new_body = new_fn.body
    odtype = new_fn.checked_type.ret_type.dtype
    input_types = [vv.dtype for vv in new_fn.checked_type.arg_types]
    # break
    if "vta_preprocessing" in op_var.name_hint:
        assert new_fn.attrs["Composite"] == "vta_preprocessing"
        
        new_body = _pack_batch_channel(new_body, oshape, bfactor, cfactor)
        new_fn = Function(
            list(new_fn.params), new_body,
            ret_type=new_body.checked_type,
            type_params=new_fn.type_params,
            attrs=new_fn.attrs,
            span=new_fn.span
        )
        run_mod[op_var] = run_opt_pass(new_fn, relay.transform.InferType())
    # elif "vta_output" in op_var.name_hint:
    #     assert new_fn.attrs["Composite"] == "vta_output"
    #     assert odtype == "float32"
    #     new_params = [
    #         _pack_batch_channel(param, list(_get_tensor_shape(param)), bfactor, cfactor)
    #         for param in new_fn.params
    #     ]
    #     new_body = _unpack_batch_channel(new_body, list(_get_tensor_shape(new_body)), unpack_transpose=True)
    #     # new_fn = Function(
    #     #     new_params, new_body,
    #     #     # ret_type=new_body.checked_type,
    #     #     # type_params=new_fn.type_params,
    #     #     attrs=new_fn.attrs,
    #     #     span=new_fn.span
    #     # )
    #     break
    #     # data = args[0]
    #     # data_shape = _get_tensor_shape(call.args[0])
    #     # data = _unpack_batch_channel(data, old_shape, unpack_transpose=True)
    #     # return _unpack_batch_channel(data, data_shape, self.unpack_transpose)
    elif "vta_conv2d" in op_var.name_hint:
        assert new_fn.attrs["Composite"] == "vta_conv2d"
        transform = PackConv2dMutator(bfactor, cfactor, weight_bits)
        new_fn = transform.visit(new_fn)
        new_body = new_fn.body
        new_fn = Function(
            list(new_fn.params), new_body,
            ret_type=new_body.checked_type,
            type_params=new_fn.type_params,
            attrs=new_fn.attrs,
            span=new_fn.span
        )
        run_mod[op_var] = run_opt_pass(new_fn, relay.transform.InferType())

In [None]:
print(new_fn)

In [None]:
run_mod.show()