In [1]:
from typing import Callable, Tuple
from time import perf_counter
from dataclasses import dataclass
from itertools import product
from threading import Barrier
from concurrent.futures import ThreadPoolExecutor

import torch
from torch import Tensor

In [2]:
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
DTYPE = torch.float32

In [3]:
class Counter:
    def __init__(self): self.loads, self.stores = 0, 0
    def increment(self, loads: int=0, stores: int=0):
        self.loads += loads    # increment loads
        self.stores += stores  # increment stores
    def show(self): print(f"Total reads: {self.loads} Total writes: {self.stores}")

def timeit(func: Callable):
    def timer(*args, **kwargs):
        st = perf_counter()
        out = func(*args, **kwargs)
        et = perf_counter()
        print(f"Time elapsed: {(et - st) * 1000:.2f} ms")
        return out
    return timer

def cdiv(a: int, b: int): return (a + b - 1) // b  # equivalent to math.ceil()

def tidx(fidx: int, tw: int): return fidx // tw, fidx % tw            # tiledIdx from flatIdx
def fidx(tidx0: int, tidx1: int, tw: int): return tidx0 * tw + tidx1  # flatIdx from tiledIdx

@dataclass
class dim3:
    z: int = 1
    y: int = 1
    x: int = 1
    @property
    def size(self): return self.z * self.y * self.x

In [4]:
def launch_cuda_kernel(gridSize: dim3, blockSize: dim3, kernel: Callable, shared_size: int=0):
    @timeit
    def dispatch_kernel(*kernargs):
        for blockIdx in product(range(gridSize.z), range(gridSize.y), range(gridSize.x)):
            shared_mem = torch.zeros(shared_size, dtype=DTYPE)
            barrier = Barrier(blockSize.size)
            with ThreadPoolExecutor(max_workers=blockSize.size) as e:
                for threadIdx in product(range(blockSize.z), range(blockSize.y), range(blockSize.x)):
                    e.submit(kernel, dim3(*blockIdx), dim3(*threadIdx), blockSize, barrier, shared_mem, *kernargs)
    return dispatch_kernel

In [5]:
# test example
M, K, N = 12, 24, 18
A = torch.randn(M, K, dtype=DTYPE)
B = torch.randn(K, N, dtype=DTYPE)
C_ref = A @ B

In [6]:
def check_output(C: Tensor):
    std = torch.std(C - C_ref).item()
    ret = "PASSED" if torch.allclose(C, C_ref, atol=1e-6) else f"FAILED"
    return f"{ret} {std=:.4f}"

In [7]:
print(check_output((A.view(M, 1, K) * B.T.view(1, N, K)).sum(dim=-1)))

PASSED std=0.0000


# Matmul Naive 1D

In [8]:
def matmul_naive1D_kernel(blockIdx: dim3, threadIdx: dim3, blockSize: dim3, barrier: Barrier, shared_mem: Tensor,
                          buffers: Tuple[Tensor, ...], metadata: Tuple[int, ...], counter: Counter):
    idx = (blockIdx.x * blockSize.x) + threadIdx.x
    (C, A, B), (M, K, N) = buffers, metadata
    m, n = tidx(idx, N)
    if m < M and n < N:
        acc = 0.
        for k in range(K):
            acc += A[fidx(m, k, K)] * B[fidx(k, n, N)]
            counter.increment(loads=2)
        C[idx] = acc
        counter.increment(stores=1)

In [9]:
def matmul_naive1D(A: Tensor, B: Tensor):
    (M, K), (K_, N) = A.shape, B.shape
    assert K == K_, f"inner dims should match, {K} != {K_}"
    A = A.contiguous() if not A.is_contiguous() else A
    B = B.contiguous() if not B.is_contiguous() else B
    C = torch.empty(M, N, dtype=DTYPE)

    threads = 8
    blockSize = dim3(x=threads)
    gridSize = dim3(x=cdiv(M * N, blockSize.x))
    kernel = launch_cuda_kernel(gridSize, blockSize, matmul_naive1D_kernel)
    kernel((C.flatten(), A.flatten(), B.flatten()), (M, K, N), counter:=Counter())
    counter.show()
    return C

In [10]:
C = matmul_naive1D(A, B)
print(check_output(C))

Time elapsed: 110.04 ms
Total reads: 10368 Total writes: 216
PASSED std=0.0000


# Matmul Naive 2D Tiled

In [11]:
def matmul_naive2D_kernel(blockIdx: dim3, threadIdx: dim3, blockSize: dim3, barrier: Barrier, shared_mem: Tensor,
                          buffers: Tuple[Tensor, ...], metadata: Tuple[int, ...], counter: Counter):
    m = (blockIdx.y * blockSize.y) + threadIdx.y
    n = (blockIdx.x * blockSize.x) + threadIdx.x
    (C, A, B), (M, K, N, tileWidth) = buffers, metadata
    if m < M and n < N:
        acc = 0.
        for tile_k in range(cdiv(K, tileWidth)):
            for t in range(tileWidth):
                k = fidx(tile_k, t, tileWidth)
                if k < K:
                    acc += A[fidx(m, k, K)] * B[fidx(k, n, N)]
                    counter.increment(loads=2)
        C[fidx(m, n, N)] += acc
        counter.increment(stores=1)

In [12]:
def matmul_naive2D(A: Tensor, B: Tensor, tileWidth: int):
    (M, K), (K_, N) = A.shape, B.shape
    assert K == K_, f"inner dims should match, {K} != {K_}"
    A = A.contiguous() if not A.is_contiguous() else A
    B = B.contiguous() if not B.is_contiguous() else B
    C = torch.empty(M, N, dtype=DTYPE)

    blockSize = dim3(y=tileWidth, x=tileWidth)
    gridSize = dim3(y=cdiv(M, tileWidth), x=cdiv(N, tileWidth))
    kernel = launch_cuda_kernel(gridSize, blockSize, matmul_naive2D_kernel)
    kernel((C.flatten(), A.flatten(), B.flatten()), (M, K, N, tileWidth), counter:=Counter())
    counter.show()
    return C

In [13]:
C = matmul_naive2D(A, B, 3)
print(check_output(C))

Time elapsed: 120.11 ms
Total reads: 10368 Total writes: 216
PASSED std=0.0000


# Matmul 2D Tiled (Shared Memory Optimization)

In [14]:
def matmul_tiled2D_kernel(blockIdx: dim3, threadIdx: dim3, blockSize: dim3, barrier: Barrier, shared_mem: Tensor,
                          buffers: Tuple[Tensor, ...], metadata: Tuple[int, ...], counter: Counter, scounter: Counter):
    m = (blockIdx.y * blockSize.y) + threadIdx.y
    n = (blockIdx.x * blockSize.x) + threadIdx.x
    (C, A, B), (M, K, N, tileWidth) = buffers, metadata
    offset = tileWidth ** 2
    if m < M and n < N:
        acc = 0.
        for tile_k in range(cdiv(K, tileWidth)):
            # load tile onto shared memory
            load_a, load_b = 0., 0.
            if (ak := fidx(tile_k, threadIdx.x, tileWidth)) < K:
                load_a = A[fidx(m, ak, K)]
                counter.increment(loads=1)
            if (bk := fidx(tile_k, threadIdx.y, tileWidth)) < K:
                load_b = B[fidx(bk, n, N)]
                counter.increment(loads=1)
            shared_mem[fidx(threadIdx.y, threadIdx.x, tileWidth)] = load_a
            shared_mem[fidx(threadIdx.y, threadIdx.x, tileWidth) + offset] = load_b
            scounter.increment(stores=2)
            barrier.wait()
            # compute dot products on tile
            for t in range(tileWidth):
                acc += shared_mem[fidx(threadIdx.y, t, tileWidth)] * shared_mem[fidx(t, threadIdx.x, tileWidth) + offset]
                scounter.increment(loads=2)
            barrier.wait()
        C[fidx(m, n, N)] += acc
        counter.increment(stores=1)

In [15]:
def matmul_tiled2D(A: Tensor, B: Tensor, tileWidth: int):
    (M, K), (K_, N) = A.shape, B.shape
    assert K == K_, f"inner dims should match, {K} != {K_}"
    A = A.contiguous() if not A.is_contiguous() else A
    B = B.contiguous() if not B.is_contiguous() else B
    C = torch.empty(M, N, dtype=DTYPE)

    blockSize = dim3(y=tileWidth, x=tileWidth)
    gridSize = dim3(y=cdiv(M, tileWidth), x=cdiv(N, tileWidth))
    kernel = launch_cuda_kernel(gridSize, blockSize, matmul_tiled2D_kernel, 2 * tileWidth ** 2)
    kernel((C.flatten(), A.flatten(), B.flatten()), (M, K, N, tileWidth), counter:=Counter(), scounter:=Counter())
    counter.show()
    scounter.show()
    return C

In [16]:
C = matmul_tiled2D(A, B, 3)
print(check_output(C))

Time elapsed: 109.05 ms
Total reads: 3456 Total writes: 216
Total reads: 10368 Total writes: 3456
PASSED std=0.0000


# Matmul Naive 3D Tiled (for fun)

In [17]:
def matmul_naive3D_kernel(blockIdx: dim3, threadIdx: dim3, blockSize: dim3, barrier: Barrier, shared_mem: Tensor,
                          buffers: Tuple[Tensor, ...], metadata: Tuple[int, ...], counter: Counter):
    m = (blockIdx.z * blockSize.z) + threadIdx.z
    n = (blockIdx.y * blockSize.y) + threadIdx.y
    k = (blockIdx.x * blockSize.x) + threadIdx.x
    (C, A, B), (M, K, N) = buffers, metadata
    if m < M and n < N and k < K:
        C[fidx(m, n, N)] += A[fidx(m, k, K)] * B[fidx(k, n, N)]
        counter.increment(loads=2, stores=1)

In [18]:
def matmul_naive3D(A: Tensor, B: Tensor, tileWidth: int):
    (M, K), (K_, N) = A.shape, B.shape
    assert K == K_, f"inner dims should match, {K} != {K_}"
    A = A.contiguous() if not A.is_contiguous() else A
    B = B.contiguous() if not B.is_contiguous() else B
    C = torch.empty(M, N, dtype=DTYPE)

    gridSize = dim3(cdiv(M, tileWidth), cdiv(N, tileWidth), cdiv(K, tileWidth))
    blockSize = dim3(tileWidth, tileWidth, tileWidth)
    kernel = launch_cuda_kernel(gridSize, blockSize, matmul_naive3D_kernel)
    kernel((C.flatten(), A.flatten(), B.flatten()), (M, K, N), counter:=Counter())
    counter.show()
    return C

In [19]:
C = matmul_naive3D(A, B, 3)
print(check_output(C))

Time elapsed: 179.80 ms
Total reads: 10368 Total writes: 5184
PASSED std=0.0000
