Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improve torch.cuda.amp type hints (#108630)
Fixes #108629 1. Add the following to their modules' `__all__` so that pyright considers them to be publicly exported: * [`torch.autocast`](https://pytorch.org/docs/stable/amp.html#torch.autocast) * [`torch.cuda.amp.GradScaler`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler) * [`torch.cuda.amp.autocast`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast) * [`torch.cuda.amp.custom_fwd`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.custom_fwd) * [`torch.cuda.amp.custom_bwd`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.custom_bwd) 2. Add `overload`s for `torch.cuda.amp.GradScaler.scale` to differentiate when a `torch.Tensor` is returned vs. an `Iterable[torch.Tensor]` is returned based on the type of the `outputs` parameter. Pull Request resolved: #108630 Approved by: https://github.com/ezyang
- Loading branch information