Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy committed Mar 17, 2024
1 parent f83c6c0 commit e867ee7
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions pyro/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __call__(
return PyroParam(init_value, self.constraint, self.event_dim)


@dataclass
@dataclass(frozen=True)
class PyroSample:
"""
Declares a Pyro-managed random attribute of a :class:`PyroModule`, similar
Expand Down Expand Up @@ -181,7 +181,8 @@ def __post_init__(self) -> None:
for p in inspect.signature(self.prior).parameters.values()
if p.default is inspect.Parameter.empty
), "prior should take the single argument 'self'"
self.name: Optional[str] = getattr(self.prior, "__name__", None)
object.__setattr__(self, "name", getattr(self.prior, "__name__", None))
self.name: Optional[str]
if self.name is not None:
# Ensure decorated function is accessible for pickling.
self.prior.__name__ = "_pyro_prior_" + self.prior.__name__
Expand Down Expand Up @@ -638,7 +639,7 @@ def __getattr__(self, name: str) -> Any:
def __setattr__(
self,
name: str,
value: Union[torch.Tensor, torch.nn.Module, PyroParam, PyroSample],
value: Any,
) -> None:
if isinstance(value, PyroModule):
# Create a new sub PyroModule, overwriting any old value.
Expand Down

0 comments on commit e867ee7

Please sign in to comment.