In [1]:
from typing import Callable
from time import perf_counter
from dataclasses import dataclass
from itertools import product

import torch
from torch import Tensor

torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
DTYPE = torch.float32

In [2]:
class Counter:
    def __init__(self):
        self.reads, self.writes = 0, 0
    def inc_reads(self, n): self.reads += n    # increment reads
    def inc_writes(self, n): self.writes += n  # increment writes
    def show(self): print(f"Total reads: {self.reads}\nTotal writes: {self.writes}")

def timeit(func: Callable):
    def timer(*args, **kwargs):
        t1 = perf_counter()
        out = func(*args, **kwargs)
        t2 = perf_counter()
        diff = (t2 - t1) * 1000 # in ms
        print(f"Time elapsed: {diff:.2f} ms")
        return out
    return timer

def cdiv(a: int, b: int): return (a + b - 1) // b # equivalent to math.ceil()

@dataclass
class D3:
    z: int = 1
    y: int = 1
    x: int = 1
    @property
    def size(self): return self.x * self.y * self.z

In [3]:
def launch_cuda_kernel(blocks: D3, threads: D3, malloc_shared_size: int=0, kernel: Callable=None):
    assert kernel is not None
    @timeit
    def dispatch_kernel(*kernargs):
        for bz, by, bx in product(range(blocks.z), range(blocks.y), range(blocks.x)):
            shared_mem = torch.empty((malloc_shared_size,), dtype=DTYPE)
            for tz, ty, tx in product(range(threads.z), range(threads.y), range(threads.x)):
                kernel(D3(bz, by, bx), D3(tz, ty, tx), threads, shared_mem, *kernargs)
    return dispatch_kernel

In [4]:
# test example
M, K, N = 12, 24, 15
A = torch.randn(M, K, dtype=DTYPE)
B = torch.randn(K, N, dtype=DTYPE)
C_ref = A @ B
print(C_ref)

tensor([[ 5.03,  2.09, -9.03,  2.60,  3.83, -6.59,  7.71, -4.47, -6.40,  3.55,  0.10,  2.37, -3.79,  1.86, -2.75],
        [ 9.02, -0.43, -6.69, -2.15,  2.83, -4.01, -0.50,  4.68, -0.97,  6.35, -3.11,  0.69, -8.55,  1.70, -3.31],
        [-3.66, -8.28,  8.09, -3.47,  0.45, -5.76,  4.54, -0.28, -8.23, -9.75,  1.97,  4.39, -0.50,  0.77, -0.22],
        [-2.60,  4.26, -0.18, 13.37, -5.49, 13.48,  4.82, -0.98,  3.19,  6.84, 10.74, -7.64,  7.85, -2.85, -7.24],
        [-5.90,  5.01,  3.96,  0.84, -0.10, -4.95,  6.50,  5.18,  4.60,  4.21,  3.62, -8.26,  8.03,  3.57, -5.75],
        [ 2.15,  0.55, -0.10, -2.93,  1.21,  4.64, -1.13,  6.23, -1.13,  0.17,  8.36, -1.05, -5.91,  2.76, -1.31],
        [-1.76, -2.08, -0.72, 10.56, -0.92,  5.09, 10.48, -4.15,  2.66,  6.33,  3.80, -6.08,  5.25, -1.04, -8.29],
        [-1.68,  6.65, -2.83, 12.43, -5.49, 13.44, -0.77,  0.32,  4.95, 11.65,  8.31, -8.69,  7.76,  2.99, -8.59],
        [-3.49, -5.51, -2.55, -2.55, -4.06,  4.61, -0.28, -7.77,  4.09,  7.14,  

In [5]:
def check(C: Tensor): return "PASSED" if torch.allclose(C, C_ref, atol=1e-6) else "FAILED"

# Matmul Naive 1D

In [6]:
def matmul_naive_kernel(blockIdx, threadIdx, blockSize, shared_mem,
                        C: Tensor, A: Tensor, B: Tensor, M: int, K: int, N: int, counter: Counter):
    idx = (blockIdx.x * blockSize.x) + threadIdx.x
    if idx < M * N:
        acc = 0.
        m, n = idx // N, idx % N
        for k in range(K):
            acc += A[m * K + k] * B[k * N + n]
            counter.inc_reads(2)
        C[idx] = acc
        counter.inc_writes(1)

In [7]:
def matmul_naive(A: Tensor, B: Tensor):
    (M, K), (K_, N) = A.shape, B.shape
    assert K == K_, f"inner dims should match! {K} != {K_}"
    A = A.contiguous() if not A.is_contiguous() else A
    B = B.contiguous() if not B.is_contiguous() else B
    C = torch.empty(M, N, dtype=DTYPE)

    threads = 8
    blocks = cdiv(M * N, threads)
    kernel = launch_cuda_kernel(D3(x=blocks), D3(x=threads), kernel=matmul_naive_kernel)
    kernel(C.flatten(), A.flatten(), B.flatten(), M, K, N, counter:=Counter())
    counter.show()
    return C

In [8]:
C = matmul_naive(A, B)
print(check(C))

Time elapsed: 18.41 ms
Total reads: 8640
Total writes: 180
PASSED


# Matmul Naive 2D Tiled

In [9]:
def matmul_naive2D_kernel(blockIdx, threadIdx, blockSize, shared_mem,
                          C: Tensor, A: Tensor, B: Tensor, M: int, K: int, N: int, tileWidth: int, counter: Counter):
    idx_m = (blockIdx.y * blockSize.y) + threadIdx.y
    idx_n = (blockIdx.x * blockSize.x) + threadIdx.x
    tiles_k = cdiv(K, tileWidth)
    if idx_m < M and idx_n < N:
        for tile_k in range(tiles_k):
            for k in range(tileWidth):
                idx_k = tile_k * tileWidth + k
                if idx_k < K:
                    idx_a = idx_m * K + idx_k
                    idx_b = idx_k * N + idx_n
                    idx_c = idx_m * N + idx_n
                    C[idx_c] += A[idx_a] * B[idx_b]
                    counter.inc_reads(2)
                    counter.inc_writes(1)

In [10]:
def matmul_naive2D(A: Tensor, B: Tensor, tileWidth: int):
    (M, K), (K_, N) = A.shape, B.shape
    assert K == K_, f"inner dims should match! {K} != {K_}"
    A = A.contiguous() if not A.is_contiguous() else A
    B = B.contiguous() if not B.is_contiguous() else B
    C = torch.empty(M, N, dtype=DTYPE)

    blocks_m = cdiv(M, tileWidth)
    blocks_n = cdiv(N, tileWidth)
    kernel = launch_cuda_kernel(
        D3(y=blocks_m, x=blocks_n), D3(y=tileWidth, x=tileWidth), kernel=matmul_naive2D_kernel)
    kernel(C.flatten(), A.flatten(), B.flatten(), M, K, N, tileWidth, counter:=Counter())
    counter.show()
    return C

In [11]:
C = matmul_naive2D(A, B, 5)
print(check(C))

Time elapsed: 29.50 ms
Total reads: 8640
Total writes: 4320
PASSED


# Matmul 2D Tiled (Shared Memory Optimization)

In [12]:
def matmul_2Dshared_kernel(blockIdx, threadIdx, blockSize, shared_mem,
                           C: Tensor, A: Tensor, B: Tensor, M: int, K: int, N: int, tileWidth: int, counter: Counter):
    idx_m = (blockIdx.y * blockSize.y) + threadIdx.y
    idx_n = (blockIdx.x * blockSize.x) + threadIdx.x
    shared_offset = tileWidth ** 2
    tiles_k = cdiv(K, tileWidth)
    if idx_m < M and idx_n < N:
        for tile_k in range(tiles_k):
            # load tile onto shared memory
            for k in range(tileWidth):
                idx_k = tile_k * tileWidth + k
                idx_a = idx_m * K + idx_k
                idx_b = idx_k * N + idx_n
                idx_shared_a = threadIdx.y * tileWidth + k
                idx_shared_b = k * tileWidth + threadIdx.x + shared_offset
                shared_mem[idx_shared_a] = A[idx_a] if idx_k < K else 0
                shared_mem[idx_shared_b] = B[idx_b] if idx_k < K else 0
                if idx_k < K: counter.inc_reads(2)

            # compute dot products
            for k in range(tileWidth):
                idx_k = tile_k * tileWidth + k
                if idx_k < K:
                    idx_shared_a = threadIdx.y * tileWidth + k
                    idx_shared_b = k * tileWidth + threadIdx.x + shared_offset
                    idx_c = idx_m * N + idx_n
                    C[idx_c] += shared_mem[idx_shared_a] * shared_mem[idx_shared_b]
                    counter.inc_writes(1)

In [13]:
def matmul_2Dshared(A: Tensor, B: Tensor, tileWidth: int):
    (M, K), (K_, N) = A.shape, B.shape
    assert K == K_, f"inner dims should match! {K} != {K_}"
    A = A.contiguous() if not A.is_contiguous() else A
    B = B.contiguous() if not B.is_contiguous() else B
    C = torch.empty(M, N, dtype=DTYPE)

    blocks_m = cdiv(M, tileWidth)
    blocks_n = cdiv(N, tileWidth)
    kernel = launch_cuda_kernel(
        D3(y=blocks_m, x=blocks_n), D3(y=tileWidth, x=tileWidth), 2*tileWidth**2, kernel=matmul_2Dshared_kernel)
    kernel(C.flatten(), A.flatten(), B.flatten(), M, K, N, tileWidth, counter:=Counter())
    counter.show()
    return C

In [14]:
C = matmul_2Dshared(A, B, 3)
print(check(C))

Time elapsed: 53.01 ms
Total reads: 8640
Total writes: 4320
PASSED


# Matmul Naive 3D Tiled (for fun)

In [15]:
def matmul_naive3D_kernel(blockIdx, threadIdx, blockSize, shared_mem,
                          C: Tensor, A: Tensor, B: Tensor, M: int, K: int, N: int, counter: Counter):
    idx_m = (blockIdx.z * blockSize.z) + threadIdx.z
    idx_n = (blockIdx.y * blockSize.y) + threadIdx.y
    idx_k = (blockIdx.x * blockSize.x) + threadIdx.x
    if idx_m < M and idx_n < N and idx_k < K:
        idx_a = idx_m * K + idx_k
        idx_b = idx_k * N + idx_n
        idx_c = idx_m * N + idx_n
        C[idx_c] += A[idx_a] * B[idx_b]
        counter.inc_reads(2)
        counter.inc_writes(1)

In [16]:
def matmul_naive3D(A: Tensor, B: Tensor, tileWidth: int):
    (M, K), (K_, N) = A.shape, B.shape
    assert K == K_, f"inner dims should match! {K} != {K_}"
    A = A.contiguous() if not A.is_contiguous() else A
    B = B.contiguous() if not B.is_contiguous() else B
    C = torch.empty(M, N, dtype=DTYPE)

    blocks_m = cdiv(M, tileWidth)
    blocks_n = cdiv(N, tileWidth)
    blocks_k = cdiv(K, tileWidth)
    kernel = launch_cuda_kernel(
        D3(z=blocks_m, y=blocks_n, x=blocks_k), D3(tileWidth, tileWidth, tileWidth), kernel=matmul_naive3D_kernel)
    kernel(C.flatten(), A.flatten(), B.flatten(), M, K, N, counter:=Counter())
    counter.show()
    return C

In [17]:
C = matmul_naive3D(A, B, 3)
print(check(C))

Time elapsed: 34.81 ms
Total reads: 8640
Total writes: 4320
PASSED
