diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 6cc9bf1fcfa..94b2b353166 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -413,7 +413,7 @@ TEST_F(AtenXlaTensorTest, TestDivInPlace) { } } ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::div_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::div", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestDivInPlaceWithRoundingMode) { @@ -443,7 +443,7 @@ TEST_F(AtenXlaTensorTest, TestDivInPlaceWithRoundingMode) { } } ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::div_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::div", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestDivScalar) { @@ -485,7 +485,7 @@ TEST_F(AtenXlaTensorTest, TestDivScalarInPlace) { } } ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::div_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::div", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestDivOut) { @@ -1920,7 +1920,7 @@ TEST_F(AtenXlaTensorTest, TestCosineSimilarity) { ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); ExpectCounterChanged("xla::sum", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::clamp_min_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::clamp_min", cpp_test::GetIgnoredCounters()); } } @@ -1944,7 +1944,7 @@ TEST_F(AtenXlaTensorTest, TestCosineEmbeddingLoss) { AllClose(output, xla_output); }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::clamp_min_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::clamp_min", cpp_test::GetIgnoredCounters()); } } } @@ -1968,7 +1968,7 @@ TEST_F(AtenXlaTensorTest, TestHingeEmbeddingLoss) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::clamp_min_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::clamp_min", cpp_test::GetIgnoredCounters()); } } } @@ -2063,7 +2063,7 @@ TEST_F(AtenXlaTensorTest, TestMarginRankingLoss) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::clamp_min_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::clamp_min", cpp_test::GetIgnoredCounters()); } } } @@ -2476,7 +2476,7 @@ TEST_F(AtenXlaTensorTest, TestAsinhInPlace) { AllClose(b, xla_b, /*rtol=*/1e-3, /*atol=*/1e-5); }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::asinh_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::asinh", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestSin) { @@ -2533,7 +2533,7 @@ TEST_F(AtenXlaTensorTest, TestAcoshInPlace) { AllClose(b, xla_b, /*rtol=*/1e-3, /*atol=*/1e-5); }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::acosh_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::acosh", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestCos) { @@ -2588,7 +2588,7 @@ TEST_F(AtenXlaTensorTest, TestAtanhInPlace) { AllClose(b, xla_b, /*rtol=*/1e-3, /*atol=*/1e-5); }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::atanh_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::atanh", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestAtan2) { @@ -2777,7 +2777,7 @@ TEST_F(AtenXlaTensorTest, TestClampMinExplicitInPlace) { AllClose(b, xla_b); }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::clamp_min_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::clamp_min", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestClampMaxExplicitInPlace) { @@ -2791,7 +2791,7 @@ TEST_F(AtenXlaTensorTest, TestClampMaxExplicitInPlace) { AllClose(b, xla_b); }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::clamp_max_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::clamp_max", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestCeil) { @@ -3192,7 +3192,7 @@ TEST_F(AtenXlaTensorTest, TestBlackmanWindow) { ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); ExpectCounterChanged("xla::arange_out", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::cos_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::cos", cpp_test::GetIgnoredCounters()); } } @@ -3213,7 +3213,7 @@ TEST_F(AtenXlaTensorTest, TestHammingWindow) { ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); ExpectCounterChanged("xla::arange_out", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::cos_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::cos", cpp_test::GetIgnoredCounters()); } } @@ -3231,7 +3231,7 @@ TEST_F(AtenXlaTensorTest, TestHannWindow) { ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); ExpectCounterChanged("xla::arange_out", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::cos_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::cos", cpp_test::GetIgnoredCounters()); } } @@ -3475,7 +3475,7 @@ TEST_F(AtenXlaTensorTest, TestBatchAddBatchMatMulInPlace) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::baddbmm_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::baddbmm", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestBatchMatMul) { @@ -5513,7 +5513,7 @@ TEST_F(AtenXlaTensorTest, TestHardSigmoidInPlace) { AllClose(output, xla_output); }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::hardsigmoid_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::hardsigmoid", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestHardSigmoidBackward) { diff --git a/test/cpp/test_tensor.cpp b/test/cpp/test_tensor.cpp index 3c293a8a282..67fd3547ccf 100644 --- a/test/cpp/test_tensor.cpp +++ b/test/cpp/test_tensor.cpp @@ -6,6 +6,7 @@ #include "cpp_test_util.h" #include "torch/csrc/autograd/variable.h" +#include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/tensor.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla_test.h" @@ -205,13 +206,14 @@ TEST_F(TensorTest, TestViewMod) { output.add_(one, 1.0); input.add_(one, 1.0); ForEachDevice([&](const Device& device) { - at::Tensor xinput = - at::zeros({32, 20, 4, 4}, at::TensorOptions(at::kFloat)); - XLATensor dev_input = XLATensor::Create(xinput, device); - XLATensor dev_one = XLATensor::Create(one, device); - XLATensor dev_output = XLATensor::view(dev_input, {-1, 320}); - XLATensor::add_(dev_output, dev_one, 1.0); - XLATensor::add_(dev_input, dev_one, 1.0); + at::Tensor dev_input = + at::zeros({32, 20, 4, 4}, + at::TensorOptions(bridge::XlaDeviceToAtenDevice(device))); + at::Tensor dev_one = at::tensor( + 1.0, at::TensorOptions(bridge::XlaDeviceToAtenDevice(device))); + at::Tensor dev_output = dev_input.view({-1, 320}); + dev_output.add_(dev_one, 1.0); + dev_input.add_(dev_one, 1.0); AllClose(output, dev_output); AllClose(input, dev_input); }); @@ -225,14 +227,15 @@ TEST_F(TensorTest, TestViewModComplex) { at::Tensor output2 = input.view({-1, 160}); output2.add_(one, 1.0); ForEachDevice([&](const Device& device) { - at::Tensor xinput = - at::zeros({32, 20, 4, 4}, at::TensorOptions(at::kFloat)); - XLATensor dev_input = XLATensor::Create(xinput, device); - XLATensor dev_one = XLATensor::Create(one, device); - XLATensor dev_output1 = XLATensor::view(dev_input, {-1, 320}); - XLATensor::add_(dev_output1, dev_one, 1.0); - XLATensor dev_output2 = XLATensor::view(dev_input, {-1, 160}); - XLATensor::add_(dev_output2, dev_one, 1.0); + at::Tensor dev_input = + at::zeros({32, 20, 4, 4}, + at::TensorOptions(bridge::XlaDeviceToAtenDevice(device))); + at::Tensor dev_one = at::tensor( + 1.0, at::TensorOptions(bridge::XlaDeviceToAtenDevice(device))); + at::Tensor dev_output1 = dev_input.view({-1, 320}); + dev_output1.add_(dev_one, 1.0); + at::Tensor dev_output2 = dev_input.view({-1, 160}); + dev_output2.add_(dev_one, 1.0); AllClose(output1, dev_output1); AllClose(output2, dev_output2); }); @@ -246,14 +249,15 @@ TEST_F(TensorTest, TestViewOfViewMod) { at::Tensor output2 = output1.view({-1, 160}); output2.add_(one, 1.0); ForEachDevice([&](const Device& device) { - at::Tensor xinput = - at::zeros({32, 20, 4, 4}, at::TensorOptions(at::kFloat)); - XLATensor dev_input = XLATensor::Create(xinput, device); - XLATensor dev_one = XLATensor::Create(one, device); - XLATensor dev_output1 = XLATensor::view(dev_input, {-1, 320}); - XLATensor::add_(dev_output1, dev_one, 1.0); - XLATensor dev_output2 = XLATensor::view(dev_output1, {-1, 160}); - XLATensor::add_(dev_output2, dev_one, 1.0); + at::Tensor dev_input = + at::zeros({32, 20, 4, 4}, + at::TensorOptions(bridge::XlaDeviceToAtenDevice(device))); + at::Tensor dev_one = at::tensor( + 1.0, at::TensorOptions(bridge::XlaDeviceToAtenDevice(device))); + at::Tensor dev_output1 = dev_input.view({-1, 320}); + dev_output1.add_(dev_one, 1.0); + at::Tensor dev_output2 = dev_input.view({-1, 160}); + dev_output2.add_(dev_one, 1.0); AllClose(output1, dev_output1); AllClose(output2, dev_output2); }); diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 18199acd0f2..e1a63c195ad 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -444,38 +444,17 @@ at::Tensor abs(const at::Tensor& self) { return bridge::AtenFromXlaTensor(XLATensor::abs(bridge::GetXlaTensor(self))); } -at::Tensor& abs_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::abs_(self_tensor); - return self; -} - at::Tensor acos(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor(XLATensor::acos(bridge::GetXlaTensor(self))); } -at::Tensor& acos_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::acos_(self_tensor); - return self; -} - at::Tensor acosh(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor( XLATensor::acosh(bridge::GetXlaTensor(self))); } -at::Tensor& acosh_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::acosh_(self_tensor); - return self; -} - at::Tensor add(const at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha) { XLA_FN_COUNTER("xla::"); @@ -497,15 +476,6 @@ at::Tensor add(const at::Tensor& self, const at::Scalar& other, }); } -at::Tensor& add_(at::Tensor& self, const at::Scalar& other, - const at::Scalar& alpha) { - XLA_FN_COUNTER("xla::"); - CheckBinaryOpTypePromotion(self, self, other); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::add_(self_tensor, other, alpha); - return self; -} - at::Tensor addcdiv(const at::Tensor& self, const at::Tensor& tensor1, const at::Tensor& tensor2, const at::Scalar& value) { XLA_FN_COUNTER("xla::"); @@ -531,15 +501,6 @@ at::Tensor addcmul(const at::Tensor& self, const at::Tensor& tensor1, bridge::GetXlaTensor(tensor2))); } -at::Tensor& addcmul_(at::Tensor& self, const at::Tensor& tensor1, - const at::Tensor& tensor2, const at::Scalar& value) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::addcmul_(self_tensor, value, bridge::GetXlaTensor(tensor1), - bridge::GetXlaTensor(tensor2)); - return self; -} - at::Tensor addmm(const at::Tensor& self, const at::Tensor& mat1, const at::Tensor& mat2, const at::Scalar& beta, const at::Scalar& alpha) { @@ -655,26 +616,12 @@ at::Tensor asin(const at::Tensor& self) { return bridge::AtenFromXlaTensor(XLATensor::asin(bridge::GetXlaTensor(self))); } -at::Tensor& asin_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::asin_(self_tensor); - return self; -} - at::Tensor asinh(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor( XLATensor::asinh(bridge::GetXlaTensor(self))); } -at::Tensor& asinh_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::asinh_(self_tensor); - return self; -} - at::Tensor atan(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor(XLATensor::atan(bridge::GetXlaTensor(self))); @@ -699,32 +646,6 @@ at::Tensor atan2(const at::Tensor& self, const at::Tensor& other) { }); } -at::Tensor& atan2_(at::Tensor& self, const at::Tensor& other) { - XLA_FN_COUNTER("xla::"); - // xla::Atan2 doesn't support integer types. - if (!self.is_floating_point() || !other.is_floating_point()) { - return AtenXlaTypeDefault::atan2_(self, other); - } - CheckBinaryOpTypePromotion(self, self, other); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::atan2_(self_tensor, bridge::GetXlaTensor(other)); - return self; -} - -at::Tensor& atan_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::atan_(self_tensor); - return self; -} - -at::Tensor& atanh_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::atanh_(self_tensor); - return self; -} - at::Tensor avg_pool2d(const at::Tensor& self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, @@ -809,21 +730,6 @@ at::Tensor baddbmm(const at::Tensor& self, const at::Tensor& batch1, bridge::GetXlaTensor(batch2), beta, alpha)); } -at::Tensor& baddbmm_(at::Tensor& self, const at::Tensor& batch1, - const at::Tensor& batch2, const at::Scalar& beta, - const at::Scalar& alpha) { - XLA_FN_COUNTER("xla::"); - // xla::dot doesn't support integer types. - if (!at::native::is_floating_point(batch1) || - !at::native::is_floating_point(batch2)) { - return AtenXlaTypeDefault::baddbmm_(self, batch1, batch2, beta, alpha); - } - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::baddbmm_(self_tensor, bridge::GetXlaTensor(batch1), - bridge::GetXlaTensor(batch2), beta, alpha); - return self; -} - at::Tensor bernoulli(const at::Tensor& self, c10::optional generator) { XLA_FN_COUNTER("xla::"); @@ -978,13 +884,6 @@ at::Tensor ceil(const at::Tensor& self) { return bridge::AtenFromXlaTensor(XLATensor::ceil(bridge::GetXlaTensor(self))); } -at::Tensor& ceil_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::ceil_(self_tensor); - return self; -} - at::Tensor cholesky(const at::Tensor& self, bool upper) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor( @@ -1005,27 +904,12 @@ at::Tensor clamp(const at::Tensor& self, const c10::optional& min, XLATensor::clamp(bridge::GetXlaTensor(self), min, max)); } -at::Tensor& clamp_(at::Tensor& self, const c10::optional& min, - const c10::optional& max) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::clamp_(self_tensor, min, max); - return self; -} - at::Tensor clamp_max(const at::Tensor& self, const at::Scalar& max) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor( XLATensor::clamp(bridge::GetXlaTensor(self), c10::nullopt, max)); } -at::Tensor& clamp_max_(at::Tensor& self, const at::Scalar& max) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::clamp_(self_tensor, c10::nullopt, max); - return self; -} - at::Tensor& clamp_max_out(const at::Tensor& self, const at::Tensor& max, at::Tensor& out) { XLA_FN_COUNTER("xla::"); @@ -1041,13 +925,6 @@ at::Tensor clamp_min(const at::Tensor& self, const at::Scalar& min) { XLATensor::clamp(bridge::GetXlaTensor(self), min, c10::nullopt)); } -at::Tensor& clamp_min_(at::Tensor& self, const at::Scalar& min) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::clamp_(self_tensor, min, c10::nullopt); - return self; -} - at::Tensor& clamp_min_out(const at::Tensor& self, const at::Tensor& min, at::Tensor& out) { XLA_FN_COUNTER("xla::"); @@ -1120,25 +997,11 @@ at::Tensor cos(const at::Tensor& self) { return bridge::AtenFromXlaTensor(XLATensor::cos(bridge::GetXlaTensor(self))); } -at::Tensor& cos_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::cos_(self_tensor); - return self; -} - at::Tensor cosh(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor(XLATensor::cosh(bridge::GetXlaTensor(self))); } -at::Tensor& cosh_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::cosh_(self_tensor); - return self; -} - at::Tensor cross(const at::Tensor& self, const at::Tensor& other, c10::optional dim) { XLA_FN_COUNTER("xla::"); @@ -1205,29 +1068,6 @@ at::Tensor div(const at::Tensor& self, const at::Scalar& other) { XLATensor::div(bridge::GetXlaTensor(self), other)); } -at::Tensor& div_(at::Tensor& self, const at::Tensor& other) { - return div_(self, other, /*rounding_mode=*/c10::nullopt); -} - -at::Tensor& div_(at::Tensor& self, const at::Tensor& other, - c10::optional rounding_mode) { - XLA_FN_COUNTER("xla::"); - CheckBinaryOpTypePromotion(self, self, other); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::div_(self_tensor, - bridge::GetOrCreateXlaTensor(other, self_tensor.GetDevice()), - rounding_mode); - return self; -} - -at::Tensor& div_(at::Tensor& self, const at::Scalar& other) { - XLA_FN_COUNTER("xla::"); - CheckBinaryOpTypePromotion(self, self, other); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::div_(self_tensor, other); - return self; -} - at::Tensor dot(const at::Tensor& self, const at::Tensor& tensor) { XLA_FN_COUNTER("xla::"); XLA_CHECK_EQ(self.dim(), 1) @@ -1328,69 +1168,27 @@ at::Tensor eq(const at::Tensor& self, const at::Tensor& other) { XLATensor::eq(bridge::GetXlaTensor(self), bridge::GetXlaTensor(other))); } -at::Tensor& eq_(at::Tensor& self, const at::Scalar& other) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::eq_(self_tensor, other); - return self; -} - -at::Tensor& eq_(at::Tensor& self, const at::Tensor& other) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::eq_(self_tensor, bridge::GetXlaTensor(other)); - return self; -} - at::Tensor erf(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor(XLATensor::erf(bridge::GetXlaTensor(self))); } -at::Tensor& erf_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::erf_(self_tensor); - return self; -} - at::Tensor erfc(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor(XLATensor::erfc(bridge::GetXlaTensor(self))); } -at::Tensor& erfc_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::erfc_(self_tensor); - return self; -} - at::Tensor erfinv(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor( XLATensor::erfinv(bridge::GetXlaTensor(self))); } -at::Tensor& erfinv_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::erfinv_(self_tensor); - return self; -} - at::Tensor exp(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor(XLATensor::exp(bridge::GetXlaTensor(self))); } -at::Tensor& exp_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::exp_(self_tensor); - return self; -} - at::Tensor expand(const at::Tensor& self, at::IntArrayRef size, bool implicit) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor(XLATensor::expand( @@ -1403,13 +1201,6 @@ at::Tensor expm1(const at::Tensor& self) { XLATensor::expm1(bridge::GetXlaTensor(self))); } -at::Tensor& expm1_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::expm1_(self_tensor); - return self; -} - at::Tensor& exponential_(at::Tensor& self, double lambd, c10::optional generator) { XLA_FN_COUNTER("xla::"); @@ -1463,13 +1254,6 @@ at::Tensor floor(const at::Tensor& self) { XLATensor::floor(bridge::GetXlaTensor(self))); } -at::Tensor& floor_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::floor_(self_tensor); - return self; -} - at::Tensor fmod(const at::Tensor& self, const at::Tensor& other) { XLA_FN_COUNTER("xla::"); return DoBinaryOp(self, other, @@ -1488,34 +1272,11 @@ at::Tensor fmod(const at::Tensor& self, const at::Scalar& other) { }); } -at::Tensor& fmod_(at::Tensor& self, const at::Tensor& other) { - XLA_FN_COUNTER("xla::"); - CheckBinaryOpTypePromotion(self, self, other); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::fmod_(self_tensor, bridge::GetXlaTensor(other)); - return self; -} - -at::Tensor& fmod_(at::Tensor& self, const at::Scalar& other) { - XLA_FN_COUNTER("xla::"); - CheckBinaryOpTypePromotion(self, self, other); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::fmod_(self_tensor, other); - return self; -} - at::Tensor frac(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor(XLATensor::frac(bridge::GetXlaTensor(self))); } -at::Tensor& frac_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::frac_(self_tensor); - return self; -} - at::Tensor gather(const at::Tensor& self, int64_t dim, const at::Tensor& index, bool /* sparse_grad */) { XLA_FN_COUNTER("xla::"); @@ -1535,20 +1296,6 @@ at::Tensor ge(const at::Tensor& self, const at::Tensor& other) { XLATensor::ge(bridge::GetXlaTensor(self), bridge::GetXlaTensor(other))); } -at::Tensor& ge_(at::Tensor& self, const at::Scalar& other) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::ge_(self_tensor, other); - return self; -} - -at::Tensor& ge_(at::Tensor& self, const at::Tensor& other) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::ge_(self_tensor, bridge::GetXlaTensor(other)); - return self; -} - at::Tensor gelu(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor(XLATensor::gelu(bridge::GetXlaTensor(self))); @@ -1578,20 +1325,6 @@ at::Tensor gt(const at::Tensor& self, const at::Tensor& other) { XLATensor::gt(bridge::GetXlaTensor(self), bridge::GetXlaTensor(other))); } -at::Tensor& gt_(at::Tensor& self, const at::Scalar& other) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::gt_(self_tensor, other); - return self; -} - -at::Tensor& gt_(at::Tensor& self, const at::Tensor& other) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::gt_(self_tensor, bridge::GetXlaTensor(other)); - return self; -} - at::Tensor hardshrink(const at::Tensor& self, const at::Scalar& lambda) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor( @@ -1604,13 +1337,6 @@ at::Tensor hardsigmoid(const at::Tensor& self) { XLATensor::hardsigmoid(bridge::GetXlaTensor(self))); } -at::Tensor& hardsigmoid_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::hardsigmoid_(self_tensor); - return self; -} - at::Tensor hardsigmoid_backward(const at::Tensor& grad_output, const at::Tensor& self) { XLA_FN_COUNTER("xla::"); @@ -1633,14 +1359,6 @@ at::Tensor hardtanh(const at::Tensor& self, const at::Scalar& min_val, XLATensor::clamp(bridge::GetXlaTensor(self), min_val, max_val)); } -at::Tensor& hardtanh_(at::Tensor& self, const at::Scalar& min_val, - const at::Scalar& max_val) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::clamp_(self_tensor, min_val, max_val); - return self; -} - at::Tensor hardtanh_backward(const at::Tensor& grad_output, const at::Tensor& self, const at::Scalar& min_val, const at::Scalar& max_val) { @@ -1777,20 +1495,6 @@ at::Tensor le(const at::Tensor& self, const at::Tensor& other) { XLATensor::le(bridge::GetXlaTensor(self), bridge::GetXlaTensor(other))); } -at::Tensor& le_(at::Tensor& self, const at::Scalar& other) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::le_(self_tensor, other); - return self; -} - -at::Tensor& le_(at::Tensor& self, const at::Tensor& other) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::le_(self_tensor, bridge::GetXlaTensor(other)); - return self; -} - at::Tensor leaky_relu(const at::Tensor& self, const at::Scalar& negative_slope) { XLA_FN_COUNTER("xla::"); @@ -1798,13 +1502,6 @@ at::Tensor leaky_relu(const at::Tensor& self, bridge::GetXlaTensor(self), negative_slope.to())); } -at::Tensor& leaky_relu_(at::Tensor& self, const at::Scalar& negative_slope) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::leaky_relu_(self_tensor, negative_slope.to()); - return self; -} - at::Tensor leaky_relu_backward(const at::Tensor& grad_output, const at::Tensor& self, const at::Scalar& negative_slope, @@ -1827,46 +1524,18 @@ at::Tensor log10(const at::Tensor& self) { bridge::GetXlaTensor(self), ir::OpKind(at::aten::log10), 10.0)); } -at::Tensor& log10_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::log_base_(self_tensor, ir::OpKind(at::aten::log10), 10.0); - return self; -} - at::Tensor log1p(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor( XLATensor::log1p(bridge::GetXlaTensor(self))); } -at::Tensor& log1p_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::log1p_(self_tensor); - return self; -} - at::Tensor log2(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor(XLATensor::log_base( bridge::GetXlaTensor(self), ir::OpKind(at::aten::log2), 2.0)); } -at::Tensor& log2_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::log_base_(self_tensor, ir::OpKind(at::aten::log2), 2.0); - return self; -} - -at::Tensor& log_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::log_(self_tensor); - return self; -} - at::Tensor log_sigmoid_backward(const at::Tensor& grad_output, const at::Tensor& self, const at::Tensor& buffer) { @@ -1910,20 +1579,6 @@ at::Tensor lt(const at::Tensor& self, const at::Tensor& other) { XLATensor::lt(bridge::GetXlaTensor(self), bridge::GetXlaTensor(other))); } -at::Tensor& lt_(at::Tensor& self, const at::Scalar& other) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::lt_(self_tensor, other); - return self; -} - -at::Tensor& lt_(at::Tensor& self, const at::Tensor& other) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::lt_(self_tensor, bridge::GetXlaTensor(other)); - return self; -} - at::Tensor& masked_fill_(at::Tensor& self, const at::Tensor& mask, const at::Scalar& value) { XLA_FN_COUNTER("xla::"); @@ -2220,23 +1875,6 @@ at::Tensor mul(const at::Tensor& self, const at::Scalar& other) { }); } -at::Tensor& mul_(at::Tensor& self, const at::Tensor& other) { - XLA_FN_COUNTER("xla::"); - CheckBinaryOpTypePromotion(self, self, other); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::mul_(self_tensor, - bridge::GetOrCreateXlaTensor(other, self_tensor.GetDevice())); - return self; -} - -at::Tensor& mul_(at::Tensor& self, const at::Scalar& other) { - XLA_FN_COUNTER("xla::"); - CheckBinaryOpTypePromotion(self, self, other); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::mul_(self_tensor, other); - return self; -} - at::Tensor mv(const at::Tensor& self, const at::Tensor& vec) { XLA_FN_COUNTER("xla::"); // xla::dot doesn't support integer types. @@ -2322,20 +1960,6 @@ at::Tensor ne(const at::Tensor& self, const at::Tensor& other) { XLATensor::ne(bridge::GetXlaTensor(self), bridge::GetXlaTensor(other))); } -at::Tensor& ne_(at::Tensor& self, const at::Scalar& other) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::ne_(self_tensor, other); - return self; -} - -at::Tensor& ne_(at::Tensor& self, const at::Tensor& other) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::ne_(self_tensor, bridge::GetXlaTensor(other)); - return self; -} - at::Tensor neg(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); XLA_CHECK(self.scalar_type() != at::kBool) @@ -2345,13 +1969,6 @@ at::Tensor neg(const at::Tensor& self) { return bridge::AtenFromXlaTensor(XLATensor::neg(bridge::GetXlaTensor(self))); } -at::Tensor& neg_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::neg_(self_tensor); - return self; -} - at::Tensor nll_loss2d_backward(const at::Tensor& grad_output, const at::Tensor& self, const at::Tensor& target, const c10::optional& weight, @@ -2558,28 +2175,6 @@ at::Tensor pow(const at::Scalar& self, const at::Tensor& exponent) { XLATensor::pow(self, bridge::GetXlaTensor(exponent))); } -at::Tensor& pow_(at::Tensor& self, const at::Scalar& exponent) { - XLA_FN_COUNTER("xla::"); - // xla::Pow() doesn't support integer types. - if (!at::native::is_floating_point(self)) { - return AtenXlaTypeDefault::pow_(self, exponent); - } - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::pow_(self_tensor, exponent); - return self; -} - -at::Tensor& pow_(at::Tensor& self, const at::Tensor& exponent) { - XLA_FN_COUNTER("xla::"); - // xla::Pow() doesn't support integer types. - if (!at::native::is_floating_point(self)) { - return AtenXlaTypeDefault::pow_(self, exponent); - } - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::pow_(self_tensor, bridge::GetXlaTensor(exponent)); - return self; -} - at::Tensor prod(const at::Tensor& self, c10::optional dtype) { XLA_FN_COUNTER("xla::"); XLATensor self_tensor = bridge::GetXlaTensor(self); @@ -2665,13 +2260,6 @@ at::Tensor reciprocal(const at::Tensor& self) { XLATensor::reciprocal(bridge::GetXlaTensor(self))); } -at::Tensor& reciprocal_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::reciprocal_(self_tensor); - return self; -} - at::Tensor reflection_pad2d(const at::Tensor& self, at::IntArrayRef padding) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor(XLATensor::reflection_pad2d( @@ -2711,20 +2299,6 @@ at::Tensor remainder(const at::Tensor& self, const at::Scalar& other) { XLATensor::remainder(bridge::GetXlaTensor(self), other)); } -at::Tensor& remainder_(at::Tensor& self, const at::Tensor& other) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::remainder_(self_tensor, bridge::GetXlaTensor(other)); - return self; -} - -at::Tensor& remainder_(at::Tensor& self, const at::Scalar& other) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::remainder_(self_tensor, other); - return self; -} - at::Tensor repeat(const at::Tensor& self, at::IntArrayRef repeats) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor(XLATensor::repeat( @@ -2775,13 +2349,6 @@ at::Tensor round(const at::Tensor& self) { XLATensor::round(bridge::GetXlaTensor(self))); } -at::Tensor& round_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::round_(self_tensor); - return self; -} - at::Tensor rrelu_with_noise(const at::Tensor& self, const at::Tensor& noise, const at::Scalar& lower, const at::Scalar& upper, bool training, @@ -2819,13 +2386,6 @@ at::Tensor rsqrt(const at::Tensor& self) { XLATensor::rsqrt(bridge::GetXlaTensor(self))); } -at::Tensor& rsqrt_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::rsqrt_(self_tensor); - return self; -} - at::Tensor rsub(const at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha) { XLA_FN_COUNTER("xla::"); @@ -2891,13 +2451,6 @@ at::Tensor sigmoid(const at::Tensor& self) { XLATensor::sigmoid(bridge::GetXlaTensor(self))); } -at::Tensor& sigmoid_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::sigmoid_(self_tensor); - return self; -} - at::Tensor sigmoid_backward(const at::Tensor& grad_output, const at::Tensor& output) { XLA_FN_COUNTER("xla::"); @@ -2910,37 +2463,16 @@ at::Tensor sign(const at::Tensor& self) { return bridge::AtenFromXlaTensor(XLATensor::sign(bridge::GetXlaTensor(self))); } -at::Tensor& sign_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::sign_(self_tensor); - return self; -} - at::Tensor sin(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor(XLATensor::sin(bridge::GetXlaTensor(self))); } -at::Tensor& sin_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::sin_(self_tensor); - return self; -} - at::Tensor sinh(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor(XLATensor::sinh(bridge::GetXlaTensor(self))); } -at::Tensor& sinh_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::sinh_(self_tensor); - return self; -} - at::Tensor slice(const at::Tensor& self, int64_t dim, c10::optional start, c10::optional end, int64_t step) { @@ -3031,13 +2563,6 @@ at::Tensor sqrt(const at::Tensor& self) { return bridge::AtenFromXlaTensor(XLATensor::sqrt(bridge::GetXlaTensor(self))); } -at::Tensor& sqrt_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::sqrt_(self_tensor); - return self; -} - at::Tensor squeeze(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor( @@ -3121,29 +2646,6 @@ at::Tensor sub(const at::Tensor& self, const at::Scalar& other, }); } -at::Tensor& sub_(at::Tensor& self, const at::Tensor& other, - const at::Scalar& alpha) { - XLA_FN_COUNTER("xla::"); - CheckBinaryOpTypePromotion(self, self, other); - at::native::alpha_check(at::result_type(self, other), alpha); - CheckSubOperandTypes(self.scalar_type(), other.scalar_type()); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::sub_(self_tensor, - bridge::GetOrCreateXlaTensor(other, self_tensor.GetDevice()), - alpha); - return self; -} - -at::Tensor& sub_(at::Tensor& self, const at::Scalar& other, - const at::Scalar& alpha) { - XLA_FN_COUNTER("xla::"); - CheckBinaryOpTypePromotion(self, self, other); - CheckSubOperandTypes(self.scalar_type(), GetScalarType(other)); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::sub_(self_tensor, other, alpha); - return self; -} - at::Tensor sum(const at::Tensor& self, c10::optional dtype) { XLA_FN_COUNTER("xla::"); XLATensor self_tensor = bridge::GetXlaTensor(self); @@ -3203,25 +2705,11 @@ at::Tensor tan(const at::Tensor& self) { return bridge::AtenFromXlaTensor(XLATensor::tan(bridge::GetXlaTensor(self))); } -at::Tensor& tan_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::tan_(self_tensor); - return self; -} - at::Tensor tanh(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor(XLATensor::tanh(bridge::GetXlaTensor(self))); } -at::Tensor& tanh_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::tanh_(self_tensor); - return self; -} - at::Tensor tanh_backward(const at::Tensor& grad_output, const at::Tensor& output) { XLA_FN_COUNTER("xla::"); @@ -3236,15 +2724,6 @@ at::Tensor threshold(const at::Tensor& self, const at::Scalar& threshold, bridge::GetXlaTensor(self), threshold.to(), value.to())); } -at::Tensor& threshold_(at::Tensor& self, const at::Scalar& threshold, - const at::Scalar& value) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::threshold_(self_tensor, threshold.to(), - value.to()); - return self; -} - at::Tensor threshold_backward(const at::Tensor& grad_output, const at::Tensor& self, const at::Scalar& threshold) { @@ -3329,13 +2808,6 @@ at::Tensor trunc(const at::Tensor& self) { XLATensor::trunc(bridge::GetXlaTensor(self))); } -at::Tensor& trunc_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::trunc_(self_tensor); - return self; -} - std::vector unbind(const at::Tensor& self, int64_t dim) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensors( diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 5cffbb1784e..ed15b502c23 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -259,24 +259,17 @@ class XLATensor { int growth_interval); static XLATensor abs(const XLATensor& input); - static void abs_(XLATensor& input); static XLATensor acos(const XLATensor& input); - static void acos_(XLATensor& input); static XLATensor acosh(const XLATensor& input); - static void acosh_(XLATensor& input); static XLATensor add( const XLATensor& input, const XLATensor& other, const at::Scalar& alpha, c10::optional logical_element_type = c10::nullopt); - static void add_(XLATensor& input, const XLATensor& other, - const at::Scalar& alpha); static XLATensor add( const XLATensor& input, const at::Scalar& other, const at::Scalar& alpha, c10::optional logical_element_type = c10::nullopt); - static void add_(XLATensor& input, const at::Scalar& other, - const at::Scalar& alpha); static XLATensor addcdiv(const XLATensor& input, const at::Scalar& value, const XLATensor& tensor1, const XLATensor& tensor2); @@ -285,8 +278,6 @@ class XLATensor { static XLATensor addcmul(const XLATensor& input, const at::Scalar& value, const XLATensor& tensor1, const XLATensor& tensor2); - static void addcmul_(XLATensor& input, const at::Scalar& value, - const XLATensor& tensor1, const XLATensor& tensor2); static XLATensor addmm(const XLATensor& input, const XLATensor& weight, const XLATensor& bias); @@ -322,21 +313,16 @@ class XLATensor { c10::optional storage_offset); static XLATensor asin(const XLATensor& input); - static void asin_(XLATensor& input); static XLATensor asinh(const XLATensor& input); - static void asinh_(XLATensor& input); static XLATensor atan(const XLATensor& input); - static void atan_(XLATensor& input); static XLATensor atanh(const XLATensor& input); - static void atanh_(XLATensor& input); static XLATensor atan2( const XLATensor& input, const XLATensor& other, c10::optional logical_element_type = c10::nullopt); - static void atan2_(XLATensor& input, const XLATensor& other); static XLATensor avg_pool_nd(const XLATensor& input, xla::int64 spatial_dim_count, @@ -356,9 +342,6 @@ class XLATensor { static XLATensor baddbmm(const XLATensor& input, const XLATensor& batch1, const XLATensor& batch2, const at::Scalar& beta, const at::Scalar& alpha); - static void baddbmm_(XLATensor& input, const XLATensor& batch1, - const XLATensor& batch2, const at::Scalar& beta, - const at::Scalar& alpha); static XLATensor bernoulli(const XLATensor& input, double probability); static XLATensor bernoulli(const XLATensor& input); @@ -408,7 +391,6 @@ class XLATensor { static XLATensor cat(absl::Span tensors, xla::int64 dim); static XLATensor ceil(const XLATensor& input); - static void ceil_(XLATensor& input); static XLATensor cholesky(const XLATensor& input, bool upper); @@ -418,8 +400,6 @@ class XLATensor { static XLATensor clamp(const XLATensor& input, const c10::optional& min, const c10::optional& max); - static void clamp_(XLATensor& input, const c10::optional& min, - const c10::optional& max); static void clamp_out(XLATensor& out, const XLATensor& input, const c10::optional& min, const c10::optional& max); @@ -453,10 +433,8 @@ class XLATensor { std::vector output_padding, xla::int64 groups); static XLATensor cos(const XLATensor& input); - static void cos_(XLATensor& input); static XLATensor cosh(const XLATensor& input); - static void cosh_(XLATensor& input); // Returns the cross product of the two input tensors in the given dimension. // If the dimension is not given, it defaults to the first dimension found @@ -487,10 +465,6 @@ class XLATensor { const c10::optional& rounding_mode = c10::nullopt, c10::optional logical_element_type = c10::nullopt); static XLATensor div(const XLATensor& input, const at::Scalar& other); - static void div_( - XLATensor& input, const XLATensor& other, - const c10::optional& rounding_mode = c10::nullopt); - static void div_(XLATensor& input, const at::Scalar& other); // A generalized contraction between tensors of arbitrary dimension defined by // the given equation and applied to the input tensors. @@ -514,27 +488,20 @@ class XLATensor { bool scale_grad_by_freq); static XLATensor eq(const XLATensor& input, const at::Scalar& other); - static void eq_(XLATensor& input, const at::Scalar& other); static XLATensor eq(const XLATensor& input, const XLATensor& other); - static void eq_(XLATensor& input, const XLATensor& other); static XLATensor erf(const XLATensor& input); - static void erf_(XLATensor& input); static XLATensor erfc(const XLATensor& input); - static void erfc_(XLATensor& input); static XLATensor erfinv(const XLATensor& input); - static void erfinv_(XLATensor& input); static XLATensor exp(const XLATensor& input); - static void exp_(XLATensor& input); static XLATensor expand(const XLATensor& input, std::vector size); static XLATensor expm1(const XLATensor& input); - static void expm1_(XLATensor& input); static void exponential_(XLATensor& input, double lambd); @@ -552,7 +519,6 @@ class XLATensor { absl::Span dims); static XLATensor floor(const XLATensor& input); - static void floor_(XLATensor& input); static XLATensor fmod( const XLATensor& input, const XLATensor& other, @@ -560,11 +526,8 @@ class XLATensor { static XLATensor fmod( const XLATensor& input, const at::Scalar& other, c10::optional logical_element_type = c10::nullopt); - static void fmod_(XLATensor& input, const XLATensor& other); - static void fmod_(XLATensor& input, const at::Scalar& other); static XLATensor frac(const XLATensor& input); - static void frac_(XLATensor& input); static XLATensor full(absl::Span size, const at::Scalar& fill_value, const Device& device, @@ -577,10 +540,8 @@ class XLATensor { const XLATensor& index); static XLATensor ge(const XLATensor& input, const at::Scalar& other); - static void ge_(XLATensor& input, const at::Scalar& other); static XLATensor ge(const XLATensor& input, const XLATensor& other); - static void ge_(XLATensor& input, const XLATensor& other); static XLATensor gelu(const XLATensor& input); static XLATensor gelu_backward(const XLATensor& grad, const XLATensor& input); @@ -588,10 +549,8 @@ class XLATensor { static XLATensor ger(const XLATensor& input, const XLATensor& vec2); static XLATensor gt(const XLATensor& input, const at::Scalar& other); - static void gt_(XLATensor& input, const at::Scalar& other); static XLATensor gt(const XLATensor& input, const XLATensor& other); - static void gt_(XLATensor& input, const XLATensor& other); // Gather slices from input into a result with shape specified by indices. The // shape of the indices are first made consistent using broadcast semantics. @@ -667,10 +626,8 @@ class XLATensor { xla::int64 reduction); static XLATensor le(const XLATensor& input, const at::Scalar& other); - static void le_(XLATensor& input, const at::Scalar& other); static XLATensor le(const XLATensor& input, const XLATensor& other); - static void le_(XLATensor& input, const XLATensor& other); static XLATensor hardshrink(const XLATensor& input, const at::Scalar& lambda); static XLATensor hardshrink_backward(const XLATensor& grad_out, @@ -679,8 +636,6 @@ class XLATensor { static XLATensor hardsigmoid(const XLATensor& input); - static void hardsigmoid_(XLATensor& input); - static XLATensor hardsigmoid_backward(const XLATensor& grad_output, const XLATensor& input); @@ -693,13 +648,10 @@ class XLATensor { static XLATensor leaky_relu_backward(const XLATensor& grad_output, const XLATensor& input, double negative_slope); - static void leaky_relu_(XLATensor& input, double negative_slope); static XLATensor log(const XLATensor& input); - static void log_(XLATensor& input); static XLATensor log_base(const XLATensor& input, ir::OpKind op, double base); - static void log_base_(XLATensor& input, ir::OpKind op, double base); static XLATensor log_sigmoid(const XLATensor& input); static std::tuple log_sigmoid_forward( @@ -725,10 +677,8 @@ class XLATensor { bool keep_reduced_dimensions); static XLATensor lt(const XLATensor& input, const at::Scalar& other); - static void lt_(XLATensor& input, const at::Scalar& other); static XLATensor lt(const XLATensor& input, const XLATensor& other); - static void lt_(XLATensor& input, const XLATensor& other); // In-place version of the method above. static void masked_fill_(XLATensor& input, const XLATensor& mask, @@ -807,8 +757,6 @@ class XLATensor { static XLATensor mul( const XLATensor& input, const at::Scalar& other, c10::optional logical_element_type = c10::nullopt); - static void mul_(XLATensor& input, const XLATensor& other); - static void mul_(XLATensor& input, const at::Scalar& other); static XLATensor mv(const XLATensor& input, const XLATensor& vec); static void mv_out(XLATensor& out, const XLATensor& input, @@ -833,13 +781,10 @@ class XLATensor { const XLATensor& save_invstd, bool training, double eps); static XLATensor ne(const XLATensor& input, const at::Scalar& other); - static void ne_(XLATensor& input, const at::Scalar& other); static XLATensor ne(const XLATensor& input, const XLATensor& other); - static void ne_(XLATensor& input, const XLATensor& other); static XLATensor neg(const XLATensor& input); - static void neg_(XLATensor& input); static XLATensor nll_loss(const XLATensor& input, const XLATensor& target, const XLATensor& weight, xla::int64 reduction, @@ -894,8 +839,6 @@ class XLATensor { static XLATensor pow(const XLATensor& input, const at::Scalar& exponent); static XLATensor pow(const XLATensor& input, const XLATensor& exponent); static XLATensor pow(const at::Scalar& input, const XLATensor& exponent); - static void pow_(XLATensor& input, const at::Scalar& exponent); - static void pow_(XLATensor& input, const XLATensor& exponent); static XLATensor prod(const XLATensor& input, std::vector dimensions, @@ -913,7 +856,6 @@ class XLATensor { at::ScalarType scalar_type); static XLATensor reciprocal(const XLATensor& input); - static void reciprocal_(XLATensor& input); static XLATensor reflection_pad2d(const XLATensor& input, std::vector padding); @@ -927,8 +869,6 @@ class XLATensor { static XLATensor remainder(const XLATensor& input, const XLATensor& other); static XLATensor remainder(const XLATensor& input, const at::Scalar& other); - static void remainder_(XLATensor& input, const XLATensor& other); - static void remainder_(XLATensor& input, const at::Scalar& other); // Repeats the input tensor along each dimension by the given number of // repeats. @@ -950,7 +890,6 @@ class XLATensor { static void resize_(XLATensor& input, std::vector size); static XLATensor round(const XLATensor& input); - static void round_(XLATensor& input); static XLATensor rrelu_with_noise(const XLATensor& input, XLATensor& noise, const at::Scalar& lower, @@ -964,7 +903,6 @@ class XLATensor { bool training); static XLATensor rsqrt(const XLATensor& input); - static void rsqrt_(XLATensor& input); static XLATensor rsub( const XLATensor& input, const XLATensor& other, const at::Scalar& alpha, @@ -988,18 +926,14 @@ class XLATensor { static void silu_out(XLATensor& input, XLATensor& out); static XLATensor sigmoid(const XLATensor& input); - static void sigmoid_(XLATensor& input); static XLATensor sigmoid_backward(const XLATensor& grad_output, const XLATensor& output); static XLATensor sign(const XLATensor& input); - static void sign_(XLATensor& input); static XLATensor sin(const XLATensor& input); - static void sin_(XLATensor& input); static XLATensor sinh(const XLATensor& input); - static void sinh_(XLATensor& input); static XLATensor slice(const XLATensor& input, xla::int64 dim, xla::int64 start, xla::int64 end, xla::int64 step); @@ -1042,7 +976,6 @@ class XLATensor { xla::int64 dim); static XLATensor sqrt(const XLATensor& input); - static void sqrt_(XLATensor& input); // Squeeze out all trivial (size 1) dimensions. static XLATensor squeeze(const XLATensor& input); @@ -1064,13 +997,9 @@ class XLATensor { static XLATensor sub( const XLATensor& input, const XLATensor& other, const at::Scalar& alpha, c10::optional logical_element_type = c10::nullopt); - static void sub_(XLATensor& input, const XLATensor& other, - const at::Scalar& alpha); static XLATensor sub( const XLATensor& input, const at::Scalar& other, const at::Scalar& alpha, c10::optional logical_element_type = c10::nullopt); - static void sub_(XLATensor& input, const at::Scalar& other, - const at::Scalar& alpha); static XLATensor sum(const XLATensor& input, std::vector dimensions, @@ -1087,16 +1016,13 @@ class XLATensor { static XLATensor take(const XLATensor& input, const XLATensor& index); static XLATensor tan(const XLATensor& input); - static void tan_(XLATensor& input); static XLATensor tanh(const XLATensor& input); - static void tanh_(XLATensor& input); static XLATensor tanh_backward(const XLATensor& grad_output, const XLATensor& output); static XLATensor threshold(const XLATensor& input, float threshold, float value); - static void threshold_(XLATensor& input, float threshold, float value); static XLATensor threshold_backward(const XLATensor& grad_output, const XLATensor& input, float threshold); @@ -1137,7 +1063,6 @@ class XLATensor { static void triu_(XLATensor& input, xla::int64 diagonal); static XLATensor trunc(const XLATensor& input); - static void trunc_(XLATensor& input); // Returns a tuple of all slices along a given dimension with that dimension // removed. diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 2e3acd261d1..83694adc906 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -502,26 +502,14 @@ XLATensor XLATensor::abs(const XLATensor& input) { return input.CreateFrom(ir::ops::Abs(input.GetIrValue())); } -void XLATensor::abs_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Abs(input.GetIrValue())); -} - XLATensor XLATensor::acos(const XLATensor& input) { return input.CreateFrom(ir::ops::Acos(input.GetIrValue())); } -void XLATensor::acos_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Acos(input.GetIrValue())); -} - XLATensor XLATensor::acosh(const XLATensor& input) { return input.CreateFrom(ir::ops::Acosh(input.GetIrValue())); } -void XLATensor::acosh_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Acosh(input.GetIrValue())); -} - XLATensor XLATensor::add(const XLATensor& input, const XLATensor& other, const at::Scalar& alpha, c10::optional logical_element_type) { @@ -531,13 +519,6 @@ XLATensor XLATensor::add(const XLATensor& input, const XLATensor& other, logical_element_type); } -void XLATensor::add_(XLATensor& input, const XLATensor& other, - const at::Scalar& alpha) { - ir::Value constant = - GetIrValueForScalar(alpha, other.shape(), input.GetDevice()); - input.SetInPlaceIrValue(input.GetIrValue() + other.GetIrValue() * constant); -} - XLATensor XLATensor::add(const XLATensor& input, const at::Scalar& other, const at::Scalar& alpha, c10::optional logical_element_type) { @@ -549,23 +530,6 @@ XLATensor XLATensor::add(const XLATensor& input, const at::Scalar& other, logical_element_type); } -void XLATensor::add_(XLATensor& input, const at::Scalar& other, - const at::Scalar& alpha) { - ir::Value other_constant = - GetIrValueForScalar(other, input.shape(), input.GetDevice()); - ir::Value alpha_constant = - GetIrValueForScalar(alpha, input.shape(), input.GetDevice()); - input.SetInPlaceIrValue(input.GetIrValue() + other_constant * alpha_constant); -} - -void XLATensor::addcmul_(XLATensor& input, const at::Scalar& value, - const XLATensor& tensor1, const XLATensor& tensor2) { - ir::Value constant = GetIrValueForScalar( - value, tensor1.shape().get().element_type(), input.GetDevice()); - ir::Value mul = tensor1.GetIrValue() * tensor2.GetIrValue(); - input.SetInPlaceIrValue(input.GetIrValue() + mul * constant); -} - XLATensor XLATensor::addcdiv(const XLATensor& input, const at::Scalar& value, const XLATensor& tensor1, const XLATensor& tensor2) { @@ -690,34 +654,18 @@ XLATensor XLATensor::asin(const XLATensor& input) { return input.CreateFrom(ir::ops::Asin(input.GetIrValue())); } -void XLATensor::asin_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Asin(input.GetIrValue())); -} - XLATensor XLATensor::asinh(const XLATensor& input) { return input.CreateFrom(ir::ops::Asinh(input.GetIrValue())); } -void XLATensor::asinh_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Asinh(input.GetIrValue())); -} - XLATensor XLATensor::atan(const XLATensor& input) { return input.CreateFrom(ir::ops::Atan(input.GetIrValue())); } -void XLATensor::atan_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Atan(input.GetIrValue())); -} - XLATensor XLATensor::atanh(const XLATensor& input) { return input.CreateFrom(ir::ops::Atanh(input.GetIrValue())); } -void XLATensor::atanh_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Atanh(input.GetIrValue())); -} - XLATensor XLATensor::atan2(const XLATensor& input, const XLATensor& other, c10::optional logical_element_type) { return input.CreateFrom( @@ -725,11 +673,6 @@ XLATensor XLATensor::atan2(const XLATensor& input, const XLATensor& other, logical_element_type); } -void XLATensor::atan2_(XLATensor& input, const XLATensor& other) { - input.SetInPlaceIrValue( - ir::ops::Atan2(input.GetIrValue(), other.GetIrValue())); -} - XLATensor XLATensor::avg_pool_nd(const XLATensor& input, xla::int64 spatial_dim_count, std::vector kernel_size, @@ -771,19 +714,6 @@ XLATensor XLATensor::baddbmm(const XLATensor& input, const XLATensor& batch1, product_multiplier, bias_multiplier)); } -void XLATensor::baddbmm_(XLATensor& input, const XLATensor& batch1, - const XLATensor& batch2, const at::Scalar& beta, - const at::Scalar& alpha) { - CheckBmmDimension(/*tag=*/"baddbmm_", batch1, batch2); - ir::Value product_multiplier = XLATensor::GetIrValueForScalar( - alpha, batch1.shape().get().element_type(), batch1.GetDevice()); - ir::Value bias_multiplier = XLATensor::GetIrValueForScalar( - beta, input.shape().get().element_type(), input.GetDevice()); - input.SetInPlaceIrValue(ir::ops::BaddBmm( - batch1.GetIrValue(), batch2.GetIrValue(), input.GetIrValue(), - product_multiplier, bias_multiplier)); -} - XLATensor XLATensor::bernoulli(const XLATensor& input, double probability) { auto input_shape = input.shape(); return input.CreateFrom(ir::MakeNode( @@ -925,10 +855,6 @@ XLATensor XLATensor::ceil(const XLATensor& input) { return input.CreateFrom(ir::ops::Ceil(input.GetIrValue())); } -void XLATensor::ceil_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Ceil(input.GetIrValue())); -} - XLATensor XLATensor::cholesky(const XLATensor& input, bool upper) { // Cholesky takes lower instead of upper, hence the negation. return input.CreateFrom( @@ -958,13 +884,6 @@ XLATensor XLATensor::clamp(const XLATensor& input, return input.CreateFrom(res); } -void XLATensor::clamp_(XLATensor& input, const c10::optional& min, - const c10::optional& max) { - MinMaxValues min_max = GetMinMaxValues(input, min, max); - input.SetInPlaceIrValue( - ir::ops::Clamp(input.GetIrValue(), min_max.min, min_max.max)); -} - void XLATensor::clamp_out(XLATensor& out, const XLATensor& input, const c10::optional& min, const c10::optional& max) { @@ -1039,18 +958,10 @@ XLATensor XLATensor::cos(const XLATensor& input) { return input.CreateFrom(ir::ops::Cos(input.GetIrValue())); } -void XLATensor::cos_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Cos(input.GetIrValue())); -} - XLATensor XLATensor::cosh(const XLATensor& input) { return input.CreateFrom(ir::ops::Cosh(input.GetIrValue())); } -void XLATensor::cosh_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Cosh(input.GetIrValue())); -} - XLATensor XLATensor::cross(const XLATensor& input, const XLATensor& other, c10::optional dim) { return tensor_ops::Cross(input, other, dim); @@ -1166,56 +1077,14 @@ XLATensor XLATensor::div(const XLATensor& input, const at::Scalar& other) { return input.CreateFrom(input_value / other_value, scalar_type); } -void XLATensor::div_(XLATensor& input, const XLATensor& other, - const c10::optional& rounding_mode) { - at::ScalarType scalar_type = - at::typeMetaToScalarType(c10::get_default_dtype()); - ir::Value input_value = GetFloatingIrValue(input, scalar_type); - ir::Value other_value = GetFloatingIrValue(other, scalar_type); - ir::Value res = input_value / other_value; - if (rounding_mode.has_value()) { - if (*rounding_mode == "trunc") { - res = ir::ops::Trunc(res); - } else if (*rounding_mode == "floor") { - res = ir::ops::Floor(res); - } else { - XLA_CHECK(false) - << "rounding_mode must be one of None, 'trunc', or 'floor'"; - } - } - input.SetInPlaceIrValue(res); -} - -void XLATensor::div_(XLATensor& input, const at::Scalar& other) { - at::ScalarType scalar_type = - at::typeMetaToScalarType(c10::get_default_dtype()); - ir::Value input_value = GetFloatingIrValue(input, scalar_type); - ir::Value other_value = GetIrValueForScalar( - other, input_value.shape().element_type(), input.GetDevice()); - input.SetInPlaceIrValue(input_value / other_value); -} - XLATensor XLATensor::eq(const XLATensor& input, const at::Scalar& other) { return DispatchComparisonOp(at::aten::eq, input, other); } -void XLATensor::eq_(XLATensor& input, const at::Scalar& other) { - ir::NodePtr cmp_result = - ir::ops::ComparisonOp(at::aten::eq, input.GetIrValue(), - GetIrValueForScalar(other, input.GetDevice())); - input.SetIrValue(ir::MakeNode(cmp_result, input.dtype())); -} - XLATensor XLATensor::eq(const XLATensor& input, const XLATensor& other) { return DispatchComparisonOp(at::aten::eq, input, other); } -void XLATensor::eq_(XLATensor& input, const XLATensor& other) { - ir::NodePtr cmp_result = ir::ops::ComparisonOp( - at::aten::eq, input.GetIrValue(), other.GetIrValue()); - input.SetIrValue(ir::MakeNode(cmp_result, input.dtype())); -} - XLATensor XLATensor::elu(const XLATensor& input, const at::Scalar& alpha, const at::Scalar& scale, const at::Scalar& input_scale) { @@ -1252,34 +1121,18 @@ XLATensor XLATensor::erf(const XLATensor& input) { return input.CreateFrom(ir::ops::Erf(input.GetIrValue())); } -void XLATensor::erf_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Erf(input.GetIrValue())); -} - XLATensor XLATensor::erfc(const XLATensor& input) { return input.CreateFrom(ir::ops::Erfc(input.GetIrValue())); } -void XLATensor::erfc_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Erfc(input.GetIrValue())); -} - XLATensor XLATensor::erfinv(const XLATensor& input) { return input.CreateFrom(ir::ops::Erfinv(input.GetIrValue())); } -void XLATensor::erfinv_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Erfinv(input.GetIrValue())); -} - XLATensor XLATensor::exp(const XLATensor& input) { return input.CreateFrom(ir::ops::Exp(input.GetIrValue())); } -void XLATensor::exp_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Exp(input.GetIrValue())); -} - XLATensor XLATensor::expand(const XLATensor& input, std::vector size) { auto input_shape = input.shape(); @@ -1292,10 +1145,6 @@ XLATensor XLATensor::expm1(const XLATensor& input) { return input.CreateFrom(ir::ops::Expm1(input.GetIrValue())); } -void XLATensor::expm1_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Expm1(input.GetIrValue())); -} - void XLATensor::exponential_(XLATensor& input, double lambd) { auto input_shape = input.shape(); input.SetInPlaceIrValue(ir::MakeNode( @@ -1339,10 +1188,6 @@ XLATensor XLATensor::floor(const XLATensor& input) { return input.CreateFrom(ir::ops::Floor(input.GetIrValue())); } -void XLATensor::floor_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Floor(input.GetIrValue())); -} - XLATensor XLATensor::fmod(const XLATensor& input, const XLATensor& other, c10::optional logical_element_type) { return input.CreateFrom(ir::ops::Fmod(input.GetIrValue(), other.GetIrValue()), @@ -1357,25 +1202,10 @@ XLATensor XLATensor::fmod(const XLATensor& input, const at::Scalar& other, logical_element_type); } -void XLATensor::fmod_(XLATensor& input, const XLATensor& other) { - input.SetInPlaceIrValue( - ir::ops::Fmod(input.GetIrValue(), other.GetIrValue())); -} - -void XLATensor::fmod_(XLATensor& input, const at::Scalar& other) { - ir::Value constant = - GetIrValueForScalar(other, input.shape(), input.GetDevice()); - input.SetInPlaceIrValue(ir::ops::Fmod(input.GetIrValue(), constant)); -} - XLATensor XLATensor::frac(const XLATensor& input) { return input.CreateFrom(ir::ops::FracOp(input.GetIrValue())); } -void XLATensor::frac_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::FracOp(input.GetIrValue())); -} - XLATensor XLATensor::full(absl::Span size, const at::Scalar& fill_value, const Device& device, at::ScalarType scalar_type) { @@ -1413,23 +1243,10 @@ XLATensor XLATensor::ge(const XLATensor& input, const at::Scalar& other) { return DispatchComparisonOp(at::aten::ge, input, other); } -void XLATensor::ge_(XLATensor& input, const at::Scalar& other) { - ir::NodePtr cmp_result = - ir::ops::ComparisonOp(at::aten::ge, input.GetIrValue(), - GetIrValueForScalar(other, input.GetDevice())); - input.SetIrValue(ir::MakeNode(cmp_result, input.dtype())); -} - XLATensor XLATensor::ge(const XLATensor& input, const XLATensor& other) { return DispatchComparisonOp(at::aten::ge, input, other); } -void XLATensor::ge_(XLATensor& input, const XLATensor& other) { - ir::NodePtr cmp_result = ir::ops::ComparisonOp( - at::aten::ge, input.GetIrValue(), other.GetIrValue()); - input.SetIrValue(ir::MakeNode(cmp_result, input.dtype())); -} - XLATensor XLATensor::gelu(const XLATensor& input) { return input.CreateFrom(ir::ops::Gelu(input.GetIrValue())); } @@ -1448,23 +1265,10 @@ XLATensor XLATensor::gt(const XLATensor& input, const at::Scalar& other) { return DispatchComparisonOp(at::aten::gt, input, other); } -void XLATensor::gt_(XLATensor& input, const at::Scalar& other) { - ir::NodePtr cmp_result = - ir::ops::ComparisonOp(at::aten::gt, input.GetIrValue(), - GetIrValueForScalar(other, input.GetDevice())); - input.SetIrValue(ir::MakeNode(cmp_result, input.dtype())); -} - XLATensor XLATensor::gt(const XLATensor& input, const XLATensor& other) { return DispatchComparisonOp(at::aten::gt, input, other); } -void XLATensor::gt_(XLATensor& input, const XLATensor& other) { - ir::NodePtr cmp_result = ir::ops::ComparisonOp( - at::aten::gt, input.GetIrValue(), other.GetIrValue()); - input.SetIrValue(ir::MakeNode(cmp_result, input.dtype())); -} - XLATensor XLATensor::index(const XLATensor& input, absl::Span indices, xla::int64 start_dim) { @@ -1601,23 +1405,10 @@ XLATensor XLATensor::le(const XLATensor& input, const at::Scalar& other) { return DispatchComparisonOp(at::aten::le, input, other); } -void XLATensor::le_(XLATensor& input, const at::Scalar& other) { - ir::NodePtr cmp_result = - ir::ops::ComparisonOp(at::aten::le, input.GetIrValue(), - GetIrValueForScalar(other, input.GetDevice())); - input.SetIrValue(ir::MakeNode(cmp_result, input.dtype())); -} - XLATensor XLATensor::le(const XLATensor& input, const XLATensor& other) { return DispatchComparisonOp(at::aten::le, input, other); } -void XLATensor::le_(XLATensor& input, const XLATensor& other) { - ir::NodePtr cmp_result = ir::ops::ComparisonOp( - at::aten::le, input.GetIrValue(), other.GetIrValue()); - input.SetIrValue(ir::MakeNode(cmp_result, input.dtype())); -} - XLATensor XLATensor::hardshrink(const XLATensor& input, const at::Scalar& lambda) { return input.CreateFrom( @@ -1636,10 +1427,6 @@ XLATensor XLATensor::hardsigmoid(const XLATensor& input) { return input.CreateFrom(ir::ops::HardSigmoid(input.GetIrValue())); } -void XLATensor::hardsigmoid_(XLATensor& input) { - input.SetIrValue(ir::ops::HardSigmoid(input.GetIrValue())); -} - XLATensor XLATensor::hardsigmoid_backward(const XLATensor& grad_output, const XLATensor& input) { return input.CreateFrom(ir::ops::HardSigmoidBackward(grad_output.GetIrValue(), @@ -1666,28 +1453,15 @@ XLATensor XLATensor::leaky_relu_backward(const XLATensor& grad_output, grad_output.GetIrValue(), input.GetIrValue(), negative_slope)); } -void XLATensor::leaky_relu_(XLATensor& input, double negative_slope) { - input.SetInPlaceIrValue( - ir::MakeNode(input.GetIrValue(), negative_slope)); -} - XLATensor XLATensor::log(const XLATensor& input) { return input.CreateFrom(ir::ops::Log(input.GetIrValue())); } -void XLATensor::log_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Log(input.GetIrValue())); -} - XLATensor XLATensor::log_base(const XLATensor& input, ir::OpKind op, double base) { return input.CreateFrom(ir::ops::LogBase(input.GetIrValue(), op, base)); } -void XLATensor::log_base_(XLATensor& input, ir::OpKind op, double base) { - input.SetInPlaceIrValue(ir::ops::LogBase(input.GetIrValue(), op, base)); -} - XLATensor XLATensor::log_sigmoid(const XLATensor& input) { return input.CreateFrom(std::get<0>(ir::ops::LogSigmoid(input.GetIrValue()))); } @@ -1752,23 +1526,10 @@ XLATensor XLATensor::lt(const XLATensor& input, const at::Scalar& other) { return DispatchComparisonOp(at::aten::lt, input, other); } -void XLATensor::lt_(XLATensor& input, const at::Scalar& other) { - ir::NodePtr cmp_result = - ir::ops::ComparisonOp(at::aten::lt, input.GetIrValue(), - GetIrValueForScalar(other, input.GetDevice())); - input.SetIrValue(ir::MakeNode(cmp_result, input.dtype())); -} - XLATensor XLATensor::lt(const XLATensor& input, const XLATensor& other) { return DispatchComparisonOp(at::aten::lt, input, other); } -void XLATensor::lt_(XLATensor& input, const XLATensor& other) { - ir::NodePtr cmp_result = ir::ops::ComparisonOp( - at::aten::lt, input.GetIrValue(), other.GetIrValue()); - input.SetIrValue(ir::MakeNode(cmp_result, input.dtype())); -} - void XLATensor::masked_fill_(XLATensor& input, const XLATensor& mask, const at::Scalar& value) { ir::ScopePusher ir_scope(at::aten::masked_fill.toQualString()); @@ -1953,16 +1714,6 @@ XLATensor XLATensor::mul(const XLATensor& input, const at::Scalar& other, return input.CreateFrom(input.GetIrValue() * constant, logical_element_type); } -void XLATensor::mul_(XLATensor& input, const XLATensor& other) { - input.SetInPlaceIrValue(input.GetIrValue() * other.GetIrValue()); -} - -void XLATensor::mul_(XLATensor& input, const at::Scalar& other) { - ir::Value constant = - GetIrValueForScalar(other, input.shape(), input.GetDevice()); - input.SetInPlaceIrValue(input.GetIrValue() * constant); -} - XLATensor XLATensor::mv(const XLATensor& input, const XLATensor& vec) { return input.CreateFrom(ir::ops::Dot(input.GetIrValue(), vec.GetIrValue())); } @@ -2048,31 +1799,14 @@ XLATensor XLATensor::ne(const XLATensor& input, const at::Scalar& other) { return DispatchComparisonOp(at::aten::ne, input, other); } -void XLATensor::ne_(XLATensor& input, const at::Scalar& other) { - ir::NodePtr cmp_result = - ir::ops::ComparisonOp(at::aten::ne, input.GetIrValue(), - GetIrValueForScalar(other, input.GetDevice())); - input.SetIrValue(ir::MakeNode(cmp_result, input.dtype())); -} - XLATensor XLATensor::ne(const XLATensor& input, const XLATensor& other) { return DispatchComparisonOp(at::aten::ne, input, other); } -void XLATensor::ne_(XLATensor& input, const XLATensor& other) { - ir::NodePtr cmp_result = ir::ops::ComparisonOp( - at::aten::ne, input.GetIrValue(), other.GetIrValue()); - input.SetIrValue(ir::MakeNode(cmp_result, input.dtype())); -} - XLATensor XLATensor::neg(const XLATensor& input) { return input.CreateFrom(ir::ops::Neg(input.GetIrValue())); } -void XLATensor::neg_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Neg(input.GetIrValue())); -} - XLATensor XLATensor::nll_loss(const XLATensor& input, const XLATensor& target, const XLATensor& weight, xla::int64 reduction, int ignore_index) { @@ -2203,17 +1937,6 @@ XLATensor XLATensor::pow(const at::Scalar& input, const XLATensor& exponent) { return exponent.CreateFrom(ir::ops::Pow(input_node, exponent.GetIrValue())); } -void XLATensor::pow_(XLATensor& input, const at::Scalar& exponent) { - ir::Value exponent_node = - GetIrValueForScalar(exponent, input.shape(), input.GetDevice()); - input.SetInPlaceIrValue(ir::ops::Pow(input.GetIrValue(), exponent_node)); -} - -void XLATensor::pow_(XLATensor& input, const XLATensor& exponent) { - input.SetInPlaceIrValue( - ir::ops::Pow(input.GetIrValue(), exponent.GetIrValue())); -} - XLATensor XLATensor::prod(const XLATensor& input, std::vector dimensions, bool keep_reduced_dimensions, @@ -2255,10 +1978,6 @@ XLATensor XLATensor::reciprocal(const XLATensor& input) { return input.CreateFrom(ir::ops::ReciprocalOp(input.GetIrValue())); } -void XLATensor::reciprocal_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::ReciprocalOp(input.GetIrValue())); -} - XLATensor XLATensor::reflection_pad2d(const XLATensor& input, std::vector padding) { return input.CreateFrom(ir::MakeNode( @@ -2292,17 +2011,6 @@ XLATensor XLATensor::remainder(const XLATensor& input, return input.CreateFrom(ir::ops::Remainder(input.GetIrValue(), constant)); } -void XLATensor::remainder_(XLATensor& input, const XLATensor& other) { - input.SetInPlaceIrValue( - ir::ops::Remainder(input.GetIrValue(), other.GetIrValue())); -} - -void XLATensor::remainder_(XLATensor& input, const at::Scalar& other) { - ir::Value constant = - GetIrValueForScalar(other, input.shape(), input.GetDevice()); - input.SetInPlaceIrValue(ir::ops::Remainder(input.GetIrValue(), constant)); -} - XLATensor XLATensor::repeat(const XLATensor& input, std::vector repeats) { return input.CreateFrom( @@ -2353,10 +2061,6 @@ XLATensor XLATensor::round(const XLATensor& input) { return input.CreateFrom(ir::ops::Round(input.GetIrValue())); } -void XLATensor::round_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Round(input.GetIrValue())); -} - XLATensor XLATensor::rrelu_with_noise(const XLATensor& input, XLATensor& noise, const at::Scalar& lower, const at::Scalar& upper, bool training) { @@ -2382,10 +2086,6 @@ XLATensor XLATensor::rsqrt(const XLATensor& input) { return input.CreateFrom(ir::ops::Rsqrt(input.GetIrValue())); } -void XLATensor::rsqrt_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Rsqrt(input.GetIrValue())); -} - XLATensor XLATensor::rsub(const XLATensor& input, const XLATensor& other, const at::Scalar& alpha, c10::optional logical_element_type) { @@ -2463,10 +2163,6 @@ XLATensor XLATensor::sigmoid(const XLATensor& input) { return input.CreateFrom(ir::ops::Sigmoid(input.GetIrValue())); } -void XLATensor::sigmoid_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Sigmoid(input.GetIrValue())); -} - XLATensor XLATensor::sigmoid_backward(const XLATensor& grad_output, const XLATensor& output) { return grad_output.CreateFrom( @@ -2477,26 +2173,14 @@ XLATensor XLATensor::sign(const XLATensor& input) { return input.CreateFrom(ir::ops::SignOp(input.GetIrValue())); } -void XLATensor::sign_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::SignOp(input.GetIrValue())); -} - XLATensor XLATensor::sin(const XLATensor& input) { return input.CreateFrom(ir::ops::Sin(input.GetIrValue())); } -void XLATensor::sin_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Sin(input.GetIrValue())); -} - XLATensor XLATensor::sinh(const XLATensor& input) { return input.CreateFrom(ir::ops::Sinh(input.GetIrValue())); } -void XLATensor::sinh_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Sinh(input.GetIrValue())); -} - XLATensor XLATensor::slice(const XLATensor& input, xla::int64 dim, xla::int64 start, xla::int64 end, xla::int64 step) { auto input_shape = input.shape(); @@ -2616,10 +2300,6 @@ XLATensor XLATensor::sqrt(const XLATensor& input) { return input.CreateFrom(ir::ops::Sqrt(input.GetIrValue())); } -void XLATensor::sqrt_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Sqrt(input.GetIrValue())); -} - XLATensor XLATensor::squeeze(const XLATensor& input) { auto input_shape = input.shape(); auto output_dimensions = BuildSqueezedDimensions( @@ -2678,13 +2358,6 @@ XLATensor XLATensor::sub(const XLATensor& input, const XLATensor& other, logical_element_type); } -void XLATensor::sub_(XLATensor& input, const XLATensor& other, - const at::Scalar& alpha) { - ir::Value constant = - GetIrValueForScalar(alpha, other.shape(), other.GetDevice()); - input.SetInPlaceIrValue(input.GetIrValue() - other.GetIrValue() * constant); -} - XLATensor XLATensor::sub(const XLATensor& input, const at::Scalar& other, const at::Scalar& alpha, c10::optional logical_element_type) { @@ -2696,15 +2369,6 @@ XLATensor XLATensor::sub(const XLATensor& input, const at::Scalar& other, logical_element_type); } -void XLATensor::sub_(XLATensor& input, const at::Scalar& other, - const at::Scalar& alpha) { - ir::Value other_constant = - GetIrValueForScalar(other, input.shape(), input.GetDevice()); - ir::Value alpha_constant = - GetIrValueForScalar(alpha, input.shape(), input.GetDevice()); - input.SetInPlaceIrValue(input.GetIrValue() - other_constant * alpha_constant); -} - XLATensor XLATensor::sum(const XLATensor& input, std::vector dimensions, bool keep_reduced_dimensions, @@ -2750,18 +2414,10 @@ XLATensor XLATensor::tan(const XLATensor& input) { return input.CreateFrom(ir::ops::Tan(input.GetIrValue())); } -void XLATensor::tan_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Tan(input.GetIrValue())); -} - XLATensor XLATensor::tanh(const XLATensor& input) { return input.CreateFrom(ir::ops::Tanh(input.GetIrValue())); } -void XLATensor::tanh_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Tanh(input.GetIrValue())); -} - XLATensor XLATensor::tanh_backward(const XLATensor& grad_output, const XLATensor& output) { return XLATensor::mul(grad_output, @@ -2774,11 +2430,6 @@ XLATensor XLATensor::threshold(const XLATensor& input, float threshold, ir::MakeNode(input.GetIrValue(), threshold, value)); } -void XLATensor::threshold_(XLATensor& input, float threshold, float value) { - input.SetInPlaceIrValue( - ir::MakeNode(input.GetIrValue(), threshold, value)); -} - XLATensor XLATensor::threshold_backward(const XLATensor& grad_output, const XLATensor& input, float threshold) { @@ -2875,10 +2526,6 @@ XLATensor XLATensor::trunc(const XLATensor& input) { return input.CreateFrom(ir::ops::Trunc(input.GetIrValue())); } -void XLATensor::trunc_(XLATensor& input) { - input.SetInPlaceIrValue(ir::ops::Trunc(input.GetIrValue())); -} - std::vector XLATensor::unbind(const XLATensor& input, xla::int64 dim) { dim = XlaHelpers::GetCanonicalDimensionIndex(dim, input.shape().get().rank()); diff --git a/xla_native_functions.yaml b/xla_native_functions.yaml index 18b949c058f..07d7020bc92 100644 --- a/xla_native_functions.yaml +++ b/xla_native_functions.yaml @@ -2,31 +2,22 @@ backend: XLA cpp_namespace: torch_xla supported: - abs - - abs_ - acos - - acos_ - add.Tensor - add.Scalar - - add_.Scalar - all.dim - any.dim - arange.start_out - argmax - argmin - acosh - - acosh_ - asinh - - asinh_ - atanh - - atanh_ - as_strided - as_strided_ - asin - - asin_ - atan - - atan_ - baddbmm - - baddbmm_ - bernoulli - bernoulli_.Tensor - bernoulli_.float @@ -37,16 +28,12 @@ supported: - bmm - cat - ceil - - ceil_ - clamp - clamp.Tensor - - clamp_ - clamp_max - clamp_max.Tensor_out - - clamp_max_ - clamp_min - clamp_min.Tensor_out - - clamp_min_ - constant_pad_nd - convolution_overrideable - convolution_backward_overrideable @@ -54,18 +41,13 @@ supported: - _copy_from_and_resize - _to_cpu - cos - - cos_ - cosh - - cosh_ - cumprod - cumsum - diagonal - div.Tensor - - div_.Tensor - div.Tensor_mode - - div_.Tensor_mode - div.Scalar - - div_.Scalar - dot - embedding - embedding_dense_backward @@ -73,22 +55,16 @@ supported: - resize_ - empty_strided - erf - - erf_ - erfc - - erfc_ - exp - - exp_ - expm1 - - expm1_ - expand - eye.out - eye.m_out - fill_.Scalar - fill_.Tensor - floor - - floor_ - frac - - frac_ - index.Tensor - index_copy_ - index_put_ @@ -98,13 +74,9 @@ supported: - kl_div_backward - kthvalue - log - - log_ - log10 - - log10_ - log1p - - log1p_ - log2 - - log2_ - logdet - _log_softmax - _log_softmax_backward_data @@ -117,21 +89,16 @@ supported: - min.dim_min - mm - mul.Tensor - - mul_.Tensor - mul.Scalar - - mul_.Scalar - mv - mv.out - native_batch_norm - native_batch_norm_backward - permute - reciprocal - - reciprocal_ - neg - - neg_ - repeat - round - - round_ - relu - relu_ - gelu @@ -139,15 +106,11 @@ supported: - hardshrink - hardshrink_backward - rsqrt - - rsqrt_ - select.int - silu.out - sigmoid - - sigmoid_ - sin - - sin_ - sinh - - sinh_ - slice.Tensor - _softmax - _softmax_backward_data @@ -161,7 +124,6 @@ supported: - sum - sum.dim_IntList - sqrt - - sqrt_ - std - std.dim - std.correction @@ -170,18 +132,14 @@ supported: - t - t_ - tan - - tan_ - tanh - - tanh_ - threshold - - threshold_ - threshold_backward - transpose.int - transpose_ - flip - _trilinear - trunc - - trunc_ - _unsafe_view - unsqueeze - unsqueeze_ @@ -196,9 +154,7 @@ supported: - clone - zero_ - sub.Tensor - - sub_.Tensor - sub.Scalar - - sub_.Scalar - rsub.Tensor - rsub.Scalar - addmm @@ -216,8 +172,6 @@ supported: - scatter_.src - scatter_.value - scatter_add_ - - eq_.Scalar - - eq_.Tensor - bitwise_and.Tensor_out - bitwise_and.Scalar_out - bitwise_or.Tensor_out @@ -232,13 +186,8 @@ supported: - __rshift__.Tensor - __irshift__.Scalar - __irshift__.Tensor - - atan2_ - tril_ - triu_ - - fmod_.Scalar - - fmod_.Tensor - - remainder_.Scalar - - remainder_.Tensor - addcdiv_ - random_.from - random_.to @@ -252,33 +201,22 @@ supported: - trace - ne.Scalar - ne.Tensor - - ne_.Scalar - - ne_.Tensor - eq.Scalar - eq.Tensor - ge.Scalar - ge.Tensor - - ge_.Scalar - - ge_.Tensor - le.Scalar - le.Tensor - - le_.Scalar - - le_.Tensor - gt.Scalar - gt.Tensor - - gt_.Scalar - - gt_.Tensor - lt.Scalar - lt.Tensor - - lt_.Scalar - - lt_.Tensor - take - index_select - masked_select - nonzero - gather - addcmul - - addcmul_ - addcdiv - triangular_solve - symeig @@ -286,9 +224,7 @@ supported: - cholesky - qr - erfinv - - erfinv_ - sign - - sign_ - atan2 - fmod.Scalar - fmod.Tensor @@ -305,8 +241,6 @@ supported: - pow.Tensor_Tensor - pow.Scalar - pow.Tensor_Scalar - - pow_.Scalar - - pow_.Tensor - normal_ - normal.Tensor_float - normal.float_Tensor @@ -328,14 +262,11 @@ supported: - elu_backward - elu_ - hardsigmoid - - hardsigmoid_ - hardsigmoid_backward - hardtanh - hardtanh_backward - - hardtanh_ - leaky_relu - leaky_relu_backward - - leaky_relu_ - log_sigmoid_forward - log_sigmoid_backward - rrelu_with_noise