In [1]:
import torch

In [61]:
a = torch.randint(low=0, high=10, size=(7, 5))
b = torch.randint(low=0, high=10, size=(5, 9))
a, b

(tensor([[1, 1, 8, 2, 7],
         [4, 0, 6, 9, 7],
         [9, 6, 5, 4, 8],
         [4, 2, 4, 8, 8],
         [9, 3, 5, 6, 5],
         [9, 7, 2, 6, 1],
         [9, 8, 4, 9, 6]]),
 tensor([[7, 3, 2, 0, 6, 0, 8, 2, 6],
         [6, 4, 6, 5, 6, 9, 4, 5, 5],
         [3, 1, 9, 7, 3, 1, 8, 4, 3],
         [1, 6, 5, 1, 5, 3, 7, 8, 5],
         [6, 5, 2, 0, 4, 3, 5, 8, 5]]))

In [63]:
from itertools import product

def tiled_matmul(a: torch.Tensor, b: torch.Tensor, tile_width = 2):
    assert a.shape[1] == b.shape[0]
    res = torch.zeros(a.shape[0], b.shape[1], dtype=int)

    # mock shared memory
    sma = torch.empty((tile_width, tile_width), dtype=torch.int64)
    smb = torch.empty((tile_width, tile_width), dtype=torch.int64)

    # number of tiles for outer dims
    m_tiles = a.shape[0] // tile_width + (a.shape[0] % tile_width)
    n_tiles = b.shape[1] // tile_width + (b.shape[1] % tile_width)
    # inner dim tiles
    k_tiles = a.shape[1] // tile_width + (a.shape[1] % tile_width)

    print(m_tiles, k_tiles, n_tiles)
    
    # tiling coordinates
    for i, j in product(range(0, m_tiles*tile_width, tile_width), range(0, n_tiles*tile_width, tile_width)):
        # loop over each phase of tiled matmul
        for p in range(k_tiles):

            sma.zero_(), smb.zero_()

            # each thread loads a value from a and b into shared memory
            # a phases across a row, b phases across a column
            for ti, tj in product(range(tile_width), repeat=2):
                if (i+ti) < a.shape[0] and (tile_width*p+tj) < a.shape[1]:
                    sma[ti][tj] = a[i+ti][tile_width*p+tj]
                if (tile_width*p+ti) < b.shape[0] and (j+tj) < b.shape[1]:
                    smb[ti][tj] = b[tile_width*p+ti][j+tj]

            # need to do two separate loops to fully load in shared memory before dot products
            for ti, tj in product(range(tile_width), repeat=2):
                # dot the tith row of a with the tjth col of b
                if (i+ti) < res.shape[0] and (j+tj) < res.shape[1]:
                    res[i+ti][j+tj] += sum(sma[ti].flatten() * smb[:, tj].flatten()) # each thread does tile_width muls


    return res


print(torch.equal(a @ b, tiled_matmul(a, b, tile_width=4)))
a @ b, tiled_matmul(a, b, tile_width=4)


4 2 3
True
4 2 3


(tensor([[ 81,  62, 104,  63,  74,  44, 125, 111,  80],
         [ 97, 107, 121,  51, 115,  54, 178, 160, 122],
         [166, 120, 135,  69, 157,  95, 204, 164, 159],
         [108, 112, 112,  46, 120,  70, 168, 162, 126],
         [132, 105, 121,  56, 137,  65, 191, 141, 139],
         [123,  98, 110,  55, 136,  86, 163, 117, 130],
         [168, 147, 159,  77, 183, 121, 229, 194, 181]]),
 tensor([[ 81,  62, 104,  63,  74,  44, 125, 111,  80],
         [ 97, 107, 121,  51, 115,  54, 178, 160, 122],
         [166, 120, 135,  69, 157,  95, 204, 164, 159],
         [108, 112, 112,  46, 120,  70, 168, 162, 126],
         [132, 105, 121,  56, 137,  65, 191, 141, 139],
         [123,  98, 110,  55, 136,  86, 163, 117, 130],
         [168, 147, 159,  77, 183, 121, 229, 194, 181]]))