In [2]:
from __future__ import annotations
import numpy as np
import tvm
from MLC import mlc
from tvm import te, relax, IRModule
from tvm.script import relax as R
from tvm.script import tir as T

In [3]:
import torch
import torch.nn as nn
from torch import fx

In [4]:
class Demo(nn.Module):
    def __init__(self):
        super().__init__()
        self.w = nn.Parameter(torch.rand((900, 128)))
        self.b = nn.Parameter(torch.rand((128,)))
        self.conv2d = nn.Conv2d(1, 1, 3, 1)
        self.relu = nn.ReLU()
        self.linear = nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.conv2d(x)
        x = torch.relu(x)
        x = x.view([1, -1])
        x = torch.matmul(x, self.w)
        x = torch.add(x, self.b)
        x = self.relu(x)
        x = self.linear(x)
        return x


In [5]:
DemoRelax = mlc.from_fx(Demo(), [(1, 1, 32, 32)])
# DemoRelax.show()

In [6]:

DemoFused = mlc.FuseDenseAddPass()(DemoRelax)

In [7]:

DemoModelTIR = mlc.LowerToTensorIRPass()(DemoFused)

In [8]:

DemoModelFinal = relax.transform.FuseTIR()(DemoModelTIR)

In [9]:
hw = 32
nrepeat = 10000
x = np.random.rand(1, 1, hw, hw).astype('float32')
x_torch_cuda = torch.from_numpy(x)
demo_cuda = Demo()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(nrepeat):
    demo_cuda(x_torch_cuda)
end.record()

print(start.elapsed_time(end) / nrepeat)

0.369782421875


## torch ---> IRModule

### Step 1: TorchFX GraphModule

nn.Module ---> torch.fx.graph_module.GraphModule

In [10]:
# TorchFx Graphmodule
model = Demo()
fx_module = fx.symbolic_trace(model)
type(fx_module)

torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl

In [11]:
fx_module.graph.print_tabular()

opcode         name    target                                                     args             kwargs
-------------  ------  ---------------------------------------------------------  ---------------  --------
placeholder    x       x                                                          ()               {}
call_module    conv2d  conv2d                                                     (x,)             {}
call_function  relu    <built-in method relu of type object at 0x7f4e31692ec0>    (conv2d,)        {}
call_method    view    view                                                       (relu, [1, -1])  {}
get_attr       w       w                                                          ()               {}
call_function  matmul  <built-in method matmul of type object at 0x7f4e31692ec0>  (view, w)        {}
get_attr       b       b                                                          ()               {}
call_function  add     <built-in method add of type object at 0x7f4e3169

In [12]:
dict(fx_module.named_modules())
# for i in fx_module.graph.nodes:
#     print(i.op)
# getattr(fx_module, 'linear').bias
# type(relax.Var('x', R.Tensor((1,10))))

{'': Demo(
   (conv2d): Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1))
   (relu): ReLU()
   (linear): Linear(in_features=128, out_features=10, bias=True)
 ),
 'conv2d': Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1)),
 'relu': ReLU(),
 'linear': Linear(in_features=128, out_features=10, bias=True)}

### Step 3: 构造映射函数

In [13]:
def map_params(param: nn.Parameter):
    return relax.const(param.data.cpu().numpy(), dtype='float32')

def fetch_attr(fx_mod, target: str):
    # 获取mod属性
    target_atoms = target.split('.')
    attr_itr = fx_mod
    for i, atom in enumerate(target_atoms):
        if not hasattr(attr_itr, atom):
            raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
        attr_itr = getattr(attr_itr, atom)
    return attr_itr

def from_fx(fx_module: torch.fx.GraphModule, input_shapes, call_map_function, call_map_module, call_map_method):
    ''' 根据映射将torchgraph ---> relax function
        1. function 参数直接来自于node_map
        2. module 带参数+计算
        
        example:
        with bb.function('main'):
            with bb.dataflow():
                (var = bb.emt_)
                bb.emit_output()
        bb.emit_func_output(fn_output, fn_inputs)
    '''
    input_index = 0
    node_map = {}
    named_modules = dict(fx_module.named_modules())

    bb = relax.BlockBuilder()
    fn_inputs = []
    fn_output = None
    with bb.function('main'):
        with bb.dataflow():
            for node in fx_module.graph.nodes:
                if node.op == 'placeholder':
                    input_shape = input_shapes[input_index]
                    input_index = input_index + 1
                    fn_input = relax.Var(node.target, R.Tensor(input_shape, 'float32'))
                    fn_inputs.append(fn_input)
                    node_map[node] = fn_input
                elif node.op == 'get_attr':
                    node_map[node] = map_params(fetch_attr(fx_module, node.target))
                elif node.op == 'call_function':
                    node_map[node] = call_map_function[node.target](bb, node_map, node)
                # --------------------- add call_method ------------------------#
                elif node.op == 'call_method':
                    node_map[node] = call_map_method[node.target](bb, node_map, node)
                # --------------------------------------------------------------#
                elif node.op == 'call_module':
                    nn_module = named_modules[node.target]
                    # node_map[node] = call_map_module[nn_module](bb, node_map, node, nn_module)
                    node_map[node] = call_map_module[type(nn_module)](bb, node_map, node, nn_module)
                elif node.op == 'output':
                    output = node_map[node.args[0]]
                    if fn_output is not None:
                        raise Warning("error")
                    fn_output = bb.emit_output(output)
        bb.emit_func_output(fn_output, fn_inputs)
    return bb.get()

In [14]:
# call_function
def map_matmul(bb: relax.BlockBuilder, node_map, node):
    x = node_map[node.args[0]]
    w = node_map[node.args[1]]
    return bb.emit(relax.op.matmul(x, w))

def map_relu(bb: relax.BlockBuilder, node_map, node):
    x = node_map[node.args[0]]
    return bb.emit(relax.op.nn.relu(x))

def map_add(bb: relax.BlockBuilder, node_map, node):
    x = node_map[node.args[0]]
    b = node_map[node.args[1]]
    return bb.emit(relax.op.add(x, b))

# call_module
def map_nn_relu(bb: relax.BlockBuilder, node_map, node, nn_module):
    return bb.emit(relax.op.nn.relu(node_map[node.args[0]]))

def map_nn_linear(bb: relax.BlockBuilder, node_map, node, nn_module):
    x = node_map[node.args[0]]
    w = map_params(nn_module.weight)
    bias = map_params(nn_module.bias)
    return bb.emit(relax.op.linear(x, w, bias))

def map_nn_conv2d(bb: relax.BlockBuilder, node_map, node, nn_module):
    x = node_map[node.args[0]]
    kernel = map_params(nn_module.weight)
    bias = None
    if nn_module.bias is not None:
        bias = map_params(nn_module.bias)
    stride = nn_module.stride
    conv_out = bb.emit(relax.op.nn.conv2d(x, kernel, stride))
    if not bias:
        return bb.emit(relax.op.add(conv_out, bias))
    return conv_out

# call_method
def map_view(bb: relax.BlockBuilder, node_map, node):
    x = node_map[node.args[0]]
    shape = node.args[1]
    return bb.emit(relax.op.reshape(x, shape))

In [15]:
call_map_function = {
    torch.matmul: map_matmul,
    torch.relu: map_relu,
    torch.add: map_add
}

call_map_module = {
    torch.nn.Conv2d: map_nn_conv2d,
    torch.nn.ReLU: map_nn_relu,
    torch.nn.Linear: map_nn_linear
}

call_map_method = {
    'view': map_view
}

DemoModel = from_fx(fx_module, [(1, 1, hw, hw)], call_map_function, call_map_module, call_map_method)

# DemoModel.show(show_meta=False)

In [1]:
tvm.ir.assert_structural_equal(DemoRelax, DemoModel)

NameError: name 'tvm' is not defined

In [15]:
for i, func in DemoModel.functions.items():
    print(i, func, type(func))

I.GlobalVar("main") # from tvm.script import relax as R

@R.function
def main(x: R.Tensor((1, 1, 32, 32), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
    with R.dataflow():
        lv: R.Tensor((1, 1, 30, 30), dtype="float32") = R.nn.conv2d(x, metadata["relax.expr.Constant"][0], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
        lv1: R.Tensor((1, 1, 30, 30), dtype="float32") = R.nn.relu(lv)
        lv2: R.Tensor((1, 900), dtype="float32") = R.reshape(lv1, (1, 900))
        lv3: R.Tensor((1, 128), dtype="float32") = R.matmul(lv2, metadata["relax.expr.Constant"][1], out_dtype="void")
        lv4: R.Tensor((1, 128), dtype="float32") = R.add(lv3, metadata["relax.expr.Constant"][2])
        lv5: R.Tensor((1, 128), dtype="float32") = R.nn.relu(lv4)
        lv6: R.Tensor((128, 10), dtype="float32") = R.permute_dims(metadata["relax.expr.Constant"][3], axes=None)
        lv7: R.Tens

### Step 4: 图优化

eg: 融合matmul & add算子

In [16]:
def create_fuse_dense_add(call: relax.Call, value: relax.Call, value_pre=None, fn_name=None):
    b = call.args[1]
    x = value.args[0]
    w = value.args[1]
    if value_pre is not None:
        w = value_pre.args[0]

    params_x = relax.Var('x', x.struct_info)
    params_w = relax.Var('w', w.struct_info)
    params_b = relax.Var('b', b.struct_info)

    bb = relax.BlockBuilder()
    with bb.function(fn_name, [params_x, params_w, params_b]):
        with bb.dataflow():
            if value_pre is not None:
                lv0 = bb.emit(relax.op.linear(params_x, params_w))
            else:
                lv0 = bb.emit(relax.op.matmul(params_x, params_w))
            
            gv = bb.emit_output(bb.emit(relax.op.add(lv0, params_b)))
        bb.emit_func_output(gv)
    
    fused_fn = bb.get()[fn_name].with_attr("Primitive", 1)
    return fused_fn, x, w, b

@relax.expr_functor.mutator
class DenseAddFusor(relax.PyExprMutator):
    def __init__(self, mod: IRModule) -> None:
        super().__init__()
        self.mod_ = mod
        # cache pre-defined ops
        # permute -> matmul -> add
        self.add_op = tvm.ir.Op.get("relax.add")
        self.permute_op = tvm.ir.Op.get("relax.permute_dims")
        self.matmul_op = tvm.ir.Op.get("relax.matmul")
        self.counter = 0

    def transform(self) -> IRModule:
        for global_var, func in self.mod_.functions.items():
            if not isinstance(func, relax.Function):
                continue
            # avoid already fused primitive functions
            if func.attrs is not None and "Primitive" in func.attrs.keys() and func.attrs["Primitive"] != 0:
                continue
            updated_func = self.visit_expr(func)
            updated_func = relax.analysis.remove_all_unused(updated_func)
            self.builder_.update_func(global_var, updated_func)

        return self.builder_.get()

    def visit_call_(self, call):
        call = self.visit_expr_post_order(call)
        
        def match_call(node, op):
            if not isinstance(node, relax.Call):
                return False
            return node.op == op

        # pattern match dense => add
        if not match_call(call, self.add_op):
            return call

        value = self.lookup_binding(call.args[0])
        if value is None:
            return call

        if not match_call(value, self.matmul_op):
            return call
        
        fn_name = "fused_dense_add%d" % (self.counter)
        self.counter += 1
        if type(value.args[1]) == tvm.relax.expr.Constant:
            fused_fn, x, w, b= create_fuse_dense_add(call, value, fn_name=fn_name)
        else:
            value_1 = self.lookup_binding(value.args[1])
            if not match_call(value_1, self.permute_op):
                fused_fn, x, w, b = create_fuse_dense_add(call, value, fn_name=fn_name)
            else:
                fused_fn, x, w, b = create_fuse_dense_add(call, value, value_1, fn_name)
        # x = value.args[0]
        # w = value_1.args[0]
        # b = call.args[1]

        # # construct a new fused primitive function
        # param_x = relax.Var("x", x.struct_info)
        # param_w = relax.Var("w", w.struct_info)
        # param_b = relax.Var("b", b.struct_info)

        # bb = relax.BlockBuilder()

        # fn_name = "fused_dense_add%d" % (self.counter)
        # self.counter += 1
        # with bb.function(fn_name, [param_x, param_w, param_b]):
        #     with bb.dataflow():
        #         lv0 = bb.emit(relax.op.linear(param_x, param_w))
        #         gv = bb.emit_output(relax.op.add(lv0, param_b))
        #     bb.emit_func_output(gv)

        # # Add Primitive attribute to the fused funtions
        # fused_fn = bb.get()[fn_name].with_attr("Primitive", 1)
        global_var = self.builder_.add_func(fused_fn, fn_name)

        # construct call into the fused function
        return relax.Call(global_var, [x, w, b], None, None)

@tvm.ir.transform.module_pass(opt_level=2, name="DeseAddFuse")
class FuseDenseAddPass:
    """The wrapper for the LowerTensorIR pass."""
    def transform_module(self, mod, ctx):
        return DenseAddFusor(mod).transform()


DemoFused = FuseDenseAddPass()(DemoModel)
DemoFused.show()

To print formatted TVM script, please install the formatter 'Black':
/staff/qiaoliang/anaconda3/envs/MLC/bin/python -m pip install "black==22.3.0" --upgrade --user


: 

### Step 5: 映射到TensorIR Call

In [13]:
from tvm import topi
@relax.expr_functor.mutator
class LowerToTensorIR(relax.PyExprMutator):
    def __init__(self, mod: IRModule, op_map):
        super().__init__()
        self.mod_ = mod
        self.op_map = {
            tvm.ir.Op.get(k): v for k, v in op_map.items()
        }
        self.matmul_op = tvm.ir.Op.get('relax.matmul')
    
    def visit_call_(self, call: relax.Call):
        call = self.visit_expr_post_order(call)

        def match_call(node, op):
            if not isinstance(node, relax.Call):
                return False
            return node.op == op
        
        if call.op in self.op_map:
            call_pre = None
            if match_call(call, self.matmul_op):
                if isinstance(call.args[1], relax.expr.DataflowVar):
                    call_pre = self.lookup_binding(call.args[1])
            return self.op_map[call.op](self.builder_, call, call_pre)
        # print(call.op, type(call.op), call.args)
        return call

    def transform(self):
        for global_var, func in self.mod_.functions.items():
            if not isinstance(func, relax.Function):
                continue
            updated_fn = self.visit_expr(func)
            updated_fn = relax.analysis.remove_all_unused(updated_fn)
            self.builder_.update_func(global_var, updated_fn)

        return self.builder_.get()

# def map_attrs(n):
#     return (x.numpy() for x in n)

def map_add(bb: relax.BlockBuilder, call, call_pre=None):
    x, b = call.args
    return bb.call_te(topi.add, x, b)

def map_matmul(bb: relax.BlockBuilder, call, call_pre=None):
    x, w = call.args
    if call_pre and call_pre.op == tvm.ir.Op.get('relax.permute_dims'):
        w = call_pre.args[0]
        return bb.call_te(topi.nn.dense, x, w)
    return bb.call_te(topi.nn.matmul, x, w)

def map_relu(bb: relax.BlockBuilder, call, call_pre=None):
    x = call.args[0]
    return bb.call_te(topi.nn.relu, x)

def map_conv(bb: relax.BlockBuilder, call: relax.Call, call_pre=None):
    x, k = call.args
    attrs = call.attrs

    return bb.call_te(topi.nn.conv2d, x, k, attrs.strides, attrs.padding, attrs.dilation)

def map_reshape(bb: relax.BlockBuilder, call: relax.Call, call_pre=None):
    x, shape = call.args
    return bb.call_te(topi.reshape, x, shape)

op_map = {
    'relax.matmul': map_matmul,
    'relax.add': map_add,
    'relax.nn.relu': map_relu,
    'relax.nn.conv2d': map_conv,
    'relax.reshape': map_reshape
}

@tvm.ir.transform.module_pass(opt_level=0, name="LowerToTensorIR")
class LowerToTensorIRPass:
    """The wrapper for the LowerTensorIR pass."""
    def transform_module(self, mod, ctx):
        return LowerToTensorIR(mod, op_map).transform()


DemoModelTIR = LowerToTensorIRPass()(DemoFused)
# DemoModelTIR.show()


To print formatted TVM script, please install the formatter 'Black':
/staff/qiaoliang/anaconda3/envs/MLC/bin/python -m pip install "black==22.3.0" --upgrade --user


In [14]:
# relax.op.nn.conv2d

In [20]:
# 融合R.function 与 T.prim_func
DemoModelFinal = relax.transform.FuseTIR()(DemoModelTIR)
# DemoModelFinal.show()

In [21]:
from MLC import mlc
mod_tn = mlc.mlc_tune_tir(DemoModelFinal,target="cuda --max_threads_per_block=1024 --max_shared_memory_per_block=49152",
                           work_dir="./tune_tmp/",
                            task_name='main',
                            max_trials_global=1,
                            num_trials_per_iter=1, compile_tir_target='cuda')

2023-02-13 23:04:24 [INFO] [task_scheduler.cc:260] Task #0 has finished. Remaining task(s): 0


Unnamed: 0,Name,FLOP,Weight,Speed (GFLOPS),Latency (us),Weighted Latency (us),Trials,Done
0,main,1,1,0.0003,2.9171,2.9171,1,Y


2023-02-13 23:04:24 [DEBUG] [task_scheduler.cc:318] 
 ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
---------------------------------------------------------------------------------------------------
  0 | main |    1 |      1 |         0.0003 |       2.9171 |                2.9171 |      1 |    Y 
---------------------------------------------------------------------------------------------------
Total trials: 1
Total latency (us): 2.91708



In [23]:
# mod_tn.show()

### Step 6 算子优化

In [16]:
x = np.random.rand(1, 1, hw, hw).astype('float32')
x_nd_cpu = tvm.nd.array(x)
x_nd_cuda = tvm.nd.array(x, tvm.cuda(0))
DemoModelFinal_ = DemoModelFinal
ex_cpu = relax.vm.build(DemoModelFinal_, target='llvm')
vm_cpu = relax.VirtualMachine(ex_cpu, tvm.cpu(0))

f_timer_cpu = vm_cpu.time_evaluator('main', tvm.cpu(0), number=nrepeat)

print("MyModuleWithParams2 time-cost: %g ms" % (f_timer_cpu(x_nd_cpu).mean * 1000))

MyModuleWithParams2 time-cost: 0.307037 ms


In [20]:
# API只支持一个main的IRModule

In [17]:
# for i in DemoModelFinal.functions:
#     print(i.name_hint)

fn_names = [x.name_hint for x in DemoModelFinal.functions]
fn_names.remove('main')
fn_names

['relu1', 'fused_dense_add1', 'conv2d', 'fused_dense_add0', 'relu', 'reshape']

In [18]:
from tvm import meta_schedule as ms

for fn_name in fn_names:
    print(fn_name)
    mod_ = tvm.IRModule.from_expr(DemoModelFinal[fn_name].with_attr("global_symbol", 'main'))


    tuned_record = ms.tune_tir(mod_, target="cuda --max_threads_per_block=1024 --max_shared_memory_per_block=49152",
                           work_dir="./tune_tmp/",
                            task_name='main',
                            max_trials_global=1,
                            num_trials_per_iter=1)
    
    tuned_sch = ms.tir_integration.compile_tir(tuned_record, mod_, target='cuda')
    new_func = tuned_sch.mod['main'].with_attr("global_symbol", fn_name)
    gv = DemoModelFinal.get_global_var(fn_name)
    DemoModelFinal.update_func(gv, new_func)
    # break



2023-02-13 23:01:23 [INFO] [task_scheduler.cc:260] Task #0 has finished. Remaining task(s): 0


Unnamed: 0,Name,FLOP,Weight,Speed (GFLOPS),Latency (us),Weighted Latency (us),Trials,Done
0,main,1,1,0.0003,3.1002,3.1002,1,Y


2023-02-13 23:01:23 [DEBUG] [task_scheduler.cc:318] 
 ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
---------------------------------------------------------------------------------------------------
  0 | main |    1 |      1 |         0.0003 |       3.1002 |                3.1002 |      1 |    Y 
---------------------------------------------------------------------------------------------------
Total trials: 1
Total latency (us): 3.10024



In [19]:
DemoModelFinal_cuda = DemoModelFinal_
DemoModelFinal_cuda.show()

To print formatted TVM script, please install the formatter 'Black':
/staff/qiaoliang/anaconda3/envs/MLC/bin/python -m pip install "black==22.3.0" --upgrade --user


In [24]:
ex_cuda = relax.vm.build(DemoModelFinal_cuda, target='cuda')
vm_cuda = relax.VirtualMachine(ex_cuda, tvm.cuda(0))


f_timer_cuda = vm_cuda.time_evaluator('main', tvm.cuda(0), number=nrepeat)

print("Demo in cuda time-cost: %g ms" % (f_timer_cuda(x_nd_cuda).mean * 1000))

Demo in cuda time-cost: 0.023727 ms


## Demo
torch 0.34713 -> cpu 0.262574 ms -> auto tune 0.0233462 ms