## Setup

In [14]:
!pip install wurlitzer
!pip install Ninja
import os,math,sys,torch,re,numpy as np
from types import SimpleNamespace as ns
from collections import namedtuple
# from utils import show_img,load_cuda,cuda_begin,cdiv

Collecting wurlitzer
  Downloading wurlitzer-3.0.3-py3-none-any.whl (7.3 kB)
Installing collected packages: wurlitzer
Successfully installed wurlitzer-3.0.3
Collecting Ninja
  Downloading ninja-1.11.1.1-py2.py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl (307 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m307.2/307.2 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: Ninja
Successfully installed Ninja-1.11.1.1


# Utils

In [2]:
import torch
import matplotlib.pyplot as plt
from torch.utils.cpp_extension import load_inline

import os,math,sys,torch,re,numpy as np
from types import SimpleNamespace as ns
from collections import namedtuple

np.set_printoptions(precision=2, linewidth=140)
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)

def show_img(x, figsize=(4,3), **kwargs):
    "Display HW or CHW format image `x`"
    plt.figure(figsize=figsize)
    plt.axis('off')
    if len(x.shape)==3: x = x.permute(1,2,0)  # CHW -> HWC
    plt.imshow(x.cpu(), **kwargs)

cuda_begin = r'''
#include <torch/extension.h>
#include <stdio.h>
#include <c10/cuda/CUDAException.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CUDA_ERR(ans) { gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true)
{
   if (code != cudaSuccess)
   {
      fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
      if (abort) exit(code);
   }
}
__host__ __device__ inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a+b-1)/b;}
'''

def load_cuda(cuda_src, cpp_src, funcs, opt=True, verbose=False, name=None):
    "Simple wrapper for torch.utils.cpp_extension.load_inline"
    if name is None: name = funcs[0]
    flags = "-O3 -Xptxas -O3 -Xcompiler -O3" if opt else "-O0 -Xptxas -O0 -Xcompiler -O0"
    return load_inline(cuda_sources=[cuda_src], cpp_sources=[cpp_src], functions=funcs,
                       extra_cuda_cflags=[flags], verbose=verbose, name=name)

def cdiv(a,b):
    "Int ceiling division of `a` over `b`"
    return (a+b-1)//b


In [None]:
%load_ext wurlitzer

## Python Version in CUDA format

In [31]:
# Setup
dim3 = namedtuple('dim3', ['x','y','z'], defaults=(1,1))
d = dim3(2,3)
m1 = torch.rand(5120, 256)
m1s = m1[:4]
m2 = torch.rand(256,5120)
m2s = m2[:,:4]

In [32]:
# Functions
def iterate_kerenel(f, blocks, threads, *args):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            for j0 in range(threads.y):
                for j1 in range(threads.x): f(dim3(i1,i0), dim3(j1,j0), threads, *args)

def matmul_kernel(blockIdx, threadIdx, blockDim, m, n, out, h, w, k):
    r = blockIdx.y*blockDim.y + threadIdx.y
    c = blockIdx.x*blockDim.x + threadIdx.x

    # boundary checking
    if (r>=h or c>=w):
       return

    # matrix multiplication loop over flattened tensors
    o = 0.0
    for i in range(k):
      o += m[r*k+i] * n[i*w+c]
    out[r*w+c] = o



def matmul(m, n):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = dim3(16,16)
    blocks = dim3(cdiv(w,tpb.x), cdiv(h,tpb.y))
    iterate_kerenel(matmul_kernel, blocks, tpb,
                 m.flatten(), n.flatten(), output.flatten(), h, w, k)
    return output

In [34]:
# Test
torch.isclose(matmul(m1s, m2s), m1s@m2s).all()

tensor(True)

## CUDA version