Skip to content

Commit

Permalink
Merge pull request #423 from hameerabbasi/add-shape-kwarg-full-like
Browse files Browse the repository at this point in the history
Add shape= kwarg in {zeros, ones, full}_like.
  • Loading branch information
hameerabbasi committed Jan 4, 2021
2 parents 66a6757 + b94b42e commit 440809f
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 31 deletions.
3 changes: 0 additions & 3 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ jobs:
python_version: '3.7'
Python36:
python_version: '3.6'
ArrayFunction:
NUMPY_EXPERIMENTAL_ARRAY_FUNCTION: '1'
NUMPY_VERSION: '==1.16.2'
- job: MacOS
variables:
python_version: '3.6'
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
numpy
numpy>=1.17
scipy>=0.19
numba>=0.49
36 changes: 9 additions & 27 deletions sparse/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,7 +1375,7 @@ def full(shape, fill_value, dtype=None, format="coo", compressed_axes=None):
).asformat(format, compressed_axes=compressed_axes)


def full_like(a, fill_value, dtype=None, format=None, compressed_axes=None):
def full_like(a, fill_value, dtype=None, shape=None, format=None, compressed_axes=None):
"""Return a full array with the same shape and type as a given array.
Parameters
Expand Down Expand Up @@ -1403,12 +1403,12 @@ def full_like(a, fill_value, dtype=None, format=None, compressed_axes=None):
"""
if format is None and not isinstance(a, np.ndarray):
format = type(a).__name__.lower()
else:
elif format is None:
format = "coo"
if hasattr(a, "compressed_axes") and compressed_axes is None:
compressed_axes = a.compressed_axes
return full(
a.shape,
a.shape if shape is None else shape,
fill_value,
dtype=(a.dtype if dtype is None else dtype),
format=format,
Expand Down Expand Up @@ -1452,7 +1452,7 @@ def zeros(shape, dtype=float, format="coo", compressed_axes=None):
)


def zeros_like(a, dtype=None, format=None, compressed_axes=None):
def zeros_like(a, dtype=None, shape=None, format=None, compressed_axes=None):
"""Return a SparseArray of zeros with the same shape and type as ``a``.
Parameters
Expand All @@ -1478,17 +1478,8 @@ def zeros_like(a, dtype=None, format=None, compressed_axes=None):
array([[0, 0, 0],
[0, 0, 0]])
"""
if format is None and not isinstance(a, np.ndarray):
format = type(a).__name__.lower()
elif format is None:
format = "coo"
if hasattr(a, "compressed_axes") and compressed_axes is None:
compressed_axes = a.compressed_axes
return zeros(
a.shape,
dtype=(a.dtype if dtype is None else dtype),
format=format,
compressed_axes=compressed_axes,
return full_like(
a, 0, dtype=dtype, shape=shape, format=format, compressed_axes=compressed_axes
)


Expand Down Expand Up @@ -1528,7 +1519,7 @@ def ones(shape, dtype=float, format="coo", compressed_axes=None):
)


def ones_like(a, dtype=None, format=None, compressed_axes=None):
def ones_like(a, dtype=None, shape=None, format=None, compressed_axes=None):
"""Return a SparseArray of ones with the same shape and type as ``a``.
Parameters
Expand All @@ -1554,17 +1545,8 @@ def ones_like(a, dtype=None, format=None, compressed_axes=None):
array([[1, 1, 1],
[1, 1, 1]])
"""
if format is None and not isinstance(a, np.ndarray):
format = type(a).__name__.lower()
else:
format = "coo"
if hasattr(a, "compressed_axes") and compressed_axes is None:
compressed_axes = a.compressed_axes
return ones(
a.shape,
dtype=(a.dtype if dtype is None else dtype),
format=format,
compressed_axes=compressed_axes,
return full_like(
a, 1, dtype=dtype, shape=shape, format=format, compressed_axes=compressed_axes
)


Expand Down
4 changes: 4 additions & 0 deletions sparse/tests/test_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,7 @@ def test_ones_zeros_like(funcname):
assert_eq(sp_func(x), np_func(x))
assert_eq(sp_func(x, dtype="f8"), np_func(x, dtype="f8"))
assert_eq(sp_func(x, dtype=None), np_func(x, dtype=None))
assert_eq(sp_func(x, shape=(2, 2)), np_func(x, shape=(2, 2)))


def test_full():
Expand All @@ -1320,6 +1321,9 @@ def test_full_like():
x = np.zeros((5, 5), dtype="i8")
assert_eq(sparse.full_like(x, 9.5), np.full_like(x, 9.5))
assert_eq(sparse.full_like(x, 9.5, dtype="f8"), np.full_like(x, 9.5, dtype="f8"))
assert_eq(
sparse.full_like(x, 9.5, shape=(2, 2)), np.full_like(x, 9.5, shape=(2, 2))
)


@pytest.mark.parametrize("complex", [True, False])
Expand Down

0 comments on commit 440809f

Please sign in to comment.