# The Strassen Subcubic Matrix Multiplication Algorithm

### A simplifying assumption:
In this program, we implement the Strassen subcubic matrix multiplication algorithm. Again, we use the divide and conquer strategy. In order to apply this strategy and to keep the presentation simple, we restrict our program to square $n\times n$ matricies where $n$ is a power of 2, i.e. $n=2^m$ for some positive integer $m$. One can later add a subroutine to avoid this retriction.

### Our choice of library:
We use pytorch tensors for our work.

### High level overview:
Let $X$ and $Y$ be two $n\times n$ matricies as described above. We break each matrix into four submatricies as follow:
### $$
X=\left[\begin{array}{ll}
A & B \\
C & D
\end{array}\right]
,
Y=\left[\begin{array}{ll}
E & F \\
G & H
\end{array}\right]
$$
Then we have 
### $$
X Y=\left[\begin{array}{ll}
M & N \\
P & Q
\end{array}\right]
$$
where $M, N, P, Q$ are $\frac{n}{2} \times \frac{n}{2}$ square matricies computed as follows:

### \begin{equation}
\begin{aligned}
M &=A E+B G \\
N &=A F+B H \\
P &=C E+D G \\
Q &=C F+D H
\end{aligned}
\end{equation}
As it is seen in the above equations, now we have to perform 8 smaller matrix multiplications to compute $XY$. If we continue this path, asymptotic running time of our algorithm wold be $O(n^3)$. Strassen, in 1969, proposed a clever method for computing the multiplication $XY$ by computing only 7 smaller matrix multiplications. To implement his method, we define $\frac{n}{2} \times \frac{n}{2}$ square matricies $Z_1, ..., Z_7$ as follows:
### \begin{equation}
\begin{aligned}
Z_1 &=A(F-H) \\
Z_2 &=(A+B)H\\
Z_3 &=(C+D)E \\
Z_4 &=D(G-E) \\
Z_5 &=(A+D)(E+H) \\
Z_6 &=(B-D)(G+H) \\
Z_7 &=(A-C)(E+F) 
\end{aligned}
\end{equation}
Then we have 
### \begin{equation}
\begin{aligned}
M &=Z_5 + Z_4 - Z_2 + Z_6  \\
N &=Z_1 + Z_2  \\
P &=Z_3 + Z_4  \\
Q &=Z_2 + Z_5 - Z_3 - Z_7  
\end{aligned}
\end{equation}
and we get an asymptotically faster algorithm that $O(n^3)$.
### Comparison:
We also implement ordinary matrix multiplication. Finally, we compare the running time of Strassen algorithm with ordinary matrix mulriplication and with torch.matmul function in pytorch on two large matricies consisting of random entries.

### We continue to use the above notation in our code:

In [1]:
import torch

In [19]:
import time

In [11]:
def strassen_matrix_mult(X, Y):
    n = int(X.shape[0]/2)
    if n == 1:
        return torch.tensor([[torch.dot(X[0,:],Y[:,0]), torch.dot(X[0,:],Y[:,1]) ],
                             [torch.dot(X[1,:],Y[:,0]), torch.dot(X[1,:],Y[:,1]) ]])
    # 8 submatricies of X and Y
    A = X[:n, :n]
    B = X[:n, n:]
    C = X[n:, :n]
    D = X[n:, n:]
    E = Y[:n, :n]
    F = Y[:n, n:]
    G = Y[n:, :n]
    H = Y[n:, n:]
    # 7 matricies of Strassen method
    Z1 = strassen_matrix_mult(A, F-H)
    Z2 = strassen_matrix_mult(A+B, H)
    Z3 = strassen_matrix_mult(C+D, E)
    Z4 = strassen_matrix_mult(D, G-E)
    Z5 = strassen_matrix_mult(A+D, E+H)
    Z6 = strassen_matrix_mult(B-D, G+H)
    Z7 = strassen_matrix_mult(A-C, E+F)
    # Computing the components of the multiplication XY
    M = Z5 + Z4 - Z2 + Z6  
    N = Z1 + Z2
    P = Z3 + Z4
    Q = Z1 + Z5 - Z3 - Z7
    # stacking the components of the multiplications to build matrix multiplication
    h1 = torch.hstack([M,N])
    h2 = torch.hstack([P,Q])
    matrix_multiplication = torch.vstack([h1,h2])
    return matrix_multiplication

In [28]:
def ordinary_matrix_mult(X, Y):
    n = int(X.shape[0]/2)
    if n == 1:
        return torch.tensor([[torch.dot(X[0,:],Y[:,0]), torch.dot(X[0,:],Y[:,1]) ],
                             [torch.dot(X[1,:],Y[:,0]), torch.dot(X[1,:],Y[:,1]) ]])
    # 8 submatricies of X and Y
    A = X[:n, :n]
    B = X[:n, n:]
    C = X[n:, :n]
    D = X[n:, n:]
    E = Y[:n, :n]
    F = Y[:n, n:]
    G = Y[n:, :n]
    H = Y[n:, n:]
    # Computing the components of the multiplication XY
    M = ordinary_matrix_mult(A,E) + ordinary_matrix_mult(B,G)  
    N = ordinary_matrix_mult(A,F) + ordinary_matrix_mult(B,H)  
    P = ordinary_matrix_mult(C,E) + ordinary_matrix_mult(D,G)  
    Q = ordinary_matrix_mult(C,F) + ordinary_matrix_mult(D,H)  
    # stacking the components of the multiplications to build matrix multiplication
    h1 = torch.hstack([M,N])
    h2 = torch.hstack([P,Q])
    matrix_multiplication = torch.vstack([h1,h2])
    return matrix_multiplication

In [55]:
n = 256
X = torch.randint(n*n, (n,n))
print(X)
Y = torch.randint(n*n, (n,n))
print(Y)

tensor([[46054, 63575,  8414,  ..., 30129,  6165, 52170],
        [45940, 18311, 50978,  ..., 57628, 12509, 18013],
        [53073, 41521,  1757,  ..., 42470, 35809,  9114],
        ...,
        [27296, 58590, 48791,  ..., 56130, 53086, 16077],
        [40391, 20040, 51890,  ..., 34412, 63461, 55285],
        [45532, 22264,  9981,  ..., 21428, 28455, 23347]])
tensor([[59173, 58592, 20934,  ..., 13411, 55370,  6986],
        [63237, 58029, 59260,  ..., 19813, 52887,  3668],
        [15039, 11020,  9207,  ..., 63262,  7127, 46368],
        ...,
        [23778, 39437,  9077,  ..., 54282, 33477, 35456],
        [46372, 54678, 49018,  ...,  1737, 23299, 14401],
        [54944, 15661, 59929,  ..., 29973, 43952, 11586]])


In [56]:
start_time = time.time()
print(torch.matmul(X,Y))
print(time.time() - start_time)

tensor([[271762007233, 263202843837, 254572589238,  ..., 281104117193,
         266192091199, 257412165554],
        [282411029445, 277203139824, 272814712034,  ..., 301558551249,
         267715354407, 283685389017],
        [271276742263, 270523240144, 241912566450,  ..., 281612050922,
         255104722054, 262067168605],
        ...,
        [261925232362, 271193187508, 266047462499,  ..., 293828739313,
         277511295859, 271003062631],
        [284497046250, 276750613555, 268808952557,  ..., 296538071465,
         281465741563, 288802938709],
        [273066060828, 273522019251, 250224024368,  ..., 289298857764,
         266567949099, 269083741823]])
0.07479643821716309


In [57]:
start_time = time.time()
print(strassen_matrix_mult(X, Y))
print(time.time() - start_time)

tensor([[271762007233, 263202843837, 254572589238,  ..., 281104117193,
         266192091199, 257412165554],
        [282411029445, 277203139824, 272814712034,  ..., 301558551249,
         267715354407, 283685389017],
        [271276742263, 270523240144, 241912566450,  ..., 281612050922,
         255104722054, 262067168605],
        ...,
        [261925232362, 271193187508, 266047462499,  ..., 293828739313,
         277511295859, 271003062631],
        [284497046250, 276750613555, 268808952557,  ..., 296538071465,
         281465741563, 288802938709],
        [273066060828, 273522019251, 250224024368,  ..., 289298857764,
         266567949099, 269083741823]])
64.92769837379456


In [58]:
start_time = time.time()
print(ordinary_matrix_mult(X, Y))
print(time.time() - start_time)

tensor([[271762007233, 263202843837, 254572589238,  ..., 281104117193,
         266192091199, 257412165554],
        [282411029445, 277203139824, 272814712034,  ..., 301558551249,
         267715354407, 283685389017],
        [271276742263, 270523240144, 241912566450,  ..., 281612050922,
         255104722054, 262067168605],
        ...,
        [261925232362, 271193187508, 266047462499,  ..., 293828739313,
         277511295859, 271003062631],
        [284497046250, 276750613555, 268808952557,  ..., 296538071465,
         281465741563, 288802938709],
        [273066060828, 273522019251, 250224024368,  ..., 289298857764,
         266567949099, 269083741823]])
141.20595574378967
