Skip to content

Commit

Permalink
Add SDDMM example (#674)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol committed May 14, 2024
1 parent 79b9d71 commit c12b29e
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 7 deletions.
15 changes: 15 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,21 @@ jobs:
- name: Run benchmarks
run: |
asv run --quick
examples:
runs-on: ubuntu-latest
steps:
- name: Checkout Repo
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5.1.0
with:
python-version: '3.11'
- name: Build and install Sparse
run: |
python -m pip install '.[finch]' scipy
- name: Run examples
run: |
source ci/test_examples.sh
array_api_tests:
runs-on: ubuntu-latest
steps:
Expand Down
10 changes: 4 additions & 6 deletions benchmarks/benchmark_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from .utils import SkipNotImplemented

TIMEOUT: float = 200.0
TIMEOUT: float = 500.0
BACKEND: sparse.BackendType = sparse.backend_var.get()


Expand Down Expand Up @@ -42,8 +42,7 @@ def time_tensordot(self):

class SpMv:
timeout = TIMEOUT
# NOTE: https://github.com/willow-ahrens/Finch.jl/issues/488
params = [[True, False], [(10, 0.01)]] # (1000, 0.01), (1_000_000, 1e-05)
params = [[True, False], [(50, 0.1)]] # (1000, 0.01), (1_000_000, 1e-05)
param_names = ["lazy_mode", "size_and_density"]

def setup(self, lazy_mode, size_and_density):
Expand All @@ -55,9 +54,8 @@ def setup(self, lazy_mode, size_and_density):
random_kwargs["format"] = "gcxs"

self.M = sparse.random((size, size), **random_kwargs)
# NOTE: Once https://github.com/willow-ahrens/Finch.jl/issues/487 is fixed change to (size, 1).
self.v1 = rng.normal(size=(size, 2))
self.v2 = rng.normal(size=(size, 2))
self.v1 = rng.normal(size=(size, 1))
self.v2 = rng.normal(size=(size, 1))

if sparse.BackendType.Finch == BACKEND:
import finch
Expand Down
3 changes: 3 additions & 0 deletions ci/test_examples.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
for example in $(find ./examples/ -iname *.py); do
python $example
done
77 changes: 77 additions & 0 deletions examples/sddmm_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import time

import sparse

import numpy as np
import scipy.sparse as sps

LEN = 10000
DENSITY = 0.0001
ITERS = 3
rng = np.random.default_rng(0)


def benchmark(func, info, args):
print(info)
start = time.time()
for _ in range(ITERS):
func(*args)
elapsed = time.time() - start
print(f"Took {elapsed / ITERS} s.\n")


if __name__ == "__main__":
a_sps = rng.random((LEN, LEN - 10)) * 10
b_sps = rng.random((LEN - 10, LEN)) * 10
s_sps = sps.random(LEN, LEN, format="coo", density=DENSITY, random_state=rng) * 10
s_sps.sum_duplicates()

# Finch
with sparse.Backend(backend=sparse.BackendType.Finch):
s = sparse.asarray(s_sps)
a = sparse.asarray(np.array(a_sps, order="F"))
b = sparse.asarray(np.array(b_sps, order="C"))

@sparse.compiled
def sddmm_finch(s, a, b):
return sparse.sum(
s[:, :, None] * (a[:, None, :] * sparse.permute_dims(b, (1, 0))[None, :, :]),
axis=-1,
)

# Compile
result_finch = sddmm_finch(s, a, b)
assert sparse.nonzero(result_finch)[0].size > 5
# Benchmark
benchmark(sddmm_finch, info="Finch", args=[s, a, b])

# Numba
with sparse.Backend(backend=sparse.BackendType.Numba):
s = sparse.asarray(s_sps)
a = a_sps
b = b_sps

def sddmm_numba(s, a, b):
return s * (a @ b)

# Compile
result_numba = sddmm_numba(s, a, b)
assert sparse.nonzero(result_numba)[0].size > 5
# Benchmark
benchmark(sddmm_numba, info="Numba", args=[s, a, b])

# SciPy
def sddmm_scipy(s, a, b):
return s.multiply(a @ b)

s = s_sps.asformat("csr")
a = a_sps
b = b_sps

result_scipy = sddmm_scipy(s, a, b)
# Benchmark
benchmark(sddmm_scipy, info="SciPy", args=[s, a, b])

np.testing.assert_allclose(result_numba.todense(), result_scipy.toarray())
np.testing.assert_allclose(result_finch.todense(), result_numba.todense())
np.testing.assert_allclose(result_finch.todense(), result_scipy.toarray())
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ tests = [
]
tox = ["sparse[tests]", "tox"]
all = ["sparse[docs,tox]", "matrepr"]
finch = ["finch-tensor>=0.1.14"]
finch = ["finch-tensor>=0.1.19"]

[project.urls]
Documentation = "https://sparse.pydata.org/"
Expand Down

0 comments on commit c12b29e

Please sign in to comment.