-
-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prevent observation masks improperly expanding samples in plates #3317
Conversation
Hi @austinv11, thanks for this fix! This should be able to merge after we fix our CI in #3318. It would be great to have a regression test for this, maybe editing or forking test_obs_mask_multivariate. (Actually I'm surprised and disappointed that test didn't catch this bug 🤔). Would you be up to either add a regression test in this PR or give us some hints about what a test might look like, based on your models that found this error? |
Hey @fritzo, I tried modifying some tests to see if I could trigger the error, but I am having difficulty replicating it outside of my model. Perhaps I don't understand the dimension broadcasting mechanisms in pyro enough. But here is snippets from my svi model that always creates issues when using the def model(...):
with pyro.poutine.scale(scale=annealing_factor):
with pyro.plate(
"cell_ligand_plate", total_cells, dim=-2, subsample=batch_idx
):
with pyro.plate("ligand_plate", n_ligands, dim=-1):
lavail = pyro.sample(
"ligand_availability",
dist.ContinuousBernoulli(
logits=ligand_available_logits[
data.samples.argmax(1)
].unsqueeze(-1)
).to_event(1),
obs=(data.ligand_X > 0).unsqueeze(-1).float(),
obs_mask=(data.ligand_X > 0),
)
def guide(...):
with pyro.poutine.scale(scale=annealing_factor):
with pyro.plate(
"cell_ligand_plate", total_cells, dim=-2, subsample=batch_idx
):
with pyro.plate("ligand_plate", n_ligands, dim=-1):
with pyro.poutine.mask(mask=data.ligand_X <= 0):
# Predict ligand availability from the current cell's profile
pyro.sample(
"ligand_availability_unobserved",
dist.ContinuousBernoulli(
logits=self._predict_ligand_activation_from_cell(
data.n_genes, data.n_ligands, data.dense_X.log1p()
).unsqueeze(-1)
).to_event(1),
)
...
predictive = pyro.infer.Predictive(
model,
guide=guide,
num_samples=replicates,
parallel=False,
)
# Calls to predictive(...) will now have dimensionality issues, but during training with svi does not |
Hi @austinv11, thanks for providing the example. Here's a regression test we could use, on the 3317-regression branch diff --git a/tests/test_primitives.py b/tests/test_primitives.py
index 663e3a67..c208a2cf 100644
--- a/tests/test_primitives.py
+++ b/tests/test_primitives.py
@@ -1,11 +1,14 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
+from typing import Optional
+
import pytest
import torch
import pyro
import pyro.distributions as dist
+from pyro import poutine
pytestmark = pytest.mark.stage("unit")
@@ -31,3 +34,30 @@ def test_deterministic_ok():
x = pyro.deterministic("x", torch.tensor(0.0))
assert isinstance(x, torch.Tensor)
assert x.shape == ()
+
+
+@pytest.mark.parametrize(
+ "mask",
+ [
+ None,
+ torch.tensor(True),
+ torch.tensor([True]),
+ torch.tensor([True, False, True]),
+ ],
+)
+def test_obs_mask_shape(mask: Optional[torch.Tensor]):
+ data = torch.randn(3, 2)
+
+ def model():
+ with pyro.plate("data", 3):
+ pyro.sample(
+ "y",
+ dist.MultivariateNormal(torch.zeros(2), scale_tril=torch.eye(2)),
+ obs=data,
+ obs_mask=mask,
+ )
+
+ trace = poutine.trace(model).get_trace()
+ y_dist = trace.nodes["y"]["fn"]
+ assert y_dist.batch_shape == (3,)
+ assert y_dist.event_shape == (2,) Could you merge in recent changes to dev (so CI passes), and add this test to your branch? I'd like to get your fix into our upcoming 1.9 release. Thanks again! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but let's add a regression test.
Hey @fritzo I appreciate your effort in developing the regression test! My fix passes that test locally. Just updated the PR. |
@austinv11 BTW how did you do that cross-repo cherry-pick or merge? Did you do that in the github gui or git command line? |
@fritzo It was a little bit of a pain, but I was able to do it using the Git interface in PyCharm (added the original repo as a remote, fetched its changes, then PyCharm let me choose that commit and cherrypick) |
Occasionally (I haven't quite determined why this happens, I suspect it has to do with nesting subsampling plates) I have noticed that partially masked observations will double in shape
(A,B,C) -> (A, B, C, A, B, C)
when using thePredictive
interface.This patch fixes this and all tests related to masking appear to pass.