In [7]:
import numpy as np
from numba import cuda, jit, float32
import time
import math
import torch

In [2]:
def matrix_multiply_cpu(A, B):
    cpu_start=time.time()
   
    C=np.dot(A, B)
    cpu_time=time.time()-cpu_start
    return C, cpu_time

In [3]:
from numba import cuda
import numpy as np
import time

# Define the CUDA kernel
@cuda.jit
def matrix_mult_kernel(A, B, C):
    i, j, k = cuda.grid(3)
    if i < A.shape[0] and j < B.shape[1] and k < A.shape[1]:
        product = A[i, k] * B[k, j]
        cuda.atomic.add(C, (i, j), product)

def matrix_mult_gpu(A, B):
    assert A.shape[1] == B.shape[0]

    total_start_time = time.time()

    A_global_mem = cuda.to_device(A)
    B_global_mem = cuda.to_device(B)
    C_global_mem = cuda.device_array((A.shape[0], B.shape[1]), dtype=np.float32)

    threads_per_block = (8, 8, 8)
    blocks_per_grid_x = (A.shape[0] + threads_per_block[0] - 1) // threads_per_block[0]
    blocks_per_grid_y = (B.shape[1] + threads_per_block[1] - 1) // threads_per_block[1]
    blocks_per_grid_z = (A.shape[1] + threads_per_block[2] - 1) // threads_per_block[2]
    blocks_per_grid = (blocks_per_grid_x, blocks_per_grid_y, blocks_per_grid_z)

    cuda.synchronize()
    kernel_start_time = time.time()
    matrix_mult_kernel[blocks_per_grid, threads_per_block](A_global_mem, B_global_mem, C_global_mem)
    cuda.synchronize()
    
    kernel_time = time.time() - kernel_start_time
    
    total_time = time.time() - total_start_time

    C = C_global_mem.copy_to_host()

    return C, kernel_time, total_time

In [4]:

@cuda.jit
def fast_matmul(A, B, C):
    TPB = 16
    # Define an array in the shared memory
    # The size and type of the arrays must be known at compile time
    sA = cuda.shared.array(shape=(TPB, TPB), dtype=float32)
    sB = cuda.shared.array(shape=(TPB, TPB), dtype=float32)

    x, y = cuda.grid(2)

    tx = cuda.threadIdx.x
    ty = cuda.threadIdx.y
    bpg = cuda.gridDim.x    # blocks per grid

    if x >= C.shape[0] and y >= C.shape[1]:
        # Quit if (x, y) is outside of valid C boundary
        return

    # Each thread computes one element in the result matrix.
    # The dot product is chunked into dot products of TPB-long vectors.
    tmp = 0.
    for i in range(bpg):
        # Preload data into shared memory
        sA[tx, ty] = A[x, ty + i * TPB]
        sB[tx, ty] = B[tx + i * TPB, y]

        # Wait until all threads finish preloading
        cuda.syncthreads()

        # Computes partial product on the shared memory
        for j in range(TPB):
            tmp += sA[tx, j] * sB[j, ty]

        # Wait until all threads finish computing
        cuda.syncthreads()

    C[x, y] = tmp

def matrix_multiply_gpu(A, B):
    TPB = 16
    n, k = A.shape
    k, m = B.shape
    start = time.time()
    A_device = cuda.to_device(A)
    B_device = cuda.to_device(B)
    C_device = cuda.device_array((n, m), dtype=np.float32)

    threads_per_block = (TPB, TPB)
    blocks_per_grid_x = math.ceil(m / TPB)
    blocks_per_grid_y = math.ceil(n / TPB)

    cuda.synchronize()
    kernel_start_time = time.time()
    fast_matmul[(blocks_per_grid_y, blocks_per_grid_x), threads_per_block](A_device, B_device, C_device)
    cuda.synchronize()  # Wait for all GPU activity to finish
    kernel_time = time.time()-kernel_start_time

    C = C_device.copy_to_host()
    gpu_time = time.time() - start
    return C, kernel_time, gpu_time

In [5]:
def random_matrix(n):
    A = np.random.default_rng().standard_normal(size=(n,n), dtype='float32')
    B = np.random.default_rng().standard_normal(size=(n,n), dtype='float32')
    return A, B

In [None]:
def matrix_multiply_torch(A, B):
    start_time = time.time()
    device = torch.device("cuda:0")
    A_device = torch.from_numpy(A).to(device)
    B_device = torch.from_numpy(B).to(device)
    kernel_start_time = time.time()
    C_device = torch.matmul(A_device, B_device)
    kernel_time = time.time()-kernel_start_time
    C = C_device.cpu().numpy()
    torch_time = time.time()-start_time
    return C, torch_time, kernel_time

In [6]:
A, B = random_matrix(10000)
print(A.shape)
print(B.shape)

cpu_start=time.time()
total_kernel_time=0.0
C_cpu, cpu_time = matrix_multiply_cpu(A, B)
total_cpu_time=time.time()-cpu_start
gpu_start=time.time()
C, kernel_time, gpu_time =  matrix_multiply_gpu(A, B)
C_torch, torch_time, torch_kernel_time = matrix_multiply_torch(A,B)
print(f"CPU time: {total_cpu_time}s, GPU time: {gpu_time}s, kernel time: {kernel_time}s, Torch time: {torch_time}s, Torch kernel time: {torch_kernel_time}s")

(10000, 10000)
(10000, 10000)
[[ -10.513769   -2.85761    35.93156  ...   86.984406  -62.66192
    30.3152  ]
 [  43.839413  -67.53968  -142.61398  ...   78.84571    17.246653
   -37.511566]
 [ 209.74535   -38.94636   -57.299988 ...   97.86541   126.144646
    75.9936  ]
 ...
 [ -24.119297  -59.03662  -105.70007  ...  162.62907    69.89179
   -10.81326 ]
 [  37.232277  -97.58978  -107.391754 ...  -89.335      87.85077
   -55.217102]
 [-114.6289    -40.85752     8.5603   ... -104.24077  -159.2414
   -87.373566]]
[[ -10.513738    -2.8576562   35.931572  ...   86.9844     -62.661915
    30.315174 ]
 [  43.83945    -67.539665  -142.61395   ...   78.84573     17.246628
   -37.511547 ]
 [ 209.74533    -38.946365   -57.299965  ...   97.8654     126.144684
    75.99359  ]
 ...
 [ -24.119337   -59.03658   -105.700066  ...  162.62909     69.891785
   -10.813235 ]
 [  37.232243   -97.589745  -107.39173   ...  -89.335       87.850815
   -55.217133 ]
 [-114.628876   -40.85749      8.560303  ... -10