## Setup

In [1]:
!pip install wurlitzer
!pip install Ninja
import os,math,sys,torch,re,numpy as np
from types import SimpleNamespace as ns
from collections import namedtuple
# from utils import show_img,load_cuda,cuda_begin,cdiv



# Utils

In [2]:
import torch
import matplotlib.pyplot as plt
from torch.utils.cpp_extension import load_inline

import os,math,sys,torch,re,numpy as np
from types import SimpleNamespace as ns
from collections import namedtuple

np.set_printoptions(precision=2, linewidth=140)
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)

def show_img(x, figsize=(4,3), **kwargs):
    "Display HW or CHW format image `x`"
    plt.figure(figsize=figsize)
    plt.axis('off')
    if len(x.shape)==3: x = x.permute(1,2,0)  # CHW -> HWC
    plt.imshow(x.cpu(), **kwargs)

cuda_begin = r'''
#include <torch/extension.h>
#include <stdio.h>
#include <c10/cuda/CUDAException.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CUDA_ERR(ans) { gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true)
{
   if (code != cudaSuccess)
   {
      fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
      if (abort) exit(code);
   }
}
__host__ __device__ inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a+b-1)/b;}
'''

def load_cuda(cuda_src, cpp_src, funcs, opt=True, verbose=False, name=None):
    "Simple wrapper for torch.utils.cpp_extension.load_inline"
    if name is None: name = funcs[0]
    # flags = "-O3 -Xptxas -O3 -Xcompiler -O3" if opt else "-O0 -Xptxas -O0 -Xcompiler -O0"
    return load_inline(cuda_sources=[cuda_src], cpp_sources=[cpp_src], functions=funcs, verbose=verbose, name=name)

def cdiv(a,b):
    "Int ceiling division of `a` over `b`"
    return (a+b-1)//b


In [3]:
%load_ext wurlitzer

## Python Version in CUDA format

In [4]:
# Setup
dim3 = namedtuple('dim3', ['x','y','z'], defaults=(1,1))
d = dim3(2,3)
m1 = torch.rand(5120, 256)
m1s = m1[:4]
m2 = torch.rand(256,5120)
m2s = m2[:,:4]


In [14]:
# Functions
def iterate_kerenel(f, blocks, threads, *args):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            for j0 in range(threads.y):
                for j1 in range(threads.x): f(dim3(i1,i0), dim3(j1,j0), threads, *args)
def get_sig(fname, src):
    res = re.findall(rf'^(.+\s+{fname}\(.*?\))\s*{{?\s*$', src, re.MULTILINE)
    return res[0]+';' if res else None

def matmul_kernel(blockIdx, threadIdx, blockDim, m, n, out, h, w, k):
    r = blockIdx.y*blockDim.y + threadIdx.y
    c = blockIdx.x*blockDim.x + threadIdx.x

    # boundary checking
    if (r>=h or c>=w):
       return

    # matrix multiplication loop over flattened tensors
    o = 0.0
    for i in range(k):
      o += m[r*k+i] * n[i*w+c]
    out[r*w+c] = o



def matmul(m, n):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = dim3(16,16)
    blocks = dim3(cdiv(w,tpb.x), cdiv(h,tpb.y))
    iterate_kerenel(matmul_kernel, blocks, tpb,
                 m.flatten(), n.flatten(), output.flatten(), h, w, k)
    return output

In [6]:
# Test
torch.isclose(matmul(m1s, m2s), m1s@m2s).all()

tensor(True)

## CUDA version

In [23]:
cuda_src = cuda_begin + r'''
__global__ void matmul_k(float* m, float* n, float* out, int h, int w, int k) {
    int r = blockIdx.y*blockDim.y + threadIdx.y;
    int c = blockIdx.x*blockDim.x + threadIdx.x;

    if (r>=h || c>=w) return;
    float o = 0;
    for (int i = 0; i<k; ++i) o += m[r*k+i] * n[i*w+c];
    out[r*w+c] = o;
}

torch::Tensor matmul(torch::Tensor m, torch::Tensor n) {
    CHECK_INPUT(m); CHECK_INPUT(n);
    int h = m.size(0);
    int w = n.size(1);
    int k = m.size(1);
    TORCH_CHECK(k==n.size(0), "Size mismatch!");
    auto output = torch::zeros({h, w}, m.options());

    dim3 tpb(16,16);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    matmul_k<<<blocks, tpb>>>(
        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''
fname = 'matmul'

In [24]:
cpp_src = get_sig(fname, cuda_src)



In [25]:
print(cpp_src)

torch::Tensor matmul(torch::Tensor m, torch::Tensor n);


In [29]:
module = load_cuda(cuda_src, cpp_src, [fname])
m1c,m2c = m1.contiguous().cuda(),m2.contiguous().cuda()


In [27]:
module.matmul(m1c,m2c).shape

tensor(True, device='cuda:0')

In [28]:
torch.isclose(module.matmul(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')

## Memory Tiling (Python)
Python version of memory tiling to make the CUDA code more efficient

In [32]:
import threading
from threading import Barrier, Thread
from concurrent.futures import ThreadPoolExecutor
# The Kernel
def matmul_tiled_bk(blockIdx, threadIdx, blockDim, shared, syncb, m, n, out, h, w, k, tw):
    tc,tr = threadIdx.x,threadIdx.y
    r = blockIdx.y*blockDim.y + tr
    c = blockIdx.x*blockDim.x + tc

    shar_sz = tw*tw
    ms,ns = shared[:shar_sz],shared[shar_sz:]

    p = 0.0
    for ph in range(cdiv(k,tw)):
        # Calculate the shared memory
        ms[tr*tw+tc] = m[ tc + ph*tw + r*k] if r<h and (ph*tw+tc)<k else 0.
        ns[tr*tw+tc] = n[(tr + ph*tw)*w +c] if c<w and (ph*tw+tr)<k else 0.

        # Sync up with other threads in block
        syncb.wait()

        # Utilize shared memory
        for i in range(tw): p += ms[tr*tw+i] * ns[tw*i+tc]
        syncb.wait()

    if (r<h and c<w): out[r*w + c] = p


# Simulates CUDA scheduling / processing
def blk_kernel2d_shar(f, blocks, tpb, sh_sz, *args, **kwargs):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            shar = torch.zeros(sh_sz)
            syncb = Barrier(tpb.y*tpb.x)

            # Create threads
            threads = [Thread(target=f, args=(dim3(i1,i0), dim3(p,o), tpb, shar, syncb, *args), kwargs=kwargs)
                       for o in range(tpb.y) for p in range(tpb.x)]
            for tr in threads: tr.start()
            for tr in threads: tr.join()


# Matrix multiplication
def matmul_2d(m, n, tw=16):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = dim3(tw,tw)
    blocks = dim3(cdiv(w,tpb.x), cdiv(h,tpb.y))

    # Launch the kernel with the arguments
    blk_kernel2d_shar(matmul_tiled_bk, blocks, tpb, tw*tw*2,
                      m.flatten(), n.flatten(), output.flatten(),
                      h, w, k, tw=tw)

    return output

In [33]:
torch.isclose(matmul_2d(m1s, m2s, tw=8), m1s@m2s).all()

tensor(True)

## CUDA using tiling

In [43]:
cuda_src = cuda_begin + r'''
__global__ void matmul_tiled_bk(float* m, float* n, float* out, int h, int w, int k, int tw) {
    int tc = threadIdx.x;
    int tr = threadIdx.y;
    int r = blockIdx.y * blockDim.y + tr;
    int c = blockIdx.x * blockDim.x + tc;

    // Load shared (within block) memoryh
    extern __shared__ float shared[];

    float* ms = &shared[0];
    float* ns = &shared[tw * tw];

    float p = 0.0;
    for (int ph = 0; ph < cdiv(k, tw); ++ph) {

        // Calculate the shared memory
        ms[tr * tw + tc] = (r < h && (ph * tw + tc) < k) ? m[tc + ph * tw + r * k] : 0.0;
        ns[tr * tw + tc] = (c < w && (ph * tw + tr) < k) ? n[(tr + ph * tw) * w + c] : 0.0;

        // Sync up with other threads in block
        __syncthreads();

        // Utilize shared memory
        for (int i = 0; i < tw; ++i) p += ms[tr * tw + i] * ns[tw * i + tc];
        __syncthreads();
    }

    if (r < h && c < w) out[r * w + c] = p;
}

torch::Tensor matmul_grid(torch::Tensor m, torch::Tensor n) {
    CHECK_INPUT(m); CHECK_INPUT(n);
    int h=m.size(0), w=n.size(1), k=m.size(1);
    TORCH_CHECK(k==n.size(0), "Size mismatch!");

    auto output = torch::zeros({h, w}, m.options());
    int TW = 16;
    size_t size = TW*TW * 2 * sizeof(float);
    dim3 tpb(TW,TW);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    matmul_tiled_bk<<<blocks,tpb,size>>>(
        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k, TW);
    C10_CUDA_KERNEL_LAUNCH_CHECK();


    return output;
}
'''
fun = "matmul_grid"
cpp_src = get_sig(fun, cuda_src)
cpp_src

'torch::Tensor matmul_grid(torch::Tensor m, torch::Tensor n);'

In [45]:
module = load_cuda(cuda_src, cpp_src, [fun], opt=True)


In [47]:
torch.isclose(module.matmul_grid(m1c,m2c), m1c@m2c).all()



tensor(True, device='cuda:0')

In [49]:
# Numba seems to make the whole multithreading process much easier to simulate in raw Python, but different syntax

from numba import cuda
from numba.cuda import as_cuda_array as ca

@cuda.jit
def matmul_k_numba(m, n, out, tw):
    cbi,cbd,tid = cuda.blockIdx,cuda.blockDim,cuda.threadIdx
    tc,tr = tid.x,tid.y
    r,c = cbi.y * cbd.y + tr, cbi.x * cbd.x + tc
    h,k  = m.shape
    k2,w = n.shape

    # Load shared array
    shar = cuda.shared.array(0, dtype=np.float32)
    ms,ns = shar[:tw*tw],shar[tw*tw:2*tw*tw]

    p = np.float32(0.0)
    for ph in range(math.ceil(k/tw)):
        idx = ph*tw
        ms[tr*tw+tc] = m[r, tc+idx] if r<h and idx+tc<k else 0.
        ns[tr*tw+tc] = n[tr+idx, c] if c<w and idx+tr<k else 0.
        cuda.syncthreads()


        for i in range(tw):
          p += ms[tr*tw+i] * ns[i*tw+tc]
        cuda.syncthreads()

    if r < h and c < w:
      out[r, c] = p

def matmul_2d_numba(m, n, tw=16):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    out = torch.zeros(h, w, dtype=m.dtype, device=m.device)
    dyn_shared_mem_size = 2 * tw * tw * 4
    tpb = tw,tw
    blocks = cdiv(w,tpb[0]), cdiv(h,tpb[1])
    matmul_k_numba[blocks, tpb, 0, dyn_shared_mem_size](ca(m), ca(n), ca(out), tw)
    return out

In [50]:
matmul_2d_numba(m1c,m2c)
torch.cuda.synchronize()

In [52]:
torch.isclose(matmul_2d_numba(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')