<a href="https://colab.research.google.com/github/yingtongxiong/MLCnotebooks/blob/xytExercise/7_1_GPU_acceleration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!python3 -m pip install mlc-ai-nightly-cu110 -f https://mlc.ai/wheels

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://mlc.ai/wheels
Collecting mlc-ai-nightly-cu110
  Downloading https://github.com/mlc-ai/utils/releases/download/v0.9.dev0/mlc_ai_nightly_cu110-0.9.dev2227%2Bg06c0ef3be-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (238.8 MB)
[K     |████████████████████████████████| 238.8 MB 22 kB/s 
Collecting synr==0.6.0
  Downloading synr-0.6.0-py3-none-any.whl (18 kB)
Installing collected packages: synr, mlc-ai-nightly-cu110
Successfully installed mlc-ai-nightly-cu110-0.9.dev2227+g06c0ef3be synr-0.6.0


In [2]:
# This is needed for deferring annotation parsing in TVMScript
from __future__ import annotations
import numpy as np
import tvm
from tvm import relax
from tvm.ir.module import IRModule
from tvm.script import relax as R
from tvm.script import tir as T

In [3]:
@tvm.script.ir_module
class MyModuleVecAdd:
  @T.prim_func
  def main(A: T.Buffer[(1024,), "float32"],
           B: T.Buffer[(1024,), "float32"],
           C: T.Buffer[(1024,), "float32"]) -> None:
      T.func_attr({"global_symbol": "main", "tir.noalias": True})
      for i in T.grid(1024):
          with T.block("C"):
              vi = T.axis.remap("S", [i])
              C[vi] = A[vi] + B[vi]


In [4]:
sch = tvm.tir.Schedule(MyModuleVecAdd)
block_C = sch.get_block("C")
i, = sch.get_loops(block=block_C)
i0, i1 = sch.split(i, [None, 128])
sch.mod.show()

In [5]:
sch.bind(i0, "blockIdx.x")
sch.bind(i1, "threadIdx.x")
sch.mod.show()

In [6]:
rt_mod = tvm.build(sch.mod, target="cuda")

A_np = np.random.uniform(size=(1024,)).astype("float32")
B_np = np.random.uniform(size=(1024,)).astype("float32")
A_nd = tvm.nd.array(A_np, tvm.cuda(0))
B_nd = tvm.nd.array(B_np, tvm.cuda(0))
C_nd = tvm.nd.array(np.zeros((1024,), dtype="float32"), tvm.cuda(0))

rt_mod["main"](A_nd, B_nd, C_nd)
print(A_nd)
print(B_nd)
print(C_nd)

[0.04366452 0.48133028 0.6060154  ... 0.11618335 0.9300148  0.00836957]
[0.74900264 0.8771537  0.9234607  ... 0.01393905 0.87135917 0.793017  ]
[0.79266715 1.358484   1.5294762  ... 0.1301224  1.801374   0.8013866 ]


## 示例：窗口求和

In [15]:
@tvm.script.ir_module
class MyModuleWindowSum:
  @T.prim_func
  def main(A: T.Buffer[(1027,), "float32"],
           B: T.Buffer[(1024,), "float32"]) ->None:
      T.func_attr({"global_symbol": "main", "tir.noalias": True})
      for i in T.grid(1024):
        with T.block("C"):
          vi = T.axis.remap("S", [i])
          B[vi] = A[vi] + A[vi + 1] + A[vi + 2];

In [16]:
sch = tvm.tir.Schedule(MyModuleWindowSum)
nthread = 128
block_C = sch.get_block("C")
i, = sch.get_loops(block=block_C)
i0, i1 = sch.split(i, [None, nthread])
sch.bind(i0, "blockIdx.x")
sch.bind(i1, "threadIdx.x")
sch.mod.show()

In [17]:
A_shared = sch.cache_read(block_C, read_buffer_index=0, storage_scope="shared")
sch.compute_at(A_shared, i1)
sch.mod.show()

In [18]:
ax = sch.get_loops(A_shared)[-1]
ax0, ax1 = sch.split(ax, [None, nthread])
sch.bind(ax1, "threadIdx.x")
sch.mod.show()

In [19]:
rt_mod = tvm.build(sch.mod, target="cuda")
print(rt_mod.imported_modules[0].get_source())


#ifdef _WIN32
  using uint = unsigned int;
  using uchar = unsigned char;
  using ushort = unsigned short;
  using int64_t = long long;
  using uint64_t = unsigned long long;
#else
  #define uint unsigned int
  #define uchar unsigned char
  #define ushort unsigned short
  #define int64_t long long
  #define uint64_t unsigned long long
#endif
extern "C" __global__ void __launch_bounds__(128) main_kernel0(float* __restrict__ A, float* __restrict__ B) {
  __shared__ float A_shared[130];
  for (int ax0_0 = 0; ax0_0 < 2; ++ax0_0) {
    if (((ax0_0 * 64) + (((int)threadIdx.x) >> 1)) < 65) {
      A_shared[((ax0_0 * 128) + ((int)threadIdx.x))] = A[(((((int)blockIdx.x) * 128) + (ax0_0 * 128)) + ((int)threadIdx.x))];
    }
  }
  __syncthreads();
  B[((((int)blockIdx.x) * 128) + ((int)threadIdx.x))] = ((A_shared[((int)threadIdx.x)] + A_shared[(((int)threadIdx.x) + 1)]) + A_shared[(((int)threadIdx.x) + 2)]);
}




In [20]:
rt_mod = tvm.build(sch.mod, target="metal")
print(rt_mod.imported_modules[0].get_source())

// Function: main_kernel0
#include <metal_stdlib>
using namespace metal;

union __TVMArgUnion {
 int v_int[2];
};

kernel void main_kernel0(  device float* A [[ buffer(0) ]],
  device float* B [[ buffer(1) ]],
  uint blockIdx [[threadgroup_position_in_grid]],
  uint threadIdx [[thread_position_in_threadgroup]]
) {
  threadgroup float A_shared[130];
  for (int ax0_0 = 0; ax0_0 < 2; ++ax0_0) {
    if (((ax0_0 * 64) + (((int)threadIdx) >> 1)) < 65) {
      A_shared[((ax0_0 * 128) + ((int)threadIdx))] = A[(((((int)blockIdx) * 128) + (ax0_0 * 128)) + ((int)threadIdx))];
    }
  }
  threadgroup_barrier(mem_flags::mem_threadgroup);
  B[((((int)blockIdx) * 128) + ((int)threadIdx))] = ((A_shared[((int)threadIdx)] + A_shared[(((int)threadIdx) + 1)]) + A_shared[(((int)threadIdx) + 2)]);
}




## 矩阵乘法

In [21]:
@tvm.script.ir_module
class MyModuleMatmul:
  @T.prim_func
  def main(A: T.Buffer[(1024, 1024), "float32"],
           B: T.Buffer[(1024, 1024), "float32"],
           C: T.Buffer[(1024, 1024), "float32"]) -> None:
      T.func_attr({"global_symbol": "main", "tir.noalias": True})
      for i, j, k in T.grid(1024, 1024, 1024):
        with T.block("C"):
          vi, vj, vk = T.axis.remap("SSR", [i, j, k])
          with T.init():
            C[vi, vj] = 0.0
          C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

In [24]:
def blocking(sch,
             tile_local_y,
             tile_local_x,
             tile_block_y,
             tile_block_x,
             tile_k):
    block_C = sch.get_block("C")
    C_local = sch.cache_write(block_C, 0, "local")

    i, j, k = sch.get_loops(block=block_C)

    i0, i1, i2 = sch.split(loop=i, factors=[None, tile_block_y, tile_local_y])
    j0, j1, j2 = sch.split(loop=j, factors=[None, tile_block_x, tile_local_x])
    k0, k1 = sch.split(loop=k, factors=[None, tile_k])

    sch.unroll(k1)
    sch.reorder(i0, j0, i1, j1, k0, k1, i2, j2)
    sch.reverse_compute_at(C_local, j1)

    sch.bind(i0, "blockIdx.y")
    sch.bind(j0, "blockIdx.x")

    sch.bind(i1, "threadIdx.y")
    sch.bind(j1, "threadIdx.x")
    sch.decompose_reduction(block_C, k0)

    return sch

sch = tvm.tir.Schedule(MyModuleMatmul)
sch = blocking(sch, 8, 8, 8, 8, 4)
sch.mod.show()

In [25]:
rt_mod = tvm.build(sch.mod, target="cuda")
dev = tvm.cuda(0)
A_np = np.random.uniform(size=(1024, 1024)).astype("float32")
B_np = np.random.uniform(size=(1024, 1024)).astype("float32")
A_nd = tvm.nd.array(A_np, dev)
B_nd = tvm.nd.array(B_np, dev)
C_nd = tvm.nd.array(np.zeros((1024, 1024), dtype="float32"), dev)

num_flop = 2 * 1024 * 1024 * 1024
evaluator = rt_mod.time_evaluator("main", dev, number=10)

print("GEMM-Blocking: %f GFLOPS" % (num_flop / evaluator(A_nd, B_nd, C_nd).mean / 1e9))

GEMM-Blocking: 872.594130 GFLOPS


In [26]:
def cache_read_and_coop_fetch(sch, block, nthread, read_idx, read_loc):
    read_cache = sch.cache_read(block=block, read_buffer_index=read_idx, storage_scope="shared")
    sch.compute_at(block=read_cache, loop=read_loc)
    # vectorized cooperative fetch
    inner0, inner1 = sch.get_loops(block=read_cache)[-2:]
    inner = sch.fuse(inner0, inner1)
    _, tx, vec = sch.split(loop=inner, factors=[None, nthread, 4])
    sch.vectorize(vec)
    sch.bind(tx, "threadIdx.x")


def blocking_with_shared(
    sch,
    tile_local_y,
    tile_local_x,
    tile_block_y,
    tile_block_x,
    tile_k):
    block_C = sch.get_block("C")
    C_local = sch.cache_write(block_C, 0, "local")

    i, j, k = sch.get_loops(block=block_C)

    i0, i1, i2 = sch.split(loop=i, factors=[None, tile_block_y, tile_local_y])
    j0, j1, j2 = sch.split(loop=j, factors=[None, tile_block_x, tile_local_x])
    k0, k1 = sch.split(loop=k, factors=[None, tile_k])

    sch.reorder(i0, j0, i1, j1, k0, k1, i2, j2)
    sch.reverse_compute_at(C_local, j1)

    sch.bind(i0, "blockIdx.y")
    sch.bind(j0, "blockIdx.x")

    tx = sch.fuse(i1, j1)
    sch.bind(tx, "threadIdx.x")
    nthread = tile_block_y * tile_block_x
    cache_read_and_coop_fetch(sch, block_C, nthread, 0, k0)
    cache_read_and_coop_fetch(sch, block_C, nthread, 1, k0)
    sch.decompose_reduction(block_C, k0)

    return sch

sch = tvm.tir.Schedule(MyModuleMatmul)
sch = blocking_with_shared(sch, 8, 8, 8, 8, 8)
sch.mod.show()

In [27]:
from tvm import meta_schedule as ms

sch_tuned = ms.tune_tir(
    mod=MyModuleMatmul,
    target="nvidia/tesla-p100",
    config=ms.TuneConfig(
      max_trials_global=64,
      num_trials_per_iter=64,
    ),
    work_dir="./tune_tmp",
    task_name="main"
)
sch_tuned.mod.show()

2022-10-16 06:44:42.154 INFO Logging directory: ./tune_tmp/logs
2022-10-16 06:44:42.163 INFO Logging directory: ./tune_tmp/logs
2022-10-16 06:44:42.169 INFO Working directory: ./tune_tmp
2022-10-16 06:44:42.175 INFO Creating JSONDatabase. Workload at: ./tune_tmp/database_workload.json. Tuning records at: ./tune_tmp/database_tuning_record.json
2022-10-16 06:44:42.178 INFO LocalBuilder: max_workers = 1
2022-10-16 06:44:42.937 INFO LocalRunner: max_workers = 1
2022-10-16 06:44:43.534 INFO Initializing Task #0: "main"
2022-10-16 06:44:43.551 INFO 
 ID | Name |       FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Terminated 
---------------------------------------------------------------------------------------------------------------
  0 | main | 2147483648 |      1 |            N/A |          N/A |                   N/A |      0 |            
---------------------------------------------------------------------------------------------------------------
To

In [31]:
rt_mod = tvm.build(sch_tuned.mod, target="nvidia/tesla-p100")
dev = tvm.cuda(0)
evaluator = rt_mod.time_evaluator("main", dev, number=10)

print( (num_flop / evaluator(A_nd, B_nd, C_nd).mean / 1e9))

TVMError: ignored