Skip to content

Commit

Permalink
Fix PyroModule rendering error in local-parameter mode (#3366)
Browse files Browse the repository at this point in the history
* test

* add constraint kwarg to fake param statements
  • Loading branch information
eb8680 committed May 7, 2024
1 parent ca36025 commit 7511353
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
16 changes: 13 additions & 3 deletions pyro/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,12 @@ def __getattr__(self, name: str) -> Any:
constrained_value.unconstrained = weakref.ref(unconstrained_value)
return pyro.poutine.runtime.effectful(type="param")(
lambda *_, **__: constrained_value
)(fullname, event_dim=event_dim, name=fullname)
)(
fullname,
constraint=constraint,
event_dim=event_dim,
name=fullname,
)
else: # Cannot determine supermodule and hence cannot compute fullname.
constrained_value = transform_to(constraint)(unconstrained_value)
constrained_value.unconstrained = weakref.ref(unconstrained_value)
Expand Down Expand Up @@ -621,7 +626,7 @@ def __getattr__(self, name: str) -> Any:
# even though we don't use the contents of the local parameter store
fullname = self._pyro_get_fullname(name)
pyro.poutine.runtime.effectful(type="param")(lambda *_, **__: result)(
fullname, result, name=fullname
fullname, result, constraint=constraints.real, name=fullname
)

if isinstance(result, torch.nn.Module):
Expand All @@ -645,7 +650,12 @@ def __getattr__(self, name: str) -> Any:
)
pyro.poutine.runtime.effectful(type="param")(
lambda *_, **__: param_value
)(fullname_param, param_value, name=fullname_param)
)(
fullname_param,
param_value,
constraint=constraints.real,
name=fullname_param,
)

return result

Expand Down
21 changes: 21 additions & 0 deletions tests/nn/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,3 +1045,24 @@ def test_with_slice_indexing(self) -> None:

def test_module_list() -> None:
assert PyroModule[torch.nn.ModuleList] is pyro.nn.PyroModuleList


@pytest.mark.parametrize("use_module_local_params", [True, False])
def test_render_constrained_param(use_module_local_params):

class Model(PyroModule):

@PyroParam(constraint=constraints.positive)
def x(self):
return torch.tensor(1.234)

@PyroParam(constraint=constraints.real)
def y(self):
return torch.tensor(0.456)

def forward(self):
return self.x + self.y

with pyro.settings.context(module_local_params=use_module_local_params):
model = Model()
pyro.render_model(model)

0 comments on commit 7511353

Please sign in to comment.