In [1]:
import numpy as np

In [2]:
X = np.array([[1, 2], [3, 4], [5, 6]])
Y = np.array([[1, 2], [3, 4]])

print(X)
print(np.matmul(X, Y))
print(X[0, 0], X[1, 0], X[0, 1], X[1, 1])

print(X.shape)

[[1 2]
 [3 4]
 [5 6]]
[[ 7 10]
 [15 22]
 [23 34]]
1 3 2 4
(3, 2)


Standard method:

In [3]:
def mat_mul_bf(X, Y):
    
    # Define shape of input arrays
    X_row, X_col = X.shape
    Y_row, Y_col = Y.shape
    
    Z = np.zeros((X_row, Y_col), dtype = 'int32')
    
    if X_col != Y_row:
        return print("Error, X columns not equal to Y rows - matrix multiplication not defined.")
    
    for i in range(X_row):
        for k in range(Y_col):
            for j in range(X_col):
                Z[i, k] = Z[i, k] + X[i, j]*Y[j, k]
                
    return Z
    

In [4]:
mat_mul_bf(X, Y)

array([[ 7, 10],
       [15, 22],
       [23, 34]], dtype=int32)

Strassen method (d=2):

In [5]:
def strassen_sq_mult(X, Y):
    
    p1 = X[0, 0]*(Y[0, 1] - Y[1, 1])
    p2 = (X[0, 0] + X[0, 1])* Y[1, 1]
    p3 = (X[1, 0] + X[1, 1])* Y[0, 0]
    p4 = X[1, 1]*(Y[1, 0] - Y[0, 0])
    p5 = (X[0, 0] + X[1, 1])*(Y[0, 0] + Y[1, 1])
    p6 = (X[0, 1] - X[1, 1])*(Y[1, 0] + Y[1, 1])
    p7 = (X[0, 0] - X[1, 0])*(Y[0, 0] + Y[0, 1])
    
    res_00 = p5 + p4 - p2 + p6
    res_10 = p3 + p4
    res_01 = p1 + p2
    res_11 = p1 + p5 - p3 - p7
    
    result = np.array([[res_00, res_01], [res_10, res_11]])
    
    return result

In [6]:
strassen_sq_mult(X, Y)

array([[ 7, 10],
       [15, 22]])

In general, can compute strassen recursively on any even sized matrix:
For arbitrary matrices, need a method to divide each larger matrix into quadrants until size 2x2 is reached, can then compute the 2x2 product and combine results

The method below is defined for square matrices with dimension a power of two (2x2, 4x4, 8x8)

In [7]:
def strassen(X, Y):
    
    if (X.shape == (2, 2)) & (Y.shape == (2, 2)):
        return strassen_sq_mult(X, Y)
    
    # Define shape of input arrays
    X_row, X_col = X.shape
    Y_row, Y_col = Y.shape
    
    # Determine midpoint of each array
    X_mid_r = X_row // 2
    X_mid_c = X_col // 2
    
    Y_mid_r = Y_row // 2
    Y_mid_c = Y_col // 2
    
    # Calculate each subdivision of the input matrices
    A = X[:X_mid_r, :X_mid_c]
    B = X[:X_mid_r, X_mid_c:]
    C = X[X_mid_r:, :X_mid_c]
    D = X[X_mid_r:, X_mid_c:]
    E = Y[:Y_mid_r, :Y_mid_c]
    F = Y[:Y_mid_r, Y_mid_c:]
    G = Y[Y_mid_r:, :Y_mid_c]
    H = X[Y_mid_r:, Y_mid_c:]
    
    # Calculate the recursive products
    p1 = strassen(A, F - H)
    p2 = strassen(A + B, H) 
    p3 = strassen(C + D, E) 
    p4 = strassen(D, G - E)
    p5 = strassen(A + D, E + H) 
    p6 = strassen(B - D, E + H) 
    p7 = strassen(A - C, E + F)
    
    # Reconstruct each quadrant
    res_00 = p5 + p4 - p2 + p6
    res_10 = p3 + p4
    res_01 = p1 + p2
    res_11 = p1 + p5 - p3 - p7
    
    # Reassemble the product Z
    Z = np.row_stack((np.column_stack((res_00, res_01)), np.column_stack((res_10, res_11))))
    
    return Z
    
    

In [8]:
X = np.array([[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]])
Y = np.array([[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]])

X_asym = np.array([[1, 2], [1, 2], [1, 2], [1, 2]])

X_asym_pad = np.array([[1, 2, 0, 0], [1, 2, 0, 0], [1, 2, 0, 0], [1, 2, 0, 0]])

print(X)

[[1 2 3 4]
 [1 2 3 4]
 [1 2 3 4]
 [1 2 3 4]]


In [9]:
# Strassen method works for square matrices
print(strassen(X, Y))
print(np.matmul(X, Y))
#(X.shape == (4, 4)) & (Y.shape == (4, 4))

# How to generalise to any even matrix?
# Clue in the behaviour of asymmetric matrices padded with zeros:
print(np.matmul(X, X_asym))
print(np.matmul(X, X_asym_pad))

# This result could be trimmed, but would require a lot of unnecessary comupations
# Is there a way to return NULL in a recursive call with no work done?

[[10 20 30 40]
 [10 20 30 40]
 [10 20 30 40]
 [10 20 30 40]]
[[10 20 30 40]
 [10 20 30 40]
 [10 20 30 40]
 [10 20 30 40]]
[[10 20]
 [10 20]
 [10 20]
 [10 20]]
[[10 20  0  0]
 [10 20  0  0]
 [10 20  0  0]
 [10 20  0  0]]


  Z = np.row_stack((np.column_stack((res_00, res_01)), np.column_stack((res_10, res_11))))
