In [3]:
import os, math, sys, torch, re, numpy as np
from types import SimpleNamespace as ns
from collections import namedtuple
import matplotlib.pyplot as plt
from torch.utils.cpp_extension import load_inline
import torch
!pip install ninja
import ninja

Collecting ninja
  Downloading ninja-1.11.1.4-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (5.0 kB)
Downloading ninja-1.11.1.4-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (422 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/422.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m419.8/422.8 kB[0m [31m13.9 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m422.8/422.8 kB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ninja
Successfully installed ninja-1.11.1.4


In [4]:
dim3 = namedtuple('dim3', ['x', 'y', 'z'], defaults=(1, 1))

In [5]:
np.set_printoptions(precision=2, linewidth=140)
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)

In [6]:
def show_img(x, figsize=(4, 3), **kwargs):
  plt.figure(figsize=figsize)
  plt.axis('off')
  if len(x.shape)==3: x = x.permute(1, 2, 0)
  plt.imshow(x.cpu(), **kwargs)

cuda_begin = r"""
#include <torch/extension.h>
#include <stdio.h>
#include <c10/cuda/CUDAException.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CUDA_ERR(ans) { gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true)
{
  if (code != cudaSuccess)
  {
    fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
    if (abort) exit(code);
  }
}
__host__ __device__ inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a+b-1)/b; }
"""

def load_cuda(cuda_src, cpp_src, funcs, opt=False, verbose=False, name=None):
  if name is None: name = funcs[0]
  return load_inline(cuda_sources=[cuda_src], cpp_sources=[cpp_src], functions=funcs,
                     extra_cuda_cflags=["-O2"] if opt else [], verbose=verbose, name=name)

def cdiv(a, b):
  return (a+b-1)//b

In [7]:
%load_ext wurlitzer

In [8]:
os.environ["CUDA_LAUNCH_BLOCKING"] = '1'
torch.manual_seed(1);

In [9]:
m1 = torch.rand(5120, 256)
m1s = m1[:4]
m2 = torch.rand(256, 5120)
m2s = m2[:, :4]

In [10]:
def blk_kernel2d(f, blocks, threads, *args):
  for i0 in range(blocks.y):
    for i1 in range(blocks.x):
      for j0 in range(threads.y):
        for j1 in range(threads.x):
          f(dim3(i1, i0), dim3(j1, j0), threads, *args)

In [11]:
def get_sig(fname, src):
  res = re.findall(rf"^(.+\s+{fname}\(.*?\))\s*{{?\s*$", src, re.MULTILINE)
  return res[0] + ';' if res else None

In [12]:
m1c, m2c = m1.contiguous().cuda(), m2.contiguous().cuda()

In [13]:
a = torch.zeros(5)
b, c = a[:3], a[3:]

In [14]:
b[1] = 2
c[0] = 6
a

tensor([0., 2., 0., 6., 0.])

In [15]:
def blk_kernel2d_shar(f, blocks, threads, sh_sz, *args, **kwargs):
  for i0 in range(blocks.y):
    for i1 in range(blocks.x):
      shared = torch.zeros(sh_sz)
      f(dim3(i1, i0), threads, shared, *args, **kwargs)

In [16]:
def matmul_tiled_blk(blockIdx, blockDim, shared, m, n, out, h, w, k, tw):
    shar_sz = tw*tw
    ms, ns = shared[:shar_sz], shared[shar_sz:]

    for ph in range(int(math.ceil(k/tw))): # ph is the tile idx
      idx = ph*tw
      for tr in range(blockDim.y):
        for tc in range(blockDim.x):
          r, c = blockIdx.y*blockDim.y + tr, blockIdx.x*blockDim.x + tc
          ms[tr*tw+tc] = m[tc+idx+r*k] if r<h  and idx+tc<k else 0.
          ns[tr*tw+tc] = n[(tr+idx)*w+c] if c<w  and idx+tr<k else 0.

      for t in range(blockDim.y):
        for tc in range(blockDim.x):
          r, c = blockIdx.y*blockDim.y + t, blockIdx.x*blockDim.x + tc
          for i in range(tw):
            if r*w+c<len(out):
               out[r*w+c] += ms[t*tw+i]*ns[tw*i+tc]


In [17]:
def matmul_2d(m, n, tw=16):
  h, k = m.shape
  k2, w = n.shape
  assert k == k2, "Size mismatch!"
  out = torch.zeros(h, w, dtype=m.dtype)
  tpb = dim3(tw, tw)
  blks = dim3(cdiv(w, tpb.x), cdiv(h, tpb.y))
  blk_kernel2d_shar(matmul_tiled_blk, blks, tpb, tw*tw*2, m.flatten(), n.flatten(),
                    out.flatten(), h, w, k, tw)
  return out


In [18]:
m1s.shape, m2.shape

(torch.Size([4, 256]), torch.Size([256, 5120]))

In [19]:
torch.isclose(matmul_2d(m1s, m2s, tw=10), m1s@m2s).all()

tensor(True)

# Run Threads

In [20]:
def run_threads(f, blockDim, *args, **kwargs):
  for i0 in range(blockDim.y):
    for i1 in range(blockDim.x):
      f(i0, i1, *args, **kwargs)

In [21]:
def matmul_tiled_blk(blockIdx, blockDim, shared, m, n, out, h, w, k, tw):
    shar_sz = tw*tw
    ms, ns = shared[:shar_sz], shared[shar_sz:]
    def fill(tr, tc, ph):
          r, c = blockIdx.y*blockDim.y + tr, blockIdx.x*blockDim.x + tc
          ms[tr*tw+tc] = m[tc+idx+r*k] if r<h  and idx+tc<k else 0.
          ns[tr*tw+tc] = n[(tr+idx)*w+c] if c<w  and idx+tr<k else 0.

    def dot_prod(tr, tc, ph):
          r, c = blockIdx.y*blockDim.y + tr, blockIdx.x*blockDim.x + tc
          for i in range(tw):
            if r*w+c<len(out):
               out[r*w+c] += ms[tr*tw+i]*ns[tw*i+tc]

    for ph in range(int(math.ceil(k/tw))): # ph is the tile idx
      idx = ph*tw
      run_threads(fill, blockDim, ph)
      run_threads(dot_prod, blockDim, ph)

In [22]:
def matmul_2d(m, n, tw=16):
  h, k = m.shape
  k2, w = n.shape
  assert k == k2, "Size mismatch!"
  out = torch.zeros(h, w, dtype=m.dtype)
  tpb = dim3(tw, tw)
  blks = dim3(cdiv(w, tpb.x), cdiv(h, tpb.y))
  blk_kernel2d_shar(matmul_tiled_blk, blks, tpb, tw*tw*2, m.flatten(), n.flatten(),
                    out.flatten(), h, w, k, tw)
  return out

In [23]:
torch.isclose(matmul_2d(m1s, m2s, tw=10), m1s@m2s).all()

tensor(True)

# Python threading lib

In [24]:
import threading
from threading import Barrier, Thread
from concurrent.futures import ThreadPoolExecutor

In [25]:
def blk_kernel2d_shar(f, blocks, tpb, sh_sz, *args, **kwargs):
  for i0 in range(blocks.y):
    for i1 in range(blocks.x):
      shar = torch.zeros(sh_sz)
      syncb = Barrier(tpb.y*tpb.x)
      threads = (
          Thread(target=f, args=(dim3(i1 ,i0), dim3(p, o), tpb, shar, syncb, *args), kwargs=kwargs)
          for o in range(tpb.y) for p in range(tpb.x)
      )
      for t in threads:
        t.start()
      for t in threads: t.join()


# def blk_kernel2d_shar(f, blocks, tpb, sh_sz, *args, **kwargs):
#     for i0 in range(blocks.y):
#         for i1 in range(blocks.x):
#           shar = torch.zeros(sh_sz)
#           syncb = Barrier(tpb.y*tpb.x)
#           threads = [
#              Thread(target=f, args=(dim3(i1,i0), dim3(p,o), tpb, shar, syncb, *args), kwargs=kwargs)
#                    for o in range(tpb.y) for p in range(tpb.x)
#           ]
#           for tr in threads: tr.start()
#           for tr in threads: tr.join()

In [26]:
def matmul_tiled_blk(blockIdx, threadIdx, blockDim, shared, syncb, m, n, out, h, w, k, tw):
  tc , tr = threadIdx.x, threadIdx.y
  r = blockIdx.y*blockDim.y+tr
  c = blockIdx.x*blockDim.x+tc

  shar_sz = tw*tw
  ms, ns = shared[:shar_sz], shared[shar_sz:]

  p = 0.
  for ph in range(cdiv(k, tw)):
    ms[tr*tw+tc] = m[tc+ph*tw +r*k] if r < h and (ph*tw+tc)<k else 0.
    ns[tr*tw+tc] = n[(tr+ph*tw)* w+c] if c < w and (ph*tw+tr)<k else 0.
    syncb.wait()
    for i in range(tw):
      p += ms[tr*tw+i] * ns[tw*i + tc]
    syncb.wait()

  if (r < h and c < w): out[r*w+c] = p

In [27]:
def matmul_2d(m, n, tw=16):
  h, k = m.shape
  k2, w = n.shape
  assert k == k2, "Size mismatch!"
  out = torch.zeros(h, w, dtype=m.dtype)
  tpb = dim3(tw, tw)
  blks = dim3(cdiv(w, tpb.x), cdiv(h, tpb.y))
  blk_kernel2d_shar(matmul_tiled_blk, blks, tpb, tw*tw*2, m.flatten(), n.flatten(),
                    out.flatten(), h, w, k, tw=tw)
  return out
matmul_2d(m1s, m2s, tw=10)
torch.isclose(matmul_2d(m1s, m2s, tw=10), m1s@m2s).all()

tensor(False)

In [28]:
m1s@m2s

tensor([[69.24, 64.11, 68.03, 63.76],
        [65.21, 62.05, 65.01, 61.14],
        [65.99, 65.22, 66.94, 62.76],
        [69.27, 61.63, 63.86, 61.43]])

In [29]:
cuda_src = cuda_begin + r"""
__global__ void matmul_k(float *m, float *n, float *out, int h, int w, int k, int tw) {
  int tc = threadIdx.x, tr=threadIdx.y;
  int r = blockIdx.y*blockDim.y+tr, c = blockIdx.x*blockDim.x+tc;

  extern __shared__ float ms[];
  float *ns = &ms[tw*tw];

  float p = 0.0f;
  for (int ph = 0; ph < cdiv(k, tw); ++ph) {
    int idx = ph*tw;
    ms[tr*tw+tc] = r < h && idx +tc < k ? m[tc+idx+r*k]: 0.0f;
    ns[tr*tw+tc] = c < w && idx +tr < k ? n[(tr+idx)*w+c]: 0.0f;
    __syncthreads();
    for (int i = 0; i < tw; ++i)
      p += ms[tr*tw+i] * ns[tw*i+tc];
    __syncthreads();
  }
  if (r < h && c < w) out[r*w+c] = p;
}
"""

In [30]:
cuda_src += r"""
torch::Tensor matmul_dyn(torch::Tensor m, torch::Tensor n) {
  CHECK_INPUT(m); CHECK_INPUT(n);
  int h = m.size(0), w = n.size(1), k = m.size(1);
  TORCH_CHECK(k==n.size(0), "Size mismatch!");
  auto out = torch::zeros({h, w}, m.options());

  int TW = 16;
  size_t size = TW*TW*2*sizeof(float);

  dim3 tpb(TW, TW);
  dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
  matmul_k<<<blocks, tpb, size>>>(
    m.data_ptr<float>(), n.data_ptr<float>(), out.data_ptr<float>(), h, w, k, TW);
  C10_CUDA_KERNEL_LAUNCH_CHECK();
  return out;
}"""

In [31]:
fname = "matmul_dyn"

In [32]:
cpp_src = get_sig(fname, cuda_src)

In [33]:
cpp_src

'torch::Tensor matmul_dyn(torch::Tensor m, torch::Tensor n);'

In [38]:
# cuda_src += "\n// force rebuild"
# import shutil
# shutil.rmtree('/root/.cache', ignore_errors=True)
module = load_cuda(cuda_src, cpp_src, [fname], verbose=True, opt=True)

Using /root/.cache/torch_extensions/py311_cu124 as PyTorch extensions root...
The input conditions for extension module matmul_dyn have changed. Bumping to version 1 and re-building as matmul_dyn_v1...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/py311_cu124/matmul_dyn/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module matmul_dyn_v1...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


[1/3] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=matmul_dyn_v1 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /usr/local/lib/python3.11/dist-packages/torch/include -isystem /usr/local/lib/python3.11/dist-packages/torch/include/torch/csrc/api/include -isystem /usr/local/lib/python3.11/dist-packages/torch/include/TH -isystem /usr/local/lib/python3.11/dist-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /usr/include/python3.11 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -c /root/.cache/torch_extensions/py311_cu124/matmul_dyn/main.cpp -o main.o 
[2/3] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output cuda.cuda.o.d -DTORCH_EXTENSION_NAME=matmul_dyn_v1 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /usr/local/lib/python3.11/dist-packages/torch/i

Loading extension module matmul_dyn_v1...


In [39]:
torch.isclose(module.matmul_dyn(m1c, m2c), m1c@m2c).all()

tensor(True, device='cuda:0')

In [40]:
m1c@m2c

tensor([[69.24, 64.11, 68.03,  ..., 64.08, 63.17, 65.47],
        [65.21, 62.05, 65.01,  ..., 64.15, 59.35, 62.69],
        [65.99, 65.22, 66.94,  ..., 61.65, 59.49, 62.10],
        ...,
        [69.24, 66.72, 66.94,  ..., 67.49, 63.27, 64.29],
        [70.44, 65.72, 70.58,  ..., 65.22, 63.46, 67.78],
        [71.36, 66.22, 68.45,  ..., 62.70, 62.34, 65.90]], device='cuda:0')

In [41]:
%%timeit -n 10
module.matmul_dyn(m1c, m2c)
torch.cuda.synchronize()

28.3 ms ± 3.04 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


# CUDA static shared

In [48]:
cuda_src = cuda_begin + r"""
constexpr int tw = 16;

__global__ void matmul_k(float *m, float *n, float *out, int h, int w, int k) {
  __shared__ float ms[tw][tw], ns[tw][tw];
  int tc = threadIdx.x, tr=threadIdx.y;
  int r = blockIdx.y*blockDim.y+tr, c = blockIdx.x*blockDim.x+tc;

  float p = 0.0f;
  for (int ph = 0; ph < cdiv(k, tw); ++ph) {
    int idx = ph*tw;
    ms[tr][tc] = r < h && idx +tc < k ? m[tc+idx+r*k]: 0.0f;
    ns[tr][tc] = c < w && idx +tr < k ? n[(tr+idx)*w+c]: 0.0f;
    __syncthreads();
    for (int i = 0; i < tw; ++i)
      p += ms[tr][i] * ns[i][tc];
    __syncthreads();
  }
  if (r < h && c < w) out[r*w+c] = p;
}
"""

In [49]:
cuda_src += r"""
torch::Tensor matmul_static(torch::Tensor m, torch::Tensor n) {
  CHECK_INPUT(m); CHECK_INPUT(n);
  int h = m.size(0), w = n.size(1), k = m.size(1);
  TORCH_CHECK(k==n.size(0), "Size mismatch!");
  auto out = torch::zeros({h, w}, m.options());

  int TW = 16;
  size_t size = TW*TW*2*sizeof(float);

  dim3 tpb(TW, TW);
  dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
  matmul_k<<<blocks, tpb, size>>>(
    m.data_ptr<float>(), n.data_ptr<float>(), out.data_ptr<float>(), h, w, k);
  C10_CUDA_KERNEL_LAUNCH_CHECK();
  return out;
}"""

In [50]:
fname = "matmul_static"
cpp_src = get_sig(fname, cuda_src)
module = load_cuda(cuda_src, cpp_src, [fname], verbose=True, opt=True)

Using /root/.cache/torch_extensions/py311_cu124 as PyTorch extensions root...
The input conditions for extension module matmul_static have changed. Bumping to version 1 and re-building as matmul_static_v1...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/py311_cu124/matmul_static/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module matmul_static_v1...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


[1/3] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=matmul_static_v1 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /usr/local/lib/python3.11/dist-packages/torch/include -isystem /usr/local/lib/python3.11/dist-packages/torch/include/torch/csrc/api/include -isystem /usr/local/lib/python3.11/dist-packages/torch/include/TH -isystem /usr/local/lib/python3.11/dist-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /usr/include/python3.11 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -c /root/.cache/torch_extensions/py311_cu124/matmul_static/main.cpp -o main.o 
[2/3] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output cuda.cuda.o.d -DTORCH_EXTENSION_NAME=matmul_static_v1 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /usr/local/lib/python3.11/dist-package

Loading extension module matmul_static_v1...


In [52]:
%%timeit -n 10
module.matmul_static(m1c, m2c)
torch.cuda.synchronize()

23.6 ms ± 2.82 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
