diff --git a/test/run_tests.sh b/test/run_tests.sh index e867dfe3057..a7532cce4bf 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -169,7 +169,6 @@ function run_xla_op_tests { run_test "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_test_without_functionalization "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_test "$CDIR/test_async_closures.py" - run_test "$CDIR/test_autocast.py" run_test "$CDIR/test_xla_dist.py" run_test "$CDIR/test_profiler.py" run_test "$CDIR/test_ops.py" diff --git a/test/test_autocast.py b/test/test_autocast.py index 3c801068df8..e8f41d9519c 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -15,144 +15,7 @@ from torch_xla.amp import autocast, GradScaler -class AutocastTPUTestLists: - # Supplies ops and arguments for TPU autocast tests - def __init__(self, dev): - super().__init__() - n = 8 - # Utility arguments, created as one-element tuples - pointwise0_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),) - pointwise1_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),) - pointwise2_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),) - mat0_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),) - mat1_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),) - mat2_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),) - - dummy_dimsets = ((n,), (n, n), (n, n, n), (n, n, n, n), (n, n, n, n, n)) - - dummy_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev),) - for dimset in dummy_dimsets] - - dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n)) - conv_args_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev), - torch.randn(dimset, dtype=torch.bfloat16, device=dev)) - for dimset in dimsets] - conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev), - torch.randn(dimset, dtype=torch.float32, device=dev)) - for dimset in dimsets] - - bias_fp32 = (torch.randn((n,), dtype=torch.float32, device=dev),) - element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),) - pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),) - pointwise1_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),) - mat0_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) - mat1_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) - mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) - mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) - - dummy_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev),) - for dimset in dummy_dimsets] - # The lists below organize ops that autocast needs to test. - # self.list_name corresponds to test_autocast_list_name . - # Each op is associated with a tuple of valid arguments. - - # Some ops implement built-in type promotion. These don't need autocasting, - # but autocasting relies on their promotion, so we include tests to double-check. - self.torch_expect_builtin_promote = [ - ("eq", pointwise0_fp32 + pointwise1_bf16, torch.bool), - ("ge", pointwise0_fp32 + pointwise1_bf16, torch.bool), - ("gt", pointwise0_fp32 + pointwise1_bf16, torch.bool), - ("le", pointwise0_fp32 + pointwise1_bf16, torch.bool), - ("lt", pointwise0_fp32 + pointwise1_bf16, torch.bool), - ("ne", pointwise0_fp32 + pointwise1_bf16, torch.bool), - ("add", pointwise0_fp32 + pointwise1_bf16, torch.float32), - ("div", pointwise0_fp32 + pointwise1_bf16, torch.float32), - ("mul", pointwise0_fp32 + pointwise1_bf16, torch.float32), - ] - self.methods_expect_builtin_promote = [ - ("__eq__", pointwise0_fp32 + pointwise1_bf16, torch.bool), - ("__ge__", pointwise0_fp32 + pointwise1_bf16, torch.bool), - ("__gt__", pointwise0_fp32 + pointwise1_bf16, torch.bool), - ("__le__", pointwise0_fp32 + pointwise1_bf16, torch.bool), - ("__lt__", pointwise0_fp32 + pointwise1_bf16, torch.bool), - ("__ne__", pointwise0_fp32 + pointwise1_bf16, torch.bool), - ("__add__", pointwise0_fp32 + pointwise1_bf16, torch.float32), - ("__div__", pointwise0_fp32 + pointwise1_bf16, torch.float32), - ("__mul__", pointwise0_fp32 + pointwise1_bf16, torch.float32), - ] - # The remaining lists organize ops that autocast treats explicitly. - self.torch_bf16 = [ - ("conv1d", conv_args_fp32[0]), - ("conv2d", conv_args_fp32[1]), - ("conv3d", conv_args_fp32[2]), - ("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32), - torch.randn((n, n, n), device=dev, dtype=torch.float32))), - ("mm", mat0_fp32 + mat1_fp32), - ("matmul", mat0_fp32 + mat1_fp32), - ("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32), - torch.randn((n, n, n), device=dev, dtype=torch.float32), - torch.randn((n, n, n), device=dev, dtype=torch.float32))), - ("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32), - ("addbmm", - mat0_fp32 + (torch.randn((n, n, n), device=dev, dtype=torch.float32), - torch.randn((n, n, n), device=dev, dtype=torch.float32))), - ("conv_tbc", (torch.randn((10, 7, 3), device=dev, dtype=torch.float32), - torch.randn((5, 3, 5), device=dev, dtype=torch.float32), - torch.randn(5, device=dev, dtype=torch.float32), 0)), - ("conv_transpose1d", conv_args_fp32[0]), - ("conv_transpose2d", conv_args_fp32[1]), - ("conv_transpose3d", conv_args_fp32[2]), - ("prelu", pointwise0_fp32 + element0_fp32), - ("relu", pointwise0_fp32 + element0_fp32), - ] - self.torch_fp32 = [ - ("cosine_embedding_loss", (torch.tensor([[1, 2, 3]], - device=dev, - dtype=torch.bfloat16), - torch.tensor([[1, 3, 4]], - device=dev, - dtype=torch.bfloat16), - torch.tensor([1], - device=dev, - dtype=torch.int))), - ("hinge_embedding_loss", - mat0_bf16 + (torch.ones(n, device=dev, dtype=torch.int),)), - ("margin_ranking_loss", mat0_bf16 + mat1_bf16 + (torch.ones( - (n,), device=dev, dtype=torch.bfloat16),)), - ("triplet_margin_loss", mat0_bf16 + mat1_bf16 + mat2_bf16), - ("binary_cross_entropy_with_logits", mat0_bf16 + (torch.rand( - (n, n), device=dev, dtype=torch.bfloat16),)), - ] - self.nn_bf16 = [ - ("linear", mat0_fp32 + mat1_fp32, {}), - ] - self.nn_fp32 = [ - ("nll_loss", (torch.rand((n, n), device=dev, dtype=torch.float), - torch.zeros((n,), device=dev, dtype=torch.long))), - ("nll_loss2d", (torch.rand((n, n, n, n), - device=dev, - dtype=torch.bfloat16), - torch.zeros((n, n, n), device=dev, dtype=torch.long))), - ("l1_loss", mat0_bf16 + mat1_bf16), - ("smooth_l1_loss", mat0_bf16 + mat1_bf16), - ("mse_loss", mat0_bf16 + mat1_bf16), - ("multilabel_margin_loss", mat0_bf16 + (torch.ones( - (n, n), device=dev, dtype=torch.long),)), - ("soft_margin_loss", mat0_bf16 + (torch.ones( - (n, n), device=dev, dtype=torch.long),)), - ("multi_margin_loss", mat0_bf16 + (torch.ones( - (n,), device=dev, dtype=torch.long),)), - ] - self.torch_need_autocast_promote = [ - ("cat", (pointwise0_bf16 + pointwise1_fp32,)), - ("stack", (pointwise0_bf16 + pointwise1_fp32,)), - ] - self.methods_fp32 = [] - - self.methods_bf16 = [("__matmul__", mat0_bf16 + mat1_fp32)] - - -class AutocastCudaTestUnsupportedLists(object): +class AutocastTestUnsupportedLists(object): def __init__(self): super().__init__() @@ -167,6 +30,7 @@ def __init__(self): # The remaining lists organize ops that autocast treats explicitly. self.torch_fp16 = [ "_convolution_nogroup", # need lowering + "prelu", # need lowering "addmv", # need lowering ] self.torch_fp32 = [ @@ -187,22 +51,19 @@ class TestAutocastBase(unittest.TestCase): def setUp(self): super(TestAutocastBase, self).setUp() - self.is_autocast_enabled = None - self.autocast_lists = None - self.autocast_unsupported_lists = None + self.autocast_lists = AutocastTestLists(torch.device(xm.xla_device())) + self.autocast_unsupported_lists = AutocastTestUnsupportedLists() + self.skip_list = [] def tearDown(self): del self.autocast_lists super(TestAutocastBase, self).tearDown() def get_autocast_list(self, list_name): - if self.autocast_unsupported_lists: - return [ - tp for tp in getattr(self.autocast_lists, list_name) - if tp[0] not in getattr(self.autocast_unsupported_lists, list_name) - ] - else: - return [tp for tp in getattr(self.autocast_lists, list_name)] + return [ + tp for tp in getattr(self.autocast_lists, list_name) + if tp[0] not in getattr(self.autocast_unsupported_lists, list_name) + ] def args_maybe_kwargs(self, op_with_args): if len(op_with_args) == 2: @@ -229,9 +90,9 @@ def cast(val, to_type): if add_kwargs is None: add_kwargs = {} - self.assertFalse(self.is_autocast_enabled()) - with autocast(xm.xla_device()): - self.assertTrue(self.is_autocast_enabled()) + self.assertFalse(torch.is_autocast_enabled()) + with autocast(): + self.assertTrue(torch.is_autocast_enabled()) out_type = out_type if out_type is not None else run_as_type output = output_method = None @@ -242,7 +103,7 @@ def cast(val, to_type): if isinstance(output, torch.Tensor): self.assertTrue( out_type == output.dtype, - "autocast for {} produced {}, should produce {}".format( + "autocast for torch.{} produced {}, should produce {}".format( op, output.dtype, out_type)) # Try Tensor.* variant: @@ -251,8 +112,8 @@ def cast(val, to_type): if isinstance(output_method, torch.Tensor): self.assertTrue( out_type == output_method.dtype, - "autocast for {} produced {}, should produce torch.{}".format( - op, output_method.dtype, out_type)) + "autocast for torch.{} produced {}, should produce torch.{}". + format(op, output_method.dtype, out_type)) self.assertTrue((output is not None) or ( output_method is not None @@ -280,8 +141,8 @@ def compare(first, second): # Compare numerics to Python-side "autocasting" that (we expect) does the same thing # as the C++-side autocasting, and should be bitwise accurate. output_to_compare = output if output is not None else output_method - with autocast(xm.xla_device(), enabled=False): - self.assertFalse(self.is_autocast_enabled()) + with autocast(enabled=False): + self.assertFalse(torch.is_autocast_enabled()) if module is not None and hasattr(module, op): control = getattr(module, op)(*cast(args, run_as_type), **add_kwargs) @@ -292,41 +153,11 @@ def compare(first, second): comparison = compare(output_to_compare, control) self.assertTrue(comparison, "torch.{} result did not match control".format(op)) - self.assertTrue(self.is_autocast_enabled()) - self.assertFalse(self.is_autocast_enabled()) - - -@unittest.skipIf(not xm.get_xla_supported_devices("GPU"), f"GPU autocast test.") -class TestAutocastCuda(TestAutocastBase): - - def setUp(self): - super(TestAutocastCuda, self).setUp() - self.is_autocast_enabled = torch.is_autocast_enabled - self.autocast_lists = AutocastTestLists(torch.device(xm.xla_device())) - self.autocast_unsupported_lists = AutocastCudaTestUnsupportedLists() + self.assertTrue(torch.is_autocast_enabled()) + self.assertFalse(torch.is_autocast_enabled()) - def test_autocast_nn_fp16(self): - with torch.backends.cudnn.flags(enabled=True, deterministic=True): - for op, args in self.get_autocast_list('nn_fp16'): - self._run_autocast_outofplace( - op, args, torch.float16, module=torch._C._nn) - def test_autocast_linalg_fp16(self): - with torch.backends.cudnn.flags(enabled=True, deterministic=True): - for op, args in self.get_autocast_list('linalg_fp16'): - self._run_autocast_outofplace( - op, args, torch.float16, module=torch._C._linalg) - - def test_autocast_methods_fp16(self): - with torch.backends.cudnn.flags(enabled=True, deterministic=True): - for op, args in self.get_autocast_list('methods_fp16'): - self._run_autocast_outofplace(op, args, torch.float16, module=None) - - def test_autocast_banned(self): - with torch.cuda.amp.autocast(): - for op, args, module in self.get_autocast_list('banned'): - with self.assertRaises(RuntimeError): - getattr(module, op)(*args) +class TestAutocast(TestAutocastBase): def test_autocast_torch_fp32(self): for op_with_args in self.get_autocast_list('torch_fp32'): @@ -343,58 +174,30 @@ def test_autocast_torch_expect_builtin_promote(self): 'torch_expect_builtin_promote'): self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type) + def test_autocast_nn_fp16(self): + with torch.backends.cudnn.flags(enabled=True, deterministic=True): + for op, args in self.get_autocast_list('nn_fp16'): + self._run_autocast_outofplace( + op, args, torch.float16, module=torch._C._nn) + def test_autocast_nn_fp32(self): for op, args in self.get_autocast_list('nn_fp32'): self._run_autocast_outofplace( op, args, torch.float32, module=torch._C._nn) - def test_autocast_methods_fp32(self): - for op, args in self.get_autocast_list('methods_fp32'): - print("autocast fp32", op) - self._run_autocast_outofplace(op, args, torch.float32, module=None) - - def test_autocast_methods_expect_builtin_promote(self): - for op, args, out_type in self.get_autocast_list( - 'methods_expect_builtin_promote'): - self._run_autocast_outofplace( - op, args, torch.float32, module=None, out_type=out_type) - - -@unittest.skipIf(not xm.get_xla_supported_devices("TPU"), f"TPU autocast test.") -class TestAutocastTPU(TestAutocastBase): - - def setUp(self): - super(TestAutocastTPU, self).setUp() - self.is_autocast_enabled = torch.is_autocast_xla_enabled - self.autocast_lists = AutocastTPUTestLists(torch.device(xm.xla_device())) - - def test_autocast_methods_bf16(self): - for op, args in self.get_autocast_list('methods_bf16'): - self._run_autocast_outofplace(op, args, torch.bfloat16, module=None) - - def test_autocast_torch_fp32(self): - for op_with_args in self.get_autocast_list('torch_fp32'): - op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) - self._run_autocast_outofplace( - op, args, torch.float32, add_kwargs=maybe_kwargs) - - def test_autocast_torch_need_autocast_promote(self): - for op, args in self.get_autocast_list('torch_need_autocast_promote'): - self._run_autocast_outofplace(op, args, torch.float32) - - def test_autocast_torch_expect_builtin_promote(self): - for op, args, out_type in self.get_autocast_list( - 'torch_expect_builtin_promote'): - self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type) + def test_autocast_linalg_fp16(self): + with torch.backends.cudnn.flags(enabled=True, deterministic=True): + for op, args in self.get_autocast_list('linalg_fp16'): + self._run_autocast_outofplace( + op, args, torch.float16, module=torch._C._linalg) - def test_autocast_nn_fp32(self): - for op, args in self.get_autocast_list('nn_fp32'): - self._run_autocast_outofplace( - op, args, torch.float32, module=torch._C._nn) + def test_autocast_methods_fp16(self): + with torch.backends.cudnn.flags(enabled=True, deterministic=True): + for op, args in self.get_autocast_list('methods_fp16'): + self._run_autocast_outofplace(op, args, torch.float16, module=None) def test_autocast_methods_fp32(self): for op, args in self.get_autocast_list('methods_fp32'): - print("autocast fp32", op) self._run_autocast_outofplace(op, args, torch.float32, module=None) def test_autocast_methods_expect_builtin_promote(self): @@ -403,6 +206,12 @@ def test_autocast_methods_expect_builtin_promote(self): self._run_autocast_outofplace( op, args, torch.float32, module=None, out_type=out_type) + def test_autocast_banned(self): + with torch.cuda.amp.autocast(): + for op, args, module in self.get_autocast_list('banned'): + with self.assertRaises(RuntimeError): + getattr(module, op)(*args) + if __name__ == "__main__": test = unittest.main(verbosity=FLAGS.verbosity, exit=False) diff --git a/test/test_train_mp_imagenet_amp.py b/test/test_train_mp_imagenet_amp.py index bec112a9378..734d9c693fc 100644 --- a/test/test_train_mp_imagenet_amp.py +++ b/test/test_train_mp_imagenet_amp.py @@ -27,6 +27,7 @@ '--test_only_at_end': { 'action': 'store_true', }, + # AMP only works with XLA:GPU '--amp': { 'action': 'store_true', }, @@ -196,7 +197,6 @@ def train_imagenet(): torch.manual_seed(42) device = xm.xla_device() - device_hw = xm.xla_device_hw(device) model = get_model_property('model_fn')().to(device) writer = None if xm.is_master_ordinal(): @@ -219,10 +219,7 @@ def train_imagenet(): summary_writer=writer) loss_fn = nn.CrossEntropyLoss() if FLAGS.amp: - if device_hw == 'TPU': - scaler = None - elif device_hw == 'GPU': - scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) + scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) def train_loop_fn(loader, epoch): tracker = xm.RateTracker() @@ -230,18 +227,15 @@ def train_loop_fn(loader, epoch): for step, (data, target) in enumerate(loader): optimizer.zero_grad() if FLAGS.amp: - with autocast(xm.xla_device()): + with autocast(): output = model(data) loss = loss_fn(output, target) - if scaler: - scaler.scale(loss).backward() - gradients = xm._fetch_gradients(optimizer) - xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size()) - scaler.step(optimizer) - scaler.update() - else: - loss.backward() - xm.optimizer_step(optimizer) + + scaler.scale(loss).backward() + gradients = xm._fetch_gradients(optimizer) + xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size()) + scaler.step(optimizer) + scaler.update() else: output = model(data) loss = loss_fn(output, target) diff --git a/test/test_train_mp_mnist_amp.py b/test/test_train_mp_mnist_amp.py index ae4db118300..2f773f4d032 100644 --- a/test/test_train_mp_mnist_amp.py +++ b/test/test_train_mp_mnist_amp.py @@ -66,10 +66,10 @@ def forward(self, x): return F.log_softmax(x, dim=1) -def _train_update(device, step, loss, tracker, writer): +def _train_update(device, x, loss, tracker, writer): test_utils.print_training_update( device, - step, + x, loss.item(), tracker.rate(), tracker.global_rate(), @@ -130,42 +130,28 @@ def train_mnist(flags, **kwargs): lr = flags.lr * xm.xrt_world_size() device = xm.xla_device() - device_hw = xm.xla_device_hw(device) model = MNIST().to(device) - writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(flags.logdir) optim_cls = syncfree.SGD if FLAGS.use_syncfree_optim else optim.SGD optimizer = optim_cls(model.parameters(), lr=lr, momentum=flags.momentum) loss_fn = nn.NLLLoss() - - if device_hw == 'TPU': - scaler = None - elif device_hw == 'GPU': - # GradScaler only used for GPU - scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) - else: - print("Only TPU or GPU supported for AMP.") - sys.exit(1) + scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() - with autocast(device): + with autocast(): output = model(data) loss = loss_fn(output, target) - if scaler: - scaler.scale(loss).backward() - gradients = xm._fetch_gradients(optimizer) - xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size()) - scaler.step(optimizer) - scaler.update() - else: - loss.backward() - xm.optimizer_step(optimizer) + scaler.scale(loss).backward() + gradients = xm._fetch_gradients(optimizer) + xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size()) + scaler.step(optimizer) + scaler.update() tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure( diff --git a/torch_patches/.torch_pin b/torch_patches/.torch_pin deleted file mode 100644 index 7e0a6082f13..00000000000 --- a/torch_patches/.torch_pin +++ /dev/null @@ -1 +0,0 @@ -#96370 diff --git a/torch_xla/amp/__init__.py b/torch_xla/amp/__init__.py index 739f55cc0dc..1c0ecd08876 100644 --- a/torch_xla/amp/__init__.py +++ b/torch_xla/amp/__init__.py @@ -1,2 +1,2 @@ -from .autocast_mode import autocast # noqa: F401 +from .autocast_mode import autocast, custom_fwd, custom_bwd # noqa: F401 from .grad_scaler import GradScaler # noqa: F401 diff --git a/torch_xla/amp/autocast_mode.py b/torch_xla/amp/autocast_mode.py index 999a099b777..2f7a0a83555 100644 --- a/torch_xla/amp/autocast_mode.py +++ b/torch_xla/amp/autocast_mode.py @@ -1,40 +1,5 @@ import torch -import torch_xla.core.xla_model as xm -from typing import Any - -class autocast(torch.amp.autocast_mode.autocast): - r""" - See :class:`torch.autocast`. - ``torch_xla.amp.autocast(device, args...)`` is equivalent to ``torch.autocast("xla", args...)`` for TPUs - ``torch.autocast("cuda", args...)`` for GPUs. - """ - - def __init__(self, - device, - enabled: bool = True, - dtype: torch.dtype = torch.bfloat16, - cache_enabled: bool = True): - if xm.xla_device_hw(device) == 'GPU': - super().__init__( - "cuda", - enabled=enabled, - dtype=torch.float16, - cache_enabled=cache_enabled) - elif xm.xla_device_hw(device) == 'TPU': - super().__init__( - "xla", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled) - else: - print( - 'Warning: AMP only supported for XLA:TPU and XLA:GPU. Ignoring autocast.' - ) - - def __enter__(self): - return super().__enter__() - - def __exit__(self, exc_type: Any, exc_val: Any, - exc_tb: Any): # type: ignore[override] - return super().__exit__(exc_type, exc_val, exc_tb) - - def __call__(self, func): - return super().__call__(func) +autocast = torch.cuda.amp.autocast +custom_fwd = torch.cuda.amp.custom_fwd +custom_bwd = torch.cuda.amp.custom_bwd diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index b5b6efb6775..f92bc2b84af 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -34,7 +34,6 @@ ptxla_cc_library( "aten_autograd_ops.cpp", "aten_xla_bridge.cpp", "aten_xla_type.cpp", - "autocast_mode.cpp", "batch_norm.cpp", "convert_ops.cpp", "convolution.cpp", diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index e8c1bb14ef6..854eaaff121 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -437,6 +437,10 @@ void XLANativeFunctions::_amp_foreach_non_finite_check_and_unscale_( at::TensorList self, at::Tensor& found_inf, const at::Tensor& inv_scale) { TORCH_LAZY_FN_COUNTER("xla::"); XLATensorPtr found_inf_tensor = bridge::GetXlaTensor(found_inf); + XlaDeviceType hw_type = + static_cast(found_inf_tensor->GetDevice().type()); + XLA_CHECK(hw_type == XlaDeviceType::GPU || hw_type == XlaDeviceType::CPU) + << "AMP should be used with XLA:GPU"; tensor_methods::_amp_foreach_non_finite_check_and_unscale_( bridge::GetXlaTensors(self), found_inf_tensor, bridge::GetXlaTensor(inv_scale)); @@ -451,6 +455,10 @@ at::Tensor& XLANativeFunctions::_amp_update_scale_(at::Tensor& current_scale, TORCH_LAZY_FN_COUNTER("xla::"); XLATensorPtr growth_tracker_tensor = bridge::GetXlaTensor(growth_tracker); XLATensorPtr current_scale_tensor = bridge::GetXlaTensor(current_scale); + XlaDeviceType hw_type = + static_cast(growth_tracker_tensor->GetDevice().type()); + XLA_CHECK(hw_type == XlaDeviceType::GPU || hw_type == XlaDeviceType::CPU) + << "AMP should be used with XLA:GPU"; tensor_methods::_amp_update_scale_( growth_tracker_tensor, current_scale_tensor, bridge::GetXlaTensor(found_inf), scale_growth_factor, diff --git a/torch_xla/csrc/autocast_mode.cpp b/torch_xla/csrc/autocast_mode.cpp deleted file mode 100644 index 91ff7999e10..00000000000 --- a/torch_xla/csrc/autocast_mode.cpp +++ /dev/null @@ -1,160 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -namespace at { -namespace autocast { -namespace { - -#define KERNEL_XLA(OP, POLICY) KERNEL(c10::DeviceType::XLA, OP, POLICY) - -#define KERNEL_XLA2(OP, OVERLOAD, POLICY) \ - KERNEL2(c10::DeviceType::XLA, OP, OVERLOAD, POLICY) - -#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XLA( \ - REDISPATCH_FUNC, REGISTER_NAME, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, \ - POLICY) \ - KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(c10::DeviceType::XLA, REDISPATCH_FUNC, \ - REGISTER_NAME, REGISTER_SIGNATURE, \ - REDISPATCH_SIGNATURE, POLICY) - -TORCH_LIBRARY_IMPL(_, AutocastXLA, m) { - m.fallback(torch::CppFunction::makeFallthrough()); -} - -TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) { - // lower_precision_fp cast policy - KERNEL_XLA(conv1d, lower_precision_fp) - KERNEL_XLA2(conv1d, padding, lower_precision_fp) - KERNEL_XLA(conv2d, lower_precision_fp) - KERNEL_XLA2(conv2d, padding, lower_precision_fp) - KERNEL_XLA(conv3d, lower_precision_fp) - KERNEL_XLA2(conv3d, padding, lower_precision_fp) - KERNEL_XLA(bmm, lower_precision_fp) - KERNEL_XLA(mm, lower_precision_fp) - KERNEL_XLA(baddbmm, lower_precision_fp) - KERNEL_XLA(addmm, lower_precision_fp) - KERNEL_XLA(addbmm, lower_precision_fp) - KERNEL_XLA(linear, lower_precision_fp) - KERNEL_XLA(matmul, lower_precision_fp) - KERNEL_XLA(conv_tbc, lower_precision_fp) - KERNEL_XLA(conv_transpose1d, lower_precision_fp) - KERNEL_XLA2(conv_transpose2d, input, lower_precision_fp) - KERNEL_XLA2(conv_transpose3d, input, lower_precision_fp) - KERNEL_XLA(prelu, lower_precision_fp) - KERNEL_XLA(relu, lower_precision_fp) - KERNEL_XLA(max_pool2d, lower_precision_fp) - - // fp32 cast policy - // Commented out ops are included in the AutoCastCPU Policy, - // but not lowered. Enable if op is lowered. - KERNEL_XLA(batch_norm, fp32) - KERNEL_XLA2(log_softmax, int, fp32) - KERNEL_XLA2(log_softmax, Dimname, fp32) - KERNEL_XLA(binary_cross_entropy, fp32) - // KERNEL_XLA(grid_sampler, fp32) - // KERNEL_XLA(polar, fp32) - KERNEL_XLA(prod, fp32) - KERNEL_XLA2(prod, dim_int, fp32) - KERNEL_XLA2(prod, dim_Dimname, fp32) - // KERNEL_XLA(quantile, fp32) - // KERNEL_XLA2(quantile, scalar, fp32) - // KERNEL_XLA(nanquantile, fp32) - // KERNEL_XLA2(nanquantile, scalar, fp32) - // KERNEL_XLA(stft, fp32) - // KERNEL_XLA2(stft, center, fp32) - KERNEL_XLA(cdist, fp32) - // KERNEL_XLA(grid_sampler_2d, fp32) - // KERNEL_XLA(grid_sampler_3d, fp32) - KERNEL_XLA(trace, fp32) - // KERNEL_XLA(view_as_complex, fp32) - KERNEL_XLA(cholesky, fp32) - KERNEL_XLA(cholesky_inverse, fp32) - KERNEL_XLA(cholesky_solve, fp32) - KERNEL_XLA(inverse, fp32) - // KERNEL_XLA(lu_solve, fp32) - // KERNEL_XLA(orgqr, fp32) - // KERNEL_XLA(ormqr, fp32) - // KERNEL_XLA(pinverse, fp32) - KERNEL_XLA(reflection_pad1d, fp32) - KERNEL_XLA(reflection_pad2d, fp32) - KERNEL_XLA(replication_pad1d, fp32) - KERNEL_XLA(replication_pad2d, fp32) - KERNEL_XLA(replication_pad3d, fp32) - KERNEL_XLA(mse_loss, fp32) - KERNEL_XLA(cosine_embedding_loss, fp32) - KERNEL_XLA(nll_loss, fp32) - KERNEL_XLA(nll_loss2d, fp32) - KERNEL_XLA(hinge_embedding_loss, fp32) - // KERNEL_XLA(poisson_nll_loss, fp32) - KERNEL_XLA(smooth_l1_loss, fp32) - // KERNEL_XLA(cross_entropy_loss, fp32) - KERNEL_XLA(l1_loss, fp32) - // KERNEL_XLA(huber_loss, fp32) - KERNEL_XLA(margin_ranking_loss, fp32) - KERNEL_XLA(soft_margin_loss, fp32) - KERNEL_XLA(triplet_margin_loss, fp32) - KERNEL_XLA(multi_margin_loss, fp32) - KERNEL_XLA2(ctc_loss, IntList, fp32) - KERNEL_XLA2(ctc_loss, Tensor, fp32) - KERNEL_XLA(kl_div, fp32) - KERNEL_XLA(multilabel_margin_loss, fp32) - KERNEL_XLA(binary_cross_entropy_with_logits, fp32) - // KERNEL_XLA(fft_fft, fp32) - // KERNEL_XLA(fft_ifft, fp32) - // KERNEL_XLA(fft_fft2, fp32) - // KERNEL_XLA(fft_ifft2, fp32) - // KERNEL_XLA(fft_fftn, fp32) - // KERNEL_XLA(fft_ifftn, fp32) - // KERNEL_XLA(fft_rfft, fp32) - // KERNEL_XLA(fft_irfft, fp32) - // KERNEL_XLA(fft_rfft2, fp32) - // KERNEL_XLA(fft_irfft2, fp32) - // KERNEL_XLA(fft_rfftn, fp32) - // KERNEL_XLA(fft_irfftn, fp32) - // KERNEL_XLA(fft_hfft, fp32) - // KERNEL_XLA(fft_ihfft, fp32) - // KERNEL_XLA(linalg_cond, fp32) - // KERNEL_XLA2(linalg_cond, p_str, fp32) - // KERNEL_XLA(linalg_matrix_rank, fp32) - // KERNEL_XLA2(linalg_matrix_rank, tol_tensor, fp32) - // KERNEL_XLA2(linalg_matrix_rank, atol_rtol_tensor, fp32) - // KERNEL_XLA2(linalg_matrix_rank, atol_rtol_float, fp32) - // KERNEL_XLA(linalg_solve, fp32) - // KERNEL_XLA(linalg_cholesky, fp32) - // KERNEL_XLA(linalg_svdvals, fp32) - // KERNEL_XLA(linalg_eigvals, fp32) - // KERNEL_XLA(linalg_eigvalsh, fp32) - // KERNEL_XLA(linalg_inv, fp32) - // KERNEL_XLA(linalg_householder_product, fp32) - // KERNEL_XLA(linalg_tensorinv, fp32) - // KERNEL_XLA(linalg_tensorsolve, fp32) - // KERNEL_XLA(fake_quantize_per_tensor_affine, fp32) - // KERNEL_XLA(geqrf, fp32) - // KERNEL_XLA(_lu_with_info, fp32) - KERNEL_XLA(qr, fp32) - KERNEL_XLA(svd, fp32) - KERNEL_XLA(triangular_solve, fp32) - KERNEL_XLA(multilabel_margin_loss_forward, fp32) - // KERNEL_XLA(linalg_qr, fp32) - // KERNEL_XLA(linalg_cholesky_ex, fp32) - KERNEL_XLA(linalg_svd, fp32) - // KERNEL_XLA(linalg_eig, fp32) - // KERNEL_XLA(linalg_eigh, fp32) - // KERNEL_XLA(linalg_lstsq, fp32) - KERNEL_XLA(linalg_inv_ex, fp32) - - // promote - KERNEL_XLA(stack, promote) - KERNEL_XLA(cat, promote) - KERNEL_XLA(index_copy, promote) - KERNEL_XLA2(index_copy, dimname, promote) -} - -} // namespace -} // namespace autocast -} // namespace at diff --git a/torch_xla/csrc/tensor_impl.cpp b/torch_xla/csrc/tensor_impl.cpp index bacccdeb85b..83f4e19c7d3 100644 --- a/torch_xla/csrc/tensor_impl.cpp +++ b/torch_xla/csrc/tensor_impl.cpp @@ -2,7 +2,6 @@ #include #include -#include #include #include #include @@ -63,15 +62,6 @@ XLATensorImpl::XLATensorImpl(XLATensor&& tensor) GetTypeMeta(tensor), bridge::XlaDeviceToAtenDevice(tensor.GetDevice())), tensor_(c10::make_intrusive(std::move(tensor))) { - // Update the Autocast key based off the backend device. - // Upstream TensorImpl cannot differentiate between XLA:TPU and XLA:GPU - // so we must manually update Autocast to AutocastCUDA on XLA:GPU. - torch::lazy::BackendDevice current_device = GetCurrentDevice(); - if (static_cast(current_device.type()) == XlaDeviceType::GPU) { - auto autocast_cuda_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastCUDA); - auto autocast_xla_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastXLA); - key_set_ = (key_set_ - autocast_xla_ks) | autocast_cuda_ks; - } is_non_overlapping_and_dense_ = false; set_custom_sizes_strides(SizesStridesPolicy::CustomSizes); }