Skip to content

Commit

Permalink
refactor autocast python APIs
Browse files Browse the repository at this point in the history
ghstack-source-id: 0b9883094b4639c3147fe43cc1807e0965cdcd28
Pull Request resolved: #124479
  • Loading branch information
guangyey committed Apr 21, 2024
1 parent 2a0b179 commit 061c456
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 138 deletions.
4 changes: 2 additions & 2 deletions test/test_autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,8 @@ def test_cache_disabled(self):

class TestTorchAutocast(TestCase):
def test_autocast_fast_dtype(self):
gpu_fast_dtype = torch.get_autocast_gpu_dtype()
cpu_fast_dtype = torch.get_autocast_cpu_dtype()
gpu_fast_dtype = torch.get_autocast_dtype('cuda')
cpu_fast_dtype = torch.get_autocast_dtype('cpu')
self.assertEqual(gpu_fast_dtype, torch.half)
self.assertEqual(cpu_fast_dtype, torch.bfloat16)

Expand Down
116 changes: 12 additions & 104 deletions torch/amp/autocast_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,34 +200,15 @@ def __init__(
return
self.device = device_type
self.custom_backend_name = torch._C._get_privateuse1_backend_name()
if self.device == "cuda":
self.fast_dtype = torch.get_autocast_gpu_dtype()
elif self.device == "cpu":
self.fast_dtype = torch.get_autocast_cpu_dtype()
elif self.device == "xpu":
self.fast_dtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined]
elif self.device == "ipu":
self.fast_dtype = torch.get_autocast_ipu_dtype() # type: ignore[attr-defined]
elif self.device == "hpu":
self.fast_dtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined]
elif self.device == "xla":
self.fast_dtype = torch.get_autocast_xla_dtype() # type: ignore[attr-defined]
elif self.device == self.custom_backend_name:
self.fast_dtype = torch.get_autocast_dtype(self.device)
if self.device == self.custom_backend_name:
necessary_funcs = [
"is_autocast_enabled",
"set_autocast_enabled",
"get_autocast_dtype",
"set_autocast_dtype",
"get_amp_supported_dtype",
]
message = f"Tried to use AMP with the `{self.custom_backend_name}` backend, but the backend has not "
message += "registered a module or the module miss some necessary funcs. The backend should register "
message += "a module by `torch._register_device_module`, and the module must have these funcs: \n"
message += "`is_autocast_enabled() -> bool`, `set_autocast_enabled(bool) -> None`, "
message += "`get_autocast_dtype() -> torch.dtype`, `set_autocast_dtype(torch.dtype) "
message += (
"-> None` and `get_amp_supported_dtype() -> List[torch.dtype]`. \n"
)
message += "`get_amp_supported_dtype() -> List[torch.dtype]`. \n"

assert hasattr(torch, self.custom_backend_name), message
self.custom_device_mod = getattr(torch, self.custom_backend_name)
Expand All @@ -236,11 +217,6 @@ def __init__(
message + f"But the func `{func}` is missing. \n"
)

self.fast_dtype = self.custom_device_mod.get_autocast_dtype()
else:
raise RuntimeError(
f"User specified an unsupported autocast device_type '{self.device}'"
)
self._cache_enabled = torch.is_autocast_cache_enabled()
if (
enabled
Expand Down Expand Up @@ -323,90 +299,22 @@ def __enter__(self):
return self

self.prev_cache_enabled = torch.is_autocast_cache_enabled()
if self.device == "cpu":
self.prev = torch.is_autocast_cpu_enabled()
self.prev_fastdtype = torch.get_autocast_cpu_dtype()
torch.set_autocast_cpu_enabled(self._enabled)
torch.set_autocast_cpu_dtype(self.fast_dtype) # type: ignore[arg-type]
torch.autocast_increment_nesting()
elif self.device == "xpu":
self.prev = torch.xpu.is_autocast_xpu_enabled() # type: ignore[attr-defined]
self.prev_fastdtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined]
torch.xpu.set_autocast_xpu_enabled(self._enabled) # type: ignore[attr-defined]
torch.xpu.set_autocast_xpu_dtype(self.fast_dtype) # type: ignore[attr-defined]
torch.autocast_increment_nesting()
elif self.device == "ipu":
self.prev = torch.is_autocast_ipu_enabled() # type: ignore[attr-defined]
self.prev_fastdtype = torch.get_autocast_ipu_dtype() # type: ignore[attr-defined]
torch.set_autocast_ipu_enabled(self._enabled) # type: ignore[attr-defined]
torch.set_autocast_ipu_dtype(self.fast_dtype) # type: ignore[attr-defined]
torch.autocast_increment_nesting()
elif self.device == "hpu":
self.prev = torch.hpu.is_autocast_hpu_enabled() # type: ignore[attr-defined]
self.prev_fastdtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined]
torch.hpu.set_autocast_hpu_enabled(self._enabled) # type: ignore[attr-defined]
torch.hpu.set_autocast_hpu_dtype(self.fast_dtype) # type: ignore[attr-defined]
torch.autocast_increment_nesting()
elif self.device == "xla":
self.prev = torch.is_autocast_xla_enabled() # type: ignore[attr-defined]
self.prev_fastdtype = torch.get_autocast_xla_dtype() # type: ignore[attr-defined]
torch.set_autocast_xla_enabled(self._enabled) # type: ignore[attr-defined]
torch.set_autocast_xla_dtype(self.fast_dtype) # type: ignore[attr-defined]
torch.autocast_increment_nesting()
elif self.device == self.custom_backend_name:
self.prev = self.custom_device_mod.is_autocast_enabled()
self.prev_fastdtype = self.custom_device_mod.get_autocast_dtype()
self.custom_device_mod.set_autocast_enabled(self._enabled)
self.custom_device_mod.set_autocast_dtype(self.fast_dtype)
torch.autocast_increment_nesting()
else:
self.prev = torch.is_autocast_enabled()
self.prev_fastdtype = torch.get_autocast_gpu_dtype()
torch.set_autocast_gpu_dtype(self.fast_dtype) # type: ignore[arg-type]
torch.set_autocast_enabled(self._enabled)
torch.autocast_increment_nesting()
self.prev = torch.is_autocast_enabled(self.device)
self.prev_fastdtype = torch.get_autocast_dtype(self.device)
torch.set_autocast_enabled(self.device, self._enabled)
torch.set_autocast_dtype(self.device, self.fast_dtype)
torch.autocast_increment_nesting()
torch.set_autocast_cache_enabled(self._cache_enabled)

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
if torch._jit_internal.is_scripting():
return

# Drop the cache when we exit to a nesting level that's outside any instance of autocast.
if self.device == "cpu":
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_cpu_enabled(self.prev)
torch.set_autocast_cpu_dtype(self.prev_fastdtype)
elif self.device == "xpu":
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.xpu.set_autocast_xpu_enabled(self.prev) # type: ignore[attr-defined]
torch.xpu.set_autocast_xpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
elif self.device == "ipu":
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_ipu_enabled(self.prev) # type: ignore[attr-defined]
torch.set_autocast_ipu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
elif self.device == "hpu":
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.hpu.set_autocast_hpu_enabled(self.prev) # type: ignore[attr-defined]
torch.hpu.set_autocast_hpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
elif self.device == "xla":
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_xla_enabled(self.prev) # type: ignore[attr-defined]
torch.set_autocast_xla_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
elif self.device == self.custom_backend_name:
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
self.custom_device_mod.set_autocast_enabled(self.prev)
self.custom_device_mod.set_autocast_dtype(self.prev_fastdtype)
else:
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_enabled(self.prev)
torch.set_autocast_gpu_dtype(self.prev_fastdtype)
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_enabled(self.device, self.prev)
torch.set_autocast_dtype(self.device, self.prev_fastdtype)
torch.set_autocast_cache_enabled(self.prev_cache_enabled)
return False

Expand Down
6 changes: 3 additions & 3 deletions torch/cuda/amp/autocast_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,12 @@ def custom_fwd(fwd=None, *, cast_inputs=None):

@functools.wraps(fwd)
def decorate_fwd(*args, **kwargs):
args[0]._dtype = torch.get_autocast_gpu_dtype()
args[0]._dtype = torch.get_autocast_dtype('cuda')
if cast_inputs is None:
args[0]._fwd_used_autocast = torch.is_autocast_enabled()
args[0]._fwd_used_autocast = torch.is_autocast_enabled("cuda")
return fwd(*args, **kwargs)
else:
autocast_context = torch.is_autocast_enabled()
autocast_context = torch.is_autocast_enabled("cuda")
args[0]._fwd_used_autocast = False
if autocast_context:
with autocast(enabled=False):
Expand Down
14 changes: 0 additions & 14 deletions torch/utils/backend_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,6 @@ def rename_privateuse1_backend(backend_name: str) -> None:
(1) ``get_amp_supported_dtype() -> List[torch.dtype]``
get the supported dtypes on your "foo" device in AMP, maybe the "foo" device supports one more dtype.
(2) ``is_autocast_enabled() -> bool``
check the AMP is enabled or not on your "foo" device.
(3) ``get_autocast_dtype() -> torch.dtype``
get the supported dtype on your "foo" device in AMP, which is set by ``set_autocast_dtype`` or the
default dtype, and the default dtype is ``torch.float16``.
(4) ``set_autocast_enabled(bool) -> None``
enable the AMP or not on your "foo" device.
(5) ``set_autocast_dtype(dtype) -> None``
set the supported dtype on your "foo" device in AMP, and the dtype be contained in the dtypes got
from ``get_amp_supported_dtype``.
Note(random): If you want to support to set seed for your device, BackendModule needs to have the following API's:
(1) ``_is_in_bad_fork() -> bool``
Expand Down
21 changes: 6 additions & 15 deletions torch/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,25 +194,16 @@ def set_device_states(devices, states) -> None:


def _get_autocast_kwargs(device="cuda"):
if device == "cuda":
device_autocast_kwargs = {
"enabled": torch.is_autocast_enabled(),
"dtype": torch.get_autocast_gpu_dtype(),
assert device != 'cpu'
device_autocast_kwargs = {
"enabled": torch.is_autocast_enabled(device),
"dtype": torch.get_autocast_dtype(device),
"cache_enabled": torch.is_autocast_cache_enabled(),
}
elif _supports_autocast(device):
device_module = _get_device_module(device)
device_autocast_kwargs = {
"enabled": device_module.is_autocast_enabled(),
"dtype": device_module.get_autocast_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled(),
}
else:
device_autocast_kwargs = None

cpu_autocast_kwargs = {
"enabled": torch.is_autocast_cpu_enabled(),
"dtype": torch.get_autocast_cpu_dtype(),
"enabled": torch.is_autocast_enabled('cpu'),
"dtype": torch.get_autocast_dtype('cpu'),
"cache_enabled": torch.is_autocast_cache_enabled(),
}

Expand Down

0 comments on commit 061c456

Please sign in to comment.