In [1]:
from tvm.ir.module import IRModule
from tvm.script import relax as R
from tvm.script import tir as T
from tvm import relax
import numpy as np
import tvm

In [2]:
@tvm.script.ir_module
class MyModuleVecAdd:
    @T.prim_func
    def main(A: T.Buffer[(1024,), "float32"],
             B: T.Buffer[(1024,), "float32"],
             C: T.Buffer[(1024,), "float32"]) -> None:
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        for i in T.grid(1024):
            with T.block("C"):
                vi = T.axis.remap("S", [i])
                C[vi] = A[vi] + B[vi]

In [3]:
sch = tvm.tir.Schedule(MyModuleVecAdd)
block_C = sch.get_block("C")
i, = sch.get_loops(block=block_C)
i0, i1 = sch.split(i, [None, 128])
sch.mod.show()

To print formatted TVM script, please install the formatter 'Black':
/opt/conda/bin/python3.8 -m pip install "black==22.3.0" --upgrade --user


In [4]:
sch.bind(i0, "blockIdx.x")
sch.bind(i1, "threadIdx.x")
sch.mod.show()

To print formatted TVM script, please install the formatter 'Black':
/opt/conda/bin/python3.8 -m pip install "black==22.3.0" --upgrade --user


In [5]:
rt_mod = tvm.build(sch.mod, target="cuda")

A_np = np.random.uniform(size=(1024,)).astype(np.float32)
B_np = np.random.uniform(size=(1024,)).astype(np.float32)

A_nd = tvm.nd.array(A_np, tvm.cuda(0))
B_nd = tvm.nd.array(B_np, tvm.cuda(0))
C_nd = tvm.nd.array(np.zeros((1024,), dtype="float32"), tvm.cuda(0))

rt_mod["main"](A_nd, B_nd, C_nd)
print(A_nd)
print(B_nd)
print(C_nd)

[0.7697751  0.1894259  0.15973441 ... 0.71162534 0.25710225 0.54057324]
[0.11879353 0.5710991  0.7792158  ... 0.17648411 0.8976469  0.04839446]
[0.88856864 0.760525   0.93895024 ... 0.88810945 1.1547492  0.5889677 ]


In [6]:
@tvm.script.ir_module
class MyModuleWindowSum:
    @T.prim_func
    def main(A: T.Buffer[(1026,), "float32"],
             B: T.Buffer[(1024,), "float32"]) -> None:
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        for i in T.grid(1024):
            with T.block("C"):
                vi = T.axis.remap("S", [i])
                B[vi] = A[vi] + A[vi + 1] + A[vi + 2]

In [7]:
sch = tvm.tir.Schedule(MyModuleWindowSum)
nthread = 128
block = sch.get_block("C")
i, = sch.get_loops(block)
i0, i1 = sch.split(i, [None, nthread])

sch.bind(i0, "blockIdx.x")
sch.bind(i1, "threadIdx.x")
sch.mod.show()

To print formatted TVM script, please install the formatter 'Black':
/opt/conda/bin/python3.8 -m pip install "black==22.3.0" --upgrade --user


In [8]:
A_shared = sch.cache_read(block, read_buffer_index=0, storage_scope="shared")
sch.compute_at(A_shared, i1)
sch.mod.show()

To print formatted TVM script, please install the formatter 'Black':
/opt/conda/bin/python3.8 -m pip install "black==22.3.0" --upgrade --user


In [9]:
ax = sch.get_loops(A_shared)[-1]
ax0, ax1 = sch.split(ax, [None, nthread])
sch.bind(ax1, "threadIdx.x")
sch.mod.show()

To print formatted TVM script, please install the formatter 'Black':
/opt/conda/bin/python3.8 -m pip install "black==22.3.0" --upgrade --user


In [10]:
rt_mod = tvm.build(sch.mod, target="cuda")
print(rt_mod.imported_modules[0].get_source())


#ifdef _WIN32
  using uint = unsigned int;
  using uchar = unsigned char;
  using ushort = unsigned short;
  using int64_t = long long;
  using uint64_t = unsigned long long;
#else
  #define uint unsigned int
  #define uchar unsigned char
  #define ushort unsigned short
  #define int64_t long long
  #define uint64_t unsigned long long
#endif
extern "C" __global__ void __launch_bounds__(128) main_kernel0(float* __restrict__ A, float* __restrict__ B) {
  __shared__ float A_shared[130];
  for (int ax0_0 = 0; ax0_0 < 2; ++ax0_0) {
    if (((ax0_0 * 64) + (((int)threadIdx.x) >> 1)) < 65) {
      A_shared[((ax0_0 * 128) + ((int)threadIdx.x))] = A[(((((int)blockIdx.x) * 128) + (ax0_0 * 128)) + ((int)threadIdx.x))];
    }
  }
  __syncthreads();
  B[((((int)blockIdx.x) * 128) + ((int)threadIdx.x))] = ((A_shared[((int)threadIdx.x)] + A_shared[(((int)threadIdx.x) + 1)]) + A_shared[(((int)threadIdx.x) + 2)]);
}




In [11]:
rt_mod = tvm.build(sch.mod, target="metal")
print(rt_mod.imported_modules[0].get_source())



// Function: main_kernel0
#include <metal_stdlib>
using namespace metal;

union __TVMArgUnion {
 int v_int[2];
};

kernel void main_kernel0(  device float* A [[ buffer(0) ]],
  device float* B [[ buffer(1) ]],
  uint blockIdx [[threadgroup_position_in_grid]],
  uint threadIdx [[thread_position_in_threadgroup]]
) {
  threadgroup float A_shared[130];
  for (int ax0_0 = 0; ax0_0 < 2; ++ax0_0) {
    if (((ax0_0 * 64) + (((int)threadIdx) >> 1)) < 65) {
      A_shared[((ax0_0 * 128) + ((int)threadIdx))] = A[(((((int)blockIdx) * 128) + (ax0_0 * 128)) + ((int)threadIdx))];
    }
  }
  threadgroup_barrier(mem_flags::mem_threadgroup);
  B[((((int)blockIdx) * 128) + ((int)threadIdx))] = ((A_shared[((int)threadIdx)] + A_shared[(((int)threadIdx) + 1)]) + A_shared[(((int)threadIdx) + 2)]);
}




In [12]:
@tvm.script.ir_module
class MyModuleMatMul:
    @T.prim_func
    def main(A: T.Buffer[(1024, 1024), "float32"],
             B: T.Buffer[(1024, 1024), "float32"],
             C: T.Buffer[(1024, 1024), "float32"]) -> None:
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        for i, j, k in T.grid(1024, 1024, 1024):
            with T.block("C"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    C[vi, vj] = 0
                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

In [13]:
def blocking(sch, tile_local_y, tile_local_x, tile_block_y, tile_block_x, tile_k):
    block = sch.get_block("C")
    local = sch.cache_write(block, 0, "local")
    i, j, k = sch.get_loops(block)
    
    i0, i1, i2 = sch.split(i, [None, tile_block_y, tile_block_y])
    j0, j1, j2 = sch.split(j, [None, tile_block_x, tile_block_x])
    k0, k1 = sch.split(k, [None, tile_k])
    
    sch.unroll(k1)
    sch.reorder(i0, j0, i1, j1, k0, k1, i2, j2)
    sch.reverse_compute_at(local, j1)
    
    sch.bind(i0, "blockIdx.y")
    sch.bind(j0, "blockIdx.x")
    sch.bind(i1, "threadIdx.y")
    sch.bind(j1, "threadIdx.x")
    
    sch.decompose_reduction(block, k0)
    return sch

sch = tvm.tir.Schedule(MyModuleMatMul)
sch = blocking(sch, 8, 8, 8, 8, 4)
sch.mod.show()

To print formatted TVM script, please install the formatter 'Black':
/opt/conda/bin/python3.8 -m pip install "black==22.3.0" --upgrade --user


In [14]:
rt_mod = tvm.build(sch.mod, target="cuda")
print(rt_mod.imported_modules[0].get_source())

A_np = np.random.uniform(size=(1024, 1024)).astype("float32")
B_np = np.random.uniform(size=(1024, 1024)).astype("float32")

A_nd = tvm.nd.array(A_np, tvm.cuda(0))
B_nd = tvm.nd.array(B_np, tvm.cuda(0))
C_nd = tvm.nd.array(np.zeros((1024, 1024), dtype="float32"), tvm.cuda(0))

flops = (1024 ** 3) * 2
evaluator = rt_mod.time_evaluator("main", tvm.cuda(0), number=10)
print("GEMM blocking: %f GFLOPS" %(flops / evaluator(A_nd, B_nd, C_nd).mean / 1e9))


#ifdef _WIN32
  using uint = unsigned int;
  using uchar = unsigned char;
  using ushort = unsigned short;
  using int64_t = long long;
  using uint64_t = unsigned long long;
#else
  #define uint unsigned int
  #define uchar unsigned char
  #define ushort unsigned short
  #define int64_t long long
  #define uint64_t unsigned long long
#endif
extern "C" __global__ void __launch_bounds__(64) main_kernel0(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C) {
  float C_local[64];
  for (int i_2_init = 0; i_2_init < 8; ++i_2_init) {
    for (int j_2_init = 0; j_2_init < 8; ++j_2_init) {
      C_local[((i_2_init * 8) + j_2_init)] = 0.000000e+00f;
    }
  }
  for (int k_0 = 0; k_0 < 256; ++k_0) {
    for (int i_2 = 0; i_2 < 8; ++i_2) {
      for (int j_2 = 0; j_2 < 8; ++j_2) {
        C_local[((i_2 * 8) + j_2)] = (C_local[((i_2 * 8) + j_2)] + (A[((((((int)blockIdx.y) * 65536) + (((int)threadIdx.y) * 8192)) + (i_2 * 1024)) + (k_0 * 4))] * B[((((k_0 * 4096) + (((int)blockIdx.x

In [15]:
def cache_read_and_coop_fetch(sch, block, nthread, read_idx, read_loc):
    read_cache = sch.cache_read(block=block, read_buffer_index=read_idx, storage_scope="shared")
    sch.compute_at(block=read_cache, loop=read_loc)
    inner0, inner1 = sch.get_loops(block=read_cache)[-2:]
    inner = sch.fuse(inner0, inner1)
    _, tx, vec = sch.split(loop=inner, factors=[None, nthread, 4])
    sch.vectorize(vec)
    sch.bind(tx, "threadIdx.x")
    
def blocking_with_shared(sch, tile_local_y, tile_local_x, tile_block_y, tile_block_x, tile_k):
    block = sch.get_block("C")
    local = sch.cache_write(block, 0, "local")
    
    i, j, k = sch.get_loops(block)
    i0, i1, i2 = sch.split(i, [None, tile_block_y, tile_local_y])
    j0, j1, j2 = sch.split(j, [None, tile_block_x, tile_local_x])
    k0, k1 = sch.split(k, [None, tile_k])
    
    sch.reorder(i0, j0, i1, j1, k0, k1, i2, j2)
    sch.reverse_compute_at(local, j1)
    
    sch.bind(i0, "blockIdx.y")
    sch.bind(j0, "blockIdx.x")
    
    tx = sch.fuse(i1, j1)
    sch.bind(tx, "threadIdx.x")
    nthread = tile_block_y * tile_block_x
    
    cache_read_and_coop_fetch(sch, block, nthread, 0, k0)
    cache_read_and_coop_fetch(sch, block, nthread, 1, k0)
    sch.decompose_reduction(block, k0)
    
    return sch

sch = tvm.tir.Schedule(MyModuleMatMul)
sch = blocking_with_shared(sch, 8, 8, 8, 8, 8)
sch.mod.show()

To print formatted TVM script, please install the formatter 'Black':
/opt/conda/bin/python3.8 -m pip install "black==22.3.0" --upgrade --user


In [16]:
rt_mod = tvm.build(sch.mod, target="cuda")
print(rt_mod.imported_modules[0].get_source())
evaluator = rt_mod.time_evaluator("main", tvm.cuda(0), number=10)
print("GEMM blocking shared: %f GFLOPS" %(flops / evaluator(A_nd, B_nd, C_nd).mean / 1e9))


#ifdef _WIN32
  using uint = unsigned int;
  using uchar = unsigned char;
  using ushort = unsigned short;
  using int64_t = long long;
  using uint64_t = unsigned long long;
#else
  #define uint unsigned int
  #define uchar unsigned char
  #define ushort unsigned short
  #define int64_t long long
  #define uint64_t unsigned long long
#endif
extern "C" __global__ void __launch_bounds__(64) main_kernel0(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C) {
  float C_local[64];
  __shared__ float A_shared[512];
  __shared__ float B_shared[512];
  for (int i_2_init = 0; i_2_init < 8; ++i_2_init) {
    for (int j_2_init = 0; j_2_init < 8; ++j_2_init) {
      C_local[((i_2_init * 8) + j_2_init)] = 0.000000e+00f;
    }
  }
  for (int k_0 = 0; k_0 < 128; ++k_0) {
    __syncthreads();
    for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
      *(float4*)(A_shared + ((ax0_ax1_fused_0 * 256) + (((int)threadIdx.x) * 4))) = *(float4*)(A + (((((((int)blockIdx

In [17]:
from tvm import meta_schedule as ms

sch_tuned = ms.tune_tir(mod=MyModuleMatMul,
                        target='nvidia/geforce-rtx-3060',
                        max_trials_global=64,
                        num_trials_per_iter=64,
                        work_dir='./tune_tmp',
                        task_name='main')

2023-02-13 05:06:06 [INFO] [task_scheduler.cc:260] Task #0 has finished. Remaining task(s): 0


Unnamed: 0,Name,FLOP,Weight,Speed (GFLOPS),Latency (us),Weighted Latency (us),Trials,Done
0,main,2147483648,1,4571.539,469.7507,469.7507,64,Y



Total trials: 64
Total latency (us): 469.751

2023-02-13 05:06:06 [DEBUG] [task_scheduler.cc:318] 
 ID | Name |       FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
---------------------------------------------------------------------------------------------------------
  0 | main | 2147483648 |      1 |      4571.5390 |     469.7507 |              469.7507 |     64 |    Y 
---------------------------------------------------------------------------------------------------------
Total trials: 64
Total latency (us): 469.751



In [18]:
sch_meta_tuned = sch_tuned.query_schedule(MyModuleMatMul, target=tvm.target.Target('nvidia/geforce-rtx-3060'), workload_name='main')
sch_meta_tuned.mod.show()

To print formatted TVM script, please install the formatter 'Black':
/opt/conda/bin/python3.8 -m pip install "black==22.3.0" --upgrade --user


In [19]:
rt_mod = tvm.build(sch_meta_tuned.mod, target="nvidia/geforce-rtx-3060")
evaluator = rt_mod.time_evaluator("main", tvm.cuda(0), number=10)
print("MetaSchedule: %f GFLOPS" %(flops / evaluator(A_nd, B_nd, C_nd).mean / 1e9))

MetaSchedule: 4645.884922 GFLOPS


In [20]:
print(rt_mod.imported_modules[0].get_source())


#ifdef _WIN32
  using uint = unsigned int;
  using uchar = unsigned char;
  using ushort = unsigned short;
  using int64_t = long long;
  using uint64_t = unsigned long long;
#else
  #define uint unsigned int
  #define uchar unsigned char
  #define ushort unsigned short
  #define int64_t long long
  #define uint64_t unsigned long long
#endif
extern "C" __global__ void __launch_bounds__(256) main_kernel0(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C) {
  float C_local[64];
  __shared__ float A_shared[4096];
  __shared__ float B_shared[1024];
  C_local[0] = 0.000000e+00f;
  C_local[16] = 0.000000e+00f;
  C_local[32] = 0.000000e+00f;
  C_local[48] = 0.000000e+00f;
  C_local[1] = 0.000000e+00f;
  C_local[17] = 0.000000e+00f;
  C_local[33] = 0.000000e+00f;
  C_local[49] = 0.000000e+00f;
  C_local[2] = 0.000000e+00f;
  C_local[18] = 0.000000e+00f;
  C_local[34] = 0.000000e+00f;
  C_local[50] = 0.000000e+00f;
  C_local[3] = 0.000000e+00f;
  C_local[19] = 0.000000e+00f;
