In [2]:
from __future__ import annotations
import numpy as np
import tvm
from tvm import relax
from tvm.script import relax as R
from tvm.script import tir as T
from tvm.ir import IRModule

In [1]:
import torch
import torch.nn as nn
import torch.fx as fx

In [269]:
class MyModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.rand((64, 10)))
        self.bias = nn.Parameter(torch.rand((10, )))
        self.linear1 = nn.Linear(10, 3, bias=True)
        self.relu = nn.ReLU()
        self.conv = nn.Conv2d(1, 1, 3, 1, bias=False)
    
    def forward(self, x):
        x = self.conv(x)
        x = torch.relu(x)
        x = x.view((1, -1))
        x = torch.matmul(x, self.weight)
        x = torch.add(x, self.bias)
        x = self.relu(x)
        x = self.linear1(x)

        return x


In [272]:
torch_mod = MyModule()
hw = 10
data = np.random.rand(1, 1, hw, hw).astype('float32')

torch_out = torch_mod(torch.from_numpy(data))

In [273]:
from tvm import te

In [274]:
def te_matmul(x: te.Tensor, w: te.Tensor):
    n = x.shape[0]
    m = w.shape[1]
    k = te.reduce_axis((0, x.shape[1]), 'k')
    return te.compute((n, m), lambda i, j: te.sum(x[i, k] * w[k, j], axis=k), name='matmul')

def te_add(x: te.Tensor, b: te.Tensor):
    n, m = x.shape
    return te.compute(x.shape, lambda i, j: x[i, j] + b[j], name='add')

def te_relu(x: te.Tensor):
    return te.compute(x.shape, lambda *i: te.max(x(*i), 0), name='relu')

def te_view(x: te.Tensor):
    n, c, h, w = x.shape
    m = np.prod(x.shape) // n
    return te.compute((n, m), lambda i, j: x[i, j // (h * w), j % (h * w) // w, j % (h * w) % w], name='view')

def te_conv2d(x: te.Tensor, kernel: te.Tensor):
    n, ci, h, w = x.shape
    _, co, k1, k2 = kernel.shape
    h_1, w_1 = h - k1 + 1, w - k2 + 1
    ci = te.reduce_axis((0, ci), name='ci')
    k1 = te.reduce_axis((0, k1), name='k1')
    k2 = te.reduce_axis((0, k2), name='k2')
    return te.compute((n, co, h_1, w_1), lambda i, j, k, l:
                            te.sum(x[i, ci, k+k1, l+k2] * kernel[j, ci, k1, k2], axis=[ci, k1, k2]), name='conv')

In [275]:
# A = te.placeholder((1,1,10,10), dtype='float32')
# K = te.placeholder((1, 1, 3, 3), dtype='float32')
# C = te_conv2d(A, K)
# print(C.shape)
# te.create_prim_func([A, K, C]).show()


In [276]:
# create computation graph
fx_module = fx.symbolic_trace(torch_mod)
fx_module.graph.print_tabular()

opcode         name     target                                                     args             kwargs
-------------  -------  ---------------------------------------------------------  ---------------  --------
placeholder    x        x                                                          ()               {}
call_module    conv     conv                                                       (x,)             {}
call_function  relu     <built-in method relu of type object at 0x7f8d3b810ec0>    (conv,)          {}
call_method    view     view                                                       (relu, (1, -1))  {}
get_attr       weight   weight                                                     ()               {}
call_function  matmul   <built-in method matmul of type object at 0x7f8d3b810ec0>  (view, weight)   {}
get_attr       bias     bias                                                       ()               {}
call_function  add      <built-in method add of type object at 

In [283]:
for node in fx_module.graph.nodes:
    print(node, type(node), type(node.target))

x <class 'torch.fx.node.Node'> <class 'str'>
conv <class 'torch.fx.node.Node'> <class 'str'>
relu <class 'torch.fx.node.Node'> <class 'builtin_function_or_method'>
view <class 'torch.fx.node.Node'> <class 'str'>
weight <class 'torch.fx.node.Node'> <class 'str'>
matmul <class 'torch.fx.node.Node'> <class 'builtin_function_or_method'>
bias <class 'torch.fx.node.Node'> <class 'str'>
add <class 'torch.fx.node.Node'> <class 'builtin_function_or_method'>
relu_1 <class 'torch.fx.node.Node'> <class 'str'>
linear1 <class 'torch.fx.node.Node'> <class 'str'>
output <class 'torch.fx.node.Node'> <class 'str'>


```python
with bb.function('main):
    with bb.dataflow():
        y = bb.emit_te(func_name, input_)
        ...
        with bb.emit_output(out)
    with bb.emit_func_output(out, fn_inputs)

MyModule = bb.get() # ---> IRModule
```

In [278]:
# 构造映射函数
def map_params(param: nn.Parameter):
    return relax.const(param.data.cpu().numpy(), dtype='float32')

def fetch_attr(fx_mod, target: str):
    target_atoms = target.split('.')
    attr_itr = fx_mod
    for i, atom in enumerate(target_atoms):
        if not hasattr(attr_itr, atom):
            raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
        attr_itr = getattr(attr_itr, atom)
    return attr_itr

def from_fx(fx_module, input_shapes, call_function_map, call_module_map, call_method_map):
    fn_inputs = []
    fn_output = None
    input_index = 0

    node_map = {}
    named_modules = dict(fx_module.named_modules())

    bb = relax.BlockBuilder()

    with bb.function('main'):
        with bb.dataflow():
            for node in fx_module.graph.nodes:
                if node.op == "placeholder":
                    shape = input_shapes[input_index]
                    input_index = input_index + 1
                    fn_input = relax.Var(node.target, R.Tensor(shape=shape, dtype='float32'))
                    fn_inputs.append(fn_input)
                    node_map[node] = fn_input
                elif node.op == "get_attr":
                    node_map[node] = map_params(fetch_attr(fx_module, node.target))
                elif node.op == "call_function":
                    node_map[node] = call_function_map[node.target](bb, node_map, node)
                elif node.op == "call_module":
                    named_module = named_modules[node.target]
                    node_map[node] = call_module_map[type(named_module)](bb, node_map, node, named_module)
                elif node.op == "call_method":
                    node_map[node] = call_method_map[node.target](bb, node_map, node)
                elif node.op == "output":
                    output = node_map[node.args[0]]
                    assert fn_output is None
                    fn_output = bb.emit_output(output)
        bb.emit_func_output(fn_output, fn_inputs)
    return bb.get()


In [279]:
# call_function
def map_matmul(bb, node_map, node):
    x = node_map[node.args[0]]
    w = node_map[node.args[1]]
    return bb.emit_te(te_matmul, x, w)

def map_relu(bb, node_map, node):
    x = node_map[node.args[0]]
    return bb.emit_te(te_relu, x)

def map_add(bb, node_map, node):
    x = node_map[node.args[0]]
    b = node_map[node.args[1]]
    return bb.emit_te(te_add, x, b)

# call_module
from tvm import topi
def map_nn_conv(bb, node_map, node, nn_mod):
    x = node_map[node.args[0]]
    kernel = map_params(nn_mod.weight)
    return bb.emit_te(te_conv2d, x, kernel)

def map_nn_linear(bb, node_map, node, nn_mod):
    x = node_map[node.args[0]]
    w = map_params(nn_mod.weight)
    b = None
    if nn_mod.bias is not None:
        b = map_params(nn_mod.bias)
    return bb.emit_te(topi.nn.dense, x, w, b)
    
def map_nn_relu(bb, node_map, node, nn_mod):
    x = node_map[node.args[0]]
    return bb.emit_te(topi.nn.relu, x)

# call method
def map_view(bb, node_map, node):
    x = node_map[node.args[0]]
    return bb.emit_te(te_view, x)




In [280]:
MyDemo = from_fx(fx_module=fx_module, input_shapes=[(1, 1, hw, hw),],
                    call_function_map={
                        torch.matmul: map_matmul,
                        torch.add: map_add,
                        torch.relu: map_relu
                    },
                    call_module_map={
                        torch.nn.ReLU: map_nn_relu,
                        torch.nn.Conv2d: map_nn_conv,
                        torch.nn.Linear: map_nn_linear
                    },
                    call_method_map={'view': map_view})

In [284]:
MyDemo.show()

To print formatted TVM script, please install the formatter 'Black':
/staff/qiaoliang/anaconda3/envs/MLC/bin/python -m pip install "black==22.3.0" --upgrade --user


In [282]:
ex = relax.vm.build(MyDemo, target='llvm')
vm = relax.VirtualMachine(ex, tvm.cpu())

tvm_out = vm['main'](tvm.nd.array(data))
np.testing.assert_allclose(tvm_out.numpy(), torch_out.detach().cpu().numpy(), rtol=1e-5)