In [None]:
import numpy as np
from numba import njit, prange
import time, multiprocessing
import matplotlib.pyplot as plt

print('CPUs available:', multiprocessing.cpu_count())

@njit(parallel=True)
def big_dot(A, B):
    n = A.shape[0]
    C = np.zeros((n, n))
    for i in prange(n):
        for j in range(n):
            for k in range(n):
                C[i, j] += A[i, k] * B[k, j]
    return C

n = 800  # adjust if Binder RAM errors (>1000 is risky)
A = np.random.rand(n, n)
B = np.random.rand(n, n)

start = time.time()
C = big_dot(A, B)
elapsed = time.time() - start
print(f'Elapsed: {elapsed:.2f} s')

plt.imshow(C[:100, :100], cmap='inferno')
plt.title(f'{n}x{n} Matrix Product Snapshot')
plt.show()