Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 45 additions & 16 deletions pytensor/sparse/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2268,36 +2268,49 @@ def add(x, y):
raise NotImplementedError()


def sub(x, y):
def subtract(
x: SparseVariable | TensorVariable, y: SparseVariable | TensorVariable
) -> SparseVariable:
"""
Subtract two matrices, at least one of which is sparse.

This method will provide the right op according
to the inputs.
This method will provide the right op according to the inputs.

Parameters
----------
x
x : SparseVariable or TensorVariable
A matrix variable.
y
y : SparseVariable or TensorVariable
A matrix variable.

Returns
-------
A sparse matrix
`x` - `y`
result: SparseVariable
Result of `x - y`, as a sparse matrix.

Notes
-----
At least one of `x` and `y` must be a sparse matrix.

The grad will be structured only when one of the variable will be a dense
matrix.

The grad will be structured only when one of the variable will be a dense matrix.
"""
return x + (-y)


def sub(x, y):
warn(
"pytensor.sparse.sub is deprecated and will be removed in a future version. Use "
"pytensor.sparse.subtract instead.",
category=DeprecationWarning,
stacklevel=2,
)

return subtract(x, y)


sub.__doc__ = subtract.__doc__


class MulSS(Op):
# mul(sparse, sparse)
# See the doc of mul() for more detail
Expand Down Expand Up @@ -2491,29 +2504,31 @@ def infer_shape(self, fgraph, node, ins_shapes):
mul_s_v = MulSV()


def mul(x, y):
def multiply(
x: SparseTensorType | TensorType, y: SparseTensorType | TensorType
) -> SparseVariable:
"""
Multiply elementwise two matrices, at least one of which is sparse.

This method will provide the right op according to the inputs.

Parameters
----------
x
x : SparseVariable
A matrix variable.
y
y : SparseVariable
A matrix variable.

Returns
-------
A sparse matrix
`x` * `y`
result: SparseVariable
The elementwise multiplication of `x` and `y`.

Notes
-----
At least one of `x` and `y` must be a sparse matrix.
The grad is regular, i.e. not structured.

The gradient is regular, i.e. not structured.
"""

x = as_sparse_or_tensor_variable(x)
Expand Down Expand Up @@ -2541,6 +2556,20 @@ def mul(x, y):
raise NotImplementedError()


def mul(x, y):
warn(
"pytensor.sparse.mul is deprecated and will be removed in a future version. Use "
"pytensor.sparse.multiply instead.",
category=DeprecationWarning,
stacklevel=2,
)

return multiply(x, y)


mul.__doc__ = multiply.__doc__


class __ComparisonOpSS(Op):
"""
Used as a superclass for all comparisons between two sparses matrices.
Expand Down
14 changes: 7 additions & 7 deletions tests/sparse/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@
gt,
le,
lt,
mul,
mul_s_v,
multiply,
sampling_dot,
sp_ones_like,
square_diagonal,
Expand Down Expand Up @@ -724,21 +724,21 @@ def test_AddDS(self):

def test_MulSS(self):
self._testSS(
mul,
multiply,
np.array([[1.0, 0], [3, 0], [0, 6]]),
np.array([[1.0, 2], [3, 0], [0, 6]]),
)

def test_MulSD(self):
self._testSD(
mul,
multiply,
np.array([[1.0, 0], [3, 0], [0, 6]]),
np.array([[1.0, 2], [3, 0], [0, 6]]),
)

def test_MulDS(self):
self._testDS(
mul,
multiply,
np.array([[1.0, 0], [3, 0], [0, 6]]),
np.array([[1.0, 2], [3, 0], [0, 6]]),
)
Expand Down Expand Up @@ -783,7 +783,7 @@ def _testSS(
assert np.all(val.todense() == array1 + array2)
if dtype1.startswith("float") and dtype2.startswith("float"):
verify_grad_sparse(op, [a, b], structured=False)
elif op is mul:
elif op is multiply:
assert np.all(val.todense() == array1 * array2)
if dtype1.startswith("float") and dtype2.startswith("float"):
verify_grad_sparse(op, [a, b], structured=False)
Expand Down Expand Up @@ -833,7 +833,7 @@ def _testSD(
continue
if dtype1.startswith("float") and dtype2.startswith("float"):
verify_grad_sparse(op, [a, b], structured=True)
elif op is mul:
elif op is multiply:
assert _is_sparse_variable(apb)
assert np.all(val.todense() == b.multiply(array1))
assert np.all(
Expand Down Expand Up @@ -887,7 +887,7 @@ def _testDS(
b = b.data
if dtype1.startswith("float") and dtype2.startswith("float"):
verify_grad_sparse(op, [a, b], structured=True)
elif op is mul:
elif op is multiply:
assert _is_sparse_variable(apb)
ans = np.array([[1, 0], [9, 0], [0, 36]])
assert np.all(val.todense() == (a.multiply(array2)))
Expand Down