diff --git a/aten/src/ATen/autocast_mode.h b/aten/src/ATen/autocast_mode.h index f4dd7d8766917..a2d7b8a24a467 100644 --- a/aten/src/ATen/autocast_mode.h +++ b/aten/src/ATen/autocast_mode.h @@ -174,12 +174,20 @@ inline DispatchKey get_autocast_dispatch_key_from_device_type( } } -inline at::ScalarType get_lower_precision_fp_from_device_type( - c10::DeviceType device_type) { +inline bool is_autocast_available(c10::DeviceType device_type) { if (device_type == at::kCPU || device_type == at::kCUDA || device_type == at::kXPU || device_type == at::kIPU || device_type == at::kHPU || device_type == at::kXLA || device_type == at::kPrivateUse1) { + return true; + } else { + return false; + } +} + +inline at::ScalarType get_lower_precision_fp_from_device_type( + c10::DeviceType device_type) { + if (is_autocast_available(device_type)) { return get_autocast_dtype(device_type); } else { throw std::runtime_error( diff --git a/test/test_autocast.py b/test/test_autocast.py index 2f788b7f65ae6..5054944932887 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -336,7 +336,7 @@ def test_autocast_fast_dtype(self): def test_invalid_device(self): dev = "not a real device" - msg = f"unsupported autocast device_type '{dev}'" + msg = f"Invalid device string: '{dev}'" with self.assertRaisesRegex(RuntimeError, msg): with torch.autocast(device_type=dev): _ = torch.tensor(1) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 1b53c1b40f081..c44357c64ed48 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1263,6 +1263,7 @@ def clear_autocast_cache() -> None: ... def set_autocast_cpu_enabled(enabled: _bool) -> None: ... def is_autocast_cpu_enabled() -> _bool: ... def _is_any_autocast_enabled() -> _bool: ... +def _is_autocast_available(device_type: str) -> _bool: ... def set_autocast_cpu_dtype(dtype: _dtype) -> None: ... def set_autocast_gpu_dtype(dtype: _dtype) -> None: ... def get_autocast_cpu_dtype() -> _dtype: ... diff --git a/torch/amp/autocast_mode.py b/torch/amp/autocast_mode.py index 30c6aefcf1bda..87ff709fcfb2e 100644 --- a/torch/amp/autocast_mode.py +++ b/torch/amp/autocast_mode.py @@ -199,35 +199,20 @@ def __init__( assert dtype is not None return self.device = device_type + if not torch._C._is_autocast_available(self.device): + raise RuntimeError( + f"User specified an unsupported autocast device_type '{self.device}'" + ) self.custom_backend_name = torch._C._get_privateuse1_backend_name() - if self.device == "cuda": - self.fast_dtype = torch.get_autocast_gpu_dtype() - elif self.device == "cpu": - self.fast_dtype = torch.get_autocast_cpu_dtype() - elif self.device == "xpu": - self.fast_dtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined] - elif self.device == "ipu": - self.fast_dtype = torch.get_autocast_ipu_dtype() # type: ignore[attr-defined] - elif self.device == "hpu": - self.fast_dtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined] - elif self.device == "xla": - self.fast_dtype = torch.get_autocast_xla_dtype() # type: ignore[attr-defined] - elif self.device == self.custom_backend_name: + self.fast_dtype = torch.get_autocast_dtype(self.device) + if self.device == self.custom_backend_name: necessary_funcs = [ - "is_autocast_enabled", - "set_autocast_enabled", - "get_autocast_dtype", - "set_autocast_dtype", "get_amp_supported_dtype", ] message = f"Tried to use AMP with the `{self.custom_backend_name}` backend, but the backend has not " message += "registered a module or the module miss some necessary funcs. The backend should register " message += "a module by `torch._register_device_module`, and the module must have these funcs: \n" - message += "`is_autocast_enabled() -> bool`, `set_autocast_enabled(bool) -> None`, " - message += "`get_autocast_dtype() -> torch.dtype`, `set_autocast_dtype(torch.dtype) " - message += ( - "-> None` and `get_amp_supported_dtype() -> List[torch.dtype]`. \n" - ) + message += "`get_amp_supported_dtype() -> List[torch.dtype]`. \n" assert hasattr(torch, self.custom_backend_name), message self.custom_device_mod = getattr(torch, self.custom_backend_name) @@ -236,11 +221,6 @@ def __init__( message + f"But the func `{func}` is missing. \n" ) - self.fast_dtype = self.custom_device_mod.get_autocast_dtype() - else: - raise RuntimeError( - f"User specified an unsupported autocast device_type '{self.device}'" - ) self._cache_enabled = torch.is_autocast_cache_enabled() if ( enabled @@ -323,48 +303,11 @@ def __enter__(self): return self self.prev_cache_enabled = torch.is_autocast_cache_enabled() - if self.device == "cpu": - self.prev = torch.is_autocast_cpu_enabled() - self.prev_fastdtype = torch.get_autocast_cpu_dtype() - torch.set_autocast_cpu_enabled(self._enabled) - torch.set_autocast_cpu_dtype(self.fast_dtype) # type: ignore[arg-type] - torch.autocast_increment_nesting() - elif self.device == "xpu": - self.prev = torch.xpu.is_autocast_xpu_enabled() # type: ignore[attr-defined] - self.prev_fastdtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined] - torch.xpu.set_autocast_xpu_enabled(self._enabled) # type: ignore[attr-defined] - torch.xpu.set_autocast_xpu_dtype(self.fast_dtype) # type: ignore[attr-defined] - torch.autocast_increment_nesting() - elif self.device == "ipu": - self.prev = torch.is_autocast_ipu_enabled() # type: ignore[attr-defined] - self.prev_fastdtype = torch.get_autocast_ipu_dtype() # type: ignore[attr-defined] - torch.set_autocast_ipu_enabled(self._enabled) # type: ignore[attr-defined] - torch.set_autocast_ipu_dtype(self.fast_dtype) # type: ignore[attr-defined] - torch.autocast_increment_nesting() - elif self.device == "hpu": - self.prev = torch.hpu.is_autocast_hpu_enabled() # type: ignore[attr-defined] - self.prev_fastdtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined] - torch.hpu.set_autocast_hpu_enabled(self._enabled) # type: ignore[attr-defined] - torch.hpu.set_autocast_hpu_dtype(self.fast_dtype) # type: ignore[attr-defined] - torch.autocast_increment_nesting() - elif self.device == "xla": - self.prev = torch.is_autocast_xla_enabled() # type: ignore[attr-defined] - self.prev_fastdtype = torch.get_autocast_xla_dtype() # type: ignore[attr-defined] - torch.set_autocast_xla_enabled(self._enabled) # type: ignore[attr-defined] - torch.set_autocast_xla_dtype(self.fast_dtype) # type: ignore[attr-defined] - torch.autocast_increment_nesting() - elif self.device == self.custom_backend_name: - self.prev = self.custom_device_mod.is_autocast_enabled() - self.prev_fastdtype = self.custom_device_mod.get_autocast_dtype() - self.custom_device_mod.set_autocast_enabled(self._enabled) - self.custom_device_mod.set_autocast_dtype(self.fast_dtype) - torch.autocast_increment_nesting() - else: - self.prev = torch.is_autocast_enabled() - self.prev_fastdtype = torch.get_autocast_gpu_dtype() - torch.set_autocast_gpu_dtype(self.fast_dtype) # type: ignore[arg-type] - torch.set_autocast_enabled(self._enabled) - torch.autocast_increment_nesting() + self.prev = torch.is_autocast_enabled(self.device) + self.prev_fastdtype = torch.get_autocast_dtype(self.device) + torch.set_autocast_enabled(self.device, self._enabled) + torch.set_autocast_dtype(self.device, self.fast_dtype) # type: ignore[arg-type] + torch.autocast_increment_nesting() torch.set_autocast_cache_enabled(self._cache_enabled) def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] @@ -372,41 +315,10 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[ov return # Drop the cache when we exit to a nesting level that's outside any instance of autocast. - if self.device == "cpu": - if torch.autocast_decrement_nesting() == 0: - torch.clear_autocast_cache() - torch.set_autocast_cpu_enabled(self.prev) - torch.set_autocast_cpu_dtype(self.prev_fastdtype) - elif self.device == "xpu": - if torch.autocast_decrement_nesting() == 0: - torch.clear_autocast_cache() - torch.xpu.set_autocast_xpu_enabled(self.prev) # type: ignore[attr-defined] - torch.xpu.set_autocast_xpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined] - elif self.device == "ipu": - if torch.autocast_decrement_nesting() == 0: - torch.clear_autocast_cache() - torch.set_autocast_ipu_enabled(self.prev) # type: ignore[attr-defined] - torch.set_autocast_ipu_dtype(self.prev_fastdtype) # type: ignore[attr-defined] - elif self.device == "hpu": - if torch.autocast_decrement_nesting() == 0: - torch.clear_autocast_cache() - torch.hpu.set_autocast_hpu_enabled(self.prev) # type: ignore[attr-defined] - torch.hpu.set_autocast_hpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined] - elif self.device == "xla": - if torch.autocast_decrement_nesting() == 0: - torch.clear_autocast_cache() - torch.set_autocast_xla_enabled(self.prev) # type: ignore[attr-defined] - torch.set_autocast_xla_dtype(self.prev_fastdtype) # type: ignore[attr-defined] - elif self.device == self.custom_backend_name: - if torch.autocast_decrement_nesting() == 0: - torch.clear_autocast_cache() - self.custom_device_mod.set_autocast_enabled(self.prev) - self.custom_device_mod.set_autocast_dtype(self.prev_fastdtype) - else: - if torch.autocast_decrement_nesting() == 0: - torch.clear_autocast_cache() - torch.set_autocast_enabled(self.prev) - torch.set_autocast_gpu_dtype(self.prev_fastdtype) + if torch.autocast_decrement_nesting() == 0: + torch.clear_autocast_cache() + torch.set_autocast_enabled(self.device, self.prev) + torch.set_autocast_dtype(self.device, self.prev_fastdtype) torch.set_autocast_cache_enabled(self.prev_cache_enabled) return False diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 5fedfb9be4b03..32a07426db773 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -571,6 +571,24 @@ static PyObject* is_any_autocast_enabled(PyObject* _unused, PyObject* arg) { END_HANDLE_TH_ERRORS } +static PyObject* is_autocast_available( + PyObject* _unused, + PyObject* args, + PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser( + {"_is_autocast_available(c10::string_view device_type)"}); + ParsedArgs<1> parsed_args; + auto r = parser.parse(args, kwargs, parsed_args); + auto device_type = at::Device(r.string(0)).type(); + if (at::autocast::is_autocast_available(device_type)) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + static PyObject* set_autocast_cpu_enabled(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS TORCH_CHECK_TYPE( @@ -1228,6 +1246,10 @@ static PyMethodDef methods[] = { // NOLINT METH_VARARGS | METH_KEYWORDS, nullptr}, {"_is_any_autocast_enabled", is_any_autocast_enabled, METH_NOARGS, nullptr}, + {"_is_autocast_available", + castPyCFunctionWithKeywords(is_autocast_available), + METH_VARARGS | METH_KEYWORDS, + nullptr}, {"clear_autocast_cache", clear_autocast_cache, METH_NOARGS, nullptr}, {"set_autocast_cpu_enabled", set_autocast_cpu_enabled, METH_O, nullptr}, {"is_autocast_cpu_enabled", is_autocast_cpu_enabled, METH_NOARGS, nullptr}, diff --git a/torch/utils/backend_registration.py b/torch/utils/backend_registration.py index d2d2b1cb8977c..1fda089204da1 100644 --- a/torch/utils/backend_registration.py +++ b/torch/utils/backend_registration.py @@ -36,20 +36,6 @@ def rename_privateuse1_backend(backend_name: str) -> None: (1) ``get_amp_supported_dtype() -> List[torch.dtype]`` get the supported dtypes on your "foo" device in AMP, maybe the "foo" device supports one more dtype. - (2) ``is_autocast_enabled() -> bool`` - check the AMP is enabled or not on your "foo" device. - - (3) ``get_autocast_dtype() -> torch.dtype`` - get the supported dtype on your "foo" device in AMP, which is set by ``set_autocast_dtype`` or the - default dtype, and the default dtype is ``torch.float16``. - - (4) ``set_autocast_enabled(bool) -> None`` - enable the AMP or not on your "foo" device. - - (5) ``set_autocast_dtype(dtype) -> None`` - set the supported dtype on your "foo" device in AMP, and the dtype be contained in the dtypes got - from ``get_amp_supported_dtype``. - Note(random): If you want to support to set seed for your device, BackendModule needs to have the following API's: (1) ``_is_in_bad_fork() -> bool`` diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index fc536dd546ce5..ca0e39d53793f 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -194,7 +194,7 @@ def set_device_states(devices, states) -> None: def _get_autocast_kwargs(device="cuda"): - if _supports_autocast(device): + if torch._C._is_autocast_available(device): device_autocast_kwargs = { "enabled": torch.is_autocast_enabled(device), "dtype": torch.get_autocast_dtype(device), @@ -211,10 +211,6 @@ def _get_autocast_kwargs(device="cuda"): return device_autocast_kwargs, cpu_autocast_kwargs -def _supports_autocast(device): - device_module = _get_device_module(device) - return device == "cuda" or (hasattr(device_module, "is_autocast_enabled") - and hasattr(device_module, "get_autocast_dtype")) class CheckpointFunction(torch.autograd.Function): @staticmethod @@ -293,7 +289,7 @@ def backward(ctx, *args): device_autocast_ctx = device_module.amp.autocast( **ctx.device_autocast_kwargs - ) if _supports_autocast(ctx.device) else contextlib.nullcontext() + ) if torch._C._is_autocast_available(ctx.device) else contextlib.nullcontext() with torch.enable_grad(), device_autocast_ctx, \ torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): outputs = ctx.run_function(*detached_inputs) @@ -1400,7 +1396,7 @@ def recompute_fn(*inputs): device_autocast_ctx = device_module.amp.autocast( **device_autocast_kwargs - ) if _supports_autocast(device) else contextlib.nullcontext() + ) if torch._C._is_autocast_available(device) else contextlib.nullcontext() with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \ recompute_context: fn(*args, **kwargs)