Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Half support for CPU autocast on eager mode #112484

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
62 changes: 48 additions & 14 deletions test/test_autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,16 @@ def tearDown(self):
del self.autocast_lists
super().tearDown()

def _run_autocast_outofplace(self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None):
def _run_autocast_outofplace(
self,
op,
args,
run_as_type,
out_type=None,
module=torch,
add_kwargs=None,
amp_dtype=torch.bfloat16,
):
# helper to cast args
def cast(val, to_type):
if isinstance(val, torch.Tensor):
Expand All @@ -31,7 +40,7 @@ def cast(val, to_type):
add_kwargs = {}

self.assertFalse(torch.is_autocast_cpu_enabled())
with torch.cpu.amp.autocast():
with torch.cpu.amp.autocast(dtype=amp_dtype):
self.assertTrue(torch.is_autocast_cpu_enabled())
out_type = out_type if out_type is not None else run_as_type
output = output_method = None
Expand Down Expand Up @@ -92,36 +101,61 @@ def args_maybe_kwargs(self, op_with_args):
return op_with_args[0], op_with_args[1], op_with_args[2]

def test_autocast_torch_expect_builtin_promote(self):
for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote:
self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type)
for op, args1, args2, out_type in self.autocast_lists.torch_expect_builtin_promote:
self._run_autocast_outofplace(op, args1, torch.float32, out_type=out_type)
self._run_autocast_outofplace(op, args2, torch.float32, out_type=out_type, amp_dtype=torch.float16)

def test_autocast_methods_expect_builtin_promote(self):
for op, args, out_type in self.autocast_lists.methods_expect_builtin_promote:
self._run_autocast_outofplace(op, args, torch.float32, module=None, out_type=out_type)
for op, args1, args2, out_type in self.autocast_lists.methods_expect_builtin_promote:
self._run_autocast_outofplace(op, args1, torch.float32, module=None, out_type=out_type)
self._run_autocast_outofplace(op, args2, torch.float32, module=None, out_type=out_type, amp_dtype=torch.float16)

def test_autocast_torch_bf16(self):
for op_with_args in self.autocast_lists.torch_bf16:
def test_autocast_torch_16(self):
for op_with_args in self.autocast_lists.torch_16:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(op, args, torch.bfloat16, add_kwargs=maybe_kwargs)
self._run_autocast_outofplace(op, args, torch.float16, add_kwargs=maybe_kwargs, amp_dtype=torch.float16)

def test_autocast_nn_bf16(self):
for op_with_args in self.autocast_lists.nn_bf16:
def test_autocast_nn_16(self):
for op_with_args in self.autocast_lists.nn_16:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(op, args, torch.bfloat16, module=torch._C._nn, add_kwargs=maybe_kwargs)
self._run_autocast_outofplace(
op, args, torch.bfloat16, module=torch._C._nn, add_kwargs=maybe_kwargs
)
self._run_autocast_outofplace(
op,
args,
torch.float16,
module=torch._C._nn,
add_kwargs=maybe_kwargs,
amp_dtype=torch.float16,
)

def test_autocast_torch_fp32(self):
for op_with_args in self.autocast_lists.torch_fp32:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs)
self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs, amp_dtype=torch.float16)

def test_autocast_nn_fp32(self):
for op_with_args in self.autocast_lists.nn_fp32:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs)
self._run_autocast_outofplace(
op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs
)
self._run_autocast_outofplace(
op,
args,
torch.float32,
module=torch._C._nn,
add_kwargs=maybe_kwargs,
amp_dtype=torch.float16,
)

def test_autocast_torch_need_autocast_promote(self):
for op, args in self.autocast_lists.torch_need_autocast_promote:
self._run_autocast_outofplace(op, args, torch.float32)
for op, args1, args2 in self.autocast_lists.torch_need_autocast_promote:
self._run_autocast_outofplace(op, args1, torch.float32)
self._run_autocast_outofplace(op, args2, torch.float32, amp_dtype=torch.float16)

@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
def test_autocast_rnn(self):
Expand Down
5 changes: 3 additions & 2 deletions torch/amp/autocast_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,12 @@ def __init__(
self._cache_enabled = cache_enabled

if self.device == "cpu":
supported_dtype = [torch.bfloat16]
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype and enabled:
error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "CPU Autocast only supports dtype of "
error_message += (
"CPU Autocast only supports dtype of torch.bfloat16 currently."
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
)
warnings.warn(error_message)
enabled = False
Expand Down
48 changes: 26 additions & 22 deletions torch/testing/_internal/autocast_test_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,9 @@ def __init__(self, dev):
mat1_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
mat2_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)

pointwise0_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
pointwise1_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)

dummy_dimsets = ((n,), (n, n), (n, n, n), (n, n, n, n), (n, n, n, n, n))

dummy_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev),)
Expand Down Expand Up @@ -275,29 +278,30 @@ def __init__(self, dev):
# Some ops implement built-in type promotion. These don't need autocasting,
# but autocasting relies on their promotion, so we include tests to double-check.
self.torch_expect_builtin_promote = [
("eq", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("ge", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("gt", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("le", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("lt", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("ne", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("add", pointwise0_fp32 + pointwise1_bf16, torch.float32),
("div", pointwise0_fp32 + pointwise1_bf16, torch.float32),
("mul", pointwise0_fp32 + pointwise1_bf16, torch.float32),
("eq", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("ge", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("gt", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("le", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("lt", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("ne", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("add", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
("div", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
("mul", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
]

self.methods_expect_builtin_promote = [
("__eq__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("__ge__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("__gt__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("__le__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("__lt__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("__ne__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("__add__", pointwise0_fp32 + pointwise1_bf16, torch.float32),
("__div__", pointwise0_fp32 + pointwise1_bf16, torch.float32),
("__mul__", pointwise0_fp32 + pointwise1_bf16, torch.float32),
("__eq__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__ge__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__gt__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__le__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__lt__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__ne__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__add__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
("__div__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
("__mul__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
]
# The remaining lists organize ops that autocast treats explicitly.
self.torch_bf16 = [
self.torch_16 = [
("conv1d", conv_args_fp32[0]),
("conv2d", conv_args_fp32[1]),
("conv3d", conv_args_fp32[2]),
Expand Down Expand Up @@ -337,7 +341,7 @@ def __init__(self, dev):
("triplet_margin_loss", mat0_bf16 + mat1_bf16 + mat2_bf16),
("binary_cross_entropy_with_logits", mat0_bf16 + (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)),
]
self.nn_bf16 = [
self.nn_16 = [
("linear", mat0_fp32 + mat1_fp32, {}),
]
self.nn_fp32 = [
Expand All @@ -358,6 +362,6 @@ def __init__(self, dev):
("huber_loss", mat0_bf16 + mat1_bf16),
]
self.torch_need_autocast_promote = [
("cat", (pointwise0_bf16 + pointwise1_fp32,)),
("stack", (pointwise0_bf16 + pointwise1_fp32,)),
("cat", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)),
("stack", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)),
]