In [None]:
!git clone https://github.com/telexyz/kim.git
%cd kim
!pip3 install --upgrade --no-deps git+https://github.com/dlsys10714/mugrade.git
!pip3 install pytest numpy numdifftools pybind11 requests
# !nvidia-smi

In [None]:
!git pull
!make && python3 -m pytest -v -k "matmul and cuda"

## Cuda MatMul Performance

In [53]:
import numpy as np
import torch
import sys
sys.path.append("../")

from kim import backend_ndarray as nd
k = 2**11
# numpy matmul
A_np = np.random.randn(k, k)
B_np = np.random.randn(k, k)
# cpu matmul
A_cpu = nd.array(A_np, device=nd.cpu())
B_cpu = nd.array(B_np, device=nd.cpu())
# cuda simple matmul
A_cuda_simple = nd.array(A_np[:-2,:-2], device=nd.cuda())
B_cuda_simple = nd.array(B_np[:-2,:-2], device=nd.cuda())
# cuda tiled matmul
A_cuda_tiled = nd.array(A_np[:,:-4], device=nd.cuda())
B_cuda_tiled = nd.array(B_np[:-4,:], device=nd.cuda())
# cuda shared mememory tiled matmul
A_cuda_best = nd.array(A_np, device=nd.cuda())
B_cuda_best = nd.array(B_np, device=nd.cuda())
# torch
A_torch = torch.tensor(A_np, dtype=torch.float32, device=torch.device("cuda"))
B_torch = torch.tensor(B_np, dtype=torch.float32, device=torch.device("cuda"))

In [None]:
%%timeit
A_np = np.random.randn(k, k)
B_np = np.random.randn(k, k)
A_np @ B_np

In [56]:
%%timeit
A_cuda_simple @ B_cuda_simple

226 ms ± 9.83 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [62]:
%%timeit
A_cuda_tiled = nd.array(A_np[:,:-4], device=nd.cuda())
B_cuda_tiled = nd.array(B_np[:-4,:], device=nd.cuda())
A_cuda_tiled @ B_cuda_tiled

30.2 ms ± 312 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [61]:
%%timeit
A_cuda_best = nd.array(A_np, device=nd.cuda())
B_cuda_best = nd.array(B_np, device=nd.cuda())
A_cuda_best @ B_cuda_best

20.9 ms ± 800 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [60]:
%%timeit
A_torch = torch.tensor(A_np, dtype=torch.float32, device=torch.device("cuda"))
B_torch = torch.tensor(B_np, dtype=torch.float32, device=torch.device("cuda"))
A_torch @ B_torch

9.99 ms ± 2.06 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
