### Basic Imports

In [45]:
import numpy as np 


In [46]:
log_size = 8
size = 2**log_size

def create_data(size=2**10, random_seed=0):
    np.random.seed(random_seed)
    A = np.random.rand(size,size)
    B = np.random.rand(size,size)
    return A, B

In [47]:
A, B = create_data(size=size)

In [48]:
# Naive implementation
def naive_multiply(A, B):
    C = np.zeros_like(A)
    for i in range(A.shape[0]):
        for j in range(A.shape[1]):
            for k in range(A.shape[0]):
                C[i,j] += A[i,k] * B[k,j]
    return C

In [49]:
%timeit -o C_naive = naive_multiply(A, B)

8.62 s ± 177 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


<TimeitResult : 8.62 s ± 177 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)>

In [54]:
C_naive

array([[65.94820735, 69.09817932, 62.78252215, ..., 60.1436441 ,
        62.86313174, 64.71917913],
       [63.18062386, 69.09843956, 63.14943823, ..., 58.63464759,
        61.3564899 , 66.82890837],
       [64.22648084, 67.36524163, 62.52753755, ..., 58.58248039,
        61.18760561, 62.958112  ],
       ...,
       [65.25257151, 68.92121804, 64.55013279, ..., 64.08365403,
        64.81985536, 68.06049955],
       [67.08241328, 68.43813322, 65.41929326, ..., 60.81540601,
        61.73456743, 67.31183199],
       [63.81425552, 66.33777535, 62.62621379, ..., 61.18557896,
        63.87677378, 68.64852561]])

In [55]:
def divide_matrix_four_parts(A):
    n = A.shape[0]//2
    A11 = A[:n,:n]
    A12 = A[:n,n:]
    A21 = A[n:,:n]
    A22 = A[n:,n:]
    return A11, A12, A21, A22

In [56]:
A11, A12, A21, A22 = divide_matrix_four_parts(A)

In [57]:
A.shape

(256, 256)

In [58]:
A11.shape

(128, 128)

In [60]:
def strassen_multiply(A, B, threshold=32):
    # if A and B are threshold X threshold matrices directly multiply them
    if A.shape[0] <= threshold:
        return naive_multiply(A, B)
    else:
        A11, A12, A21, A22 = divide_matrix_four_parts(A)
        B11, B12, B21, B22 = divide_matrix_four_parts(B)
        M1 = strassen_multiply(A11 + A22, B11 + B22)
        M2 = strassen_multiply(A21 + A22, B11)
        M3 = strassen_multiply(A11, B12 - B22)
        M4 = strassen_multiply(A22, B21 - B11)
        M5 = strassen_multiply(A11 + A12, B22)
        M6 = strassen_multiply(A21 - A11, B11 + B12)
        M7 = strassen_multiply(A12 - A22, B21 + B22)
        C11 = M1 + M4 - M5 + M7
        C12 = M3 + M5
        C21 = M2 + M4
        C22 = M1 - M2 + M3 + M6
        C = np.vstack((np.hstack((C11, C12)), np.hstack((C21, C22))))
        return C

In [66]:
timeit_strassen_32 = %timeit -o -n 5 -q C_strassen_32 = strassen_multiply(A, B, 32)

KeyboardInterrupt: 

In [63]:
%timeit -o C_strassen_16 = strassen_multiply(A, B, 16)

5.76 s ± 21.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


<TimeitResult : 5.76 s ± 21.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)>

In [64]:
%timeit -o C_strassen_8 = strassen_multiply(A, B, 8)

5.81 s ± 134 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


<TimeitResult : 5.81 s ± 134 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)>

In [65]:
np.allclose(C_naive, C_strassen_32)


NameError: name 'C_strassen_32' is not defined