Skip to content

Commit

Permalink
Make Constant and Shared variables subclasses of the respective Varia…
Browse files Browse the repository at this point in the history
…bles
  • Loading branch information
ricardoV94 authored and michaelosthege committed Mar 24, 2024
1 parent 97317a5 commit 60e2510
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pytensor/sparse/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def pytensor_hash(self):
return hash_from_sparse(d)


class SparseConstant(TensorConstant, _sparse_py_operators):
class SparseConstant(SparseVariable, TensorConstant):
format = property(lambda self: self.type.format)

def signature(self):
Expand Down
4 changes: 2 additions & 2 deletions pytensor/sparse/sharedvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import scipy.sparse

from pytensor.compile import shared_constructor
from pytensor.sparse.basic import SparseTensorType, _sparse_py_operators
from pytensor.sparse.basic import SparseTensorType, SparseVariable
from pytensor.tensor.sharedvar import TensorSharedVariable


class SparseTensorSharedVariable(TensorSharedVariable, _sparse_py_operators):
class SparseTensorSharedVariable(TensorSharedVariable, SparseVariable):
@property
def format(self):
return self.type.format
Expand Down
4 changes: 2 additions & 2 deletions pytensor/tensor/sharedvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pytensor.misc.safe_asarray import _asarray
from pytensor.tensor import _get_vector_length
from pytensor.tensor.type import TensorType
from pytensor.tensor.variable import _tensor_py_operators
from pytensor.tensor.variable import TensorVariable


def __getattr__(name):
Expand All @@ -31,7 +31,7 @@ def load_shared_variable(val):
return tensor_constructor(val)


class TensorSharedVariable(_tensor_py_operators, SharedVariable):
class TensorSharedVariable(SharedVariable, TensorVariable):
def zero(self, borrow: bool = False):
r"""Set the values of a shared variable to 0.
Expand Down

0 comments on commit 60e2510

Please sign in to comment.