In [1]:
import tilelang.language.v2 as tl
import torch
from tilelang import PassConfigKey
from typing import Tuple

## Quick Start

### Simple GEMM Kernel


Tilelang jit v2 allow you to write `tl.Tensor` to represent a schema for torch.Tensor
* `tl.Tensor[int]` is an 1-dimensional tensor, while `tl.Tensor[int, int]` is a 2-dimensional tensor

You can allocate global buffer using `tl.empty`, and return it.

In [2]:
@tl.jit
def gemm(
    A: tl.Tensor[int, int],
    B: tl.Tensor[int, int],
    accum_dtype: torch.dtype = torch.float32,
    block_N: int = 128,
    block_M: int = 128,
    block_K: int = 128,
):
    # use A.xxx_params() to get params
    #  params binding checks all param are the same
    N, K = A.shape_params()
    M, K = B.shape_params()

    C = tl.empty((M, N), dtype=accum_dtype)
    dims = [
        tl.ceildiv(M, block_M),
        tl.ceildiv(N, block_N),
    ]
    with tl.Kernel(*dims, threads=128) as (bx, by):
        A_shared = tl.alloc_shared((block_M, block_K), dtype=A.dtype)
        B_shared = tl.alloc_shared((block_K, block_N), dtype=B.dtype)
        C_local = tl.alloc_fragment((block_M, block_N), dtype=accum_dtype)
        tl.clear(C_local)
        for k in tl.Pipelined(tl.ceildiv(K, block_K), num_stages=3):
            tl.copy(A[by * block_M, k * block_K], A_shared)
            tl.copy(B[k * block_K, bx * block_N], B_shared)
            tl.gemm(A_shared, B_shared, C_local)
        tl.copy(C_local, C[by * block_M, bx * block_N])
    return C

Call the kernel with torch tensor will trigger jit compilation

In [3]:
A = torch.randn(1024, 1024, dtype=torch.float16).cuda()
B = torch.randn(1024, 1024, dtype=torch.float16).cuda()
C_tl = gemm(A, B)
C_torch = A.float() @ B.float()
torch.testing.assert_close(C_torch, C_tl, atol=1e-2, rtol=1e-2)

Benchmarking the kernel by `do_bench`

In [4]:
from tilelang.profiler import do_bench
do_bench(lambda: gemm(A, B))
# or you can use gemm.bench(A, B)

0.012575408695652115

### Pointer Based Kernel

JITv2 allow you to use pointer and make_tensor to write kernel, this is very similar to triton

In [5]:
@tl.jit
def gemm_ptr(
    A_ptr: tl.ptr,
    B_ptr: tl.ptr,
    M: int,
    N: int,
    K: int,
    dtype: torch.dtype = torch.float16,
    accum_dtype: torch.dtype = torch.float32,
    block_N: int = 128,
    block_M: int = 128,
    block_K: int = 128,
):
    # `tl.empty` is before `tl.make_tensor`, because `tl.empty` is host code, and `tl.make_tensor` is device code
    C = tl.empty((M, N), dtype=accum_dtype)

    A = tl.make_tensor(A_ptr, (M, K), dtype=dtype)
    B = tl.make_tensor(B_ptr, (N, K), dtype=dtype)
    dims = [
        tl.ceildiv(M, block_M),
        tl.ceildiv(N, block_N),
    ]
    with tl.Kernel(*dims, threads=128) as (bx, by):
        A_shared = tl.alloc_shared((block_M, block_K), dtype=A.dtype)
        B_shared = tl.alloc_shared((block_K, block_N), dtype=B.dtype)
        C_local = tl.alloc_fragment((block_M, block_N), dtype=accum_dtype)
        tl.clear(C_local)
        for k in tl.Pipelined(tl.ceildiv(K, block_K), num_stages=3):
            tl.copy(A[by * block_M, k * block_K], A_shared)
            tl.copy(B[k * block_K, bx * block_N], B_shared)
            tl.gemm(A_shared, B_shared, C_local)
        tl.copy(C_local, C[by * block_M, bx * block_N])
    return C

You need to manually convert a tensor to its pointer to call the kernel

In [6]:
A = torch.randn(1024, 1024, dtype=torch.float16).cuda()
B = torch.randn(1024, 1024, dtype=torch.float16).cuda()
C_torch = A.float() @ B.float()
C_tl = gemm_ptr(A.data_ptr(), B.data_ptr(), 1024, 1024, 1024, dtype=torch.float16)
torch.testing.assert_close(C_torch, C_tl, atol=1e-2, rtol=1e-2)

### Dynamic Arguments

Use `tl.dyn` to define dynamic argument, dynamic argument is a Var in TVM.

* `tl.dyn` is a dynamic int
* `tl.dyn[float]` is a dynamic float
* `tl.dyn[int, '_N']` is a dynamic int with name "_N"

In [7]:
_N = tl.dyn[int, '_N']
@tl.jit
def vec_add(
    A: tl.Tensor[_N],
    B: tl.Tensor[_N],
    block_N: int = 128 * 8,
):
    N, = A.shape
    assert A.dtype == B.dtype, "Expect 2 tensor with the same dtype"
    C = tl.empty((N,), dtype=A.dtype)
    with tl.Kernel(tl.ceildiv(N, block_N), threads=128) as bx:
        px = bx * block_N
        for i in tl.Parallel(block_N):
            C[px + i] = A[px + i] + B[px + i]
    return C

In [8]:
A = torch.randn(2**20, dtype=torch.float16).cuda()
B = torch.randn(2**20, dtype=torch.float16).cuda()
C_torch = A + B
C_tl = vec_add(A, B)
torch.testing.assert_close(C_torch, C_tl, atol=1e-2, rtol=1e-2)

Unnamed dynamic argument is also supported, you can use it both in tensor and as scalar

In [9]:
@tl.jit
def vec_add_scalar(
    A: tl.Tensor[int],
    cval: tl.dyn[float],
    block_N: int = 128 * 8
):
    N, = A.shape
    assert A.dtype == tl.get_tvm_dtype(float), "Expect A to have float dtype"
    C = tl.empty((N,), dtype=A.dtype)
    with tl.Kernel(tl.ceildiv(N, block_N), threads=128) as bx:
        px = bx * block_N
        for i in tl.Parallel(block_N):
            C[px + i] = A[px + i] + cval
    return C

In [10]:
A = torch.randn(2**20, dtype=torch.float32).cuda()
cval = 1.0
res_torch = A + cval
res_tl = vec_add_scalar(A, cval)
torch.testing.assert_close(res_torch, res_tl, atol=1e-2, rtol=1e-2)

### Strided Tensor

use tl.StridedTensor to to use strided tensor, you need to annotate which is static and which is dynamic
* Type hinting is not implemented in strided tensor, sorry...

In [11]:
@tl.jit
def get_contingous(
    # Fixed Shape, Variable Stride
    A: tl.StridedTensor[(int, int), (tl.dyn, tl.dyn)]
):
    M, N, dtype = A.params()
    C = tl.empty((M, N), dtype=dtype)
    with tl.Kernel(tl.ceildiv(M, 128), tl.ceildiv(N, 128), threads=128) as (bx, by):
        tl.copy(
            C[bx * 128: (bx + 1) * 128, by * 128: (by + 1) * 128],
            A[bx * 128: (bx + 1) * 128, by * 128: (by + 1) * 128],
        )
    return C

In [12]:
A = torch.randn(1024, 4, 1024, 4, dtype=torch.float16).cuda()
out = get_contingous(A[:, 0, :, 0])
gold = A[:, 0, :, 0].clone()
torch.testing.assert_close(out, gold, atol=1e-2, rtol=1e-2)

## Call Overhead

JITv2 has extremely low runtime overhead
* about **3us** additional overhead compared to torch native kernel
* the main overhead comes from
    * parse argument: ~1us
    * torch.current_stream: ~0.2us
    * torch.device('cuda'): ~0.2us
    * torch.empty: 1~2us

In [13]:
import time

print('GEMM 1024 x 1024 x 1024')

A = torch.randn(1024, 1024, dtype=torch.float16).cuda()
B = torch.randn(1024, 1024, dtype=torch.float16).cuda()
torch_beg = time.perf_counter()
for _ in range(10000):
    A @ B
torch_end = time.perf_counter()
elapsed = (torch_end - torch_beg) / 10000 * 1e6

print('Torch time:    ', elapsed, 'us')

tl_beg = time.perf_counter()
for _ in range(10000):
    gemm(A, B)
tl_end = time.perf_counter()
elapsed = (tl_end - tl_beg) / 10000 * 1e6

print('Tilelang time: ', elapsed, 'us')


GEMM 1024 x 1024 x 1024
Torch time:     19.696090300567448 us
Tilelang time:  12.795752682723105 us


In [14]:
import time

print('Vec Add 1024 x 1024')

A = torch.randn(1024 * 1024, dtype=torch.float16).cuda()
B = torch.randn(1024 * 1024, dtype=torch.float16).cuda()
torch_beg = time.perf_counter()
for _ in range(10000):
    A + B
torch_end = time.perf_counter()
elapsed = (torch_end - torch_beg) / 10000 * 1e6

print('Torch time:    ', elapsed, 'us')

tl_beg = time.perf_counter()
for _ in range(10000):
    vec_add(A, B)
tl_end = time.perf_counter()
elapsed = (tl_end - tl_beg) / 10000 * 1e6

print('Tilelang time: ', elapsed, 'us')

Vec Add 1024 x 1024
Torch time:     7.411740510724485 us
Tilelang time:  9.731462597846985 us


## Autotune

**We highly suggest to use autotune in development and make the parameter fixed in deployment**

### Run autotune

Run autotune with `ker.tune`

In [15]:
A = torch.randn(1024, 1024, dtype=torch.float16).cuda()
B = torch.randn(1024, 1024, dtype=torch.float16).cuda()
tune_result = gemm.tune(
    A,
    B,
    block_M=tl.tune([64, 128, 256]),
    block_N=tl.tune([64, 128, 256]),
    block_K=tl.tune([32, 64]),
)
tune_result

2025-10-13 15:42:01  [TileLang:tilelang.language.v2.jit:INFO]: Elaborate 18 configs


Elaborating:   0%|          | 0/18 [00:00<?, ?it/s]

Parallel Compiling:   0%|          | 0/18 [00:00<?, ?it/s]

Benchmarking:   0%|          | 0/18 [00:00<?, ?it/s]

AutoTuneResult(
  name=gemm,
  num_errors=0,
  best_latency=0.01013209827044034,
  best_args=gemm(
    tl.place(1024, 1024, dtype=torch.float16),
    tl.place(1024, 1024, dtype=torch.float16),
    block_M=64,
    block_N=128,
    block_K=64
  ),
  best={'A.dtype': torch.float16, 'A__shape_0': 1024, 'A__shape_1': 1024, 'A__stride_0': 1024, 'A__stride_1': 1, 'B.dtype': torch.float16, 'B__shape_0': 1024, 'B__shape_1': 1024, 'B__stride_0': 1024, 'B__stride_1': 1, 'accum_dtype': torch.float32, 'block_N': 128, 'block_M': 64, 'block_K': 64, 'latency': 0.01013209827044034, '_status': 'Success', '_error': ''},
  records=<18 records>,
)

In [16]:
best_kwargs = tune_result.best_args.kwargs
gemm.bench(A, B, **best_kwargs)



0.010162922123893914

The tune result can be converted into pandas for futher analysis

In [17]:
(
    tune_result
    .to_pandas()
    .pivot_table(index=['block_N', 'block_M'], columns='block_K', values='latency')
    .style
    .highlight_min(color='lightgreen')
)

Unnamed: 0_level_0,block_K,32,64
block_N,block_M,Unnamed: 2_level_1,Unnamed: 3_level_1
64,64,0.02341,0.015334
64,128,0.014278,0.010503
64,256,0.017084,0.013397
128,64,0.013945,0.010132
128,128,0.016207,0.01327
128,256,0.136026,0.143063
256,64,0.016845,0.012954
256,128,0.195256,0.213166
256,256,0.746829,0.735


### Advanced Tunning

Autotune is divided into 2 steps:
1. `kerl.get_tune_configs` to obtain all configs for tunning
2. `kerl.tune_configs` to tune all the configs

In [18]:
gemm.get_tune_configs(
    A,
    B,
    block_M=tl.tune([64, 128, 256]),
    block_N=tl.tune([64, 128, 256]),
    block_K=tl.tune([32, 64]),
)[:2]

[CallArgs(
   tl.place(1024, 1024, dtype=torch.float16),
   tl.place(1024, 1024, dtype=torch.float16),
   block_M=64,
   block_N=64,
   block_K=32
 ),
 CallArgs(
   tl.place(1024, 1024, dtype=torch.float16),
   tl.place(1024, 1024, dtype=torch.float16),
   block_M=64,
   block_N=64,
   block_K=64
 )]

For your advanced tuning, you can write a generator to generate call args:
* Return a **dict** to call the function with `**kwargs`
* Return a **list/tuple** to call the function with `*args`

In [20]:
from itertools import product
def generate_tune_args(A, B):
    for n, m, k in product((64, 128), (32, 64), (32, 64, 128)):
        yield {
            'A': A,
            'B': B,
            'block_N': n,
            'block_M': m,
            'block_K': k
        }
gemm.tune_configs(generate_tune_args(A, B))

2025-10-13 15:42:52  [TileLang:tilelang.language.v2.jit:INFO]: Elaborate 12 configs


Elaborating:   0%|          | 0/12 [00:00<?, ?it/s]

2025-10-13 15:42:52  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `gemm` with `out_idx=[2]`


Parallel Compiling:   0%|          | 0/12 [00:00<?, ?it/s]

2025-10-13 15:43:09  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `gemm`


Benchmarking:   0%|          | 0/12 [00:00<?, ?it/s]

AutoTuneResult(
  name=gemm,
  num_errors=0,
  best_latency=0.009626281249999693,
  best_args=gemm(
    A=tl.place(1024, 1024, dtype=torch.float16),
    B=tl.place(1024, 1024, dtype=torch.float16),
    block_N=128,
    block_M=64,
    block_K=128
  ),
  best={'A.dtype': torch.float16, 'A__shape_0': 1024, 'A__shape_1': 1024, 'A__stride_0': 1024, 'A__stride_1': 1, 'B.dtype': torch.float16, 'B__shape_0': 1024, 'B__shape_1': 1024, 'B__stride_0': 1024, 'B__stride_1': 1, 'accum_dtype': torch.float32, 'block_N': 128, 'block_M': 64, 'block_K': 128, 'latency': 0.009626281249999693, '_status': 'Success', '_error': ''},
  records=<12 records>,
)

### Tune Pass Configs and Compilation Options

In [21]:
from tilelang import PassConfigKey
@tl.jit
def gemm_tune_advanced(
    A: tl.Tensor[int, int],
    B: tl.Tensor[int, int],
    accum_dtype: torch.dtype = torch.float32,
    block_N: int = 128,
    block_M: int = 128,
    block_K: int = 128,
    disable_tma: bool = False,
    use_prec_sqrt: bool = False,
):
    tl.set_pass_configs({
        PassConfigKey.TL_DISABLE_TMA_LOWER: disable_tma
    })
    if use_prec_sqrt:
        tl.add_compile_flags(['--prec-sqrt', 'true'])
    (M, K), (N, K2) = A.shape, B.shape
    assert K == K2, "Expect matrix A and B to have the same number of columns"
    C = tl.empty((M, N), dtype=accum_dtype)
    dims = [
        tl.ceildiv(M, block_M),
        tl.ceildiv(N, block_N),
    ]
    with tl.Kernel(*dims, threads=128) as (bx, by):
        A_shared = tl.alloc_shared((block_M, block_K), dtype=A.dtype)
        B_shared = tl.alloc_shared((block_K, block_N), dtype=B.dtype)
        C_local = tl.alloc_fragment((block_M, block_N), dtype=accum_dtype)
        tl.clear(C_local)
        for k in tl.Pipelined(tl.ceildiv(K, block_K), num_stages=3):
            tl.copy(A[by * block_M, k * block_K], A_shared)
            tl.copy(B[k * block_K, bx * block_N], B_shared)
            for i, j in tl.Parallel(block_M, block_K):
                A_shared[i, j] = tl.sqrt(A_shared[i, j])
            tl.gemm(A_shared, B_shared, C_local)
        tl.copy(C_local, C[by * block_M, bx * block_N])
    return C

In [22]:
A = torch.randn(1024, 1024, dtype=torch.float16).cuda()
B = torch.randn(1024, 1024, dtype=torch.float16).cuda()
result = gemm_tune_advanced.tune(
    A, B,
    disable_tma=tl.tune([True, False]),
    use_prec_sqrt=tl.tune([True, False])
)
result.to_pandas().pivot_table(index='disable_tma', columns='use_prec_sqrt', values='latency')

2025-10-13 15:43:50  [TileLang:tilelang.language.v2.jit:INFO]: Elaborate 4 configs


Elaborating:   0%|          | 0/4 [00:00<?, ?it/s]

2025-10-13 15:43:50  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `gemm_tune_advanced` with `out_idx=[2]`
2025-10-13 15:43:50  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `gemm_tune_advanced` with `out_idx=[2]`
2025-10-13 15:43:50  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `gemm_tune_advanced` with `out_idx=[2]`
2025-10-13 15:43:50  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `gemm_tune_advanced` with `out_idx=[2]`


Parallel Compiling:   0%|          | 0/4 [00:00<?, ?it/s]

2025-10-13 15:44:02  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `gemm_tune_advanced`
2025-10-13 15:44:02  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `gemm_tune_advanced`
2025-10-13 15:44:02  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `gemm_tune_advanced`
2025-10-13 15:44:02  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `gemm_tune_advanced`


Benchmarking:   0%|          | 0/4 [00:00<?, ?it/s]

use_prec_sqrt,False,True
disable_tma,Unnamed: 1_level_1,Unnamed: 2_level_1
False,0.150496,0.150485
True,0.150605,0.150277


## Other Features

### Parallel Compilation

Use ker.par_compile to compile many kernel args parallely

In [23]:
def generate_args():
    available_nmk = [
        (1024, 2048, 1024),
        (1024, 1024, 1024),
        (2048, 1024, 1024),
        (1024, 2048, 2048),
        (2048, 2048, 2048),
    ]
    for n, m, k in available_nmk:
        yield {
            'A': tl.place(n, k, dtype=torch.float16),
            'B': tl.place(m, k, dtype=torch.float16),
        }
_ = gemm.par_compile(generate_args())

2025-10-13 15:44:06  [TileLang:tilelang.language.v2.jit:INFO]: Elaborate 5 configs


Elaborating:   0%|          | 0/5 [00:00<?, ?it/s]

2025-10-13 15:44:06  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `gemm` with `out_idx=[2]`
2025-10-13 15:44:06  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `gemm` with `out_idx=[2]`
2025-10-13 15:44:06  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `gemm` with `out_idx=[2]`
2025-10-13 15:44:06  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `gemm` with `out_idx=[2]`


Parallel Compiling:   0%|          | 0/5 [00:00<?, ?it/s]

2025-10-13 15:44:17  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `gemm`
2025-10-13 15:44:17  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `gemm`
2025-10-13 15:44:17  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `gemm`
2025-10-13 15:44:17  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `gemm`


Use `tl.empty_data_ptr()` as a place holder for `tl.ptr`

In [24]:
def generate_args_ptr():
    available_nmk = [
        (1024, 2048, 1024),
        (1024, 1024, 1024),
        (2048, 1024, 1024),
        (1024, 2048, 2048),
        (2048, 2048, 2048),
    ]
    for n, m, k in available_nmk:
        yield {
            'A_ptr': tl.empty_data_ptr(),
            'B_ptr': tl.empty_data_ptr(),
            'M': n,
            'N': m,
            'K': k,
        }
_ = gemm_ptr.par_compile(generate_args_ptr())

2025-10-13 15:44:29  [TileLang:tilelang.language.v2.jit:INFO]: Elaborate 5 configs


Elaborating:   0%|          | 0/5 [00:00<?, ?it/s]

2025-10-13 15:44:29  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `gemm_ptr` with `out_idx=[0]`
2025-10-13 15:44:29  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `gemm_ptr` with `out_idx=[0]`
2025-10-13 15:44:29  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `gemm_ptr` with `out_idx=[0]`
2025-10-13 15:44:29  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `gemm_ptr` with `out_idx=[0]`


Parallel Compiling:   0%|          | 0/5 [00:00<?, ?it/s]

2025-10-13 15:44:40  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `gemm_ptr`
2025-10-13 15:44:40  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `gemm_ptr`
2025-10-13 15:44:40  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `gemm_ptr`
2025-10-13 15:44:40  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `gemm_ptr`


## JITv2 Internals

JIT Compilation Flow:
1. User write Tilelang Kernel
2. **Elaboration**: Tilelang Kernel is converted to JITFunc
    * JITFunc contains all required data for compilation
3. **Compilation**: JITFunc is compiled to JITKernel
    * JITKernel is callable

You can use `kernel.partial` to generate the JITFunc, and manually compile it
* You can give it `torch.Tensor` or just `tl.place` as placeholder

In [25]:
func = gemm.partial(
    A=tl.place(1024, 1024, dtype=torch.float16),
    B=tl.place(1024, 1024, dtype=torch.float16),
    accum_dtype=torch.float32,
)
func

JITFunc(
  target=cuda -keys=cuda,gpu -arch=sm_90 -max_num_threads=1024 -thread_warp_size=32,
  target_host=None,
  global_allocs=[BufferSchema(name='C', shape=[1024, 1024], stride=[1024, 1], dtype=dtype('float32'), arg_idx=2)],
  outs=[BufferSchema(name='C', shape=(1024, 1024), stride=(1024, 1), dtype=dtype('float32'), arg_idx=2)],
  pass_configs={},
  compile_flags=[],
  arg_parser=<function parse_args.<locals>.gemm at 0x7f671c471260>,
  const_args=(torch.float16, 1024, 1024, 1024, 1, torch.float16, 1024, 1024, 1024, 1, torch.float32, 128, 128, 128),
  prim_func=r'''# from tvm.script import tir as T

@T.prim_func
def gemm(A_handle: T.handle, B_handle: T.handle, C_handle: T.handle):
    A = T.match_buffer(A_handle, (1024, 1024), "float16", strides=(1024, 1))
    B = T.match_buffer(B_handle, (1024, 1024), "float16", strides=(1024, 1))
    C = T.match_buffer(C_handle, (1024, 1024), strides=(1024, 1))
    # with T.block("root"):
    bx = T.launch_thread("blockIdx.x", 8)
    by = T.launch

Please pay attention to the `func.const_args`, it is the cache key used for kernel memory caching

It contains
* tensor shapes, strides, dtypes
* non dynamic kernel parameters
* autotune parameters

In [26]:
print(func.const_args)

(torch.float16, 1024, 1024, 1024, 1, torch.float16, 1024, 1024, 1024, 1, torch.float32, 128, 128, 128)


use `tl.compile` to compile the function, it generate the compiled kernel

* kernel.source is the source code of the device code
* kernel.wrapped_source contains both device code and host code

In [27]:
kernel = tl.compile(func, verbose=True)
kernel



JITKernel(
  lib_path='/home/zhoukexing/.tilelang/cache/28451a0742f2da6d1a7aa58dbb6c9e462bfd75ad6798a7a2b9baedbea38796d5/kernel_lib.so',
  lib=<cffi.api._make_ffi_library.<locals>.FFILibrary object at 0x7f6291bffa40>,
  lib_call=<function __closure.<locals>.wrapper at 0x7f6292446980>,
  source='#include <tl_templates/cuda/gemm.h>\n#include <tl_templates/cuda/copy.h>\n#include <tl_templates/cuda/reduce.h>\n#include <tl_templates/cuda/ldsm.h>\n#include <tl_templates/cuda/threadblock_swizzle.h>\n#include <tl_templates/cuda/debug.h>\n#ifdef ENABLE_BF16\n#include <tl_templates/cuda/cuda_bf16_fallbacks.cuh>\n#endif\n\nextern "C" __global__ void gemm_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc, float* __restrict__ C);\nextern "C" __global__ void __launch_bounds__(256, 1) gemm_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc, float* __restrict__ C) {\n  extern __shared__ __align__(1024) uchar buf_dyn

You can use `kernel.compile` to run both **elaboration** and **compilation**

In [28]:
kernel = gemm.compile(
    A=tl.place(1024, 1024, dtype=torch.float16),
    B=tl.place(1024, 1024, dtype=torch.float16),
    accum_dtype=torch.float32,
)
kernel

JITKernel(
  lib_path='/home/zhoukexing/.tilelang/cache/28451a0742f2da6d1a7aa58dbb6c9e462bfd75ad6798a7a2b9baedbea38796d5/kernel_lib.so',
  lib=<cffi.api._make_ffi_library.<locals>.FFILibrary object at 0x7f65e0797aa0>,
  lib_call=<function __closure.<locals>.wrapper at 0x7f65e08a5800>,
  source='#include <tl_templates/cuda/gemm.h>\n#include <tl_templates/cuda/copy.h>\n#include <tl_templates/cuda/reduce.h>\n#include <tl_templates/cuda/ldsm.h>\n#include <tl_templates/cuda/threadblock_swizzle.h>\n#include <tl_templates/cuda/debug.h>\n#ifdef ENABLE_BF16\n#include <tl_templates/cuda/cuda_bf16_fallbacks.cuh>\n#endif\n\nextern "C" __global__ void gemm_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc, float* __restrict__ C);\nextern "C" __global__ void __launch_bounds__(256, 1) gemm_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc, float* __restrict__ C) {\n  extern __shared__ __align__(1024) uchar buf_dyn

Placeholder is not a valid tensor, if you call the kernel with placeholder, it raises error

In [29]:
A = tl.place(1024, 1024, dtype=torch.float16)
B = tl.place(1024, 1024, dtype=torch.float16)
try:
    C = gemm(A, B)
except TypeError as e:
    print(repr(e))

TypeError('an integer is required')


You can manually parse the arguments using kernel.parse_arg

In [30]:
const_args, dyn_args = gemm.arg_parser(
    tl.place(1024, 1024, dtype=torch.float32),
    tl.place(1024, 1024, dtype=torch.float32)
)
print('const_args: ', const_args)
print('dyn_args:   ', dyn_args)

const_args:  (torch.float32, 1024, 1024, 1024, 1, torch.float32, 1024, 1024, 1024, 1, torch.float32, 128, 128, 128)
dyn_args:    (empty_data_ptr(), empty_data_ptr(), None, 0)


## Migration from JITv1

JITv2 support almost all grammar in JITv1, the only difference is the signature of function

When migrating from JITv1 to JITv2
1. Rewrite the function signature
2. Add shape checking
3. Allocate global buffer
4. Copy all `tl.Kernel` inside

Here is a example to migrate `flashattn_fwd` to JITv2

In [31]:
import tilelang.language as T
import tilelang.language.v2 as tl
import tilelang
from tilelang import PassConfigKey
import torch

In [32]:
@tilelang.jit(
    out_idx=[3, 4], pass_configs={
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
    })
def flashattn_fwd_jitv1(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
    scale = (1.0 / dim_qk)**0.5 * 1.44269504  # log2(e)
    head_kv = heads // groups
    q_shape = [batch, seq_len, heads, dim_qk]
    k_shape = [batch, seq_len, head_kv, dim_qk]
    v_shape = [batch, seq_len, head_kv, dim_v]
    dtype = "float16"
    accum_dtype = "float"

    @T.prim_func
    def flash_fwd(
            Q: T.Tensor(q_shape, dtype),  # type: ignore
            K: T.Tensor(k_shape, dtype),  # type: ignore
            V: T.Tensor(v_shape, dtype),  # type: ignore
            Output: T.Tensor([batch, seq_len, heads, dim_v], dtype),  # type: ignore
            lse: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
    ):
        with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz):
            Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
            K_shared = T.alloc_shared([block_N, dim_qk], dtype)
            V_shared = T.alloc_shared([block_N, dim_v], dtype)
            acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
            acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
            acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype)
            scores_max = T.alloc_fragment([block_M], accum_dtype)
            scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
            scores_scale = T.alloc_fragment([block_M], accum_dtype)
            scores_sum = T.alloc_fragment([block_M], accum_dtype)
            logsum = T.alloc_fragment([block_M], accum_dtype)

            T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
            T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
            T.fill(acc_o, 0)
            T.fill(logsum, 0)
            T.fill(scores_max, -T.infinity(accum_dtype))
            loop_range = (
                T.ceildiv(
                    (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
            for k in T.Pipelined(loop_range, num_stages=1):
                T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared)
                if is_causal:
                    for i, j in T.Parallel(block_M, block_N):
                        acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
                                                     -T.infinity(acc_s.dtype))
                else:
                    T.clear(acc_s)
                T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
                T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
                T.copy(scores_max, scores_max_prev)
                T.reduce_max(acc_s, scores_max, dim=1, clear=False)
                for i in T.Parallel(block_M):
                    scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
                for i, j in T.Parallel(block_M, dim_v):
                    acc_o[i, j] *= scores_scale[i]
                for i, j in T.Parallel(block_M, block_N):
                    acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
                T.copy(acc_s, acc_s_cast)
                T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
                T.reduce_sum(acc_s, scores_sum, dim=1)
                for i in T.Parallel(block_M):
                    logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
            for i, j in T.Parallel(block_M, dim_v):
                acc_o[i, j] /= logsum[i]
            T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
            for i in T.Parallel(block_M):
                logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
            T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M])

    return flash_fwd

In [33]:
pass_configs = {
    PassConfigKey.TL_ENABLE_FAST_MATH: True
}
@tl.jit(pass_configs=pass_configs)
def flashattn_fwd_jitv2(
    Q: tl.Tensor[int, int, int, int],
    K: tl.Tensor[int, int, int, int],
    V: tl.Tensor[int, int, int, int],
    is_causal: bool,
    block_M: int=128,
    block_N: int=64,
    accum_dtype: torch.dtype = torch.float32,
    threads: int = 256,
    groups: int = 1
):

    # 1. extract shape and dtype
    batch, seq_len, heads, dim_qk, dtype = Q.params()
    batch, seq_len, heads_kv, dim_qk, dtype = K.params()
    batch, seq_len, heads_kv, dim_v, dtype = V.params()

    # 3. allocate output
    Output = tl.empty((batch, seq_len, heads, dim_v), dtype)
    lse = tl.empty((batch, heads, seq_len), accum_dtype)

    scale = (1.0 / dim_qk)**0.5 * 1.44269504  # log2(e)

    # 4. paste all tilelang v1 code here
    with tl.Kernel(tl.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
        Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
        K_shared = T.alloc_shared([block_N, dim_qk], dtype)
        V_shared = T.alloc_shared([block_N, dim_v], dtype)
        acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
        acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
        acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype)
        scores_max = T.alloc_fragment([block_M], accum_dtype)
        scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
        scores_scale = T.alloc_fragment([block_M], accum_dtype)
        scores_sum = T.alloc_fragment([block_M], accum_dtype)
        logsum = T.alloc_fragment([block_M], accum_dtype)

        T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
        T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
        T.fill(acc_o, 0)
        T.fill(logsum, 0)
        T.fill(scores_max, -T.infinity(accum_dtype))
        loop_range = (
            T.ceildiv(
                (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
        for k in T.Pipelined(loop_range, num_stages=1):
            T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared)
            if is_causal:
                for i, j in T.Parallel(block_M, block_N):
                    acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
                                                    -T.infinity(acc_s.dtype))
            else:
                T.clear(acc_s)
            T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
            T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
            T.copy(scores_max, scores_max_prev)
            T.reduce_max(acc_s, scores_max, dim=1, clear=False)
            for i in T.Parallel(block_M):
                scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
            for i, j in T.Parallel(block_M, dim_v):
                acc_o[i, j] *= scores_scale[i]
            for i, j in T.Parallel(block_M, block_N):
                acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
            T.copy(acc_s, acc_s_cast)
            T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
            T.reduce_sum(acc_s, scores_sum, dim=1)
            for i in T.Parallel(block_M):
                logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
        for i, j in T.Parallel(block_M, dim_v):
            acc_o[i, j] /= logsum[i]
        T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
        for i in T.Parallel(block_M):
            logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
        T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M])

    return Output, lse

In [34]:
BATCH: int = 1
H: int = 32
N_CTX: int = 256
D_HEAD_QK: int = 192
D_HEAD_V: int = 128
groups: int = 16
causal: bool = False
use_atomic: bool = True
head_kv = H // groups
Q = torch.randn(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda")
K = torch.randn(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda")
V = torch.randn(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda")
batch, seq_len, heads, dim_qk = Q.shape
batch, seq_len, heads_kv, dim_qk = K.shape
batch, seq_len, heads_kv, dim_v = V.shape

In [35]:
import time
jit_v1_kernel = flashattn_fwd_jitv1(batch, heads, seq_len, dim_qk, dim_v, causal, block_M=128, block_N=128, groups=groups)
flashattn_fwd_jitv2(Q, K, V, causal, block_M=128, block_N=128, groups=groups)

time_beg = time.perf_counter()
for _ in range(10000):
    res_v1 = jit_v1_kernel(Q, K, V)
time_end = time.perf_counter()
print('JITv1: ', (time_end - time_beg) / 10000 * 1e6, 'us')

time_beg = time.perf_counter()
for _ in range(10000):
    res_v2 = flashattn_fwd_jitv2(Q, K, V, causal, block_M=128, block_N=128, groups=groups)
time_end = time.perf_counter()
print('JITv2: ', (time_end - time_beg) / 10000 * 1e6, 'us')
torch.testing.assert_close(res_v1, res_v2, atol=1e-2, rtol=1e-2)

2025-10-13 15:44:58  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `flash_fwd` with `out_idx=[3, 4]`
2025-10-13 15:45:09  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `flash_fwd`
2025-10-13 15:45:10  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `flashattn_fwd_jitv2` with `out_idx=[3, 4]`
2025-10-13 15:45:21  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `flashattn_fwd_jitv2`
JITv1:  33.56990269385278 us
JITv2:  16.764699900522828 us


In [36]:
flashattn_fwd_jitv2.bench(Q, K, V, causal, block_M=128, block_N=128, groups=groups)

0.013378110394842798