In [60]:
# This is a quick re-impl of Strassen's algorithm to understand it better
# I should go back and really try to speed it up with C++ or something
# It only has speedups on massively dense matrices - order 10^4

# this is a great resource I used
# https://stanford.edu/~rezab/classes/cme323/S16/notes/Lecture03/cme323_lec3.pdf
# https://www.cs.cmu.edu/afs/cs/academic/class/15750-s17/ScribeNotes/lecture1.pdf
# I plan on parallelizing

In [61]:
using LinearAlgebra

In [51]:
# create a random matrix reference
matrix = rand([0, 1], 4, 4)

4×4 Matrix{Int64}:
 1  0  1  1
 0  0  0  1
 0  1  1  1
 0  1  1  1

In [56]:
# This is not actually Strassen's algorithm. Strassen's algorithm is recursive.
function naive_strassens(A, B)
    n, _ = size(A)
    block_dim = Int64(n/2)

    A_11 = A[1:block_dim, 1:block_dim]
    A_12 = A[1:block_dim, block_dim+1:end]
    A_21 = A[block_dim+1:end, 1:block_dim]
    A_22 = A[block_dim+1:end, block_dim+1:end]

    B_11 = B[1:block_dim, 1:block_dim]
    B_12 = B[1:block_dim, block_dim+1:end]
    B_21 = B[block_dim+1:end, 1:block_dim]
    B_22 = B[block_dim+1:end, block_dim+1:end]
    
    M_1 = (A_11 + A_22) * (B_11 + B_22)
    M_2 = (A_21 + A_22) * B_11
    M_3 = A_11 * (B_12 - B_22)
    M_4 = A_22 * (B_21 - B_11)
    M_5 = (A_11 + A_12) * B_22
    M_6 = (A_21 - A_11) * (B_11 + B_12)
    M_7 = (A_12 - A_22) * (B_21 + B_22)

    C_11 = M_1 + M_4 - M_5 + M_7
    C_12 = M_3 + M_5
    C_21 = M_2 + M_4
    C_22 = M_1 - M_2 + M_3 + M_6

    C = [C_11 C_12;
         C_21 C_22]

    
    return C
end

naive_strassens (generic function with 1 method)

In [53]:
function confirm_matrices(A, B)
    return (A == B)
end

confirm_matrices (generic function with 1 method)

In [62]:
# create two random matrix references
A = rand([0, 1], 100, 100)
B = rand([0, 1], 100, 100)

C_strassen = naive_strassens(A, B)
C_regular = A * B
confirm_matrices(C_strassen, C_regular)

true

In [104]:
function get_blocks(A)
    n, _ = size(A)
    block_dim = Int64(n/2)
    A = pad_matrix(A)

    A_11 = A[1:block_dim, 1:block_dim]
    A_12 = A[1:block_dim, block_dim+1:end]
    A_21 = A[block_dim+1:end, 1:block_dim]
    A_22 = A[block_dim+1:end, block_dim+1:end]


    return A_11, A_12, A_21, A_22
end

get_blocks (generic function with 1 method)

In [107]:
function strassens(A, B)
    n, _ = size(A)
    block_dim = n/2

    if n == 1
        return (A[1][1] * B[1][1])
    end

    A_11, A_12, A_21, A_22 = get_blocks(A)
    B_11, B_12, B_21, B_22 = get_blocks(B)

    
    M_1 = strassens((A_11 + A_22), (B_11 + B_22))
    M_2 = strassens((A_21 + A_22), B_11)
    M_3 = strassens(A_11, (B_12 - B_22))
    M_4 = strassens(A_22, (B_21 - B_11))
    M_5 = strassens((A_11 + A_12), B_22)
    M_6 = strassens((A_21 - A_11), (B_11 + B_12))
    M_7 = strassens((A_12 - A_22), (B_21 + B_22))

    C_11 = M_1 + M_4 - M_5 + M_7
    C_12 = M_3 + M_5
    C_21 = M_2 + M_4
    C_22 = M_1 - M_2 + M_3 + M_6

    C = [C_11 C_12;
         C_21 C_22]

    return C
end

strassens (generic function with 1 method)

In [108]:
A = rand([0, 1], 20, 20)
B = rand([0, 1], 20, 20)

C_strassen = strassens(A, B)
print(size(C_strassen))
C_regular = A * B

confirm_matrices(C_strassen, C_regular)

LoadError: DimensionMismatch: dimensions must match: a has dims (Base.OneTo(10), Base.OneTo(10)), b has dims (Base.OneTo(22), Base.OneTo(22)), mismatch at 1