Skip to content

Commit

Permalink
add new API torch.amp.is_autocast_available
Browse files Browse the repository at this point in the history
ghstack-source-id: 0cf1a3696fe5123bf39b81dd33f0e53e665857ce
Pull Request resolved: #124938
  • Loading branch information
guangyey committed Apr 26, 2024
1 parent e04c7b1 commit 63ea78a
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 6 deletions.
4 changes: 4 additions & 0 deletions docs/source/amp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ For CUDA and CPU, APIs are also provided separately:

Autocasting
^^^^^^^^^^^
.. currentmodule:: torch.amp.autocast_mode

.. autofunction:: is_autocast_available

.. currentmodule:: torch

.. autoclass:: autocast
Expand Down
2 changes: 2 additions & 0 deletions test/test_autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,8 @@ def test_invalid_device(self):
with self.assertRaisesRegex(RuntimeError, msg):
with torch.autocast(device_type=dev):
_ = torch.tensor(1)
with self.assertRaisesRegex(RuntimeError, msg):
assert torch.amp.is_autocast_available(device_type=dev)


if __name__ == "__main__":
Expand Down
7 changes: 6 additions & 1 deletion torch/amp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
from .autocast_mode import _enter_autocast, _exit_autocast, autocast
from .autocast_mode import (
_enter_autocast,
_exit_autocast,
autocast,
is_autocast_available,
)
from .grad_scaler import GradScaler
16 changes: 14 additions & 2 deletions torch/amp/autocast_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,19 @@
import torch
from torch.types import _dtype

__all__ = ["autocast_decorator", "autocast"]
__all__ = ["autocast_decorator", "autocast", "is_autocast_available"]


def is_autocast_available(device_type: str) -> bool:
r"""
Return a bool indicating if autocast is available on :attr:`device_type`.
Args:
device_type(str): Device type to use. Possible values are: 'cuda', 'cpu', 'xpu' and so on.
The type is the same as the `type` attribute of a :class:`torch.device`.
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
"""
return torch._C._is_autocast_available(device_type)


def autocast_decorator(autocast_instance, func):
Expand Down Expand Up @@ -199,7 +211,7 @@ def __init__(
assert dtype is not None
return
self.device = device_type
if not torch._C._is_autocast_available(self.device):
if not is_autocast_available(self.device):
raise RuntimeError(
f"User specified an unsupported autocast device_type '{self.device}'"
)
Expand Down
6 changes: 3 additions & 3 deletions torch/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def set_device_states(devices, states) -> None:


def _get_autocast_kwargs(device="cuda"):
if torch._C._is_autocast_available(device):
if torch.amp.is_autocast_available(device):
device_autocast_kwargs = {
"enabled": torch.is_autocast_enabled(device),
"dtype": torch.get_autocast_dtype(device),
Expand Down Expand Up @@ -289,7 +289,7 @@ def backward(ctx, *args):

device_autocast_ctx = device_module.amp.autocast(
**ctx.device_autocast_kwargs
) if torch._C._is_autocast_available(ctx.device) else contextlib.nullcontext()
) if torch.amp.is_autocast_available(ctx.device) else contextlib.nullcontext()
with torch.enable_grad(), device_autocast_ctx, \
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
outputs = ctx.run_function(*detached_inputs)
Expand Down Expand Up @@ -1396,7 +1396,7 @@ def recompute_fn(*inputs):

device_autocast_ctx = device_module.amp.autocast(
**device_autocast_kwargs
) if torch._C._is_autocast_available(device) else contextlib.nullcontext()
) if torch.amp.is_autocast_available(device) else contextlib.nullcontext()
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
recompute_context:
fn(*args, **kwargs)
Expand Down

0 comments on commit 63ea78a

Please sign in to comment.