In [None]:
import sys
from pathlib import Path

sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))
import tilelang
import torch
import tilelang.language as T

# Tilelang Lazy JIT

## Tensor Annotation

Tilelang Lazy JIT merges JIT kernel generation and invocation into a single workflow.

The function signature looks similar to Triton, but we add many enhancements; the most important one is allowing rich Tensor annotations:

* If a Tensor has complex shape constraints, we can move its annotation into the function body.
* Use `T.const` or `T.dynamic` to create shape variables, then annotate complex Tensors with `T.Tensor`.
* Use `T.empty` to declare return tensors.

In [2]:
@tilelang.lazy_jit
def gemm(
    A,
    B,
    out_dtype: T.dtype = T.float32,
    block_M: int = 128,
    block_N: int = 128,
    block_K: int = 32,
):
    M, N, K = T.const("M, N, K")

    A: T.Tensor[[M, K], T.float16]
    B: T.Tensor[[K, N], T.float16]

    C = T.empty((M, N), out_dtype)

    with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):
        A_shared = T.alloc_shared((block_M, block_K), A.dtype)
        B_shared = T.alloc_shared((block_K, block_N), B.dtype)
        C_local = T.alloc_fragment((block_M, block_N), out_dtype)
        T.clear(C_local)
        for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
            T.copy(A[bx * block_M, k * block_K], A_shared)
            T.copy(B[k * block_K, by * block_N], B_shared)
            T.gemm(A_shared, B_shared, C_local)
        T.copy(C_local, C[bx * block_M, by * block_N])
    return C

Calling the function with Tensors directly triggers the full JIT compile-and-run pipeline:

In [3]:
A = torch.randn(1024, 512, dtype=torch.float16, device="cuda")
B = torch.randn(512, 256, dtype=torch.float16, device="cuda")
C = gemm(A, B)

# check output is correct
C_ref = (A @ B).float()
torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)

Changing the call arguments may trigger a recompilation when compilation parameters change:

In [4]:
A = torch.randn(1024, 512, dtype=torch.float16, device="cuda")
B = torch.randn(512, 1024, dtype=torch.float16, device="cuda")
C = gemm(A, B, block_M=64, block_N=64)

You can also explicitly call the `compile` method to build the kernel.

1. `ker.compile` compiles the kernel
2. `ker.get_tir` retrieves the TIR
3. `ker.par_compile` compiles in parallel

In [5]:
kernel = gemm.compile(A, B, block_M=64, block_N=64)
C = kernel(A, B)

## More Tensor Annotation

### Use macros to separate implementation

Next, we implement a simple GEMM in several different ways. For convenience, we first write a macro that contains the core GEMM logic:

In [6]:
@T.macro
def gemm_impl(A, B, C, M, N, K, block_M, block_N, block_K):
    with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):
        A_shared = T.alloc_shared((block_M, block_K), A.dtype)
        B_shared = T.alloc_shared((block_K, block_N), B.dtype)
        C_local = T.alloc_fragment((block_M, block_N), C.dtype)
        T.clear(C_local)
        for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
            T.copy(A[bx * block_M, k * block_K], A_shared)
            T.copy(B[k * block_K, by * block_N], B_shared)
            T.gemm(A_shared, B_shared, C_local)
        T.copy(C_local, C[bx * block_M, by * block_N])

### Use `T.dynamic` to mark dynamic shapes


In [7]:
@tilelang.lazy_jit
def gemm_dyn_K(A, B):
    M, N, K = T.dynamic("M, N, K")
    A: T.Tensor[[M, K], T.float16]
    B: T.Tensor[[K, N], T.float16]
    C = T.empty((M, N), T.float32)
    gemm_impl(A, B, C, M, N, K, 128, 128, 32)
    return C

In [8]:
A = torch.randn(1024, 512, dtype=torch.float16, device="cuda")
B = torch.randn(512, 256, dtype=torch.float16, device="cuda")
C = gemm_dyn_K(A, B)
C_ref = (A @ B).float()
torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)

### Use `T.StridedTensor` to annotate tensors with strides


In [9]:
@tilelang.lazy_jit
def as_contingious(A):
    M, N, dM, dN = T.dynamic("M, N, dM, dN")
    A: T.StridedTensor[[M, N], [dM, dN], T.float32]
    B = T.empty((M, N), A.dtype)
    block_M = 128
    block_N = 128
    with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):
        T.copy(
            A[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],
            B[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],
        )
    return B

In [10]:
A = torch.randn(1024, 1024, device="cuda")
B = as_contingious(A.T)
B_ref = A.T.contiguous()
torch.testing.assert_close(B, B_ref)

## More Annotation

### Use parameters directly as annotations

You can directly use function parameters in the annotations.

In [11]:
@tilelang.lazy_jit
def gemm_ptr(
    A,
    B,
    M,
    N,
    K,
):
    A: T.Tensor[[M, K], T.float16]
    B: T.Tensor[[K, N], T.float16]
    C = T.empty((M, N), T.float32)
    gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)
    return C

In [12]:
A = torch.randn(1024, 512, dtype=torch.float16, device="cuda")
B = torch.randn(512, 256, dtype=torch.float16, device="cuda")
C = gemm_ptr(A, B, 1024, 256, 512)
C_ref = (A @ B).float()
torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)

### Annotations for runtime variables

Runtime variables work the same; if the function annotation becomes too long, you can move it into the function body.

In [13]:
@tilelang.lazy_jit
def gemm_ptr_dyn(A, B, M, N, K):
    M: T.int32
    N: T.int32
    K: T.int32
    A: T.Tensor[[M, K], T.float16]
    B: T.Tensor[[K, N], T.float16]
    C = T.empty((M, N), T.float32)
    gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)
    return C

In [14]:
A = torch.randn(1024, 512, dtype=torch.float16, device="cuda")
B = torch.randn(512, 256, dtype=torch.float16, device="cuda")
C = gemm_ptr_dyn(A, B, 1024, 256, 512)
C_ref = (A @ B).float()
torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)

### Constraints for constants

A constant annotation created by `T.const` must be used directly at least once, otherwise an error is raised.

In [15]:
@tilelang.lazy_jit
def example_wrong_kernel(A):
    M = T.const("M")
    A: T.Tensor[[M * 2, M * 3], T.float32]
    with T.Kernel(1) as _:
        A[0, 0]


try:
    A = torch.randn(64, 96, dtype=torch.float32, device="cuda")
    example_wrong_kernel(A)
except Exception as e:
    print(e)

Constexpr variable `M` is not used in any buffer shape or stride.
At least one **DIRECT** usage is required. Please check:
(1) the variable is not used
(2) all uses are indirect, e.g. M * 2, M * 3. (you can replace them with separate constexpr variables)
Buffer shapes: {A: [M * 2, M * 3]}
Buffer strides: {A: [M * 3, 1]}


### Dynamic dimensions

If you want certain parameters in a Tensor annotation to change, it is recommended to switch to the `T.ptr` + `T.match_buffer` style.

In [16]:
@tilelang.lazy_jit
def dyn_annot(
    A: T.ptr,  # 1. T.ptr type annotation
    is_2d=False,
):
    if is_2d:
        M, N = T.const("M, N")
        # 2. dynamic shape annotation inside function body
        A = T.match_buffer(A, [M, N], T.float32)
        with T.Kernel(1) as _:
            A[0, 0]
    else:
        L = T.const("L")
        A = T.match_buffer(A, [L], T.float32)
        with T.Kernel(1) as _:
            A[0]


A = torch.randn(64, 96, dtype=torch.float32, device="cuda")
dyn_annot(A, is_2d=True)

[]

### Default arguments

Scalar annotations like `T.float32` can carry default values.

In [17]:
@tilelang.lazy_jit
def add_one(X, data: T.float32 = 1):
    M, N = T.const("M, N")
    X: T.Tensor[[M, N], T.float32]
    Y = T.empty((M, N), T.float32)
    with T.Kernel(T.ceildiv(M, 128), threads=128) as bx:
        for i, j in T.Parallel(128, N):
            Y[bx * 128 + i, j] = X[bx * 128 + i, j] + data
    return Y

In [18]:
X = torch.randn(1024, 1024, dtype=torch.float32, device="cuda")
Y = add_one(X)
torch.testing.assert_close(Y, X + 1)

## Overhead of argument matching

LazyJIT has very small overhead; each additional constant annotation costs about 200 ns.
* 200 ns is roughly the cost of an FFI call that reads parameters from a `torch.Tensor`'s shape/stride.

In [None]:
import time

A = torch.randn(128, 128, dtype=torch.float16, device="cuda")
B = torch.randn(128, 128, dtype=torch.float16, device="cuda")


@tilelang.lazy_jit
def dummy_kernel(A, B):
    M, N = T.const("M, N")
    A: T.Tensor[[M, N], T.float16]
    B: T.Tensor[[M, N], T.float16]
    with T.Kernel(1) as _:
        pass


# compile it first
dummy_kernel(A, B)


def eval_overhead(f):
    start = time.perf_counter_ns()
    for _ in range(10000):
        f()
    stop = time.perf_counter_ns()
    return (stop - start) / 10000 / 1000


kernel_call_overhead = eval_overhead(lambda: dummy_kernel(A, B))
parse_cache_key_overhead = eval_overhead(lambda: dummy_kernel.parse_cache_key(A, B))

print(f"Kernel call    : {kernel_call_overhead:.2f} us")
print(f"Parse cache key: {parse_cache_key_overhead:.2f} us")

Kernel call    : 7.68 us
Parse cache key: 0.41 us


## Compilation and parallel compilation

Both `lazyjit` and the original `jit` support parallel compilation.

To avoid wasting memory on temporary `torch.Tensor` objects, you can use `T.Tensor` to create placeholders.

In [20]:
from itertools import product


def get_configs():
    return [
        {
            "A": T.Tensor((1024, 1024), T.float32),
            "B": T.Tensor((1024, 1024), T.float32),
            "block_M": block_M,
            "block_N": block_N,
            "block_K": block_K,
        }
        for block_M, block_N, block_K in product([32, 64], repeat=3)
    ]


gemm.par_compile(get_configs())

Elaborating:   0%|          | 0/8 [00:00<?, ?it/s]

Parallel Compiling:   0%|          | 0/8 [00:00<?, ?it/s]

[<tilelang.jit.kernel.JITKernel at 0x7ef9f7de7d70>,
 <tilelang.jit.kernel.JITKernel at 0x7ef9f7de52b0>,
 <tilelang.jit.kernel.JITKernel at 0x7ef9f7e34b30>,
 <tilelang.jit.kernel.JITKernel at 0x7ef9f7e34530>,
 <tilelang.jit.kernel.JITKernel at 0x7ef9f7de6900>,
 <tilelang.jit.kernel.JITKernel at 0x7ef9f7e344a0>,
 <tilelang.jit.kernel.JITKernel at 0x7ef9f7e347a0>,
 <tilelang.jit.kernel.JITKernel at 0x7ef9f7fb25d0>]

## More convenient macros

tilelang's macros have been improved:

1. Allow using `T.Ref` as an annotation, similar to C++ references.
2. Allow returning multiple values.
3. Allow nesting and recursion.

### Passing references with `T.Ref`

A `T.Ref` reference can point to a scalar variable or to an element of a buffer.

In [21]:
@T.macro
def macro_with_ref(x: T.Ref):
    x = 1  # noqa: F841


@T.prim_func
def foo(x: T.Tensor((2,))):
    with T.Kernel(1) as _:
        # Supports constant indices
        macro_with_ref(x[1])

        # Also supports variable indices
        idx = T.alloc_var(T.int32, 0)
        macro_with_ref(x[idx])


foo

# from tvm.script import tir as T

@T.prim_func
def foo(x_handle: T.handle):
    x = T.match_buffer(x_handle, (2,), strides=(1,))
    # with T.block("root"):
    bx = T.launch_thread("blockIdx.x", 1)
    tx = T.launch_thread("threadIdx.x", 128)
    ty = T.launch_thread("threadIdx.y", 1)
    tz = T.launch_thread("threadIdx.z", 1)
    with T.block("tilelang_root"):
        T.reads()
        idx = T.Buffer((1,), "int32", scope="local.var")
        T.writes(x[T.min(1, idx[0]):T.min(1, idx[0]) + (T.max(1, idx[0]) + 1 - T.min(1, idx[0]))])
        T.block_attr({"tl.local_var_init": {idx.data: 0}})
        idx = T.alloc_buffer((1,), "int32", data=idx.data, scope="local.var")
        x[1] = T.float32(1.0)
        _tmp: T.int32 = idx[0]
        x[_tmp] = T.float32(1.0)

### Pass macros as arguments

You can pass a macro as a function argument.

In [22]:
@tilelang.lazy_jit
def element_wise(A, fn):
    N = T.dynamic("N")
    A: T.Tensor[[N], T.float32]
    B = T.empty((N,), dtype=A.dtype)
    block_N = 128
    with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:
        for i in T.Parallel(block_N):
            idx = bx * block_N + i
            B[idx] = fn(A[idx])
    return B


@T.macro
def add_one(x):
    return x + 1

In [23]:
A = torch.randn(1024, device="cuda")
B = element_wise(A, add_one)
B_ref = A + 1
torch.testing.assert_close(B, B_ref)

### Recursive macros

You may not need this often, but macros can be recursive as long as the termination condition is known at compile time.

In [24]:
@T.macro
def n31(x, var: T.Ref):
    if x == 1:
        pass
    elif x % 2 == 0:
        var = var // 2
        n31(x // 2, var)
    else:
        var = var * 3 + 1
        n31(x * 3 + 1, var)


@tilelang.lazy_jit
def foo(A: T.Tensor[[1], T.int32], n: int):
    with T.Kernel(1) as _:
        n31(n, A[0])

In [25]:
A = torch.tensor([100], dtype=torch.int32, device="cuda")
foo(A, 5)
A

tensor([18], device='cuda:0', dtype=torch.int32)

### Macros returning multiple values

In [26]:
@T.macro
def sincos(x):
    return T.sin(x), T.cos(x)


@T.prim_func
def foo():
    with T.Kernel(32) as x:
        s, c = sincos(x)
        a = s + c  # noqa: F841
        b = s - c  # noqa: F841


foo

# from tvm.script import tir as T

@T.prim_func
def foo():
    # with T.block("root"):
    x = T.launch_thread("blockIdx.x", 32)
    tx = T.launch_thread("threadIdx.x", 128)
    ty = T.launch_thread("threadIdx.y", 1)
    tz = T.launch_thread("threadIdx.z", 1)
    with T.block("tilelang_root"):
        T.reads()
        T.writes()
        s: T.int32 = T.sin(x)
        c: T.int32 = T.cos(x)
        a: T.int32 = s + c
        b: T.int32 = s - c
        T.evaluate(0)