diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 005e1d55fa..3d29c0d81c 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -20,6 +20,7 @@ ) from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType from pytensor.tensor.random.utils import ( + compute_batch_shape, explicit_expand_dims, normalize_size_param, ) @@ -403,15 +404,14 @@ def vectorize_random_variable( original_expanded_dist_params, dict(zip(original_dist_params, dist_params)) ) - if len_old_size and equal_computations([old_size], [size]): + new_ndim = dist_params[0].type.ndim - original_expanded_dist_params[0].type.ndim + + if new_ndim and len_old_size and equal_computations([old_size], [size]): # If the original RV had a size variable and a new one has not been provided, # we need to define a new size as the concatenation of the original size dimensions # and the novel ones implied by new broadcasted batched parameters dimensions. - # We use the first broadcasted batch dimension for reference. - bcasted_param = explicit_expand_dims(dist_params, op.ndims_params)[0] - new_param_ndim = (bcasted_param.type.ndim - op.ndims_params[0]) - len_old_size - if new_param_ndim >= 0: - new_size_dims = bcasted_param.shape[:new_param_ndim] - size = concatenate([new_size_dims, size]) + broadcasted_batch_shape = compute_batch_shape(dist_params, op.ndims_params) + new_size_dims = broadcasted_batch_shape[:new_ndim] + size = concatenate([new_size_dims, size]) return op.make_node(rng, size, dtype, *dist_params) diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py index 5d74a16e20..8f9f77d536 100644 --- a/pytensor/tensor/random/utils.py +++ b/pytensor/tensor/random/utils.py @@ -11,7 +11,7 @@ from pytensor.scalar import ScalarVariable from pytensor.tensor import get_vector_length from pytensor.tensor.basic import as_tensor_variable, cast, constant -from pytensor.tensor.extra_ops import broadcast_to +from pytensor.tensor.extra_ops import broadcast_arrays, broadcast_to from pytensor.tensor.math import maximum from pytensor.tensor.shape import shape_padleft, specify_shape from pytensor.tensor.type import int_dtypes @@ -149,6 +149,14 @@ def explicit_expand_dims( return new_params +def compute_batch_shape(params, ndims_params: Sequence[int]) -> TensorVariable: + params = explicit_expand_dims(params, ndims_params) + batch_params = [ + param[..., *(0,) * core_ndim] for param, core_ndim in zip(params, ndims_params) + ] + return broadcast_arrays(*batch_params)[0].shape + + def normalize_size_param( size: int | np.ndarray | Variable | Sequence | None, ) -> Variable: diff --git a/tests/tensor/random/test_op.py b/tests/tensor/random/test_op.py index f801ab731e..80ee011015 100644 --- a/tests/tensor/random/test_op.py +++ b/tests/tensor/random/test_op.py @@ -292,6 +292,14 @@ def test_vectorize_node(): assert vect_node.op is normal assert vect_node.default_output().type.shape == (10, 5) + node = normal(vec, size=(5,)).owner + new_inputs = node.inputs.copy() + new_inputs[3] = tensor("mu", shape=(1, 5)) # mu + new_inputs[4] = tensor("sigma", shape=(10,)) # sigma + vect_node = vectorize_node(node, *new_inputs) + assert vect_node.op is normal + assert vect_node.default_output().type.shape == (10, 5) + # Test parameter broadcasting with expanding size node = normal(vec, size=(2, 5)).owner new_inputs = node.inputs.copy()