In [32]:
# This is a quick re-impl of Strassen's algorithm to understand it better
# I should go back and really try to speed it up with C++ or something
# It only has speedups on massively dense matrices - order 10^4 apparently
# It's very slow here and I don't exactly know why? I suspect it's because of memory allocation

# this is a great resource I used
# https://stanford.edu/~rezab/classes/cme323/S16/notes/Lecture03/cme323_lec3.pdf
# https://www.cs.cmu.edu/afs/cs/academic/class/15750-s17/ScribeNotes/lecture1.pdf
# I plan on parallelizing

In [33]:
using LinearAlgebra

In [34]:
# create a random matrix reference
matrix = rand([0, 1], 4, 4)

4×4 Matrix{Int64}:
 0  1  0  1
 0  0  0  1
 0  0  1  0
 1  1  0  1

In [35]:
# This is not actually Strassen's algorithm. Strassen's algorithm is recursive.
function naive_strassens(A, B)
    n, _ = size(A)
    block_dim = Int64(n/2)

    A_11 = A[1:block_dim, 1:block_dim]
    A_12 = A[1:block_dim, block_dim+1:end]
    A_21 = A[block_dim+1:end, 1:block_dim]
    A_22 = A[block_dim+1:end, block_dim+1:end]

    B_11 = B[1:block_dim, 1:block_dim]
    B_12 = B[1:block_dim, block_dim+1:end]
    B_21 = B[block_dim+1:end, 1:block_dim]
    B_22 = B[block_dim+1:end, block_dim+1:end]
    
    M_1 = (A_11 + A_22) * (B_11 + B_22)
    M_2 = (A_21 + A_22) * B_11
    M_3 = A_11 * (B_12 - B_22)
    M_4 = A_22 * (B_21 - B_11)
    M_5 = (A_11 + A_12) * B_22
    M_6 = (A_21 - A_11) * (B_11 + B_12)
    M_7 = (A_12 - A_22) * (B_21 + B_22)

    C_11 = M_1 + M_4 - M_5 + M_7
    C_12 = M_3 + M_5
    C_21 = M_2 + M_4
    C_22 = M_1 - M_2 + M_3 + M_6

    C = [C_11 C_12;
         C_21 C_22]

    
    return C
end

naive_strassens (generic function with 1 method)

In [36]:
function confirm_matrices(A, B)
    return (A == B)
end

confirm_matrices (generic function with 1 method)

In [37]:
# create two random matrix references
A = rand([0, 1], 100, 100)
B = rand([0, 1], 100, 100)

C_strassen = naive_strassens(A, B)
C_regular = A * B
confirm_matrices(C_strassen, C_regular)

true

In [38]:
function get_blocks(A)
    n, _ = size(A)
    block_dim = Int64(n/2)
    
    A_11 = A[1:block_dim, 1:block_dim]
    A_12 = A[1:block_dim, block_dim+1:end]
    A_21 = A[block_dim+1:end, 1:block_dim]
    A_22 = A[block_dim+1:end, block_dim+1:end]


    return A_11, A_12, A_21, A_22
end

get_blocks (generic function with 1 method)

In [39]:
function strassens(A, B)
    n, _ = size(A)
    block_dim = n/2

    if n == 1
        return (A[1][1] * B[1][1])
    end

    A_11, A_12, A_21, A_22 = get_blocks(A)
    B_11, B_12, B_21, B_22 = get_blocks(B)

    
    M_1 = strassens((A_11 + A_22), (B_11 + B_22))
    M_2 = strassens((A_21 + A_22), B_11)
    M_3 = strassens(A_11, (B_12 - B_22))
    M_4 = strassens(A_22, (B_21 - B_11))
    M_5 = strassens((A_11 + A_12), B_22)
    M_6 = strassens((A_21 - A_11), (B_11 + B_12))
    M_7 = strassens((A_12 - A_22), (B_21 + B_22))

    C_11 = M_1 + M_4 - M_5 + M_7
    C_12 = M_3 + M_5
    C_21 = M_2 + M_4
    C_22 = M_1 - M_2 + M_3 + M_6

    C = [C_11 C_12;
         C_21 C_22]

    return C
end

strassens (generic function with 1 method)

In [40]:
A = rand([0, 1], 256, 256)
B = rand([0, 1], 256, 256)

C_strassen = strassens(A, B)
C_regular = A * B

confirm_matrices(C_strassen, C_regular)

true

In [41]:
function strassens_time(A, B)
    @time strassens(A, B)
end

strassens_time (generic function with 1 method)

In [42]:
function regular_time(A, B)
    @time A*B
end

regular_time (generic function with 1 method)

In [43]:
A = rand([0, 1], 512, 512)
B = rand([0, 1], 512, 512)

strassens_time(A, B)
regular_time(A, B)

  8.352625 seconds (138.17 M allocations: 9.948 GiB, 9.75% gc time)
  0.069124 seconds (5 allocations: 2.030 MiB)


512×512 Matrix{Int64}:
 141  153  126  138  150  136  136  133  …  136  147  125  135  132  134  129
 133  141  129  130  135  138  134  139     129  139  122  129  144  137  120
 142  143  122  126  139  132  129  138     125  133  125  124  135  124  117
 126  134  120  137  132  121  121  129     132  126  122  120  122  127  122
 139  137  125  126  144  134  135  135     132  137  125  127  139  136  117
 126  131  116  122  131  126  120  120  …  122  122  132  110  116  127  117
 136  138  124  142  134  133  134  139     133  132  131  130  129  135  128
 136  135  112  116  134  126  135  133     129  141  125  127  134  135  120
 128  143  122  133  126  129  122  131     132  146  128  118  133  136  126
 126  131  125  139  130  140  131  131     123  126  130  123  141  129  125
 130  139  127  144  143  137  130  130  …  142  130  133  129  136  139  129
 130  126  122  130  132  122  134  130     133  131  125  114  131  130  133
 133  146  128  133  139  129  132  136  

In [44]:
# looks like we are allocating way to much memory. Gotta re-impl this more 
# carefully in something like CUDA to see the full speed ups!