diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index c28d8af2c637..d39790494629 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -960,6 +960,16 @@ TEST_F(AtenXlaTensorTest, TestMean) { }); } +TEST_F(AtenXlaTensorTest, TestMeanCast) { + at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat)); + at::Tensor b = at::mean(a, at::kDouble); + ForEachDevice([&](const Device& device) { + at::Tensor xla_a = bridge::CreateXlaTensor(a, device); + at::Tensor xla_b = at::mean(xla_a, at::kDouble); + AllClose(b, xla_b); + }); +} + TEST_F(AtenXlaTensorTest, TestMeanInDim) { at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat)); int rank = a.dim(); @@ -985,6 +995,18 @@ TEST_F(AtenXlaTensorTest, TestMeanInDims) { } } +TEST_F(AtenXlaTensorTest, TestMeanInDimsKeepCast) { + at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat)); + for (auto dims : std::vector>{{0, 1}, {-3, -2}}) { + at::Tensor b = at::mean(a, dims, true, at::kDouble); + ForEachDevice([&](const Device& device) { + at::Tensor xla_a = bridge::CreateXlaTensor(a, device); + at::Tensor xla_b = at::mean(xla_a, dims, true, at::kDouble); + AllClose(b, xla_b); + }); + } +} + TEST_F(AtenXlaTensorTest, TestSum) { at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat)); at::Tensor b = at::sum(a); @@ -995,6 +1017,16 @@ TEST_F(AtenXlaTensorTest, TestSum) { }); } +TEST_F(AtenXlaTensorTest, TestSumCast) { + at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat)); + at::Tensor b = at::sum(a, at::kDouble); + ForEachDevice([&](const Device& device) { + at::Tensor xla_a = bridge::CreateXlaTensor(a, device); + at::Tensor xla_b = at::sum(xla_a, at::kDouble); + AllClose(b, xla_b); + }); +} + TEST_F(AtenXlaTensorTest, TestSumU8) { at::Tensor a = at::ones({256}, at::TensorOptions(at::kByte)); at::Tensor b = at::sum(a); @@ -1042,6 +1074,18 @@ TEST_F(AtenXlaTensorTest, TestSumInDimsKeep) { } } +TEST_F(AtenXlaTensorTest, TestSumInDimsKeepCast) { + at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat)); + for (auto dims : std::vector>{{0, 1}, {-3, -2}}) { + at::Tensor b = at::sum(a, dims, /*keepdim=*/true, at::kDouble); + ForEachDevice([&](const Device& device) { + at::Tensor xla_a = bridge::CreateXlaTensor(a, device); + at::Tensor xla_b = at::sum(xla_a, dims, /*keepdim=*/true, at::kDouble); + AllClose(b, xla_b); + }); + } +} + TEST_F(AtenXlaTensorTest, TestMaxInDim) { at::Tensor input = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat)); int rank = input.dim(); @@ -1435,6 +1479,16 @@ TEST_F(AtenXlaTensorTest, TestProd) { }); } +TEST_F(AtenXlaTensorTest, TestProdCast) { + at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat)); + at::Tensor b = at::prod(a, at::kDouble); + ForEachDevice([&](const Device& device) { + at::Tensor xla_a = bridge::CreateXlaTensor(a, device); + at::Tensor xla_b = at::prod(xla_a, at::kDouble); + AllClose(b, xla_b); + }); +} + TEST_F(AtenXlaTensorTest, TestProdInDim) { at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat)); int rank = a.dim(); @@ -1448,6 +1502,19 @@ TEST_F(AtenXlaTensorTest, TestProdInDim) { } } +TEST_F(AtenXlaTensorTest, TestProdInDimKeepCast) { + at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat)); + int rank = a.dim(); + for (int dim = -rank; dim < rank; ++dim) { + at::Tensor b = at::prod(a, dim, /*keepdim=*/true, at::kDouble); + ForEachDevice([&](const Device& device) { + at::Tensor xla_a = bridge::CreateXlaTensor(a, device); + at::Tensor xla_b = at::prod(xla_a, dim, /*keepdim=*/true, at::kDouble); + AllClose(b, xla_b); + }); + } +} + TEST_F(AtenXlaTensorTest, TestProdInDimKeep) { at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat)); int rank = a.dim(); @@ -1478,10 +1545,11 @@ TEST_F(AtenXlaTensorTest, TestCumSumCast) { at::Tensor input = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat)); int rank = input.dim(); for (int dim = -rank; dim < rank; ++dim) { - at::Tensor result = at::cumsum(input, dim, at::ScalarType::Int); + at::Tensor result = at::cumsum(input, dim, at::kDouble); ForEachDevice([&](const Device& device) { at::Tensor xla_input = bridge::CreateXlaTensor(input, device); - at::Tensor xla_result = at::cumsum(xla_input, dim, at::ScalarType::Int); + at::Tensor xla_result = at::cumsum(xla_input, dim, at::kDouble); + std::cout << result.dtype() << " " << xla_result.dtype() << std::endl; EXPECT_TRUE(EqualValues(result, xla_result)); }); } @@ -1531,10 +1599,10 @@ TEST_F(AtenXlaTensorTest, TestCumProdCast) { at::mul(at::rand({4, 3, 4}, at::TensorOptions(at::kFloat)), 10); int rank = input.dim(); for (int dim = -rank; dim < rank; ++dim) { - at::Tensor result = at::cumprod(input, dim, at::ScalarType::Int); + at::Tensor result = at::cumprod(input, dim, at::kDouble); ForEachDevice([&](const Device& device) { at::Tensor xla_input = bridge::CreateXlaTensor(input, device); - at::Tensor xla_result = at::cumprod(xla_input, dim, at::ScalarType::Int); + at::Tensor xla_result = at::cumprod(xla_input, dim, at::kDouble); EXPECT_TRUE(EqualValues(result, xla_result)); }); } @@ -4307,6 +4375,19 @@ TEST_F(AtenXlaTensorTest, TestLogSoftmax) { }); } +TEST_F(AtenXlaTensorTest, TestLogSoftmaxCast) { + at::Tensor input = at::rand({5, 3, 4, 2}, at::TensorOptions(at::kFloat)); + ForEachDevice([&](const Device& device) { + at::Tensor xla_input = bridge::CreateXlaTensor(input, device); + int rank = input.dim(); + for (int dim = -rank; dim < rank; ++dim) { + at::Tensor output = at::log_softmax(input, dim, at::kDouble); + at::Tensor xla_output = at::log_softmax(xla_input, dim, at::kDouble); + AllClose(output, xla_output, /*rtol=*/1e-3); + } + }); +} + TEST_F(AtenXlaTensorTest, TestSoftmax) { at::Tensor input = at::rand({10, 8, 24, 16}, at::TensorOptions(at::kFloat)); ForEachDevice([&](const Device& device) { @@ -4320,6 +4401,19 @@ TEST_F(AtenXlaTensorTest, TestSoftmax) { }); } +TEST_F(AtenXlaTensorTest, TestSoftmaxCast) { + at::Tensor input = at::rand({10, 8, 24, 16}, at::TensorOptions(at::kFloat)); + ForEachDevice([&](const Device& device) { + at::Tensor xla_input = bridge::CreateXlaTensor(input, device); + int rank = input.dim(); + for (int dim = -rank; dim < rank; ++dim) { + at::Tensor output = at::softmax(input, dim, at::kDouble); + at::Tensor xla_output = at::softmax(xla_input, dim, at::kDouble); + AllClose(output, xla_output, /*rtol=*/1e-3); + } + }); +} + TEST_F(AtenXlaTensorTest, TestSoftmaxWrapper) { at::Tensor input = at::rand({10, 8, 24, 16}, at::TensorOptions(at::kFloat)); ForEachDevice([&](const Device& device) { diff --git a/test/cpp/test_tensor.cpp b/test/cpp/test_tensor.cpp index e2e251312e8b..c6a16894893d 100644 --- a/test/cpp/test_tensor.cpp +++ b/test/cpp/test_tensor.cpp @@ -243,7 +243,7 @@ TEST_F(TensorTest, TestLogSoftmax) { auto dev_input = XLATensor::Create(input, device); for (int dim = 0; dim < input.dim(); ++dim) { auto output = input.log_softmax(dim); - auto dev_output = XLATensor::log_softmax(dev_input, dim); + auto dev_output = XLATensor::log_softmax(dev_input, dim, c10::nullopt); AllClose(output, dev_output, /*rtol=*/1e-3); } }); diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 27aefcb337e0..d5a6bcc12fe8 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -276,7 +276,7 @@ at::Tensor& AtenXlaType::_index_put_impl_(at::Tensor& self, at::Tensor AtenXlaType::_log_softmax(const at::Tensor& self, int64_t dim, bool /* half_to_float */) { return bridge::AtenFromXlaTensor( - XLATensor::log_softmax(bridge::GetXlaTensor(self), dim)); + XLATensor::log_softmax(bridge::GetXlaTensor(self), dim, c10::nullopt)); } at::Tensor AtenXlaType::_log_softmax_backward_data( @@ -288,7 +288,7 @@ at::Tensor AtenXlaType::_log_softmax_backward_data( at::Tensor AtenXlaType::_softmax(const at::Tensor& self, int64_t dim, bool /* half_to_float */) { - return softmax(self, dim); + return softmax(self, dim, c10::nullopt); } at::Tensor AtenXlaType::_softmax_backward_data(const at::Tensor& grad_output, @@ -909,7 +909,7 @@ at::Tensor AtenXlaType::cross(const at::Tensor& self, const at::Tensor& other, } at::Tensor AtenXlaType::cumprod(const at::Tensor& self, int64_t dim, - at::ScalarType dtype) { + c10::optional dtype) { XLATensor self_tensor = bridge::GetXlaTensor(self); if (dtype == at::ScalarType::Long && self_tensor.GetDevice().hw_type == DeviceType::TPU) { @@ -919,18 +919,8 @@ at::Tensor AtenXlaType::cumprod(const at::Tensor& self, int64_t dim, return bridge::AtenFromXlaTensor(XLATensor::cumprod(self_tensor, dim, dtype)); } -at::Tensor AtenXlaType::cumprod(const at::Tensor& self, int64_t dim) { - XLATensor self_tensor = bridge::GetXlaTensor(self); - if (self_tensor.dtype() == at::ScalarType::Long && - self_tensor.GetDevice().hw_type == DeviceType::TPU) { - return AtenXlaTypeDefault::cumprod(self, dim); - } - return bridge::AtenFromXlaTensor( - XLATensor::cumprod(self_tensor, dim, c10::nullopt)); -} - at::Tensor AtenXlaType::cumsum(const at::Tensor& self, int64_t dim, - at::ScalarType dtype) { + c10::optional dtype) { XLATensor self_tensor = bridge::GetXlaTensor(self); if (dtype == at::ScalarType::Long && self_tensor.GetDevice().hw_type == DeviceType::TPU) { @@ -940,17 +930,6 @@ at::Tensor AtenXlaType::cumsum(const at::Tensor& self, int64_t dim, return bridge::AtenFromXlaTensor(XLATensor::cumsum(self_tensor, dim, dtype)); } -at::Tensor AtenXlaType::cumsum(const at::Tensor& self, int64_t dim) { - XLATensor self_tensor = bridge::GetXlaTensor(self); - if (self_tensor.dtype() == at::ScalarType::Long && - self_tensor.GetDevice().hw_type == DeviceType::TPU) { - // XLA reduce-window does not support S64 mode. - return AtenXlaTypeDefault::cumsum(self, dim); - } - return bridge::AtenFromXlaTensor( - XLATensor::cumsum(self_tensor, dim, c10::nullopt)); -} - at::Tensor AtenXlaType::diag(const at::Tensor& self, int64_t diagonal) { return bridge::AtenFromXlaTensor( XLATensor::diag(bridge::GetXlaTensor(self), diagonal)); @@ -1063,7 +1042,8 @@ at::Tensor AtenXlaType::embedding_dense_backward(const at::Tensor& grad_output, } at::Tensor AtenXlaType::empty(at::IntArrayRef size, - const at::TensorOptions& options) { + const at::TensorOptions& options, + c10::optional memory_format) { // PT empty*() are optimizations to avoid initializing the data when it is // known it will be completely rewritten. But since for us doing a zero*() // does not actually end up doing any memory initialization, we use that and @@ -1076,8 +1056,9 @@ at::Tensor AtenXlaType::empty_like(const at::Tensor& self) { return full_like(self, 0); } -at::Tensor AtenXlaType::empty_like(const at::Tensor& self, - const at::TensorOptions& options) { +at::Tensor AtenXlaType::empty_like( + const at::Tensor& self, const at::TensorOptions& options, + c10::optional memory_format) { return full_like(self, 0, options); } @@ -1674,9 +1655,10 @@ std::tuple AtenXlaType::log_sigmoid_forward( bridge::AtenFromXlaTensor(std::get<1>(result_tuple))); } -at::Tensor AtenXlaType::log_softmax(const at::Tensor& self, int64_t dim) { +at::Tensor AtenXlaType::log_softmax(const at::Tensor& self, int64_t dim, + c10::optional dtype) { return bridge::AtenFromXlaTensor( - XLATensor::log_softmax(bridge::GetXlaTensor(self), dim)); + XLATensor::log_softmax(bridge::GetXlaTensor(self), dim, dtype)); } at::Tensor AtenXlaType::lt(const at::Tensor& self, at::Scalar other) { @@ -1897,7 +1879,8 @@ std::tuple AtenXlaType::max_pool3d_with_indices( bridge::AtenFromXlaTensor(indices_not_supported)); } -at::Tensor AtenXlaType::mean(const at::Tensor& self, at::ScalarType dtype) { +at::Tensor AtenXlaType::mean(const at::Tensor& self, + c10::optional dtype) { XLATensor self_tensor = bridge::GetXlaTensor(self); return bridge::AtenFromXlaTensor(XLATensor::mean( self_tensor, @@ -1905,35 +1888,14 @@ at::Tensor AtenXlaType::mean(const at::Tensor& self, at::ScalarType dtype) { /*keep_reduced_dimensions*/ false, dtype)); } -at::Tensor AtenXlaType::mean(const at::Tensor& self) { - XLATensor self_tensor = bridge::GetXlaTensor(self); - return bridge::AtenFromXlaTensor(XLATensor::mean( - self_tensor, - xla::util::Iota(self_tensor.shape().get().rank()), - /*keep_reduced_dimensions*/ false, c10::nullopt)); -} - at::Tensor AtenXlaType::mean(const at::Tensor& self, at::IntArrayRef dim, - bool keepdim, at::ScalarType dtype) { + bool keepdim, + c10::optional dtype) { return bridge::AtenFromXlaTensor(XLATensor::mean( bridge::GetXlaTensor(self), xla::util::ToVector(dim), /*keep_reduced_dimensions*/ keepdim, dtype)); } -at::Tensor AtenXlaType::mean(const at::Tensor& self, at::IntArrayRef dim, - bool keepdim) { - return bridge::AtenFromXlaTensor(XLATensor::mean( - bridge::GetXlaTensor(self), xla::util::ToVector(dim), - /*keep_reduced_dimensions*/ keepdim, c10::nullopt)); -} - -at::Tensor AtenXlaType::mean(const at::Tensor& self, at::IntArrayRef dim, - at::ScalarType dtype) { - return bridge::AtenFromXlaTensor(XLATensor::mean( - bridge::GetXlaTensor(self), xla::util::ToVector(dim), - /*keep_reduced_dimensions*/ false, dtype)); -} - std::vector AtenXlaType::meshgrid(at::TensorList tensors) { return at::native::meshgrid(tensors); } @@ -2208,7 +2170,8 @@ at::Tensor& AtenXlaType::pow_(at::Tensor& self, const at::Tensor& exponent) { return self; } -at::Tensor AtenXlaType::prod(const at::Tensor& self, at::ScalarType dtype) { +at::Tensor AtenXlaType::prod(const at::Tensor& self, + c10::optional dtype) { XLATensor self_tensor = bridge::GetXlaTensor(self); return bridge::AtenFromXlaTensor(XLATensor::prod( self_tensor, @@ -2216,33 +2179,12 @@ at::Tensor AtenXlaType::prod(const at::Tensor& self, at::ScalarType dtype) { /*keep_reduced_dimensions=*/false, dtype)); } -at::Tensor AtenXlaType::prod(const at::Tensor& self) { - XLATensor self_tensor = bridge::GetXlaTensor(self); - return bridge::AtenFromXlaTensor(XLATensor::prod( - self_tensor, - xla::util::Iota(self_tensor.shape().get().rank()), - /*keep_reduced_dimensions=*/false, c10::nullopt)); -} - at::Tensor AtenXlaType::prod(const at::Tensor& self, int64_t dim, bool keepdim, - at::ScalarType dtype) { + c10::optional dtype) { return bridge::AtenFromXlaTensor( XLATensor::prod(bridge::GetXlaTensor(self), {dim}, keepdim, dtype)); } -at::Tensor AtenXlaType::prod(const at::Tensor& self, int64_t dim, - bool keepdim) { - return bridge::AtenFromXlaTensor(XLATensor::prod( - bridge::GetXlaTensor(self), {dim}, keepdim, c10::nullopt)); -} - -at::Tensor AtenXlaType::prod(const at::Tensor& self, int64_t dim, - at::ScalarType dtype) { - return bridge::AtenFromXlaTensor( - XLATensor::prod(bridge::GetXlaTensor(self), {dim}, - /*keep_reduced_dimensions=*/false, dtype)); -} - std::tuple AtenXlaType::qr(const at::Tensor& self, bool some) { auto results = XLATensor::qr(bridge::GetXlaTensor(self), some); @@ -2482,9 +2424,10 @@ at::Tensor AtenXlaType::smooth_l1_loss_backward(const at::Tensor& grad_output, bridge::GetXlaTensor(target), reduction)); } -at::Tensor AtenXlaType::softmax(const at::Tensor& self, int64_t dim) { +at::Tensor AtenXlaType::softmax(const at::Tensor& self, int64_t dim, + c10::optional dtype) { return bridge::AtenFromXlaTensor( - XLATensor::softmax(bridge::GetXlaTensor(self), dim)); + XLATensor::softmax(bridge::GetXlaTensor(self), dim, dtype)); } at::Tensor AtenXlaType::softplus(const at::Tensor& self, at::Scalar beta, @@ -2604,7 +2547,8 @@ at::Tensor& AtenXlaType::sub_(at::Tensor& self, at::Scalar other, return self; } -at::Tensor AtenXlaType::sum(const at::Tensor& self, at::ScalarType dtype) { +at::Tensor AtenXlaType::sum(const at::Tensor& self, + c10::optional dtype) { XLATensor self_tensor = bridge::GetXlaTensor(self); return bridge::AtenFromXlaTensor(XLATensor::sum( self_tensor, @@ -2612,35 +2556,13 @@ at::Tensor AtenXlaType::sum(const at::Tensor& self, at::ScalarType dtype) { /*keep_reduced_dimensions=*/false, dtype)); } -at::Tensor AtenXlaType::sum(const at::Tensor& self) { - XLATensor self_tensor = bridge::GetXlaTensor(self); - return bridge::AtenFromXlaTensor(XLATensor::sum( - self_tensor, - xla::util::Iota(self_tensor.shape().get().rank()), - /*keep_reduced_dimensions=*/false, c10::nullopt)); -} - at::Tensor AtenXlaType::sum(const at::Tensor& self, at::IntArrayRef dim, - bool keepdim, at::ScalarType dtype) { + bool keepdim, c10::optional dtype) { return bridge::AtenFromXlaTensor( XLATensor::sum(bridge::GetXlaTensor(self), xla::util::ToVector(dim), keepdim, dtype)); } -at::Tensor AtenXlaType::sum(const at::Tensor& self, at::IntArrayRef dim, - bool keepdim) { - return bridge::AtenFromXlaTensor(XLATensor::sum( - bridge::GetXlaTensor(self), xla::util::ToVector(dim), keepdim, - c10::nullopt)); -} - -at::Tensor AtenXlaType::sum(const at::Tensor& self, at::IntArrayRef dim, - at::ScalarType dtype) { - return bridge::AtenFromXlaTensor(XLATensor::sum( - bridge::GetXlaTensor(self), xla::util::ToVector(dim), - /*keep_reduced_dimensions=*/false, dtype)); -} - at::Tensor AtenXlaType::sum_to_size(const at::Tensor& self, at::IntArrayRef size) { return at::native::sum_to_size(self, size); diff --git a/torch_xla/csrc/aten_xla_type.h b/torch_xla/csrc/aten_xla_type.h index 858981186436..82b4b68f6b69 100644 --- a/torch_xla/csrc/aten_xla_type.h +++ b/torch_xla/csrc/aten_xla_type.h @@ -326,14 +326,10 @@ class AtenXlaType { c10::optional dim); static at::Tensor cumprod(const at::Tensor& self, int64_t dim, - at::ScalarType dtype); - - static at::Tensor cumprod(const at::Tensor& self, int64_t dim); + c10::optional dtype); static at::Tensor cumsum(const at::Tensor& self, int64_t dim, - at::ScalarType dtype); - - static at::Tensor cumsum(const at::Tensor& self, int64_t dim); + c10::optional dtype); static at::Tensor diag(const at::Tensor& self, int64_t diagonal); @@ -380,12 +376,14 @@ class AtenXlaType { bool scale_grad_by_freq); static at::Tensor empty(at::IntArrayRef size, - const at::TensorOptions& options); + const at::TensorOptions& options, + c10::optional memory_format); static at::Tensor empty_like(const at::Tensor& self); static at::Tensor empty_like(const at::Tensor& self, - const at::TensorOptions& options); + const at::TensorOptions& options, + c10::optional memory_format); static at::Tensor eq(const at::Tensor& self, at::Scalar other); @@ -628,7 +626,8 @@ class AtenXlaType { static std::tuple log_sigmoid_forward( const at::Tensor& self); - static at::Tensor log_softmax(const at::Tensor& self, int64_t dim); + static at::Tensor log_softmax(const at::Tensor& self, int64_t dim, + c10::optional dtype); static at::Tensor lt(const at::Tensor& self, at::Scalar other); @@ -701,18 +700,11 @@ class AtenXlaType { at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor& indices); - static at::Tensor mean(const at::Tensor& self, at::ScalarType dtype); - - static at::Tensor mean(const at::Tensor& self); + static at::Tensor mean(const at::Tensor& self, + c10::optional dtype); static at::Tensor mean(const at::Tensor& self, at::IntArrayRef dim, - bool keepdim, at::ScalarType dtype); - - static at::Tensor mean(const at::Tensor& self, at::IntArrayRef dim, - bool keepdim); - - static at::Tensor mean(const at::Tensor& self, at::IntArrayRef dim, - at::ScalarType dtype); + bool keepdim, c10::optional dtype); static std::vector meshgrid(at::TensorList tensors); @@ -827,17 +819,11 @@ class AtenXlaType { static at::Tensor& pow_(at::Tensor& self, const at::Tensor& exponent); - static at::Tensor prod(const at::Tensor& self, at::ScalarType dtype); - - static at::Tensor prod(const at::Tensor& self); + static at::Tensor prod(const at::Tensor& self, + c10::optional dtype); static at::Tensor prod(const at::Tensor& self, int64_t dim, bool keepdim, - at::ScalarType dtype); - - static at::Tensor prod(const at::Tensor& self, int64_t dim, bool keepdim); - - static at::Tensor prod(const at::Tensor& self, int64_t dim, - at::ScalarType dtype); + c10::optional dtype); static std::tuple qr(const at::Tensor& self, bool some); @@ -935,7 +921,8 @@ class AtenXlaType { const at::Tensor& target, int64_t reduction); - static at::Tensor softmax(const at::Tensor& self, int64_t dim); + static at::Tensor softmax(const at::Tensor& self, int64_t dim, + c10::optional dtype); static at::Tensor softplus(const at::Tensor& self, at::Scalar beta, at::Scalar threshold); @@ -986,18 +973,11 @@ class AtenXlaType { static at::Tensor& sub_(at::Tensor& self, at::Scalar other, at::Scalar alpha); - static at::Tensor sum(const at::Tensor& self, at::ScalarType dtype); - - static at::Tensor sum(const at::Tensor& self); - - static at::Tensor sum(const at::Tensor& self, at::IntArrayRef dim, - bool keepdim, at::ScalarType dtype); - - static at::Tensor sum(const at::Tensor& self, at::IntArrayRef dim, - bool keepdim); + static at::Tensor sum(const at::Tensor& self, + c10::optional dtype); static at::Tensor sum(const at::Tensor& self, at::IntArrayRef dim, - at::ScalarType dtype); + bool keepdim, c10::optional dtype); static at::Tensor sum_to_size(const at::Tensor& self, at::IntArrayRef size); diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index a16c1b39f174..79b071aa4dd9 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -242,52 +242,47 @@ void InitXlaModuleBindings(py::module m) { m.def("_xla_set_default_device", [](const std::string& device) { return SetCurrentDevice(device); }); m.def("_xla_get_default_device", []() { return GetCurrentDevice(); }); - m.def( - "_xla_sync_multi", - [](const std::vector& tensors, - const std::vector& devices, bool wait, - bool sync_xla_data) { - NoGilSection nogil; - SyncTensors(tensors, devices, wait, sync_xla_data); - }, - py::arg("tensors"), py::arg("devices"), py::arg("wait") = true, - py::arg("sync_xla_data") = true); - m.def( - "_xla_sync_live_tensors", - [](const std::string& device, const std::vector& devices, - bool wait) { - NoGilSection nogil; - SyncLiveTensors(device, devices, wait); - }, - py::arg("device") = "", py::arg("devices"), py::arg("wait") = true); - m.def( - "_xla_step_marker", - [](const std::string& device, const std::vector& devices, - bool wait) { - NoGilSection nogil; - StepMarker(device, devices, wait); - }, - py::arg("device") = "", py::arg("devices"), py::arg("wait") = true); + m.def("_xla_sync_multi", + [](const std::vector& tensors, + const std::vector& devices, bool wait, + bool sync_xla_data) { + NoGilSection nogil; + SyncTensors(tensors, devices, wait, sync_xla_data); + }, + py::arg("tensors"), py::arg("devices"), py::arg("wait") = true, + py::arg("sync_xla_data") = true); + m.def("_xla_sync_live_tensors", + [](const std::string& device, const std::vector& devices, + bool wait) { + NoGilSection nogil; + SyncLiveTensors(device, devices, wait); + }, + py::arg("device") = "", py::arg("devices"), py::arg("wait") = true); + m.def("_xla_step_marker", + [](const std::string& device, const std::vector& devices, + bool wait) { + NoGilSection nogil; + StepMarker(device, devices, wait); + }, + py::arg("device") = "", py::arg("devices"), py::arg("wait") = true); m.def("_xla_counter_value", [](const std::string& name) -> py::object { xla::metrics::CounterData* data = xla::metrics::GetCounter(name); return data != nullptr ? py::cast(data->Value()) : py::none(); }); m.def("_xla_metrics_report", []() { return xla::metrics::CreateMetricReport(); }); - m.def( - "_xla_tensors_report", - [](size_t nodes_threshold, const std::string& device) { - return GetLiveTensorsReport(nodes_threshold, device); - }, - py::arg("nodes_threshold") = 100, py::arg("device") = ""); - m.def( - "_xla_set_use_full_mat_mul_precision", - [](bool use_full_mat_mul_precision) { - XlaHelpers::set_mat_mul_precision(use_full_mat_mul_precision - ? xla::PrecisionConfig::HIGHEST - : xla::PrecisionConfig::DEFAULT); - }, - py::arg("use_full_mat_mul_precision") = true); + m.def("_xla_tensors_report", + [](size_t nodes_threshold, const std::string& device) { + return GetLiveTensorsReport(nodes_threshold, device); + }, + py::arg("nodes_threshold") = 100, py::arg("device") = ""); + m.def("_xla_set_use_full_mat_mul_precision", + [](bool use_full_mat_mul_precision) { + XlaHelpers::set_mat_mul_precision( + use_full_mat_mul_precision ? xla::PrecisionConfig::HIGHEST + : xla::PrecisionConfig::DEFAULT); + }, + py::arg("use_full_mat_mul_precision") = true); } } // namespace diff --git a/torch_xla/csrc/ops/log_softmax.cpp b/torch_xla/csrc/ops/log_softmax.cpp index a797f8dc0692..bcc8a400e25c 100644 --- a/torch_xla/csrc/ops/log_softmax.cpp +++ b/torch_xla/csrc/ops/log_softmax.cpp @@ -1,30 +1,55 @@ #include "torch_xla/csrc/ops/log_softmax.h" #include "tensorflow/compiler/xla/xla_client/util.h" +#include "torch_xla/csrc/convert_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/softmax_builder.h" +#include "torch_xla/csrc/tensor_util.h" +#include "torch_xla/csrc/torch_util.h" namespace torch_xla { namespace ir { namespace ops { +namespace { -LogSoftmax::LogSoftmax(const Value& input, xla::int64 dim) - : Node(ir::OpKind(at::aten::log_softmax), {input}, input.shape(), - /*num_outputs=*/1, xla::util::MHash(dim)), - dim_(dim) {} +xla::XlaOp LowerLogSoftmax(const xla::XlaOp& input, + xla::int64 dim, + c10::optional dtype) { + xla::XlaOp casted_input = CastToScalarType(input, dtype); + return BuildLogSoftmax(casted_input, dim); +} + +xla::Shape NodeOutputShape(const Value& input, + c10::optional dtype) { + if (dtype) { + return xla::ShapeUtil::ChangeElementType( + input.shape(), MakeXlaPrimitiveType(*dtype, /*device=*/nullptr)); + } + return input.shape(); +} +} // namespace + + +LogSoftmax::LogSoftmax(const Value& input, xla::int64 dim, c10::optional dtype) + : Node(ir::OpKind(at::aten::log_softmax), {input}, + [&]() { return NodeOutputShape(input, dtype); }, + /*num_outputs=*/1, xla::util::MHash(dim, OptionalOr(dtype, -1))), + dim_(dim), + dtype_(dtype) {} NodePtr LogSoftmax::Clone(OpList operands) const { - return MakeNode(operands.at(0), dim_); + return MakeNode(operands.at(0), dim_, dtype_); } XlaOpVector LogSoftmax::Lower(LoweringContext* loctx) const { xla::XlaOp input = loctx->GetOutputOp(operand(0)); - return ReturnOp(BuildLogSoftmax(input, dim_), loctx); + return ReturnOp(LowerLogSoftmax(input, dim_, dtype_), loctx); } std::string LogSoftmax::ToString() const { std::stringstream ss; - ss << Node::ToString() << ", dim=" << dim_; + ss << Node::ToString() << ", dim=" << dim_ + << ", dtype=" << OptionalOr(dtype_, -1); return ss.str(); } diff --git a/torch_xla/csrc/ops/log_softmax.h b/torch_xla/csrc/ops/log_softmax.h index cdf3d0e72137..a319383328b8 100644 --- a/torch_xla/csrc/ops/log_softmax.h +++ b/torch_xla/csrc/ops/log_softmax.h @@ -1,5 +1,8 @@ #pragma once +#include +#include + #include "torch_xla/csrc/ir.h" namespace torch_xla { @@ -9,7 +12,7 @@ namespace ops { // IR node for log(softmax) operation. class LogSoftmax : public Node { public: - LogSoftmax(const Value& input, xla::int64 dim); + LogSoftmax(const Value& input, xla::int64 dim, c10::optional dtype); NodePtr Clone(OpList operands) const override; @@ -19,9 +22,12 @@ class LogSoftmax : public Node { xla::int64 dim() const { return dim_; } + const c10::optional& dtype() const { return dtype_; } + private: // The dimension along which the result is computed. xla::int64 dim_; + c10::optional dtype_; }; } // namespace ops diff --git a/torch_xla/csrc/ops/softmax.cpp b/torch_xla/csrc/ops/softmax.cpp index 94afe7b2b126..b430e25e43cd 100644 --- a/torch_xla/csrc/ops/softmax.cpp +++ b/torch_xla/csrc/ops/softmax.cpp @@ -1,30 +1,54 @@ #include "torch_xla/csrc/ops/softmax.h" +#include "torch_xla/csrc/convert_ops.h" #include "tensorflow/compiler/xla/xla_client/util.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/softmax_builder.h" +#include "torch_xla/csrc/tensor_util.h" +#include "torch_xla/csrc/torch_util.h" namespace torch_xla { namespace ir { namespace ops { +namespace { -Softmax::Softmax(const Value& input, xla::int64 dim) - : Node(ir::OpKind(at::aten::softmax), {input}, input.shape(), - /*num_outputs=*/1, xla::util::MHash(dim)), - dim_(dim) {} +xla::XlaOp LowerSoftmax(const xla::XlaOp& input, + xla::int64 dim, + c10::optional dtype) { + xla::XlaOp casted_input = CastToScalarType(input, dtype); + return BuildSoftmax(casted_input, dim); +} + +xla::Shape NodeOutputShape(const Value& input, + c10::optional dtype) { + if (dtype) { + return xla::ShapeUtil::ChangeElementType( + input.shape(), MakeXlaPrimitiveType(*dtype, /*device=*/nullptr)); + } + return input.shape(); +} +} // namespace + +Softmax::Softmax(const Value& input, xla::int64 dim, c10::optional dtype) + : Node(ir::OpKind(at::aten::softmax), {input}, + [&]() { return NodeOutputShape(input, dtype); }, + /*num_outputs=*/1, xla::util::MHash(dim, OptionalOr(dtype, -1))), + dim_(dim), + dtype_(dtype) {} NodePtr Softmax::Clone(OpList operands) const { - return MakeNode(operands.at(0), dim_); + return MakeNode(operands.at(0), dim_, dtype_); } XlaOpVector Softmax::Lower(LoweringContext* loctx) const { xla::XlaOp input = loctx->GetOutputOp(operand(0)); - return ReturnOp(BuildSoftmax(input, dim_), loctx); + return ReturnOp(LowerSoftmax(input, dim_, dtype_), loctx); } std::string Softmax::ToString() const { std::stringstream ss; - ss << Node::ToString() << ", dim=" << dim_; + ss << Node::ToString() << ", dim=" << dim_ + << ", dtype=" << OptionalOr(dtype_, -1); return ss.str(); } diff --git a/torch_xla/csrc/ops/softmax.h b/torch_xla/csrc/ops/softmax.h index dfc6b5673be7..90cd76336a76 100644 --- a/torch_xla/csrc/ops/softmax.h +++ b/torch_xla/csrc/ops/softmax.h @@ -1,5 +1,8 @@ #pragma once +#include +#include + #include "torch_xla/csrc/ir.h" namespace torch_xla { @@ -8,7 +11,7 @@ namespace ops { class Softmax : public Node { public: - Softmax(const Value& input, xla::int64 dim); + Softmax(const Value& input, xla::int64 dim, c10::optional dtype); NodePtr Clone(OpList operands) const override; @@ -18,8 +21,11 @@ class Softmax : public Node { xla::int64 dim() const { return dim_; } + const c10::optional& dtype() const { return dtype_; } + private: xla::int64 dim_; + c10::optional dtype_; }; } // namespace ops diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index defb2819486b..8eb7b49d919b 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -557,7 +557,8 @@ class XLATensor { const XLATensor& input, const XLATensor& buffer); - static XLATensor log_softmax(const XLATensor& input, xla::int64 dim); + static XLATensor log_softmax(const XLATensor& input, xla::int64 dim, + c10::optional dtype); static XLATensor log_softmax_backward(const XLATensor& grad_output, const XLATensor& output, @@ -763,7 +764,8 @@ class XLATensor { const XLATensor& target, xla::int64 reduction); - static XLATensor softmax(const XLATensor& input, xla::int64 dim); + static XLATensor softmax(const XLATensor& input, xla::int64 dim, + c10::optional dtype); static XLATensor softmax_backward(const XLATensor& grad_output, const XLATensor& output, xla::int64 dim); diff --git a/torch_xla/csrc/tensor_impl.cpp b/torch_xla/csrc/tensor_impl.cpp index e739f9ca6ecb..eafe7ce1b0ae 100644 --- a/torch_xla/csrc/tensor_impl.cpp +++ b/torch_xla/csrc/tensor_impl.cpp @@ -59,7 +59,7 @@ c10::intrusive_ptr XLATensorImpl::shallow_copy_and_detach( const c10::VariableVersion& version_counter, bool allow_tensor_metadata_change) const { auto impl = c10::make_intrusive(tensor_); - copy_tensor_data( + copy_tensor_metadata( /*src_impl=*/this, /*dest_impl=*/impl.get(), /*version_counter=*/version_counter, @@ -70,7 +70,7 @@ c10::intrusive_ptr XLATensorImpl::shallow_copy_and_detach( void XLATensorImpl::shallow_copy_from( const c10::intrusive_ptr& impl) { XLATensorImpl* xla_impl = dynamic_cast(impl.get()); - copy_tensor_data( + copy_tensor_metadata( /*src_impl=*/xla_impl, /*dest_impl=*/this, /*version_counter=*/version_counter(), diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 403af16d0a77..a638b30034cb 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -1259,10 +1259,14 @@ XLATensor XLATensor::log_sigmoid_backward(const XLATensor& grad_output, grad_output.GetIrValue(), input.GetIrValue(), buffer.GetIrValue())); } -XLATensor XLATensor::log_softmax(const XLATensor& input, xla::int64 dim) { - return input.CreateFrom(ir::MakeNode( - input.GetIrValue(), - XlaHelpers::GetCanonicalDimensionIndex(dim, input.shape().get().rank()))); +XLATensor XLATensor::log_softmax(const XLATensor& input, xla::int64 dim, + c10::optional dtype) { + return input.CreateFrom( + ir::MakeNode(input.GetIrValue(), + XlaHelpers::GetCanonicalDimensionIndex( + dim, input.shape().get().rank()), + dtype), + dtype); } XLATensor XLATensor::log_softmax_backward(const XLATensor& grad_output, @@ -1821,10 +1825,14 @@ XLATensor XLATensor::smooth_l1_loss_backward(const XLATensor& grad_output, reduction); } -XLATensor XLATensor::softmax(const XLATensor& input, xla::int64 dim) { - return input.CreateFrom(ir::MakeNode( - input.GetIrValue(), - XlaHelpers::GetCanonicalDimensionIndex(dim, input.shape().get().rank()))); +XLATensor XLATensor::softmax(const XLATensor& input, xla::int64 dim, + c10::optional dtype) { + return input.CreateFrom( + ir::MakeNode(input.GetIrValue(), + XlaHelpers::GetCanonicalDimensionIndex( + dim, input.shape().get().rank()), + dtype), + dtype); } XLATensor XLATensor::softmax_backward(const XLATensor& grad_output,