# Strassen matrix multiply

- 假定有两个 2<sup>n</sup>×2<sup>n</sup> 的矩阵，求取矩阵C = A × B

In [9]:
import numpy as np
from random import sample
from functools import reduce
import datetime

N = 512

a, b = np.random.random((N,N)), np.random.random((N,N))
# a, b = np.array(((1, 2),(3 ,4))), np.array(((5, 6),(7, 8)))

## natrue algorithms

In [10]:
def matrix_multiply_nature(a, b):
    if not a.shape == b.shape and a.shape[0] == a.shape[1]:
        raise ValueError('invald matrix')
    c = np.zeros(a.shape)
    n = a.shape[0]
    
    for i in range(n):
        for j in range(n):
            c[i][j] = reduce(lambda x, y: x + y, [a[i][x] * b[x][j] for x in range(n)])
            
    return c;

In [11]:
t1 = datetime.datetime.now()
c = matrix_multiply_nature(a, b)
print(datetime.datetime.now() - t1)
c[25][25]

0:02:21.904609


122.9289311208492

## recursion algorithms

In [12]:
def matrix_multiply(a, b):
    if not a.shape == b.shape and a.shape[0] == a.shape[1]:
        raise ValueError('invald matrix')
    c = np.zeros(a.shape)
    mid = a.shape[0]//2
    if not mid:
        c[0][0] = a[0][0] * b[0][0]
    else:
        a11 = a[:mid, :mid]
        a12 = a[:mid, mid:]
        a21 = a[mid:, :mid]
        a22 = a[mid:, mid:]
        
        b11 = b[:mid, :mid]
        b12 = b[:mid, mid:]
        b21 = b[mid:, :mid]
        b22 = b[mid:, mid:]
        
        c11 = matrix_multiply(a11, b11) + matrix_multiply(a12, b21)
        c12 = matrix_multiply(a11, b12) + matrix_multiply(a12, b22)
        c21 = matrix_multiply(a21, b11) + matrix_multiply(a22, b21)
        c22 = matrix_multiply(a21, b12) + matrix_multiply(a22, b22)
                        
        c[:mid, :mid] = c11
        c[:mid, mid:] = c12
        c[mid:, :mid] = c21
        c[mid:, mid:] = c22
        
    return c

In [13]:
t1 = datetime.datetime.now()
c = matrix_multiply(a, b)
print(datetime.datetime.now() - t1)
c[25][25]

0:09:48.067821


122.92893112084923

## strassen algorithms

In [14]:
def strassen(a, b):
    if not a.shape == b.shape and a.shape[0] == a.shape[1]:
        raise ValueError('invald matrix')
    c = np.zeros(a.shape)
    mid = a.shape[0]//2
    if not mid:
        c[0][0] = a[0][0] * b[0][0]
    else:
        a11 = a[:mid, :mid]
        a12 = a[:mid, mid:]
        a21 = a[mid:, :mid]
        a22 = a[mid:, mid:]
        
        b11 = b[:mid, :mid]
        b12 = b[:mid, mid:]
        b21 = b[mid:, :mid]
        b22 = b[mid:, mid:]
        
        s1 = b12 - b22
        s2 = a11 + a12
        s3 = a21 + a22
        s4 = b21 - b11
        s5 = a11 + a22
        s6 = b11 + b22
        s7 = a12 - a22
        s8 = b21 + b22
        s9 = a11 - a21
        s10 = b11 + b12
        
        p1 = strassen(a11, s1)
        p2 = strassen(s2, b22)
        p3 = strassen(s3, b11)
        p4 = strassen(a22, s4)
        p5 = strassen(s5, s6)
        p6 = strassen(s7, s8)
        p7 = strassen(s9, s10)
        
        c11 = p5 + p4 - p2 + p6
        c12 = p1 + p2
        c21 = p3 + p4
        c22 = p5 + p1 - p3 - p7
        
        c[:mid, :mid] = c11
        c[:mid, mid:] = c12
        c[mid:, :mid] = c21
        c[mid:, mid:] = c22
        
    return c

In [15]:
t1 = datetime.datetime.now()
c = strassen(a, b)
print(datetime.datetime.now() - t1)
c[25][25]

0:05:04.511303


122.92893112089878