From 0df4743981dde418eae0769ae8d9e36fbcd8a502 Mon Sep 17 00:00:00 2001 From: Meghan Cowan Date: Wed, 8 Mar 2023 23:27:43 +0000 Subject: [PATCH 01/20] tpu amp --- test/test_train_mp_imagenet_amp.py | 27 ++++++++++++++++++--------- test/test_train_mp_mnist_amp.py | 23 +++++++++++++++++------ torch_xla/csrc/aten_xla_type.cpp | 8 -------- 3 files changed, 35 insertions(+), 23 deletions(-) diff --git a/test/test_train_mp_imagenet_amp.py b/test/test_train_mp_imagenet_amp.py index 734d9c693fc..da909564a58 100644 --- a/test/test_train_mp_imagenet_amp.py +++ b/test/test_train_mp_imagenet_amp.py @@ -27,7 +27,6 @@ '--test_only_at_end': { 'action': 'store_true', }, - # AMP only works with XLA:GPU '--amp': { 'action': 'store_true', }, @@ -68,7 +67,7 @@ import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp import torch_xla.test.test_utils as test_utils -from torch_xla.amp import autocast, GradScaler +from torch_xla.amp import GradScaler try: from torch_xla.amp import syncfree except ImportError: @@ -197,6 +196,7 @@ def train_imagenet(): torch.manual_seed(42) device = xm.xla_device() + device_hw = xm.xla_device_hw(device) model = get_model_property('model_fn')().to(device) writer = None if xm.is_master_ordinal(): @@ -219,7 +219,13 @@ def train_imagenet(): summary_writer=writer) loss_fn = nn.CrossEntropyLoss() if FLAGS.amp: - scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) + if device_hw == 'TPU': + autocast = torch.xla.amp.autocast + scaler = None + elif device_hw == 'GPU': + autocast = torch.cuda.amp.autocast + # GradScaler only used for GPU + scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) def train_loop_fn(loader, epoch): tracker = xm.RateTracker() @@ -230,12 +236,15 @@ def train_loop_fn(loader, epoch): with autocast(): output = model(data) loss = loss_fn(output, target) - - scaler.scale(loss).backward() - gradients = xm._fetch_gradients(optimizer) - xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size()) - scaler.step(optimizer) - scaler.update() + if scaler: + scaler.scale(loss).backward() + gradients = xm._fetch_gradients(optimizer) + xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size()) + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + xm.optimizer_step(optimizer) else: output = model(data) loss = loss_fn(output, target) diff --git a/test/test_train_mp_mnist_amp.py b/test/test_train_mp_mnist_amp.py index 2f773f4d032..5eebc2858c1 100644 --- a/test/test_train_mp_mnist_amp.py +++ b/test/test_train_mp_mnist_amp.py @@ -130,6 +130,7 @@ def train_mnist(flags, **kwargs): lr = flags.lr * xm.xrt_world_size() device = xm.xla_device() + device_hw = xm.xla_device_hw(device) model = MNIST().to(device) writer = None if xm.is_master_ordinal(): @@ -137,7 +138,13 @@ def train_mnist(flags, **kwargs): optim_cls = syncfree.SGD if FLAGS.use_syncfree_optim else optim.SGD optimizer = optim_cls(model.parameters(), lr=lr, momentum=flags.momentum) loss_fn = nn.NLLLoss() - scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) + if device_hw == 'TPU': + autocast = torch.xla.amp.autocast + scaler = None + elif device_hw == 'GPU': + autocast = torch.cuda.amp.autocast + # GradScaler only used for GPU + scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) def train_loop_fn(loader): tracker = xm.RateTracker() @@ -147,11 +154,15 @@ def train_loop_fn(loader): with autocast(): output = model(data) loss = loss_fn(output, target) - scaler.scale(loss).backward() - gradients = xm._fetch_gradients(optimizer) - xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size()) - scaler.step(optimizer) - scaler.update() + if scaler: + scaler.scale(loss).backward() + gradients = xm._fetch_gradients(optimizer) + xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size()) + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + xm.optimizer_step(optimizer) tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure( diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index c6e7abc1ff8..ba128ca5bf7 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -426,10 +426,6 @@ void XLANativeFunctions::_amp_foreach_non_finite_check_and_unscale_( at::TensorList self, at::Tensor& found_inf, const at::Tensor& inv_scale) { TORCH_LAZY_FN_COUNTER("xla::"); XLATensorPtr found_inf_tensor = bridge::GetXlaTensor(found_inf); - XlaDeviceType hw_type = - static_cast(found_inf_tensor->GetDevice().type()); - XLA_CHECK(hw_type == XlaDeviceType::GPU || hw_type == XlaDeviceType::CPU) - << "AMP should be used with XLA:GPU"; tensor_methods::_amp_foreach_non_finite_check_and_unscale_( bridge::GetXlaTensors(self), found_inf_tensor, bridge::GetXlaTensor(inv_scale)); @@ -444,10 +440,6 @@ at::Tensor& XLANativeFunctions::_amp_update_scale_(at::Tensor& current_scale, TORCH_LAZY_FN_COUNTER("xla::"); XLATensorPtr growth_tracker_tensor = bridge::GetXlaTensor(growth_tracker); XLATensorPtr current_scale_tensor = bridge::GetXlaTensor(current_scale); - XlaDeviceType hw_type = - static_cast(growth_tracker_tensor->GetDevice().type()); - XLA_CHECK(hw_type == XlaDeviceType::GPU || hw_type == XlaDeviceType::CPU) - << "AMP should be used with XLA:GPU"; tensor_methods::_amp_update_scale_( growth_tracker_tensor, current_scale_tensor, bridge::GetXlaTensor(found_inf), scale_growth_factor, From 7997da0fcaa27b5a11b1129b0c7e46e680ee0068 Mon Sep 17 00:00:00 2001 From: Meghan Cowan Date: Thu, 16 Mar 2023 16:45:17 +0000 Subject: [PATCH 02/20] Add torch pin --- torch_patches/.torch_pin | 1 + 1 file changed, 1 insertion(+) create mode 100644 torch_patches/.torch_pin diff --git a/torch_patches/.torch_pin b/torch_patches/.torch_pin new file mode 100644 index 00000000000..7e0a6082f13 --- /dev/null +++ b/torch_patches/.torch_pin @@ -0,0 +1 @@ +#96370 From 401eaf61efa3acd1750acabfacf0ea7b0cae7360 Mon Sep 17 00:00:00 2001 From: Meghan Date: Tue, 11 Apr 2023 19:42:08 +0000 Subject: [PATCH 03/20] updates --- test/test_fsdp_auto_wrap_amp.py | 65 +++++++++ test/test_train_mp_mnist_amp.py | 13 +- torch_xla/csrc/autocast_mode.cpp | 229 +++++++++++++++++++++++++++++++ 3 files changed, 304 insertions(+), 3 deletions(-) create mode 100644 test/test_fsdp_auto_wrap_amp.py create mode 100644 torch_xla/csrc/autocast_mode.cpp diff --git a/test/test_fsdp_auto_wrap_amp.py b/test/test_fsdp_auto_wrap_amp.py new file mode 100644 index 00000000000..27a1fa0f9ae --- /dev/null +++ b/test/test_fsdp_auto_wrap_amp.py @@ -0,0 +1,65 @@ +import torch +import torch_xla +import torch_xla.utils.utils as xu +import torch_xla.core.xla_model as xm +import torch_xla.distributed.xla_multiprocessing as xmp +import test_utils + +from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel +from torch_xla.distributed.fsdp.wrap import always_wrap_policy + +import sys +import unittest + + +class TestNoBackwardModule(test_utils.XlaTestCase): + # Test the FSDP autowrap feature with a module containing a submodule + # that is only used in forward (fc2 below), to make sure it doesn't + # fail by the hook assertion. + class MyModel(torch.nn.Module): + + def __init__(self, input_size, hidden_size): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size) + self.fc2 = torch.nn.Linear(self.input_size, self.hidden_size) + + def forward(self, x): + hidden1 = self.fc1(x) + hidden2 = self.fc2(x) + return hidden1, hidden2 + + def test(self): + dev = xm.xla_device() + device_hw = xm.xla_device_hw(dev) + if device_hw == 'TPU': + autocast = torch.xla.amp.autocast + elif device_hw == 'GPU': + autocast = torch.cuda.amp.autocast + + input = torch.zeros([16, 16], device=dev) + model = self.MyModel(input_size=16, hidden_size=4) + model = XlaFullyShardedDataParallel( + model, auto_wrap_policy=always_wrap_policy) + model.to(dev) + with autocast(): + hid1, hid2 = model(input) + loss = hid1.sum() + loss.backward() + xm.mark_step() + + +def _mp_fn(index): + device = xm.xla_device() + if xm.xla_device_hw(device) in ('TPU', 'GPU'): + test = unittest.main(exit=False) + sys.exit(0 if test.result.wasSuccessful() else 1) + else: + print( + 'Default device {} is not a TPU or GPU device'.format(device), + file=sys.stderr) + + +if __name__ == '__main__': + xmp.spawn(_mp_fn, args=()) \ No newline at end of file diff --git a/test/test_train_mp_mnist_amp.py b/test/test_train_mp_mnist_amp.py index 5eebc2858c1..2e7c9c72f32 100644 --- a/test/test_train_mp_mnist_amp.py +++ b/test/test_train_mp_mnist_amp.py @@ -139,21 +139,28 @@ def train_mnist(flags, **kwargs): optimizer = optim_cls(model.parameters(), lr=lr, momentum=flags.momentum) loss_fn = nn.NLLLoss() if device_hw == 'TPU': - autocast = torch.xla.amp.autocast + device = "xla" + dtype = torch.bfloat16 scaler = None + print("Setting autocast device to xla") elif device_hw == 'GPU': - autocast = torch.cuda.amp.autocast + device = "cuda" + dtype = torch.float16 # GradScaler only used for GPU scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) + print("Setting autocast device to cuda") def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() - with autocast(): + print("Entering step", step) + with torch.autocast(device, dtype=dtype): output = model(data) + print(output.dtype) loss = loss_fn(output, target) + print("Exiting autocast region") if scaler: scaler.scale(loss).backward() gradients = xm._fetch_gradients(optimizer) diff --git a/torch_xla/csrc/autocast_mode.cpp b/torch_xla/csrc/autocast_mode.cpp new file mode 100644 index 00000000000..a6be6929c4f --- /dev/null +++ b/torch_xla/csrc/autocast_mode.cpp @@ -0,0 +1,229 @@ +#include +#include +#include +#include +#include + +#include +#include + +namespace at { +namespace autocast { + +// Note: Copied over from autocast_mode.cpp +// Policies correspond to op categories that need code-divergent handling. +// Wrapper templates below are specialized based on a policy template parameter. +enum class CastPolicy : uint8_t { + lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before running the op. + // Currently, lower_precision_fp is fp16 for AutocastCUDA, and is defined by user(default bf16) for AutocastCPU. + fp32, // Cast all inputs to at::kFloat before running the op. + // TODO: fp32_set_opt_dtype seems to only be for CUDA devices? + fp32_set_opt_dtype, // Treats functions (like softmax) that + // 1. we'd like to run in fp32 and + // 2. have a c10::optional arg that controls the output type. + // fp32_set_opt_dtype wrappers' policy is: if the output type is already set, + // don't touch it, otherwise, set it to at::kFloat. + fp32_append_dtype, // Treats functions (like norm) that + // 1. we'd like to run in fp32 and + // 2. have some overloads that accept an output type and other overloads that don't. + // fp32_append_dtype wrappers wrap the overloads that don't have an output dtype. + // The wrapper policy is: append at::kFloat to the args, and redispatch to the + // type-aware overload. + promote, // Run in the widest dtype among several args. +}; + +// Base template for WrapFunction_, which is specialized to contain a "call" method each CastPolicy +template struct WrapFunction_ {}; + +// CastPolicy::lower_precision_fp General_DeviceType +template +struct WrapFunction_> { + static Ret call(Args... args) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(get_autocast_dispatch_key_from_device_type(device_type)); + return (*F)(cached_cast(get_lower_precision_fp_from_device_type(device_type), args, device_type)...); + } +}; + +// CastPolicy::fp32 General_DeviceType +template +struct WrapFunction_> { + static Ret call(Args... args) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(get_autocast_dispatch_key_from_device_type(device_type)); + return (*F)(cached_cast(at::kFloat, args, device_type)...); + } +}; + +// CastPolicy::promote General_DeviceType +template +struct WrapFunction_> { + static Ret call(Args... args) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(get_autocast_dispatch_key_from_device_type(device_type)); + auto to_type = promote_type(get_lower_precision_fp_from_device_type(device_type), device_type, args...); + return (*F)(cached_cast(to_type, args, device_type)...); + } +}; + +// Wrapper to infer return_type and parameter_types for WrapFunction_ (imitating core/boxing/impl/WrapFunctionIntoFunctor.h) +template // The actual function we're redispatching to. +struct WrapFunction final { + using type = WrapFunction_::return_type, + typename guts::function_traits::parameter_types>; +}; + +namespace { + +// KERNEL_XLA registration for AutocastXLA +#define KERNEL_XLA(OP, POLICY) \ + m.impl(TORCH_SELECTIVE_NAME("aten::" #OP), \ + &WrapFunction::type::call); +#define KERNEL_XLA2(OP, OVERLOAD, POLICY) \ + m.impl(TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \ + &WrapFunction::type::call); + + +TORCH_LIBRARY_IMPL(_, AutocastXLA, m) { + m.fallback(torch::CppFunction::makeFallthrough()); +} + + +TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) { + // lower_precision_fp cast policy + KERNEL_XLA(conv1d, lower_precision_fp) + KERNEL_XLA2(conv1d, padding, lower_precision_fp) + KERNEL_XLA(conv2d, lower_precision_fp) + KERNEL_XLA2(conv2d, padding, lower_precision_fp) + KERNEL_XLA(conv3d, lower_precision_fp) + KERNEL_XLA2(conv3d, padding, lower_precision_fp) + KERNEL_XLA(bmm, lower_precision_fp) + KERNEL_XLA(mm, lower_precision_fp) + KERNEL_XLA(baddbmm, lower_precision_fp) + KERNEL_XLA(addmm, lower_precision_fp) + KERNEL_XLA(addbmm, lower_precision_fp) + KERNEL_XLA(linear, lower_precision_fp) + KERNEL_XLA(matmul, lower_precision_fp) + KERNEL_XLA(conv_tbc, lower_precision_fp) + KERNEL_XLA(conv_transpose1d, lower_precision_fp) + KERNEL_XLA2(conv_transpose2d, input, lower_precision_fp) + KERNEL_XLA2(conv_transpose3d, input, lower_precision_fp) + KERNEL_XLA(prelu, lower_precision_fp) + KERNEL_XLA(relu, lower_precision_fp) + + // fp32 cast policy + KERNEL_XLA(batch_norm, fp32) + KERNEL_XLA(binary_cross_entropy, fp32) + KERNEL_XLA(grid_sampler, fp32) + KERNEL_XLA(polar, fp32) + KERNEL_XLA(prod, fp32) + KERNEL_XLA2(prod, dim_int, fp32) + KERNEL_XLA2(prod, dim_Dimname, fp32) + KERNEL_XLA(quantile, fp32) + KERNEL_XLA2(quantile, scalar, fp32) + KERNEL_XLA(nanquantile, fp32) + KERNEL_XLA2(nanquantile, scalar, fp32) + KERNEL_XLA(stft, fp32) + KERNEL_XLA2(stft, center, fp32) + KERNEL_XLA(cdist, fp32) + KERNEL_XLA(grid_sampler_2d, fp32) + KERNEL_XLA(grid_sampler_3d, fp32) + KERNEL_XLA(trace, fp32) + KERNEL_XLA(view_as_complex, fp32) + KERNEL_XLA(cholesky, fp32) + KERNEL_XLA(cholesky_inverse, fp32) + KERNEL_XLA(cholesky_solve, fp32) + KERNEL_XLA(inverse, fp32) + KERNEL_XLA(lu_solve, fp32) + KERNEL_XLA(orgqr, fp32) + KERNEL_XLA(ormqr, fp32) + KERNEL_XLA(pinverse, fp32) + KERNEL_XLA(reflection_pad1d, fp32) + KERNEL_XLA(reflection_pad2d, fp32) + KERNEL_XLA(replication_pad1d, fp32) + KERNEL_XLA(replication_pad2d, fp32) + KERNEL_XLA(replication_pad3d, fp32) + KERNEL_XLA(mse_loss, fp32) + KERNEL_XLA(cosine_embedding_loss, fp32) + KERNEL_XLA(nll_loss, fp32) + KERNEL_XLA(nll_loss2d, fp32) + KERNEL_XLA(hinge_embedding_loss, fp32) + KERNEL_XLA(poisson_nll_loss, fp32) + KERNEL_XLA(smooth_l1_loss, fp32) + KERNEL_XLA(cross_entropy_loss, fp32) + KERNEL_XLA(l1_loss, fp32) + KERNEL_XLA(huber_loss, fp32) + KERNEL_XLA(margin_ranking_loss, fp32) + KERNEL_XLA(soft_margin_loss, fp32) + KERNEL_XLA(triplet_margin_loss, fp32) + KERNEL_XLA(multi_margin_loss, fp32) + KERNEL_XLA2(ctc_loss, IntList, fp32) + KERNEL_XLA2(ctc_loss, Tensor, fp32) + KERNEL_XLA(kl_div, fp32) + KERNEL_XLA(multilabel_margin_loss, fp32) + KERNEL_XLA(binary_cross_entropy_with_logits, fp32) + KERNEL_XLA(fft_fft, fp32) + KERNEL_XLA(fft_ifft, fp32) + KERNEL_XLA(fft_fft2, fp32) + KERNEL_XLA(fft_ifft2, fp32) + KERNEL_XLA(fft_fftn, fp32) + KERNEL_XLA(fft_ifftn, fp32) + KERNEL_XLA(fft_rfft, fp32) + KERNEL_XLA(fft_irfft, fp32) + KERNEL_XLA(fft_rfft2, fp32) + KERNEL_XLA(fft_irfft2, fp32) + KERNEL_XLA(fft_rfftn, fp32) + KERNEL_XLA(fft_irfftn, fp32) + KERNEL_XLA(fft_hfft, fp32) + KERNEL_XLA(fft_ihfft, fp32) + KERNEL_XLA(linalg_cond, fp32) + KERNEL_XLA2(linalg_cond, p_str, fp32) + KERNEL_XLA(linalg_matrix_rank, fp32) + KERNEL_XLA2(linalg_matrix_rank, tol_tensor, fp32) + KERNEL_XLA2(linalg_matrix_rank, atol_rtol_tensor, fp32) + KERNEL_XLA2(linalg_matrix_rank, atol_rtol_float, fp32) + KERNEL_XLA(linalg_solve, fp32) + KERNEL_XLA(linalg_cholesky, fp32) + KERNEL_XLA(linalg_svdvals, fp32) + KERNEL_XLA(linalg_eigvals, fp32) + KERNEL_XLA(linalg_eigvalsh, fp32) + KERNEL_XLA(linalg_inv, fp32) + KERNEL_XLA(linalg_householder_product, fp32) + KERNEL_XLA(linalg_tensorinv, fp32) + KERNEL_XLA(linalg_tensorsolve, fp32) + KERNEL_XLA(fake_quantize_per_tensor_affine, fp32) + KERNEL_XLA(geqrf, fp32) + KERNEL_XLA(_lu_with_info, fp32) + KERNEL_XLA(qr, fp32) + KERNEL_XLA(svd, fp32) + KERNEL_XLA(triangular_solve, fp32) + KERNEL_XLA(multilabel_margin_loss_forward, fp32) + KERNEL_XLA(linalg_qr, fp32) + KERNEL_XLA(linalg_cholesky_ex, fp32) + KERNEL_XLA(linalg_svd, fp32) + KERNEL_XLA(linalg_eig, fp32) + KERNEL_XLA(linalg_eigh, fp32) + KERNEL_XLA(linalg_lstsq, fp32) + KERNEL_XLA(linalg_inv_ex, fp32) + + // promote + KERNEL_XLA(stack, promote) + KERNEL_XLA(cat, promote) + KERNEL_XLA(index_copy, promote) + KERNEL_XLA2(index_copy, dimname, promote) +} + +} // namespace +} // namespace autocast +} // namespace at From 041fce290485038d5c7d91a86cb6a26e8ea9dee3 Mon Sep 17 00:00:00 2001 From: Meghan Date: Wed, 12 Apr 2023 16:35:28 +0000 Subject: [PATCH 04/20] Clean up --- test/test_train_mp_imagenet_amp.py | 12 +++++++----- test/test_train_mp_mnist_amp.py | 31 +++++++++++++++--------------- torch_xla/csrc/autocast_mode.cpp | 26 ++----------------------- 3 files changed, 25 insertions(+), 44 deletions(-) diff --git a/test/test_train_mp_imagenet_amp.py b/test/test_train_mp_imagenet_amp.py index da909564a58..bf81a52a73a 100644 --- a/test/test_train_mp_imagenet_amp.py +++ b/test/test_train_mp_imagenet_amp.py @@ -67,6 +67,8 @@ import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp import torch_xla.test.test_utils as test_utils +import torch.xla.amp as xla_amp +import torch.cuda.amp as xla_cuda from torch_xla.amp import GradScaler try: from torch_xla.amp import syncfree @@ -220,12 +222,12 @@ def train_imagenet(): loss_fn = nn.CrossEntropyLoss() if FLAGS.amp: if device_hw == 'TPU': - autocast = torch.xla.amp.autocast - scaler = None + autocast = xla_amp.autocast + scaler = None elif device_hw == 'GPU': - autocast = torch.cuda.amp.autocast - # GradScaler only used for GPU - scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) + autocast = cuda_amp.autocast + # GradScaler only used for GPU + scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) def train_loop_fn(loader, epoch): tracker = xm.RateTracker() diff --git a/test/test_train_mp_mnist_amp.py b/test/test_train_mp_mnist_amp.py index 2e7c9c72f32..652c79c89b5 100644 --- a/test/test_train_mp_mnist_amp.py +++ b/test/test_train_mp_mnist_amp.py @@ -42,6 +42,8 @@ from torch_xla.amp import syncfree except ImportError: assert False, "Missing package syncfree; the package is available in torch-xla>=1.11" +import torch.xla.amp as xla_amp +import torch.cuda.amp as xla_cuda class MNIST(nn.Module): @@ -66,10 +68,10 @@ def forward(self, x): return F.log_softmax(x, dim=1) -def _train_update(device, x, loss, tracker, writer): +def _train_update(device, step, loss, tracker, writer): test_utils.print_training_update( device, - x, + step, loss.item(), tracker.rate(), tracker.global_rate(), @@ -132,35 +134,34 @@ def train_mnist(flags, **kwargs): device = xm.xla_device() device_hw = xm.xla_device_hw(device) model = MNIST().to(device) + writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(flags.logdir) optim_cls = syncfree.SGD if FLAGS.use_syncfree_optim else optim.SGD optimizer = optim_cls(model.parameters(), lr=lr, momentum=flags.momentum) loss_fn = nn.NLLLoss() + if device_hw == 'TPU': - device = "xla" - dtype = torch.bfloat16 - scaler = None - print("Setting autocast device to xla") + autocast = xla_amp.autocast + scaler = None elif device_hw == 'GPU': - device = "cuda" - dtype = torch.float16 - # GradScaler only used for GPU - scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) - print("Setting autocast device to cuda") + autocast = cuda_amp.autocast + # GradScaler only used for GPU + scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) + else: + print("Only TPU or GPU supported for AMP.") + sys.exit(1) + scaler = None def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() - print("Entering step", step) - with torch.autocast(device, dtype=dtype): + with autocast(): output = model(data) - print(output.dtype) loss = loss_fn(output, target) - print("Exiting autocast region") if scaler: scaler.scale(loss).backward() gradients = xm._fetch_gradients(optimizer) diff --git a/torch_xla/csrc/autocast_mode.cpp b/torch_xla/csrc/autocast_mode.cpp index a6be6929c4f..3081392cf18 100644 --- a/torch_xla/csrc/autocast_mode.cpp +++ b/torch_xla/csrc/autocast_mode.cpp @@ -10,28 +10,6 @@ namespace at { namespace autocast { -// Note: Copied over from autocast_mode.cpp -// Policies correspond to op categories that need code-divergent handling. -// Wrapper templates below are specialized based on a policy template parameter. -enum class CastPolicy : uint8_t { - lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before running the op. - // Currently, lower_precision_fp is fp16 for AutocastCUDA, and is defined by user(default bf16) for AutocastCPU. - fp32, // Cast all inputs to at::kFloat before running the op. - // TODO: fp32_set_opt_dtype seems to only be for CUDA devices? - fp32_set_opt_dtype, // Treats functions (like softmax) that - // 1. we'd like to run in fp32 and - // 2. have a c10::optional arg that controls the output type. - // fp32_set_opt_dtype wrappers' policy is: if the output type is already set, - // don't touch it, otherwise, set it to at::kFloat. - fp32_append_dtype, // Treats functions (like norm) that - // 1. we'd like to run in fp32 and - // 2. have some overloads that accept an output type and other overloads that don't. - // fp32_append_dtype wrappers wrap the overloads that don't have an output dtype. - // The wrapper policy is: append at::kFloat to the args, and redispatch to the - // type-aware overload. - promote, // Run in the widest dtype among several args. -}; - // Base template for WrapFunction_, which is specialized to contain a "call" method each CastPolicy template struct WrapFunction_ {}; @@ -63,7 +41,7 @@ struct WrapFunction_ Date: Wed, 12 Apr 2023 09:49:19 -0700 Subject: [PATCH 05/20] Delete test_fsdp_auto_wrap_amp.py --- test/test_fsdp_auto_wrap_amp.py | 65 --------------------------------- 1 file changed, 65 deletions(-) delete mode 100644 test/test_fsdp_auto_wrap_amp.py diff --git a/test/test_fsdp_auto_wrap_amp.py b/test/test_fsdp_auto_wrap_amp.py deleted file mode 100644 index 27a1fa0f9ae..00000000000 --- a/test/test_fsdp_auto_wrap_amp.py +++ /dev/null @@ -1,65 +0,0 @@ -import torch -import torch_xla -import torch_xla.utils.utils as xu -import torch_xla.core.xla_model as xm -import torch_xla.distributed.xla_multiprocessing as xmp -import test_utils - -from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel -from torch_xla.distributed.fsdp.wrap import always_wrap_policy - -import sys -import unittest - - -class TestNoBackwardModule(test_utils.XlaTestCase): - # Test the FSDP autowrap feature with a module containing a submodule - # that is only used in forward (fc2 below), to make sure it doesn't - # fail by the hook assertion. - class MyModel(torch.nn.Module): - - def __init__(self, input_size, hidden_size): - super().__init__() - self.input_size = input_size - self.hidden_size = hidden_size - self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size) - self.fc2 = torch.nn.Linear(self.input_size, self.hidden_size) - - def forward(self, x): - hidden1 = self.fc1(x) - hidden2 = self.fc2(x) - return hidden1, hidden2 - - def test(self): - dev = xm.xla_device() - device_hw = xm.xla_device_hw(dev) - if device_hw == 'TPU': - autocast = torch.xla.amp.autocast - elif device_hw == 'GPU': - autocast = torch.cuda.amp.autocast - - input = torch.zeros([16, 16], device=dev) - model = self.MyModel(input_size=16, hidden_size=4) - model = XlaFullyShardedDataParallel( - model, auto_wrap_policy=always_wrap_policy) - model.to(dev) - with autocast(): - hid1, hid2 = model(input) - loss = hid1.sum() - loss.backward() - xm.mark_step() - - -def _mp_fn(index): - device = xm.xla_device() - if xm.xla_device_hw(device) in ('TPU', 'GPU'): - test = unittest.main(exit=False) - sys.exit(0 if test.result.wasSuccessful() else 1) - else: - print( - 'Default device {} is not a TPU or GPU device'.format(device), - file=sys.stderr) - - -if __name__ == '__main__': - xmp.spawn(_mp_fn, args=()) \ No newline at end of file From 2157e2c0119fcb71d6a0c5239d09921670150057 Mon Sep 17 00:00:00 2001 From: Meghan Date: Thu, 13 Apr 2023 04:53:36 +0000 Subject: [PATCH 06/20] updates --- test/test_train_mp_mnist_amp.py | 1 - torch_xla/csrc/autocast_mode.cpp | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_train_mp_mnist_amp.py b/test/test_train_mp_mnist_amp.py index 652c79c89b5..550a0d6ae77 100644 --- a/test/test_train_mp_mnist_amp.py +++ b/test/test_train_mp_mnist_amp.py @@ -152,7 +152,6 @@ def train_mnist(flags, **kwargs): else: print("Only TPU or GPU supported for AMP.") sys.exit(1) - scaler = None def train_loop_fn(loader): tracker = xm.RateTracker() diff --git a/torch_xla/csrc/autocast_mode.cpp b/torch_xla/csrc/autocast_mode.cpp index 3081392cf18..6e0fcc58d42 100644 --- a/torch_xla/csrc/autocast_mode.cpp +++ b/torch_xla/csrc/autocast_mode.cpp @@ -98,10 +98,12 @@ TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) { KERNEL_XLA2(conv_transpose3d, input, lower_precision_fp) KERNEL_XLA(prelu, lower_precision_fp) KERNEL_XLA(relu, lower_precision_fp) + KERNEL_XLA(max_pool2d, lower_precision_fp) // fp32 cast policy KERNEL_XLA(batch_norm, fp32) - KERNEL_XLA(_log_softmax, fp32) + KERNEL_XLA2(log_softmax, int, fp32) + KERNEL_XLA2(log_softmax, Dimname, fp32) KERNEL_XLA(binary_cross_entropy, fp32) KERNEL_XLA(grid_sampler, fp32) KERNEL_XLA(polar, fp32) From e247b43c445a39ecc259f0c48345b33301cd4563 Mon Sep 17 00:00:00 2001 From: Meghan Date: Fri, 14 Apr 2023 16:51:55 +0000 Subject: [PATCH 07/20] Update autocast key to pick between Cuda and Xla. Unit tests --- test/test_autocast.py | 276 ++++++++++++++++++++++++++----- torch_xla/csrc/autocast_mode.cpp | 227 ++++++++++++++----------- torch_xla/csrc/tensor_impl.cpp | 11 ++ 3 files changed, 372 insertions(+), 142 deletions(-) diff --git a/test/test_autocast.py b/test/test_autocast.py index e8f41d9519c..77cf680287a 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -13,9 +13,147 @@ import unittest from torch.testing._internal.autocast_test_lists import AutocastTestLists from torch_xla.amp import autocast, GradScaler +import torch.xla.amp -class AutocastTestUnsupportedLists(object): +class AutocastTPUTestLists: + # Supplies ops and arguments for TPU autocast tests + def __init__(self, dev): + super().__init__() + n = 8 + # Utility arguments, created as one-element tuples + pointwise0_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),) + pointwise1_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),) + pointwise2_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),) + mat0_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),) + mat1_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),) + mat2_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),) + + dummy_dimsets = ((n,), (n, n), (n, n, n), (n, n, n, n), (n, n, n, n, n)) + + dummy_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev),) + for dimset in dummy_dimsets] + + dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n)) + conv_args_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev), + torch.randn(dimset, dtype=torch.bfloat16, device=dev)) + for dimset in dimsets] + conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev), + torch.randn(dimset, dtype=torch.float32, device=dev)) + for dimset in dimsets] + + bias_fp32 = (torch.randn((n,), dtype=torch.float32, device=dev),) + element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),) + pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),) + pointwise1_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),) + mat0_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + mat1_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + + dummy_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev),) + for dimset in dummy_dimsets] + # The lists below organize ops that autocast needs to test. + # self.list_name corresponds to test_autocast_list_name . + # Each op is associated with a tuple of valid arguments. + + # Some ops implement built-in type promotion. These don't need autocasting, + # but autocasting relies on their promotion, so we include tests to double-check. + self.torch_expect_builtin_promote = [ + ("eq", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("ge", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("gt", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("le", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("lt", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("ne", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("add", pointwise0_fp32 + pointwise1_bf16, torch.float32), + ("div", pointwise0_fp32 + pointwise1_bf16, torch.float32), + ("mul", pointwise0_fp32 + pointwise1_bf16, torch.float32), + ] + self.methods_expect_builtin_promote = [ + ("__eq__", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("__ge__", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("__gt__", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("__le__", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("__lt__", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("__ne__", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("__add__", pointwise0_fp32 + pointwise1_bf16, torch.float32), + ("__div__", pointwise0_fp32 + pointwise1_bf16, torch.float32), + ("__mul__", pointwise0_fp32 + pointwise1_bf16, torch.float32), + ] + # The remaining lists organize ops that autocast treats explicitly. + self.torch_bf16 = [ + ("conv1d", conv_args_fp32[0]), + ("conv2d", conv_args_fp32[1]), + ("conv3d", conv_args_fp32[2]), + ("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32))), + ("mm", mat0_fp32 + mat1_fp32), + ("matmul", mat0_fp32 + mat1_fp32), + ("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32))), + ("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32), + ("addbmm", + mat0_fp32 + (torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32))), + ("conv_tbc", (torch.randn((10, 7, 3), device=dev, dtype=torch.float32), + torch.randn((5, 3, 5), device=dev, dtype=torch.float32), + torch.randn(5, device=dev, dtype=torch.float32), 0)), + ("conv_transpose1d", conv_args_fp32[0]), + ("conv_transpose2d", conv_args_fp32[1]), + ("conv_transpose3d", conv_args_fp32[2]), + ("prelu", pointwise0_fp32 + element0_fp32), + ("relu", pointwise0_fp32 + element0_fp32), + ] + self.torch_fp32 = [ + ("cosine_embedding_loss", (torch.tensor([[1, 2, 3]], + device=dev, + dtype=torch.bfloat16), + torch.tensor([[1, 3, 4]], + device=dev, + dtype=torch.bfloat16), + torch.tensor([1], + device=dev, + dtype=torch.int))), + ("hinge_embedding_loss", + mat0_bf16 + (torch.ones(n, device=dev, dtype=torch.int),)), + ("margin_ranking_loss", mat0_bf16 + mat1_bf16 + (torch.ones( + (n,), device=dev, dtype=torch.bfloat16),)), + ("triplet_margin_loss", mat0_bf16 + mat1_bf16 + mat2_bf16), + ("binary_cross_entropy_with_logits", mat0_bf16 + (torch.rand( + (n, n), device=dev, dtype=torch.bfloat16),)), + ] + self.nn_bf16 = [ + ("linear", mat0_fp32 + mat1_fp32, {}), + ] + self.nn_fp32 = [ + ("nll_loss", (torch.rand((n, n), device=dev, dtype=torch.float), + torch.zeros((n,), device=dev, dtype=torch.long))), + ("nll_loss2d", (torch.rand((n, n, n, n), + device=dev, + dtype=torch.bfloat16), + torch.zeros((n, n, n), device=dev, dtype=torch.long))), + ("l1_loss", mat0_bf16 + mat1_bf16), + ("smooth_l1_loss", mat0_bf16 + mat1_bf16), + ("mse_loss", mat0_bf16 + mat1_bf16), + ("multilabel_margin_loss", mat0_bf16 + (torch.ones( + (n, n), device=dev, dtype=torch.long),)), + ("soft_margin_loss", mat0_bf16 + (torch.ones( + (n, n), device=dev, dtype=torch.long),)), + ("multi_margin_loss", mat0_bf16 + (torch.ones( + (n,), device=dev, dtype=torch.long),)), + ] + self.torch_need_autocast_promote = [ + ("cat", (pointwise0_bf16 + pointwise1_fp32,)), + ("stack", (pointwise0_bf16 + pointwise1_fp32,)), + ] + self.methods_fp32 = [] + + self.methods_bf16 = [("__matmul__", mat0_bf16 + mat1_fp32)] + + +class AutocastCudaTestUnsupportedLists(object): def __init__(self): super().__init__() @@ -30,7 +168,6 @@ def __init__(self): # The remaining lists organize ops that autocast treats explicitly. self.torch_fp16 = [ "_convolution_nogroup", # need lowering - "prelu", # need lowering "addmv", # need lowering ] self.torch_fp32 = [ @@ -51,19 +188,23 @@ class TestAutocastBase(unittest.TestCase): def setUp(self): super(TestAutocastBase, self).setUp() - self.autocast_lists = AutocastTestLists(torch.device(xm.xla_device())) - self.autocast_unsupported_lists = AutocastTestUnsupportedLists() - self.skip_list = [] + self.autocast = None + self.is_autocast_enabled = None + self.autocast_lists = None + self.autocast_unsupported_lists = None def tearDown(self): del self.autocast_lists super(TestAutocastBase, self).tearDown() def get_autocast_list(self, list_name): - return [ - tp for tp in getattr(self.autocast_lists, list_name) - if tp[0] not in getattr(self.autocast_unsupported_lists, list_name) - ] + if self.autocast_unsupported_lists: + return [ + tp for tp in getattr(self.autocast_lists, list_name) + if tp[0] not in getattr(self.autocast_unsupported_lists, list_name) + ] + else: + return [tp for tp in getattr(self.autocast_lists, list_name)] def args_maybe_kwargs(self, op_with_args): if len(op_with_args) == 2: @@ -90,9 +231,9 @@ def cast(val, to_type): if add_kwargs is None: add_kwargs = {} - self.assertFalse(torch.is_autocast_enabled()) - with autocast(): - self.assertTrue(torch.is_autocast_enabled()) + self.assertFalse(self.is_autocast_enabled()) + with self.autocast(): + self.assertTrue(self.is_autocast_enabled()) out_type = out_type if out_type is not None else run_as_type output = output_method = None @@ -103,7 +244,7 @@ def cast(val, to_type): if isinstance(output, torch.Tensor): self.assertTrue( out_type == output.dtype, - "autocast for torch.{} produced {}, should produce {}".format( + "autocast for {} produced {}, should produce {}".format( op, output.dtype, out_type)) # Try Tensor.* variant: @@ -112,8 +253,8 @@ def cast(val, to_type): if isinstance(output_method, torch.Tensor): self.assertTrue( out_type == output_method.dtype, - "autocast for torch.{} produced {}, should produce torch.{}". - format(op, output_method.dtype, out_type)) + "autocast for {} produced {}, should produce torch.{}".format( + op, output_method.dtype, out_type)) self.assertTrue((output is not None) or ( output_method is not None @@ -141,8 +282,8 @@ def compare(first, second): # Compare numerics to Python-side "autocasting" that (we expect) does the same thing # as the C++-side autocasting, and should be bitwise accurate. output_to_compare = output if output is not None else output_method - with autocast(enabled=False): - self.assertFalse(torch.is_autocast_enabled()) + with self.autocast(enabled=False): + self.assertFalse(self.is_autocast_enabled()) if module is not None and hasattr(module, op): control = getattr(module, op)(*cast(args, run_as_type), **add_kwargs) @@ -153,11 +294,42 @@ def compare(first, second): comparison = compare(output_to_compare, control) self.assertTrue(comparison, "torch.{} result did not match control".format(op)) - self.assertTrue(torch.is_autocast_enabled()) - self.assertFalse(torch.is_autocast_enabled()) + self.assertTrue(self.is_autocast_enabled()) + self.assertFalse(self.is_autocast_enabled()) + + +@unittest.skipIf(not torch.cuda.is_available(), "requires cuda") +class TestAutocastCuda(TestAutocastBase): + + def setUp(self): + super(TestAutocastCuda, self).setUp() + self.autocast = torch.xla.cuda.autocast + self.is_autocast_enabled = torch.is_autocast_enabled + self.autocast_lists = AutocastTestLists(torch.device(xm.xla_device())) + self.autocast_unsupported_lists = AutocastCudaTestUnsupportedLists() + def test_autocast_nn_fp16(self): + with torch.backends.cudnn.flags(enabled=True, deterministic=True): + for op, args in self.get_autocast_list('nn_fp16'): + self._run_autocast_outofplace( + op, args, torch.float16, module=torch._C._nn) -class TestAutocast(TestAutocastBase): + def test_autocast_linalg_fp16(self): + with torch.backends.cudnn.flags(enabled=True, deterministic=True): + for op, args in self.get_autocast_list('linalg_fp16'): + self._run_autocast_outofplace( + op, args, torch.float16, module=torch._C._linalg) + + def test_autocast_methods_fp16(self): + with torch.backends.cudnn.flags(enabled=True, deterministic=True): + for op, args in self.get_autocast_list('methods_fp16'): + self._run_autocast_outofplace(op, args, torch.float16, module=None) + + def test_autocast_banned(self): + with torch.cuda.amp.autocast(): + for op, args, module in self.get_autocast_list('banned'): + with self.assertRaises(RuntimeError): + getattr(module, op)(*args) def test_autocast_torch_fp32(self): for op_with_args in self.get_autocast_list('torch_fp32'): @@ -174,30 +346,58 @@ def test_autocast_torch_expect_builtin_promote(self): 'torch_expect_builtin_promote'): self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type) - def test_autocast_nn_fp16(self): - with torch.backends.cudnn.flags(enabled=True, deterministic=True): - for op, args in self.get_autocast_list('nn_fp16'): - self._run_autocast_outofplace( - op, args, torch.float16, module=torch._C._nn) - def test_autocast_nn_fp32(self): for op, args in self.get_autocast_list('nn_fp32'): self._run_autocast_outofplace( op, args, torch.float32, module=torch._C._nn) - def test_autocast_linalg_fp16(self): - with torch.backends.cudnn.flags(enabled=True, deterministic=True): - for op, args in self.get_autocast_list('linalg_fp16'): - self._run_autocast_outofplace( - op, args, torch.float16, module=torch._C._linalg) + def test_autocast_methods_fp32(self): + for op, args in self.get_autocast_list('methods_fp32'): + print("autocast fp32", op) + self._run_autocast_outofplace(op, args, torch.float32, module=None) - def test_autocast_methods_fp16(self): - with torch.backends.cudnn.flags(enabled=True, deterministic=True): - for op, args in self.get_autocast_list('methods_fp16'): - self._run_autocast_outofplace(op, args, torch.float16, module=None) + def test_autocast_methods_expect_builtin_promote(self): + for op, args, out_type in self.get_autocast_list( + 'methods_expect_builtin_promote'): + self._run_autocast_outofplace( + op, args, torch.float32, module=None, out_type=out_type) + + +class TestAutocastTPU(TestAutocastBase): + + def setUp(self): + super(TestAutocastTPU, self).setUp() + self.autocast = torch.xla.amp.autocast + self.is_autocast_enabled = torch.is_autocast_xla_enabled + self.autocast_lists = AutocastTPUTestLists(torch.device(xm.xla_device())) + + def test_autocast_methods_bf16(self): + for op, args in self.get_autocast_list('methods_bf16'): + self._run_autocast_outofplace(op, args, torch.bfloat16, module=None) + + def test_autocast_torch_fp32(self): + for op_with_args in self.get_autocast_list('torch_fp32'): + op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) + self._run_autocast_outofplace( + op, args, torch.float32, add_kwargs=maybe_kwargs) + + def test_autocast_torch_need_autocast_promote(self): + for op, args in self.get_autocast_list('torch_need_autocast_promote'): + self._run_autocast_outofplace(op, args, torch.float32) + + def test_autocast_torch_expect_builtin_promote(self): + for op, args, out_type in self.get_autocast_list( + 'torch_expect_builtin_promote'): + self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type) + + def test_autocast_nn_fp32(self): + for op, args in self.get_autocast_list('nn_fp32'): + self._run_autocast_outofplace( + op, args, torch.float32, module=torch._C._nn) def test_autocast_methods_fp32(self): for op, args in self.get_autocast_list('methods_fp32'): + print("autocast fp32", op) self._run_autocast_outofplace(op, args, torch.float32, module=None) def test_autocast_methods_expect_builtin_promote(self): @@ -206,12 +406,6 @@ def test_autocast_methods_expect_builtin_promote(self): self._run_autocast_outofplace( op, args, torch.float32, module=None, out_type=out_type) - def test_autocast_banned(self): - with torch.cuda.amp.autocast(): - for op, args, module in self.get_autocast_list('banned'): - with self.assertRaises(RuntimeError): - getattr(module, op)(*args) - if __name__ == "__main__": test = unittest.main(verbosity=FLAGS.verbosity, exit=False) diff --git a/torch_xla/csrc/autocast_mode.cpp b/torch_xla/csrc/autocast_mode.cpp index 6e0fcc58d42..d756e5edd0f 100644 --- a/torch_xla/csrc/autocast_mode.cpp +++ b/torch_xla/csrc/autocast_mode.cpp @@ -1,78 +1,101 @@ #include -#include #include -#include #include - -#include +#include #include +#include +#include namespace at { namespace autocast { -// Base template for WrapFunction_, which is specialized to contain a "call" method each CastPolicy -template struct WrapFunction_ {}; +// Base template for WrapFunction_, which is specialized to contain a "call" +// method each CastPolicy +template +struct WrapFunction_ {}; // CastPolicy::lower_precision_fp General_DeviceType -template -struct WrapFunction_> { +template +struct WrapFunction_> { static Ret call(Args... args) { - c10::impl::ExcludeDispatchKeyGuard no_autocast(get_autocast_dispatch_key_from_device_type(device_type)); - return (*F)(cached_cast(get_lower_precision_fp_from_device_type(device_type), args, device_type)...); + c10::impl::ExcludeDispatchKeyGuard no_autocast( + get_autocast_dispatch_key_from_device_type(device_type)); + return (*F)( + cached_cast(get_lower_precision_fp_from_device_type(device_type), args, + device_type)...); } }; // CastPolicy::fp32 General_DeviceType -template -struct WrapFunction_> { +template +struct WrapFunction_> { static Ret call(Args... args) { - c10::impl::ExcludeDispatchKeyGuard no_autocast(get_autocast_dispatch_key_from_device_type(device_type)); + c10::impl::ExcludeDispatchKeyGuard no_autocast( + get_autocast_dispatch_key_from_device_type(device_type)); return (*F)(cached_cast(at::kFloat, args, device_type)...); } }; // CastPolicy::promote General_DeviceType -template -struct WrapFunction_> { +template +struct WrapFunction_> { static Ret call(Args... args) { - c10::impl::ExcludeDispatchKeyGuard no_autocast(get_autocast_dispatch_key_from_device_type(device_type)); - auto to_type = promote_type(get_lower_precision_fp_from_device_type(device_type), device_type, args...); + c10::impl::ExcludeDispatchKeyGuard no_autocast( + get_autocast_dispatch_key_from_device_type(device_type)); + auto to_type = + promote_type(get_lower_precision_fp_from_device_type(device_type), + device_type, args...); return (*F)(cached_cast(to_type, args, device_type)...); } }; -// Wrapper to infer return_type and parameter_types for WrapFunction_ (imitating pytorch/core/boxing/impl/WrapFunctionIntoFunctor.h) -template // The actual function we're redispatching to. +// Wrapper to infer return_type and parameter_types for WrapFunction_ (imitating +// pytorch/core/boxing/impl/WrapFunctionIntoFunctor.h) +template < + CastPolicy policy, DeviceType device_type, + class Registered, // The signature for which we're registering. The + // dispatcher's calling code invokes our registered + // functions with arguments matching Registered, so we + // register WrapFunction_::call methods with a matching + // signature to properly field those arguments. + // guts::function_traits below extracts return_type and + // parameter_types from Registered, which WrapFunction_ + // templates above use to declare their call methods. + class Redispatch, // The signature for the function we're redispatching to. + // In most cases this is the same as Registered, but for + // some ops (for example, ops where we append a dtype) + // it's useful to redispatch to a function with a + // different signature. + Redispatch* F> // The actual function we're redispatching to. struct WrapFunction final { - using type = WrapFunction_::return_type, - typename guts::function_traits::parameter_types>; + using type = WrapFunction_< + policy, device_type, Redispatch, F, + typename guts::function_traits::return_type, + typename guts::function_traits::parameter_types>; }; namespace { // KERNEL_XLA registration for AutocastXLA -#define KERNEL_XLA(OP, POLICY) \ - m.impl(TORCH_SELECTIVE_NAME("aten::" #OP), \ - &WrapFunction::type::call); -#define KERNEL_XLA2(OP, OVERLOAD, POLICY) \ - m.impl(TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \ - &WrapFunction::type::call); +#define KERNEL_XLA(OP, POLICY) \ + m.impl(TORCH_SELECTIVE_NAME("aten::" #OP), \ + &WrapFunction::type::call); +#define KERNEL_XLA2(OP, OVERLOAD, POLICY) \ + m.impl(TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \ + &WrapFunction::type::call); - TORCH_LIBRARY_IMPL(_, AutocastXLA, m) { m.fallback(torch::CppFunction::makeFallthrough()); } @@ -96,39 +119,41 @@ TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) { KERNEL_XLA(conv_transpose1d, lower_precision_fp) KERNEL_XLA2(conv_transpose2d, input, lower_precision_fp) KERNEL_XLA2(conv_transpose3d, input, lower_precision_fp) - KERNEL_XLA(prelu, lower_precision_fp) + KERNEL_XLA(prelu, lower_precision_fp) KERNEL_XLA(relu, lower_precision_fp) KERNEL_XLA(max_pool2d, lower_precision_fp) // fp32 cast policy + // Commented out ops are included in the AutoCastCPU Policy, + // but not lowered. Enable if op is lowered. KERNEL_XLA(batch_norm, fp32) KERNEL_XLA2(log_softmax, int, fp32) KERNEL_XLA2(log_softmax, Dimname, fp32) KERNEL_XLA(binary_cross_entropy, fp32) - KERNEL_XLA(grid_sampler, fp32) - KERNEL_XLA(polar, fp32) + // KERNEL_XLA(grid_sampler, fp32) + // KERNEL_XLA(polar, fp32) KERNEL_XLA(prod, fp32) KERNEL_XLA2(prod, dim_int, fp32) KERNEL_XLA2(prod, dim_Dimname, fp32) - KERNEL_XLA(quantile, fp32) - KERNEL_XLA2(quantile, scalar, fp32) - KERNEL_XLA(nanquantile, fp32) - KERNEL_XLA2(nanquantile, scalar, fp32) - KERNEL_XLA(stft, fp32) - KERNEL_XLA2(stft, center, fp32) + // KERNEL_XLA(quantile, fp32) + // KERNEL_XLA2(quantile, scalar, fp32) + // KERNEL_XLA(nanquantile, fp32) + // KERNEL_XLA2(nanquantile, scalar, fp32) + // KERNEL_XLA(stft, fp32) + // KERNEL_XLA2(stft, center, fp32) KERNEL_XLA(cdist, fp32) - KERNEL_XLA(grid_sampler_2d, fp32) - KERNEL_XLA(grid_sampler_3d, fp32) + // KERNEL_XLA(grid_sampler_2d, fp32) + // KERNEL_XLA(grid_sampler_3d, fp32) KERNEL_XLA(trace, fp32) - KERNEL_XLA(view_as_complex, fp32) + // KERNEL_XLA(view_as_complex, fp32) KERNEL_XLA(cholesky, fp32) KERNEL_XLA(cholesky_inverse, fp32) KERNEL_XLA(cholesky_solve, fp32) KERNEL_XLA(inverse, fp32) - KERNEL_XLA(lu_solve, fp32) - KERNEL_XLA(orgqr, fp32) - KERNEL_XLA(ormqr, fp32) - KERNEL_XLA(pinverse, fp32) + // KERNEL_XLA(lu_solve, fp32) + // KERNEL_XLA(orgqr, fp32) + // KERNEL_XLA(ormqr, fp32) + // KERNEL_XLA(pinverse, fp32) KERNEL_XLA(reflection_pad1d, fp32) KERNEL_XLA(reflection_pad2d, fp32) KERNEL_XLA(replication_pad1d, fp32) @@ -139,11 +164,11 @@ TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) { KERNEL_XLA(nll_loss, fp32) KERNEL_XLA(nll_loss2d, fp32) KERNEL_XLA(hinge_embedding_loss, fp32) - KERNEL_XLA(poisson_nll_loss, fp32) + // KERNEL_XLA(poisson_nll_loss, fp32) KERNEL_XLA(smooth_l1_loss, fp32) - KERNEL_XLA(cross_entropy_loss, fp32) + // KERNEL_XLA(cross_entropy_loss, fp32) KERNEL_XLA(l1_loss, fp32) - KERNEL_XLA(huber_loss, fp32) + // KERNEL_XLA(huber_loss, fp32) KERNEL_XLA(margin_ranking_loss, fp32) KERNEL_XLA(soft_margin_loss, fp32) KERNEL_XLA(triplet_margin_loss, fp32) @@ -153,48 +178,48 @@ TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) { KERNEL_XLA(kl_div, fp32) KERNEL_XLA(multilabel_margin_loss, fp32) KERNEL_XLA(binary_cross_entropy_with_logits, fp32) - KERNEL_XLA(fft_fft, fp32) - KERNEL_XLA(fft_ifft, fp32) - KERNEL_XLA(fft_fft2, fp32) - KERNEL_XLA(fft_ifft2, fp32) - KERNEL_XLA(fft_fftn, fp32) - KERNEL_XLA(fft_ifftn, fp32) - KERNEL_XLA(fft_rfft, fp32) - KERNEL_XLA(fft_irfft, fp32) - KERNEL_XLA(fft_rfft2, fp32) - KERNEL_XLA(fft_irfft2, fp32) - KERNEL_XLA(fft_rfftn, fp32) - KERNEL_XLA(fft_irfftn, fp32) - KERNEL_XLA(fft_hfft, fp32) - KERNEL_XLA(fft_ihfft, fp32) - KERNEL_XLA(linalg_cond, fp32) - KERNEL_XLA2(linalg_cond, p_str, fp32) - KERNEL_XLA(linalg_matrix_rank, fp32) - KERNEL_XLA2(linalg_matrix_rank, tol_tensor, fp32) - KERNEL_XLA2(linalg_matrix_rank, atol_rtol_tensor, fp32) - KERNEL_XLA2(linalg_matrix_rank, atol_rtol_float, fp32) - KERNEL_XLA(linalg_solve, fp32) - KERNEL_XLA(linalg_cholesky, fp32) - KERNEL_XLA(linalg_svdvals, fp32) - KERNEL_XLA(linalg_eigvals, fp32) - KERNEL_XLA(linalg_eigvalsh, fp32) - KERNEL_XLA(linalg_inv, fp32) - KERNEL_XLA(linalg_householder_product, fp32) - KERNEL_XLA(linalg_tensorinv, fp32) - KERNEL_XLA(linalg_tensorsolve, fp32) - KERNEL_XLA(fake_quantize_per_tensor_affine, fp32) - KERNEL_XLA(geqrf, fp32) - KERNEL_XLA(_lu_with_info, fp32) + // KERNEL_XLA(fft_fft, fp32) + // KERNEL_XLA(fft_ifft, fp32) + // KERNEL_XLA(fft_fft2, fp32) + // KERNEL_XLA(fft_ifft2, fp32) + // KERNEL_XLA(fft_fftn, fp32) + // KERNEL_XLA(fft_ifftn, fp32) + // KERNEL_XLA(fft_rfft, fp32) + // KERNEL_XLA(fft_irfft, fp32) + // KERNEL_XLA(fft_rfft2, fp32) + // KERNEL_XLA(fft_irfft2, fp32) + // KERNEL_XLA(fft_rfftn, fp32) + // KERNEL_XLA(fft_irfftn, fp32) + // KERNEL_XLA(fft_hfft, fp32) + // KERNEL_XLA(fft_ihfft, fp32) + // KERNEL_XLA(linalg_cond, fp32) + // KERNEL_XLA2(linalg_cond, p_str, fp32) + // KERNEL_XLA(linalg_matrix_rank, fp32) + // KERNEL_XLA2(linalg_matrix_rank, tol_tensor, fp32) + // KERNEL_XLA2(linalg_matrix_rank, atol_rtol_tensor, fp32) + // KERNEL_XLA2(linalg_matrix_rank, atol_rtol_float, fp32) + // KERNEL_XLA(linalg_solve, fp32) + // KERNEL_XLA(linalg_cholesky, fp32) + // KERNEL_XLA(linalg_svdvals, fp32) + // KERNEL_XLA(linalg_eigvals, fp32) + // KERNEL_XLA(linalg_eigvalsh, fp32) + // KERNEL_XLA(linalg_inv, fp32) + // KERNEL_XLA(linalg_householder_product, fp32) + // KERNEL_XLA(linalg_tensorinv, fp32) + // KERNEL_XLA(linalg_tensorsolve, fp32) + // KERNEL_XLA(fake_quantize_per_tensor_affine, fp32) + // KERNEL_XLA(geqrf, fp32) + // KERNEL_XLA(_lu_with_info, fp32) KERNEL_XLA(qr, fp32) KERNEL_XLA(svd, fp32) KERNEL_XLA(triangular_solve, fp32) KERNEL_XLA(multilabel_margin_loss_forward, fp32) - KERNEL_XLA(linalg_qr, fp32) - KERNEL_XLA(linalg_cholesky_ex, fp32) + // KERNEL_XLA(linalg_qr, fp32) + // KERNEL_XLA(linalg_cholesky_ex, fp32) KERNEL_XLA(linalg_svd, fp32) - KERNEL_XLA(linalg_eig, fp32) - KERNEL_XLA(linalg_eigh, fp32) - KERNEL_XLA(linalg_lstsq, fp32) + // KERNEL_XLA(linalg_eig, fp32) + // KERNEL_XLA(linalg_eigh, fp32) + // KERNEL_XLA(linalg_lstsq, fp32) KERNEL_XLA(linalg_inv_ex, fp32) // promote @@ -204,6 +229,6 @@ TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) { KERNEL_XLA2(index_copy, dimname, promote) } -} // namespace -} // namespace autocast -} // namespace at +} // namespace +} // namespace autocast +} // namespace at diff --git a/torch_xla/csrc/tensor_impl.cpp b/torch_xla/csrc/tensor_impl.cpp index 7e9d1cf19a5..7f64a5c770e 100644 --- a/torch_xla/csrc/tensor_impl.cpp +++ b/torch_xla/csrc/tensor_impl.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include "third_party/xla_client/computation_client.h" @@ -62,6 +63,16 @@ XLATensorImpl::XLATensorImpl(XLATensor&& tensor) GetTypeMeta(tensor), bridge::XlaDeviceToAtenDevice(tensor.GetDevice())), tensor_(c10::make_intrusive(std::move(tensor))) { + // Update the Autocast key based off the backend device. + // Upstream TensorImpl cannot differentiate between XLA:TPU and XLA:GPU + // so we must manually update Autocast to AutocastCUDA on XLA:GPU. + c10::DispatchKeySet key_set = c10::TensorImpl::key_set(); + torch::lazy::BackendDevice current_device = GetCurrentDevice(); + if (static_cast(current_device.type()) == XlaDeviceType::GPU) { + auto autocast_cuda_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastCUDA); + auto autocast_xla_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastXLA); + key_set = (key_set - autocast_xla_ks) | autocast_cuda_ks; + } is_non_overlapping_and_dense_ = false; set_custom_sizes_strides(SizesStridesPolicy::CustomSizes); } From 7d215d1bcba26e76124dee4cfec0c0381b40c016 Mon Sep 17 00:00:00 2001 From: Meghan Date: Fri, 14 Apr 2023 16:55:01 +0000 Subject: [PATCH 08/20] lint --- torch_xla/csrc/autocast_mode.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/csrc/autocast_mode.cpp b/torch_xla/csrc/autocast_mode.cpp index d756e5edd0f..8cf3569ca76 100644 --- a/torch_xla/csrc/autocast_mode.cpp +++ b/torch_xla/csrc/autocast_mode.cpp @@ -73,7 +73,7 @@ template < // some ops (for example, ops where we append a dtype) // it's useful to redispatch to a function with a // different signature. - Redispatch* F> // The actual function we're redispatching to. + Redispatch* F> // The actual function we're redispatching to. struct WrapFunction final { using type = WrapFunction_< policy, device_type, Redispatch, F, @@ -119,7 +119,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) { KERNEL_XLA(conv_transpose1d, lower_precision_fp) KERNEL_XLA2(conv_transpose2d, input, lower_precision_fp) KERNEL_XLA2(conv_transpose3d, input, lower_precision_fp) - KERNEL_XLA(prelu, lower_precision_fp) + KERNEL_XLA(prelu, lower_precision_fp) KERNEL_XLA(relu, lower_precision_fp) KERNEL_XLA(max_pool2d, lower_precision_fp) From 661759ad3c19e7abd30e8625d59a4439ed0be85e Mon Sep 17 00:00:00 2001 From: Meghan Date: Fri, 14 Apr 2023 22:26:11 +0000 Subject: [PATCH 09/20] moving code from pt to ptxla --- test/test_autocast.py | 11 ++++----- test/test_train_mp_imagenet_amp.py | 9 ++------ test/test_train_mp_mnist_amp.py | 6 +---- torch_xla/amp/__init__.py | 2 +- torch_xla/amp/autocast_mode.py | 37 +++++++++++++++++++++++++++--- 5 files changed, 42 insertions(+), 23 deletions(-) diff --git a/test/test_autocast.py b/test/test_autocast.py index 77cf680287a..8b33aed2e11 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -13,7 +13,6 @@ import unittest from torch.testing._internal.autocast_test_lists import AutocastTestLists from torch_xla.amp import autocast, GradScaler -import torch.xla.amp class AutocastTPUTestLists: @@ -188,7 +187,6 @@ class TestAutocastBase(unittest.TestCase): def setUp(self): super(TestAutocastBase, self).setUp() - self.autocast = None self.is_autocast_enabled = None self.autocast_lists = None self.autocast_unsupported_lists = None @@ -232,7 +230,7 @@ def cast(val, to_type): add_kwargs = {} self.assertFalse(self.is_autocast_enabled()) - with self.autocast(): + with autocast(xm.xla_device()): self.assertTrue(self.is_autocast_enabled()) out_type = out_type if out_type is not None else run_as_type @@ -282,7 +280,7 @@ def compare(first, second): # Compare numerics to Python-side "autocasting" that (we expect) does the same thing # as the C++-side autocasting, and should be bitwise accurate. output_to_compare = output if output is not None else output_method - with self.autocast(enabled=False): + with autocast(xm.xla_device(), enabled=False): self.assertFalse(self.is_autocast_enabled()) if module is not None and hasattr(module, op): @@ -298,12 +296,12 @@ def compare(first, second): self.assertFalse(self.is_autocast_enabled()) -@unittest.skipIf(not torch.cuda.is_available(), "requires cuda") +@unittest.skipIf(not xm.get_xla_supported_devices("GPU"), + f"GPU specific autocast test.") class TestAutocastCuda(TestAutocastBase): def setUp(self): super(TestAutocastCuda, self).setUp() - self.autocast = torch.xla.cuda.autocast self.is_autocast_enabled = torch.is_autocast_enabled self.autocast_lists = AutocastTestLists(torch.device(xm.xla_device())) self.autocast_unsupported_lists = AutocastCudaTestUnsupportedLists() @@ -367,7 +365,6 @@ class TestAutocastTPU(TestAutocastBase): def setUp(self): super(TestAutocastTPU, self).setUp() - self.autocast = torch.xla.amp.autocast self.is_autocast_enabled = torch.is_autocast_xla_enabled self.autocast_lists = AutocastTPUTestLists(torch.device(xm.xla_device())) diff --git a/test/test_train_mp_imagenet_amp.py b/test/test_train_mp_imagenet_amp.py index bf81a52a73a..bec112a9378 100644 --- a/test/test_train_mp_imagenet_amp.py +++ b/test/test_train_mp_imagenet_amp.py @@ -67,9 +67,7 @@ import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp import torch_xla.test.test_utils as test_utils -import torch.xla.amp as xla_amp -import torch.cuda.amp as xla_cuda -from torch_xla.amp import GradScaler +from torch_xla.amp import autocast, GradScaler try: from torch_xla.amp import syncfree except ImportError: @@ -222,11 +220,8 @@ def train_imagenet(): loss_fn = nn.CrossEntropyLoss() if FLAGS.amp: if device_hw == 'TPU': - autocast = xla_amp.autocast scaler = None elif device_hw == 'GPU': - autocast = cuda_amp.autocast - # GradScaler only used for GPU scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) def train_loop_fn(loader, epoch): @@ -235,7 +230,7 @@ def train_loop_fn(loader, epoch): for step, (data, target) in enumerate(loader): optimizer.zero_grad() if FLAGS.amp: - with autocast(): + with autocast(xm.xla_device()): output = model(data) loss = loss_fn(output, target) if scaler: diff --git a/test/test_train_mp_mnist_amp.py b/test/test_train_mp_mnist_amp.py index 550a0d6ae77..ae4db118300 100644 --- a/test/test_train_mp_mnist_amp.py +++ b/test/test_train_mp_mnist_amp.py @@ -42,8 +42,6 @@ from torch_xla.amp import syncfree except ImportError: assert False, "Missing package syncfree; the package is available in torch-xla>=1.11" -import torch.xla.amp as xla_amp -import torch.cuda.amp as xla_cuda class MNIST(nn.Module): @@ -143,10 +141,8 @@ def train_mnist(flags, **kwargs): loss_fn = nn.NLLLoss() if device_hw == 'TPU': - autocast = xla_amp.autocast scaler = None elif device_hw == 'GPU': - autocast = cuda_amp.autocast # GradScaler only used for GPU scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) else: @@ -158,7 +154,7 @@ def train_loop_fn(loader): model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() - with autocast(): + with autocast(device): output = model(data) loss = loss_fn(output, target) if scaler: diff --git a/torch_xla/amp/__init__.py b/torch_xla/amp/__init__.py index 1c0ecd08876..739f55cc0dc 100644 --- a/torch_xla/amp/__init__.py +++ b/torch_xla/amp/__init__.py @@ -1,2 +1,2 @@ -from .autocast_mode import autocast, custom_fwd, custom_bwd # noqa: F401 +from .autocast_mode import autocast # noqa: F401 from .grad_scaler import GradScaler # noqa: F401 diff --git a/torch_xla/amp/autocast_mode.py b/torch_xla/amp/autocast_mode.py index 2f7a0a83555..0e2803b311a 100644 --- a/torch_xla/amp/autocast_mode.py +++ b/torch_xla/amp/autocast_mode.py @@ -1,5 +1,36 @@ import torch +import torch_xla.core.xla_model as xm +from typing import Any -autocast = torch.cuda.amp.autocast -custom_fwd = torch.cuda.amp.custom_fwd -custom_bwd = torch.cuda.amp.custom_bwd + +class autocast(torch.amp.autocast_mode.autocast): + r""" + See :class:`torch.autocast`. + ``torch_xla.amp.autocast(device, args...)`` is equivalent to ``torch.autocast("xla", args...)`` for TPUs + ``torch.autocast("cuda", args...)`` for GPUs. + """ + + def __init__(self, + device, + enabled: bool = True, + dtype: torch.dtype = torch.bfloat16, + cache_enabled: bool = True): + if xm.xla_device_hw(device) == 'GPU': + super().__init__( + "cuda", + enabled=enabled, + dtype=torch.float16, + cache_enabled=cache_enabled) + else: + super().__init__( + "xla", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled) + + def __enter__(self): + return super().__enter__() + + def __exit__(self, exc_type: Any, exc_val: Any, + exc_tb: Any): # type: ignore[override] + return super().__exit__(exc_type, exc_val, exc_tb) + + def __call__(self, func): + return super().__call__(func) From e339506c749a16bab017282189953709d1e6a79e Mon Sep 17 00:00:00 2001 From: Meghan Date: Mon, 17 Apr 2023 18:30:16 +0000 Subject: [PATCH 10/20] fixes --- test/run_tests.sh | 1 + test/test_autocast.py | 5 +++-- torch_xla/csrc/tensor_impl.cpp | 3 +-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/test/run_tests.sh b/test/run_tests.sh index e63eb162c90..0cd6f4db498 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -168,6 +168,7 @@ function run_op_tests { run_test "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_test_without_functionalization "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_test "$CDIR/test_async_closures.py" + run_test "$CDIR/test_autocast.py" run_test "$CDIR/test_xla_dist.py" run_test "$CDIR/test_profiler.py" run_test "$CDIR/test_ops.py" diff --git a/test/test_autocast.py b/test/test_autocast.py index 8b33aed2e11..67479eed1be 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -297,7 +297,7 @@ def compare(first, second): @unittest.skipIf(not xm.get_xla_supported_devices("GPU"), - f"GPU specific autocast test.") + f"GPU autocast test.") class TestAutocastCuda(TestAutocastBase): def setUp(self): @@ -360,7 +360,8 @@ def test_autocast_methods_expect_builtin_promote(self): self._run_autocast_outofplace( op, args, torch.float32, module=None, out_type=out_type) - +@unittest.skipIf(xm.get_xla_supported_devices("GPU"), + f"TPU autocast test.") class TestAutocastTPU(TestAutocastBase): def setUp(self): diff --git a/torch_xla/csrc/tensor_impl.cpp b/torch_xla/csrc/tensor_impl.cpp index 7f64a5c770e..1cc214cc375 100644 --- a/torch_xla/csrc/tensor_impl.cpp +++ b/torch_xla/csrc/tensor_impl.cpp @@ -66,12 +66,11 @@ XLATensorImpl::XLATensorImpl(XLATensor&& tensor) // Update the Autocast key based off the backend device. // Upstream TensorImpl cannot differentiate between XLA:TPU and XLA:GPU // so we must manually update Autocast to AutocastCUDA on XLA:GPU. - c10::DispatchKeySet key_set = c10::TensorImpl::key_set(); torch::lazy::BackendDevice current_device = GetCurrentDevice(); if (static_cast(current_device.type()) == XlaDeviceType::GPU) { auto autocast_cuda_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastCUDA); auto autocast_xla_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastXLA); - key_set = (key_set - autocast_xla_ks) | autocast_cuda_ks; + key_set_ = (key_set_ - autocast_xla_ks) | autocast_cuda_ks; } is_non_overlapping_and_dense_ = false; set_custom_sizes_strides(SizesStridesPolicy::CustomSizes); From 03da7ec794ed782106a28a1a4af16631d0b7c774 Mon Sep 17 00:00:00 2001 From: Meghan Date: Mon, 17 Apr 2023 18:47:53 +0000 Subject: [PATCH 11/20] lint --- test/test_autocast.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/test_autocast.py b/test/test_autocast.py index 67479eed1be..f6ceae0fd43 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -296,8 +296,7 @@ def compare(first, second): self.assertFalse(self.is_autocast_enabled()) -@unittest.skipIf(not xm.get_xla_supported_devices("GPU"), - f"GPU autocast test.") +@unittest.skipIf(not xm.get_xla_supported_devices("GPU"), f"GPU autocast test.") class TestAutocastCuda(TestAutocastBase): def setUp(self): @@ -360,8 +359,8 @@ def test_autocast_methods_expect_builtin_promote(self): self._run_autocast_outofplace( op, args, torch.float32, module=None, out_type=out_type) -@unittest.skipIf(xm.get_xla_supported_devices("GPU"), - f"TPU autocast test.") + +@unittest.skipIf(xm.get_xla_supported_devices("GPU"), f"TPU autocast test.") class TestAutocastTPU(TestAutocastBase): def setUp(self): From e4a0daae9a0fd6268fa8ce76f4f710a7f31753f4 Mon Sep 17 00:00:00 2001 From: Meghan Date: Mon, 17 Apr 2023 21:20:26 +0000 Subject: [PATCH 12/20] move autocast+test from common.sh to run_tests.sh --- .circleci/common.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/.circleci/common.sh b/.circleci/common.sh index 6daa7e54b12..e9721e7ab3e 100755 --- a/.circleci/common.sh +++ b/.circleci/common.sh @@ -140,8 +140,6 @@ function run_torch_xla_tests() { chmod -R 755 ~/htmlcov else ./test/run_tests.sh - # only run test_autocast for cpu and gpu on circleCI. - python test/test_autocast.py # GPU tests if [ -x "$(command -v nvidia-smi)" ]; then From ff8f10bb38fd95cd313fbea2944e4dd753f0cd86 Mon Sep 17 00:00:00 2001 From: Meghan Date: Mon, 24 Apr 2023 18:14:46 +0000 Subject: [PATCH 13/20] updates --- test/test_autocast.py | 2 +- torch_xla/amp/autocast_mode.py | 4 +- torch_xla/csrc/autocast_mode.cpp | 122 +++++++++++++++---------------- 3 files changed, 65 insertions(+), 63 deletions(-) diff --git a/test/test_autocast.py b/test/test_autocast.py index f6ceae0fd43..3c801068df8 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -360,7 +360,7 @@ def test_autocast_methods_expect_builtin_promote(self): op, args, torch.float32, module=None, out_type=out_type) -@unittest.skipIf(xm.get_xla_supported_devices("GPU"), f"TPU autocast test.") +@unittest.skipIf(not xm.get_xla_supported_devices("TPU"), f"TPU autocast test.") class TestAutocastTPU(TestAutocastBase): def setUp(self): diff --git a/torch_xla/amp/autocast_mode.py b/torch_xla/amp/autocast_mode.py index 0e2803b311a..9459eadfe05 100644 --- a/torch_xla/amp/autocast_mode.py +++ b/torch_xla/amp/autocast_mode.py @@ -21,9 +21,11 @@ def __init__(self, enabled=enabled, dtype=torch.float16, cache_enabled=cache_enabled) - else: + elif xm.xla_device_hw(device) == 'TPU': super().__init__( "xla", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled) + else: + print('Warning: AMP only supported for XLA:TPU and XLA:GPU. Ignoring autocast.') def __enter__(self): return super().__enter__() diff --git a/torch_xla/csrc/autocast_mode.cpp b/torch_xla/csrc/autocast_mode.cpp index 8cf3569ca76..0c0cc8062b1 100644 --- a/torch_xla/csrc/autocast_mode.cpp +++ b/torch_xla/csrc/autocast_mode.cpp @@ -11,75 +11,75 @@ namespace autocast { // Base template for WrapFunction_, which is specialized to contain a "call" // method each CastPolicy -template -struct WrapFunction_ {}; +// template +// struct WrapFunction_ {}; -// CastPolicy::lower_precision_fp General_DeviceType -template -struct WrapFunction_> { - static Ret call(Args... args) { - c10::impl::ExcludeDispatchKeyGuard no_autocast( - get_autocast_dispatch_key_from_device_type(device_type)); - return (*F)( - cached_cast(get_lower_precision_fp_from_device_type(device_type), args, - device_type)...); - } -}; +// // CastPolicy::lower_precision_fp General_DeviceType +// template +// struct WrapFunction_> { +// static Ret call(Args... args) { +// c10::impl::ExcludeDispatchKeyGuard no_autocast( +// get_autocast_dispatch_key_from_device_type(device_type)); +// return (*F)( +// cached_cast(get_lower_precision_fp_from_device_type(device_type), args, +// device_type)...); +// } +// }; // CastPolicy::fp32 General_DeviceType -template -struct WrapFunction_> { - static Ret call(Args... args) { - c10::impl::ExcludeDispatchKeyGuard no_autocast( - get_autocast_dispatch_key_from_device_type(device_type)); - return (*F)(cached_cast(at::kFloat, args, device_type)...); - } -}; +// template +// struct WrapFunction_> { +// static Ret call(Args... args) { +// c10::impl::ExcludeDispatchKeyGuard no_autocast( +// get_autocast_dispatch_key_from_device_type(device_type)); +// return (*F)(cached_cast(at::kFloat, args, device_type)...); +// } +// }; // CastPolicy::promote General_DeviceType -template -struct WrapFunction_> { - static Ret call(Args... args) { - c10::impl::ExcludeDispatchKeyGuard no_autocast( - get_autocast_dispatch_key_from_device_type(device_type)); - auto to_type = - promote_type(get_lower_precision_fp_from_device_type(device_type), - device_type, args...); - return (*F)(cached_cast(to_type, args, device_type)...); - } -}; +// template +// struct WrapFunction_> { +// static Ret call(Args... args) { +// c10::impl::ExcludeDispatchKeyGuard no_autocast( +// get_autocast_dispatch_key_from_device_type(device_type)); +// auto to_type = +// promote_type(get_lower_precision_fp_from_device_type(device_type), +// device_type, args...); +// return (*F)(cached_cast(to_type, args, device_type)...); +// } +// }; // Wrapper to infer return_type and parameter_types for WrapFunction_ (imitating // pytorch/core/boxing/impl/WrapFunctionIntoFunctor.h) -template < - CastPolicy policy, DeviceType device_type, - class Registered, // The signature for which we're registering. The - // dispatcher's calling code invokes our registered - // functions with arguments matching Registered, so we - // register WrapFunction_::call methods with a matching - // signature to properly field those arguments. - // guts::function_traits below extracts return_type and - // parameter_types from Registered, which WrapFunction_ - // templates above use to declare their call methods. - class Redispatch, // The signature for the function we're redispatching to. - // In most cases this is the same as Registered, but for - // some ops (for example, ops where we append a dtype) - // it's useful to redispatch to a function with a - // different signature. - Redispatch* F> // The actual function we're redispatching to. -struct WrapFunction final { - using type = WrapFunction_< - policy, device_type, Redispatch, F, - typename guts::function_traits::return_type, - typename guts::function_traits::parameter_types>; -}; +// template < +// CastPolicy policy, DeviceType device_type, +// class Registered, // The signature for which we're registering. The +// // dispatcher's calling code invokes our registered +// // functions with arguments matching Registered, so we +// // register WrapFunction_::call methods with a matching +// // signature to properly field those arguments. +// // guts::function_traits below extracts return_type and +// // parameter_types from Registered, which WrapFunction_ +// // templates above use to declare their call methods. +// class Redispatch, // The signature for the function we're redispatching to. +// // In most cases this is the same as Registered, but for +// // some ops (for example, ops where we append a dtype) +// // it's useful to redispatch to a function with a +// // different signature. +// Redispatch* F> // The actual function we're redispatching to. +// struct WrapFunction final { +// using type = WrapFunction_< +// policy, device_type, Redispatch, F, +// typename guts::function_traits::return_type, +// typename guts::function_traits::parameter_types>; +// }; namespace { From 410912ebb2f029fdd5440509b03e45c7edf92869 Mon Sep 17 00:00:00 2001 From: Meghan Date: Mon, 24 Apr 2023 19:06:46 +0000 Subject: [PATCH 14/20] updates with pytorch --- torch_xla/csrc/autocast_mode.cpp | 102 ++++++------------------------- 1 file changed, 18 insertions(+), 84 deletions(-) diff --git a/torch_xla/csrc/autocast_mode.cpp b/torch_xla/csrc/autocast_mode.cpp index 0c0cc8062b1..fee1b90ae15 100644 --- a/torch_xla/csrc/autocast_mode.cpp +++ b/torch_xla/csrc/autocast_mode.cpp @@ -8,93 +8,27 @@ namespace at { namespace autocast { +namespace { -// Base template for WrapFunction_, which is specialized to contain a "call" -// method each CastPolicy -// template -// struct WrapFunction_ {}; - -// // CastPolicy::lower_precision_fp General_DeviceType -// template -// struct WrapFunction_> { -// static Ret call(Args... args) { -// c10::impl::ExcludeDispatchKeyGuard no_autocast( -// get_autocast_dispatch_key_from_device_type(device_type)); -// return (*F)( -// cached_cast(get_lower_precision_fp_from_device_type(device_type), args, -// device_type)...); -// } -// }; - -// CastPolicy::fp32 General_DeviceType -// template -// struct WrapFunction_> { -// static Ret call(Args... args) { -// c10::impl::ExcludeDispatchKeyGuard no_autocast( -// get_autocast_dispatch_key_from_device_type(device_type)); -// return (*F)(cached_cast(at::kFloat, args, device_type)...); -// } -// }; - -// CastPolicy::promote General_DeviceType -// template -// struct WrapFunction_> { -// static Ret call(Args... args) { -// c10::impl::ExcludeDispatchKeyGuard no_autocast( -// get_autocast_dispatch_key_from_device_type(device_type)); -// auto to_type = -// promote_type(get_lower_precision_fp_from_device_type(device_type), -// device_type, args...); -// return (*F)(cached_cast(to_type, args, device_type)...); -// } -// }; - -// Wrapper to infer return_type and parameter_types for WrapFunction_ (imitating -// pytorch/core/boxing/impl/WrapFunctionIntoFunctor.h) -// template < -// CastPolicy policy, DeviceType device_type, -// class Registered, // The signature for which we're registering. The -// // dispatcher's calling code invokes our registered -// // functions with arguments matching Registered, so we -// // register WrapFunction_::call methods with a matching -// // signature to properly field those arguments. -// // guts::function_traits below extracts return_type and -// // parameter_types from Registered, which WrapFunction_ -// // templates above use to declare their call methods. -// class Redispatch, // The signature for the function we're redispatching to. -// // In most cases this is the same as Registered, but for -// // some ops (for example, ops where we append a dtype) -// // it's useful to redispatch to a function with a -// // different signature. -// Redispatch* F> // The actual function we're redispatching to. -// struct WrapFunction final { -// using type = WrapFunction_< -// policy, device_type, Redispatch, F, -// typename guts::function_traits::return_type, -// typename guts::function_traits::parameter_types>; -// }; +#define KERNEL_XLA(OP, POLICY) \ + KERNEL(c10::DeviceType::XLA, OP, POLICY) -namespace { +#define KERNEL_XLA2(OP, OVERLOAD, POLICY) \ + KERNEL2(c10::DeviceType::XLA, OP, OVERLOAD, POLICY) -// KERNEL_XLA registration for AutocastXLA -#define KERNEL_XLA(OP, POLICY) \ - m.impl(TORCH_SELECTIVE_NAME("aten::" #OP), \ - &WrapFunction::type::call); -#define KERNEL_XLA2(OP, OVERLOAD, POLICY) \ - m.impl(TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \ - &WrapFunction::type::call); +#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XLA( \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) \ + KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \ + c10::DeviceType::XLA, \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) TORCH_LIBRARY_IMPL(_, AutocastXLA, m) { m.fallback(torch::CppFunction::makeFallthrough()); From aabbd128b9f168c3388b5ad054d99b7734fb5fe0 Mon Sep 17 00:00:00 2001 From: Meghan Date: Mon, 24 Apr 2023 19:21:03 +0000 Subject: [PATCH 15/20] lint --- torch_xla/amp/autocast_mode.py | 4 +++- torch_xla/csrc/autocast_mode.cpp | 22 +++++++--------------- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/torch_xla/amp/autocast_mode.py b/torch_xla/amp/autocast_mode.py index 9459eadfe05..999a099b777 100644 --- a/torch_xla/amp/autocast_mode.py +++ b/torch_xla/amp/autocast_mode.py @@ -25,7 +25,9 @@ def __init__(self, super().__init__( "xla", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled) else: - print('Warning: AMP only supported for XLA:TPU and XLA:GPU. Ignoring autocast.') + print( + 'Warning: AMP only supported for XLA:TPU and XLA:GPU. Ignoring autocast.' + ) def __enter__(self): return super().__enter__() diff --git a/torch_xla/csrc/autocast_mode.cpp b/torch_xla/csrc/autocast_mode.cpp index fee1b90ae15..91ff7999e10 100644 --- a/torch_xla/csrc/autocast_mode.cpp +++ b/torch_xla/csrc/autocast_mode.cpp @@ -10,25 +10,17 @@ namespace at { namespace autocast { namespace { -#define KERNEL_XLA(OP, POLICY) \ - KERNEL(c10::DeviceType::XLA, OP, POLICY) +#define KERNEL_XLA(OP, POLICY) KERNEL(c10::DeviceType::XLA, OP, POLICY) #define KERNEL_XLA2(OP, OVERLOAD, POLICY) \ KERNEL2(c10::DeviceType::XLA, OP, OVERLOAD, POLICY) -#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XLA( \ - REDISPATCH_FUNC, \ - REGISTER_NAME, \ - REGISTER_SIGNATURE, \ - REDISPATCH_SIGNATURE, \ - POLICY) \ - KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \ - c10::DeviceType::XLA, \ - REDISPATCH_FUNC, \ - REGISTER_NAME, \ - REGISTER_SIGNATURE, \ - REDISPATCH_SIGNATURE, \ - POLICY) +#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XLA( \ + REDISPATCH_FUNC, REGISTER_NAME, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, \ + POLICY) \ + KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(c10::DeviceType::XLA, REDISPATCH_FUNC, \ + REGISTER_NAME, REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, POLICY) TORCH_LIBRARY_IMPL(_, AutocastXLA, m) { m.fallback(torch::CppFunction::makeFallthrough()); From 52753340e6aa0004f867ffaf02a509abec9bde89 Mon Sep 17 00:00:00 2001 From: Meghan Date: Wed, 26 Apr 2023 20:35:19 +0000 Subject: [PATCH 16/20] build autocast_mode --- torch_xla/csrc/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index edd009967f6..17733dafc38 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -71,6 +71,7 @@ ptxla_cc_library( "softmax_builder.cpp", "batch_norm.cpp", "resize_ops.cpp", + "autocast_mode.cpp", ] + glob(["ops/*.cpp"]), hdrs = [ ":LazyIr.h", From 643691c1b5ee47a9ff575efd595b64afb25fb0a1 Mon Sep 17 00:00:00 2001 From: Meghan Cowan Date: Wed, 7 Jun 2023 12:57:43 -0700 Subject: [PATCH 17/20] Disable bazel remote cache --- .github/workflows/_build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/_build.yml b/.github/workflows/_build.yml index 3b3abd96f60..7563c47d65b 100644 --- a/.github/workflows/_build.yml +++ b/.github/workflows/_build.yml @@ -78,7 +78,7 @@ jobs: echo "declare -x CC=clang-8 CXX=clang++-8" | docker exec -i "${pid}" sh -c "cat >> xla_env" echo "declare -x XLA_USE_XRT=1" | docker exec -i "${pid}" sh -c "cat >> xla_env" echo "declare -x XLA_CUDA=1" | docker exec -i "${pid}" sh -c "cat >> xla_env" - echo "declare -x BAZEL_REMOTE_CACHE=1" | docker exec -i "${pid}" sh -c "cat >> xla_env" + echo "declare -x BAZEL_REMOTE_CACHE=0" | docker exec -i "${pid}" sh -c "cat >> xla_env" echo "${GCLOUD_SERVICE_KEY}" | docker exec -i "${pid}" sh -c "cat >> default_credentials.json" - name: Build From 4eba81f7fdad5c7c56ff995e429d9aa20f55a0d0 Mon Sep 17 00:00:00 2001 From: Meghan Date: Thu, 8 Jun 2023 17:45:53 +0000 Subject: [PATCH 18/20] experiment with no new files --- .github/workflows/_build.yml | 2 +- torch_xla/csrc/BUILD | 2 +- torch_xla/csrc/tensor.cpp | 162 +++++++++++++++++++++++++++++++++++ 3 files changed, 164 insertions(+), 2 deletions(-) diff --git a/.github/workflows/_build.yml b/.github/workflows/_build.yml index 7563c47d65b..3b3abd96f60 100644 --- a/.github/workflows/_build.yml +++ b/.github/workflows/_build.yml @@ -78,7 +78,7 @@ jobs: echo "declare -x CC=clang-8 CXX=clang++-8" | docker exec -i "${pid}" sh -c "cat >> xla_env" echo "declare -x XLA_USE_XRT=1" | docker exec -i "${pid}" sh -c "cat >> xla_env" echo "declare -x XLA_CUDA=1" | docker exec -i "${pid}" sh -c "cat >> xla_env" - echo "declare -x BAZEL_REMOTE_CACHE=0" | docker exec -i "${pid}" sh -c "cat >> xla_env" + echo "declare -x BAZEL_REMOTE_CACHE=1" | docker exec -i "${pid}" sh -c "cat >> xla_env" echo "${GCLOUD_SERVICE_KEY}" | docker exec -i "${pid}" sh -c "cat >> default_credentials.json" - name: Build diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index b839acc5692..2275bec5623 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -34,7 +34,7 @@ ptxla_cc_library( "aten_autograd_ops.cpp", "aten_xla_bridge.cpp", "aten_xla_type.cpp", - "autocast_mode.cpp", + -- "autocast_mode.cpp", "batch_norm.cpp", "convert_ops.cpp", "convolution.cpp", diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 04d782bb850..1733c79c752 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -47,6 +47,14 @@ #include "torch_xla/csrc/xla_graph_executor.h" #include "torch_xla/csrc/xla_sharding_util.h" +#include +#include +#include +#include +#include +#include +#include + namespace torch_xla { XLATensor::Data::~Data() { XLAGraphExecutor::Get()->UnregisterTensor(this); } @@ -875,3 +883,157 @@ int64_t XLATensor::GetOpaqueHandle() const { } } // namespace torch_xla + +// TODO: TEMPORARY +namespace at { +namespace autocast { +namespace { + +#define KERNEL_XLA(OP, POLICY) KERNEL(c10::DeviceType::XLA, OP, POLICY) + +#define KERNEL_XLA2(OP, OVERLOAD, POLICY) \ + KERNEL2(c10::DeviceType::XLA, OP, OVERLOAD, POLICY) + +#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XLA( \ + REDISPATCH_FUNC, REGISTER_NAME, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, \ + POLICY) \ + KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(c10::DeviceType::XLA, REDISPATCH_FUNC, \ + REGISTER_NAME, REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, POLICY) + +TORCH_LIBRARY_IMPL(_, AutocastXLA, m) { + m.fallback(torch::CppFunction::makeFallthrough()); +} + +TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) { + // lower_precision_fp cast policy + KERNEL_XLA(conv1d, lower_precision_fp) + KERNEL_XLA2(conv1d, padding, lower_precision_fp) + KERNEL_XLA(conv2d, lower_precision_fp) + KERNEL_XLA2(conv2d, padding, lower_precision_fp) + KERNEL_XLA(conv3d, lower_precision_fp) + KERNEL_XLA2(conv3d, padding, lower_precision_fp) + KERNEL_XLA(bmm, lower_precision_fp) + KERNEL_XLA(mm, lower_precision_fp) + KERNEL_XLA(baddbmm, lower_precision_fp) + KERNEL_XLA(addmm, lower_precision_fp) + KERNEL_XLA(addbmm, lower_precision_fp) + KERNEL_XLA(linear, lower_precision_fp) + KERNEL_XLA(matmul, lower_precision_fp) + KERNEL_XLA(conv_tbc, lower_precision_fp) + KERNEL_XLA(conv_transpose1d, lower_precision_fp) + KERNEL_XLA2(conv_transpose2d, input, lower_precision_fp) + KERNEL_XLA2(conv_transpose3d, input, lower_precision_fp) + KERNEL_XLA(prelu, lower_precision_fp) + KERNEL_XLA(relu, lower_precision_fp) + KERNEL_XLA(max_pool2d, lower_precision_fp) + + // fp32 cast policy + // Commented out ops are included in the AutoCastCPU Policy, + // but not lowered. Enable if op is lowered. + KERNEL_XLA(batch_norm, fp32) + KERNEL_XLA2(log_softmax, int, fp32) + KERNEL_XLA2(log_softmax, Dimname, fp32) + KERNEL_XLA(binary_cross_entropy, fp32) + // KERNEL_XLA(grid_sampler, fp32) + // KERNEL_XLA(polar, fp32) + KERNEL_XLA(prod, fp32) + KERNEL_XLA2(prod, dim_int, fp32) + KERNEL_XLA2(prod, dim_Dimname, fp32) + // KERNEL_XLA(quantile, fp32) + // KERNEL_XLA2(quantile, scalar, fp32) + // KERNEL_XLA(nanquantile, fp32) + // KERNEL_XLA2(nanquantile, scalar, fp32) + // KERNEL_XLA(stft, fp32) + // KERNEL_XLA2(stft, center, fp32) + KERNEL_XLA(cdist, fp32) + // KERNEL_XLA(grid_sampler_2d, fp32) + // KERNEL_XLA(grid_sampler_3d, fp32) + KERNEL_XLA(trace, fp32) + // KERNEL_XLA(view_as_complex, fp32) + KERNEL_XLA(cholesky, fp32) + KERNEL_XLA(cholesky_inverse, fp32) + KERNEL_XLA(cholesky_solve, fp32) + KERNEL_XLA(inverse, fp32) + // KERNEL_XLA(lu_solve, fp32) + // KERNEL_XLA(orgqr, fp32) + // KERNEL_XLA(ormqr, fp32) + // KERNEL_XLA(pinverse, fp32) + KERNEL_XLA(reflection_pad1d, fp32) + KERNEL_XLA(reflection_pad2d, fp32) + KERNEL_XLA(replication_pad1d, fp32) + KERNEL_XLA(replication_pad2d, fp32) + KERNEL_XLA(replication_pad3d, fp32) + KERNEL_XLA(mse_loss, fp32) + KERNEL_XLA(cosine_embedding_loss, fp32) + KERNEL_XLA(nll_loss, fp32) + KERNEL_XLA(nll_loss2d, fp32) + KERNEL_XLA(hinge_embedding_loss, fp32) + // KERNEL_XLA(poisson_nll_loss, fp32) + KERNEL_XLA(smooth_l1_loss, fp32) + // KERNEL_XLA(cross_entropy_loss, fp32) + KERNEL_XLA(l1_loss, fp32) + // KERNEL_XLA(huber_loss, fp32) + KERNEL_XLA(margin_ranking_loss, fp32) + KERNEL_XLA(soft_margin_loss, fp32) + KERNEL_XLA(triplet_margin_loss, fp32) + KERNEL_XLA(multi_margin_loss, fp32) + KERNEL_XLA2(ctc_loss, IntList, fp32) + KERNEL_XLA2(ctc_loss, Tensor, fp32) + KERNEL_XLA(kl_div, fp32) + KERNEL_XLA(multilabel_margin_loss, fp32) + KERNEL_XLA(binary_cross_entropy_with_logits, fp32) + // KERNEL_XLA(fft_fft, fp32) + // KERNEL_XLA(fft_ifft, fp32) + // KERNEL_XLA(fft_fft2, fp32) + // KERNEL_XLA(fft_ifft2, fp32) + // KERNEL_XLA(fft_fftn, fp32) + // KERNEL_XLA(fft_ifftn, fp32) + // KERNEL_XLA(fft_rfft, fp32) + // KERNEL_XLA(fft_irfft, fp32) + // KERNEL_XLA(fft_rfft2, fp32) + // KERNEL_XLA(fft_irfft2, fp32) + // KERNEL_XLA(fft_rfftn, fp32) + // KERNEL_XLA(fft_irfftn, fp32) + // KERNEL_XLA(fft_hfft, fp32) + // KERNEL_XLA(fft_ihfft, fp32) + // KERNEL_XLA(linalg_cond, fp32) + // KERNEL_XLA2(linalg_cond, p_str, fp32) + // KERNEL_XLA(linalg_matrix_rank, fp32) + // KERNEL_XLA2(linalg_matrix_rank, tol_tensor, fp32) + // KERNEL_XLA2(linalg_matrix_rank, atol_rtol_tensor, fp32) + // KERNEL_XLA2(linalg_matrix_rank, atol_rtol_float, fp32) + // KERNEL_XLA(linalg_solve, fp32) + // KERNEL_XLA(linalg_cholesky, fp32) + // KERNEL_XLA(linalg_svdvals, fp32) + // KERNEL_XLA(linalg_eigvals, fp32) + // KERNEL_XLA(linalg_eigvalsh, fp32) + // KERNEL_XLA(linalg_inv, fp32) + // KERNEL_XLA(linalg_householder_product, fp32) + // KERNEL_XLA(linalg_tensorinv, fp32) + // KERNEL_XLA(linalg_tensorsolve, fp32) + // KERNEL_XLA(fake_quantize_per_tensor_affine, fp32) + // KERNEL_XLA(geqrf, fp32) + // KERNEL_XLA(_lu_with_info, fp32) + KERNEL_XLA(qr, fp32) + KERNEL_XLA(svd, fp32) + KERNEL_XLA(triangular_solve, fp32) + KERNEL_XLA(multilabel_margin_loss_forward, fp32) + // KERNEL_XLA(linalg_qr, fp32) + // KERNEL_XLA(linalg_cholesky_ex, fp32) + KERNEL_XLA(linalg_svd, fp32) + // KERNEL_XLA(linalg_eig, fp32) + // KERNEL_XLA(linalg_eigh, fp32) + // KERNEL_XLA(linalg_lstsq, fp32) + KERNEL_XLA(linalg_inv_ex, fp32) + + // promote + KERNEL_XLA(stack, promote) + KERNEL_XLA(cat, promote) + KERNEL_XLA(index_copy, promote) + KERNEL_XLA2(index_copy, dimname, promote) +} + +} // namespace +} // namespace autocast +} // namespace at From ffac43f58de7b84488af0ff01349364ee5632b08 Mon Sep 17 00:00:00 2001 From: Meghan Date: Thu, 8 Jun 2023 19:13:54 +0000 Subject: [PATCH 19/20] revert back --- torch_xla/csrc/BUILD | 2 +- torch_xla/csrc/tensor.cpp | 160 -------------------------------------- 2 files changed, 1 insertion(+), 161 deletions(-) diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 2275bec5623..b839acc5692 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -34,7 +34,7 @@ ptxla_cc_library( "aten_autograd_ops.cpp", "aten_xla_bridge.cpp", "aten_xla_type.cpp", - -- "autocast_mode.cpp", + "autocast_mode.cpp", "batch_norm.cpp", "convert_ops.cpp", "convolution.cpp", diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 1733c79c752..3cd7ef9e5e2 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -47,13 +47,6 @@ #include "torch_xla/csrc/xla_graph_executor.h" #include "torch_xla/csrc/xla_sharding_util.h" -#include -#include -#include -#include -#include -#include -#include namespace torch_xla { @@ -884,156 +877,3 @@ int64_t XLATensor::GetOpaqueHandle() const { } // namespace torch_xla -// TODO: TEMPORARY -namespace at { -namespace autocast { -namespace { - -#define KERNEL_XLA(OP, POLICY) KERNEL(c10::DeviceType::XLA, OP, POLICY) - -#define KERNEL_XLA2(OP, OVERLOAD, POLICY) \ - KERNEL2(c10::DeviceType::XLA, OP, OVERLOAD, POLICY) - -#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XLA( \ - REDISPATCH_FUNC, REGISTER_NAME, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, \ - POLICY) \ - KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(c10::DeviceType::XLA, REDISPATCH_FUNC, \ - REGISTER_NAME, REGISTER_SIGNATURE, \ - REDISPATCH_SIGNATURE, POLICY) - -TORCH_LIBRARY_IMPL(_, AutocastXLA, m) { - m.fallback(torch::CppFunction::makeFallthrough()); -} - -TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) { - // lower_precision_fp cast policy - KERNEL_XLA(conv1d, lower_precision_fp) - KERNEL_XLA2(conv1d, padding, lower_precision_fp) - KERNEL_XLA(conv2d, lower_precision_fp) - KERNEL_XLA2(conv2d, padding, lower_precision_fp) - KERNEL_XLA(conv3d, lower_precision_fp) - KERNEL_XLA2(conv3d, padding, lower_precision_fp) - KERNEL_XLA(bmm, lower_precision_fp) - KERNEL_XLA(mm, lower_precision_fp) - KERNEL_XLA(baddbmm, lower_precision_fp) - KERNEL_XLA(addmm, lower_precision_fp) - KERNEL_XLA(addbmm, lower_precision_fp) - KERNEL_XLA(linear, lower_precision_fp) - KERNEL_XLA(matmul, lower_precision_fp) - KERNEL_XLA(conv_tbc, lower_precision_fp) - KERNEL_XLA(conv_transpose1d, lower_precision_fp) - KERNEL_XLA2(conv_transpose2d, input, lower_precision_fp) - KERNEL_XLA2(conv_transpose3d, input, lower_precision_fp) - KERNEL_XLA(prelu, lower_precision_fp) - KERNEL_XLA(relu, lower_precision_fp) - KERNEL_XLA(max_pool2d, lower_precision_fp) - - // fp32 cast policy - // Commented out ops are included in the AutoCastCPU Policy, - // but not lowered. Enable if op is lowered. - KERNEL_XLA(batch_norm, fp32) - KERNEL_XLA2(log_softmax, int, fp32) - KERNEL_XLA2(log_softmax, Dimname, fp32) - KERNEL_XLA(binary_cross_entropy, fp32) - // KERNEL_XLA(grid_sampler, fp32) - // KERNEL_XLA(polar, fp32) - KERNEL_XLA(prod, fp32) - KERNEL_XLA2(prod, dim_int, fp32) - KERNEL_XLA2(prod, dim_Dimname, fp32) - // KERNEL_XLA(quantile, fp32) - // KERNEL_XLA2(quantile, scalar, fp32) - // KERNEL_XLA(nanquantile, fp32) - // KERNEL_XLA2(nanquantile, scalar, fp32) - // KERNEL_XLA(stft, fp32) - // KERNEL_XLA2(stft, center, fp32) - KERNEL_XLA(cdist, fp32) - // KERNEL_XLA(grid_sampler_2d, fp32) - // KERNEL_XLA(grid_sampler_3d, fp32) - KERNEL_XLA(trace, fp32) - // KERNEL_XLA(view_as_complex, fp32) - KERNEL_XLA(cholesky, fp32) - KERNEL_XLA(cholesky_inverse, fp32) - KERNEL_XLA(cholesky_solve, fp32) - KERNEL_XLA(inverse, fp32) - // KERNEL_XLA(lu_solve, fp32) - // KERNEL_XLA(orgqr, fp32) - // KERNEL_XLA(ormqr, fp32) - // KERNEL_XLA(pinverse, fp32) - KERNEL_XLA(reflection_pad1d, fp32) - KERNEL_XLA(reflection_pad2d, fp32) - KERNEL_XLA(replication_pad1d, fp32) - KERNEL_XLA(replication_pad2d, fp32) - KERNEL_XLA(replication_pad3d, fp32) - KERNEL_XLA(mse_loss, fp32) - KERNEL_XLA(cosine_embedding_loss, fp32) - KERNEL_XLA(nll_loss, fp32) - KERNEL_XLA(nll_loss2d, fp32) - KERNEL_XLA(hinge_embedding_loss, fp32) - // KERNEL_XLA(poisson_nll_loss, fp32) - KERNEL_XLA(smooth_l1_loss, fp32) - // KERNEL_XLA(cross_entropy_loss, fp32) - KERNEL_XLA(l1_loss, fp32) - // KERNEL_XLA(huber_loss, fp32) - KERNEL_XLA(margin_ranking_loss, fp32) - KERNEL_XLA(soft_margin_loss, fp32) - KERNEL_XLA(triplet_margin_loss, fp32) - KERNEL_XLA(multi_margin_loss, fp32) - KERNEL_XLA2(ctc_loss, IntList, fp32) - KERNEL_XLA2(ctc_loss, Tensor, fp32) - KERNEL_XLA(kl_div, fp32) - KERNEL_XLA(multilabel_margin_loss, fp32) - KERNEL_XLA(binary_cross_entropy_with_logits, fp32) - // KERNEL_XLA(fft_fft, fp32) - // KERNEL_XLA(fft_ifft, fp32) - // KERNEL_XLA(fft_fft2, fp32) - // KERNEL_XLA(fft_ifft2, fp32) - // KERNEL_XLA(fft_fftn, fp32) - // KERNEL_XLA(fft_ifftn, fp32) - // KERNEL_XLA(fft_rfft, fp32) - // KERNEL_XLA(fft_irfft, fp32) - // KERNEL_XLA(fft_rfft2, fp32) - // KERNEL_XLA(fft_irfft2, fp32) - // KERNEL_XLA(fft_rfftn, fp32) - // KERNEL_XLA(fft_irfftn, fp32) - // KERNEL_XLA(fft_hfft, fp32) - // KERNEL_XLA(fft_ihfft, fp32) - // KERNEL_XLA(linalg_cond, fp32) - // KERNEL_XLA2(linalg_cond, p_str, fp32) - // KERNEL_XLA(linalg_matrix_rank, fp32) - // KERNEL_XLA2(linalg_matrix_rank, tol_tensor, fp32) - // KERNEL_XLA2(linalg_matrix_rank, atol_rtol_tensor, fp32) - // KERNEL_XLA2(linalg_matrix_rank, atol_rtol_float, fp32) - // KERNEL_XLA(linalg_solve, fp32) - // KERNEL_XLA(linalg_cholesky, fp32) - // KERNEL_XLA(linalg_svdvals, fp32) - // KERNEL_XLA(linalg_eigvals, fp32) - // KERNEL_XLA(linalg_eigvalsh, fp32) - // KERNEL_XLA(linalg_inv, fp32) - // KERNEL_XLA(linalg_householder_product, fp32) - // KERNEL_XLA(linalg_tensorinv, fp32) - // KERNEL_XLA(linalg_tensorsolve, fp32) - // KERNEL_XLA(fake_quantize_per_tensor_affine, fp32) - // KERNEL_XLA(geqrf, fp32) - // KERNEL_XLA(_lu_with_info, fp32) - KERNEL_XLA(qr, fp32) - KERNEL_XLA(svd, fp32) - KERNEL_XLA(triangular_solve, fp32) - KERNEL_XLA(multilabel_margin_loss_forward, fp32) - // KERNEL_XLA(linalg_qr, fp32) - // KERNEL_XLA(linalg_cholesky_ex, fp32) - KERNEL_XLA(linalg_svd, fp32) - // KERNEL_XLA(linalg_eig, fp32) - // KERNEL_XLA(linalg_eigh, fp32) - // KERNEL_XLA(linalg_lstsq, fp32) - KERNEL_XLA(linalg_inv_ex, fp32) - - // promote - KERNEL_XLA(stack, promote) - KERNEL_XLA(cat, promote) - KERNEL_XLA(index_copy, promote) - KERNEL_XLA2(index_copy, dimname, promote) -} - -} // namespace -} // namespace autocast -} // namespace at From 4097324dd4807492581a2d7f7d3f75173b91b9aa Mon Sep 17 00:00:00 2001 From: Meghan Date: Thu, 8 Jun 2023 22:08:40 +0000 Subject: [PATCH 20/20] lint --- torch_xla/csrc/tensor.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 3cd7ef9e5e2..04d782bb850 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -47,7 +47,6 @@ #include "torch_xla/csrc/xla_graph_executor.h" #include "torch_xla/csrc/xla_sharding_util.h" - namespace torch_xla { XLATensor::Data::~Data() { XLAGraphExecutor::Get()->UnregisterTensor(this); } @@ -876,4 +875,3 @@ int64_t XLATensor::GetOpaqueHandle() const { } } // namespace torch_xla -