In [1]:
import torch

In [55]:
a = torch.randint(low=0, high=10, size=(6, 4))
b = torch.randint(low=0, high=10, size=(4, 8))
a, b

(tensor([[8, 2, 4, 5],
         [4, 2, 2, 6],
         [5, 5, 4, 8],
         [3, 6, 8, 9],
         [5, 7, 3, 6],
         [8, 1, 9, 9]]),
 tensor([[0, 4, 4, 6, 9, 3, 4, 1],
         [8, 4, 4, 8, 5, 6, 1, 6],
         [2, 5, 1, 9, 9, 9, 4, 0],
         [6, 5, 7, 7, 9, 8, 5, 7]]))

In [56]:
from itertools import product

def tiled_matmul(a: torch.Tensor, b: torch.Tensor, tile_width = 2):
    assert a.shape[1] == b.shape[0]
    res = torch.zeros(a.shape[0], b.shape[1], dtype=int)

    # mock shared memory
    sma = torch.empty((tile_width, tile_width), dtype=torch.int64)
    smb = torch.empty((tile_width, tile_width), dtype=torch.int64)

    # number of tiles for outer dims
    m_tiles = a.shape[0] // tile_width + (a.shape[0] % tile_width)
    n_tiles = b.shape[1] // tile_width + (b.shape[1] % tile_width)
    # inner dim tiles
    k_tiles = a.shape[1] // tile_width + (a.shape[1] % tile_width)

    print(m_tiles, k_tiles, n_tiles)
    
    # tiling coordinates
    for i, j in product(range(0, m_tiles*tile_width, tile_width), range(0, n_tiles*tile_width, tile_width)):
        # loop over each phase of tiled matmul
        for p in range(k_tiles):

            sma.zero_(), smb.zero_()
            # print(f'({i, tile_width*p}), ({i, tile_width*(p+1)})')
            # print(f'({i+tile_width, tile_width*p}), ({i+tile_width, tile_width*(p+1)})')
            # load in shared memory
            # matrix a, we go across a row
            # matrix b, go across column
            ai_range = slice(i, min(i+tile_width, a.shape[0]))
            aj_range = slice(tile_width*p, min(tile_width*(p+1), a.shape[1]))
            bi_range = aj_range
            bj_range = slice(j, min(j+tile_width, b.shape[1]))

            sma = a[ai_range, aj_range].clone() #clone() prevents this from being a view on global tensor
            smb = b[bi_range, bj_range].clone()

            # print(f'Multiplying a tile i({i, i+tile_width}) j({tile_width*p, tile_width*(p+1)})')
            # print(f'Multiplying b tile i({tile_width*p, tile_width*(p+1)}) j({j, j+tile_width})\n')
            for ti, tj in product(range(tile_width), repeat=2):
                # dot the tith row of a with the tjth col of b
                res[i+ti][j+tj] += sum(sma[ti].flatten() * smb[:, tj].flatten()) # each thread does tile_width muls


    return res


print(torch.equal(a @ b, tiled_matmul(a, b)))
a @ b, tiled_matmul(a, b)


3 2 4
True
3 2 4


(tensor([[ 54,  85,  79, 135, 163, 112,  75,  55],
         [ 56,  64,  68, 100, 118,  90,  56,  58],
         [ 96, 100, 100, 162, 178, 145,  81,  91],
         [118, 121, 107, 201, 210, 189,  95, 102],
         [ 98,  93,  93, 155, 161, 132,  69,  89],
         [ 80, 126, 108, 200, 239, 183, 114,  77]]),
 tensor([[ 54,  85,  79, 135, 163, 112,  75,  55],
         [ 56,  64,  68, 100, 118,  90,  56,  58],
         [ 96, 100, 100, 162, 178, 145,  81,  91],
         [118, 121, 107, 201, 210, 189,  95, 102],
         [ 98,  93,  93, 155, 161, 132,  69,  89],
         [ 80, 126, 108, 200, 239, 183, 114,  77]]))