In [1]:
import torch, os, math, gzip, pickle
import matplotlib.pyplot as plt
from urllib.request import urlretrieve
from pathlib import Path

from torch import tensor
import torchvision as tv
import torchvision.transforms.functional as tvf
from torchvision import io
from torch.utils.cpp_extension import load_inline

## Matrix Multiplication

2d Matrix multiplication

In [2]:
N, M = 50, 75

In [3]:
A = torch.randn(N, M)
B = torch.randn(M, N)
C = torch.zeros(N, N, dtype=torch.float32)

In [None]:
for i in range(N):
  for j in range(N):
    for k in range(M):
      C[i,j] += A[i,k] * B[k,j]

In [None]:
def mat_mul(A, B):
  C = torch.zeros(N, N)
  for i in range(N):
    for j in range(N):
      for k in range(M):
        C[i,j] += A[i,k] * B[k,j]
  return C

In [None]:
%%timeit
C = mat_mul(A, B)

30.3 s ± 343 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
C.shape

torch.Size([100, 100])

## CUDA setup

In [4]:
os.environ['CUDA_LAUNCH_BLOCKING']='1'
# Get the CUDA capability of the current device
if torch.cuda.is_available():
    major, minor = torch.cuda.get_device_capability()
    # Set the environment variable with the detected architecture
    os.environ['TORCH_CUDA_ARCH_LIST'] = f"{major}.{minor}"
    print(f"Setting TORCH_CUDA_ARCH_LIST to: {os.environ['TORCH_CUDA_ARCH_LIST']}")

Setting TORCH_CUDA_ARCH_LIST to: 7.5


In [5]:
%pip install -q wurlitzer ninja

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/422.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m422.8/422.8 kB[0m [31m26.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [6]:
%load_ext wurlitzer

In [14]:
def load_cuda(cuda_src, cpp_src, funcs, opt=False, verbose=False):
    return load_inline(cuda_sources=[cuda_src], cpp_sources=[cpp_src], functions=funcs,
                       extra_cuda_cflags=["-O2"] if opt else [], verbose=verbose, name="inline_ext")

In [15]:
cuda_begin = r'''
#include <torch/extension.h>
#include <stdio.h>
#include <c10/cuda/CUDAException.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;}
'''

In [18]:
cuda_src = cuda_begin + r'''

extern "C" __global__ void mat_mul_kernel(
  float *A, float *B, float *C,
  int m, int n, int o)
  {
    /*
    A: m x n
    B: n x o
    C: m x o
    */
    int r = blockIdx.y * blockDim.y + threadIdx.y;
    int c = blockIdx.x * blockDim.x + threadIdx.x;

    if (r < o && c < m) {
      float val = 0;
      for (int i = 0; i < n; i++) {
        val += A[r * n + i] * B[i * o + c];
      }
      C[r * o + c] = val;
    }
}


torch::Tensor mat_mul(torch::Tensor A, torch::Tensor B) {
    CHECK_INPUT(A);
    CHECK_INPUT(B);
    int m = A.size(0);
    int n = A.size(1);
    int o = B.size(1);

    torch::Tensor C = torch::zeros({m, o}, A.options());
    dim3 tpb(32,32);
    dim3 blocks(cdiv(m, tpb.x), cdiv(o, tpb.y));

    mat_mul_kernel<<<blocks, tpb>>>(
        A.data_ptr<float>(), B.data_ptr<float>(),
        C.data_ptr<float>(),  m, n, o);

    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return C;
    }

'''

In [23]:
cpp_src = "torch::Tensor mat_mul(torch::Tensor m, torch::Tensor n);"
module = load_cuda(cuda_src, cpp_src, ['mat_mul'])

In [24]:
dir(module)

['__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__spec__',
 'mat_mul']

In [25]:
Ac = A.contiguous().cuda()
Bc = B.contiguous().cuda()

In [26]:
Ac.shape, Bc.shape

(torch.Size([50, 75]), torch.Size([75, 50]))

In [27]:
%%time
C = module.mat_mul(Ac, Bc)

CPU times: user 852 µs, sys: 4.14 ms, total: 4.99 ms
Wall time: 13.4 ms


In [29]:
N, M = 1150, 750

In [31]:
A = torch.randn(N, M)
B = torch.randn(M, N)
C = torch.zeros(N, N, dtype=torch.float32)
Ac = A.contiguous().cuda()
Bc = B.contiguous().cuda()

In [32]:
%%time
C = module.mat_mul(Ac, Bc)

CPU times: user 5.56 ms, sys: 936 µs, total: 6.49 ms
Wall time: 6.56 ms
