diff --git a/.circleci/common.sh b/.circleci/common.sh index a3715e6032e..f4d840f87d8 100755 --- a/.circleci/common.sh +++ b/.circleci/common.sh @@ -131,8 +131,6 @@ function run_torch_xla_tests() { chmod -R 755 ~/htmlcov else ./test/run_tests.sh - # only run test_autocast for cpu and gpu on circleCI. - python test/test_autocast.py # GPU tests if [ -x "$(command -v nvidia-smi)" ]; then diff --git a/test/run_tests.sh b/test/run_tests.sh index 1346ef91d3e..cefe9711c12 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -168,6 +168,7 @@ function run_xla_op_tests { run_test "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_test_without_functionalization "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_test "$CDIR/test_async_closures.py" + run_test "$CDIR/test_autocast.py" run_test "$CDIR/test_xla_dist.py" run_test "$CDIR/test_profiler.py" run_test "$CDIR/test_ops.py" diff --git a/test/test_autocast.py b/test/test_autocast.py index e8f41d9519c..3c801068df8 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -15,7 +15,144 @@ from torch_xla.amp import autocast, GradScaler -class AutocastTestUnsupportedLists(object): +class AutocastTPUTestLists: + # Supplies ops and arguments for TPU autocast tests + def __init__(self, dev): + super().__init__() + n = 8 + # Utility arguments, created as one-element tuples + pointwise0_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),) + pointwise1_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),) + pointwise2_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),) + mat0_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),) + mat1_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),) + mat2_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),) + + dummy_dimsets = ((n,), (n, n), (n, n, n), (n, n, n, n), (n, n, n, n, n)) + + dummy_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev),) + for dimset in dummy_dimsets] + + dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n)) + conv_args_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev), + torch.randn(dimset, dtype=torch.bfloat16, device=dev)) + for dimset in dimsets] + conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev), + torch.randn(dimset, dtype=torch.float32, device=dev)) + for dimset in dimsets] + + bias_fp32 = (torch.randn((n,), dtype=torch.float32, device=dev),) + element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),) + pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),) + pointwise1_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),) + mat0_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + mat1_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + + dummy_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev),) + for dimset in dummy_dimsets] + # The lists below organize ops that autocast needs to test. + # self.list_name corresponds to test_autocast_list_name . + # Each op is associated with a tuple of valid arguments. + + # Some ops implement built-in type promotion. These don't need autocasting, + # but autocasting relies on their promotion, so we include tests to double-check. + self.torch_expect_builtin_promote = [ + ("eq", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("ge", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("gt", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("le", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("lt", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("ne", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("add", pointwise0_fp32 + pointwise1_bf16, torch.float32), + ("div", pointwise0_fp32 + pointwise1_bf16, torch.float32), + ("mul", pointwise0_fp32 + pointwise1_bf16, torch.float32), + ] + self.methods_expect_builtin_promote = [ + ("__eq__", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("__ge__", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("__gt__", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("__le__", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("__lt__", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("__ne__", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("__add__", pointwise0_fp32 + pointwise1_bf16, torch.float32), + ("__div__", pointwise0_fp32 + pointwise1_bf16, torch.float32), + ("__mul__", pointwise0_fp32 + pointwise1_bf16, torch.float32), + ] + # The remaining lists organize ops that autocast treats explicitly. + self.torch_bf16 = [ + ("conv1d", conv_args_fp32[0]), + ("conv2d", conv_args_fp32[1]), + ("conv3d", conv_args_fp32[2]), + ("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32))), + ("mm", mat0_fp32 + mat1_fp32), + ("matmul", mat0_fp32 + mat1_fp32), + ("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32))), + ("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32), + ("addbmm", + mat0_fp32 + (torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32))), + ("conv_tbc", (torch.randn((10, 7, 3), device=dev, dtype=torch.float32), + torch.randn((5, 3, 5), device=dev, dtype=torch.float32), + torch.randn(5, device=dev, dtype=torch.float32), 0)), + ("conv_transpose1d", conv_args_fp32[0]), + ("conv_transpose2d", conv_args_fp32[1]), + ("conv_transpose3d", conv_args_fp32[2]), + ("prelu", pointwise0_fp32 + element0_fp32), + ("relu", pointwise0_fp32 + element0_fp32), + ] + self.torch_fp32 = [ + ("cosine_embedding_loss", (torch.tensor([[1, 2, 3]], + device=dev, + dtype=torch.bfloat16), + torch.tensor([[1, 3, 4]], + device=dev, + dtype=torch.bfloat16), + torch.tensor([1], + device=dev, + dtype=torch.int))), + ("hinge_embedding_loss", + mat0_bf16 + (torch.ones(n, device=dev, dtype=torch.int),)), + ("margin_ranking_loss", mat0_bf16 + mat1_bf16 + (torch.ones( + (n,), device=dev, dtype=torch.bfloat16),)), + ("triplet_margin_loss", mat0_bf16 + mat1_bf16 + mat2_bf16), + ("binary_cross_entropy_with_logits", mat0_bf16 + (torch.rand( + (n, n), device=dev, dtype=torch.bfloat16),)), + ] + self.nn_bf16 = [ + ("linear", mat0_fp32 + mat1_fp32, {}), + ] + self.nn_fp32 = [ + ("nll_loss", (torch.rand((n, n), device=dev, dtype=torch.float), + torch.zeros((n,), device=dev, dtype=torch.long))), + ("nll_loss2d", (torch.rand((n, n, n, n), + device=dev, + dtype=torch.bfloat16), + torch.zeros((n, n, n), device=dev, dtype=torch.long))), + ("l1_loss", mat0_bf16 + mat1_bf16), + ("smooth_l1_loss", mat0_bf16 + mat1_bf16), + ("mse_loss", mat0_bf16 + mat1_bf16), + ("multilabel_margin_loss", mat0_bf16 + (torch.ones( + (n, n), device=dev, dtype=torch.long),)), + ("soft_margin_loss", mat0_bf16 + (torch.ones( + (n, n), device=dev, dtype=torch.long),)), + ("multi_margin_loss", mat0_bf16 + (torch.ones( + (n,), device=dev, dtype=torch.long),)), + ] + self.torch_need_autocast_promote = [ + ("cat", (pointwise0_bf16 + pointwise1_fp32,)), + ("stack", (pointwise0_bf16 + pointwise1_fp32,)), + ] + self.methods_fp32 = [] + + self.methods_bf16 = [("__matmul__", mat0_bf16 + mat1_fp32)] + + +class AutocastCudaTestUnsupportedLists(object): def __init__(self): super().__init__() @@ -30,7 +167,6 @@ def __init__(self): # The remaining lists organize ops that autocast treats explicitly. self.torch_fp16 = [ "_convolution_nogroup", # need lowering - "prelu", # need lowering "addmv", # need lowering ] self.torch_fp32 = [ @@ -51,19 +187,22 @@ class TestAutocastBase(unittest.TestCase): def setUp(self): super(TestAutocastBase, self).setUp() - self.autocast_lists = AutocastTestLists(torch.device(xm.xla_device())) - self.autocast_unsupported_lists = AutocastTestUnsupportedLists() - self.skip_list = [] + self.is_autocast_enabled = None + self.autocast_lists = None + self.autocast_unsupported_lists = None def tearDown(self): del self.autocast_lists super(TestAutocastBase, self).tearDown() def get_autocast_list(self, list_name): - return [ - tp for tp in getattr(self.autocast_lists, list_name) - if tp[0] not in getattr(self.autocast_unsupported_lists, list_name) - ] + if self.autocast_unsupported_lists: + return [ + tp for tp in getattr(self.autocast_lists, list_name) + if tp[0] not in getattr(self.autocast_unsupported_lists, list_name) + ] + else: + return [tp for tp in getattr(self.autocast_lists, list_name)] def args_maybe_kwargs(self, op_with_args): if len(op_with_args) == 2: @@ -90,9 +229,9 @@ def cast(val, to_type): if add_kwargs is None: add_kwargs = {} - self.assertFalse(torch.is_autocast_enabled()) - with autocast(): - self.assertTrue(torch.is_autocast_enabled()) + self.assertFalse(self.is_autocast_enabled()) + with autocast(xm.xla_device()): + self.assertTrue(self.is_autocast_enabled()) out_type = out_type if out_type is not None else run_as_type output = output_method = None @@ -103,7 +242,7 @@ def cast(val, to_type): if isinstance(output, torch.Tensor): self.assertTrue( out_type == output.dtype, - "autocast for torch.{} produced {}, should produce {}".format( + "autocast for {} produced {}, should produce {}".format( op, output.dtype, out_type)) # Try Tensor.* variant: @@ -112,8 +251,8 @@ def cast(val, to_type): if isinstance(output_method, torch.Tensor): self.assertTrue( out_type == output_method.dtype, - "autocast for torch.{} produced {}, should produce torch.{}". - format(op, output_method.dtype, out_type)) + "autocast for {} produced {}, should produce torch.{}".format( + op, output_method.dtype, out_type)) self.assertTrue((output is not None) or ( output_method is not None @@ -141,8 +280,8 @@ def compare(first, second): # Compare numerics to Python-side "autocasting" that (we expect) does the same thing # as the C++-side autocasting, and should be bitwise accurate. output_to_compare = output if output is not None else output_method - with autocast(enabled=False): - self.assertFalse(torch.is_autocast_enabled()) + with autocast(xm.xla_device(), enabled=False): + self.assertFalse(self.is_autocast_enabled()) if module is not None and hasattr(module, op): control = getattr(module, op)(*cast(args, run_as_type), **add_kwargs) @@ -153,11 +292,41 @@ def compare(first, second): comparison = compare(output_to_compare, control) self.assertTrue(comparison, "torch.{} result did not match control".format(op)) - self.assertTrue(torch.is_autocast_enabled()) - self.assertFalse(torch.is_autocast_enabled()) + self.assertTrue(self.is_autocast_enabled()) + self.assertFalse(self.is_autocast_enabled()) + + +@unittest.skipIf(not xm.get_xla_supported_devices("GPU"), f"GPU autocast test.") +class TestAutocastCuda(TestAutocastBase): + + def setUp(self): + super(TestAutocastCuda, self).setUp() + self.is_autocast_enabled = torch.is_autocast_enabled + self.autocast_lists = AutocastTestLists(torch.device(xm.xla_device())) + self.autocast_unsupported_lists = AutocastCudaTestUnsupportedLists() + def test_autocast_nn_fp16(self): + with torch.backends.cudnn.flags(enabled=True, deterministic=True): + for op, args in self.get_autocast_list('nn_fp16'): + self._run_autocast_outofplace( + op, args, torch.float16, module=torch._C._nn) -class TestAutocast(TestAutocastBase): + def test_autocast_linalg_fp16(self): + with torch.backends.cudnn.flags(enabled=True, deterministic=True): + for op, args in self.get_autocast_list('linalg_fp16'): + self._run_autocast_outofplace( + op, args, torch.float16, module=torch._C._linalg) + + def test_autocast_methods_fp16(self): + with torch.backends.cudnn.flags(enabled=True, deterministic=True): + for op, args in self.get_autocast_list('methods_fp16'): + self._run_autocast_outofplace(op, args, torch.float16, module=None) + + def test_autocast_banned(self): + with torch.cuda.amp.autocast(): + for op, args, module in self.get_autocast_list('banned'): + with self.assertRaises(RuntimeError): + getattr(module, op)(*args) def test_autocast_torch_fp32(self): for op_with_args in self.get_autocast_list('torch_fp32'): @@ -174,30 +343,58 @@ def test_autocast_torch_expect_builtin_promote(self): 'torch_expect_builtin_promote'): self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type) - def test_autocast_nn_fp16(self): - with torch.backends.cudnn.flags(enabled=True, deterministic=True): - for op, args in self.get_autocast_list('nn_fp16'): - self._run_autocast_outofplace( - op, args, torch.float16, module=torch._C._nn) - def test_autocast_nn_fp32(self): for op, args in self.get_autocast_list('nn_fp32'): self._run_autocast_outofplace( op, args, torch.float32, module=torch._C._nn) - def test_autocast_linalg_fp16(self): - with torch.backends.cudnn.flags(enabled=True, deterministic=True): - for op, args in self.get_autocast_list('linalg_fp16'): - self._run_autocast_outofplace( - op, args, torch.float16, module=torch._C._linalg) + def test_autocast_methods_fp32(self): + for op, args in self.get_autocast_list('methods_fp32'): + print("autocast fp32", op) + self._run_autocast_outofplace(op, args, torch.float32, module=None) - def test_autocast_methods_fp16(self): - with torch.backends.cudnn.flags(enabled=True, deterministic=True): - for op, args in self.get_autocast_list('methods_fp16'): - self._run_autocast_outofplace(op, args, torch.float16, module=None) + def test_autocast_methods_expect_builtin_promote(self): + for op, args, out_type in self.get_autocast_list( + 'methods_expect_builtin_promote'): + self._run_autocast_outofplace( + op, args, torch.float32, module=None, out_type=out_type) + + +@unittest.skipIf(not xm.get_xla_supported_devices("TPU"), f"TPU autocast test.") +class TestAutocastTPU(TestAutocastBase): + + def setUp(self): + super(TestAutocastTPU, self).setUp() + self.is_autocast_enabled = torch.is_autocast_xla_enabled + self.autocast_lists = AutocastTPUTestLists(torch.device(xm.xla_device())) + + def test_autocast_methods_bf16(self): + for op, args in self.get_autocast_list('methods_bf16'): + self._run_autocast_outofplace(op, args, torch.bfloat16, module=None) + + def test_autocast_torch_fp32(self): + for op_with_args in self.get_autocast_list('torch_fp32'): + op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) + self._run_autocast_outofplace( + op, args, torch.float32, add_kwargs=maybe_kwargs) + + def test_autocast_torch_need_autocast_promote(self): + for op, args in self.get_autocast_list('torch_need_autocast_promote'): + self._run_autocast_outofplace(op, args, torch.float32) + + def test_autocast_torch_expect_builtin_promote(self): + for op, args, out_type in self.get_autocast_list( + 'torch_expect_builtin_promote'): + self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type) + + def test_autocast_nn_fp32(self): + for op, args in self.get_autocast_list('nn_fp32'): + self._run_autocast_outofplace( + op, args, torch.float32, module=torch._C._nn) def test_autocast_methods_fp32(self): for op, args in self.get_autocast_list('methods_fp32'): + print("autocast fp32", op) self._run_autocast_outofplace(op, args, torch.float32, module=None) def test_autocast_methods_expect_builtin_promote(self): @@ -206,12 +403,6 @@ def test_autocast_methods_expect_builtin_promote(self): self._run_autocast_outofplace( op, args, torch.float32, module=None, out_type=out_type) - def test_autocast_banned(self): - with torch.cuda.amp.autocast(): - for op, args, module in self.get_autocast_list('banned'): - with self.assertRaises(RuntimeError): - getattr(module, op)(*args) - if __name__ == "__main__": test = unittest.main(verbosity=FLAGS.verbosity, exit=False) diff --git a/test/test_train_mp_imagenet_amp.py b/test/test_train_mp_imagenet_amp.py index 734d9c693fc..bec112a9378 100644 --- a/test/test_train_mp_imagenet_amp.py +++ b/test/test_train_mp_imagenet_amp.py @@ -27,7 +27,6 @@ '--test_only_at_end': { 'action': 'store_true', }, - # AMP only works with XLA:GPU '--amp': { 'action': 'store_true', }, @@ -197,6 +196,7 @@ def train_imagenet(): torch.manual_seed(42) device = xm.xla_device() + device_hw = xm.xla_device_hw(device) model = get_model_property('model_fn')().to(device) writer = None if xm.is_master_ordinal(): @@ -219,7 +219,10 @@ def train_imagenet(): summary_writer=writer) loss_fn = nn.CrossEntropyLoss() if FLAGS.amp: - scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) + if device_hw == 'TPU': + scaler = None + elif device_hw == 'GPU': + scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) def train_loop_fn(loader, epoch): tracker = xm.RateTracker() @@ -227,15 +230,18 @@ def train_loop_fn(loader, epoch): for step, (data, target) in enumerate(loader): optimizer.zero_grad() if FLAGS.amp: - with autocast(): + with autocast(xm.xla_device()): output = model(data) loss = loss_fn(output, target) - - scaler.scale(loss).backward() - gradients = xm._fetch_gradients(optimizer) - xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size()) - scaler.step(optimizer) - scaler.update() + if scaler: + scaler.scale(loss).backward() + gradients = xm._fetch_gradients(optimizer) + xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size()) + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + xm.optimizer_step(optimizer) else: output = model(data) loss = loss_fn(output, target) diff --git a/test/test_train_mp_mnist_amp.py b/test/test_train_mp_mnist_amp.py index 2f773f4d032..ae4db118300 100644 --- a/test/test_train_mp_mnist_amp.py +++ b/test/test_train_mp_mnist_amp.py @@ -66,10 +66,10 @@ def forward(self, x): return F.log_softmax(x, dim=1) -def _train_update(device, x, loss, tracker, writer): +def _train_update(device, step, loss, tracker, writer): test_utils.print_training_update( device, - x, + step, loss.item(), tracker.rate(), tracker.global_rate(), @@ -130,28 +130,42 @@ def train_mnist(flags, **kwargs): lr = flags.lr * xm.xrt_world_size() device = xm.xla_device() + device_hw = xm.xla_device_hw(device) model = MNIST().to(device) + writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(flags.logdir) optim_cls = syncfree.SGD if FLAGS.use_syncfree_optim else optim.SGD optimizer = optim_cls(model.parameters(), lr=lr, momentum=flags.momentum) loss_fn = nn.NLLLoss() - scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) + + if device_hw == 'TPU': + scaler = None + elif device_hw == 'GPU': + # GradScaler only used for GPU + scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) + else: + print("Only TPU or GPU supported for AMP.") + sys.exit(1) def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() - with autocast(): + with autocast(device): output = model(data) loss = loss_fn(output, target) - scaler.scale(loss).backward() - gradients = xm._fetch_gradients(optimizer) - xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size()) - scaler.step(optimizer) - scaler.update() + if scaler: + scaler.scale(loss).backward() + gradients = xm._fetch_gradients(optimizer) + xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size()) + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + xm.optimizer_step(optimizer) tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure( diff --git a/torch_patches/.torch_pin b/torch_patches/.torch_pin new file mode 100644 index 00000000000..7e0a6082f13 --- /dev/null +++ b/torch_patches/.torch_pin @@ -0,0 +1 @@ +#96370 diff --git a/torch_xla/amp/__init__.py b/torch_xla/amp/__init__.py index 1c0ecd08876..739f55cc0dc 100644 --- a/torch_xla/amp/__init__.py +++ b/torch_xla/amp/__init__.py @@ -1,2 +1,2 @@ -from .autocast_mode import autocast, custom_fwd, custom_bwd # noqa: F401 +from .autocast_mode import autocast # noqa: F401 from .grad_scaler import GradScaler # noqa: F401 diff --git a/torch_xla/amp/autocast_mode.py b/torch_xla/amp/autocast_mode.py index 2f7a0a83555..999a099b777 100644 --- a/torch_xla/amp/autocast_mode.py +++ b/torch_xla/amp/autocast_mode.py @@ -1,5 +1,40 @@ import torch +import torch_xla.core.xla_model as xm +from typing import Any -autocast = torch.cuda.amp.autocast -custom_fwd = torch.cuda.amp.custom_fwd -custom_bwd = torch.cuda.amp.custom_bwd + +class autocast(torch.amp.autocast_mode.autocast): + r""" + See :class:`torch.autocast`. + ``torch_xla.amp.autocast(device, args...)`` is equivalent to ``torch.autocast("xla", args...)`` for TPUs + ``torch.autocast("cuda", args...)`` for GPUs. + """ + + def __init__(self, + device, + enabled: bool = True, + dtype: torch.dtype = torch.bfloat16, + cache_enabled: bool = True): + if xm.xla_device_hw(device) == 'GPU': + super().__init__( + "cuda", + enabled=enabled, + dtype=torch.float16, + cache_enabled=cache_enabled) + elif xm.xla_device_hw(device) == 'TPU': + super().__init__( + "xla", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled) + else: + print( + 'Warning: AMP only supported for XLA:TPU and XLA:GPU. Ignoring autocast.' + ) + + def __enter__(self): + return super().__enter__() + + def __exit__(self, exc_type: Any, exc_val: Any, + exc_tb: Any): # type: ignore[override] + return super().__exit__(exc_type, exc_val, exc_tb) + + def __call__(self, func): + return super().__call__(func) diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 7a016153790..b839acc5692 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -34,6 +34,7 @@ ptxla_cc_library( "aten_autograd_ops.cpp", "aten_xla_bridge.cpp", "aten_xla_type.cpp", + "autocast_mode.cpp", "batch_norm.cpp", "convert_ops.cpp", "convolution.cpp", diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 854eaaff121..e8c1bb14ef6 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -437,10 +437,6 @@ void XLANativeFunctions::_amp_foreach_non_finite_check_and_unscale_( at::TensorList self, at::Tensor& found_inf, const at::Tensor& inv_scale) { TORCH_LAZY_FN_COUNTER("xla::"); XLATensorPtr found_inf_tensor = bridge::GetXlaTensor(found_inf); - XlaDeviceType hw_type = - static_cast(found_inf_tensor->GetDevice().type()); - XLA_CHECK(hw_type == XlaDeviceType::GPU || hw_type == XlaDeviceType::CPU) - << "AMP should be used with XLA:GPU"; tensor_methods::_amp_foreach_non_finite_check_and_unscale_( bridge::GetXlaTensors(self), found_inf_tensor, bridge::GetXlaTensor(inv_scale)); @@ -455,10 +451,6 @@ at::Tensor& XLANativeFunctions::_amp_update_scale_(at::Tensor& current_scale, TORCH_LAZY_FN_COUNTER("xla::"); XLATensorPtr growth_tracker_tensor = bridge::GetXlaTensor(growth_tracker); XLATensorPtr current_scale_tensor = bridge::GetXlaTensor(current_scale); - XlaDeviceType hw_type = - static_cast(growth_tracker_tensor->GetDevice().type()); - XLA_CHECK(hw_type == XlaDeviceType::GPU || hw_type == XlaDeviceType::CPU) - << "AMP should be used with XLA:GPU"; tensor_methods::_amp_update_scale_( growth_tracker_tensor, current_scale_tensor, bridge::GetXlaTensor(found_inf), scale_growth_factor, diff --git a/torch_xla/csrc/autocast_mode.cpp b/torch_xla/csrc/autocast_mode.cpp new file mode 100644 index 00000000000..91ff7999e10 --- /dev/null +++ b/torch_xla/csrc/autocast_mode.cpp @@ -0,0 +1,160 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace autocast { +namespace { + +#define KERNEL_XLA(OP, POLICY) KERNEL(c10::DeviceType::XLA, OP, POLICY) + +#define KERNEL_XLA2(OP, OVERLOAD, POLICY) \ + KERNEL2(c10::DeviceType::XLA, OP, OVERLOAD, POLICY) + +#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XLA( \ + REDISPATCH_FUNC, REGISTER_NAME, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, \ + POLICY) \ + KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(c10::DeviceType::XLA, REDISPATCH_FUNC, \ + REGISTER_NAME, REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, POLICY) + +TORCH_LIBRARY_IMPL(_, AutocastXLA, m) { + m.fallback(torch::CppFunction::makeFallthrough()); +} + +TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) { + // lower_precision_fp cast policy + KERNEL_XLA(conv1d, lower_precision_fp) + KERNEL_XLA2(conv1d, padding, lower_precision_fp) + KERNEL_XLA(conv2d, lower_precision_fp) + KERNEL_XLA2(conv2d, padding, lower_precision_fp) + KERNEL_XLA(conv3d, lower_precision_fp) + KERNEL_XLA2(conv3d, padding, lower_precision_fp) + KERNEL_XLA(bmm, lower_precision_fp) + KERNEL_XLA(mm, lower_precision_fp) + KERNEL_XLA(baddbmm, lower_precision_fp) + KERNEL_XLA(addmm, lower_precision_fp) + KERNEL_XLA(addbmm, lower_precision_fp) + KERNEL_XLA(linear, lower_precision_fp) + KERNEL_XLA(matmul, lower_precision_fp) + KERNEL_XLA(conv_tbc, lower_precision_fp) + KERNEL_XLA(conv_transpose1d, lower_precision_fp) + KERNEL_XLA2(conv_transpose2d, input, lower_precision_fp) + KERNEL_XLA2(conv_transpose3d, input, lower_precision_fp) + KERNEL_XLA(prelu, lower_precision_fp) + KERNEL_XLA(relu, lower_precision_fp) + KERNEL_XLA(max_pool2d, lower_precision_fp) + + // fp32 cast policy + // Commented out ops are included in the AutoCastCPU Policy, + // but not lowered. Enable if op is lowered. + KERNEL_XLA(batch_norm, fp32) + KERNEL_XLA2(log_softmax, int, fp32) + KERNEL_XLA2(log_softmax, Dimname, fp32) + KERNEL_XLA(binary_cross_entropy, fp32) + // KERNEL_XLA(grid_sampler, fp32) + // KERNEL_XLA(polar, fp32) + KERNEL_XLA(prod, fp32) + KERNEL_XLA2(prod, dim_int, fp32) + KERNEL_XLA2(prod, dim_Dimname, fp32) + // KERNEL_XLA(quantile, fp32) + // KERNEL_XLA2(quantile, scalar, fp32) + // KERNEL_XLA(nanquantile, fp32) + // KERNEL_XLA2(nanquantile, scalar, fp32) + // KERNEL_XLA(stft, fp32) + // KERNEL_XLA2(stft, center, fp32) + KERNEL_XLA(cdist, fp32) + // KERNEL_XLA(grid_sampler_2d, fp32) + // KERNEL_XLA(grid_sampler_3d, fp32) + KERNEL_XLA(trace, fp32) + // KERNEL_XLA(view_as_complex, fp32) + KERNEL_XLA(cholesky, fp32) + KERNEL_XLA(cholesky_inverse, fp32) + KERNEL_XLA(cholesky_solve, fp32) + KERNEL_XLA(inverse, fp32) + // KERNEL_XLA(lu_solve, fp32) + // KERNEL_XLA(orgqr, fp32) + // KERNEL_XLA(ormqr, fp32) + // KERNEL_XLA(pinverse, fp32) + KERNEL_XLA(reflection_pad1d, fp32) + KERNEL_XLA(reflection_pad2d, fp32) + KERNEL_XLA(replication_pad1d, fp32) + KERNEL_XLA(replication_pad2d, fp32) + KERNEL_XLA(replication_pad3d, fp32) + KERNEL_XLA(mse_loss, fp32) + KERNEL_XLA(cosine_embedding_loss, fp32) + KERNEL_XLA(nll_loss, fp32) + KERNEL_XLA(nll_loss2d, fp32) + KERNEL_XLA(hinge_embedding_loss, fp32) + // KERNEL_XLA(poisson_nll_loss, fp32) + KERNEL_XLA(smooth_l1_loss, fp32) + // KERNEL_XLA(cross_entropy_loss, fp32) + KERNEL_XLA(l1_loss, fp32) + // KERNEL_XLA(huber_loss, fp32) + KERNEL_XLA(margin_ranking_loss, fp32) + KERNEL_XLA(soft_margin_loss, fp32) + KERNEL_XLA(triplet_margin_loss, fp32) + KERNEL_XLA(multi_margin_loss, fp32) + KERNEL_XLA2(ctc_loss, IntList, fp32) + KERNEL_XLA2(ctc_loss, Tensor, fp32) + KERNEL_XLA(kl_div, fp32) + KERNEL_XLA(multilabel_margin_loss, fp32) + KERNEL_XLA(binary_cross_entropy_with_logits, fp32) + // KERNEL_XLA(fft_fft, fp32) + // KERNEL_XLA(fft_ifft, fp32) + // KERNEL_XLA(fft_fft2, fp32) + // KERNEL_XLA(fft_ifft2, fp32) + // KERNEL_XLA(fft_fftn, fp32) + // KERNEL_XLA(fft_ifftn, fp32) + // KERNEL_XLA(fft_rfft, fp32) + // KERNEL_XLA(fft_irfft, fp32) + // KERNEL_XLA(fft_rfft2, fp32) + // KERNEL_XLA(fft_irfft2, fp32) + // KERNEL_XLA(fft_rfftn, fp32) + // KERNEL_XLA(fft_irfftn, fp32) + // KERNEL_XLA(fft_hfft, fp32) + // KERNEL_XLA(fft_ihfft, fp32) + // KERNEL_XLA(linalg_cond, fp32) + // KERNEL_XLA2(linalg_cond, p_str, fp32) + // KERNEL_XLA(linalg_matrix_rank, fp32) + // KERNEL_XLA2(linalg_matrix_rank, tol_tensor, fp32) + // KERNEL_XLA2(linalg_matrix_rank, atol_rtol_tensor, fp32) + // KERNEL_XLA2(linalg_matrix_rank, atol_rtol_float, fp32) + // KERNEL_XLA(linalg_solve, fp32) + // KERNEL_XLA(linalg_cholesky, fp32) + // KERNEL_XLA(linalg_svdvals, fp32) + // KERNEL_XLA(linalg_eigvals, fp32) + // KERNEL_XLA(linalg_eigvalsh, fp32) + // KERNEL_XLA(linalg_inv, fp32) + // KERNEL_XLA(linalg_householder_product, fp32) + // KERNEL_XLA(linalg_tensorinv, fp32) + // KERNEL_XLA(linalg_tensorsolve, fp32) + // KERNEL_XLA(fake_quantize_per_tensor_affine, fp32) + // KERNEL_XLA(geqrf, fp32) + // KERNEL_XLA(_lu_with_info, fp32) + KERNEL_XLA(qr, fp32) + KERNEL_XLA(svd, fp32) + KERNEL_XLA(triangular_solve, fp32) + KERNEL_XLA(multilabel_margin_loss_forward, fp32) + // KERNEL_XLA(linalg_qr, fp32) + // KERNEL_XLA(linalg_cholesky_ex, fp32) + KERNEL_XLA(linalg_svd, fp32) + // KERNEL_XLA(linalg_eig, fp32) + // KERNEL_XLA(linalg_eigh, fp32) + // KERNEL_XLA(linalg_lstsq, fp32) + KERNEL_XLA(linalg_inv_ex, fp32) + + // promote + KERNEL_XLA(stack, promote) + KERNEL_XLA(cat, promote) + KERNEL_XLA(index_copy, promote) + KERNEL_XLA2(index_copy, dimname, promote) +} + +} // namespace +} // namespace autocast +} // namespace at diff --git a/torch_xla/csrc/tensor_impl.cpp b/torch_xla/csrc/tensor_impl.cpp index 83f4e19c7d3..bacccdeb85b 100644 --- a/torch_xla/csrc/tensor_impl.cpp +++ b/torch_xla/csrc/tensor_impl.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -62,6 +63,15 @@ XLATensorImpl::XLATensorImpl(XLATensor&& tensor) GetTypeMeta(tensor), bridge::XlaDeviceToAtenDevice(tensor.GetDevice())), tensor_(c10::make_intrusive(std::move(tensor))) { + // Update the Autocast key based off the backend device. + // Upstream TensorImpl cannot differentiate between XLA:TPU and XLA:GPU + // so we must manually update Autocast to AutocastCUDA on XLA:GPU. + torch::lazy::BackendDevice current_device = GetCurrentDevice(); + if (static_cast(current_device.type()) == XlaDeviceType::GPU) { + auto autocast_cuda_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastCUDA); + auto autocast_xla_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastXLA); + key_set_ = (key_set_ - autocast_xla_ks) | autocast_cuda_ks; + } is_non_overlapping_and_dense_ = false; set_custom_sizes_strides(SizesStridesPolicy::CustomSizes); }