<a href="https://colab.research.google.com/github/yingtongxiong/MLCnotebooks/blob/xytExercise/2_4_exercise.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!python3 -m  pip install mlc-ai-nightly -f https://mlc.ai/wheels

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://mlc.ai/wheels
Collecting mlc-ai-nightly
  Downloading https://github.com/mlc-ai/utils/releases/download/v0.9.dev0/mlc_ai_nightly-0.9.dev2227%2Bg06c0ef3be-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (45.5 MB)
[K     |████████████████████████████████| 45.5 MB 2.3 MB/s 
Collecting synr==0.6.0
  Downloading synr-0.6.0-py3-none-any.whl (18 kB)
Installing collected packages: synr, mlc-ai-nightly
Successfully installed mlc-ai-nightly-0.9.dev2227+g06c0ef3be synr-0.6.0


In [3]:
import IPython
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T

In [5]:
# init data
a = np.arange(16).reshape(4, 4)
b = np.arange(16, 0, -1).reshape(4, 4)

In [6]:
# low-level numpy version add
def lnumpy_add(a: np.ndarray, b: np.ndarray, c: np.ndarray):
  for i in range(4):
    for j in range(4):
      c[i, j] = a[i, j] + b[i, j]

c_lnumpy = np.empty((4, 4), dtype=np.int64)
lnumpy_add(a, b, c_lnumpy)
c_lnumpy

array([[16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16]])

In [10]:
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer[(4, 4), "int64"],
          B: T.Buffer[(4, 4), "int64"],
          C: T.Buffer[(4, 4), "int64"]):
    T.func_attr({"global_symbol": "add"})
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        C[vi, vj] = A[vi, vj] + B[vi, vj]

rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_lnumpy, rtol=1e-5)

## 练习1: 广播加法

In [11]:
# init data
a = np.arange(16).reshape(4, 4)
b = np.arange(4, 0, -1).reshape(4)

In [12]:
a

array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15]])

In [13]:
b

array([4, 3, 2, 1])

In [14]:
# numpy version
c_np = a + b
c_np

array([[ 4,  4,  4,  4],
       [ 8,  8,  8,  8],
       [12, 12, 12, 12],
       [16, 16, 16, 16]])

In [33]:
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer[(4, 4), "int64"],
          B: T.Buffer[(4,), "int64"],
          C: T.Buffer[(4, 4), "int64"]):
    T.func_attr({"global_symbol": "add", "tir.noalias": True})
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        C[vi, vj] = A[vi, vj] + B[vj]

rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)

## 练习2: 二维卷积

In [35]:
N, CI, H, W, CO, K = 1, 1, 8, 8, 2, 3
OUT_H, OUT_W = H - K + 1, W - K + 1
data = np.arange(N * CI * H * W).reshape(N, CI, H, W)
weight = np.arange(CO * CI * K * K).reshape(CO, CI, K, K)

In [36]:
import torch

data_torch = torch.Tensor(data)
weight_torch = torch.Tensor(weight)
conv_torch = torch.nn.functional.conv2d(data_torch, weight_torch)
conv_torch = conv_torch.numpy().astype(np.int64)
conv_torch

array([[[[ 474,  510,  546,  582,  618,  654],
         [ 762,  798,  834,  870,  906,  942],
         [1050, 1086, 1122, 1158, 1194, 1230],
         [1338, 1374, 1410, 1446, 1482, 1518],
         [1626, 1662, 1698, 1734, 1770, 1806],
         [1914, 1950, 1986, 2022, 2058, 2094]],

        [[1203, 1320, 1437, 1554, 1671, 1788],
         [2139, 2256, 2373, 2490, 2607, 2724],
         [3075, 3192, 3309, 3426, 3543, 3660],
         [4011, 4128, 4245, 4362, 4479, 4596],
         [4947, 5064, 5181, 5298, 5415, 5532],
         [5883, 6000, 6117, 6234, 6351, 6468]]]])

In [37]:
conv_torch.shape

(1, 2, 6, 6)

In [42]:
# low-level numpy implementation
def lnumpy_conv2d(data: np.ndarray, weight: np.ndarray):
  N, CI, H, W = data.shape
  CO, CI, K, K = weight.shape
  OUT_H, OUT_W = H - K + 1, W - K + 1
  output = np.zeros((N, CO, OUT_H, OUT_W))
  for i in range(N):
    for j in range(CO):
      for i1 in range(OUT_H):
        for j1 in range(OUT_W):
          for ci in range(CI):
            for i11 in range(K):
              for j11 in range(K):
                output[i, j, i1, j1] += weight[j, ci, i11, j11] * data[i, ci, i1 + i11, j1 + j11]
  

  return output

output_numpy = lnumpy_conv2d(data, weight)
np.testing.assert_allclose(output_numpy, conv_torch, rtol=1e-5)



In [67]:
# A:(N, CI, H, W)
# B:(CO, CI, K, K)
# C: (N, CO, H - K + 1, W - K + 1)
# strides = 1 and padding = 0
@tvm.script.ir_module
class MyConv:
  @T.prim_func
  def conv(data: T.Buffer[(1, 1, 8, 8), "int64"],
           weight: T.Buffer[(2, 1, 3, 3), "int64"],
           output: T.Buffer[(1, 2, 6, 6), "int64"]):
    T.func_attr({"global_symbol": "conv", "tir.noalias": True})
    for i, j, i1, j1, ci, i11, j11 in T.grid(1, 2, 6, 6, 1, 3, 3):
      with T.block("C"):
        vi, vj, vi1, vj1 = T.axis.remap("SSSS", (i, j, i1, j1))
        vci, vi11, vj11 = T.axis.remap("RRR", (ci, i11, j11))
        with T.init():
          output[vi, vj, vi1, vj1] = T.int64(0)
        output[vi, vj, vi1, vj1] += weight[vj, vci, vi11, vj11] * data[vi, vci, vi1 + vi11, vj1 + vj11]


rt_lib = tvm.build(MyConv, target="llvm")
data_tvm = tvm.nd.array(data)
weight_tvm = tvm.nd.array(weight)
conv_tvm = tvm.nd.array(np.empty((N, CO, OUT_H, OUT_W), dtype=np.int64))
rt_lib["conv"](data_tvm, weight_tvm, conv_tvm)
# # print(conv_tvm)
np.testing.assert_allclose(conv_tvm.numpy(), conv_torch, rtol=1e-5)

## 练习3: 变化批量矩阵乘法程序

In [74]:
@tvm.script.ir_module
class MyBmmRelu:
  @T.prim_func
  def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"],
               B: T.Buffer[(16, 128, 128), "float32"],
               C: T.Buffer[(16, 128, 128), "float32"]):
    T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
    Y = T.alloc_buffer((16, 128, 128), dtype="float32")
    for n, i, j, k in T.grid(16, 128, 128, 128):
      with T.block("Y"):
        vn, vi, vj, vk = T.axis.remap("SSSR", (n, i, j, k))
        with T.init():
          Y[vn, vi, vj] = T.float32(0)
        Y[vn, vi, vj] += A[vn, vi, vk] * B[vn, vk, vj]
    
    for n, i, j in T.grid(16, 128, 128):
      with T.block("C"):
        vn, vi, vj = T.axis.remap("SSS", (n, i, j))
        C[vn, vi, vj] = T.max(Y[vn, vi, vj], T.float32(0))


        



In [98]:
sch = tvm.tir.Schedule(MyBmmRelu)
# TODO: transformations
# Hints: you can use
# `IPython.display.Code(sch.mod.script(), language="python")`
# or `print(sch.mod.script())`
# to show the current program at any time during the transformation.

# Step 1. Get blocks
Y = sch.get_block("Y", func_name="bmm_relu")

# Step 2. Get loops
n, i, j, k = sch.get_loops(Y)
sch.parallel(n)

# Step 3. Organize the loops
j0, j1 = sch.split(j, factors=[None, 8])
sch.reorder(j0, k, j1)
C = sch.get_block("C", func_name="bmm_relu")
sch.reverse_compute_at(C, j0)

# Step 4. decompose reduction
Y_init = sch.decompose_reduction(Y, k)
n, i, j_0, j_1_init = sch.get_loops(Y_init)
_, _, _, ax0 = sch.get_loops(C)
Y_update = sch.get_block("Y_update", func_name="bmm_relu")
_, _, _, k, j_1 = sch.get_loops(Y_update)
k0, k1 = sch.split(k, factors=[None, 4])

# Step 5. vectorize / parallel / unroll
sch.vectorize(j_1_init)
sch.vectorize(ax0)
sch.unroll(k1)


IPython.display.Code(sch.mod.script(), language="python")

In [99]:
before_rt_lib = tvm.build(MyBmmRelu, target="llvm")
after_rt_lib = tvm.build(sch.mod, target="llvm")
a_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
b_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
c_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
after_rt_lib["bmm_relu"](a_tvm, b_tvm, c_tvm)
before_timer = before_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("Before transformation:")
print(before_timer(a_tvm, b_tvm, c_tvm))

f_timer = after_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("After transformation:")
print(f_timer(a_tvm, b_tvm, c_tvm))

Before transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  53.1903      53.1903      53.1903      53.1903       0.0000   
               
After transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  15.2437      15.2437      15.2437      15.2437       0.0000   
               
