diff --git a/test/test_autocast.py b/test/test_autocast.py index 256aa627b5802..85998107a0629 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -17,7 +17,16 @@ def tearDown(self): del self.autocast_lists super().tearDown() - def _run_autocast_outofplace(self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None): + def _run_autocast_outofplace( + self, + op, + args, + run_as_type, + out_type=None, + module=torch, + add_kwargs=None, + amp_dtype=torch.bfloat16, + ): # helper to cast args def cast(val, to_type): if isinstance(val, torch.Tensor): @@ -31,7 +40,7 @@ def cast(val, to_type): add_kwargs = {} self.assertFalse(torch.is_autocast_cpu_enabled()) - with torch.cpu.amp.autocast(): + with torch.cpu.amp.autocast(dtype=amp_dtype): self.assertTrue(torch.is_autocast_cpu_enabled()) out_type = out_type if out_type is not None else run_as_type output = output_method = None @@ -92,36 +101,61 @@ def args_maybe_kwargs(self, op_with_args): return op_with_args[0], op_with_args[1], op_with_args[2] def test_autocast_torch_expect_builtin_promote(self): - for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote: - self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type) + for op, args1, args2, out_type in self.autocast_lists.torch_expect_builtin_promote: + self._run_autocast_outofplace(op, args1, torch.float32, out_type=out_type) + self._run_autocast_outofplace(op, args2, torch.float32, out_type=out_type, amp_dtype=torch.float16) def test_autocast_methods_expect_builtin_promote(self): - for op, args, out_type in self.autocast_lists.methods_expect_builtin_promote: - self._run_autocast_outofplace(op, args, torch.float32, module=None, out_type=out_type) + for op, args1, args2, out_type in self.autocast_lists.methods_expect_builtin_promote: + self._run_autocast_outofplace(op, args1, torch.float32, module=None, out_type=out_type) + self._run_autocast_outofplace(op, args2, torch.float32, module=None, out_type=out_type, amp_dtype=torch.float16) - def test_autocast_torch_bf16(self): - for op_with_args in self.autocast_lists.torch_bf16: + def test_autocast_torch_16(self): + for op_with_args in self.autocast_lists.torch_16: op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) self._run_autocast_outofplace(op, args, torch.bfloat16, add_kwargs=maybe_kwargs) + self._run_autocast_outofplace(op, args, torch.float16, add_kwargs=maybe_kwargs, amp_dtype=torch.float16) - def test_autocast_nn_bf16(self): - for op_with_args in self.autocast_lists.nn_bf16: + def test_autocast_nn_16(self): + for op_with_args in self.autocast_lists.nn_16: op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) - self._run_autocast_outofplace(op, args, torch.bfloat16, module=torch._C._nn, add_kwargs=maybe_kwargs) + self._run_autocast_outofplace( + op, args, torch.bfloat16, module=torch._C._nn, add_kwargs=maybe_kwargs + ) + self._run_autocast_outofplace( + op, + args, + torch.float16, + module=torch._C._nn, + add_kwargs=maybe_kwargs, + amp_dtype=torch.float16, + ) def test_autocast_torch_fp32(self): for op_with_args in self.autocast_lists.torch_fp32: op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs) + self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs, amp_dtype=torch.float16) def test_autocast_nn_fp32(self): for op_with_args in self.autocast_lists.nn_fp32: op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) - self._run_autocast_outofplace(op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs) + self._run_autocast_outofplace( + op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs + ) + self._run_autocast_outofplace( + op, + args, + torch.float32, + module=torch._C._nn, + add_kwargs=maybe_kwargs, + amp_dtype=torch.float16, + ) def test_autocast_torch_need_autocast_promote(self): - for op, args in self.autocast_lists.torch_need_autocast_promote: - self._run_autocast_outofplace(op, args, torch.float32) + for op, args1, args2 in self.autocast_lists.torch_need_autocast_promote: + self._run_autocast_outofplace(op, args1, torch.float32) + self._run_autocast_outofplace(op, args2, torch.float32, amp_dtype=torch.float16) @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path") def test_autocast_rnn(self): diff --git a/torch/amp/autocast_mode.py b/torch/amp/autocast_mode.py index 8ed2c92b10ec5..30c6aefcf1bda 100644 --- a/torch/amp/autocast_mode.py +++ b/torch/amp/autocast_mode.py @@ -257,11 +257,12 @@ def __init__( self._cache_enabled = cache_enabled if self.device == "cpu": - supported_dtype = [torch.bfloat16] + supported_dtype = [torch.bfloat16, torch.float16] if self.fast_dtype not in supported_dtype and enabled: error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n" + error_message += "CPU Autocast only supports dtype of " error_message += ( - "CPU Autocast only supports dtype of torch.bfloat16 currently." + ", ".join(str(dtype) for dtype in supported_dtype) + " currently." ) warnings.warn(error_message) enabled = False diff --git a/torch/testing/_internal/autocast_test_lists.py b/torch/testing/_internal/autocast_test_lists.py index 5eced2f65c734..e6b6dcfc0f40d 100644 --- a/torch/testing/_internal/autocast_test_lists.py +++ b/torch/testing/_internal/autocast_test_lists.py @@ -244,6 +244,9 @@ def __init__(self, dev): mat1_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),) mat2_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),) + pointwise0_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),) + pointwise1_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),) + dummy_dimsets = ((n,), (n, n), (n, n, n), (n, n, n, n), (n, n, n, n, n)) dummy_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev),) @@ -275,29 +278,30 @@ def __init__(self, dev): # Some ops implement built-in type promotion. These don't need autocasting, # but autocasting relies on their promotion, so we include tests to double-check. self.torch_expect_builtin_promote = [ - ("eq", pointwise0_fp32 + pointwise1_bf16, torch.bool), - ("ge", pointwise0_fp32 + pointwise1_bf16, torch.bool), - ("gt", pointwise0_fp32 + pointwise1_bf16, torch.bool), - ("le", pointwise0_fp32 + pointwise1_bf16, torch.bool), - ("lt", pointwise0_fp32 + pointwise1_bf16, torch.bool), - ("ne", pointwise0_fp32 + pointwise1_bf16, torch.bool), - ("add", pointwise0_fp32 + pointwise1_bf16, torch.float32), - ("div", pointwise0_fp32 + pointwise1_bf16, torch.float32), - ("mul", pointwise0_fp32 + pointwise1_bf16, torch.float32), + ("eq", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("ge", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("gt", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("le", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("lt", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("ne", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("add", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32), + ("div", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32), + ("mul", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32), ] + self.methods_expect_builtin_promote = [ - ("__eq__", pointwise0_fp32 + pointwise1_bf16, torch.bool), - ("__ge__", pointwise0_fp32 + pointwise1_bf16, torch.bool), - ("__gt__", pointwise0_fp32 + pointwise1_bf16, torch.bool), - ("__le__", pointwise0_fp32 + pointwise1_bf16, torch.bool), - ("__lt__", pointwise0_fp32 + pointwise1_bf16, torch.bool), - ("__ne__", pointwise0_fp32 + pointwise1_bf16, torch.bool), - ("__add__", pointwise0_fp32 + pointwise1_bf16, torch.float32), - ("__div__", pointwise0_fp32 + pointwise1_bf16, torch.float32), - ("__mul__", pointwise0_fp32 + pointwise1_bf16, torch.float32), + ("__eq__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("__ge__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("__gt__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("__le__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("__lt__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("__ne__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("__add__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32), + ("__div__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32), + ("__mul__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32), ] # The remaining lists organize ops that autocast treats explicitly. - self.torch_bf16 = [ + self.torch_16 = [ ("conv1d", conv_args_fp32[0]), ("conv2d", conv_args_fp32[1]), ("conv3d", conv_args_fp32[2]), @@ -337,7 +341,7 @@ def __init__(self, dev): ("triplet_margin_loss", mat0_bf16 + mat1_bf16 + mat2_bf16), ("binary_cross_entropy_with_logits", mat0_bf16 + (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)), ] - self.nn_bf16 = [ + self.nn_16 = [ ("linear", mat0_fp32 + mat1_fp32, {}), ] self.nn_fp32 = [ @@ -358,6 +362,6 @@ def __init__(self, dev): ("huber_loss", mat0_bf16 + mat1_bf16), ] self.torch_need_autocast_promote = [ - ("cat", (pointwise0_bf16 + pointwise1_fp32,)), - ("stack", (pointwise0_bf16 + pointwise1_fp32,)), + ("cat", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)), + ("stack", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)), ]