<a href="https://colab.research.google.com/github/yingtongxiong/MLCnotebooks/blob/xytExercise/2_4_TensorIR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!python3 -m  pip install mlc-ai-nightly -f https://mlc.ai/wheels

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://mlc.ai/wheels


In [None]:
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T

In [None]:
dtype = "float32"
a_np = np.random.rand(128, 128).astype(dtype)
b_np = np.random.rand(128, 128).astype(dtype)
# a @ b is equal to np.matmul(a, b)
c_mm_relu = np.maximum(a_np @ b_np, 0)

In [None]:
def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
  Y = np.empty((128, 128), dtype="float32")
  for i in range(128):
    for j in range(128):
      for k in range(128):
        if k == 0:
          Y[i, j] = 0
        Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
  for i in range(128):
    for j in range(128):
      C[i, j] = max(Y[i, j], 0)

In [None]:
c_np = np.empty((128, 128), dtype=dtype)
lnumpy_mm_relu(a_np, b_np, c_np)
np.testing.assert_allclose(c_mm_relu, c_np, rtol=1e-5)

In [None]:
@tvm.script.ir_module
class MyModule:
  @T.prim_func
  def mm_relu(A: T.Buffer[(128, 128), "float32"],
              B: T.Buffer[(128, 128), "float32"],
              C: T.Buffer[(128, 128), "float32"]):
    T.func_attr({"global_symbol":"mm_relu", "tir.noalias": True})
    Y = T.alloc_buffer((128, 128), dtype="float32")
    for i, j, k in T.grid(128, 128, 128):
      with T.block("Y"):
        vi = T.axis.spatial(128, i)
        vj = T.axis.spatial(128, j)
        vk = T.axis.reduce(128, k)
        with T.init():
          Y[vi, vj] = T.float32(0)
        Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
    for i, j in T.grid(128, 128):
      with T.block("C"):
        vi = T.axis.spatial(128, i)
        vj = T.axis.spatial(128, j)
        C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

In [None]:
def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
  Y = np.empty((128, 128), dtype="float32")
  for i in range(128):
    for j0 in range(32):
      for k in range(128):
        for j1 in range(4):
          j = j0 * 4 + j1
          if k == 0:
            Y[i, j] = 0
          Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
  
  for i in range(128):
    for j in range(128):
      C[i, j] = max(Y[i, j], 0)

c_np = np.empty((128, 128), dtype=dtype)
lnumpy_mm_relu_v2(a_np, b_np, c_np)
np.testing.assert_allclose(c_mm_relu, c_np, rtol=1e-5)

In [None]:
import IPython
IPython.display.Code(MyModule.script(), language="python")

In [None]:
sch = tvm.tir.Schedule(MyModule)

In [None]:
block_Y = sch.get_block("Y", func_name="mm_relu")
i, j, k = sch.get_loops(block_Y)

In [None]:
j0, j1 = sch.split(j, factors=[None, 4])

In [None]:
IPython.display.Code(sch.mod.script(), language="python")

In [None]:
sch.reorder(j0, k, j1)
IPython.display.Code(sch.mod.script(), language="python")

In [None]:
block_C = sch.get_block("C", "mm_relu")
sch.reverse_compute_at(block_C, j0)
IPython.display.Code(sch.mod.script(), language="python")

In [None]:
sch.decompose_reduction(block_Y, k)
IPython.display.Code(sch.mod.script(), language="python")

In [None]:
rt_lib = tvm.build(MyModule, target="llvm")

In [None]:
a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.array(b_np)
c_nd = tvm.nd.empty((128, 128), dtype="float32")
type(c_nd)

tvm.runtime.ndarray.NDArray

In [None]:
func_mm_relu = rt_lib["mm_relu"]
func_mm_relu(a_nd, b_nd, c_nd)
np.testing.assert_allclose(c_mm_relu, c_nd.numpy(), rtol=1e-5)

In [None]:
rt_lib_after = tvm.build(sch.mod, target="llvm")
rt_lib_after["mm_relu"](a_nd, b_nd, c_nd)
np.testing.assert_allclose(c_mm_relu, c_nd.numpy(), rtol=1e-5)

In [None]:
f_timer_before = rt_lib.time_evaluator("mm_relu", tvm.cpu())
print("Time cost of MyModule %g sec" % f_timer_before(a_nd, b_nd, c_nd).mean)
f_timer_after = rt_lib_after.time_evaluator("mm_relu", tvm.cpu())
print("Time cost of transformed sch.mod %g sec" % f_timer_after(a_nd, b_nd, c_nd).mean)

Time cost of MyModule 0.00515859 sec
Time cost of transformed sch.mod 0.000732008 sec


In [None]:
def transform(mod, jfactor):
  sch = tvm.tir.Schedule(mod)
  block_Y = sch.get_block("Y", func_name="mm_relu")
  i, j, k = sch.get_loops(block_Y)
  j0, j1 = sch.split(j, factors=[None, jfactor])
  sch.reorder(j0, k, j1)
  block_C = sch.get_block("C", "mm_relu")
  sch.reverse_compute_at(block_C, j0)
  return sch.mod

mod_transformed = transform(MyModule, jfactor=8)

rt_lib_transformed = tvm.build(mod_transformed, "llvm")
f_timer_transformed = rt_lib_transformed.time_evaluator("mm_relu", tvm.cpu())
print("Time cost of transformed mod_transformed %g sec" % f_timer_transformed(a_nd, b_nd, c_nd).mean)
# display the code below
IPython.display.Code(mod_transformed.script(), language="python")

Time cost of transformed mod_transformed 0.000651081 sec


In [None]:
#使用张量表达式生成TensorIR代码

In [None]:
from tvm import te

A = te.placeholder((128, 128), "float32", name="A")
B = te.placeholder((128, 128), "float32", name="B")
k = te.reduce_axis((0, 128), "k")
Y = te.compute((128, 128), lambda i, j: te.sum(A[i, j] * B[k, j], axis=k), name="Y")
C = te.compute((128, 128), lambda i, j: te.max(Y[i, j], 0), name="C")

In [None]:
te_func = te.create_prim_func([A, B, C]).with_attr({"global_symbol": "mm_relu"})
MyModuleFromTE = tvm.IRModule({"mm_relu": te_func})
print("Time cost of transformed mod_transformed %g sec" % f_timer_transformed(a_nd, b_nd, c_nd).mean)
# display the code below
IPython.display.Code(MyModuleFromTE.script(), language="python")

Time cost of transformed mod_transformed 0.00135466 sec
