Skip to content

Commit

Permalink
Prevent observation masks improperly expanding samples in plates (#3317)
Browse files Browse the repository at this point in the history
  • Loading branch information
austinv11 committed Feb 7, 2024
1 parent a52338c commit 6337ced
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _masked_observe(
f"broadcastable to batch_shape = {tuple(batch_shape)}"
) from e
raise
return deterministic(name, value)
return deterministic(name, value, event_dim=fn.event_dim)


def sample(
Expand Down
30 changes: 30 additions & 0 deletions tests/test_primitives.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from typing import Optional

import pytest
import torch

import pyro
import pyro.distributions as dist
from pyro import poutine

pytestmark = pytest.mark.stage("unit")

Expand All @@ -31,3 +34,30 @@ def test_deterministic_ok():
x = pyro.deterministic("x", torch.tensor(0.0))
assert isinstance(x, torch.Tensor)
assert x.shape == ()


@pytest.mark.parametrize(
"mask",
[
None,
torch.tensor(True),
torch.tensor([True]),
torch.tensor([True, False, True]),
],
)
def test_obs_mask_shape(mask: Optional[torch.Tensor]):
data = torch.randn(3, 2)

def model():
with pyro.plate("data", 3):
pyro.sample(
"y",
dist.MultivariateNormal(torch.zeros(2), scale_tril=torch.eye(2)),
obs=data,
obs_mask=mask,
)

trace = poutine.trace(model).get_trace()
y_dist = trace.nodes["y"]["fn"]
assert y_dist.batch_shape == (3,)
assert y_dist.event_shape == (2,)

0 comments on commit 6337ced

Please sign in to comment.