In [1]:
import numpy as np
from numba import cuda, jit, float32
import time
import math

In [2]:
def matrix_multiply_cpu(A, B):
    cpu_start=time.time()
   
    C=np.dot(A, B)
    cpu_time=time.time()-cpu_start
    return C, cpu_time

In [3]:
from numba import cuda
import numpy as np
import time

# Define the CUDA kernel
@cuda.jit
def matrix_mult_kernel(A, B, C):
    i, j, k = cuda.grid(3)
    if i < A.shape[0] and j < B.shape[1] and k < A.shape[1]:
        product = A[i, k] * B[k, j]
        cuda.atomic.add(C, (i, j), product)

def matrix_mult_gpu(A, B):
    assert A.shape[1] == B.shape[0]

    total_start_time = time.time()

    A_global_mem = cuda.to_device(A)
    B_global_mem = cuda.to_device(B)
    C_global_mem = cuda.device_array((A.shape[0], B.shape[1]), dtype=np.float32)

    threads_per_block = (8, 8, 8)
    blocks_per_grid_x = (A.shape[0] + threads_per_block[0] - 1) // threads_per_block[0]
    blocks_per_grid_y = (B.shape[1] + threads_per_block[1] - 1) // threads_per_block[1]
    blocks_per_grid_z = (A.shape[1] + threads_per_block[2] - 1) // threads_per_block[2]
    blocks_per_grid = (blocks_per_grid_x, blocks_per_grid_y, blocks_per_grid_z)

    cuda.synchronize()
    kernel_start_time = time.time()
    matrix_mult_kernel[blocks_per_grid, threads_per_block](A_global_mem, B_global_mem, C_global_mem)
    cuda.synchronize()
    
    kernel_time = time.time() - kernel_start_time
    
    total_time = time.time() - total_start_time

    C = C_global_mem.copy_to_host()

    return C, kernel_time, total_time

In [4]:

@cuda.jit
def fast_matmul(A, B, C):
    TPB = 16
    # Define an array in the shared memory
    # The size and type of the arrays must be known at compile time
    sA = cuda.shared.array(shape=(TPB, TPB), dtype=float32)
    sB = cuda.shared.array(shape=(TPB, TPB), dtype=float32)

    x, y = cuda.grid(2)

    tx = cuda.threadIdx.x
    ty = cuda.threadIdx.y
    bpg = cuda.gridDim.x    # blocks per grid

    if x >= C.shape[0] and y >= C.shape[1]:
        # Quit if (x, y) is outside of valid C boundary
        return

    # Each thread computes one element in the result matrix.
    # The dot product is chunked into dot products of TPB-long vectors.
    tmp = 0.
    for i in range(bpg):
        # Preload data into shared memory
        sA[tx, ty] = A[x, ty + i * TPB]
        sB[tx, ty] = B[tx + i * TPB, y]

        # Wait until all threads finish preloading
        cuda.syncthreads()

        # Computes partial product on the shared memory
        for j in range(TPB):
            tmp += sA[tx, j] * sB[j, ty]

        # Wait until all threads finish computing
        cuda.syncthreads()

    C[x, y] = tmp

def matrix_multiply_gpu(A, B):
    TPB = 16
    n, k = A.shape
    k, m = B.shape
    start = time.time()
    A_device = cuda.to_device(A)
    B_device = cuda.to_device(B)
    C_device = cuda.device_array((n, m), dtype=np.float32)

    threads_per_block = (TPB, TPB)
    blocks_per_grid_x = math.ceil(m / TPB)
    blocks_per_grid_y = math.ceil(n / TPB)

    cuda.synchronize()
    kernel_start_time = time.time()
    fast_matmul[(blocks_per_grid_y, blocks_per_grid_x), threads_per_block](A_device, B_device, C_device)
    cuda.synchronize()  # Wait for all GPU activity to finish
    kernel_time = time.time()-kernel_start_time

    C = C_device.copy_to_host()
    gpu_time = time.time() - start
    return C, kernel_time, gpu_time

In [5]:
def random_matrix(n):
    A = np.random.default_rng().standard_normal(size=(n,n), dtype='float32')
    B = np.random.default_rng().standard_normal(size=(n,n), dtype='float32')
    return A, B

In [6]:
A, B = random_matrix(10000)
print(A.shape)
print(B.shape)

cpu_start=time.time()
total_kernel_time=0.0
C_cpu, cpu_time = matrix_multiply_cpu(A, B)
total_cpu_time=time.time()-cpu_start
gpu_start=time.time()
C, kernel_time, gpu_time =  matrix_multiply_gpu(A, B)
total_gpu_time=time.time()-gpu_start
print(C_cpu)
print(C)
print(np.sum(C-C_cpu))
print(f"CPU time: {total_cpu_time}s, GPU time: {total_gpu_time}s, kernel time: {kernel_time}s")

(10000, 10000)
(10000, 10000)
[[  57.37409   -119.11281     -9.604512  ...   63.047615    20.997852
   -22.503843 ]
 [  98.83168    -32.27753     45.27719   ...   41.65147   -127.59444
   -27.816784 ]
 [ 117.028015  -233.14578    -41.142445  ...   -1.0513175  -45.53106
   -40.725536 ]
 ...
 [-184.3483      99.350555   115.54753   ...    9.32468    -60.251015
    47.758987 ]
 [-158.54546    100.58536   -165.96198   ... -124.94373     -4.197003
   142.48087  ]
 [ 121.89091    194.1341      88.67832   ... -165.17607    -73.865036
   -17.21402  ]]
[[   57.37404   -119.112816    -9.60455  ...     4.952976   219.05614
    -35.121243]
 [   93.107285   -87.43592    -72.75129  ...   151.59071   -505.03323
    -21.697716]
 [   37.00121    134.99147     23.175694 ...  -133.4305     466.2569
    166.2491  ]
 ...
 [ 4210.505    -1120.3566    3348.48     ...    54.65859    -12.757178
     19.127314]
 [ -251.5593     875.8597     194.32184  ...  -242.38495    396.62753
    263.44785 ]
 [ 1263.6487   