Skip to content

Commit

Permalink
add sparse.roll function
Browse files Browse the repository at this point in the history
  • Loading branch information
ahwillia authored and hameerabbasi committed Jun 25, 2018
1 parent eeae480 commit c23441b
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 1 deletion.
6 changes: 6 additions & 0 deletions docs/generated/sparse.roll.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
roll
====

.. currentmodule:: sparse

.. autofunction:: roll
2 changes: 2 additions & 0 deletions docs/generated/sparse.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ API

random

roll

save_npz

stack
Expand Down
2 changes: 1 addition & 1 deletion sparse/coo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .core import COO, as_coo
from .umath import elemwise
from .common import tensordot, dot, concatenate, stack, triu, tril, where, \
nansum, nanprod, nanmin, nanmax, nanreduce
nansum, nanprod, nanmin, nanmax, nanreduce, roll
67 changes: 67 additions & 0 deletions sparse/coo/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import reduce
import operator
import warnings
import collections

import numpy as np
import scipy.sparse
Expand Down Expand Up @@ -609,3 +610,69 @@ def nanreduce(x, method, identity=None, axis=None, keepdims=False, **kwargs):
"""
arr = _replace_nan(x, method.identity if identity is None else identity)
return arr.reduce(method, axis, keepdims, **kwargs)


def roll(a, shift, axis=None):
"""
Shifts elements of an array along specified axis. Elements that roll beyond
the last position are circulated and re-introduced at the first.
Parameters
----------
x : COO
Input array
shift : int or tuple of ints
Number of index positions that elements are shifted. If a tuple is
provided, then axis must be a tuple of the same size, and each of the
given axes is shifted by the corresponding number. If an int while axis
is a tuple of ints, then broadcasting is used so the same shift is
applied to all axes.
axis : int or tuple of ints, optional
Axis or tuple specifying multiple axes. By default, the
array is flattened before shifting, after which the original shape is
restored.
Returns
-------
res : ndarray
Output array, with the same shape as a.
"""
from .core import COO, as_coo
a = as_coo(a)

# roll flattened array
if axis is None:
return roll(a.reshape((-1,)), shift, 0).reshape(a.shape)

# roll across specified axis
else:
# parse axis input, wrap in tuple
axis = normalize_axis(axis, a.ndim)
if not isinstance(axis, tuple):
axis = (axis,)

# make shift iterable
if not isinstance(shift, collections.Iterable):
shift = (shift,)

elif np.ndim(shift) > 1:
raise ValueError(
"'shift' and 'axis' must be integers or 1D sequences.")

# handle broadcasting
if len(shift) == 1:
shift = np.full(len(axis), shift)

# check if dimensions are consistent
if len(axis) != len(shift):
raise ValueError(
"If 'shift' is a 1D sequence, "
"'axis' must have equal length.")

# shift elements
coords, data = np.copy(a.coords), np.copy(a.data)
for sh, ax in zip(shift, axis):
coords[ax] += sh
coords[ax] %= a.shape[ax]

return COO(coords, data=data, shape=a.shape, has_duplicates=False)
66 changes: 66 additions & 0 deletions sparse/tests/test_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,3 +1527,69 @@ def test_invalid_iterable_error():
with pytest.raises(ValueError):
x = [((2.3, 4.5), 3.2)]
COO.from_iter(x)


class TestRoll(object):

# test on 1d array #
@pytest.mark.parametrize('shift', [0, 2, -2, 20, -20])
def test_1d(self, shift):
xs = sparse.random((100,), density=0.5)
x = xs.todense()
assert_eq(np.roll(x, shift), sparse.roll(xs, shift))
assert_eq(np.roll(x, shift), sparse.roll(x, shift))

# test on 2d array #
@pytest.mark.parametrize(
'shift', [0, 2, -2, 20, -20])
@pytest.mark.parametrize(
'ax', [None, 0, 1, (0, 1)])
def test_2d(self, shift, ax):
xs = sparse.random((10, 10), density=0.5)
x = xs.todense()
assert_eq(np.roll(x, shift, axis=ax), sparse.roll(xs, shift, axis=ax))
assert_eq(np.roll(x, shift, axis=ax), sparse.roll(x, shift, axis=ax))

# test on rolling multiple axes at once #
@pytest.mark.parametrize(
'shift', [(0, 0), (1, -1), (-1, 1), (10, -10)])
@pytest.mark.parametrize(
'ax', [(0, 1), (0, 2), (1, 2), (-1, 1)])
def test_multiaxis(self, shift, ax):
xs = sparse.random((9, 9, 9), density=0.5)
x = xs.todense()
assert_eq(np.roll(x, shift, axis=ax), sparse.roll(xs, shift, axis=ax))
assert_eq(np.roll(x, shift, axis=ax), sparse.roll(x, shift, axis=ax))

# test original is unchanged #
@pytest.mark.parametrize(
'shift', [0, 2, -2, 20, -20])
@pytest.mark.parametrize(
'ax', [None, 0, 1, (0, 1)])
def test_original_is_copied(self, shift, ax):
xs = sparse.random((10, 10), density=0.5)
xc = COO(np.copy(xs.coords), np.copy(xs.data), shape=xs.shape)
sparse.roll(xs, shift, axis=ax)
assert_eq(xs, xc)

# test on empty array #
def test_empty(self):
x = np.array([])
assert_eq(np.roll(x, 1), sparse.roll(sparse.as_coo(x), 1))

# test error handling #
@pytest.mark.parametrize(
'args', [
# iterable shift, but axis not iterable
((1, 1), 0),
# ndim(axis) != 1
(1, [[0, 1]]),
# ndim(shift) != 1
([[0, 1]], [0, 1]),
([[0, 1], [0, 1]], [0, 1])
]
)
def test_valerr(self, args):
x = sparse.random((2, 2, 2), density=1)
with pytest.raises(ValueError):
sparse.roll(x, *args)

0 comments on commit c23441b

Please sign in to comment.