diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 87b8e380d3..08a28af085 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -281,7 +281,7 @@ def numba_typify(data, dtype=None, **kwargs): return data -def generate_fallback_impl(op, node=None, storage_map=None, **kwargs): +def generate_fallback_impl(op, node, storage_map=None, **kwargs): """Create a Numba compatible function from a Pytensor `Op`.""" warnings.warn( diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index f7700acf47..8f7734080c 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -6,7 +6,11 @@ from pytensor.graph import Apply from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify +from pytensor.link.numba.dispatch.basic import ( + generate_fallback_impl, + get_numba_type, + numba_funcify, +) from pytensor.raise_op import CheckAndRaise from pytensor.tensor import TensorVariable from pytensor.tensor.extra_ops import ( @@ -200,45 +204,19 @@ def ravelmultiindex(*inp): @numba_funcify.register(Repeat) def numba_funcify_Repeat(op, node, **kwargs): axis = op.axis + a, _ = node.inputs - use_python = False - - if axis is not None: - use_python = True - - if use_python: - warnings.warn( - ( - "Numba will use object mode to allow the " - "`axis` argument to `numpy.repeat`." - ), - UserWarning, - ) - - ret_sig = get_numba_type(node.outputs[0].type) + # Numba only supports axis=None, which in our case is when axis is 0 and the input is a vector + if axis == 0 and a.type.ndim == 1: - @numba_basic.numba_njit + @numba_basic.numba_njit(inline="always") def repeatop(x, repeats): - with numba.objmode(ret=ret_sig): - ret = np.repeat(x, repeats, axis) - return ret + return np.repeat(x, repeats) - else: - repeats_ndim = node.inputs[1].ndim + return repeatop - if repeats_ndim == 0: - - @numba_basic.numba_njit(inline="always") - def repeatop(x, repeats): - return np.repeat(x, repeats.item()) - - else: - - @numba_basic.numba_njit(inline="always") - def repeatop(x, repeats): - return np.repeat(x, repeats) - - return repeatop + else: + return generate_fallback_impl(op, node) @numba_funcify.register(Unique) diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index a6eafcf485..80fded9697 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -668,53 +668,72 @@ class Repeat(Op): __props__ = ("axis",) - def __init__(self, axis: int | None = None): - if axis is not None: - if not isinstance(axis, int) or axis < 0: + def __init__(self, axis: int): + if isinstance(axis, int): + if axis < 0: raise ValueError( - f"Repeat only accepts positive integer axis or None, got {axis}" + f"Repeat Op only accepts positive integer axis, got {axis}. " + "Use the helper `pt.repeat` to handle negative axis." ) + elif axis is None: + raise ValueError( + "Repeat Op only accepts positive integer axis. " + "Use the helper `pt.repeat` to handle axis=None." + ) + else: + raise TypeError( + f"Invalid type for axis {axis}, expected int got {type(axis)}" + ) + self.axis = axis def make_node(self, x, repeats): x = ptb.as_tensor_variable(x) repeats = ptb.as_tensor_variable(repeats, dtype="int64") - if repeats.dtype not in integer_dtypes: - raise TypeError("repeats.dtype must be an integer.") + if repeats.type.ndim != 1: + if repeats.type.ndim == 0: + raise ValueError( + f"repeats {repeats} must have 1 dimension, got 0. Use the helper `pt.repeat` to handle scalar repeats." + ) + else: + raise ValueError( + f"repeats {repeats} must have 1 dimension, got {repeats.type.ndim}" + ) + + if repeats.type.dtype not in integer_dtypes: + raise TypeError( + f"repeats {repeats} dtype must be an integer, got {repeats.type.dtype}." + ) # Some dtypes are not supported by numpy's implementation of repeat. # Until another one is available, we should fail at graph construction # time, not wait for execution. - ptr_bitwidth = LOCAL_BITWIDTH - if ptr_bitwidth == 64: - numpy_unsupported_dtypes = ("uint64",) - if ptr_bitwidth == 32: - numpy_unsupported_dtypes = ("uint32", "int64", "uint64") - - if repeats.dtype in numpy_unsupported_dtypes: + numpy_unsupported_dtypes = ( + ("uint64",) if LOCAL_BITWIDTH == 64 else ("uint64", "uint32", "int64") + ) + if repeats.type.dtype in numpy_unsupported_dtypes: raise TypeError( - ( - f"dtypes {numpy_unsupported_dtypes!s} are not supported by numpy.repeat " - "for the 'repeats' parameter, " - ), - repeats.dtype, + f"repeats {repeats} dtype {repeats.type.dtype} are not supported by numpy.repeat" ) - if self.axis is None: - out_shape = [None] - else: + shape = list(x.type.shape) + axis_input_dim_length = shape[self.axis] + axis_output_dim_length = None + + if axis_input_dim_length is not None: + # If we have a static dim and constant repeats we can infer the length of the output dim + # Right now we only support homogenous constant repeats try: - const_reps = ptb.get_scalar_constant_value(repeats) + const_reps = ptb.get_underlying_scalar_constant_value(repeats) except NotScalarConstantError: - const_reps = None - if const_reps == 1: - out_shape = x.type.shape + pass else: - out_shape = list(x.type.shape) - out_shape[self.axis] = None + axis_output_dim_length = int(const_reps * axis_input_dim_length) + + shape[self.axis] = axis_output_dim_length - out_type = TensorType(x.dtype, shape=out_shape) + out_type = TensorType(x.dtype, shape=shape) return Apply(self, [x, repeats], [out_type()]) def perform(self, node, inputs, output_storage): @@ -728,36 +747,16 @@ def grad(self, inputs, gout): (x, repeats) = inputs (gz,) = gout axis = self.axis - if repeats.ndim == 0: - # When axis is a scalar (same number of reps for all elements), - # We can split the repetitions into their own axis with reshape and sum them back - # to the original element location - sum_axis = x.ndim if axis is None else axis + 1 - shape = list(x.shape) - shape.insert(sum_axis, repeats) - gx = gz.reshape(shape).sum(axis=sum_axis) - - elif repeats.ndim == 1: - # To sum the gradients that belong to the same repeated x, - # We create a repeated eye and dot product it with the gradient. - axis_size = x.size if axis is None else x.shape[axis] - repeated_eye = repeat( - ptb.eye(axis_size), repeats, axis=0 - ) # A sparse repeat would be neat - - if axis is None: - gx = gz @ repeated_eye - # Undo the ravelling when axis=None - gx = gx.reshape(x.shape) - else: - # Place gradient axis at end for dot product - gx = ptb.moveaxis(gz, axis, -1) - gx = gx @ repeated_eye - # Place gradient back into the correct axis - gx = ptb.moveaxis(gx, -1, axis) - else: - raise ValueError() + # Use IncSubtensor to sum the gradients that belong to the repeated entries of x + axis_size = x.shape[axis] + repeated_arange = repeat(ptb.arange(axis_size), repeats, axis=0) + + # Move the axis to repeat to front for easier indexing + x_transpose = ptb.moveaxis(x, axis, 0) + gz_transpose = ptb.moveaxis(gz, axis, 0) + gx_transpose = ptb.zeros_like(x_transpose)[repeated_arange].inc(gz_transpose) + gx = ptb.moveaxis(gx_transpose, 0, axis) return [gx, DisconnectedType()()] @@ -771,22 +770,8 @@ def infer_shape(self, fgraph, node, ins_shapes): dtype = None if repeats.dtype in ("uint8", "uint16", "uint32"): dtype = "int64" - if axis is None: - if repeats.ndim == 0: - if len(i0_shapes) == 0: - out_shape = [repeats] - else: - res = 1 - for d in i0_shapes: - res = res * d - out_shape = (res * repeats,) - else: - out_shape = [pt_sum(repeats, dtype=dtype)] - else: - if repeats.ndim == 0: - out_shape[axis] = out_shape[axis] * repeats - else: - out_shape[axis] = pt_sum(repeats, dtype=dtype) + + out_shape[axis] = pt_sum(repeats, dtype=dtype) return [out_shape] @@ -851,7 +836,10 @@ def repeat( """ a = ptb.as_tensor_variable(a) - if axis is not None: + if axis is None: + axis = 0 + a = a.flatten() + else: axis = normalize_axis_index(axis, a.ndim) repeats = ptb.as_tensor_variable(repeats, dtype=np.int64) @@ -859,40 +847,31 @@ def repeat( if repeats.ndim > 1: raise ValueError("The dimension of repeats should not exceed 1.") - if repeats.ndim == 1 and not repeats.broadcastable[0]: + if repeats.type.broadcastable == (True,): + # This behaves the same as scalar repeat + repeats = repeats.squeeze() + + if repeats.ndim == 1: # We only use the Repeat Op for vector repeats return Repeat(axis=axis)(a, repeats) else: - if repeats.ndim == 1: - repeats = repeats[0] - if a.dtype == "uint64": # Multiplying int64 (shape) by uint64 (repeats) yields a float64 # Which is not valid for the `reshape` operation at the end raise TypeError("repeat doesn't support dtype uint64") - if axis is None: - axis = 0 - a = a.flatten() - - repeat_shape = list(a.shape) + # Scalar repeat, we implement this with canonical Ops broadcast + reshape + a_shape = a.shape - # alloc_shape is the shape of the intermediate tensor which has - # an additional dimension comparing to x. We use alloc to - # allocate space for this intermediate tensor to replicate x - # along that additional dimension. - alloc_shape = repeat_shape[:] - alloc_shape.insert(axis + 1, repeats) + # Replicate a along a new axis (axis+1) repeats times + broadcast_shape = list(a_shape) + broadcast_shape.insert(axis + 1, repeats) + broadcast_a = broadcast_to(ptb.expand_dims(a, axis + 1), broadcast_shape) - # repeat_shape is now the shape of output, where shape[axis] becomes - # shape[axis]*repeats. + # Reshape broadcast_a to the final shape, merging axis and axis+1 + repeat_shape = list(a_shape) repeat_shape[axis] = repeat_shape[axis] * repeats - - # After the original tensor is duplicated along the additional - # dimension, we reshape it to the expected output shape - return ptb.alloc(ptb.expand_dims(a, axis + 1), *alloc_shape).reshape( - repeat_shape - ) + return broadcast_a.reshape(repeat_shape) class Bartlett(Op): diff --git a/pytensor/utils.py b/pytensor/utils.py index c81fb74f56..d878cafac8 100644 --- a/pytensor/utils.py +++ b/pytensor/utils.py @@ -34,7 +34,6 @@ Note that according to Python documentation, `platform.architecture()` is not reliable on OS X with universal binaries. -Also, sys.maxsize does not exist in Python < 2.6. 'P' denotes a void*, and the size is expressed in bytes. """ diff --git a/tests/link/numba/test_extra_ops.py b/tests/link/numba/test_extra_ops.py index e9b6700c63..0c7813e791 100644 --- a/tests/link/numba/test_extra_ops.py +++ b/tests/link/numba/test_extra_ops.py @@ -212,27 +212,15 @@ def test_RavelMultiIndex(arr, shape, mode, order, exc): @pytest.mark.parametrize( "x, repeats, axis, exc", [ - ( - (pt.lscalar(), np.array(1, dtype="int64")), - (pt.lscalar(), np.array(0, dtype="int64")), - None, - None, - ), - ( - (pt.lmatrix(), np.zeros((2, 2), dtype="int64")), - (pt.lscalar(), np.array(1, dtype="int64")), - None, - None, - ), ( (pt.lvector(), np.arange(2, dtype="int64")), - (pt.lvector(), np.array([1, 1], dtype="int64")), - None, + (pt.lvector(), np.array([1, 3], dtype="int64")), + 0, None, ), ( (pt.lmatrix(), np.zeros((2, 2), dtype="int64")), - (pt.lscalar(), np.array(1, dtype="int64")), + (pt.lvector(), np.array([1, 3], dtype="int64")), 0, UserWarning, ), diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index 352238adec..705afc3c54 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -530,7 +530,7 @@ def _possible_axis(self, ndim): def setup_method(self): super().setup_method() self.op_class = Repeat - self.op = Repeat() + self.op = Repeat(axis=0) # uint64 always fails # int64 and uint32 also fail if python int are 32-bit if LOCAL_BITWIDTH == 64: @@ -595,43 +595,30 @@ def test_basic(self, ndim, dtype): def test_infer_shape(self, ndim, dtype): rng = np.random.default_rng(4282) - x = TensorType(config.floatX, shape=(None,) * ndim)() + a_var = TensorType(config.floatX, shape=(None,) * ndim)("a") + r_var = vector("r", dtype=dtype) + shp = (np.arange(ndim) + 1) * 3 a = rng.random(shp).astype(config.floatX) for axis in self._possible_axis(ndim): - if axis is not None and axis < 0: - # Operator does not support negative axis - continue - - r_var = scalar(dtype=dtype) - r = np.asarray(3, dtype=dtype) if dtype in self.numpy_unsupported_dtypes: - r_var = vector(dtype=dtype) with pytest.raises(TypeError): - repeat(x, r_var) - else: - self._compile_and_check( - [x, r_var], - [Repeat(axis=axis)(x, r_var)], - [a, r], - self.op_class, - ) + repeat(a_var, r_var, axis=axis) + continue - r_var = vector(dtype=dtype) - if axis is None: - r = rng.integers(1, 6, size=a.size).astype(dtype) - elif a.size > 0: - r = rng.integers(1, 6, size=a.shape[axis]).astype(dtype) - else: - r = rng.integers(1, 6, size=(10,)).astype(dtype) + if axis is None or axis < 0: + # Operator Repeat does not support None or negative axis + continue - self._compile_and_check( - [x, r_var], - [Repeat(axis=axis)(x, r_var)], - [a, r], - self.op_class, - ) + r = rng.integers(1, 6, size=a.shape[axis]).astype(dtype) + + self._compile_and_check( + [a_var, r_var], + [Repeat(axis=axis)(a_var, r_var)], + [a, r], + self.op_class, + ) @pytest.mark.parametrize("x_ndim", [2, 3], ids=lambda x: f"x_ndim={x}") @pytest.mark.parametrize("repeats_ndim", [0, 1], ids=lambda r: f"repeats_ndim={r}") @@ -647,18 +634,38 @@ def test_grad(self, x_ndim, repeats_ndim, axis): repeats_size = (x_test.shape[axis] if axis is not None else x_test.size,) repeats = rng.integers(1, 6, size=repeats_size) utt.verify_grad( - lambda x: Repeat(axis=axis)(x, repeats), + lambda x: repeat(x, repeats, axis=axis), [x_test], ) - def test_broadcastable(self): - x = TensorType(config.floatX, shape=(None, 1, None))() - r = Repeat(axis=1)(x, 2) - assert r.broadcastable == (False, False, False) - r = Repeat(axis=1)(x, 1) - assert r.broadcastable == (False, True, False) - r = Repeat(axis=0)(x, 2) - assert r.broadcastable == (False, True, False) + def test_static_shape(self): + x = TensorType(config.floatX, shape=(None, 1, 3))() + symbolic_r = scalar(dtype="int32") + + r = repeat(x, 2, axis=0) + assert r.type.shape == (None, 1, 3) + + r = repeat(x, 2, axis=1) + assert r.type.shape == (None, 2, 3) + + r = repeat(x, [2], axis=1) + assert r.type.shape == (None, 2, 3) + + r = repeat(x, symbolic_r, axis=1) + assert r.type.shape == (None, None, 3) + + r = repeat(x, 1, axis=1) + assert r.type.shape == (None, 1, 3) + + r = repeat(x, 2, axis=2) + assert r.type.shape == (None, 1, 6) + + r = repeat(x, [2, 2, 2], axis=2) + assert r.type.shape == (None, 1, 6) + + # This case could be implemented in the future + r = repeat(x, [1, 2, 4], axis=2) + assert r.type.shape == (None, 1, None) class TestBartlett(utt.InferShapeTester):