diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index d610d7dd13540..6996ce9f7e0c8 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -2041,31 +2041,28 @@ def setup_amp(self, current_device=None): devices = [current_device] if current_device else self.args.devices if self.args.amp: - if devices == ["cuda"]: - # AMP training can lead to small loss values which can undeflow - # gradient values returning in zero gradients. To solve this - # problem, PyTorch introduces GradScaler. GradScaler is a stateful - # structure, that scales the loss values to prevent underflow. Loss - # values are big at the beginning of training (therefore not - # requiring scaling), while loss value tends to be small as network - # starts getting better (requiring scaling). GradScaler manages all - # of this fine tuning, checking the gradients are turning to inf, - # discarding such batches. - - # Since we are not running a long iteration, default value of - # init_scale 65536 is going to turn all gradients to inf. Therefore, - # we just use a init_scale of 2.0 for benchmarking purpose. - - # Disabling Gradscaler because - # 1) Benchmark setup runs 2 iterations of fwd-bwd. So, not useful. - # 2) Current setup shares grad_scaler for eager and dynamo model, - # which is bad as Gradscaler has state and can adjust the scaling - # factor between eager and dynamo run, making accuracy check - # harder. - # self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0) - self.autocast = torch.cuda.amp.autocast - if devices == ["cpu"]: - self.autocast = torch.cpu.amp.autocast + # AMP training can lead to small loss values which can undeflow + # gradient values returning in zero gradients. To solve this + # problem, PyTorch introduces GradScaler. GradScaler is a stateful + # structure, that scales the loss values to prevent underflow. Loss + # values are big at the beginning of training (therefore not + # requiring scaling), while loss value tends to be small as network + # starts getting better (requiring scaling). GradScaler manages all + # of this fine tuning, checking the gradients are turning to inf, + # discarding such batches. + + # Since we are not running a long iteration, default value of + # init_scale 65536 is going to turn all gradients to inf. Therefore, + # we just use a init_scale of 2.0 for benchmarking purpose. + + # Disabling Gradscaler because + # 1) Benchmark setup runs 2 iterations of fwd-bwd. So, not useful. + # 2) Current setup shares grad_scaler for eager and dynamo model, + # which is bad as Gradscaler has state and can adjust the scaling + # factor between eager and dynamo run, making accuracy check + # harder. + # self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0) + self.autocast = functools.partial(torch.amp.autocast, device_type=devices[0]) if self.args.amp_dtype: amp_dtype = ( torch.float16 diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 0be89d59e2b92..e6773dcd807d1 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -598,28 +598,20 @@ def save_global_state(self, out=None): ) global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled()) - def autocast_specific_backend( - device_type: str, func: Callable[[str, Any], None] - ): - def decorator(value): - return func(device_type, value) - - return decorator - global_state["autocast_enabled"] = ( - autocast_specific_backend("cuda", torch.set_autocast_enabled), + functools.partial(torch.set_autocast_enabled, "cuda"), torch.is_autocast_enabled("cuda"), ) global_state["autocast_cpu_enabled"] = ( - autocast_specific_backend("cpu", torch.set_autocast_enabled), + functools.partial(torch.set_autocast_enabled, "cpu"), torch.is_autocast_enabled("cpu"), ) global_state["autocast_gpu_dtype"] = ( - autocast_specific_backend("cuda", torch.set_autocast_dtype), + functools.partial(torch.set_autocast_dtype, "cuda"), torch.get_autocast_dtype("cuda"), ) global_state["autocast_cpu_dtype"] = ( - autocast_specific_backend("cpu", torch.set_autocast_dtype), + functools.partial(torch.set_autocast_dtype, "cpu"), torch.get_autocast_dtype("cpu"), ) global_state["autocast_cache_enabled"] = ( diff --git a/torch/amp/autocast_mode.py b/torch/amp/autocast_mode.py index 523d8dc34d5ea..a02242b1e2f5e 100644 --- a/torch/amp/autocast_mode.py +++ b/torch/amp/autocast_mode.py @@ -203,6 +203,8 @@ def __init__( enabled: bool = True, cache_enabled: Optional[bool] = None, ): + if dtype is None: + dtype = torch.get_autocast_dtype(device_type) if torch._jit_internal.is_scripting(): self._enabled = enabled self.device = device_type diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index 7c74b4c6be714..ba8de894e069a 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -287,11 +287,10 @@ def backward(ctx, *args): set_device_states(ctx.fwd_devices, ctx.fwd_device_states) detached_inputs = detach_variable(tuple(inputs)) - device_autocast_ctx = device_module.amp.autocast( - **ctx.device_autocast_kwargs + device_autocast_ctx = torch.amp.autocast( + device_type=ctx.device, **ctx.device_autocast_kwargs ) if torch.amp.is_autocast_available(ctx.device) else contextlib.nullcontext() - with torch.enable_grad(), device_autocast_ctx, \ - torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): + with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined] outputs = ctx.run_function(*detached_inputs) if isinstance(outputs, torch.Tensor): @@ -1394,11 +1393,10 @@ def recompute_fn(*inputs): if had_device_in_fwd: set_device_states(fwd_devices, fwd_device_states) - device_autocast_ctx = device_module.amp.autocast( - **device_autocast_kwargs + device_autocast_ctx = torch.amp.autocast( + device_type=device, **device_autocast_kwargs ) if torch.amp.is_autocast_available(device) else contextlib.nullcontext() - with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \ - recompute_context: + with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined] fn(*args, **kwargs) new_frame = _CheckpointFrame(