In [1]:
import torch
import triton
import triton.language as tl
import time

In [2]:
def is_cuda():
    return triton.runtime.driver.active.get_current_target().backend == "cuda"

is_cuda()

True

In [3]:
def is_hip_mi200():
    target = triton.runtime.driver.active.get_current_target()
    return target.backend == 'hip' and target.arch == 'gfx90a'

is_hip_mi200()

False

In [4]:
"""
PA2 Part 2: MatMul+Relu+Add Fused Optimization.
The kernel uses several optimization techniques:

  1. Shared memory tiling.
  2. Register tiling.
  3. Cooperative fetching.
  4. Operator Fusion
  5. Write cache / epilogue fusion.

Fill in the missing parts (marked with TODO).
"""

# -----------------------------------------------------------------------------
# Tiling parameters - You will need to change these to achieve better results.
# -----------------------------------------------------------------------------
BLOCK_M = 128  # Tile size in the M dimension.
BLOCK_N = 128 # Tile size in the N dimension.
BLOCK_K = 32 # Tile size in the K dimension.


# -----------------------------------------------------------------------------
# Triton Kernel: Matrix Multiplication + ReLU + Add
#
# The kernel uses:
#   Step 1: Tile assignment (each kernel computes a tile of C)
#   Step 2: Shared memory tiling + Cooperative Fetching: Load tiles of A and B.
#   Step 3: Register tiling: Use a register accumulator.
#   Step 4: Add and ReLU fusion
#   Step 5: Write cache/Epilogue: Write the final tile back to global memory.
# -----------------------------------------------------------------------------
@triton.jit
def matmul_add_relu_kernel_fp16(
    a_ptr, b_ptr, c_ptr, d_ptr,
    M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
    stride_am: tl.constexpr, stride_ak: tl.constexpr,
    stride_bk: tl.constexpr, stride_bn: tl.constexpr,
    stride_cm: tl.constexpr, stride_cn: tl.constexpr,
    stride_dm: tl.constexpr, stride_dn: tl.constexpr,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    # -------------------------------------------------------------------------
    # Step 1: Tile: Assignment
    #
    # Each kernel instance is mapped to a tile in the output matrix C.
    # Compute the starting indices (m_start, n_start) for this tile.
    # -------------------------------------------------------------------------
    # Compute the tile indices using program_id(0) for M and program_id(1) for N.
    grid_m = tl.program_id(0)
    grid_n = tl.program_id(1)
    m_start = grid_m * BLOCK_M
    n_start = grid_n * BLOCK_N
    off_m = m_start + tl.arange(0, BLOCK_M)
    off_n = n_start + tl.arange(0, BLOCK_N)
    mask_m = off_m < M
    mask_n = off_n < N

    # -------------------------------------------------------------------------
    # Step 2: Register Tiling
    # -------------------------------------------------------------------------
    # Initialize the accumulator "acc" with zeros (dtype: float16).
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float16)

    # -------------------------------------------------------------------------
    # Step 3: Shared Memory Tiling & Cooperative Fetching.
    # Compute pointers to the sub-tiles of A and B that are needed to compute
    # the current C tile. The offsets here serve to load BLOCK_SIZE_M x BLOCK_SIZE_K
    # and BLOCK_SIZE_K x BLOCK_SIZE_N blocks from A and B respectively.
    # -------------------------------------------------------------------------
    k_iterations = tl.cdiv(K, BLOCK_K)
    for k in range(0, k_iterations):
        k_start = k * BLOCK_K
        off_k = k_start + tl.arange(0, BLOCK_K)
        mask_k = off_k < K

        a_ptrs = a_ptr + off_m[:, None] * stride_am + off_k[None, :] * stride_ak
        b_ptrs = b_ptr + off_k[:, None] * stride_bk + off_n[None, :] * stride_bn
        a_mask = mask_m[:, None] & mask_k[None, :]
        b_mask = mask_k[:, None] & mask_n[None, :]
        a = tl.load(a_ptrs, mask=a_mask, other=0)
        b = tl.load(b_ptrs, mask=b_mask, other=0)

        acc += tl.dot(a, b, out_dtype=tl.float16)

    # -------------------------------------------------------------------------
    # Step 4: Apply ReLU and Add C to the accumulator
    # -------------------------------------------------------------------------
    c_ptrs = c_ptr + off_m[:, None] * stride_cm + off_n[None, :] * stride_cn
    c_mask = mask_m[:, None] & mask_n[None, :]  
    c = tl.load(c_ptrs, mask=c_mask)
    acc = acc + c
    acc = tl.maximum(acc, 0)

    # -------------------------------------------------------------------------
    # Step 5: Write Cache / Epilogue Fusion: Write the computed tile to D.
    # -------------------------------------------------------------------------
    d_ptrs = d_ptr + off_m[:, None] * stride_dm + off_n[None, :] * stride_dn
    d_mask = mask_m[:, None] & mask_n[None, :]
    tl.store(d_ptrs, acc, mask=d_mask)


In [5]:
def matmul_add_relu_fp16(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor,
                         block_m: int = BLOCK_M, block_n: int = BLOCK_N, block_k: int = BLOCK_K) -> torch.Tensor:
    """
    Computes Output = ReLU(A @ B + C) using fp16 precision for maximum throughput.
    """
    M, K = a.shape
    K2, N = b.shape
    assert K == K2, "Incompatible dimensions"

    d = torch.empty((M, N), device=a.device, dtype=torch.float16)
    # Create launch grid
    grid = (triton.cdiv(M, block_m), triton.cdiv(N, block_n))

    matmul_add_relu_kernel_fp16[grid](
        a, b, c, d,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
        d.stride(0), d.stride(1),
        BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k
    )
    return d

In [6]:
# Reference implementation using PyTorch
def reference_matmul_add_relu(A, B, C):
    result = torch.matmul(A, B).add(C).relu_()
    return result

In [7]:
# -----------------------------------------------------------------------------
# Accuracy Tests
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    torch.manual_seed(0)
    a = torch.randn((512, 512), device=torch.device("cuda"), dtype=torch.float16)
    b = torch.randn((512, 512), device=torch.device("cuda"), dtype=torch.float16)
    c = torch.randn((512, 512), device=torch.device("cuda"), dtype=torch.float16)
    triton_output = matmul_add_relu_fp16(a, b, c)
    torch_output = reference_matmul_add_relu(a, b, c)
    print(f"triton_output_with_fp16_inputs={triton_output}")
    print(f"torch_output_with_fp16_inputs={torch_output}")
    rtol = 1e-2 if is_hip_mi200() else 0.032
    if torch.allclose(triton_output, torch_output, atol=0.15, rtol=rtol):
        print("✅ Triton and Torch match")
    else:
        diff = triton_output - torch_output
        abs_diff = torch.abs(diff)
        max_abs_diff = torch.max(abs_diff)
        print(f"❌ Triton and Torch differ: {max_abs_diff=}")

triton_output_with_fp16_inputs=tensor([[ 3.4297,  0.0000, 12.4453,  ...,  0.0000,  0.0000,  0.0000],
        [23.2656,  0.0000,  0.0000,  ...,  0.0000,  0.0000, 27.6250],
        [ 0.0000,  0.9302,  0.0000,  ..., 10.3906,  0.0000, 14.1016],
        ...,
        [14.2578,  0.0000, 10.1953,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000, 27.2812,  0.0000,  ...,  4.6367,  0.0000,  0.0000],
        [ 0.0000, 20.9375,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0', dtype=torch.float16)
torch_output_with_fp16_inputs=tensor([[ 3.4336,  0.0000, 12.4453,  ...,  0.0000,  0.0000,  0.0000],
        [23.3281,  0.0000,  0.0000,  ...,  0.0000,  0.0000, 27.5938],
        [ 0.0000,  0.9146,  0.0000,  ..., 10.3750,  0.0000, 14.0938],
        ...,
        [14.2578,  0.0000, 10.2031,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000, 27.2344,  0.0000,  ...,  4.6172,  0.0000,  0.0000],
        [ 0.0000, 20.9531,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0', dt

In [8]:
# -----------------------------------------------------------------------------
# Performance Benchmark 
# IMPORTANT: DO NOT CHANGE THIS CODE. 
# THIS IS THE EXACT CODE THAT WILL BE USED TO GRADE YOUR IMPLEMENTATION.
# ANY CHANGES TO THIS CODE (INCLUDING DIMENSIONS, REPEATS, etc.)
# WILL CAUSE YOU TO HAVE DIFFERENT SPEEDUP RESULTS.
# -----------------------------------------------------------------------------
M = 2048
K = 2048
N = 2048

# KEEP THESE MATRICES IN FP16. FP32 WILL NOT PROVIDE ACCURATE RESULTS
A = torch.randn((M, K), device="cuda", dtype=torch.float16)
B = torch.randn((K, N), device="cuda", dtype=torch.float16)
C = torch.randn((M, N), device="cuda", dtype=torch.float16)

# warmup
_ = matmul_add_relu_fp16(A, B, C)
_ = reference_matmul_add_relu(A, B, C)

REPEATS = 5000

# time your implementation
print("Triton implementation")
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(REPEATS):
    _ = matmul_add_relu_fp16(A, B, C)
torch.cuda.synchronize()
triton_time = (time.perf_counter() - start) / REPEATS

# time pytorch
print("PyTorch implementation")
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(REPEATS):
    _ = reference_matmul_add_relu(A, B, C)
torch.cuda.synchronize()
torch_time = (time.perf_counter() - start) / REPEATS

print(f"Performance comparison for matrix multiplication ({M}x{K} @ {K}x{N}):")
print(f"Triton implementation: {triton_time*1000:.2f} ms")
print(f"PyTorch implementation: {torch_time*1000:.2f} ms")

print(f"\nSpeedup of Triton vs PyTorch: {torch_time/triton_time:.2f}x")

Triton implementation
PyTorch implementation
Performance comparison for matrix multiplication (2048x2048 @ 2048x2048):
Triton implementation: 0.24 ms
PyTorch implementation: 0.45 ms

Speedup of Triton vs PyTorch: 1.85x


In [9]:
# Write your grid search here.

def perf_triton_matmul_add_relu(M, K, N, block_m, block_n, block_k):
    A = torch.randn((M, K), device="cuda", dtype=torch.float16)
    B = torch.randn((K, N), device="cuda", dtype=torch.float16)
    C = torch.randn((M, N), device="cuda", dtype=torch.float16)
    # warmup
    _ = matmul_add_relu_fp16(A, B, C, block_m, block_n, block_k)

    REPEATS = 5000

    torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(REPEATS):
        _ = matmul_add_relu_fp16(A, B, C, block_m, block_n, block_k)
    torch.cuda.synchronize()
    triton_time = (time.perf_counter() - start) / REPEATS
    return triton_time

best_time = float('inf')
best_block_mnk = 0, 0, 0
for block_m in [64, 128, 256]:
    for block_n in [64, 128, 256]:
        for block_k in [16, 32, 64]:
            triton_time = perf_triton_matmul_add_relu(M, K, N, block_m, block_n, block_k)
            print(f"Block size: {block_m}x{block_n}x{block_k}, Triton time: {triton_time*1000:.2f} ms")
            if triton_time < best_time:
                best_time = triton_time
                best_block_mnk = block_m, block_n, block_k

print(f"Best block size: {best_block_mnk}")



Block size: 64x64x16, Triton time: 0.36 ms
Block size: 64x64x32, Triton time: 0.36 ms
Block size: 64x64x64, Triton time: 0.36 ms
Block size: 64x128x16, Triton time: 0.29 ms
Block size: 64x128x32, Triton time: 0.28 ms
Block size: 64x128x64, Triton time: 0.30 ms
Block size: 64x256x16, Triton time: 0.28 ms
Block size: 64x256x32, Triton time: 0.26 ms
Block size: 64x256x64, Triton time: 0.29 ms
Block size: 128x64x16, Triton time: 0.30 ms
Block size: 128x64x32, Triton time: 0.28 ms
Block size: 128x64x64, Triton time: 0.30 ms
Block size: 128x128x16, Triton time: 0.28 ms
Block size: 128x128x32, Triton time: 0.24 ms
Block size: 128x128x64, Triton time: 0.26 ms
Block size: 128x256x16, Triton time: 0.25 ms
Block size: 128x256x32, Triton time: 0.25 ms
Block size: 128x256x64, Triton time: 0.28 ms
Block size: 256x64x16, Triton time: 0.34 ms
Block size: 256x64x32, Triton time: 0.26 ms
Block size: 256x64x64, Triton time: 0.28 ms
Block size: 256x128x16, Triton time: 0.26 ms
Block size: 256x128x32, Trit

OutOfResources: out of resource: shared memory, Required: 131072, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.