From 054a4513366a0ef1c3219d257d5beda50249c3e1 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Sat, 27 Apr 2024 23:26:10 +0000 Subject: [PATCH 1/4] make torch.amp.autocast more generic [ghstack-poisoned] --- benchmarks/dynamo/common.py | 47 ++++++++++++++++------------------- torch/_dynamo/output_graph.py | 16 +++--------- torch/amp/autocast_mode.py | 2 ++ torch/utils/checkpoint.py | 14 +++++------ 4 files changed, 34 insertions(+), 45 deletions(-) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index d610d7dd13540..250aa079f0a3f 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -2041,31 +2041,28 @@ def setup_amp(self, current_device=None): devices = [current_device] if current_device else self.args.devices if self.args.amp: - if devices == ["cuda"]: - # AMP training can lead to small loss values which can undeflow - # gradient values returning in zero gradients. To solve this - # problem, PyTorch introduces GradScaler. GradScaler is a stateful - # structure, that scales the loss values to prevent underflow. Loss - # values are big at the beginning of training (therefore not - # requiring scaling), while loss value tends to be small as network - # starts getting better (requiring scaling). GradScaler manages all - # of this fine tuning, checking the gradients are turning to inf, - # discarding such batches. - - # Since we are not running a long iteration, default value of - # init_scale 65536 is going to turn all gradients to inf. Therefore, - # we just use a init_scale of 2.0 for benchmarking purpose. - - # Disabling Gradscaler because - # 1) Benchmark setup runs 2 iterations of fwd-bwd. So, not useful. - # 2) Current setup shares grad_scaler for eager and dynamo model, - # which is bad as Gradscaler has state and can adjust the scaling - # factor between eager and dynamo run, making accuracy check - # harder. - # self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0) - self.autocast = torch.cuda.amp.autocast - if devices == ["cpu"]: - self.autocast = torch.cpu.amp.autocast + # AMP training can lead to small loss values which can undeflow + # gradient values returning in zero gradients. To solve this + # problem, PyTorch introduces GradScaler. GradScaler is a stateful + # structure, that scales the loss values to prevent underflow. Loss + # values are big at the beginning of training (therefore not + # requiring scaling), while loss value tends to be small as network + # starts getting better (requiring scaling). GradScaler manages all + # of this fine tuning, checking the gradients are turning to inf, + # discarding such batches. + + # Since we are not running a long iteration, default value of + # init_scale 65536 is going to turn all gradients to inf. Therefore, + # we just use a init_scale of 2.0 for benchmarking purpose. + + # Disabling Gradscaler because + # 1) Benchmark setup runs 2 iterations of fwd-bwd. So, not useful. + # 2) Current setup shares grad_scaler for eager and dynamo model, + # which is bad as Gradscaler has state and can adjust the scaling + # factor between eager and dynamo run, making accuracy check + # harder. + # self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0) + self.autocast = functools.partial(torch.amp.autocast, device_type=devices) if self.args.amp_dtype: amp_dtype = ( torch.float16 diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 0be89d59e2b92..e6773dcd807d1 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -598,28 +598,20 @@ def save_global_state(self, out=None): ) global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled()) - def autocast_specific_backend( - device_type: str, func: Callable[[str, Any], None] - ): - def decorator(value): - return func(device_type, value) - - return decorator - global_state["autocast_enabled"] = ( - autocast_specific_backend("cuda", torch.set_autocast_enabled), + functools.partial(torch.set_autocast_enabled, "cuda"), torch.is_autocast_enabled("cuda"), ) global_state["autocast_cpu_enabled"] = ( - autocast_specific_backend("cpu", torch.set_autocast_enabled), + functools.partial(torch.set_autocast_enabled, "cpu"), torch.is_autocast_enabled("cpu"), ) global_state["autocast_gpu_dtype"] = ( - autocast_specific_backend("cuda", torch.set_autocast_dtype), + functools.partial(torch.set_autocast_dtype, "cuda"), torch.get_autocast_dtype("cuda"), ) global_state["autocast_cpu_dtype"] = ( - autocast_specific_backend("cpu", torch.set_autocast_dtype), + functools.partial(torch.set_autocast_dtype, "cpu"), torch.get_autocast_dtype("cpu"), ) global_state["autocast_cache_enabled"] = ( diff --git a/torch/amp/autocast_mode.py b/torch/amp/autocast_mode.py index 523d8dc34d5ea..a02242b1e2f5e 100644 --- a/torch/amp/autocast_mode.py +++ b/torch/amp/autocast_mode.py @@ -203,6 +203,8 @@ def __init__( enabled: bool = True, cache_enabled: Optional[bool] = None, ): + if dtype is None: + dtype = torch.get_autocast_dtype(device_type) if torch._jit_internal.is_scripting(): self._enabled = enabled self.device = device_type diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index 7c74b4c6be714..ba8de894e069a 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -287,11 +287,10 @@ def backward(ctx, *args): set_device_states(ctx.fwd_devices, ctx.fwd_device_states) detached_inputs = detach_variable(tuple(inputs)) - device_autocast_ctx = device_module.amp.autocast( - **ctx.device_autocast_kwargs + device_autocast_ctx = torch.amp.autocast( + device_type=ctx.device, **ctx.device_autocast_kwargs ) if torch.amp.is_autocast_available(ctx.device) else contextlib.nullcontext() - with torch.enable_grad(), device_autocast_ctx, \ - torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): + with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined] outputs = ctx.run_function(*detached_inputs) if isinstance(outputs, torch.Tensor): @@ -1394,11 +1393,10 @@ def recompute_fn(*inputs): if had_device_in_fwd: set_device_states(fwd_devices, fwd_device_states) - device_autocast_ctx = device_module.amp.autocast( - **device_autocast_kwargs + device_autocast_ctx = torch.amp.autocast( + device_type=device, **device_autocast_kwargs ) if torch.amp.is_autocast_available(device) else contextlib.nullcontext() - with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \ - recompute_context: + with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined] fn(*args, **kwargs) new_frame = _CheckpointFrame( From 2b6089e0731d8e87556d135e7af5d78e3f1f7495 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Sun, 28 Apr 2024 00:43:30 +0000 Subject: [PATCH 2/4] Update on "[WIP] make torch.amp.autocast more generic" cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned] --- benchmarks/dynamo/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 250aa079f0a3f..6996ce9f7e0c8 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -2062,7 +2062,7 @@ def setup_amp(self, current_device=None): # factor between eager and dynamo run, making accuracy check # harder. # self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0) - self.autocast = functools.partial(torch.amp.autocast, device_type=devices) + self.autocast = functools.partial(torch.amp.autocast, device_type=devices[0]) if self.args.amp_dtype: amp_dtype = ( torch.float16 From c6e6fac00334757b84b2324481931187f6776bd2 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Mon, 29 Apr 2024 00:55:36 +0000 Subject: [PATCH 3/4] Update on "[WIP] make torch.amp.autocast more generic" cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned] --- benchmarks/dynamo/common.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 6996ce9f7e0c8..364ce65a38153 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -2062,7 +2062,9 @@ def setup_amp(self, current_device=None): # factor between eager and dynamo run, making accuracy check # harder. # self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0) - self.autocast = functools.partial(torch.amp.autocast, device_type=devices[0]) + self.autocast = functools.partial( + torch.amp.autocast, device_type=devices[0] + ) if self.args.amp_dtype: amp_dtype = ( torch.float16 From 58cababe03517520168d63d5c168408b168550b8 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Mon, 29 Apr 2024 01:42:40 +0000 Subject: [PATCH 4/4] Update on "[WIP] make torch.amp.autocast more generic" cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned] --- aten/src/ATen/core/interned_strings.h | 1 + torch/csrc/jit/runtime/register_prim_ops.cpp | 10 ++++++++++ 2 files changed, 11 insertions(+) diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 0e63abad58fc6..4f6abd66cb887 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -227,6 +227,7 @@ namespace c10 { _(aten, is_autocast_enabled) \ _(aten, is_autocast_cpu_enabled) \ _(aten, is_autocast_xla_enabled) \ + _(aten, get_autocast_dtype) \ FORALL_ATEN_BASE_SYMBOLS(_) \ _(onnx, Add) \ _(onnx, Concat) \ diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index 485adcb5a8c21..d9109c5d60de8 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -815,6 +815,16 @@ static const std::vector opGenArgs{ push(stack, enabled); }, aliasAnalysisConservative()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::get_autocast_dtype(str device_type) -> ScalarType"), + [](Stack& stack) { + at::DeviceType device_type = + at::Device(pop(stack).toStringRef()).type(); + at::ScalarType dtype = at::autocast::get_autocast_dtype(device_type); + push(stack, dtype); + }, + aliasAnalysisConservative()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("prim::Uninitialized() -> Any"), unInitialized,