In [None]:
import sys
from pathlib import Path

sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))
import tilelang
import torch
import tilelang.language as T

# Tilelang Lazy JIT

## Tensor Annotation

Tilelang Lazy JIT 将 jit 生成和调用的逻辑合并到一起

函数签名的写法与 triton 相似，但做了大量增强，最主要的增强是允许对 Tensor 的标注：

例如，下面的代码用 T.Tensor[[int, int], T.float16] 来标注了一个二维 Tensor
1. 它的每个维度都是编译期常量，如果改变，会触发重新编译
2. 它的类型必须是 T.float16

DType 除了写确定的外，还可以写 Any 或者 None

In [2]:
@tilelang.lazy_jit
def gemm(
    A: T.Tensor[[int, int], T.float16],
    B: T.Tensor[[int, int], T.float16],
    out_dtype: T.dtype = T.float32,
    block_M: int = 128,
    block_N: int = 128,
    block_K: int = 32,
):
    M, K = A.shape
    K, N = B.shape
    C = T.empty((M, N), out_dtype)
    with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):
        A_shared = T.alloc_shared((block_M, block_K), A.dtype)
        B_shared = T.alloc_shared((block_K, block_N), B.dtype)
        C_local = T.alloc_fragment((block_M, block_N), out_dtype)
        T.clear(C_local)
        for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
            T.copy(A[bx * block_M, k * block_K], A_shared)
            T.copy(B[k * block_K, by * block_N], B_shared)
            T.gemm(A_shared, B_shared, C_local)
        T.copy(C_local, C[bx * block_M, by * block_N])
    return C

直接将 Tensor 作为参数调用，即可触发完整的 jit 编译运行流程：

In [3]:
A = torch.randn(1024, 512, dtype=torch.float16, device="cuda")
B = torch.randn(512, 256, dtype=torch.float16, device="cuda")
C = gemm(A, B)

# check output is correct
C_ref = (A @ B).float()
torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)

更改调用的参数，如果编译器参数发生了变化，会触发重新编译：

In [4]:
A = torch.randn(1024, 512, dtype=torch.float16, device="cuda")
B = torch.randn(512, 1024, dtype=torch.float16, device="cuda")
C = gemm(A, B, block_M=64, block_N=64)

你也可以手动调用 compile 函数编译 kernel

1. `ker.compile` 编译 kernel
2. `ker.get_tir` 获取 tir
3. `ker.par_compile` 并行编译

In [5]:
kernel = gemm.compile(A, B, block_M=64, block_N=64)
C = kernel(A, B)



## More Tensor Annotation

### 用 macro 来分离实现

接下来，我们会用各种方式来实现一个简单的 gemm，为了方便，我们先写一个 macro 把 gemm 的主要逻辑写出来：

In [6]:
@T.macro
def gemm_impl(A, B, C, M, N, K, block_M, block_N, block_K):
    with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):
        A_shared = T.alloc_shared((block_M, block_K), A.dtype)
        B_shared = T.alloc_shared((block_K, block_N), B.dtype)
        C_local = T.alloc_fragment((block_M, block_N), C.dtype)
        T.clear(C_local)
        for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
            T.copy(A[bx * block_M, k * block_K], A_shared)
            T.copy(B[k * block_K, by * block_N], B_shared)
            T.gemm(A_shared, B_shared, C_local)
        T.copy(C_local, C[bx * block_M, by * block_N])

### 用 T.dyn 标记动态 Shape

当某些维度是动态的的时候，可以用 T.dyn 来标记。T.dyn 可以接受一个字符串参数，表示变量的名字

In [None]:
@tilelang.lazy_jit
def gemm_dyn_K(
    A: T.Tensor[[int, T.dyn["K"]], T.float16],  # noqa: F821
    B: T.Tensor[[T.dyn["K"], int], T.float16],  # noqa: F821
):
    M, K = A.shape
    K, N = B.shape
    C = T.empty((M, N), T.float32)
    gemm_impl(A, B, C, M, N, K, 128, 128, 32)
    return C

查看 lazy_jit 的函数签名，其中带有后缀`$` 的是不确定的编译期常量，带有 `$dyn` 的是运行时的变量

In [8]:
gemm_dyn_K.func.annot

{'A': TensorAnnot(shape=[A_shape_0$, K$dyn], strides=None, dtype=dtype('float16')),
 'B': TensorAnnot(shape=[K$dyn, B_shape_1$], strides=None, dtype=dtype('float16'))}

In [9]:
A = torch.randn(1024, 512, dtype=torch.float16, device="cuda")
B = torch.randn(512, 256, dtype=torch.float16, device="cuda")
C = gemm_dyn_K(A, B)
C_ref = (A @ B).float()
torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)

### 用 T.StridedTensor 标记带 stride 的 Tensor

标记方法：T.StridedTensor[Shape, Stride, DType]，每个 Shape 或 Stride 可以写
* int: 表示编译期常量
* T.dyn：表示运行时常量

DType 可以写 None 或 Any

In [10]:
from typing import Any


@tilelang.lazy_jit
def as_contingious(A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]):
    M, N = A.shape
    B = T.empty((M, N), A.dtype)
    block_M = 128
    block_N = 128
    with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):
        T.copy(
            A[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],
            B[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],
        )
    return B

In [11]:
A = torch.randn(1024, 1024, device="cuda")
B = as_contingious(A[::2, ::2])
B_ref = A[::2, ::2].contiguous()
torch.testing.assert_close(B, B_ref)

## More Annotation

### 用 T.ptr 标注 Tensor
lazy_jit 允许你用 T.ptr 来声明一个 handle，但必须在函数内用 T.match_buffer 给它定义 shape

In [12]:
@tilelang.lazy_jit
def gemm_ptr(
    A: T.ptr,
    B: T.ptr,
    M: int,
    N: int,
    K: int,
):
    A = T.match_buffer(A, (M, K), T.float16)
    B = T.match_buffer(B, (K, N), T.float16)
    C = T.empty((M, N), T.float32)
    gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)
    return C

In [13]:
A = torch.randn(1024, 512, dtype=torch.float16, device="cuda")
B = torch.randn(512, 256, dtype=torch.float16, device="cuda")
C = gemm_ptr(A, B, 1024, 256, 512)
C_ref = (A @ B).float()
torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)

### 用 T.int32 标注运行时变量

lazy_jit 允许你用 T.int32 或其他类型来定义运行时变量，这样，你可以写一个完全动态的 gemm，这和 triton 非常相似

In [14]:
@tilelang.lazy_jit
def gemm_ptr_dyn(
    A: T.ptr,
    B: T.ptr,
    M: T.int32,
    N: T.int32,
    K: T.int32,
):
    A = T.match_buffer(A, (M, K), T.float16, strides=(K, 1))
    B = T.match_buffer(B, (K, N), T.float16, strides=(N, 1))
    C = T.empty((M, N), T.float32)
    gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)
    return C

In [15]:
A = torch.randn(1024, 512, dtype=torch.float16, device="cuda")
B = torch.randn(512, 256, dtype=torch.float16, device="cuda")
C = gemm_ptr_dyn(A, B, 1024, 256, 512)
C_ref = (A @ B).float()
torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)

## 编译与并行编译

lazyjit 和原来的 jit 都支持并行编译

为了防止 torch.tensor 白白浪费内存，可以使用 T.Tensor 来创建 placeholder

In [16]:
from itertools import product


def get_configs():
    return [
        {
            "A": T.Tensor((1024, 1024), T.float32),
            "B": T.Tensor((1024, 1024), T.float32),
            "block_M": block_M,
            "block_N": block_N,
            "block_K": block_K,
        }
        for block_M, block_N, block_K in product([32, 64], repeat=3)
    ]


gemm.par_compile(get_configs())

Elaborating:   0%|          | 0/8 [00:00<?, ?it/s]

Parallel Compiling:   0%|          | 0/8 [00:00<?, ?it/s]

[<tilelang.jit.kernel.JITKernel at 0x7f29c0072ed0>,
 <tilelang.jit.kernel.JITKernel at 0x7f29c00882f0>,
 <tilelang.jit.kernel.JITKernel at 0x7f29c00735f0>,
 <tilelang.jit.kernel.JITKernel at 0x7f29c0088890>,
 <tilelang.jit.kernel.JITKernel at 0x7f29c01f94c0>,
 <tilelang.jit.kernel.JITKernel at 0x7f29c0073fe0>,
 <tilelang.jit.kernel.JITKernel at 0x7f29c0070ce0>,
 <tilelang.jit.kernel.JITKernel at 0x7f29c00732f0>]

## 更便利的 Macro

tilelang 的 macro 现在已经升级：

1. 允许用 `T.Ref` 作为 annotation，这类似与 C++ 的引用传递
2. 允许返回多个值
3. 允许嵌套，递归

### T.Ref 传递引用

T.Ref 传递的引用可以 var 也可以是 Buffer 的索引

In [None]:
@T.macro
def macro_with_ref(x: T.Ref):
    x = 1  # noqa: F841


@T.prim_func
def foo(x: T.Tensor((2,))):
    with T.Kernel(1) as _:
        # 支持常量 index
        macro_with_ref(x[1])

        # 也支持变量 index
        idx = T.alloc_var(T.int32, 0)
        macro_with_ref(x[idx])


foo

# import tilelang.language as T

@T.prim_func
def foo(x_handle: T.handle):
    x = T.match_buffer(x_handle, (2,), strides=(1,))
    # with T.block("root"):
    bx = T.launch_thread("blockIdx.x", 1)
    tx = T.launch_thread("threadIdx.x", 128)
    ty = T.launch_thread("threadIdx.y", 1)
    tz = T.launch_thread("threadIdx.z", 1)
    with T.block("tilelang_root"):
        T.reads()
        idx = T.Buffer((1,), "int32", scope="local.var")
        T.writes(x[T.min(1, idx[0]):T.min(1, idx[0]) + (T.max(1, idx[0]) + 1 - T.min(1, idx[0]))])
        T.block_attr({"tl.local_var_init": {idx.data: 0}})
        idx = T.alloc_buffer((1,), "int32", data=idx.data, scope="local.var")
        x[1] = T.float32(1.0)
        _tmp: T.int32 = idx[0]
        x[_tmp] = T.float32(1.0)

### 当作参数传递

你可以把 macro 当做参数传递

In [18]:
@tilelang.lazy_jit
def element_wise(
    A: T.Tensor[[T.dyn], Any],
    fn,
):
    (N,) = A.shape
    B = T.empty((N,), dtype=A.dtype)
    block_N = 128
    with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:
        for i in T.Parallel(block_N):
            idx = bx * block_N + i
            B[idx] = fn(A[idx])
    return B


@T.macro
def add_one(x):
    return x + 1

In [19]:
A = torch.randn(1024, device="cuda")
B = element_wise(A, add_one)
B_ref = A + 1
torch.testing.assert_close(B, B_ref)

### Macro 递归

虽然不知道有没有这种需求，但 macro 是可以递归的，但要求终止条件编译期间确定

In [20]:
@T.macro
def n31(x, var: T.Ref):
    if x == 1:
        pass
    elif x % 2 == 0:
        var = var // 2
        n31(x // 2, var)
    else:
        var = var * 3 + 1
        n31(x * 3 + 1, var)


@tilelang.lazy_jit
def foo(A: T.Tensor[[1], T.int32], n: int):
    with T.Kernel(1) as _:
        n31(n, A[0])

In [21]:
A = torch.tensor([100], dtype=torch.int32, device="cuda")
foo(A, 5)
A

tensor([18], device='cuda:0', dtype=torch.int32)

### Macro 返回多个值

In [None]:
@T.macro
def sincos(x):
    return T.sin(x), T.cos(x)


@T.prim_func
def foo():
    with T.Kernel(32) as x:
        s, c = sincos(x)
        a = s + c  # noqa: F841
        b = s - c  # noqa: F841


foo

# import tilelang.language as T

@T.prim_func
def foo():
    # with T.block("root"):
    x = T.launch_thread("blockIdx.x", 32)
    tx = T.launch_thread("threadIdx.x", 128)
    ty = T.launch_thread("threadIdx.y", 1)
    tz = T.launch_thread("threadIdx.z", 1)
    with T.block("tilelang_root"):
        T.reads()
        T.writes()
        s: T.int32 = T.sin(x)
        c: T.int32 = T.cos(x)
        a: T.int32 = s + c
        b: T.int32 = s - c
        T.evaluate(0)