Skip to content

Commit

Permalink
Support left-side np.number elemwise operations. (#67)
Browse files Browse the repository at this point in the history
* Added tests with np.float32 scalar

* Add support np.scalar left-side multiplication.

* Add an assertion to make sure nnz is preserved by scalar multiplication.

* Add more tests for left-side operation with scalar.
  • Loading branch information
fujiisoup authored and hameerabbasi committed Jan 8, 2018
1 parent 9be8058 commit 49f6479
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 8 deletions.
9 changes: 8 additions & 1 deletion sparse/core.py
@@ -1,7 +1,7 @@
from __future__ import absolute_import, division, print_function

from collections import Iterable, defaultdict, deque
from functools import reduce
from functools import reduce, partial
import numbers
import operator

Expand Down Expand Up @@ -728,6 +728,13 @@ def _elemwise(func, *args, **kwargs):
self = args[0]
if isinstance(self, scipy.sparse.spmatrix):
self = COO.from_numpy(self)
elif np.isscalar(self) or (isinstance(self, np.ndarray)
and self.ndim == 0):
func = partial(func, self)
other = args[1]
if isinstance(other, scipy.sparse.spmatrix):
other = COO.from_numpy(other)
return other._elemwise_unary(func, *args[2:], **kwargs)

if len(args) == 1:
return self._elemwise_unary(func, *args[1:], **kwargs)
Expand Down
48 changes: 41 additions & 7 deletions sparse/tests/test_core.py
Expand Up @@ -265,8 +265,11 @@ def test_op_scipy_sparse():
(operator.le, -3),
(operator.eq, 1)
])
def test_elemwise_scalar(func, scalar):
@pytest.mark.parametrize('convert_to_np_number', [True, False])
def test_elemwise_scalar(func, scalar, convert_to_np_number):
xs = sparse.random((2, 3, 4), density=0.5)
if convert_to_np_number:
scalar = np.float32(scalar)
y = scalar

x = xs.todense()
Expand All @@ -278,6 +281,33 @@ def test_elemwise_scalar(func, scalar):
assert_eq(fs, func(x, y))


@pytest.mark.parametrize('func, scalar', [
(operator.mul, 5),
(operator.add, 0),
(operator.sub, 0),
(operator.gt, -5),
(operator.lt, 5),
(operator.ne, 0),
(operator.ge, -5),
(operator.le, 3),
(operator.eq, 1)
])
@pytest.mark.parametrize('convert_to_np_number', [True, False])
def test_leftside_elemwise_scalar(func, scalar, convert_to_np_number):
xs = sparse.random((2, 3, 4), density=0.5)
if convert_to_np_number:
scalar = np.float32(scalar)
y = scalar

x = xs.todense()
fs = func(y, xs)

assert isinstance(fs, COO)
assert xs.nnz >= fs.nnz

assert_eq(fs, func(y, x))


@pytest.mark.parametrize('func, scalar', [
(operator.add, 5),
(operator.sub, -5),
Expand Down Expand Up @@ -601,15 +631,19 @@ def test_broadcast_to(shape1, shape2):
assert_eq(np.broadcast_to(x, shape2), a.broadcast_to(shape2))


def test_scalar_multiplication():
@pytest.mark.parametrize('scalar', [2, 2.5, np.float32(2.0), np.int8(3)])
def test_scalar_multiplication(scalar):
a = sparse.random((2, 3, 4), density=0.5)
x = a.todense()

assert_eq(x * 2, a * 2)
assert_eq(2 * x, 2 * a)
assert_eq(x / 2, a / 2)
assert_eq(x / 2.5, a / 2.5)
assert_eq(x // 2.5, a // 2.5)
assert_eq(x * scalar, a * scalar)
assert (a * scalar).nnz == a.nnz
assert_eq(scalar * x, scalar * a)
assert (scalar * a).nnz == a.nnz
assert_eq(x / scalar, a / scalar)
assert (a / scalar).nnz == a.nnz
assert_eq(x // scalar, a // scalar)
# division may reduce nnz.


@pytest.mark.filterwarnings('ignore:divide by zero')
Expand Down

0 comments on commit 49f6479

Please sign in to comment.