### Numpy Implementation of Strassen Matrix Multiplication

In [74]:
import numpy as np
'''
Numpy Implementation
'''
def split_mat(X,mdp):
    '''
    Matrix Split.
    Input:
        X: nxn matrix
        mdp: middle of of the matrix
    Output:
        4 sub matrices in the format of:
            a b
            c c
    '''
    a = X[:mdp,:mdp]
    b = X[:mdp,mdp:]
    c = X[mdp:,:mdp]
    d = X[mdp:,mdp:]
    return a,b,c,d


def np_strassen(X,Y):
    '''
    Input:
        nxn integer matrices X and Y
    Output:
        Z = X.Y
    Assume:
        n is power of 2     
    '''
    n = len(X)
    assert n==X.shape[0]==X.shape[1]==Y.shape[0] == Y.shape[1]
    
    
    if n==1:
        return X*Y
    
    
    
    mdp = len(X)//2 #middle point
    # Divide into sub-matrices
    A,B,C,D = split_mat(X,mdp)
    E,F,G,H = split_mat(Y,mdp)
    p1 = np_strassen(A,F-H)
    p2 = np_strassen(A+B,H)
    p3 = np_strassen(C+D,E)
    p4 = np_strassen(D,G-E)
    p5 = np_strassen(A+D,E+H)
    p6 = np_strassen(B-D,G+H)
    p7 = np_strassen(A-C,E+F)    
    # Strassen:
    upper_left = p5+p4-p2+p6
    upper_right = p1+p2
    lower_left = p3+p4
    lower_right = p1+p5-p3-p7
    # Construct New:
    final_output = np.zeros((len(upper_left)*2,len(upper_left)*2)) 
    final_output[:len(upper_left),:len(upper_left)] = upper_left
    final_output[len(upper_left):,:len(upper_left)] = lower_left
    final_output[:len(upper_left),len(upper_left):] = upper_right
    final_output[len(upper_left):,len(upper_left):] = lower_right
    
    return final_output

In [75]:
X = np.array([[1,2,1,1],[3,4,1,2],[5,2,3,1],[4,2,2,5]])
Y = np.array([[5,6,10,1],[1,1,7,8],[2,1,2,1],[3,5,2,-1]])

Test0:

In [76]:
np_strassen(X,Y)

array([[12., 14., 28., 17.],
       [27., 33., 64., 34.],
       [36., 40., 72., 23.],
       [41., 53., 68., 17.]])

In [77]:
# to validate:
np.dot(X,Y)

array([[12, 14, 28, 17],
       [27, 33, 64, 34],
       [36, 40, 72, 23],
       [41, 53, 68, 17]])

Test1:

In [78]:
X = np.array([[2,3],[9,4]])
Y = np.array([[5,7],[1,8]])

In [79]:
np_strassen(X,Y)

array([[13., 38.],
       [49., 95.]])

In [80]:
# to validate:
np.dot(X,Y)

array([[13, 38],
       [49, 95]])

### Native List Implementation of Strassen Matrix Multiplication

In [81]:
'''
Native List Py Implementation:
'''

def split_mat(X,mdp):
    '''
    Matrix Split.
    Input:
        X: nxn matrix
        mdp: middle of of the matrix
    Output:
        4 sub matrices in the format of:
            a b
            c c
    '''
    a = [i[:mdp] for i in X[:mdp]]
    b = [i[mdp:] for i in X[:mdp]]
    c = [i[:mdp] for i in X[mdp:]]
    d = [i[mdp:] for i in X[mdp:]]
    return a,b,c,d

def reconstruct(x):
    '''
    reconstruct submatrices for the output
    '''
    plen = int(len(x)/2)
    if plen==0:
        return x
    
    x_new = [[0 for i in range(plen)] for  j in range(plen)]
    cnt = 0
    for i in range(plen):
        for j in range(plen):
            x_new[i][j] = x[cnt]
            cnt+=1
    return x_new


def matrix_addition(X,Y):

    assert len(X) == len(Y)
    assert [len(i) for i in X] == [len(i) for i in Y]

    nrow = len(X)
    ncol = list(set([len(i) for i in X]))[0]
    output_mat = []
    for i in range(nrow):
        temp_row = []
        for j in range(ncol):
            temp_row.append(X[i][j] + Y[i][j])
        output_mat.append(temp_row)
    return output_mat
            
def matrix_subtraction(X,Y):
    assert len(X) == len(Y)
    assert [len(i) for i in X] == [len(i) for i in Y]
    nrow = len(X)
    ncol = list(set([len(i) for i in X]))[0]
    output_mat = []
    
    for i in range(nrow):
        temp_row = []
        for j in range(ncol):
            temp_row.append(X[i][j] - Y[i][j])
        output_mat.append(temp_row)
    return output_mat


def list_strassen(X,Y):
    '''
    Input:
        nxn integer matrices X and Y
    Output:
        Z = X.Y
    Assume:
        n is power of 2     
    '''
    n = len(X)
    assert n==len(X)== len(Y)
    
    if n==1:
        return [[X[0][0]*Y[0][0]]]
    
    
    
    mdp = len(X)//2 #middle point
    # Divide into sub-matrices
    A,B,C,D = split_mat(X,mdp)
    E,F,G,H = split_mat(Y,mdp)
    
    p1 = list_strassen(A,matrix_subtraction(F,H))
    p2 = list_strassen(matrix_addition(A,B),H)
    p3 = list_strassen(matrix_addition(C,D),E)
    p4 = list_strassen(D,matrix_subtraction(G,E))
    p5 = list_strassen(matrix_addition(A,D),matrix_addition(E,H))
    p6 = list_strassen(matrix_subtraction(B,D),matrix_addition(G,H))
    p7 = list_strassen(matrix_subtraction(A,C),matrix_addition(E,F))    
    # Strassen:
    upper_left =  matrix_addition(matrix_subtraction(matrix_addition(p5,p4),p2),p6)
    upper_right = matrix_addition(p1,p2)
    lower_left = matrix_addition(p3,p4)
    lower_right = matrix_subtraction(matrix_subtraction(matrix_addition(p1,p5),p3),p7) 
    
    
    res = []
    
    for row in range(len(upper_left)):
        res.append(upper_left[row] + upper_right[row])
    
    for row in range(len(upper_left)):
   
        res.append(lower_left[row] + lower_right[row])

    return res



    

In [82]:
X =  [[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16]]
Y =  [[2,2,3,4],[5,4,7,2],[92,2,1,5],[5,2,7,0]]

list_strassen(X,Y)

[[308, 24, 48, 23],
 [724, 64, 120, 67],
 [1140, 104, 192, 111],
 [1556, 144, 264, 155]]

In [83]:
# To validate:
np.dot(np.array(X),np.array(Y))

array([[ 308,   24,   48,   23],
       [ 724,   64,  120,   67],
       [1140,  104,  192,  111],
       [1556,  144,  264,  155]])

In [84]:
'''
Native List Py Implementation:
'''

def split_mat(X,mdp):
    '''
    Matrix Split.
    Input:
        X: nxn matrix
        mdp: middle of of the matrix
    Output:
        4 sub matrices in the format of:
            a b
            c c
    '''
    a = [i[:mdp] for i in X[:mdp]]
    b = [i[mdp:] for i in X[:mdp]]
    c = [i[:mdp] for i in X[mdp:]]
    d = [i[mdp:] for i in X[mdp:]]
    return a,b,c,d

def reconstruct(x):
    '''
    reconstruct submatrices for the output
    '''
    plen = int(len(x)/2)
    if plen==0:
        return x
    
    x_new = [[0 for i in range(plen)] for  j in range(plen)]
    cnt = 0
    for i in range(plen):
        for j in range(plen):
            x_new[i][j] = x[cnt]
            cnt+=1
    return x_new


def matrix_addition(X,Y):

    assert len(X) == len(Y)
    assert [len(i) for i in X] == [len(i) for i in Y]

    nrow = len(X)
    ncol = list(set([len(i) for i in X]))[0]
    output_mat = []
    for i in range(nrow):
        temp_row = []
        for j in range(ncol):
            temp_row.append(X[i][j] + Y[i][j])
        output_mat.append(temp_row)
    return output_mat
            
def matrix_subtraction(X,Y):
    assert len(X) == len(Y)
    assert [len(i) for i in X] == [len(i) for i in Y]
    nrow = len(X)
    ncol = list(set([len(i) for i in X]))[0]
    output_mat = []
    
    for i in range(nrow):
        temp_row = []
        for j in range(ncol):
            temp_row.append(X[i][j] - Y[i][j])
        output_mat.append(temp_row)
    return output_mat


def list_strassen(X,Y):
    '''
    Input:
        nxn integer matrices X and Y
    Output:
        Z = X.Y
    Assume:
        n is power of 2     
    '''
    n = len(X)
    assert n==len(X)== len(Y)
    
    if n==1:
        return [[X[0][0]*Y[0][0]]]
    
    
    
    mdp = len(X)//2 #middle point
    # Divide into sub-matrices
    A,B,C,D = split_mat(X,mdp)
    E,F,G,H = split_mat(Y,mdp)
    
    p1 = list_strassen(A,matrix_subtraction(F,H))
    p2 = list_strassen(matrix_addition(A,B),H)
    p3 = list_strassen(matrix_addition(C,D),E)
    p4 = list_strassen(D,matrix_subtraction(G,E))
    p5 = list_strassen(matrix_addition(A,D),matrix_addition(E,H))
    p6 = list_strassen(matrix_subtraction(B,D),matrix_addition(G,H))
    p7 = list_strassen(matrix_subtraction(A,C),matrix_addition(E,F))    
    # Strassen:
    upper_left =  matrix_addition(matrix_subtraction(matrix_addition(p5,p4),p2),p6)
    upper_right = matrix_addition(p1,p2)
    lower_left = matrix_addition(p3,p4)
    lower_right = matrix_subtraction(matrix_subtraction(matrix_addition(p1,p5),p3),p7) 
    
    
    res = [[0]*(len(upper_left)*2) for i in range(len(upper_left)*2)]

    # Construct:
    for row in range(len(upper_left)):
        for col in range(len(upper_left)):
            
            res[row][col] = upper_left[row][col]
            
            
            res[row][len(upper_left)+col] = upper_right[row][col]
            
            
            res[row+len(upper_left)][col] = lower_left[row][col]
            
            
            res[row+len(upper_left)][col+len(upper_left)] = lower_right[row][col]
            
 
    return res



    

In [85]:
X =  [[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16]]
Y =  [[2,2,3,4],[5,4,7,2],[92,2,1,5],[5,2,7,0]]


list_strassen(X,Y)

[[308, 24, 48, 23],
 [724, 64, 120, 67],
 [1140, 104, 192, 111],
 [1556, 144, 264, 155]]

Test:

In [86]:
np.dot(np.array(X),np.array(Y))

array([[ 308,   24,   48,   23],
       [ 724,   64,  120,   67],
       [1140,  104,  192,  111],
       [1556,  144,  264,  155]])