diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index cc7f191951..5f161a2b95 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -479,7 +479,7 @@ def pytensor_hash(self): return hash_from_sparse(d) -class SparseConstant(TensorConstant, _sparse_py_operators): +class SparseConstant(SparseVariable, TensorConstant): format = property(lambda self: self.type.format) def signature(self): diff --git a/pytensor/sparse/sharedvar.py b/pytensor/sparse/sharedvar.py index bdfb84cbb4..60b09656be 100644 --- a/pytensor/sparse/sharedvar.py +++ b/pytensor/sparse/sharedvar.py @@ -3,11 +3,11 @@ import scipy.sparse from pytensor.compile import shared_constructor -from pytensor.sparse.basic import SparseTensorType, _sparse_py_operators +from pytensor.sparse.basic import SparseTensorType, SparseVariable from pytensor.tensor.sharedvar import TensorSharedVariable -class SparseTensorSharedVariable(TensorSharedVariable, _sparse_py_operators): +class SparseTensorSharedVariable(TensorSharedVariable, SparseVariable): @property def format(self): return self.type.format diff --git a/pytensor/tensor/sharedvar.py b/pytensor/tensor/sharedvar.py index 28d9961da2..dad1751f9b 100644 --- a/pytensor/tensor/sharedvar.py +++ b/pytensor/tensor/sharedvar.py @@ -6,7 +6,7 @@ from pytensor.misc.safe_asarray import _asarray from pytensor.tensor import _get_vector_length from pytensor.tensor.type import TensorType -from pytensor.tensor.variable import _tensor_py_operators +from pytensor.tensor.variable import TensorVariable def __getattr__(name): @@ -31,7 +31,7 @@ def load_shared_variable(val): return tensor_constructor(val) -class TensorSharedVariable(_tensor_py_operators, SharedVariable): +class TensorSharedVariable(SharedVariable, TensorVariable): def zero(self, borrow: bool = False): r"""Set the values of a shared variable to 0.