In [None]:
import torch
import triton
import triton.language as tl

# Host side TMA api
from triton.tools.tensor_descriptor import TensorDescriptor


@triton.jit
def _schedule_pid_mn(pid, group_size, m_group_size, m):
    group_id = pid // group_size
    sub_pid = pid % group_size
    m_start = group_id * m_group_size
    group_m_size = min(m_group_size, m - m_start)
    return m_start + (sub_pid % group_m_size), sub_pid // group_m_size


def _get_mm_persistent_tma_config(pre_hook=None):
    return [
        triton.Config(
            {"BM": BM, "BN": BN, "BK": BK, "GROUP_SIZE_M": 8, "EPILOGUE_SUBTILE": epilogue_subtile},
            num_stages=stage,
            num_warps=num_warps,
            pre_hook=pre_hook,
        )
        for BM in [128]
        for BN in [128, 256]
        for BK in [64, 128]
        for stage in [2, 3, 4]
        for num_warps in [4, 8]
        for epilogue_subtile in {True, False}
    ]


def _set_mma_args_hook(nargs):
    BM, BK, BN = nargs["BM"], nargs["BK"], nargs["BN"]
    nargs["a_desc"].block_shape = [BM, BK]
    nargs["b_desc"].block_shape = [BN, BK]
    if nargs["EPILOGUE_SUBTILE"]:
        nargs["c_desc"].block_shape = [BM, BN // 2]
    else:
        nargs["c_desc"].block_shape = [BM, BN]


@triton.autotune(
    configs=_get_mm_persistent_tma_config(pre_hook=_set_mma_args_hook),
    key=["M", "N", "K"],
)
@triton.jit
def _matmul_tma_persistent_kernel(
    a_desc,
    b_desc,
    c_desc,
    M,
    N,
    K,
    BM: tl.constexpr,
    BN: tl.constexpr,
    BK: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
    EPILOGUE_SUBTILE: tl.constexpr,
    FP8_OUTPUT: tl.constexpr,
    NUM_SMs: tl.constexpr,
    WARP_SPECIALIZE: tl.constexpr,
):
    dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16

    pid = tl.program_id(axis=0)
    m_tiles = tl.cdiv(M, BM)
    n_tiles = tl.cdiv(N, BN)
    k_tiles = tl.cdiv(K, BK)
    total_tiles = m_tiles * n_tiles
    group_tiles = GROUP_SIZE_M * n_tiles

    tile_c = pid - NUM_SMs
    for tile_id in tl.range(pid, total_tiles, NUM_SMs, flatten=True, warp_specialize=WARP_SPECIALIZE):
        # reschedule the tile_id to the pid_m and pid_n
        pid_m, pid_n = _schedule_pid_mn(tile_id, group_tiles, GROUP_SIZE_M, m_tiles)

        off_am = pid_m * BM
        off_bn = pid_n * BN
        accu = tl.zeros((BM, BN), dtype=tl.float32)
        for ki in tl.range(k_tiles):
            a = a_desc.load([off_am, ki * BK])
            b = b_desc.load([off_bn, ki * BK])
            accu = tl.dot(a, b.T, accu)

        # actually the same value as tile_id, maually use another variable to promote pipelining
        tile_c += NUM_SMs
        pid_m, pid_n = _schedule_pid_mn(tile_c, group_tiles, GROUP_SIZE_M, m_tiles)
        off_am_c = pid_m * BM
        off_bn_c = pid_n * BN

        # split the epilogue may reduce the shared memory consumption, thus promote more stages
        if EPILOGUE_SUBTILE:
            accu = tl.reshape(accu, (BM, 2, BN // 2))
            accu = tl.permute(accu, (0, 2, 1))
            acc0, acc1 = tl.split(accu)
            c0 = acc0.to(dtype)
            c_desc.store([off_am_c, off_bn_c], c0)
            c1 = acc1.to(dtype)
            c_desc.store([off_am_c, off_bn_c + BN // 2], c1)
        else:
            accu = accu.to(dtype)
            c_desc.store([off_am_c, off_bn_c], accu)


def my_matmul_tma_persistent(a: torch.Tensor, b: torch.Tensor, warp_specialize=False) -> torch.Tensor:
    assert a.shape[1] == b.shape[1]
    M, K = a.shape
    N, K = b.shape

    c = torch.empty((M, N), device=a.device, dtype=a.dtype)

    dummy_block = [1, 1]
    a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block)
    b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block)
    c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block)

    NUM_SMs = torch.cuda.get_device_properties("cuda").multi_processor_count

    def grid(META):
        BM, BN = META["BM"], META["BN"]
        return (min(NUM_SMs, triton.cdiv(M, BM) * triton.cdiv(N, BN)),)

    _matmul_tma_persistent_kernel[grid](
        a_desc, b_desc, c_desc, M, N, K, NUM_SMs=NUM_SMs, FP8_OUTPUT=a.dtype == torch.float8_e4m3fn, WARP_SPECIALIZE=warp_specialize
    )
    return c

# Test

In [None]:
m_group_size = 4
n  = 7
group_tiles = m_group_size * n
m = 13
for i in range(m):
    for j in range(n):
        print(f"{j+i*n:2d}", end=" ")
    print()

 0  1  2  3  4  5  6 
 7  8  9 10 11 12 13 
14 15 16 17 18 19 20 
21 22 23 24 25 26 27 
28 29 30 31 32 33 34 
35 36 37 38 39 40 41 
42 43 44 45 46 47 48 
49 50 51 52 53 54 55 
56 57 58 59 60 61 62 
63 64 65 66 67 68 69 
70 71 72 73 74 75 76 
77 78 79 80 81 82 83 
84 85 86 87 88 89 90 


In [None]:
points = [[0]*n for _ in range(m)]
for i in range(m):
    for j in range(n):
        x,y = _schedule_pid_mn.fn(j+i*n, group_tiles, m_group_size, m)
        points[x][y] = j+i*n
for i in range(m):
    for j in range(n):
        print(f"{points[i][j]:2d}", end=" ")
    print()

 0  4  8 12 16 20 24 
 1  5  9 13 17 21 25 
 2  6 10 14 18 22 26 
 3  7 11 15 19 23 27 
28 32 36 40 44 48 52 
29 33 37 41 45 49 53 
30 34 38 42 46 50 54 
31 35 39 43 47 51 55 
56 60 64 68 72 76 80 
57 61 65 69 73 77 81 
58 62 66 70 74 78 82 
59 63 67 71 75 79 83 
84 85 86 87 88 89 90 
