diff --git a/scripts/gen.py b/scripts/gen.py index 5efdab4cf71a..b541899391a3 100755 --- a/scripts/gen.py +++ b/scripts/gen.py @@ -80,7 +80,7 @@ class ArgTemplate(string.Template): _XPARSER = lark.Lark( _GRAMMAR, parser='lalr', propagate_positions=True, keep_all_tokens=True) -# _FN_WHITELIST/_FN_FULL_OVERRIDE/_FN_BLACKLIST takes either name or mapsig. +# _FN_FULL_OVERRIDE/_FN_BLACKLIST takes either name or mapsig. _FN_BLACKLIST = set([]) # List of non-leaf ops we want to override both forward + backward. @@ -90,12 +90,6 @@ class ArgTemplate(string.Template): 'max_pool3d(Tensor, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, bool) -> Tensor', ]) -# List of non-leaf ops we want to override forward. -# TODO(#1362, #1364) -_FN_WHITELIST = _FN_FULL_OVERRIDE | set([ - 'copy_(Tensor, Tensor, bool) -> Tensor', -]) - _FN_BLACKLIST_REGEX = [ # ATEN functions r'[^(]*cudnn', @@ -1033,7 +1027,7 @@ def generate_registrations(fgens, overrides): # XLA is only able to override leaf ops and whitelisted non-leaf ops. def is_overrideable(fgen): - return fgen.leaf or fgen.mapsig in _FN_WHITELIST or fgen.func in _FN_WHITELIST + return fgen.leaf or fgen.mapsig in _FN_FULL_OVERRIDE or fgen.func in _FN_FULL_OVERRIDE def generate_functions(fgens): diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 56c857b216a3..fc2db5c91d22 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -48,7 +48,7 @@ TEST_F(AtenXlaTensorTest, TestTo) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::copy_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::_copy_from", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestIsFloatingPoint) { @@ -2571,7 +2571,7 @@ TEST_F(AtenXlaTensorTest, TestZerosLikeOptions) { ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); ExpectCounterChanged("xla::empty", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::copy_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::_copy_from", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestZeros) { @@ -2626,7 +2626,7 @@ TEST_F(AtenXlaTensorTest, TestOnesLikeOptions) { ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); ExpectCounterChanged("xla::empty", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::copy_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::_copy_from", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestFull) { @@ -2670,7 +2670,7 @@ TEST_F(AtenXlaTensorTest, TestFullLikeOptions) { ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); ExpectCounterChanged("xla::empty", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::copy_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::_copy_from", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestARange) { @@ -3075,6 +3075,23 @@ TEST_F(AtenXlaTensorTest, TestEinsumOuter) { ExpectCounterChanged("xla::view", cpp_test::GetIgnoredCounters()); } +TEST_F(AtenXlaTensorTest, TestEinsumOuterBackward) { + torch::Tensor a = + torch::rand({5}, torch::TensorOptions(torch::kFloat).requires_grad(true)); + torch::Tensor b = + torch::rand({5}, torch::TensorOptions(torch::kFloat).requires_grad(true)); + std::string equation = "i,j->ij"; + auto testfn = [&](const std::vector& inputs) -> torch::Tensor { + return torch::einsum(equation, inputs); + }; + ForEachDevice([&](const torch::Device& device) { + TestBackward({a, b}, device, testfn, /*rtol=*/1e-3, /*atol=*/1e-4); + }); + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::view", cpp_test::GetIgnoredCounters()); +} + TEST_F(AtenXlaTensorTest, TestEinsumBatchMatMul) { torch::Tensor a = torch::rand({3, 2, 5}, torch::TensorOptions(torch::kFloat)); torch::Tensor b = torch::rand({3, 5, 4}, torch::TensorOptions(torch::kFloat)); @@ -7100,7 +7117,7 @@ TEST_F(AtenXlaTensorTest, TestContiguous) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::copy_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::_copy_from", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestSqueezeAll) { diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index d3e5ba369fcb..a189b68a06ba 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -223,7 +223,22 @@ at::Tensor AtenXlaType::_adaptive_avg_pool2d_backward( at::Tensor AtenXlaType::_copy_from(const at::Tensor& self, const at::Tensor& dst, bool non_blocking) { XLA_FN_COUNTER("xla::"); - copy_(const_cast(dst), self, non_blocking); + auto dst_tensor = bridge::TryGetXlaTensor(dst); + auto self_tensor = bridge::TryGetXlaTensor(self); + if (!self_tensor) { + static bool sync_update = + xla::sys_util::GetEnvBool("XLA_TENSOR_UPDATE_SYNC", true); + XLA_CHECK(dst_tensor); + dst_tensor->UpdateFromTensor(self, /*sync=*/sync_update); + } else if (!dst_tensor) { + at::Tensor tensor = self_tensor->ToTensor(/*detached=*/true); + at::Tensor typed_tensor = + CopyTensor(tensor, dst.scalar_type(), /*copy=*/false); + dst.resize_as_(typed_tensor).copy_(typed_tensor); + } else { + XLATensor::copy_(*dst_tensor, *self_tensor); + bridge::ReplaceXlaTensor(dst, *dst_tensor); + } return dst; } @@ -907,28 +922,6 @@ AtenXlaType::convolution_backward_overrideable( : at::Tensor()); } -at::Tensor& AtenXlaType::copy_(at::Tensor& self, const at::Tensor& src, - bool non_blocking) { - XLA_FN_COUNTER("xla::"); - auto self_tensor = bridge::TryGetXlaTensor(self); - auto src_tensor = bridge::TryGetXlaTensor(src); - if (!src_tensor) { - static bool sync_update = - xla::sys_util::GetEnvBool("XLA_TENSOR_UPDATE_SYNC", true); - XLA_CHECK(self_tensor); - self_tensor->UpdateFromTensor(src, /*sync=*/sync_update); - } else if (!self_tensor) { - at::Tensor tensor = src_tensor->ToTensor(/*detached=*/true); - at::Tensor typed_tensor = - CopyTensor(tensor, self.scalar_type(), /*copy=*/false); - self.resize_as_(typed_tensor).copy_(typed_tensor); - } else { - XLATensor::copy_(*self_tensor, *src_tensor); - bridge::ReplaceXlaTensor(self, *self_tensor); - } - return self; -} - at::Tensor AtenXlaType::cos(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor(XLATensor::cos(bridge::GetXlaTensor(self))); diff --git a/torch_xla/csrc/aten_xla_type.h b/torch_xla/csrc/aten_xla_type.h index 3e3294f289a0..b3f6c02beaf1 100644 --- a/torch_xla/csrc/aten_xla_type.h +++ b/torch_xla/csrc/aten_xla_type.h @@ -261,9 +261,6 @@ class AtenXlaType { at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups); - static at::Tensor& copy_(at::Tensor& self, const at::Tensor& src, - bool non_blocking); - static at::Tensor cos(const at::Tensor& self); static at::Tensor& cos_(at::Tensor& self);