Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotate messengers #3309

Merged
merged 6 commits into from
Jan 6, 2024
Merged

Type annotate messengers #3309

merged 6 commits into from
Jan 6, 2024

Conversation

ordabayevy
Copy link
Member

  • plate_messenger
  • reentrant_messenger
  • reparam_messenger

pyro/distributions/torch_distribution.py Show resolved Hide resolved
is_observed: bool
args: Tuple
kwargs: Dict
value: Optional[torch.Tensor]
value: Optional[T]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've learned this neat trick with Generic and TypeVar where the type of value can be inferred from the Callable signature. I also fixed effectful so that it gives the correct signature for the decorated function when the return type is diferent from torch.Tensor (e.g. reparam_messenger._get_init_messengers).

@@ -368,6 +365,7 @@ def _fn(
)
# apply the stack and return its return value
apply_stack(msg)
assert msg["value"] is not None
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this always correct? All tests have passed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pyro/poutine/reparam_messenger.py Outdated Show resolved Hide resolved
pyro/poutine/runtime.py Show resolved Hide resolved
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, just a couple nits.

pyro/distributions/torch_distribution.py Show resolved Hide resolved
pyro/poutine/messenger.py Outdated Show resolved Hide resolved
pyro/poutine/messenger.py Show resolved Hide resolved
pyro/poutine/reentrant_messenger.py Outdated Show resolved Hide resolved
@@ -368,6 +365,7 @@ def _fn(
)
# apply the stack and return its return value
apply_stack(msg)
assert msg["value"] is not None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

@ordabayevy ordabayevy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fritzo can you have another look? I have addressed your comments.

) -> Union[T, torch.Tensor, None]:
obs: Optional[_T] = None,
**kwargs: _P.kwargs,
) -> Optional[_T]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed this back to return Optional and removed the assert msg["value"] is not None line. One concern I have is that if _T itself is None then it will raise an assertion error.

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, Thanks for the ping!

@fritzo fritzo merged commit 670e9cb into dev Jan 6, 2024
9 checks passed
@ordabayevy ordabayevy deleted the type-handlers branch January 6, 2024 18:00
@ordabayevy ordabayevy mentioned this pull request Feb 5, 2024
23 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants