Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Annotate handlers & add py.typed #3321

Merged
merged 1 commit into from
Feb 6, 2024
Merged

Annotate handlers & add py.typed #3321

merged 1 commit into from
Feb 6, 2024

Conversation

ordabayevy
Copy link
Member

No description provided.

@overload
def condition(
data: Union[Dict[str, "torch.Tensor"], "Trace"],
) -> ConditionMessenger: ...
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some handlers have arguments that are not optional so fn has to be required in the signature. This is the trick with overloading to have fn as optional. Based on the signature mypy can figure out which type annotations to use.

@ordabayevy ordabayevy mentioned this pull request Feb 5, 2024
23 tasks
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, I have just a couple questions.



@_make_handler(BlockMessenger)
def block( # type: ignore[empty-body]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be more idiomatic to use pass rather than ... # type: ignore[empty-body], here and in other targets of @_make_handler? Or would that cause mypy to complain about the return type?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checked it, it gives the same empty-body mypy error:

(pyro) yordabay@yo-dl-dev:/mnt/disks/dev/repos/pyro$ git diff
diff --git a/pyro/poutine/handlers.py b/pyro/poutine/handlers.py
index c54da6cf..404ae667 100644
--- a/pyro/poutine/handlers.py
+++ b/pyro/poutine/handlers.py
@@ -165,7 +165,7 @@ def block(
 
 
 @_make_handler(BlockMessenger)
-def block(  # type: ignore[empty-body]
+def block(
     fn: Optional[Callable[_P, _T]] = None,
     hide_fn: Optional[Callable[["Message"], Optional[bool]]] = None,
     expose_fn: Optional[Callable[["Message"], Optional[bool]]] = None,
@@ -175,7 +175,8 @@ def block(  # type: ignore[empty-body]
     expose: Optional[List[str]] = None,
     hide_types: Optional[List[str]] = None,
     expose_types: Optional[List[str]] = None,
-) -> Union[BlockMessenger, Callable[_P, _T]]: ...
+) -> Union[BlockMessenger, Callable[_P, _T]]:
+    pass
 
 
 @overload
(pyro) yordabay@yo-dl-dev:/mnt/disks/dev/repos/pyro$ make lint
ruff check .
black --check *.py pyro examples tests scripts profiler
Skipping .ipynb files as Jupyter dependencies are not installed.
You can fix this by running ``pip install "black[jupyter]"``
All done! ✨ 🍰 ✨
621 files would be left unchanged.
python scripts/update_headers.py --check
mypy --install-types --non-interactive pyro scripts
pyro/poutine/handlers.py:168: error: Missing return statement  [empty-body]
Found 1 error in 1 file (checked 321 source files)
make: *** [Makefile:24: lint] Error 1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for checking!

@@ -94,14 +94,14 @@ def _masked_observe(
name: str,
fn: TorchDistributionMixin,
obs: Optional[torch.Tensor],
obs_mask: torch.Tensor,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Is torch.cuda.BoolTensor a subclass of torch.BoolTensor?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like not:

>>> issubclass(torch.cuda.BoolTensor, torch.Tensor)
False
>>> issubclass(torch.cuda.BoolTensor, torch.BoolTensor)
False

But it should be okay since it is only used for type checking, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it should be ok. Let's just keep in mind that assert isinstance(obs_mask, torch.BoolTensor) would fail in runtime.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. That's why I try to use assert x is not None when x: Optional[torch.Tensor]

*args,
**kwargs,
) -> torch.Tensor:
# Split into two auxiliary sample sites.
with poutine.mask(mask=obs_mask):
observed = sample(f"{name}_observed", fn, *args, **kwargs, obs=obs)
with poutine.mask(mask=~obs_mask):
with poutine.mask(mask=~obs_mask): # type: ignore[call-overload]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, unfortunately, ~torch.BoolTensor returns a torch.Tensor

@fritzo fritzo merged commit 6d2a56f into dev Feb 6, 2024
9 checks passed
@ordabayevy ordabayevy deleted the type-handlers branch February 10, 2024 14:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Mypy Cannot fined implementation or library stub for module named pyro
2 participants