tvm.script.tir
tensor intermedia representation

tvm.script.tir.prim_func 元张量函数
- 张量（Tensor)
    - T.Buffer
    - T.alloc_buffer


- 循环
    1. T.grid
    2. T.axis.spatial, T.axis.reduce


- 计算
    - T.init()
    - T.block()

In [1]:
from tvm.script import tir as T

@T.prim_func
def main(A: T.Buffer[128, "float32"],
         B: T.Buffer[128, "float32"],
         C: T.Buffer[128, "float32"]):
    for i in range(128):
        with T.block("C"):
            vi = T.axis.spatial(128, i)
            C[vi] = A[vi] + B[vi]

low level numpy

In [2]:
import numpy as np

dtype = "float32"
a_np = np.random.rand(128, 128).astype(dtype)
b_np = np.random.rand(128, 128).astype(dtype)
c_mm_relu = np.maximum(a_np @ b_np, 0)

def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((128, 128), dtype="float32")
    for i in range(128):
        for j in range(128):
            for k in range(128):
                if k == 0:
                    Y[i, j] = 0
                Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
                
    for i in range(128):
        for j in range(128):
            C[i, j] = max(Y[i, j], 0)
            
c_np = np.empty((128, 128), dtype=dtype)
lnumpy_mm_relu(a_np, b_np, c_np)
np.testing.assert_allclose(c_mm_relu, c_np, rtol=1e-5)

IRModule 和 Schedule是进行函数变换的入口API
- tvm.ir.module.IRModule
- tvm.tir.Schedule

In [6]:
import tvm
from tvm.ir.module import IRModule

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def mm_relu(A: T.Buffer[(128, 128), "float32"],
                B: T.Buffer[(128, 128), "float32"],
                C: T.Buffer[(128, 128), "float32"]):
        T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                vk = T.axis.reduce(128, k)
                # vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                
                with T.init():
                    Y[vi, vj] = T.float32(0)

                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]

        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

In [7]:
import IPython
IPython.display.Code(MyModule.script(), language="python")

In [18]:
sch = tvm.tir.Schedule(MyModule)
block_Y = sch.get_block("Y", func_name="mm_relu")
i, j, k = sch.get_loops(block_Y)

In [19]:
# loop split
j0, j1 = sch.split(j, factors=[None, 4])
IPython.display.Code(sch.mod.script(), language="python")

In [20]:
# loop reorder
sch.reorder(j0, k, j1)
IPython.display.Code(sch.mod.script(), language="python")

In [21]:
# reverse block_C loops after j0 
block_C = sch.get_block("C", "mm_relu")
sch.reverse_compute_at(block_C, j0)
IPython.display.Code(sch.mod.script(), language="python")

In [22]:
# decompose init and compute at reduction axis
sch.decompose_reduction(block_Y, k)
IPython.display.Code(sch.mod.script(), language='python')

编译运行IRModule
- tvm.build

In [23]:
rt_lib = tvm.build(MyModule, target="llvm")

In [24]:
a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.array(b_np)
c_nd = tvm.nd.empty((128, 128), dtype="float32")

In [25]:
func_mm_relu = rt_lib['mm_relu']
func_mm_relu(a_nd, b_nd, c_nd)
np.testing.assert_allclose(c_mm_relu, c_nd.numpy(), rtol=1e-5)

In [26]:
rt_lib_after = tvm.build(sch.mod, target="llvm")
rt_lib_after["mm_relu"](a_nd, b_nd, c_nd)
np.testing.assert_allclose(c_mm_relu, c_nd.numpy(), rtol=1e-5)

速度测试
- time_evaluator

In [28]:
f_timer_before = rt_lib.time_evaluator("mm_relu", tvm.cpu())
print("Time cost of MyModule %g sec" % f_timer_before(a_nd, b_nd, c_nd).mean)

Time cost of MyModule 0.00181989 sec


In [29]:
f_timer_after = rt_lib_after.time_evaluator("mm_relu", tvm.cpu())
print("Time cost of transformed sch.mod %g sec" % f_timer_after(a_nd, b_nd, c_nd).mean)

Time cost of transformed sch.mod 0.000376969 sec


使用张量表达式构造张量元函数
- tvm.te

In [30]:
from tvm import te

A = te.placeholder((128, 128), "float32", name="A")
B = te.placeholder((128, 128), "float32", name="B")
k = te.reduce_axis((0, 128), "k")

Y = te.compute((128, 128), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="Y")
C = te.compute((128, 128), lambda i, j: te.max(Y[i, j], 0), name="C")

te_func = te.create_prim_func([A, B, C]).with_attr({"global_symbol": "mm_relu"})
MyModuleFromTE = tvm.IRModule({"mm_relu": te_func})

IPython.display.Code(MyModuleFromTE.script(), language="python")