# Strassen multiplication
I’ve already spent days 10 and 20 on multiplication algorithms, but I couldn’t resist this one. The first reason was that I actually never implemented [Strassen algorithm](https://en.wikipedia.org/wiki/Strassen_algorithm) and the second that the algorithm is intriguing.

When you look closely at the expressions Strassen discovered, it’s so hard to comprehend. I spend a lot of time on math and regardless of topic difficulty, there’s always intuition and formal expression behind any idea. Strassen’s expressions, however, completely lack any intuition and I can’t imagine how could he possibly make the discovery?

Unfortunately, the algorithm was unlucky. After its discovery it was rarely used due to concerns about loss of precision and it took many years to prove the concerns were false. Today it has become completely obsolete due to work of [Mr. Kazushige Goto](https://en.wikipedia.org/wiki/Kazushige_Goto) and his extraordinary skill to develop highly optimized code on superscalar CPUs.

In [1]:
import numpy as np

## algorithm

In [2]:
def strassen(A, B):
    k = A.shape[0] // 2
    if k == 0:
        return A * B

    A11, A12 = A[:k, :k], A[:k, k:]
    A21, A22 = A[k:, :k], A[k:, k:]
    B11, B12 = B[:k, :k], B[:k, k:]
    B21, B22 = B[k:, :k], B[k:, k:]
    
    T1 = strassen(A11 + A22, B11 + B22)
    T2 = strassen(A21 + A22, B11)
    T3 = strassen(A11, B12 - B22)
    T4 = strassen(A22, B21 - B11)
    T5 = strassen(A11 + A12, B22)
    T6 = strassen(A21 - A11, B11 + B12)
    T7 = strassen(A12 - A22, B21 + B22)
    
    C = np.zeros(A.shape, dtype=A.dtype)
    C[:k, :k] = T1 + T4 - T5 + T7
    C[:k, k:] = T3 + T5
    C[k:, :k] = T2 + T4
    C[k:, k:] = T1 - T2 + T3 + T6
    
    return C

## run

In [3]:
X = np.random.randint(0, 10, (8, 8))
X

array([[9, 1, 1, 2, 5, 2, 8, 3],
       [4, 7, 1, 9, 6, 6, 9, 5],
       [0, 1, 0, 1, 7, 6, 0, 8],
       [9, 0, 4, 3, 3, 4, 4, 5],
       [6, 9, 4, 7, 1, 5, 5, 0],
       [3, 5, 8, 1, 6, 2, 3, 8],
       [3, 1, 0, 4, 2, 5, 1, 7],
       [2, 6, 4, 6, 3, 4, 6, 3]])

In [4]:
Y = np.random.randint(0, 10, (8, 8))
Y

array([[8, 1, 1, 7, 6, 2, 7, 3],
       [2, 3, 6, 5, 8, 2, 0, 0],
       [5, 5, 1, 6, 3, 4, 4, 9],
       [4, 0, 6, 0, 4, 0, 7, 8],
       [2, 7, 0, 7, 3, 1, 5, 2],
       [8, 1, 9, 3, 3, 7, 6, 2],
       [8, 5, 7, 4, 7, 8, 9, 7],
       [4, 6, 2, 9, 2, 8, 0, 7]])

In [5]:
strassen(X, Y)

array([[189, 112, 108, 174, 156, 131, 190, 143],
       [239, 153, 228, 210, 228, 186, 242, 215],
       [100, 106,  82, 144,  67, 115,  78,  90],
       [194, 104, 105, 181, 137, 137, 175, 164],
       [196,  90, 186, 153, 201, 122, 187, 157],
       [162, 165, 102, 226, 147, 156, 129, 182],
       [122,  72,  99, 122,  84, 109,  98, 111],
       [170, 113, 162, 152, 165, 135, 165, 167]])