Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor autocast python APIs #124479

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 10 additions & 2 deletions aten/src/ATen/autocast_mode.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,20 @@ inline DispatchKey get_autocast_dispatch_key_from_device_type(
}
}

inline at::ScalarType get_lower_precision_fp_from_device_type(
c10::DeviceType device_type) {
inline bool is_autocast_available(c10::DeviceType device_type) {
if (device_type == at::kCPU || device_type == at::kCUDA ||
device_type == at::kXPU || device_type == at::kIPU ||
device_type == at::kHPU || device_type == at::kXLA ||
device_type == at::kPrivateUse1) {
return true;
} else {
return false;
}
}

inline at::ScalarType get_lower_precision_fp_from_device_type(
c10::DeviceType device_type) {
if (is_autocast_available(device_type)) {
return get_autocast_dtype(device_type);
} else {
throw std::runtime_error(
Expand Down
2 changes: 1 addition & 1 deletion test/test_autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def test_autocast_fast_dtype(self):

def test_invalid_device(self):
dev = "not a real device"
msg = f"unsupported autocast device_type '{dev}'"
msg = f"Invalid device string: '{dev}'"
with self.assertRaisesRegex(RuntimeError, msg):
with torch.autocast(device_type=dev):
_ = torch.tensor(1)
Expand Down
1 change: 1 addition & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,7 @@ def clear_autocast_cache() -> None: ...
def set_autocast_cpu_enabled(enabled: _bool) -> None: ...
def is_autocast_cpu_enabled() -> _bool: ...
def _is_any_autocast_enabled() -> _bool: ...
def _is_autocast_available(device_type: str) -> _bool: ...
def set_autocast_cpu_dtype(dtype: _dtype) -> None: ...
def set_autocast_gpu_dtype(dtype: _dtype) -> None: ...
def get_autocast_cpu_dtype() -> _dtype: ...
Expand Down
120 changes: 16 additions & 104 deletions torch/amp/autocast_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,35 +199,20 @@ def __init__(
assert dtype is not None
return
self.device = device_type
if not torch._C._is_autocast_available(self.device):
raise RuntimeError(
f"User specified an unsupported autocast device_type '{self.device}'"
)
self.custom_backend_name = torch._C._get_privateuse1_backend_name()
if self.device == "cuda":
self.fast_dtype = torch.get_autocast_gpu_dtype()
elif self.device == "cpu":
self.fast_dtype = torch.get_autocast_cpu_dtype()
elif self.device == "xpu":
self.fast_dtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined]
elif self.device == "ipu":
self.fast_dtype = torch.get_autocast_ipu_dtype() # type: ignore[attr-defined]
elif self.device == "hpu":
self.fast_dtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined]
elif self.device == "xla":
self.fast_dtype = torch.get_autocast_xla_dtype() # type: ignore[attr-defined]
elif self.device == self.custom_backend_name:
self.fast_dtype = torch.get_autocast_dtype(self.device)
if self.device == self.custom_backend_name:
necessary_funcs = [
"is_autocast_enabled",
"set_autocast_enabled",
"get_autocast_dtype",
"set_autocast_dtype",
"get_amp_supported_dtype",
]
message = f"Tried to use AMP with the `{self.custom_backend_name}` backend, but the backend has not "
message += "registered a module or the module miss some necessary funcs. The backend should register "
message += "a module by `torch._register_device_module`, and the module must have these funcs: \n"
message += "`is_autocast_enabled() -> bool`, `set_autocast_enabled(bool) -> None`, "
message += "`get_autocast_dtype() -> torch.dtype`, `set_autocast_dtype(torch.dtype) "
message += (
"-> None` and `get_amp_supported_dtype() -> List[torch.dtype]`. \n"
)
message += "`get_amp_supported_dtype() -> List[torch.dtype]`. \n"

assert hasattr(torch, self.custom_backend_name), message
self.custom_device_mod = getattr(torch, self.custom_backend_name)
Expand All @@ -236,11 +221,6 @@ def __init__(
message + f"But the func `{func}` is missing. \n"
)

self.fast_dtype = self.custom_device_mod.get_autocast_dtype()
else:
raise RuntimeError(
f"User specified an unsupported autocast device_type '{self.device}'"
)
self._cache_enabled = torch.is_autocast_cache_enabled()
if (
enabled
Expand Down Expand Up @@ -323,90 +303,22 @@ def __enter__(self):
return self

self.prev_cache_enabled = torch.is_autocast_cache_enabled()
if self.device == "cpu":
self.prev = torch.is_autocast_cpu_enabled()
self.prev_fastdtype = torch.get_autocast_cpu_dtype()
torch.set_autocast_cpu_enabled(self._enabled)
torch.set_autocast_cpu_dtype(self.fast_dtype) # type: ignore[arg-type]
torch.autocast_increment_nesting()
elif self.device == "xpu":
self.prev = torch.xpu.is_autocast_xpu_enabled() # type: ignore[attr-defined]
self.prev_fastdtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined]
torch.xpu.set_autocast_xpu_enabled(self._enabled) # type: ignore[attr-defined]
torch.xpu.set_autocast_xpu_dtype(self.fast_dtype) # type: ignore[attr-defined]
torch.autocast_increment_nesting()
elif self.device == "ipu":
self.prev = torch.is_autocast_ipu_enabled() # type: ignore[attr-defined]
self.prev_fastdtype = torch.get_autocast_ipu_dtype() # type: ignore[attr-defined]
torch.set_autocast_ipu_enabled(self._enabled) # type: ignore[attr-defined]
torch.set_autocast_ipu_dtype(self.fast_dtype) # type: ignore[attr-defined]
torch.autocast_increment_nesting()
elif self.device == "hpu":
self.prev = torch.hpu.is_autocast_hpu_enabled() # type: ignore[attr-defined]
self.prev_fastdtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined]
torch.hpu.set_autocast_hpu_enabled(self._enabled) # type: ignore[attr-defined]
torch.hpu.set_autocast_hpu_dtype(self.fast_dtype) # type: ignore[attr-defined]
torch.autocast_increment_nesting()
elif self.device == "xla":
self.prev = torch.is_autocast_xla_enabled() # type: ignore[attr-defined]
self.prev_fastdtype = torch.get_autocast_xla_dtype() # type: ignore[attr-defined]
torch.set_autocast_xla_enabled(self._enabled) # type: ignore[attr-defined]
torch.set_autocast_xla_dtype(self.fast_dtype) # type: ignore[attr-defined]
torch.autocast_increment_nesting()
elif self.device == self.custom_backend_name:
self.prev = self.custom_device_mod.is_autocast_enabled()
self.prev_fastdtype = self.custom_device_mod.get_autocast_dtype()
self.custom_device_mod.set_autocast_enabled(self._enabled)
self.custom_device_mod.set_autocast_dtype(self.fast_dtype)
torch.autocast_increment_nesting()
else:
self.prev = torch.is_autocast_enabled()
self.prev_fastdtype = torch.get_autocast_gpu_dtype()
torch.set_autocast_gpu_dtype(self.fast_dtype) # type: ignore[arg-type]
torch.set_autocast_enabled(self._enabled)
torch.autocast_increment_nesting()
self.prev = torch.is_autocast_enabled(self.device)
self.prev_fastdtype = torch.get_autocast_dtype(self.device)
torch.set_autocast_enabled(self.device, self._enabled)
torch.set_autocast_dtype(self.device, self.fast_dtype) # type: ignore[arg-type]
torch.autocast_increment_nesting()
torch.set_autocast_cache_enabled(self._cache_enabled)

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
if torch._jit_internal.is_scripting():
return

# Drop the cache when we exit to a nesting level that's outside any instance of autocast.
if self.device == "cpu":
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_cpu_enabled(self.prev)
torch.set_autocast_cpu_dtype(self.prev_fastdtype)
elif self.device == "xpu":
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.xpu.set_autocast_xpu_enabled(self.prev) # type: ignore[attr-defined]
torch.xpu.set_autocast_xpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
elif self.device == "ipu":
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_ipu_enabled(self.prev) # type: ignore[attr-defined]
torch.set_autocast_ipu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
elif self.device == "hpu":
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.hpu.set_autocast_hpu_enabled(self.prev) # type: ignore[attr-defined]
torch.hpu.set_autocast_hpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
elif self.device == "xla":
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_xla_enabled(self.prev) # type: ignore[attr-defined]
torch.set_autocast_xla_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
elif self.device == self.custom_backend_name:
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
self.custom_device_mod.set_autocast_enabled(self.prev)
self.custom_device_mod.set_autocast_dtype(self.prev_fastdtype)
else:
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_enabled(self.prev)
torch.set_autocast_gpu_dtype(self.prev_fastdtype)
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_enabled(self.device, self.prev)
torch.set_autocast_dtype(self.device, self.prev_fastdtype)
torch.set_autocast_cache_enabled(self.prev_cache_enabled)
return False

Expand Down
22 changes: 22 additions & 0 deletions torch/csrc/autograd/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,24 @@ static PyObject* is_any_autocast_enabled(PyObject* _unused, PyObject* arg) {
END_HANDLE_TH_ERRORS
}

static PyObject* is_autocast_available(
PyObject* _unused,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser(
{"_is_autocast_available(c10::string_view device_type)"});
ParsedArgs<1> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
auto device_type = at::Device(r.string(0)).type();
if (at::autocast::is_autocast_available(device_type)) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}

static PyObject* set_autocast_cpu_enabled(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK_TYPE(
Expand Down Expand Up @@ -1228,6 +1246,10 @@ static PyMethodDef methods[] = { // NOLINT
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_is_any_autocast_enabled", is_any_autocast_enabled, METH_NOARGS, nullptr},
{"_is_autocast_available",
castPyCFunctionWithKeywords(is_autocast_available),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"clear_autocast_cache", clear_autocast_cache, METH_NOARGS, nullptr},
{"set_autocast_cpu_enabled", set_autocast_cpu_enabled, METH_O, nullptr},
{"is_autocast_cpu_enabled", is_autocast_cpu_enabled, METH_NOARGS, nullptr},
Expand Down
14 changes: 0 additions & 14 deletions torch/utils/backend_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,6 @@ def rename_privateuse1_backend(backend_name: str) -> None:
(1) ``get_amp_supported_dtype() -> List[torch.dtype]``
get the supported dtypes on your "foo" device in AMP, maybe the "foo" device supports one more dtype.

(2) ``is_autocast_enabled() -> bool``
check the AMP is enabled or not on your "foo" device.

(3) ``get_autocast_dtype() -> torch.dtype``
get the supported dtype on your "foo" device in AMP, which is set by ``set_autocast_dtype`` or the
default dtype, and the default dtype is ``torch.float16``.

(4) ``set_autocast_enabled(bool) -> None``
enable the AMP or not on your "foo" device.

(5) ``set_autocast_dtype(dtype) -> None``
set the supported dtype on your "foo" device in AMP, and the dtype be contained in the dtypes got
from ``get_amp_supported_dtype``.

Note(random): If you want to support to set seed for your device, BackendModule needs to have the following API's:

(1) ``_is_in_bad_fork() -> bool``
Expand Down
10 changes: 3 additions & 7 deletions torch/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def set_device_states(devices, states) -> None:


def _get_autocast_kwargs(device="cuda"):
if _supports_autocast(device):
if torch._C._is_autocast_available(device):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API sounds generally useful. Should we expose it as a public API on torch.amp.is_autocast_available() ?
This can be a follow up PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, thanks. I will add it in a follow-up PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition, I have a quick question. Is it reasonable to add torch.xpu.amp.autocast which is equivalent to torch.autocast('xpu')? PyTorch assumes that each backend has its own autocast in many places, like

device_autocast_ctx = device_module.amp.autocast(

If not, xpu will be not supported in torch.xxx.amp.autocast scenario. Note torch.cuda.amp.autocast is not equivalent to torch.autocast('cuda') as the former one can handle JIT path. If use torch.autocast(device_type), it can not handle CUDA/CPU JIT path.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The answer here is a bit tricky. The very honest answer is:

  • If you really want to do it, I understand you might need TorchScript (TS) support and I'm not going to block it on principle.
  • But think most of our users don't care about that, I personally don't know about TS enough and I have limited bw to review PRs. So I don't personally want to sign up to review this.
  • So you need to find someone to review your TS-related. Given that there are 0 engineers from the compiler team working on TS, I have no idea who to point you to to help with this.
  • I unfortunately don't know what is good solution to recommend from here. If you don't have any user hard-blocked on this, I would suggest dropping it as the simplest thing to do. If you have users that don't have any other option but using TS, you most likely will need to escalate this through the compiler team.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't like to support TorchScript on XPU. We just intend to add torch.xpu.amp.autocast, a wrapper of torch.amp.autocast specific to xpu, to support torch.xxx.amp.autocast code-style scenario on eager mode, like the below code defined in torch.xpu.amp.autocast_mode.py

class autocast(torch.amp.autocast_mode.autocast):
    r"""
    See :class:`torch.autocast`.
    ``torch.xpu.amp.autocast(args...)`` is equivalent to ``torch.autocast("xpu", args...)``
    """

    def __init__(self, enabled=True, dtype=torch.bfloat16, cache_enabled=True):
        super().__init__(
            "xpu", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled
        )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what is the benefit of such API? We are adding generic API so that users don't have to use these one-off APIs that make their code locked onto a specific device.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD
This is a history issue, indeed. Before this PR, user has been used to adopting torch.cuda.amp.autocast in their scripts.
We followed CUDA to provide torch.xpu.amp.autocast as a public API in our extension.
Although we have torch.amp.autocast now, it's difficult to ask all users to replace with new API immediately.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure to follow.
The existing code contains torch.cuda.XXX calls. So user code must be changed to support xpu. Adding a new torch.xpu.XXX API doesn't remove the need for user-code change.
And if we're changing user code, we might as well change it to use torch.amp.XXX API directly.
Am I missing something here?

device_autocast_kwargs = {
"enabled": torch.is_autocast_enabled(device),
"dtype": torch.get_autocast_dtype(device),
Expand All @@ -211,10 +211,6 @@ def _get_autocast_kwargs(device="cuda"):

return device_autocast_kwargs, cpu_autocast_kwargs

def _supports_autocast(device):
device_module = _get_device_module(device)
return device == "cuda" or (hasattr(device_module, "is_autocast_enabled")
and hasattr(device_module, "get_autocast_dtype"))

class CheckpointFunction(torch.autograd.Function):
@staticmethod
Expand Down Expand Up @@ -293,7 +289,7 @@ def backward(ctx, *args):

device_autocast_ctx = device_module.amp.autocast(
**ctx.device_autocast_kwargs
) if _supports_autocast(ctx.device) else contextlib.nullcontext()
) if torch._C._is_autocast_available(ctx.device) else contextlib.nullcontext()
with torch.enable_grad(), device_autocast_ctx, \
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
outputs = ctx.run_function(*detached_inputs)
Expand Down Expand Up @@ -1400,7 +1396,7 @@ def recompute_fn(*inputs):

device_autocast_ctx = device_module.amp.autocast(
**device_autocast_kwargs
) if _supports_autocast(device) else contextlib.nullcontext()
) if torch._C._is_autocast_available(device) else contextlib.nullcontext()
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
recompute_context:
fn(*args, **kwargs)
Expand Down