In [None]:
!git clone https://github.com/hao-ai-lab/cse234-w25-PA.git

fatal: destination path 'cse234-w25-PA' already exists and is not an empty directory.


In [None]:
%cd /content/cse234-w25-PA/pa2

/content/cse234-w25-PA/pa2


In [None]:
import torch
import triton
import triton.language as tl
import time

In [None]:
def is_cuda():
    return triton.runtime.driver.active.get_current_target().backend == "cuda"

In [None]:
def is_hip_mi200():
    target = triton.runtime.driver.active.get_current_target()
    return target.backend == 'hip' and target.arch == 'gfx90a'

In [None]:
# -----------------------------------------------------------------------------
# Tiling parameters - You will need to change these to achieve better results.
# -----------------------------------------------------------------------------
BLOCK_M = 256  # Tile size in the M dimension.
BLOCK_N = 64 # Tile size in the N dimension.
BLOCK_K = 32 # Tile size in the K dimension.


# -----------------------------------------------------------------------------
# Triton Kernel: Matrix Multiplication + ReLU + Add
#
# The kernel uses:
#   Step 1: Tile assignment (each kernel computes a tile of C)
#   Step 2: Shared memory tiling + Cooperative Fetching: Load tiles of A and B.
#   Step 3: Register tiling: Use a register accumulator.
#   Step 4: Add and ReLU fusion
#   Step 5: Write cache/Epilogue: Write the final tile back to global memory.
# -----------------------------------------------------------------------------
@triton.jit
def matmul_add_relu_kernel_fp16(
    a_ptr, b_ptr, c_ptr, d_ptr,
    M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
    stride_am: tl.constexpr, stride_ak: tl.constexpr,
    stride_bk: tl.constexpr, stride_bn: tl.constexpr,
    stride_cm: tl.constexpr, stride_cn: tl.constexpr,
    stride_dm: tl.constexpr, stride_dn: tl.constexpr,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    # -------------------------------------------------------------------------
    # Step 1: Compute program ID and offsets for this tile
    # -------------------------------------------------------------------------
    pid = tl.program_id(0)
    num_pid_m = tl.cdiv(M, BLOCK_M)
    pid_m = pid // num_pid_m
    pid_n = pid % num_pid_m
    offs_m_base = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n_base = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    mask_m = offs_m_base < M
    mask_n = offs_n_base < N

    # -------------------------------------------------------------------------
    # Step 2: Initialize accumulator with higher precision
    # -------------------------------------------------------------------------
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    # -------------------------------------------------------------------------
    # Step 3: Main loop - optimize for T4 tensor cores
    # -------------------------------------------------------------------------
    for k in range(0, K, BLOCK_K):
        offs_k = k + tl.arange(0, BLOCK_K)
        mask_k = offs_k < K
        offs_a = (offs_m_base[:, None] * stride_am + offs_k[None, :] * stride_ak)
        offs_b = (offs_k[:, None] * stride_bk + offs_n_base[None, :] * stride_bn)
        a = tl.load(a_ptr + offs_a, mask=mask_m[:, None] & mask_k[None, :], other=0.0)
        b = tl.load(b_ptr + offs_b, mask=mask_k[:, None] & mask_n[None, :], other=0.0)
        acc += tl.dot(a.to(tl.float16), b.to(tl.float16))

    # -------------------------------------------------------------------------
    # Step 4: Load C matrix and apply fused add + ReLU
    # -------------------------------------------------------------------------
    offs_c = offs_m_base[:, None] * stride_cm + offs_n_base[None, :] * stride_cn
    c = tl.load(c_ptr + offs_c, mask=mask_m[:, None] & mask_n[None, :], other=0.0)
    acc = tl.maximum(acc + c.to(tl.float32), 0.0)

    # -------------------------------------------------------------------------
    # Step 5: Store result with optimal memory pattern
    # -------------------------------------------------------------------------
    offs_d = offs_m_base[:, None] * stride_dm + offs_n_base[None, :] * stride_dn
    tl.store(d_ptr + offs_d, acc.to(tl.float16), mask=mask_m[:, None] & mask_n[None, :])

In [None]:
def matmul_add_relu_fp16(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
    """
    Computes Output = ReLU(A @ B + C) using fp16 precision for maximum throughput.
    """
    M, K = a.shape
    K2, N = b.shape
    assert K == K2, "Incompatible dimensions"

    d = torch.empty((M, N), device=a.device, dtype=torch.float16)
    # Create launch grid
    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))

    matmul_add_relu_kernel_fp16[grid](
        a, b, c, d,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
        d.stride(0), d.stride(1),
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K
    )
    return d

In [None]:
# Reference implementation using PyTorch
def reference_matmul_add_relu(A, B, C):
    result = torch.matmul(A, B).add(C).relu_()
    return result

In [None]:
# -----------------------------------------------------------------------------
# Accuracy Tests
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    torch.manual_seed(0)
    a = torch.randn((512, 512), device=torch.device("cuda"), dtype=torch.float16)
    b = torch.randn((512, 512), device=torch.device("cuda"), dtype=torch.float16)
    c = torch.randn((512, 512), device=torch.device("cuda"), dtype=torch.float16)
    triton_output = matmul_add_relu_fp16(a, b, c)
    torch_output = reference_matmul_add_relu(a, b, c)
    print(f"triton_output_with_fp16_inputs={triton_output}")
    print(f"torch_output_with_fp16_inputs={torch_output}")
    rtol = 1e-2 if is_hip_mi200() else 0.032
    if torch.allclose(triton_output, torch_output, atol=0.15, rtol=rtol):
        print("✅ Triton and Torch match")
    else:
        diff = triton_output - torch_output
        abs_diff = torch.abs(diff)
        max_abs_diff = torch.max(abs_diff)
        print(f"❌ Triton and Torch differ: {max_abs_diff=}")

triton_output_with_fp16_inputs=tensor([[ 0.0000,  6.1250,  0.0000,  ..., 10.0391,  0.0000,  0.0000],
        [ 7.9102, 15.6250, 26.6250,  ..., 11.4531,  5.3945, 18.6562],
        [ 2.7285,  0.0000,  0.0000,  ...,  0.0000, 26.1250,  0.0000],
        ...,
        [ 0.4316, 75.2500,  0.0000,  ..., 26.2812,  0.0000,  0.0000],
        [ 6.9570,  1.1260,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [27.6406, 26.9531, 22.9375,  ..., 13.5625,  6.0391, 21.6406]],
       device='cuda:0', dtype=torch.float16)
torch_output_with_fp16_inputs=tensor([[ 0.0000,  6.1289,  0.0000,  ..., 10.0391,  0.0000,  0.0000],
        [ 7.9102, 15.6328, 26.6250,  ..., 11.4531,  5.3945, 18.6562],
        [ 2.7266,  0.0000,  0.0000,  ...,  0.0000, 26.1250,  0.0000],
        ...,
        [ 0.4316, 75.2500,  0.0000,  ..., 26.2812,  0.0000,  0.0000],
        [ 6.9570,  1.1260,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [27.6406, 26.9531, 22.9375,  ..., 13.5625,  6.0391, 21.6406]],
       device='cuda:0', dt

In [None]:
# -----------------------------------------------------------------------------
# Performance Benchmark
# IMPORTANT: DO NOT CHANGE THIS CODE.
# THIS IS THE EXACT CODE THAT WILL BE USED TO GRADE YOUR IMPLEMENTATION.
# ANY CHANGES TO THIS CODE (INCLUDING DIMENSIONS, REPEATS, etc.)
# WILL CAUSE YOU TO HAVE DIFFERENT SPEEDUP RESULTS.
# -----------------------------------------------------------------------------
M = 2048
K = 2048
N = 2048

# KEEP THESE MATRICES IN FP16. FP32 WILL NOT PROVIDE ACCURATE RESULTS
A = torch.randn((M, K), device="cuda", dtype=torch.float16)
B = torch.randn((K, N), device="cuda", dtype=torch.float16)
C = torch.randn((M, N), device="cuda", dtype=torch.float16)

# warmup
_ = matmul_add_relu_fp16(A, B, C)
_ = reference_matmul_add_relu(A, B, C)

REPEATS = 5000

# time your implementation
print("Triton implementation")
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(REPEATS):
    _ = matmul_add_relu_fp16(A, B, C)
torch.cuda.synchronize()
triton_time = (time.perf_counter() - start) / REPEATS

# time pytorch
print("PyTorch implementation")
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(REPEATS):
    _ = reference_matmul_add_relu(A, B, C)
torch.cuda.synchronize()
torch_time = (time.perf_counter() - start) / REPEATS

print(f"Performance comparison for matrix multiplication ({M}x{K} @ {K}x{N}):")
print(f"Triton implementation: {triton_time*1000:.2f} ms")
print(f"PyTorch implementation: {torch_time*1000:.2f} ms")

print(f"\nSpeedup of Triton vs PyTorch: {torch_time/triton_time:.2f}x")

Triton implementation
PyTorch implementation
Performance comparison for matrix multiplication (2048x2048 @ 2048x2048):
Triton implementation: 0.76 ms
PyTorch implementation: 0.93 ms

Speedup of Triton vs PyTorch: 1.22x


In [None]:
# Write your grid search here.

def benchmark_matmul(M, N, K, block_m, block_n, block_k, repeats=100):
    """Benchmark a specific configuration of block sizes."""
    global BLOCK_M, BLOCK_N, BLOCK_K
    BLOCK_M, BLOCK_N, BLOCK_K = block_m, block_n, block_k

    a = torch.randn((M, K), device="cuda", dtype=torch.float16)
    b = torch.randn((K, N), device="cuda", dtype=torch.float16)
    c = torch.randn((M, N), device="cuda", dtype=torch.float16)

    for _ in range(10):
        _ = matmul_add_relu_fp16(a, b, c)
        _ = reference_matmul_add_relu(a, b, c)
        torch.cuda.synchronize()

    torch.cuda.empty_cache()

    # Benchmark Triton
    torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(repeats):
        _ = matmul_add_relu_fp16(a, b, c)
    torch.cuda.synchronize()
    triton_time = (time.perf_counter() - start) / repeats

    torch.cuda.empty_cache()

    # Benchmark PyTorch
    torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(repeats):
        _ = reference_matmul_add_relu(a, b, c)
    torch.cuda.synchronize()
    torch_time = (time.perf_counter() - start) / repeats

    speedup = torch_time / triton_time
    print(f"  Config (M={block_m}, N={block_n}, K={block_k}): Triton={triton_time*1000:.2f}ms, PyTorch={torch_time*1000:.2f}ms, Speedup={speedup:.2f}x")

    return speedup

def grid_search():
    """Run a grid search to find optimal block sizes."""
    block_m_options = [64, 128, 256]
    block_n_options = [64, 128, 256]
    block_k_options = [16, 32, 64]

    best_speedup = 0
    best_config = None
    results = []

    for block_m in block_m_options:
        for block_n in block_n_options:
            for block_k in block_k_options:
                try:
                    print(f"Testing configuration: M={block_m}, N={block_n}, K={block_k}")
                    speedup = benchmark_matmul(2048, 2048, 2048, block_m, block_n, block_k)

                    results.append({
                        'block_m': block_m,
                        'block_n': block_n,
                        'block_k': block_k,
                        'speedup': speedup
                    })

                    if speedup > best_speedup:
                        best_speedup = speedup
                        best_config = (block_m, block_n, block_k)

                    time.sleep(1)

                except Exception as e:
                    print(f"  Error with configuration {(block_m, block_n, block_k)}: {e}")

    results.sort(key=lambda x: x['speedup'], reverse=True)

    print("\nAll Configurations (by performance):")
    for i, result in enumerate(results):
        print(f"{i+1}. BLOCK_M={result['block_m']}, BLOCK_N={result['block_n']}, BLOCK_K={result['block_k']}: {result['speedup']:.2f}x speedup")

    print(f"\nBest configuration: BLOCK_M={best_config[0]}, BLOCK_N={best_config[1]}, BLOCK_K={best_config[2]}")
    print(f"Best speedup: {best_speedup:.2f}x")

    global BLOCK_M, BLOCK_N, BLOCK_K
    BLOCK_M, BLOCK_N, BLOCK_K = best_config

    return best_config, best_speedup

best_config, best_speedup = grid_search()
print("\nBest configuration:")
_ = benchmark_matmul(2048, 2048, 2048, *best_config, repeats=5000)

Testing configuration: M=64, N=64, K=16
  Config (M=64, N=64, K=16): Triton=2.25ms, PyTorch=0.82ms, Speedup=0.36x
Testing configuration: M=64, N=64, K=32
  Config (M=64, N=64, K=32): Triton=1.34ms, PyTorch=0.90ms, Speedup=0.67x
Testing configuration: M=64, N=64, K=64
  Config (M=64, N=64, K=64): Triton=1.24ms, PyTorch=0.93ms, Speedup=0.75x
Testing configuration: M=64, N=128, K=16
  Config (M=64, N=128, K=16): Triton=0.85ms, PyTorch=0.80ms, Speedup=0.94x
Testing configuration: M=64, N=128, K=32
  Config (M=64, N=128, K=32): Triton=0.71ms, PyTorch=0.84ms, Speedup=1.18x
Testing configuration: M=64, N=128, K=64
  Config (M=64, N=128, K=64): Triton=0.76ms, PyTorch=0.83ms, Speedup=1.10x
Testing configuration: M=64, N=256, K=16
  Config (M=64, N=256, K=16): Triton=1.20ms, PyTorch=0.86ms, Speedup=0.72x
Testing configuration: M=64, N=256, K=32
  Config (M=64, N=256, K=32): Triton=1.03ms, PyTorch=0.82ms, Speedup=0.79x
Testing configuration: M=64, N=256, K=64
  Config (M=64, N=256, K=64): Triton=