Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix broadcasting bug in vectorization of RandomVariables #738

Merged
merged 1 commit into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions pytensor/tensor/random/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
from pytensor.tensor.random.utils import (
compute_batch_shape,
explicit_expand_dims,
normalize_size_param,
)
Expand Down Expand Up @@ -403,15 +404,14 @@ def vectorize_random_variable(
original_expanded_dist_params, dict(zip(original_dist_params, dist_params))
)

if len_old_size and equal_computations([old_size], [size]):
new_ndim = dist_params[0].type.ndim - original_expanded_dist_params[0].type.ndim

if new_ndim and len_old_size and equal_computations([old_size], [size]):
# If the original RV had a size variable and a new one has not been provided,
# we need to define a new size as the concatenation of the original size dimensions
# and the novel ones implied by new broadcasted batched parameters dimensions.
# We use the first broadcasted batch dimension for reference.
bcasted_param = explicit_expand_dims(dist_params, op.ndims_params)[0]
new_param_ndim = (bcasted_param.type.ndim - op.ndims_params[0]) - len_old_size
if new_param_ndim >= 0:
new_size_dims = bcasted_param.shape[:new_param_ndim]
size = concatenate([new_size_dims, size])
broadcasted_batch_shape = compute_batch_shape(dist_params, op.ndims_params)
new_size_dims = broadcasted_batch_shape[:new_ndim]
size = concatenate([new_size_dims, size])

return op.make_node(rng, size, dtype, *dist_params)
11 changes: 10 additions & 1 deletion pytensor/tensor/random/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pytensor.scalar import ScalarVariable
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import as_tensor_variable, cast, constant
from pytensor.tensor.extra_ops import broadcast_to
from pytensor.tensor.extra_ops import broadcast_arrays, broadcast_to
from pytensor.tensor.math import maximum
from pytensor.tensor.shape import shape_padleft, specify_shape
from pytensor.tensor.type import int_dtypes
Expand Down Expand Up @@ -149,6 +149,15 @@ def explicit_expand_dims(
return new_params


def compute_batch_shape(params, ndims_params: Sequence[int]) -> TensorVariable:
params = explicit_expand_dims(params, ndims_params)
batch_params = [
param[(..., *(0,) * core_ndim)]
for param, core_ndim in zip(params, ndims_params)
]
return broadcast_arrays(*batch_params)[0].shape


def normalize_size_param(
size: int | np.ndarray | Variable | Sequence | None,
) -> Variable:
Expand Down
8 changes: 8 additions & 0 deletions tests/tensor/random/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,14 @@ def test_vectorize_node():
assert vect_node.op is normal
assert vect_node.default_output().type.shape == (10, 5)

node = normal(vec, size=(5,)).owner
new_inputs = node.inputs.copy()
new_inputs[3] = tensor("mu", shape=(1, 5)) # mu
new_inputs[4] = tensor("sigma", shape=(10,)) # sigma
vect_node = vectorize_node(node, *new_inputs)
assert vect_node.op is normal
assert vect_node.default_output().type.shape == (10, 5)

# Test parameter broadcasting with expanding size
node = normal(vec, size=(2, 5)).owner
new_inputs = node.inputs.copy()
Expand Down
Loading