In [1]:
import torch

In [None]:
a = torch.randint(low=0, high=10, size=(4, 4), dtype=torch.int32)
b = torch.randint(low=0, high=10, size=(4, 4), dtype=torch.int32)
a, b

In [None]:
from itertools import product

def tiled_matmul(a: torch.Tensor, b: torch.Tensor, tile_width = 2):
    assert a.shape[1] == b.shape[0]
    res = torch.zeros(a.shape[0], b.shape[1], dtype=torch.int32)

    # mock shared memory
    sma = torch.empty((tile_width, tile_width), dtype=torch.int32)
    smb = torch.empty((tile_width, tile_width), dtype=torch.int32)

    # number of tiles for outer dims
    m_tiles = a.shape[0] // tile_width + (a.shape[0] % tile_width)
    n_tiles = b.shape[1] // tile_width + (b.shape[1] % tile_width)
    # inner dim tiles
    k_tiles = a.shape[1] // tile_width + (a.shape[1] % tile_width)

    # tiling coordinates
    for i, j in product(range(0, m_tiles*tile_width, tile_width), range(0, n_tiles*tile_width, tile_width)):
        # loop over each phase of tiled matmul
        for p in range(k_tiles):

            # each thread loads a value from a and b into shared memory
            # a phases across a row, b phases across a column
            for ti, tj in product(range(tile_width), repeat=2):
                if (i+ti) < a.shape[0] and (tile_width*p+tj) < a.shape[1]:
                    sma[ti][tj] = a[i+ti][tile_width*p+tj]
                else:
                    sma[ti][tj] = 0
                if (tile_width*p+ti) < b.shape[0] and (j+tj) < b.shape[1]:
                    smb[ti][tj] = b[tile_width*p+ti][j+tj]
                else:
                    smb[ti][tj] = 0

            # need to do two separate loops to fully load in shared memory before dot products
            for ti, tj in product(range(tile_width), repeat=2):
                # dot the tith row of a with the tjth col of b
                if (i+ti) < res.shape[0] and (j+tj) < res.shape[1]:
                    res[i+ti][j+tj] += sum(sma[ti].flatten() * smb[:, tj].flatten()) # each thread does tile_width muls


    return res


print(torch.equal(a @ b, tiled_matmul(a, b, tile_width=2)))
a @ b, tiled_matmul(a, b, tile_width=4)


## Load in CUDA version

In [2]:
from torch.utils.cpp_extension import load
module = load(
    name = 'm',
    sources = ['main.cpp', 'tiled_matmul.cu'],
    verbose=True
)



Using /home/seb/.cache/torch_extensions/py312_cu121 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/seb/.cache/torch_extensions/py312_cu121/m/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module m...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


[1/3] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=m -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/seb/CUDA/cudaenv/lib/python3.12/site-packages/torch/include -isystem /home/seb/CUDA/cudaenv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /home/seb/CUDA/cudaenv/lib/python3.12/site-packages/torch/include/TH -isystem /home/seb/CUDA/cudaenv/lib/python3.12/site-packages/torch/include/THC -isystem /usr/local/cuda-12.3/include -isystem /home/seb/miniconda3/include/python3.12 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -c /home/seb/CUDA/pmpp/tiled_matmul/main.cpp -o main.o 
[2/3] /usr/local/cuda-12.3/bin/nvcc --generate-dependencies-with-compile --dependency-output tiled_matmul.cuda.o.d -DTORCH_EXTENSION_NAME=m -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/seb

Loading extension module m...


In [5]:
a = torch.rand((4, 4), device='cuda')
b = torch.rand((4, 4), device='cuda')
print(a, b)



# a = torch.randint(low=0, high=10, size=(4, 4), dtype=torch.int32, device='cuda')
# b = torch.randint(low=0, high=10, size=(4, 4), dtype=torch.int32, device='cuda')

tensor([[0.5696, 0.0066, 0.3270, 0.8539],
        [0.9236, 0.4954, 0.7321, 0.8394],
        [0.9924, 0.8557, 0.0706, 0.9079],
        [0.9878, 0.3277, 0.0740, 0.4086]], device='cuda:0') tensor([[0.6172, 0.9494, 0.5361, 0.2185],
        [0.5762, 0.7989, 0.6186, 0.2083],
        [0.6684, 0.0917, 0.2277, 0.4593],
        [0.7602, 0.7561, 0.9238, 0.3692]], device='cuda:0')


In [6]:
module.matmul(a, b), a @ b

(tensor([[1.2231, 1.2217, 1.1727, 0.5913],
         [1.9829, 1.9744, 1.7437, 0.9511],
         [1.8429, 2.3187, 1.9161, 0.7627],
         [1.1586, 1.5153, 1.1266, 0.4689]], device='cuda:0'),
 tensor([[1.2231, 1.2217, 1.1727, 0.5913],
         [1.9829, 1.9744, 1.7437, 0.9511],
         [1.8429, 2.3187, 1.9161, 0.7627],
         [1.1586, 1.5153, 1.1266, 0.4689]], device='cuda:0'))