From 49f64793c75f37cf137a3f86077135a22097ac0e Mon Sep 17 00:00:00 2001 From: Keisuke Fujii Date: Tue, 9 Jan 2018 04:36:50 +0900 Subject: [PATCH] Support left-side np.number elemwise operations. (#67) * Added tests with np.float32 scalar * Add support np.scalar left-side multiplication. * Add an assertion to make sure nnz is preserved by scalar multiplication. * Add more tests for left-side operation with scalar. --- sparse/core.py | 9 +++++++- sparse/tests/test_core.py | 48 +++++++++++++++++++++++++++++++++------ 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/sparse/core.py b/sparse/core.py index e455a9dc..9a28145c 100644 --- a/sparse/core.py +++ b/sparse/core.py @@ -1,7 +1,7 @@ from __future__ import absolute_import, division, print_function from collections import Iterable, defaultdict, deque -from functools import reduce +from functools import reduce, partial import numbers import operator @@ -728,6 +728,13 @@ def _elemwise(func, *args, **kwargs): self = args[0] if isinstance(self, scipy.sparse.spmatrix): self = COO.from_numpy(self) + elif np.isscalar(self) or (isinstance(self, np.ndarray) + and self.ndim == 0): + func = partial(func, self) + other = args[1] + if isinstance(other, scipy.sparse.spmatrix): + other = COO.from_numpy(other) + return other._elemwise_unary(func, *args[2:], **kwargs) if len(args) == 1: return self._elemwise_unary(func, *args[1:], **kwargs) diff --git a/sparse/tests/test_core.py b/sparse/tests/test_core.py index 394abe7c..2858e026 100644 --- a/sparse/tests/test_core.py +++ b/sparse/tests/test_core.py @@ -265,8 +265,11 @@ def test_op_scipy_sparse(): (operator.le, -3), (operator.eq, 1) ]) -def test_elemwise_scalar(func, scalar): +@pytest.mark.parametrize('convert_to_np_number', [True, False]) +def test_elemwise_scalar(func, scalar, convert_to_np_number): xs = sparse.random((2, 3, 4), density=0.5) + if convert_to_np_number: + scalar = np.float32(scalar) y = scalar x = xs.todense() @@ -278,6 +281,33 @@ def test_elemwise_scalar(func, scalar): assert_eq(fs, func(x, y)) +@pytest.mark.parametrize('func, scalar', [ + (operator.mul, 5), + (operator.add, 0), + (operator.sub, 0), + (operator.gt, -5), + (operator.lt, 5), + (operator.ne, 0), + (operator.ge, -5), + (operator.le, 3), + (operator.eq, 1) +]) +@pytest.mark.parametrize('convert_to_np_number', [True, False]) +def test_leftside_elemwise_scalar(func, scalar, convert_to_np_number): + xs = sparse.random((2, 3, 4), density=0.5) + if convert_to_np_number: + scalar = np.float32(scalar) + y = scalar + + x = xs.todense() + fs = func(y, xs) + + assert isinstance(fs, COO) + assert xs.nnz >= fs.nnz + + assert_eq(fs, func(y, x)) + + @pytest.mark.parametrize('func, scalar', [ (operator.add, 5), (operator.sub, -5), @@ -601,15 +631,19 @@ def test_broadcast_to(shape1, shape2): assert_eq(np.broadcast_to(x, shape2), a.broadcast_to(shape2)) -def test_scalar_multiplication(): +@pytest.mark.parametrize('scalar', [2, 2.5, np.float32(2.0), np.int8(3)]) +def test_scalar_multiplication(scalar): a = sparse.random((2, 3, 4), density=0.5) x = a.todense() - assert_eq(x * 2, a * 2) - assert_eq(2 * x, 2 * a) - assert_eq(x / 2, a / 2) - assert_eq(x / 2.5, a / 2.5) - assert_eq(x // 2.5, a // 2.5) + assert_eq(x * scalar, a * scalar) + assert (a * scalar).nnz == a.nnz + assert_eq(scalar * x, scalar * a) + assert (scalar * a).nnz == a.nnz + assert_eq(x / scalar, a / scalar) + assert (a / scalar).nnz == a.nnz + assert_eq(x // scalar, a // scalar) + # division may reduce nnz. @pytest.mark.filterwarnings('ignore:divide by zero')