Skip to content

Commit

Permalink
Update on "generalize custom_fwd&custom_bwd to be device-agnostic"
Browse files Browse the repository at this point in the history
cc mcarilli ptrblck leslie-fang-intel jgong5 voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
  • Loading branch information
guangyey committed May 23, 2024
2 parents fa5f5a5 + a85e55f commit d0417ae
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 17 deletions.
9 changes: 3 additions & 6 deletions docs/source/amp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,9 @@ However, :class:`torch.autocast` and :class:`torch.GradScaler` are modular, and
As shown in the CPU example section of :class:`torch.autocast`, "automatic mixed precision training/inference" on CPU with
datatype of ``torch.bfloat16`` only uses :class:`torch.autocast`.

For CUDA and CPU, APIs are also provided separately:

* ``torch.autocast("cuda", args...)`` is equivalent to ``torch.cuda.amp.autocast(args...)``.
* ``torch.autocast("cpu", args...)`` is equivalent to ``torch.cpu.amp.autocast(args...)``. For CPU, only lower precision floating point datatype of ``torch.bfloat16`` is supported for now.
* ``torch.GradScaler("cuda", args...)`` is equivalent to ``torch.cuda.amp.GradScaler(args...)``.
* ``torch.GradScaler("cpu", args...)`` is equivalent to ``torch.cpu.amp.GradScaler(args...)``.
.. warning::
``torch.cuda.amp.autocast(args...)`` and ``torch.cpu.amp.autocast(args...)`` will be deprecated. Please use ``torch.autocast("cuda", args...)`` or ``torch.autocast("cpu", args...)`` instead.
``torch.cuda.amp.GradScaler(args...)`` and ``torch.cpu.amp.GradScaler(args...)`` will be deprecated. Please use ``torch.GradScaler("cuda", args...)`` or ``torch.GradScaler("cpu", args...)`` instead.

:class:`torch.autocast` and :class:`torch.cpu.amp.autocast` are new in version `1.10`.

Expand Down
1 change: 0 additions & 1 deletion test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1800,7 +1800,6 @@ def backward(ctx, grad):

def test_autocast_custom_deprecated_warning(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always", category=DeprecationWarning)

class MyMM(torch.autograd.Function):
@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6152,7 +6152,7 @@ def run(model0, model1, optimizer0, optimizer1, try_scaling_api):
@onlyNativeDeviceTypes
def test_grad_scaler_pass_itself(self, device):
device = torch.device(device)
GradScaler = partial(torch.amp.GradScaler, device=device)
GradScaler = partial(torch.amp.GradScaler, device=device.type)

class _PlaceHolderOptimizer(torch.optim.Optimizer):
tester = self
Expand Down Expand Up @@ -6195,7 +6195,7 @@ def test_grad_scaler_deprecated_warning(self, device):
GradScaler = torch.cuda.amp.GradScaler if "cuda" == device.type else torch.cpu.amp.GradScaler

with self.assertWarnsRegex(
DeprecationWarning,
UserWarning,
rf"torch.{device.type}.amp.GradScaler\(args...\) is deprecated.",
):
_ = GradScaler(init_scale=2.0)
Expand Down
3 changes: 1 addition & 2 deletions torch/cpu/amp/grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ def __init__(
enabled: bool = True,
) -> None:
warnings.warn(
"torch.cpu.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cpu', args...) instead.",
DeprecationWarning,
"torch.cpu.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cpu', args...) instead."
)
super().__init__(
"cpu",
Expand Down
6 changes: 2 additions & 4 deletions torch/cuda/amp/autocast_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def custom_fwd(fwd=None, *, cast_inputs=None):
``torch.amp.custom_fwd(args..., device_type='cuda')`` instead.
"""
warnings.warn(
"torch.cuda.amp.custom_fwd(args...) is deprecated. Please use torch.amp.custom_fwd(args..., device_type='cuda') instead.",
DeprecationWarning,
"torch.cuda.amp.custom_fwd(args...) is deprecated. Please use torch.amp.custom_fwd(args..., device_type='cuda') instead."
)
return functools.partial(torch.amp.custom_fwd, device_type="cuda")(
fwd=fwd, cast_inputs=cast_inputs
Expand All @@ -69,7 +68,6 @@ def custom_bwd(bwd):
``torch.amp.custom_bwd(args..., device_type='cuda')`` instead.
"""
warnings.warn(
"torch.cuda.amp.custom_bwd(args...) is deprecated. Please use torch.amp.custom_bwd(args..., device_type='cuda') instead.",
DeprecationWarning,
"torch.cuda.amp.custom_bwd(args...) is deprecated. Please use torch.amp.custom_bwd(args..., device_type='cuda') instead."
)
return functools.partial(torch.amp.custom_bwd, device_type="cuda")(bwd)
3 changes: 1 addition & 2 deletions torch/cuda/amp/grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ def __init__(
enabled: bool = True,
) -> None:
warnings.warn(
"torch.cuda.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cuda', args...) instead.",
DeprecationWarning,
"torch.cuda.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cuda', args...) instead."
)
super().__init__(
"cuda",
Expand Down

0 comments on commit d0417ae

Please sign in to comment.