In [1]:
import os, sys
sys.path.insert(0, '..')

In [2]:
import torch
from torchvision.io import read_image, write_png
from torch.utils.cpp_extension import load_inline
from utils import *

In [3]:
from numba import cuda
from numba.cuda import as_cuda_array as ca
import numpy as np
import math

## Numba matmul

In [4]:
@cuda.jit
def matmul_k(m, n, out, tw):
    cbi,cbd,tid = cuda.blockIdx,cuda.blockDim,cuda.threadIdx
    tc,tr = tid.x,tid.y
    r,c = cbi.y * cbd.y + tr, cbi.x * cbd.x + tc
    h,k  = m.shape
    k2,w = n.shape

    shar = cuda.shared.array(0, dtype=np.float32)
    ms,ns = shar[:tw*tw],shar[tw*tw:2*tw*tw]

    p = np.float32(0.0)
    for ph in range(math.ceil(k/tw)):
        idx = ph*tw
        ms[tr*tw+tc] = m[r, tc+idx] if r<h and idx+tc<k else 0.
        ns[tr*tw+tc] = n[tr+idx, c] if c<w and idx+tr<k else 0.
        cuda.syncthreads()
        for i in range(tw): p += ms[tr*tw+i] * ns[i*tw+tc]
        cuda.syncthreads()
    if r < h and c < w: out[r, c] = p


def matmul_2d_numba(m, n, tw=16):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    out = torch.zeros(h, w, dtype=m.dtype, device=m.device)
    dyn_shared_mem_size = 2 * tw * tw * 4
    tpb = tw,tw
    blocks = cdiv(w,tpb[0]), cdiv(h,tpb[1])
    matmul_k[blocks, tpb, 0, dyn_shared_mem_size](ca(m), ca(n), ca(out), tw) 
    return out

In [5]:
m1 = torch.rand(5120, 256)
m1s = m1[:4]
m2 = torch.rand(256,5120)
m2s = m2[:,:4]

In [6]:
m1c,m2c = m1.contiguous().cuda(),m2.contiguous().cuda()

In [7]:
torch.isclose(matmul_2d_numba(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')

In [8]:
%%timeit -n 10
matmul_2d_numba(m1c,m2c)
torch.cuda.synchronize()

48.1 ms ± 1.44 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


## CUDA

In [9]:
cuda_src = cuda_begin + r'''
__global__ void matmul_k(float* m, float* n, float* out, int h, int w, int k) {
    int r = blockIdx.y*blockDim.y + threadIdx.y;
    int c = blockIdx.x*blockDim.x + threadIdx.x;

    if (r>=h || c>=w) return;
    float o = 0;
    for (int i = 0; i<k; ++i) o += m[r*k+i] * n[i*w+c];
    out[r*w+c] = o;
}

torch::Tensor matmul(torch::Tensor m, torch::Tensor n) {
    CHECK_INPUT(m); CHECK_INPUT(n);
    int h = m.size(0);
    int w = n.size(1);
    int k = m.size(1);
    TORCH_CHECK(k==n.size(0), "Size mismatch!");
    auto output = torch::zeros({h, w}, m.options());

    dim3 tpb(16,16);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    matmul_k<<<blocks, tpb>>>(
        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''
fname = 'matmul'

In [10]:
ext = load_cuda(cuda_src, get_sig(fname, cuda_src), [fname])

In [11]:
torch.isclose(ext.matmul(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')

In [13]:
%%timeit -n 10
ext.matmul(m1c,m2c)
torch.cuda.synchronize()

28.4 ms ± 7.63 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [15]:
%%timeit -n 10
m1c@m2c
torch.cuda.synchronize()

3.34 ms ± 167 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
