Skip to content

Commit

Permalink
Update on "Deprecate device-specific GradScaler autocast API"
Browse files Browse the repository at this point in the history
# Motivation

## for `torch.amp.GradScaler`,
- `torch.cpu.amp.GradScaler(args...)` is completely equivalent to `torch. amp.GradScaler("cpu", args...)`.
- `torch.cuda.amp.GradScaler(args...)` is completely equivalent to `torch.amp.GradScaler("cuda", args...)`.

So, we intend to depreate them and **strongly recommend** developer to use `torch.amp.GradScaler`.

## for `custom_fwd` and `custom_bwd`,
this is a good solution to make the custom function run with or without effect even in an autocast-enabled region and can be shared by other backends, like CPU and XPU.
So we generalize it to be device-agnostic and put them int `torch/amp/autocast_mode.py` and re-expose to `torch.amp.custom_fwd` and `torch.amp.custom_bwd`. Meanwhile, we deprecate `torch.cuda.amp.custom_fwd` and `torch.cuda.amp.custom_bwd`.

# Additional Context
Add UT to cover the deprecated warning.
No need for more UTs to cover the functionality of `torch.amp.custom_f/bwd`, the existing UTs that previously covered the functionality of `torch.cuda.amp.custom_f/bwd` can cover them.
To facilitate the review, we separate these code changes to two PRs. The first PR cover `torch.amp.GradScaler`. The follow-up covers `custom_fwd` and `custom_bwd`.

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 mcarilli ptrblck leslie-fang-intel voznesenskym EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng

[ghstack-poisoned]
  • Loading branch information
guangyey committed May 23, 2024
2 parents b84ef62 + 2138b68 commit b5bf757
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 12 deletions.
9 changes: 3 additions & 6 deletions docs/source/amp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,9 @@ However, :class:`torch.autocast` and :class:`torch.GradScaler` are modular, and
As shown in the CPU example section of :class:`torch.autocast`, "automatic mixed precision training/inference" on CPU with
datatype of ``torch.bfloat16`` only uses :class:`torch.autocast`.

For CUDA and CPU, APIs are also provided separately:

* ``torch.autocast("cuda", args...)`` is equivalent to ``torch.cuda.amp.autocast(args...)``.
* ``torch.autocast("cpu", args...)`` is equivalent to ``torch.cpu.amp.autocast(args...)``. For CPU, only lower precision floating point datatype of ``torch.bfloat16`` is supported for now.
* ``torch.GradScaler("cuda", args...)`` is equivalent to ``torch.cuda.amp.GradScaler(args...)``.
* ``torch.GradScaler("cpu", args...)`` is equivalent to ``torch.cpu.amp.GradScaler(args...)``.
.. warning::
``torch.cuda.amp.autocast(args...)`` and ``torch.cpu.amp.autocast(args...)`` will be deprecated. Please use ``torch.autocast("cuda", args...)`` or ``torch.autocast("cpu", args...)`` instead.
``torch.cuda.amp.GradScaler(args...)`` and ``torch.cpu.amp.GradScaler(args...)`` will be deprecated. Please use ``torch.GradScaler("cuda", args...)`` or ``torch.GradScaler("cpu", args...)`` instead.

:class:`torch.autocast` and :class:`torch.cpu.amp.autocast` are new in version `1.10`.

Expand Down
4 changes: 2 additions & 2 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6152,7 +6152,7 @@ def run(model0, model1, optimizer0, optimizer1, try_scaling_api):
@onlyNativeDeviceTypes
def test_grad_scaler_pass_itself(self, device):
device = torch.device(device)
GradScaler = partial(torch.amp.GradScaler, device=device)
GradScaler = partial(torch.amp.GradScaler, device=device.type)

class _PlaceHolderOptimizer(torch.optim.Optimizer):
tester = self
Expand Down Expand Up @@ -6195,7 +6195,7 @@ def test_grad_scaler_deprecated_warning(self, device):
GradScaler = torch.cuda.amp.GradScaler if "cuda" == device.type else torch.cpu.amp.GradScaler

with self.assertWarnsRegex(
DeprecationWarning,
UserWarning,
rf"torch.{device.type}.amp.GradScaler\(args...\) is deprecated.",
):
_ = GradScaler(init_scale=2.0)
Expand Down
3 changes: 1 addition & 2 deletions torch/cpu/amp/grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ def __init__(
enabled: bool = True,
) -> None:
warnings.warn(
"torch.cpu.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cpu', args...) instead.",
DeprecationWarning,
"torch.cpu.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cpu', args...) instead."
)
super().__init__(
"cpu",
Expand Down
3 changes: 1 addition & 2 deletions torch/cuda/amp/grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ def __init__(
enabled: bool = True,
) -> None:
warnings.warn(
"torch.cuda.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cuda', args...) instead.",
DeprecationWarning,
"torch.cuda.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cuda', args...) instead."
)
super().__init__(
"cuda",
Expand Down

0 comments on commit b5bf757

Please sign in to comment.