Skip to content

Commit

Permalink
Improve torch.cuda.amp type hints (#108630)
Browse files Browse the repository at this point in the history
Fixes #108629

1. Add the following to their modules' `__all__` so that pyright considers them to be publicly exported:
* [`torch.autocast`](https://pytorch.org/docs/stable/amp.html#torch.autocast)
* [`torch.cuda.amp.GradScaler`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler)
* [`torch.cuda.amp.autocast`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast)
* [`torch.cuda.amp.custom_fwd`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.custom_fwd)
* [`torch.cuda.amp.custom_bwd`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.custom_bwd)
2. Add `overload`s for `torch.cuda.amp.GradScaler.scale` to differentiate when a `torch.Tensor` is returned vs. an `Iterable[torch.Tensor]` is returned based on the type of the `outputs` parameter.

Pull Request resolved: #108630
Approved by: https://github.com/ezyang
  • Loading branch information
ringohoffman authored and pytorchmergebot committed Sep 8, 2023
1 parent 6c72604 commit e40d6ae
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 109 deletions.
2 changes: 1 addition & 1 deletion test/test_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def generate_tensor_like_torch_implementations():
# the problem. A more proper fix is to make the "not tested" check
# a test on its own, and to make sure the monkeypatch is only installed
# for the span of the relevant test (and deleted afterwards)
testing_ignore = {"sample_functional"}
testing_ignore = {"sample_functional", "autocast"}
for namespace, funcs in get_overridable_functions().items():
for func in funcs:
if func not in testing_overrides and func.__name__ not in testing_ignore:
Expand Down
2 changes: 1 addition & 1 deletion torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _running_with_deploy():
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
'SymBool', 'sym_not',
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'compile', 'vmap',
'export',
'export', 'autocast',
]

################################################################################
Expand Down
11 changes: 9 additions & 2 deletions torch/cuda/amp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,9 @@
from .autocast_mode import autocast, custom_bwd, custom_fwd # noqa: F401
from .grad_scaler import GradScaler # noqa: F401
from .autocast_mode import autocast, custom_bwd, custom_fwd
from .grad_scaler import GradScaler

__all__ = [
"autocast",
"custom_bwd",
"custom_fwd",
"GradScaler",
]
Loading

0 comments on commit e40d6ae

Please sign in to comment.