In [1]:
from typing import Callable
from time import perf_counter
from dataclasses import dataclass
from itertools import product
from threading import Barrier
from concurrent.futures import ThreadPoolExecutor

import torch
from torch import Tensor

In [2]:
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
DTYPE = torch.float32

In [3]:
class Counter:
    def __init__(self):
        self.reads, self.writes = 0, 0
    def inc_reads(self, n): self.reads += n    # increment reads
    def inc_writes(self, n): self.writes += n  # increment writes
    def show(self): print(f"Total reads: {self.reads}\nTotal writes: {self.writes}")

def timeit(func: Callable):
    def timer(*args, **kwargs):
        t1 = perf_counter()
        out = func(*args, **kwargs)
        t2 = perf_counter()
        diff = (t2 - t1) * 1000 # in ms
        print(f"Time elapsed: {diff:.2f} ms")
        return out
    return timer

def cdiv(a: int, b: int): return (a + b - 1) // b # equivalent to math.ceil()

@dataclass
class D3:
    z: int = 1
    y: int = 1
    x: int = 1
    @property
    def size(self): return self.x * self.y * self.z

In [4]:
def launch_cuda_kernel(blocks: D3, threads: D3, malloc_shared_size: int=0, kernel: Callable=None):
    assert kernel is not None
    @timeit
    def dispatch_kernel(*kernargs):
        workers = blocks.size * threads.size
        barrier = Barrier(workers)
        with ThreadPoolExecutor(max_workers=workers) as e:
            for bz, by, bx in product(range(blocks.z), range(blocks.y), range(blocks.x)):
                shared_mem = torch.empty((malloc_shared_size,), dtype=DTYPE)
                for tz, ty, tx in product(range(threads.z), range(threads.y), range(threads.x)):
                    e.submit(kernel, D3(bz, by, bx), D3(tz, ty, tx), threads, barrier, shared_mem, *kernargs)
    return dispatch_kernel

In [5]:
# test example
M, K, N = 12, 24, 15
A = torch.randn(M, K, dtype=DTYPE)
B = torch.randn(K, N, dtype=DTYPE)
C_ref = A @ B

In [6]:
def check(C: Tensor): return "PASSED" if torch.allclose(C, C_ref, atol=1e-6) else "FAILED"

# Matmul Naive 1D

In [7]:
def matmul_naive_kernel(blockIdx: D3, threadIdx: D3, blockSize: D3, barrier: Barrier, shared_mem: Tensor,
                        C: Tensor, A: Tensor, B: Tensor, M: int, K: int, N: int, counter: Counter):
    idx = (blockIdx.x * blockSize.x) + threadIdx.x
    if idx < M * N:
        acc = 0.
        m, n = idx // N, idx % N
        for k in range(K):
            acc += A[m * K + k] * B[k * N + n]
            counter.inc_reads(2)
        C[idx] = acc
        counter.inc_writes(1)

In [8]:
def matmul_naive(A: Tensor, B: Tensor):
    (M, K), (K_, N) = A.shape, B.shape
    assert K == K_, f"inner dims should match! {K} != {K_}"
    A = A.contiguous() if not A.is_contiguous() else A
    B = B.contiguous() if not B.is_contiguous() else B
    C = torch.empty(M, N, dtype=DTYPE)

    threads = 8
    blocks = cdiv(M * N, threads)
    kernel = launch_cuda_kernel(D3(x=blocks), D3(x=threads), kernel=matmul_naive_kernel)
    kernel(C.flatten(), A.flatten(), B.flatten(), M, K, N, counter:=Counter())
    counter.show()
    return C

In [9]:
C = matmul_naive(A, B)
print(check(C))

Time elapsed: 98.41 ms
Total reads: 8640
Total writes: 180
PASSED


# Matmul Naive 2D Tiled

In [10]:
def matmul_naive2D_kernel(blockIdx: D3, threadIdx: D3, blockSize: D3, barrier: Barrier, shared_mem: Tensor,
                          C: Tensor, A: Tensor, B: Tensor, M: int, K: int, N: int, tileWidth: int, counter: Counter):
    idx_m = (blockIdx.y * blockSize.y) + threadIdx.y
    idx_n = (blockIdx.x * blockSize.x) + threadIdx.x
    tiles_k = cdiv(K, tileWidth)
    if idx_m < M and idx_n < N:
        acc = 0.
        for tile_k in range(tiles_k):
            for k in range(tileWidth):
                idx_k = tile_k * tileWidth + k
                if idx_k < K:
                    acc += A[idx_m * K + idx_k] * B[idx_k * N + idx_n]
                    counter.inc_reads(2)
        C[idx_m * N + idx_n] += acc
        counter.inc_writes(1)

In [11]:
def matmul_naive2D(A: Tensor, B: Tensor, tileWidth: int):
    (M, K), (K_, N) = A.shape, B.shape
    assert K == K_, f"inner dims should match! {K} != {K_}"
    A = A.contiguous() if not A.is_contiguous() else A
    B = B.contiguous() if not B.is_contiguous() else B
    C = torch.empty(M, N, dtype=DTYPE)

    blocks_m = cdiv(M, tileWidth)
    blocks_n = cdiv(N, tileWidth)
    kernel = launch_cuda_kernel(
        D3(y=blocks_m, x=blocks_n), D3(y=tileWidth, x=tileWidth), kernel=matmul_naive2D_kernel)
    kernel(C.flatten(), A.flatten(), B.flatten(), M, K, N, tileWidth, counter:=Counter())
    counter.show()
    return C

In [12]:
C = matmul_naive2D(A, B, 5)
print(check(C))

Time elapsed: 117.39 ms
Total reads: 8640
Total writes: 180
PASSED


# Matmul 2D Tiled (Shared Memory Optimization)

In [13]:
def matmul_tiled2D_kernel(blockIdx: D3, threadIdx: D3, blockSize: D3, barrier: Barrier, shared_mem: Tensor,
                          C: Tensor, A: Tensor, B: Tensor, M: int, K: int, N: int, tileWidth: int, counter: Counter):
    idx_m = (blockIdx.y * blockSize.y) + threadIdx.y
    idx_n = (blockIdx.x * blockSize.x) + threadIdx.x
    shared_offset = tileWidth ** 2
    tiles_k = cdiv(K, tileWidth) 
    if idx_m < M and idx_n < N:
        acc = 0.
        for tile_k in range(tiles_k):
            # load tile onto shared memory
            idx_ak = tile_k * tileWidth + threadIdx.x
            idx_bk = tile_k * tileWidth + threadIdx.y
            shared_mem[threadIdx.y * tileWidth + threadIdx.x] = A[idx_m * K + idx_ak] if idx_ak < K else 0
            shared_mem[threadIdx.y * tileWidth + threadIdx.x + shared_offset] = B[idx_bk * N + idx_n] if idx_bk < K else 0
            if idx_ak < K: counter.inc_reads(1)
            if idx_bk < K: counter.inc_reads(1)
            barrier.wait()
            # compute dot products on tile
            for k in range(tileWidth):
                idx_k = tile_k * tileWidth + k
                if idx_k < K:
                    acc += shared_mem[threadIdx.y * tileWidth + k] * shared_mem[shared_offset + k * tileWidth + threadIdx.x]
            barrier.wait()
        C[idx_m * N + idx_n] += acc
        counter.inc_writes(1)

In [14]:
def matmul_tiled2D(A: Tensor, B: Tensor, tileWidth: int):
    (M, K), (K_, N) = A.shape, B.shape
    assert K == K_, f"inner dims should match! {K} != {K_}"
    A = A.contiguous() if not A.is_contiguous() else A
    B = B.contiguous() if not B.is_contiguous() else B
    C = torch.empty(M, N, dtype=DTYPE)

    blocks_m = cdiv(M, tileWidth)
    blocks_n = cdiv(N, tileWidth)
    kernel = launch_cuda_kernel(
        D3(y=blocks_m, x=blocks_n), D3(y=tileWidth, x=tileWidth), 2*tileWidth**2, kernel=matmul_tiled2D_kernel)
    kernel(C.flatten(), A.flatten(), B.flatten(), M, K, N, tileWidth, counter:=Counter())
    counter.show()
    return C

In [15]:
C = matmul_tiled2D(A, B, 3)
print(check(C))

Time elapsed: 164.31 ms
Total reads: 2880
Total writes: 180
PASSED


# Matmul Naive 3D Tiled (for fun)

In [16]:
def matmul_naive3D_kernel(blockIdx: D3, threadIdx: D3, blockSize: D3, barrier: Barrier, shared_mem: Tensor,
                          C: Tensor, A: Tensor, B: Tensor, M: int, K: int, N: int, counter: Counter):
    idx_m = (blockIdx.z * blockSize.z) + threadIdx.z
    idx_n = (blockIdx.y * blockSize.y) + threadIdx.y
    idx_k = (blockIdx.x * blockSize.x) + threadIdx.x
    if idx_m < M and idx_n < N and idx_k < K:
        C[idx_m * N + idx_n] += A[idx_m * K + idx_k] * B[idx_k * N + idx_n]
        counter.inc_reads(2)
        counter.inc_writes(1)

In [17]:
def matmul_naive3D(A: Tensor, B: Tensor, tileWidth: int):
    (M, K), (K_, N) = A.shape, B.shape
    assert K == K_, f"inner dims should match! {K} != {K_}"
    A = A.contiguous() if not A.is_contiguous() else A
    B = B.contiguous() if not B.is_contiguous() else B
    C = torch.empty(M, N, dtype=DTYPE)

    blocks_m, blocks_n, blocks_k = cdiv(M, tileWidth), cdiv(N, tileWidth), cdiv(K, tileWidth)
    kernel = launch_cuda_kernel(
        D3(z=blocks_m, y=blocks_n, x=blocks_k), D3(tileWidth, tileWidth, tileWidth), kernel=matmul_naive3D_kernel)
    kernel(C.flatten(), A.flatten(), B.flatten(), M, K, N, counter:=Counter())
    counter.show()
    return C

In [18]:
C = matmul_naive3D(A, B, 3)
print(check(C))

Time elapsed: 291.22 ms
Total reads: 8640
Total writes: 4320
PASSED
