Skip to content

Commit

Permalink
Merge pull request #676 from pydata/mttkrp-example
Browse files Browse the repository at this point in the history
Add `MTTKRP` example
  • Loading branch information
mtsokol committed May 15, 2024
2 parents c12b29e + 4f4c80e commit eb78737
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
62 changes: 62 additions & 0 deletions examples/mttkrp_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import time

import sparse

import numpy as np

I_ = 1000
J_ = 25
K_ = 1000
L_ = 100
DENSITY = 0.0001
ITERS = 3
rng = np.random.default_rng(0)


def benchmark(func, info, args):
print(info)
start = time.time()
for _ in range(ITERS):
func(*args)
elapsed = time.time() - start
print(f"Took {elapsed / ITERS} s.\n")


if __name__ == "__main__":
print("MTTKRP Example:\n")

B_sps = sparse.random((I_, K_, L_), density=DENSITY, random_state=rng) * 10
D_sps = rng.random((L_, J_)) * 10
C_sps = rng.random((K_, J_)) * 10

# Finch
with sparse.Backend(backend=sparse.BackendType.Finch):
B = sparse.asarray(B_sps.todense(), format="csf")
D = sparse.asarray(np.array(D_sps, order="F"))
C = sparse.asarray(np.array(C_sps, order="F"))

@sparse.compiled
def mttkrp_finch(B, D, C):
return sparse.sum(B[:, :, :, None] * D[None, None, :, :] * C[None, :, None, :], axis=(1, 2))

# Compile
result_finch = mttkrp_finch(B, D, C)
assert sparse.nonzero(result_finch)[0].size > 5
# Benchmark
benchmark(mttkrp_finch, info="Finch", args=[B, D, C])

# Numba
with sparse.Backend(backend=sparse.BackendType.Numba):
B = sparse.asarray(B_sps, format="gcxs")
D = D_sps
C = C_sps

def mttkrp_numba(B, D, C):
return sparse.sum(B[:, :, :, None] * D[None, None, :, :] * C[None, :, None, :], axis=(1, 2))

# Compile
result_numba = mttkrp_numba(B, D, C)
# Benchmark
benchmark(mttkrp_numba, info="Numba", args=[B, D, C])

np.testing.assert_allclose(result_finch.todense(), result_numba.todense())
2 changes: 2 additions & 0 deletions examples/sddmm_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def benchmark(func, info, args):


if __name__ == "__main__":
print("SDDMM Example:\n")

a_sps = rng.random((LEN, LEN - 10)) * 10
b_sps = rng.random((LEN - 10, LEN)) * 10
s_sps = sps.random(LEN, LEN, format="coo", density=DENSITY, random_state=rng) * 10
Expand Down

0 comments on commit eb78737

Please sign in to comment.