# TensorIR练习

In [1]:
import numpy as np
import tvm
from tvm.script import tir as T
import IPython
from tvm.ir.module import IRModule

## 2.5.1 TensorIR

### 2.5.1.1 逐位相加

In [2]:
# numpy
a = np.arange(16).reshape(4, 4)
b = np.arange(16, 0, -1).reshape(4, 4)
c_np = a + b

In [3]:
# low-level numpy version
def lnumpy_add(a: np.ndarray, b: np.ndarray, c: np.ndarray):
    for i in range(4):
        for j in range(4):
            c[i, j] = a[i, j] + b[i, j]
c_lnp = np.empty((4, 4), dtype=np.int64)
lnumpy_add(a, b, c_lnp)
np.testing.assert_equal(c_lnp, c_np)

In [4]:
# TensorIR
@tvm.script.ir_module
class MyAdd():
    @T.prim_func
    def add(A: T.Buffer[(4, 4), 'int64'],
            B: T.Buffer[(4, 4), 'int64'],
            C: T.Buffer[(4, 4), 'int64']):
        T.func_attr({'global_symbol': 'add', 'tir.noalias': True})
        for i, j in T.grid(4, 4):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = A[vi, vj] + B[vi, vj]

rt_lib = tvm.build(MyAdd, target='llvm')
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.empty((4, 4), dtype='int64')
rt_lib['add'](a_tvm, b_tvm, c_tvm)
np.testing.assert_equal(c_tvm.numpy(), c_np)


### 2.5.1.2 练习1：广播加

In [5]:
# numpy
a = np.arange(16).reshape(4, 4)
b = np.arange(4, 0, -1).reshape(4)
c_np = a + b
c_np

array([[ 4,  4,  4,  4],
       [ 8,  8,  8,  8],
       [12, 12, 12, 12],
       [16, 16, 16, 16]])

In [6]:
# low-level numpy
def lnumpy_b_add(a: np.ndarray, b: np.ndarray, c: np.ndarray):
    for i in range(4):
        for j in range(4):
            c[i, j] = a[i, j] + b[j]

c_lnp = np.empty((4, 4), dtype='int64')
lnumpy_b_add(a, b, c_lnp)
np.testing.assert_equal(c_lnp, c_np)

In [7]:
# TensorIR
@tvm.script.ir_module
class MyBAdd():
    @T.prim_func
    def b_add(A: T.Buffer[(4, 4), 'int64'],
              B: T.Buffer[(4), 'int64'],
              C: T.Buffer[(4, 4), 'int64']):
        T.func_attr({"global_symbol": "b_add", "tir.noalias": True})
        for i, j in T.grid(4, 4):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = A[vi, vj] + B[vj]

rt_lib = tvm.build(MyBAdd, target='llvm')
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.empty((4, 4), dtype='int64')
rt_lib['b_add'](a_tvm, b_tvm, c_tvm)
np.testing.assert_equal(c_tvm.numpy(), c_np)

### 2.5.1.3 练习2：二维卷积

$$
Conv[b,k,i,j] = \sum_{di,dj,q}{A[b,q,stride*i+di,strides*j+dj]*W[k,q,di,dj]}
$$

In [8]:
N, Ci, H, W, Co, K = 1, 1, 8, 8, 2, 3
Out_H, Out_W = H - K + 1, W - K + 1
data = np.arange(N*Ci*H*W).reshape(N,Ci,H,W)
kernel = np.arange(K*K*Ci*Co).reshape(Co,Ci,K,K)

In [9]:
# torch version
import torch

data_torch = torch.Tensor(data)
kernel_torch = torch.Tensor(kernel)
conv_torch = torch.nn.functional.conv2d(data_torch, kernel_torch)
conv_torch = conv_torch.numpy().astype(np.int64)
conv_torch

array([[[[ 474,  510,  546,  582,  618,  654],
         [ 762,  798,  834,  870,  906,  942],
         [1050, 1086, 1122, 1158, 1194, 1230],
         [1338, 1374, 1410, 1446, 1482, 1518],
         [1626, 1662, 1698, 1734, 1770, 1806],
         [1914, 1950, 1986, 2022, 2058, 2094]],

        [[1203, 1320, 1437, 1554, 1671, 1788],
         [2139, 2256, 2373, 2490, 2607, 2724],
         [3075, 3192, 3309, 3426, 3543, 3660],
         [4011, 4128, 4245, 4362, 4479, 4596],
         [4947, 5064, 5181, 5298, 5415, 5532],
         [5883, 6000, 6117, 6234, 6351, 6468]]]])

In [10]:
# TensorIR
@tvm.script.ir_module
class MyConv:
    @T.prim_func
    def conv(X: T.Buffer[(1, 1, 8, 8), 'int64'],
             K: T.Buffer[(2, 1, 3, 3), 'int64'],
             Y: T.Buffer[(1, 2, 6, 6), 'int64']):
        T.func_attr({'global_symbol': 'conv', 'tir.noalias': True})
        for n, co, h, w, ci, k1, k2 in T.grid(1, 2, 6, 6, 1, 3, 3):
            with T.block("Y"):
                vn, vco, vh, vw, vci, vk1, vk2 = T.axis.remap("SSSSRRR", [n, co, h, w, ci, k1, k2])
                with T.init():
                    Y[vn, vco, vh, vw] = T.int64(0)
                Y[vn, vco, vh, vw] = Y[vn, vco, vh, vw] + X[vn, vci, vh + k1, vw + k2] * K[vco, vci, vk1, vk2]

# IPython.display.Code(MyConv.script(), language='python')
rt_lib = tvm.build(MyConv, target='llvm')
data_tvm = tvm.nd.array(data)
kernel_tvm = tvm.nd.array(kernel)
conv_tvm = tvm.nd.empty((N, Co, Out_H, Out_W), dtype='int64')
rt_lib['conv'](data_tvm, kernel_tvm, conv_tvm)
np.testing.assert_allclose(conv_tvm.numpy(), conv_torch, rtol=1e-5)

## 2.5.2 第二节：如何变换TensorIR

### 2.5.2.1 并行化、向量化与循环展开

parallel & vectorize & unroll

In [11]:
@tvm.script.ir_module
class MyAdd:
    @T.prim_func
    def add(A: T.Buffer[(4, 4), 'int64'],
            B: T.Buffer[(4, 4), 'int64'],
            C: T.Buffer[(4, 4), 'int64']):
        T.func_attr({"global_symbol": "add", "tir.noalias": True})
        for i, j in T.grid(4, 4):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = A[vi, vj] + B[vi, vj]

sch = tvm.tir.Schedule(MyAdd)
block_C = sch.get_block('C', 'add')
i, j = sch.get_loops(block_C)
i0, i1 = sch.split(i, [2, 2])
sch.parallel(i0)
sch.unroll(i1)
sch.vectorize(j)
IPython.display.Code(sch.mod.script(), language='python')

### 2.5.2.2 练习3：变换批量矩阵乘法程序

In [12]:
# bmm_relu numpy version
def lnumpy_bmm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty_like(C)
    for n in range(16):
        for i in range(128):
            for j in range(128):
                for k in range(128):
                    if k == 0:
                        Y[n, i, j] = 0
                    Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
        
    
    for n in range(16):
        for i in range(128):
            for j in range(128):
                C[n, i, j] = max(Y[n, i, j], 0)

a = np.random.rand(16, 128, 128).astype('float32')
b = np.random.rand(16, 128, 128).astype('float32')
c = np.empty((16, 128, 128)).astype('float32')
lnumpy_bmm_relu(a, b, c)

In [13]:
# TensorIR version
@tvm.script.ir_module
class MyBmmRelu():
    @T.prim_func
    def bmm_relu(A: T.Buffer[(16, 128, 128), 'float32'],
                 B: T.Buffer[(16, 128, 128), 'float32'],
                 C: T.Buffer[(16, 128, 128), 'float32']):
        T.func_attr({'global_symbol': "bmm_relu", "tir.noalias": True})
        Y = T.alloc_buffer((16, 128, 128), 'float32')
        for n, i, j, k in T.grid(16, 128, 128, 128):
            with T.block("Y"):
                vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k])
                with T.init():
                    Y[vn, vi, vj] = T.float32(0)
                Y[vn, vi, vj] = Y[vn, vi, vj] + A[vn, vi, vk] * B[vn, vk, vj]
        
        for n, i, j in T.grid(16, 128, 128):
            with T.block("C"):
                vn, vi, vj = T.axis.remap("SSS", [n, i, j])
                C[vn, vi, vj] = T.max(Y[vn, vi, vj], T.float32(0))

rt_lib = tvm.build(MyBmmRelu, target='llvm')
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.empty((16, 128, 128), dtype='float32')
rt_lib['bmm_relu'](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c, rtol=1e-5)
IPython.display.Code(MyBmmRelu.script(), language='python')

In [14]:
sch = tvm.tir.Schedule(MyBmmRelu)
# 1.get_block
block_Y = sch.get_block('Y', 'bmm_relu')
block_C = sch.get_block('C', 'bmm_relu')
# 2.get_loops
n, i, j, k = sch.get_loops(block_Y)

# 3.organize the loops
j0, j1 = sch.split(j, factors=[16, 8])
k0, k1 = sch.split(k, [32, 4])
sch.reorder(j0, k0, k1, j1)
sch.reverse_compute_at(block_C, j0)


# 4 parallel/vectorize/unroll
sch.parallel(n)
sch.unroll(k1)

# 5.decompose reduction
Y_init = sch.decompose_reduction(block_Y,k0)

n, i, j0, j1 = sch.get_loops(block_C)
sch.vectorize(j1)
n, i, j0, j1 = sch.get_loops(Y_init)
sch.vectorize(j1)

IPython.display.Code(sch.mod.script(), language='python')


In [15]:
rt_lib_after = tvm.build(sch.mod)
c_tvm_after = tvm.nd.empty((16, 128, 128), 'float32')
rt_lib_after['bmm_relu'](a_tvm, b_tvm, c_tvm_after)
np.testing.assert_allclose(c_tvm_after.numpy(), c, rtol=1e-5)

In [16]:
@tvm.script.ir_module
class TargetModule:
    @T.prim_func
    def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"], B: T.Buffer[(16, 128, 128), "float32"], C: T.Buffer[(16, 128, 128), "float32"]) -> None:
        T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        Y = T.alloc_buffer([16, 128, 128], dtype="float32")
        for i0 in T.parallel(16):
            for i1, i2_0 in T.grid(128, 16):
                for ax0_init in T.vectorized(8):
                    with T.block("Y_init"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + ax0_init)
                        Y[n, i, j] = T.float32(0)
                for ax1_0 in T.serial(32):
                    for ax1_1 in T.unroll(4):
                        for ax0 in T.serial(8):
                            with T.block("Y_update"):
                                n, i = T.axis.remap("SS", [i0, i1])
                                j = T.axis.spatial(128, i2_0 * 8 + ax0)
                                k = T.axis.reduce(128, ax1_0 * 4 + ax1_1)
                                Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
                for i2_1 in T.vectorized(8):
                    with T.block("C"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + i2_1)
                        C[n, i, j] = T.max(Y[n, i, j], T.float32(0))

In [17]:
tvm.ir.assert_structural_equal(sch.mod, TargetModule)
print("Pass")

Pass


### 2.5.2.3 构建和评估

In [18]:
before_timer = rt_lib.time_evaluator('bmm_relu', tvm.cpu(), number=500, repeat=1)
after_timer = rt_lib_after.time_evaluator('bmm_relu', tvm.cpu(), number=500, repeat=1)

print("Before transformation:")
print(before_timer(a_tvm, b_tvm, c_tvm).mean)

print("After transformation:")
print(after_timer(a_tvm, b_tvm, c_tvm_after).mean)

Before transformation:
0.036769920342
After transformation:
0.001340763776
