In [None]:
import numpy as np

# Algorithm 1.3.1 Strassen Matrix Multiplicationn

In [16]:
def strassen_matmul(A, B, n_min):
    n = A.shape[0]
    
    if n <= n_min:
        return A @ B
    
    m = n // 2
    u = slice(0, m)
    v = slice(m, n)
    P1 = strassen_matmul(A[u, u] + A[v, v], B[u, u] + B[v, v], n_min)
    P2 = strassen_matmul(A[v, u] + A[v, v], B[u, u], n_min)
    P3 = strassen_matmul(A[u, u], B[u, v] - B[v, v], n_min)
    P4 = strassen_matmul(A[v, v], B[v, u] - B[u, u], n_min)
    P5 = strassen_matmul(A[u, u] + A[u, v], B[v, v], n_min)
    P6 = strassen_matmul(A[v, u] - A[u, u], B[u, u] + B[u, v], n_min)
    P7 = strassen_matmul(A[u, v] - A[v, v], B[v, u] + B[v, v], n_min)
    C = np.empty((n, n))
    C[u, u] = P1 + P4 - P5 + P7
    C[u, v] = P3 + P5
    C[v, u] = P2 + P4
    C[v, v] = P1 + P3 - P2 + P6
    
    return C

In [17]:
n = 2048
A = np.random.rand(n, n)
B = np.random.rand(n, n)
C_ref = A @ B
C = strassen_matmul(A, B, 64)

# Check results
print(np.allclose(C, C_ref))


True
