## Tests for `gemm_recompute` and `gemm_coded`

Suite of test cases for the matrix multiplication functions for various inputs (square and nonsquare of various sizes).

In [12]:
from matmul import *

# Library function numpywren.BigMatrix.numpy uses some deprecated syntax; 
# want to avoid flooding the output when calling this
import warnings
warnings.filterwarnings('ignore')

In [13]:
def make_randn_block(id, X):
    """Draw data randomly from a standard normal distribution."""
    X.put_block(np.random.randn(X.shard_sizes[0], X.shard_sizes[1]), *id)
    return 0

def randn_bigmtx(RB, CB, SS):
    """Make a dense (RB x CB) blocked numpywren.matrix.BigMatrix with SS x SS blocks."""
    num_rows, num_cols = int(RB * SS), int(CB * SS)
    X_sharded = matrix.BigMatrix("X_sharded{0}_{1}_{2}".format(num_rows, num_cols, SS), \
                                 shape=(num_rows, num_cols), \
                                 shard_sizes=[SS, SS], \
                                 autosqueeze=False, \
                                 write_header=True)        
    if len(X_sharded.block_idxs_not_exist) > 0:
        pwex = pywren.lambda_executor()
        futures = pwex.map(lambda idx: make_randn_block(idx, X_sharded), X_sharded.block_idxs_not_exist)
        pywren.wait(futures, ALL_COMPLETED)
    return X_sharded

In [15]:
# Test for AA.T
A = randn_bigmtx(2, 2, 512)
A_local = A.numpy()
C = A_local.dot(A_local.T)
C_coding, _, _, _ = gemm_coded(A, A, blocks_per_parity=2, s3_key="coded_out")
C_recomp, _, _ = gemm_recompute(A, A, thresh=.7, s3_key="recomp_out")
assert np.allclose(C_coding.numpy(), C)
assert np.allclose(C_recomp.numpy(), C)

In [17]:
# Test for AB.T with square A and nonsquare B
A = randn_bigmtx(2, 2, 128)
B = randn_bigmtx(4, 2, 128)

A_local, B_local = A.numpy(), B.numpy()
C = A_local.dot(B_local.T)
C_coding, _, _, _ = gemm_coded(A, B, blocks_per_parity=2, s3_key="coded_out")
C_recomp, _, _ = gemm_recompute(A, B, thresh=.7, s3_key="recomp_out")
assert np.allclose(C_coding.numpy(), C)
assert np.allclose(C_recomp.numpy(), C)

In [18]:
# Test for AB.T with nonsquare A and nonsquare B
A = randn_bigmtx(3, 6, 64)
B = randn_bigmtx(9, 6, 64)

A_local, B_local = A.numpy(), B.numpy()
C = A_local.dot(B_local.T)
C_coding, _, _, _ = gemm_coded(A, B, blocks_per_parity=3, s3_key="coded_out")
C_recomp, _, _ = gemm_recompute(A, B, thresh=.7, s3_key="recomp_out")
assert np.allclose(C_coding.numpy(), C)
assert np.allclose(C_recomp.numpy(), C)