In [None]:
import sys
from pathlib import Path

sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))
import tilelang
import torch
import tilelang.language as T

# Tilelang Lazy JIT

## Tensor Annotation

Tilelang Lazy JIT combines the jit generation and invocation logic.

The function signature syntax is similar to triton but with significant enhancements, most notably allowing Tensor annotations:

For example, the code below annotates a 2D Tensor with T.Tensor[[int, int], T.float16]
1. Each dimension is a compile-time constant; changing it triggers recompilation
2. Its dtype must be T.float16

DType can also be Any or None in addition to a concrete type


In [2]:
@tilelang.lazy_jit
def gemm(
    A: T.Tensor[[int, int], T.float16],
    B: T.Tensor[[int, int], T.float16],
    out_dtype: T.dtype = T.float32,
    block_M: int = 128,
    block_N: int = 128,
    block_K: int = 32,
):
    M, K = A.shape
    K, N = B.shape
    C = T.empty((M, N), out_dtype)
    with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):
        A_shared = T.alloc_shared((block_M, block_K), A.dtype)
        B_shared = T.alloc_shared((block_K, block_N), B.dtype)
        C_local = T.alloc_fragment((block_M, block_N), out_dtype)
        T.clear(C_local)
        for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
            T.copy(A[bx * block_M, k * block_K], A_shared)
            T.copy(B[k * block_K, by * block_N], B_shared)
            T.gemm(A_shared, B_shared, C_local)
        T.copy(C_local, C[bx * block_M, by * block_N])
    return C

Call the Tensor directly as an argument to trigger the full jit compile-and-run workflow:

In [3]:
A = torch.randn(1024, 512, dtype=torch.float16, device="cuda")
B = torch.randn(512, 256, dtype=torch.float16, device="cuda")
C = gemm(A, B)

# check output is correct
C_ref = (A @ B).float()
torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)

Change the call-site arguments; if the compiler parameters differ, it recompiles:

In [4]:
A = torch.randn(1024, 512, dtype=torch.float16, device="cuda")
B = torch.randn(512, 1024, dtype=torch.float16, device="cuda")
C = gemm(A, B, block_M=64, block_N=64)

You can also manually call compile helpers to build a kernel

1. `ker.compile` compiles the kernel
2. `ker.get_tir` retrieves the tir
3. `ker.par_compile` compiles in parallel

In [5]:
kernel = gemm.compile(A, B, block_M=64, block_N=64)
C = kernel(A, B)



## More Tensor Annotation

### Separate the implementation with macros

Next we'll implement a simple gemm in several ways. For convenience, first write a macro that captures the main gemm logic:

In [6]:
@T.macro
def gemm_impl(A, B, C, M, N, K, block_M, block_N, block_K):
    with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):
        A_shared = T.alloc_shared((block_M, block_K), A.dtype)
        B_shared = T.alloc_shared((block_K, block_N), B.dtype)
        C_local = T.alloc_fragment((block_M, block_N), C.dtype)
        T.clear(C_local)
        for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
            T.copy(A[bx * block_M, k * block_K], A_shared)
            T.copy(B[k * block_K, by * block_N], B_shared)
            T.gemm(A_shared, B_shared, C_local)
        T.copy(C_local, C[bx * block_M, by * block_N])

### Mark dynamic shapes with T.dyn

When some dimensions are dynamic, mark them with T.dyn. T.dyn can take a string argument to name the variable

In [None]:
@tilelang.lazy_jit
def gemm_dyn_K(
    A: T.Tensor[[int, T.dyn["K"]], T.float16],  # noqa: F821
    B: T.Tensor[[T.dyn["K"], int], T.float16],  # noqa: F821
):
    M, K = A.shape
    K, N = B.shape
    C = T.empty((M, N), T.float32)
    gemm_impl(A, B, C, M, N, K, 128, 128, 32)
    return C

Inspect the lazy_jit function signature: parameters with a `$` suffix are compile-time constants that may vary, and those with `$dyn` are runtime variables

In [8]:
gemm_dyn_K.func.annot

{'A': TensorAnnot(shape=[A_shape_0$, K$dyn], strides=None, dtype=dtype('float16')),
 'B': TensorAnnot(shape=[K$dyn, B_shape_1$], strides=None, dtype=dtype('float16'))}

In [9]:
A = torch.randn(1024, 512, dtype=torch.float16, device="cuda")
B = torch.randn(512, 256, dtype=torch.float16, device="cuda")
C = gemm_dyn_K(A, B)
C_ref = (A @ B).float()
torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)

### Use T.StridedTensor to annotate tensors with strides

Annotation format: T.StridedTensor[Shape, Stride, DType]. Each Shape or Stride entry can be
* int: compile-time constant
* T.dyn: runtime value

DType can be None or Any

In [10]:
from typing import Any


@tilelang.lazy_jit
def as_contingious(A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]):
    M, N = A.shape
    B = T.empty((M, N), A.dtype)
    block_M = 128
    block_N = 128
    with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):
        T.copy(
            A[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],
            B[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],
        )
    return B

In [11]:
A = torch.randn(1024, 1024, device="cuda")
B = as_contingious(A[::2, ::2])
B_ref = A[::2, ::2].contiguous()
torch.testing.assert_close(B, B_ref)

## More Annotation

### Annotate tensors with T.ptr
lazy_jit lets you declare a handle with T.ptr, but you must define its shape inside the function via T.match_buffer

In [12]:
@tilelang.lazy_jit
def gemm_ptr(
    A: T.ptr,
    B: T.ptr,
    M: int,
    N: int,
    K: int,
):
    A = T.match_buffer(A, (M, K), T.float16)
    B = T.match_buffer(B, (K, N), T.float16)
    C = T.empty((M, N), T.float32)
    gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)
    return C

In [13]:
A = torch.randn(1024, 512, dtype=torch.float16, device="cuda")
B = torch.randn(512, 256, dtype=torch.float16, device="cuda")
C = gemm_ptr(A, B, 1024, 256, 512)
C_ref = (A @ B).float()
torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)

### Use T.int32 to annotate runtime variables

lazy_jit lets you define runtime variables with T.int32 or other types, enabling a fully dynamic gemm similar to triton

In [14]:
@tilelang.lazy_jit
def gemm_ptr_dyn(
    A: T.ptr,
    B: T.ptr,
    M: T.int32,
    N: T.int32,
    K: T.int32,
):
    A = T.match_buffer(A, (M, K), T.float16, strides=(K, 1))
    B = T.match_buffer(B, (K, N), T.float16, strides=(N, 1))
    C = T.empty((M, N), T.float32)
    gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)
    return C

In [15]:
A = torch.randn(1024, 512, dtype=torch.float16, device="cuda")
B = torch.randn(512, 256, dtype=torch.float16, device="cuda")
C = gemm_ptr_dyn(A, B, 1024, 256, 512)
C_ref = (A @ B).float()
torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)

## Compilation and parallel compilation

lazyjit and the original jit both support parallel compilation

To avoid wasting memory with torch.tensor placeholders, use T.Tensor to create placeholders

In [16]:
from itertools import product


def get_configs():
    return [
        {
            "A": T.Tensor((1024, 1024), T.float32),
            "B": T.Tensor((1024, 1024), T.float32),
            "block_M": block_M,
            "block_N": block_N,
            "block_K": block_K,
        }
        for block_M, block_N, block_K in product([32, 64], repeat=3)
    ]


gemm.par_compile(get_configs())

Elaborating:   0%|          | 0/8 [00:00<?, ?it/s]

Parallel Compiling:   0%|          | 0/8 [00:00<?, ?it/s]

[<tilelang.jit.kernel.JITKernel at 0x7f29c0072ed0>,
 <tilelang.jit.kernel.JITKernel at 0x7f29c00882f0>,
 <tilelang.jit.kernel.JITKernel at 0x7f29c00735f0>,
 <tilelang.jit.kernel.JITKernel at 0x7f29c0088890>,
 <tilelang.jit.kernel.JITKernel at 0x7f29c01f94c0>,
 <tilelang.jit.kernel.JITKernel at 0x7f29c0073fe0>,
 <tilelang.jit.kernel.JITKernel at 0x7f29c0070ce0>,
 <tilelang.jit.kernel.JITKernel at 0x7f29c00732f0>]

## More convenient macros

tilelang macros are now upgraded:

1. Allow `T.Ref` as an annotation, similar to C++ pass-by-reference
2. Allow returning multiple values
3. Allow nesting and recursion

### Passing references with T.Ref

The reference via T.Ref can target a var or a buffer element

In [None]:
@T.macro
def macro_with_ref(x: T.Ref):
    x = 1  # noqa: F841


@T.prim_func
def foo(x: T.Tensor((2,))):
    with T.Kernel(1) as _:
        # Supports constant indices
        macro_with_ref(x[1])

        # Also supports variable indices
        idx = T.alloc_var(T.int32, 0)
        macro_with_ref(x[idx])


foo

# import tilelang.language as T

@T.prim_func
def foo(x_handle: T.handle):
    x = T.match_buffer(x_handle, (2,), strides=(1,))
    # with T.block("root"):
    bx = T.launch_thread("blockIdx.x", 1)
    tx = T.launch_thread("threadIdx.x", 128)
    ty = T.launch_thread("threadIdx.y", 1)
    tz = T.launch_thread("threadIdx.z", 1)
    with T.block("tilelang_root"):
        T.reads()
        idx = T.Buffer((1,), "int32", scope="local.var")
        T.writes(x[T.min(1, idx[0]):T.min(1, idx[0]) + (T.max(1, idx[0]) + 1 - T.min(1, idx[0]))])
        T.block_attr({"tl.local_var_init": {idx.data: 0}})
        idx = T.alloc_buffer((1,), "int32", data=idx.data, scope="local.var")
        x[1] = T.float32(1.0)
        _tmp: T.int32 = idx[0]
        x[_tmp] = T.float32(1.0)

### Pass as arguments

You can pass macros as parameters

In [18]:
@tilelang.lazy_jit
def element_wise(
    A: T.Tensor[[T.dyn], Any],
    fn,
):
    (N,) = A.shape
    B = T.empty((N,), dtype=A.dtype)
    block_N = 128
    with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:
        for i in T.Parallel(block_N):
            idx = bx * block_N + i
            B[idx] = fn(A[idx])
    return B


@T.macro
def add_one(x):
    return x + 1

In [19]:
A = torch.randn(1024, device="cuda")
B = element_wise(A, add_one)
B_ref = A + 1
torch.testing.assert_close(B, B_ref)

### Macro recursion

Macro can be recursive, even if it's rarely needed, as long as the termination condition is known at compile time

In [20]:
@T.macro
def n31(x, var: T.Ref):
    if x == 1:
        pass
    elif x % 2 == 0:
        var = var // 2
        n31(x // 2, var)
    else:
        var = var * 3 + 1
        n31(x * 3 + 1, var)


@tilelang.lazy_jit
def foo(A: T.Tensor[[1], T.int32], n: int):
    with T.Kernel(1) as _:
        n31(n, A[0])

In [21]:
A = torch.tensor([100], dtype=torch.int32, device="cuda")
foo(A, 5)
A

tensor([18], device='cuda:0', dtype=torch.int32)

### Macro returning multiple values

In [None]:
@T.macro
def sincos(x):
    return T.sin(x), T.cos(x)


@T.prim_func
def foo():
    with T.Kernel(32) as x:
        s, c = sincos(x)
        a = s + c  # noqa: F841
        b = s - c  # noqa: F841


foo

# import tilelang.language as T

@T.prim_func
def foo():
    # with T.block("root"):
    x = T.launch_thread("blockIdx.x", 32)
    tx = T.launch_thread("threadIdx.x", 128)
    ty = T.launch_thread("threadIdx.y", 1)
    tz = T.launch_thread("threadIdx.z", 1)
    with T.block("tilelang_root"):
        T.reads()
        T.writes()
        s: T.int32 = T.sin(x)
        c: T.int32 = T.cos(x)
        a: T.int32 = s + c
        b: T.int32 = s - c
        T.evaluate(0)