TensorIR 是深度学习领域的特定语言，主要有两个作用：

+ 在各种硬件后端转换和优化程序。
+ 自动`_tensorized_`程序优化的抽象

In [1]:
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T
import numpy as np

IRModule是TVM IR的一种可往返于法，可通过编写TVMScript来创建

与通过张量表达式创建计算表达式不同，TensorIR可以通过TVMScript（一种嵌入在Python AST中的语言）来进行编程

In [2]:
@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def main(a: T.handle, b: T.handle):
        # 通过T.handle进行数据交换，类似于内存指针
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # 通过handle创建Buffer
        A = T.match_buffer(a, (8,), dtype="float32")
        B = T.match_buffer(b, (8,), dtype="float32")
        for i in range(8):
            # block是针对计算的抽象
            with T.block("B"):
                # 定义一个空间（可并行）block迭代器，并且将它的值绑定成i 
                vi = T.axis.spatial(8, i)
                B[vi] = A[vi] + 1.0

ir_module = MyModule
print(type(ir_module))
print(ir_module.script())

<class 'tvm.ir.module.IRModule'>
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((8,), "float32"), B: T.Buffer((8,), "float32")):
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # with T.block("root"):
        for i in range(8):
            with T.block("B"):
                vi = T.axis.spatial(8, i)
                T.reads(A[vi])
                T.writes(B[vi])
                B[vi] = A[vi] + T.float32(1)


还可以使用张亮表达式DSL编写简单的运算符，并将它们转换为IRMudule

In [3]:
from tvm import te

A = te.placeholder((8,), dtype="float32", name="A")
B = te.compute((8,), lambda *i: A(*i) + 1.0, name="B")
func = te.create_prim_func([A, B])
ir_module_from_te = IRModule({"main": func})
print(ir_module_from_te.script())

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((8,), "float32"), B: T.Buffer((8,), "float32")):
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # with T.block("root"):
        for i0 in range(8):
            with T.block("B"):
                v_i0 = T.axis.spatial(8, i0)
                T.reads(A[v_i0])
                T.writes(B[v_i0])
                B[v_i0] = A[v_i0] + T.float32(1)


构建并运行IRModule

In [4]:
mod = tvm.build(ir_module, target="llvm")
print(type(mod))

<class 'tvm.driver.build_module.OperatorModule'>


In [5]:
a = tvm.nd.array(np.arange(8).astype("float32"))
b = tvm.nd.array(np.zeros((8,)).astype("float32"))
mod(a, b)
print(a)
print(b)

[0. 1. 2. 3. 4. 5. 6. 7.]
[1. 2. 3. 4. 5. 6. 7. 8.]


IRModule是核心的数据结构，可以通过`Schedule`进行转换，schedule包含多个primitive方法来交互地转换程序，每个primitive都以特定方式对程序进行转换，从而优化性能

In [6]:
sch = tvm.tir.Schedule(ir_module)
print(type(sch))

<class 'tvm.tir.schedule.schedule.Schedule'>


将嵌套循环展开成3个循环，并打印结果

In [7]:
# 通过名字获取block
block_b = sch.get_block("B")
# 获取包围block的循环
(i,) = sch.get_loops(block_b)
# 展开嵌套循环
i_0, i_1, i_2 = sch.split(i, factors=[2, 2, 2])
print(sch.mod.script())

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((8,), "float32"), B: T.Buffer((8,), "float32")):
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # with T.block("root"):
        for i_0, i_1, i_2 in T.grid(2, 2, 2):
            with T.block("B"):
                vi = T.axis.spatial(8, i_0 * 4 + i_1 * 2 + i_2)
                T.reads(A[vi])
                T.writes(B[vi])
                B[vi] = A[vi] + T.float32(1)


对循环重新排序，例如将循环i_2移到i_1之外

In [8]:
sch.reorder(i_0, i_2, i_1)
print(sch.mod.script())

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((8,), "float32"), B: T.Buffer((8,), "float32")):
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # with T.block("root"):
        for i_0, i_2, i_1 in T.grid(2, 2, 2):
            with T.block("B"):
                vi = T.axis.spatial(8, i_0 * 4 + i_1 * 2 + i_2)
                T.reads(A[vi])
                T.writes(B[vi])
                B[vi] = A[vi] + T.float32(1)


转换为GPU程序，用原语进行增量转换

绑定线程后用cuda后段来构建IRModule

In [9]:
sch.bind(i_0, "blockIdx.x")
sch.bind(i_2, "threadIdx.x")
print(sch.mod.script())

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((8,), "float32"), B: T.Buffer((8,), "float32")):
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # with T.block("root"):
        for i_0 in T.thread_binding(2, thread="blockIdx.x"):
            for i_2 in T.thread_binding(2, thread="threadIdx.x"):
                for i_1 in range(2):
                    with T.block("B"):
                        vi = T.axis.spatial(8, i_0 * 4 + i_1 * 2 + i_2)
                        T.reads(A[vi])
                        T.writes(B[vi])
                        B[vi] = A[vi] + T.float32(1)
