In [2]:
import torch

import triton
import triton.language as tl
from triton.runtime import driver

In [4]:
DEVICE = torch.device(f'cuda:{torch.cuda.current_device()}')

In [5]:
DEVICE

device(type='cuda', index=0)

### Autotune Configs
We sample different values of `BLOCK_SIZE_M`, `BLOCK_SIZE_N`, `BLOCK_SIZE_K`, `GROUP_SIZE_M` and `num_stages` and figure out the best combination that gives the best performance for the hardware

In [6]:
def get_cuda_autotune_config():
    return [
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
                      num_warps=8),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                      num_warps=2),
        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                      num_warps=2),
        # Good config for fp8 inputs.
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
                      num_warps=8),
        triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
                      num_warps=8),
        triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4)
    ]

In [31]:
# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
#   - A list of `triton.Config` objects that define different configurations of
#       meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
#   - An auto-tuning *key* whose change in values will trigger evaluation of all the
#       provided configs
@triton.autotune(
    configs=get_cuda_autotune_config(),
    key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(a_ptr: torch.Tensor, b_ptr: torch.Tensor, c_ptr: torch.Tensor,
                  M: int, N: int, K: int, stride_am: int, 
                  stride_ak: int, stride_bk: int, stride_bn: int,
                  stride_cm: int, stride_cn: int,
                  BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
                  BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
                  ACTIVATION: tl.constexpr):
    """Compute C = A x B
    Args:
        a_ptr: A torch Tensor of shape: M x K
        b_ptr: A torch Tensor of shape: K x N
        c_ptr: A torch Tensor of shape: M x N 
            This is where the output will be stored.
        M: Number of rows in A.
        K: Number of columns in A and rows in B.
        N: Number of columns in B.
        stride_am: int. How much to increase a_ptr by to reach the next row.
        atride_ak: int. How much to increase a_ptr by to reach next element.
        BLOCK_SIZE_M: Number of rows of A in a block.
        BLOCK_SIZE_K: Number of cols of A or rows of B in a block.
        BLOCK_SIZE_N: Number of cols of B in a block.
        GROUP_SIZE_M: Number of rows to process before moving on to the next column. 
            This promoted data-reuse. 
            GROUP_SIZE_M = 1 is row-major, maximum re-use of A but no re-use of B.
            GRPUP_SIZE_M = M is col-major, maximum reuse of Bbut no re-use of A.
            Somewhere in the middle we get optimal value for cache utilization.
            We'll use auto-tune config for this.
        ACTIVATION: activation to apply after matmul, fused operation.
    """
    
    pid = tl.program_id(0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) # Number of programs to process blocks of size M
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) # Number of programs to process blocks of size N
    # Number of programs/threads to process a group. Each group processes a tile of size: 
    # GROUP_SIZE_M (rows) × BLOCK_SIZE_N (columns)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    # pid is the program id of the current process and 
    # each group has num_pid_in_group programs
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    # If BLOCK_SIZE_M is not divisible GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m
    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_bn =  (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
    b_ptrs = b_ptr + offs_k[:, None] * stride_ak + offs_bn[None, :] * stride_bn
    
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
        accumulator = tl.dot(a, b, accumulator)
    
    if ACTIVATION == "leaky_relu":
        accumulator = leaky_relu(accumulator)
    c = accumulator.to(tl.float16)
    
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) % M
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) % N
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)
    
@triton.jit
def leaky_relu(x: torch.Tensor) -> torch.Tensor:
    return tl.where(x >= 0, x, 0.01 * x)
                  

In [32]:
def matmul(a, b, activation=""):
    # Check constraints.
    assert a.shape[1] == b.shape[0], "Incompatible dimensions"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    M, K = a.shape
    K, N = b.shape
    # Allocates output.
    c = torch.empty((M, N), device=a.device, dtype=torch.float16)
    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    matmul_kernel[grid](
        a, b, c,  #
        M, N, K, #
        a.stride(0), a.stride(1),  #
        b.stride(0), b.stride(1),  #
        c.stride(0), c.stride(1),  #
        ACTIVATION=activation  #
    )
    return c

In [34]:
torch.manual_seed(0)
a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
triton_output = matmul(a, b)
torch_output = torch.matmul(a, b)
print(f"triton_output_with_fp16_inputs={triton_output}")
print(f"torch_output_with_fp16_inputs={torch_output}")
# Bigger tolerance for AMD MI200 devices.
# MI200 devices use reduced precision fp16 and bf16 and flush input and
# output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
rtol = 1e-2
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol):
    print("✅ Triton and Torch match")
else:
    print("❌ Triton and Torch differ")

TORCH_HAS_FP8 = hasattr(torch, "float8_e5m2")

triton_output_with_fp16_inputs=tensor([[  41.1250,   56.9062,   50.7500,  ...,  -93.9375,  -95.2500,
          153.2500],
        [ 126.1250,  -41.1875,  151.0000,  ...,   14.8984,  -10.0625,
         -103.5000],
        [  22.1719,   -8.8516,  -76.1875,  ...,   81.5625, -130.1250,
           39.4062],
        ...,
        [  32.0938,   34.1250,   -5.8906,  ...,   33.2188,   51.9062,
          135.1250],
        [ -26.1875,  -71.1875,   -8.1328,  ...,   16.0312,  -21.0156,
           11.0547],
        [  95.8750,   -5.9609,   77.6875,  ...,  104.1250,  109.0625,
           78.6250]], device='cuda:0', dtype=torch.float16)
torch_output_with_fp16_inputs=tensor([[ -4.4844, -18.9844,   8.2500,  ..., -16.2344,  36.6562, -14.1406],
        [ 27.0781, -38.7188, -24.4531,  ..., -21.7031, -26.5938,  17.4688],
        [-13.3438,  14.1719,   7.6016,  ...,  -9.1172, -43.2500,   9.6406],
        ...,
        [  9.6406,   1.0146,  -9.3047,  ...,  -7.7852,  39.8438,  13.6172],
        [-28.8594,   8.2