From a85e55ff4af1d07cadbeb01f9f5c18b4b5b7e667 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Thu, 23 May 2024 10:11:19 +0000 Subject: [PATCH] Update base for Update on "generalize custom_fwd&custom_bwd to be device-agnostic" cc mcarilli ptrblck leslie-fang-intel jgong5 voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang [ghstack-poisoned] --- docs/source/amp.rst | 9 +++------ test/test_torch.py | 4 ++-- torch/cpu/amp/grad_scaler.py | 3 +-- torch/cuda/amp/grad_scaler.py | 3 +-- 4 files changed, 7 insertions(+), 12 deletions(-) diff --git a/docs/source/amp.rst b/docs/source/amp.rst index a9d98a0aa0989..96aee67e02926 100644 --- a/docs/source/amp.rst +++ b/docs/source/amp.rst @@ -25,12 +25,9 @@ However, :class:`torch.autocast` and :class:`torch.GradScaler` are modular, and As shown in the CPU example section of :class:`torch.autocast`, "automatic mixed precision training/inference" on CPU with datatype of ``torch.bfloat16`` only uses :class:`torch.autocast`. -For CUDA and CPU, APIs are also provided separately: - -* ``torch.autocast("cuda", args...)`` is equivalent to ``torch.cuda.amp.autocast(args...)``. -* ``torch.autocast("cpu", args...)`` is equivalent to ``torch.cpu.amp.autocast(args...)``. For CPU, only lower precision floating point datatype of ``torch.bfloat16`` is supported for now. -* ``torch.GradScaler("cuda", args...)`` is equivalent to ``torch.cuda.amp.GradScaler(args...)``. -* ``torch.GradScaler("cpu", args...)`` is equivalent to ``torch.cpu.amp.GradScaler(args...)``. +.. warning:: + ``torch.cuda.amp.autocast(args...)`` and ``torch.cpu.amp.autocast(args...)`` will be deprecated. Please use ``torch.autocast("cuda", args...)`` or ``torch.autocast("cpu", args...)`` instead. + ``torch.cuda.amp.GradScaler(args...)`` and ``torch.cpu.amp.GradScaler(args...)`` will be deprecated. Please use ``torch.GradScaler("cuda", args...)`` or ``torch.GradScaler("cpu", args...)`` instead. :class:`torch.autocast` and :class:`torch.cpu.amp.autocast` are new in version `1.10`. diff --git a/test/test_torch.py b/test/test_torch.py index 2e8fe9d9a250b..d0418813c9250 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6152,7 +6152,7 @@ def run(model0, model1, optimizer0, optimizer1, try_scaling_api): @onlyNativeDeviceTypes def test_grad_scaler_pass_itself(self, device): device = torch.device(device) - GradScaler = partial(torch.amp.GradScaler, device=device) + GradScaler = partial(torch.amp.GradScaler, device=device.type) class _PlaceHolderOptimizer(torch.optim.Optimizer): tester = self @@ -6195,7 +6195,7 @@ def test_grad_scaler_deprecated_warning(self, device): GradScaler = torch.cuda.amp.GradScaler if "cuda" == device.type else torch.cpu.amp.GradScaler with self.assertWarnsRegex( - DeprecationWarning, + UserWarning, rf"torch.{device.type}.amp.GradScaler\(args...\) is deprecated.", ): _ = GradScaler(init_scale=2.0) diff --git a/torch/cpu/amp/grad_scaler.py b/torch/cpu/amp/grad_scaler.py index a97d4343e71f0..2c93e0100f161 100644 --- a/torch/cpu/amp/grad_scaler.py +++ b/torch/cpu/amp/grad_scaler.py @@ -20,8 +20,7 @@ def __init__( enabled: bool = True, ) -> None: warnings.warn( - "torch.cpu.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cpu', args...) instead.", - DeprecationWarning, + "torch.cpu.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cpu', args...) instead." ) super().__init__( "cpu", diff --git a/torch/cuda/amp/grad_scaler.py b/torch/cuda/amp/grad_scaler.py index c97f822f402fd..8263fcdb480de 100644 --- a/torch/cuda/amp/grad_scaler.py +++ b/torch/cuda/amp/grad_scaler.py @@ -20,8 +20,7 @@ def __init__( enabled: bool = True, ) -> None: warnings.warn( - "torch.cuda.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cuda', args...) instead.", - DeprecationWarning, + "torch.cuda.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cuda', args...) instead." ) super().__init__( "cuda",