# TensorIR练习

In [2]:
import numpy as np
import tvm
from tvm.script import tir as T
import IPython
from tvm.ir.module import IRModule

## 2.5.1 TensorIR

### 2.5.1.1 逐位相加

In [3]:
# numpy
a = np.arange(16).reshape(4, 4)
b = np.arange(16, 0, -1).reshape(4, 4)
c_np = a + b

In [4]:
# low-level numpy version
def lnumpy_add(a: np.ndarray, b: np.ndarray, c: np.ndarray):
    for i in range(4):
        for j in range(4):
            c[i, j] = a[i, j] + b[i, j]
c_lnp = np.empty((4, 4), dtype=np.int64)
lnumpy_add(a, b, c_lnp)
np.testing.assert_equal(c_lnp, c_np)

In [5]:
# TensorIR
@tvm.script.ir_module
class MyAdd():
    @T.prim_func
    def add(A: T.Buffer[(4, 4), 'int64'],
            B: T.Buffer[(4, 4), 'int64'],
            C: T.Buffer[(4, 4), 'int64']):
        T.func_attr({'global_symbol': 'add', 'tir.noalias': True})
        for i, j in T.grid(4, 4):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = A[vi, vj] + B[vi, vj]

rt_lib = tvm.build(MyAdd, target='llvm')
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.empty((4, 4), dtype='int64')
rt_lib['add'](a_tvm, b_tvm, c_tvm)
np.testing.assert_equal(c_tvm.numpy(), c_np)


### 2.5.1.2 广播加

In [6]:
# numpy
a = np.arange(16).reshape(4, 4)
b = np.arange(4, 0, -1).reshape(4)
c_np = a + b
c_np

array([[ 4,  4,  4,  4],
       [ 8,  8,  8,  8],
       [12, 12, 12, 12],
       [16, 16, 16, 16]])

In [7]:
# low-level numpy
def lnumpy_b_add(a: np.ndarray, b: np.ndarray, c: np.ndarray):
    for i in range(4):
        for j in range(4):
            c[i, j] = a[i, j] + b[j]

c_lnp = np.empty((4, 4), dtype='int64')
lnumpy_b_add(a, b, c_lnp)
np.testing.assert_equal(c_lnp, c_np)

In [8]:
# TensorIR
@tvm.script.ir_module
class MyBAdd():
    @T.prim_func
    def b_add(A: T.Buffer[(4, 4), 'int64'],
              B: T.Buffer[(4), 'int64'],
              C: T.Buffer[(4, 4), 'int64']):
        T.func_attr({"global_symbol": "b_add", "tir.noalias": True})
        for i, j in T.grid(4, 4):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = A[vi, vj] + B[vj]

rt_lib = tvm.build(MyBAdd, target='llvm')
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.empty((4, 4), dtype='int64')
rt_lib['b_add'](a_tvm, b_tvm, c_tvm)
np.testing.assert_equal(c_tvm.numpy(), c_np)

### 2.5.1.3 二维卷积

$$
Conv[b,k,i,j] = \sum_{di,dj,q}{A[b,q,stride*i+di,strides*j+dj]*W[k,q,di,dj]}
$$

In [9]:
N, Ci, H, W, Co, K = 1, 1, 8, 8, 2, 3
Out_H, Out_W = H - K + 1, W - K + 1
data = np.arange(N*Ci*H*W).reshape(N,Ci,H,W)
kernel = np.arange(K*K*Ci*Co).reshape(Co,Ci,K,K)

In [20]:
# torch version
import torch

data_torch = torch.Tensor(data)
kernel_torch = torch.Tensor(kernel)
conv_torch = torch.nn.functional.conv2d(data_torch, kernel_torch)
conv_torch = conv_torch.numpy().astype(np.int64)
conv_torch

array([[[[ 474,  510,  546,  582,  618,  654],
         [ 762,  798,  834,  870,  906,  942],
         [1050, 1086, 1122, 1158, 1194, 1230],
         [1338, 1374, 1410, 1446, 1482, 1518],
         [1626, 1662, 1698, 1734, 1770, 1806],
         [1914, 1950, 1986, 2022, 2058, 2094]],

        [[1203, 1320, 1437, 1554, 1671, 1788],
         [2139, 2256, 2373, 2490, 2607, 2724],
         [3075, 3192, 3309, 3426, 3543, 3660],
         [4011, 4128, 4245, 4362, 4479, 4596],
         [4947, 5064, 5181, 5298, 5415, 5532],
         [5883, 6000, 6117, 6234, 6351, 6468]]]])

In [36]:
# TensorIR
@tvm.script.ir_module
class MyConv:
    @T.prim_func
    def conv(X: T.Buffer[(1, 1, 8, 8), 'int64'],
             K: T.Buffer[(2, 1, 3, 3), 'int64'],
             Y: T.Buffer[(1, 2, 6, 6), 'int64']):
        T.func_attr({'global_symbol': 'conv', 'tir.noalias': True})
        for n, co, h, w, ci, k1, k2 in T.grid(1, 2, 6, 6, 1, 3, 3):
            with T.block("Y"):
                vn, vco, vh, vw, vci, vk1, vk2 = T.axis.remap("SSSSRRR", [n, co, h, w, ci, k1, k2])
                with T.init():
                    Y[vn, vco, vh, vw] = T.int64(0)
                Y[vn, vco, vh, vw] = Y[vn, vco, vh, vw] + X[vn, vci, vh + k1, vw + k2] * K[vco, vci, vk1, vk2]

# IPython.display.Code(MyConv.script(), language='python')
rt_lib = tvm.build(MyConv, target='llvm')
data_tvm = tvm.nd.array(data)
kernel_tvm = tvm.nd.array(kernel)
print(data_tvm)
print(kernel_tvm)
conv_tvm = tvm.nd.empty((N, Co, Out_H, Out_W), dtype='int64')
rt_lib['conv'](data_tvm, kernel_tvm, conv_tvm)
conv_tvm
# np.testing.assert_allclose(conv_tvm.numpy(), conv_torch, rtol=1e-5)

[[[[ 0  1  2  3  4  5  6  7]
   [ 8  9 10 11 12 13 14 15]
   [16 17 18 19 20 21 22 23]
   [24 25 26 27 28 29 30 31]
   [32 33 34 35 36 37 38 39]
   [40 41 42 43 44 45 46 47]
   [48 49 50 51 52 53 54 55]
   [56 57 58 59 60 61 62 63]]]]
[[[[ 0  1  2]
   [ 3  4  5]
   [ 6  7  8]]]


 [[[ 9 10 11]
   [12 13 14]
   [15 16 17]]]]


<tvm.nd.NDArray shape=(1, 2, 6, 6), cpu(0)>
array([[[[              0,  94682335971344,  94682333359128,
               3296526593,             128,             128],
         [ 94682335293008,  94682333360336,  94682335982315,
           94682335507832,               0,      4294967317],
         [140517446610048,               0,               0,
                        0,               0,  94682333826112],
         [ 94682336148440,             257,              -8,
                      449,  94682335759392,  94682335697648],
         [ 94682335982441,               0,               0,
                        3,  94682335507832,               0],
         [     8589934614, 140517446610000,               0,
                        0,  94682334779424,  94682334779424]],

        [[ 94682333826112,  94682335053056,     77089865985,
          140527034957823,             160,             128],
         [ 94682336285152,  94682335234480,  94682336504459,
           94682335507832,      