In [1]:
import torch

In [2]:
a = torch.randint(low=0, high=10, size=(4, 4), dtype=torch.int32)
b = torch.randint(low=0, high=10, size=(4, 4), dtype=torch.int32)
a, b

(tensor([[3, 5, 3, 4],
         [2, 4, 7, 3],
         [3, 5, 6, 0],
         [7, 6, 2, 2]], dtype=torch.int32),
 tensor([[8, 9, 4, 1],
         [1, 7, 1, 8],
         [3, 2, 7, 4],
         [0, 1, 8, 5]], dtype=torch.int32))

In [3]:
from itertools import product

def tiled_matmul(a: torch.Tensor, b: torch.Tensor, tile_width = 2):
    assert a.shape[1] == b.shape[0]
    res = torch.zeros(a.shape[0], b.shape[1], dtype=torch.int32)

    # mock shared memory
    sma = torch.empty((tile_width, tile_width), dtype=torch.int32)
    smb = torch.empty((tile_width, tile_width), dtype=torch.int32)

    # number of tiles for outer dims
    m_tiles = a.shape[0] // tile_width + (a.shape[0] % tile_width)
    n_tiles = b.shape[1] // tile_width + (b.shape[1] % tile_width)
    # inner dim tiles
    k_tiles = a.shape[1] // tile_width + (a.shape[1] % tile_width)

    # tiling coordinates
    for i, j in product(range(0, m_tiles*tile_width, tile_width), range(0, n_tiles*tile_width, tile_width)):
        # loop over each phase of tiled matmul
        for p in range(k_tiles):

            # each thread loads a value from a and b into shared memory
            # a phases across a row, b phases across a column
            for ti, tj in product(range(tile_width), repeat=2):
                if (i+ti) < a.shape[0] and (tile_width*p+tj) < a.shape[1]:
                    sma[ti][tj] = a[i+ti][tile_width*p+tj]
                else:
                    sma[ti][tj] = 0
                if (tile_width*p+ti) < b.shape[0] and (j+tj) < b.shape[1]:
                    smb[ti][tj] = b[tile_width*p+ti][j+tj]
                else:
                    smb[ti][tj] = 0

            # need to do two separate loops to fully load in shared memory before dot products
            for ti, tj in product(range(tile_width), repeat=2):
                # dot the tith row of a with the tjth col of b
                if (i+ti) < res.shape[0] and (j+tj) < res.shape[1]:
                    res[i+ti][j+tj] += sum(sma[ti].flatten() * smb[:, tj].flatten()) # each thread does tile_width muls


    return res


print(torch.equal(a @ b, tiled_matmul(a, b, tile_width=2)))
a @ b, tiled_matmul(a, b, tile_width=4)


True


(tensor([[ 38,  72,  70,  75],
         [ 41,  63,  85,  77],
         [ 47,  74,  59,  67],
         [ 68, 111,  64,  73]], dtype=torch.int32),
 tensor([[ 38,  72,  70,  75],
         [ 41,  63,  85,  77],
         [ 47,  74,  59,  67],
         [ 68, 111,  64,  73]], dtype=torch.int32))

## Load in CUDA version

In [37]:
from torch.utils.cpp_extension import load
module = load(
    name = 'm',
    sources = ['main.cpp', 'tiled_matmul.cu'],
    verbose=True
)



Using /home/seb/.cache/torch_extensions/py312_cu121 as PyTorch extensions root...
The input conditions for extension module m have changed. Bumping to version 4 and re-building as m_v4...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/seb/.cache/torch_extensions/py312_cu121/m/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module m_v4...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


[1/3] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=m_v4 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/seb/CUDA/cudaenv/lib/python3.12/site-packages/torch/include -isystem /home/seb/CUDA/cudaenv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /home/seb/CUDA/cudaenv/lib/python3.12/site-packages/torch/include/TH -isystem /home/seb/CUDA/cudaenv/lib/python3.12/site-packages/torch/include/THC -isystem /usr/local/cuda-12.3/include -isystem /home/seb/miniconda3/include/python3.12 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -c /home/seb/CUDA/pmpp/tiled_matmul/main.cpp -o main.o 
[2/3] /usr/local/cuda-12.3/bin/nvcc --generate-dependencies-with-compile --dependency-output tiled_matmul.cuda.o.d -DTORCH_EXTENSION_NAME=m_v4 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /ho

Loading extension module m_v4...


In [38]:
a = torch.rand((4, 2), device='cuda')
b = torch.rand((2, 5), device='cuda')
print(a, b)



# a = torch.randint(low=0, high=10, size=(4, 4), dtype=torch.int64, device='cuda')
# b = torch.randint(low=0, high=10, size=(4, 4), dtype=torch.int64, device='cuda')

tensor([[0.3178, 0.7526],
        [0.0085, 0.3547],
        [0.6547, 0.2386],
        [0.5069, 0.3096]], device='cuda:0') tensor([[0.6467, 0.8290, 0.1769, 0.1295, 0.7825],
        [0.3927, 0.2491, 0.2163, 0.6451, 0.7824]], device='cuda:0')


In [39]:
module.matmul(a, b), a @ b

(tensor([[0.5011, 0.4509, 0.2190, 0.5267, 0.8375],
         [0.1448, 0.0954, 0.0782, 0.2299, 0.2842],
         [0.5171, 0.6021, 0.1674, 0.2387, 0.6989],
         [0.4494, 0.4973, 0.1567, 0.2654, 0.6388]], device='cuda:0'),
 tensor([[0.5011, 0.4509, 0.2190, 0.5267, 0.8375],
         [0.1448, 0.0954, 0.0782, 0.2299, 0.2842],
         [0.5171, 0.6021, 0.1674, 0.2387, 0.6989],
         [0.4494, 0.4973, 0.1567, 0.2654, 0.6388]], device='cuda:0'))