<a href="https://colab.research.google.com/github/quxiaojing1985/OpenSearch/blob/main/Untitled3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!python3 -m  pip install mlc-ai-nightly -f https://mlc.ai/wheels

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://mlc.ai/wheels
Collecting mlc-ai-nightly
  Downloading https://github.com/mlc-ai/utils/releases/download/v0.9.dev0/mlc_ai_nightly-0.9.dev1956%2Bge3f218d71-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (44.2 MB)
[K     |████████████████████████████████| 44.2 MB 657 kB/s 
Collecting synr==0.6.0
  Downloading synr-0.6.0-py3-none-any.whl (18 kB)
Installing collected packages: synr, mlc-ai-nightly
Successfully installed mlc-ai-nightly-0.9.dev1956+ge3f218d71 synr-0.6.0


In [2]:
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T
import numpy as np
import IPython

In [3]:
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer[(4, 4), "int64"],
          B: T.Buffer[(4, 4), "int64"],
          C: T.Buffer[(4, 4), "int64"]):
    T.func_attr({"global_symbol": "add", "tir.noalias" : True})
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        C[vi, vj] = A[vi, vj] + B[vi, vj]

sch =  tvm.tir.Schedule(MyAdd)
block = sch.get_block("C", func_name = "add")
i,j = sch.get_loops(block)
i0, i1 = sch.split(i, factors = [2,2])
sch.parallel(i0)
sch.unroll(i1)
sch.vectorize(j)


In [4]:
import IPython

def code2html(code):
    """Helper function to use pygments to turn the code string into highlighted html."""
    import pygments
    from pygments.lexers import Python3Lexer
    from pygments.formatters import HtmlFormatter
    formatter = HtmlFormatter()
    html = pygments.highlight(code, Python3Lexer(), formatter)
    return "<style>%s</style>%s\n" % (formatter.get_style_defs(".highlight"), html)

In [5]:
IPython.display.HTML(code2html(sch.mod.script()))

2.5.2.2 练习3：变批量矩阵乘法程序

In [16]:
def lnumpy_mm_relu_v2(A: np.ndarray, B:np.ndarray, C:np.ndarray):
  Y =  np.empty((16, 128, 128), dtype="float32")
  for n in range(16):
    for i in range(128):
      for j in range(128):
        for k in range(128):
          if k == 0:
            Y[n,i, j] = 0
          Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
  for n in range(16):
    for i in range(128):
      for j in range(128):
        C[n, i, j] = max(Y[n, i, j], 0)

tvm 版本

In [7]:
@tvm.script.ir_module
class MyBmmRelu:
  @T.prim_func
  def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"],
               B: T.Buffer[(16, 128, 128), "float32"],
               C: T.Buffer[(16, 128, 128), "float32"]):
    T.func_attr({"global_symbol" : "bmm_relu", "tir.noalias" : True})
    Y = T.alloc_buffer([16, 128,128], dtype= "float32")
    for n, i, j, k in T.grid(16, 128, 128, 128):
      with T.block("Y"):
        vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k])
        with T.init():
          Y[vn, vi, vj] = T.float32(0)
        Y[vn, vi, vj] = Y[vn, vi, vj] + A[vn, vi, vk]* B[vn, vk, vj]
    for n, i, j in T.grid(16, 128, 128):
      with T.block("C"):
        vn, vi, vj = T.axis.remap("SSS", [n, i, j])
        C[vn, vi, vj] = T.max(Y[vn, vi, vj], T.float32(0))
         




添加测试程序，测试一下tvm版本和numpy版本的代码是否结果一致

In [19]:
a_np = np.random.rand(16, 128, 128).astype("float32")
b_np = np.random.rand(16, 128, 128).astype("float32")
c_np = np.empty((16, 128, 128), dtype= np.float32)
lnumpy_mm_relu_v2(a_np, b_np, c_np)


In [23]:
a_tvm=tvm.nd.array(a_np)
b_tvm = tvm.nd.array(b_np)
c_tvm = tvm.nd.array(np.empty((16, 128, 128), dtype = np.float32))
rt_lib = tvm.build(MyBmmRelu, target="llvm")
rt_lib["bmm_relu"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)

目标程序如下：

In [61]:
@tvm.script.ir_module
class TargetModule:
    @T.prim_func
    def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"], B: T.Buffer[(16, 128, 128), "float32"], C: T.Buffer[(16, 128, 128), "float32"]) -> None:
        T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        Y = T.alloc_buffer([16, 128, 128], dtype="float32")
        for i0 in T.parallel(16):
            for i1, i2_0 in T.grid(128, 16):
                for ax0_init in T.vectorized(8):
                    with T.block("Y_init"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + ax0_init)
                        Y[n, i, j] = T.float32(0)
                for ax1_0 in T.serial(32):
                    for ax1_1 in T.unroll(4):
                        for ax0 in T.serial(8):
                            with T.block("Y_update"):
                                n, i = T.axis.remap("SS", [i0, i1])
                                j = T.axis.spatial(128, i2_0 * 8 + ax0)
                                k = T.axis.reduce(128, ax1_0 * 4 + ax1_1)
                                Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
                for i2_1 in T.vectorized(8):
                    with T.block("C"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + i2_1)
                        C[n, i, j] = T.max(Y[n, i, j], T.float32(0))

变化的过程

In [75]:
sch = tvm.tir.Schedule(MyBmmRelu)

# Step 1. Get blocks
Y = sch.get_block("Y", func_name="bmm_relu")

# Step 2. Get loops
b, i, j, k = sch.get_loops(Y)
j0, j1 = sch.split(j, factors=[None, 8])

#step  3.
k0, k1 = sch.split(k, factors = [None, 4])
sch.reorder(k0, k1, j1)
sch.parallel(b)
sch.decompose_reduction(Y, k0)



block_C = sch.get_block("C",func_name="bmm_relu")
sch.reverse_compute_at(block_C, j0)

Y_init = sch.get_block("Y_init", func_name="bmm_relu")
b, i, j0, j1 = sch.get_loops(Y_init)
sch.vectorize(j1)

C = sch.get_block("C", func_name = "bmm_relu")
b, i, j0, j1 = sch.get_loops(C)
sch.vectorize(j1)
sch.unroll(k1)
tvm.ir.assert_structural_equal(sch.mod, TargetModule)
print("Pass")




Pass


In [73]:
IPython.display.HTML(code2html(sch.mod.script()))

构建和评估

In [79]:
before_rt_lib =  tvm.build(MyBmmRelu, target="llvm")
after_rt_lib = tvm.build(sch.mod, target = "llvm")
a_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
b_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
c_tvm = tvm.nd.array(np.empty((16, 128, 128), dtype=np.float32))

before_timer = before_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("Before transformation:")
print(before_timer(a_tvm, b_tvm, c_tvm))

f_timer = after_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("After transformation:")
print(f_timer(a_tvm, b_tvm, c_tvm))


Before transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  74.1253      74.1253      74.1253      74.1253       0.0000   
               
After transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  14.7466      14.7466      14.7466      14.7466       0.0000   
               


In [77]:
print(sch.trace)

b0 = sch.get_block(name="Y", func_name="bmm_relu")
l1, l2, l3, l4 = sch.get_loops(block=b0)
l5, l6 = sch.split(loop=l3, factors=[None, 8], preserve_unit_iters=True)
l7, l8 = sch.split(loop=l4, factors=[None, 4], preserve_unit_iters=True)
sch.reorder(l7, l8, l6)
sch.parallel(loop=l1)
b9 = sch.decompose_reduction(block=b0, loop=l7)
b10 = sch.get_block(name="C", func_name="bmm_relu")
sch.reverse_compute_at(block=b10, loop=l5, preserve_unit_loops=False)
b11 = sch.get_block(name="Y_init", func_name="bmm_relu")
l12, l13, l14, l15 = sch.get_loops(block=b11)
sch.vectorize(loop=l15)
b16 = sch.get_block(name="C", func_name="bmm_relu")
l17, l18, l19, l20 = sch.get_loops(block=b16)
sch.vectorize(loop=l20)
sch.unroll(loop=l8)
