Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove tensor.ScalarSharedVariable #396

Closed
ricardoV94 opened this issue Jul 24, 2023 · 0 comments · Fixed by #629
Closed

Remove tensor.ScalarSharedVariable #396

ricardoV94 opened this issue Jul 24, 2023 · 0 comments · Fixed by #629

Comments

@ricardoV94
Copy link
Member

class ScalarSharedVariable(TensorSharedVariable):
pass
@shared_constructor.register(np.number)
@shared_constructor.register(float)
@shared_constructor.register(int)
@shared_constructor.register(complex)
def scalar_constructor(
value, name=None, strict=False, allow_downcast=None, borrow=False, target="cpu"
):
"""`SharedVariable` constructor for scalar values.
Default: int64 or float64.
Notes
-----
We implement this using 0-d tensors for now.
We ignore the borrow parameter as we convert ``value`` to an
ndarray (this is a new object). This respects the semantic of
borrow, as it is a hint to PyTensor that we can reuse it.
"""
if target != "cpu":
raise TypeError("not for cpu")
try:
dtype = value.dtype
except AttributeError:
dtype = np.asarray(value).dtype
dtype = str(dtype)
value = _asarray(value, dtype=dtype)
tensor_type = TensorType(dtype=str(value.dtype), shape=())
# Do not pass the dtype to asarray because we want this to fail if
# strict is True and the types do not match.
rval = ScalarSharedVariable(
type=tensor_type,
value=np.array(value, copy=True),
name=name,
strict=strict,
allow_downcast=allow_downcast,
)
return rval

This is an empty subclass created when doing shared(5). Internally, it behaves exactly like a shared 0d tensor variable shared(np.array(5)). In fact calling shared(shared(5).get_value()) will return a TensorSharedVariable from the original ScalarSharedVariable.

There exists a true ScalarSharedVariable with scalar operators in

class ScalarSharedVariable(_scalar_py_operators, SharedVariable):
pass
# this is not installed in the default shared variable registry so that
# scalars are typically 0-d tensors.
# still, in case you need a shared variable scalar, you can get one
# by calling this function directly.
def shared(value, name=None, strict=False, allow_downcast=None):
"""
SharedVariable constructor for scalar values. Default: int64 or float64.
Notes
-----
We implement this using 0-d tensors for now.
"""
if not isinstance(value, (np.number, float, int, complex)):
raise TypeError()
try:
dtype = value.dtype
except AttributeError:
dtype = np.asarray(value).dtype
dtype = str(dtype)
value = getattr(np, dtype)(value)
scalar_type = ScalarType(dtype=dtype)
rval = ScalarSharedVariable(
type=scalar_type,
value=value,
name=name,
strict=strict,
allow_downcast=allow_downcast,
)
return rval

We should probably not use that when a user calls shared(5), but I don't see any reason for the tensor.ScalarSharedVariable either. Just do the same checks as we have now but return the same type as shared(np.array(5)).

This would be as if type(pytensor.tensor.scalar()) returned a dummy subclass of TensorVariable confusingly called ScalarVariable (it does not).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant