# The Strassen Subcubic Matrix Multiplication Algorithm

### A simplifying assumption:
In this program, we implement the Strassen subcubic matrix multiplication algorithm. Again, we use the divide and conquer strategy. In order to apply this strategy and to keep the presentation simple, we restrict our program to square $n\times n$ matricies where $n$ is a power of 2, i.e. $n=2^m$ for some positive integer $m$. One can later add a subroutine to avoid this retriction.

### Our choice of library:
We use pytorch tensors for our work.

### High level overview:
Let $X$ and $Y$ be two $n\times n$ matricies as described above. We break each matrix into four submatricies as follow:
### $$
X=\left[\begin{array}{ll}
A & B \\
C & D
\end{array}\right]
,
Y=\left[\begin{array}{ll}
E & F \\
G & H
\end{array}\right]
$$
Then we have 
### $$
X Y=\left[\begin{array}{ll}
M & N \\
P & Q
\end{array}\right]
$$
where $M, N, P, Q$ are $\frac{n}{2} \times \frac{n}{2}$ square matricies computed as follows:

### \begin{equation}
\begin{aligned}
M &=A E+B G \\
N &=A F+B H \\
P &=C E+D G \\
Q &=C F+D H
\end{aligned}
\end{equation}
As it is seen in the above equations, now we have to perform 8 smaller matrix multiplications to compute $XY$. If we continue this path, asymptotic running time of our algorithm wold be $O(n^3)$. Strassen, in 1969, proposed a clever method for computing the multiplication $XY$ by computing only 7 smaller matrix multiplications. To implement his method, we define $\frac{n}{2} \times \frac{n}{2}$ square matricies $Z_1, ..., Z_7$ as follows:
### \begin{equation}
\begin{aligned}
Z_1 &=A(F-H) \\
Z_2 &=(A+B)H\\
Z_3 &=(C+D)E \\
Z_4 &=D(G-E) \\
Z_5 &=(A+D)(E+H) \\
Z_6 &=(B-D)(G+H) \\
Z_7 &=(A-C)(E+F) 
\end{aligned}
\end{equation}
Then we have 
### \begin{equation}
\begin{aligned}
M &=Z_5 + Z_4 - Z_2 + Z_6  \\
N &=Z_1 + Z_2  \\
P &=Z_3 + Z_4  \\
Q &=Z_2 + Z_5 - Z_3 - Z_7  
\end{aligned}
\end{equation}
and we get an asymptotically faster algorithm that $O(n^3)$.
### Comparison:
We also implement naive matrix multiplication and by definition matrix multiplication. Finally, we compare the running time of the following matrix multiplications on two large matricies consisting of random entries:

1. torch.matmul function in pytorch
2. Strassen algorithm
3. naive matrix mulriplication 
4. by definition (using the fnction torch.dot(,) defined in pytorch)

We the random matricies, we both constrauuct them in CPU(integers and possibly floats) and in CUDA (floats).

### We continue to use the above notation in our code:

In [1]:
import torch

In [2]:
import time

In [3]:
def strassen_matrix_mult(X, Y):
    n = int(X.shape[0]/2)
    if n == 1:
        return torch.tensor([[torch.dot(X[0,:],Y[:,0]), torch.dot(X[0,:],Y[:,1]) ],
                             [torch.dot(X[1,:],Y[:,0]), torch.dot(X[1,:],Y[:,1]) ]])
    # 8 submatricies of X and Y
    A = X[:n, :n]
    B = X[:n, n:]
    C = X[n:, :n]
    D = X[n:, n:]
    E = Y[:n, :n]
    F = Y[:n, n:]
    G = Y[n:, :n]
    H = Y[n:, n:]
    # 7 matricies of Strassen method
    Z1 = strassen_matrix_mult(A, F-H)
    Z2 = strassen_matrix_mult(A+B, H)
    Z3 = strassen_matrix_mult(C+D, E)
    Z4 = strassen_matrix_mult(D, G-E)
    Z5 = strassen_matrix_mult(A+D, E+H)
    Z6 = strassen_matrix_mult(B-D, G+H)
    Z7 = strassen_matrix_mult(A-C, E+F)
    # Computing the components of the multiplication XY
    M = Z5 + Z4 - Z2 + Z6  
    N = Z1 + Z2
    P = Z3 + Z4
    Q = Z1 + Z5 - Z3 - Z7
    # stacking the components of the multiplications to build matrix multiplication
    h1 = torch.hstack([M,N])
    h2 = torch.hstack([P,Q])
    matrix_multiplication = torch.vstack([h1,h2])
    return matrix_multiplication

In [4]:
def naive_matrix_mult(X, Y):
    n = int(X.shape[0]/2)
    if n == 1:
        return torch.tensor([[torch.dot(X[0,:],Y[:,0]), torch.dot(X[0,:],Y[:,1]) ],
                             [torch.dot(X[1,:],Y[:,0]), torch.dot(X[1,:],Y[:,1]) ]])
    # 8 submatricies of X and Y
    A = X[:n, :n]
    B = X[:n, n:]
    C = X[n:, :n]
    D = X[n:, n:]
    E = Y[:n, :n]
    F = Y[:n, n:]
    G = Y[n:, :n]
    H = Y[n:, n:]
    # Computing the components of the multiplication XY
    M = naive_matrix_mult(A,E) + naive_matrix_mult(B,G)  
    N = naive_matrix_mult(A,F) + naive_matrix_mult(B,H)  
    P = naive_matrix_mult(C,E) + naive_matrix_mult(D,G)  
    Q = naive_matrix_mult(C,F) + naive_matrix_mult(D,H)  
    # stacking the components of the multiplications to build matrix multiplication
    h1 = torch.hstack([M,N])
    h2 = torch.hstack([P,Q])
    matrix_multiplication = torch.vstack([h1,h2])
    return matrix_multiplication

In [5]:
def by_definition_matrix_mult(X, Y):
    n = int(X.shape[0])
    Z = torch.zeros(n,n)
    for i in range(n):
        for j in range(n):
            Z[i,j] = torch.dot(X[i,:],Y[:,j])
    return Z

In [12]:
# matricies with random integers and using CPU
n = 256
X = torch.randint(low=-10, high=11, size=(n,n))
print(X)
Y = torch.randint(low=-10, high=11, size=(n,n))
print(Y)

tensor([[ -4,  -7,   2,  ...,  -4,   2,   6],
        [  1,   9,   6,  ...,  -5,   9,  -3],
        [ -5,  -8,  -8,  ...,   0,   2,  -7],
        ...,
        [ -7,  10,   3,  ...,  -9, -10,   8],
        [  4,   2,  -7,  ...,  -5,   1,  -2],
        [  1,   0,  -5,  ..., -10,   7,  -1]])
tensor([[  8,   8,  -9,  ...,   7, -10,  -8],
        [  7,  -1,   9,  ...,   7,   2,  -3],
        [  0,   8,  10,  ...,  -7,  -9,  -1],
        ...,
        [  8,   5,  -7,  ...,  -9,   9,  -1],
        [ -6,   4,  -8,  ...,   9, -10,   8],
        [-10,   6,  10,  ...,  -4,   0,   0]])


In [7]:
# matricies with random floats and using CUDA
n = 64
X = torch.rand( size=(n,n), device=torch.device('cuda:0'))
print(X)
Y = torch.rand(size=(n,n), device=torch.device('cuda:0'))
print(Y)

tensor([[0.2805, 0.8426, 0.4132,  ..., 0.2835, 0.2148, 0.0123],
        [0.8611, 0.7705, 0.8043,  ..., 0.3545, 0.3727, 0.1676],
        [0.1211, 0.4358, 0.4004,  ..., 0.7909, 0.5573, 0.0433],
        ...,
        [0.3841, 0.5009, 0.6581,  ..., 0.8104, 0.5988, 0.5321],
        [0.5748, 0.0617, 0.2454,  ..., 0.4483, 0.2809, 0.0429],
        [0.8988, 0.9917, 0.3168,  ..., 0.6453, 0.6785, 0.4649]],
       device='cuda:0')
tensor([[0.9259, 0.2074, 0.9687,  ..., 0.6469, 0.0658, 0.3482],
        [0.6034, 0.4192, 0.4788,  ..., 0.3869, 0.6007, 0.1296],
        [0.7918, 0.3358, 0.1574,  ..., 0.8204, 0.7284, 0.8991],
        ...,
        [0.2003, 0.6265, 0.6476,  ..., 0.8986, 0.7269, 0.2619],
        [0.2055, 0.0269, 0.5605,  ..., 0.5356, 0.7297, 0.6870],
        [0.5165, 0.6963, 0.9101,  ..., 0.8222, 0.9616, 0.5181]],
       device='cuda:0')


In [13]:
start_time = time.time()
print(torch.matmul(X,Y))
print(time.time() - start_time)

tensor([[ -635,   656,    44,  ...,   102,  -153,   -92],
        [  496,  -258,   410,  ...,  -112,  -717,   727],
        [   28,  -695,  -704,  ...,  1012,  -358, -1031],
        ...,
        [ -245,  -271,  1949,  ...,   239,   978,   381],
        [ -803,  -366,   -45,  ...,   194,  -588,   833],
        [   94,  -257,  -611,  ...,    24,   268,   453]])
0.0728302001953125


In [14]:
start_time = time.time()
print(strassen_matrix_mult(X, Y))
print(time.time() - start_time)

tensor([[ -635,   656,    44,  ...,   102,  -153,   -92],
        [  496,  -258,   410,  ...,  -112,  -717,   727],
        [   28,  -695,  -704,  ...,  1012,  -358, -1031],
        ...,
        [ -245,  -271,  1949,  ...,   239,   978,   381],
        [ -803,  -366,   -45,  ...,   194,  -588,   833],
        [   94,  -257,  -611,  ...,    24,   268,   453]])
63.8971734046936


In [15]:
start_time = time.time()
print(naive_matrix_mult(X, Y))
print(time.time() - start_time)

tensor([[ -635,   656,    44,  ...,   102,  -153,   -92],
        [  496,  -258,   410,  ...,  -112,  -717,   727],
        [   28,  -695,  -704,  ...,  1012,  -358, -1031],
        ...,
        [ -245,  -271,  1949,  ...,   239,   978,   381],
        [ -803,  -366,   -45,  ...,   194,  -588,   833],
        [   94,  -257,  -611,  ...,    24,   268,   453]])
137.66592264175415


In [16]:
start_time = time.time()
print(by_definition_matrix_mult(X, Y))
print(time.time() - start_time)

tensor([[ -635.,   656.,    44.,  ...,   102.,  -153.,   -92.],
        [  496.,  -258.,   410.,  ...,  -112.,  -717.,   727.],
        [   28.,  -695.,  -704.,  ...,  1012.,  -358., -1031.],
        ...,
        [ -245.,  -271.,  1949.,  ...,   239.,   978.,   381.],
        [ -803.,  -366.,   -45.,  ...,   194.,  -588.,   833.],
        [   94.,  -257.,  -611.,  ...,    24.,   268.,   453.]])
1.3294436931610107
