Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get_scalar_constant_value now raises for non-scalar inputs #248

Merged
merged 7 commits into from
Mar 22, 2023
2 changes: 1 addition & 1 deletion doc/library/tensor/basic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ them perfectly, but a `dscalar` otherwise.
.. method:: round(mode="half_away_from_zero")
:noindex:
.. method:: trace()
.. method:: get_scalar_constant_value()
.. method:: get_underlying_scalar_constant_value()
.. method:: zeros_like(model, dtype=None)

All the above methods are equivalent to NumPy for PyTensor on the current tensor.
Expand Down
6 changes: 3 additions & 3 deletions pytensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _as_symbolic(x, **kwargs) -> Variable:
# isort: on


def get_scalar_constant_value(v):
def get_underlying_scalar_constant(v):
"""Return the constant scalar (i.e. 0-D) value underlying variable `v`.

If `v` is the output of dim-shuffles, fills, allocs, cast, etc.
Expand All @@ -153,8 +153,8 @@ def get_scalar_constant_value(v):
if sparse and isinstance(v.type, sparse.SparseTensorType):
if v.owner is not None and isinstance(v.owner.op, sparse.CSM):
data = v.owner.inputs[0]
return tensor.get_scalar_constant_value(data)
return tensor.get_scalar_constant_value(v)
return tensor.get_underlying_scalar_constant_value(data)
return tensor.get_underlying_scalar_constant_value(v)


# isort: off
Expand Down
4 changes: 2 additions & 2 deletions pytensor/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,7 +1325,7 @@ def try_to_copy_if_needed(var):
f" {i}. Since this input is only connected "
"to integer-valued outputs, it should "
"evaluate to zeros, but it evaluates to"
f"{pytensor.get_scalar_constant_value(term)}."
f"{pytensor.get_underlying_scalar_constant(term)}."
)
raise ValueError(msg)

Expand Down Expand Up @@ -2086,7 +2086,7 @@ def _is_zero(x):

no_constant_value = True
try:
constant_value = pytensor.get_scalar_constant_value(x)
constant_value = pytensor.get_underlying_scalar_constant(x)
no_constant_value = False
except pytensor.tensor.exceptions.NotScalarConstantError:
pass
Expand Down
6 changes: 3 additions & 3 deletions pytensor/link/jax/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
ScalarFromTensor,
Split,
TensorFromScalar,
get_scalar_constant_value,
get_underlying_scalar_constant_value,
)
from pytensor.tensor.exceptions import NotScalarConstantError

Expand Down Expand Up @@ -106,7 +106,7 @@ def join(axis, *tensors):
def jax_funcify_Split(op: Split, node, **kwargs):
_, axis, splits = node.inputs
try:
constant_axis = get_scalar_constant_value(axis)
constant_axis = get_underlying_scalar_constant_value(axis)
except NotScalarConstantError:
constant_axis = None
warnings.warn(
Expand All @@ -116,7 +116,7 @@ def jax_funcify_Split(op: Split, node, **kwargs):
try:
constant_splits = np.array(
[
get_scalar_constant_value(splits[i])
get_underlying_scalar_constant_value(splits[i])
for i in range(get_vector_length(splits))
]
)
Expand Down
6 changes: 3 additions & 3 deletions pytensor/scan/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pytensor.graph.utils import MissingInputError, TestValueError
from pytensor.scan.op import Scan, ScanInfo
from pytensor.scan.utils import expand_empty, safe_new, until
from pytensor.tensor.basic import get_scalar_constant_value
from pytensor.tensor.basic import get_underlying_scalar_constant_value
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import minimum
from pytensor.tensor.shape import shape_padleft, unbroadcast
Expand Down Expand Up @@ -147,7 +147,7 @@ def isNaN_or_Inf_or_None(x):
isStr = False
if not isNaN and not isInf:
try:
val = get_scalar_constant_value(x)
val = get_underlying_scalar_constant_value(x)
isInf = np.isinf(val)
isNaN = np.isnan(val)
except Exception:
Expand Down Expand Up @@ -476,7 +476,7 @@ def wrap_into_list(x):
n_fixed_steps = int(n_steps)
else:
try:
n_fixed_steps = at.get_scalar_constant_value(n_steps)
n_fixed_steps = at.get_underlying_scalar_constant_value(n_steps)
except NotScalarConstantError:
n_fixed_steps = None

Expand Down
10 changes: 7 additions & 3 deletions pytensor/scan/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@
safe_new,
scan_can_remove_outs,
)
from pytensor.tensor.basic import Alloc, AllocEmpty, get_scalar_constant_value
from pytensor.tensor.basic import (
Alloc,
AllocEmpty,
get_underlying_scalar_constant_value,
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import Dot, dot, maximum, minimum
Expand Down Expand Up @@ -1956,13 +1960,13 @@ def belongs_to_set(self, node, set_nodes):

nsteps = node.inputs[0]
try:
nsteps = int(get_scalar_constant_value(nsteps))
nsteps = int(get_underlying_scalar_constant_value(nsteps))
except NotScalarConstantError:
pass

rep_nsteps = rep.inputs[0]
try:
rep_nsteps = int(get_scalar_constant_value(rep_nsteps))
rep_nsteps = int(get_underlying_scalar_constant_value(rep_nsteps))
except NotScalarConstantError:
pass

Expand Down
53 changes: 37 additions & 16 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,26 @@ def _obj_is_wrappable_as_tensor(x):


def get_scalar_constant_value(
v, elemwise=True, only_process_constants=False, max_recur=10
):
"""
Checks whether 'v' is a scalar (ndim = 0).

If 'v' is a scalar then this function fetches the underlying constant by calling
'get_underlying_scalar_constant_value()'.

If 'v' is not a scalar, it raises a NotScalarConstantError.

"""
if isinstance(v, (Variable, np.ndarray)):
if v.ndim != 0:
raise NotScalarConstantError()
return get_underlying_scalar_constant_value(
v, elemwise, only_process_constants, max_recur
)


def get_underlying_scalar_constant_value(
orig_v, elemwise=True, only_process_constants=False, max_recur=10
):
"""Return the constant scalar(0-D) value underlying variable `v`.
Expand Down Expand Up @@ -358,7 +378,7 @@ def get_scalar_constant_value(
elif isinstance(v.owner.op, CheckAndRaise):
# check if all conditions are constant and true
conds = [
get_scalar_constant_value(c, max_recur=max_recur)
get_underlying_scalar_constant_value(c, max_recur=max_recur)
for c in v.owner.inputs[1:]
]
if builtins.all(0 == c.ndim and c != 0 for c in conds):
Expand All @@ -372,7 +392,7 @@ def get_scalar_constant_value(
continue
if isinstance(v.owner.op, _scalar_constant_value_elemwise_ops):
const = [
get_scalar_constant_value(i, max_recur=max_recur)
get_underlying_scalar_constant_value(i, max_recur=max_recur)
for i in v.owner.inputs
]
ret = [[None]]
Expand All @@ -391,7 +411,7 @@ def get_scalar_constant_value(
v.owner.op.scalar_op, _scalar_constant_value_elemwise_ops
):
const = [
get_scalar_constant_value(i, max_recur=max_recur)
get_underlying_scalar_constant_value(i, max_recur=max_recur)
for i in v.owner.inputs
]
ret = [[None]]
Expand Down Expand Up @@ -437,7 +457,7 @@ def get_scalar_constant_value(
):
idx = v.owner.op.idx_list[0]
if isinstance(idx, Type):
idx = get_scalar_constant_value(
idx = get_underlying_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur
)
try:
Expand Down Expand Up @@ -471,14 +491,14 @@ def get_scalar_constant_value(
):
idx = v.owner.op.idx_list[0]
if isinstance(idx, Type):
idx = get_scalar_constant_value(
idx = get_underlying_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur
)
# Python 2.4 does not support indexing with numpy.integer
# So we cast it.
idx = int(idx)
ret = v.owner.inputs[0].owner.inputs[idx]
ret = get_scalar_constant_value(ret, max_recur=max_recur)
ret = get_underlying_scalar_constant_value(ret, max_recur=max_recur)
# MakeVector can cast implicitly its input in some case.
return _asarray(ret, dtype=v.type.dtype)

Expand All @@ -493,7 +513,7 @@ def get_scalar_constant_value(
idx_list = op.idx_list
idx = idx_list[0]
if isinstance(idx, Type):
idx = get_scalar_constant_value(
idx = get_underlying_scalar_constant_value(
owner.inputs[1], max_recur=max_recur
)
grandparent = leftmost_parent.owner.inputs[0]
Expand All @@ -508,7 +528,7 @@ def get_scalar_constant_value(

if not (idx < ndim):
msg = (
"get_scalar_constant_value detected "
"get_underlying_scalar_constant_value detected "
f"deterministic IndexError: x.shape[{int(idx)}] "
f"when x.ndim={int(ndim)}."
)
Expand Down Expand Up @@ -1570,7 +1590,7 @@ def do_constant_folding(self, fgraph, node):
@_get_vector_length.register(Alloc)
def _get_vector_length_Alloc(var_inst, var):
try:
return get_scalar_constant_value(var.owner.inputs[1])
return get_underlying_scalar_constant_value(var.owner.inputs[1])
except NotScalarConstantError:
raise ValueError(f"Length of {var} cannot be determined")

Expand Down Expand Up @@ -1821,17 +1841,17 @@ def perform(self, node, inp, out_):

def extract_constant(x, elemwise=True, only_process_constants=False):
"""
This function is basically a call to tensor.get_scalar_constant_value.
This function is basically a call to tensor.get_underlying_scalar_constant_value.

The main difference is the behaviour in case of failure. While
get_scalar_constant_value raises an TypeError, this function returns x,
get_underlying_scalar_constant_value raises an TypeError, this function returns x,
as a tensor if possible. If x is a ScalarVariable from a
scalar_from_tensor, we remove the conversion. If x is just a
ScalarVariable, we convert it to a tensor with tensor_from_scalar.

"""
try:
x = get_scalar_constant_value(x, elemwise, only_process_constants)
x = get_underlying_scalar_constant_value(x, elemwise, only_process_constants)
except NotScalarConstantError:
pass
if isinstance(x, aes.ScalarVariable) or isinstance(
Expand Down Expand Up @@ -2201,7 +2221,7 @@ def make_node(self, axis, *tensors):

if not isinstance(axis, int):
try:
axis = int(get_scalar_constant_value(axis))
axis = int(get_underlying_scalar_constant_value(axis))
except NotScalarConstantError:
pass

Expand Down Expand Up @@ -2450,7 +2470,7 @@ def infer_shape(self, fgraph, node, ishapes):
def _get_vector_length_Join(op, var):
axis, *arrays = var.owner.inputs
try:
axis = get_scalar_constant_value(axis)
axis = get_underlying_scalar_constant_value(axis)
assert axis == 0 and builtins.all(a.ndim == 1 for a in arrays)
return builtins.sum(get_vector_length(a) for a in arrays)
except NotScalarConstantError:
Expand Down Expand Up @@ -2862,7 +2882,7 @@ def infer_shape(self, fgraph, node, i_shapes):

def is_constant_value(var, value):
try:
v = get_scalar_constant_value(var)
v = get_underlying_scalar_constant_value(var)
return np.all(v == value)
except NotScalarConstantError:
pass
Expand Down Expand Up @@ -3774,7 +3794,7 @@ def make_node(self, a, choices):
static_out_shape = ()
for s in out_shape:
try:
s_val = pytensor.get_scalar_constant_value(s)
s_val = pytensor.get_underlying_scalar_constant(s)
except (NotScalarConstantError, AttributeError):
s_val = None

Expand Down Expand Up @@ -4095,6 +4115,7 @@ def take_along_axis(arr, indices, axis=0):
"scalar_from_tensor",
"tensor_from_scalar",
"get_scalar_constant_value",
"get_underlying_scalar_constant_value",
"constant",
"as_tensor_variable",
"as_tensor",
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1834,7 +1834,7 @@ def local_gemm_to_ger(fgraph, node):
xv = x.dimshuffle(0)
yv = y.dimshuffle(1)
try:
bval = at.get_scalar_constant_value(b)
bval = at.get_underlying_scalar_constant_value(b)
except NotScalarConstantError:
# b isn't a constant, GEMM is doing useful pre-scaling
return
Expand Down
19 changes: 13 additions & 6 deletions pytensor/tensor/conv/abstract_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.op import Op
from pytensor.raise_op import Assert
from pytensor.tensor.basic import as_tensor_variable, get_scalar_constant_value
from pytensor.tensor.basic import (
as_tensor_variable,
get_underlying_scalar_constant_value,
)
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.var import TensorConstant, TensorVariable

Expand Down Expand Up @@ -495,8 +498,8 @@ def check_dim(given, computed):
if given is None or computed is None:
return True
try:
given = get_scalar_constant_value(given)
computed = get_scalar_constant_value(computed)
given = get_underlying_scalar_constant_value(given)
computed = get_underlying_scalar_constant_value(computed)
return int(given) == int(computed)
except NotScalarConstantError:
# no answer possible, accept for now
Expand Down Expand Up @@ -532,7 +535,7 @@ def assert_conv_shape(shape):
out_shape = []
for i, n in enumerate(shape):
try:
const_n = get_scalar_constant_value(n)
const_n = get_underlying_scalar_constant_value(n)
if i < 2:
if const_n < 0:
raise ValueError(
Expand Down Expand Up @@ -2200,7 +2203,9 @@ def __init__(
if imshp_i is not None:
# Components of imshp should be constant or ints
try:
get_scalar_constant_value(imshp_i, only_process_constants=True)
get_underlying_scalar_constant_value(
imshp_i, only_process_constants=True
)
except NotScalarConstantError:
raise ValueError(
"imshp should be None or a tuple of constant int values"
Expand All @@ -2213,7 +2218,9 @@ def __init__(
if kshp_i is not None:
# Components of kshp should be constant or ints
try:
get_scalar_constant_value(kshp_i, only_process_constants=True)
get_underlying_scalar_constant_value(
kshp_i, only_process_constants=True
)
except NotScalarConstantError:
raise ValueError(
"kshp should be None or a tuple of constant int values"
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ def perform(self, node, inputs, output_storage):
ufunc = self.ufunc
elif not hasattr(node.tag, "ufunc"):
# It happen that make_thunk isn't called, like in
# get_scalar_constant_value
# get_underlying_scalar_constant_value
self.prepare_node(node, None, None, "py")
# prepare_node will add ufunc to self or the tag
# depending if we can reuse it or not. So we need to
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ class ShapeError(Exception):

class NotScalarConstantError(Exception):
"""
Raised by get_scalar_constant_value if called on something that is
Raised by get_underlying_scalar_constant_value if called on something that is
not a scalar constant.
"""

Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def make_node(self, x, repeats):
out_shape = [None]
else:
try:
const_reps = at.get_scalar_constant_value(repeats)
const_reps = at.get_underlying_scalar_constant_value(repeats)
except NotScalarConstantError:
const_reps = None
if const_reps == 1:
Expand Down