diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 4faed8dd679..d6fb5428266 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -936,16 +936,20 @@ TEST_F(AtenXlaTensorTest, TestSort) { for (int k = 1; k <= 3; ++k) { for (int dim = 0; dim < 3; ++dim) { for (bool descending : {false, true}) { - auto b = torch::sort(a, dim, descending); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_a = CopyToDevice(a, device); - auto xla_b = torch::sort(xla_a, dim, descending); - AllClose(std::get<0>(b), std::get<0>(xla_b)); - AllEqual(std::get<1>(b), std::get<1>(xla_b)); - }); + for (bool stable : {false, true}) { + auto b = torch::sort(a, dim, descending, stable); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_a = CopyToDevice(a, device); + auto xla_b = torch::sort(xla_a, dim, descending, stable); + AllClose(std::get<0>(b), std::get<0>(xla_b)); + AllEqual(std::get<1>(b), std::get<1>(xla_b)); + }); + } } } } + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::sort", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestSortDescWithMinValue) { diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index b6ee491417f..eae60092ad3 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2982,10 +2982,12 @@ at::Tensor AtenXlaType::softshrink_backward(const at::Tensor& grad_out, std::tuple AtenXlaType::sort(const at::Tensor& self, int64_t dim, - bool descending) { + bool descending, + bool stable) { XLA_FN_COUNTER("xla::"); - auto results = XLATensor::topk(bridge::GetXlaTensor(self), self.size(dim), - dim, descending, true); + auto results = + XLATensor::topk(bridge::GetXlaTensor(self), self.size(dim), dim, + descending, /*sorted=*/false, /*stable=*/stable); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), bridge::AtenFromXlaTensor(std::get<1>(results))); } diff --git a/torch_xla/csrc/aten_xla_type.h b/torch_xla/csrc/aten_xla_type.h index d34745fb492..207c13a2c6a 100644 --- a/torch_xla/csrc/aten_xla_type.h +++ b/torch_xla/csrc/aten_xla_type.h @@ -918,7 +918,8 @@ class AtenXlaType { at::Scalar lambda); static std::tuple sort(const at::Tensor& self, - int64_t dim, bool descending); + int64_t dim, bool descending, + bool stable); static std::vector split(const at::Tensor& self, int64_t split_size, int64_t dim); diff --git a/torch_xla/csrc/ops/topk.cpp b/torch_xla/csrc/ops/topk.cpp index ebd5484587c..c633d32a34d 100644 --- a/torch_xla/csrc/ops/topk.cpp +++ b/torch_xla/csrc/ops/topk.cpp @@ -11,11 +11,11 @@ namespace ops { namespace { xla::Shape NodeOutputShape(const Value& input, xla::int64 k, xla::int64 dim, - bool largest, bool sorted) { + bool largest, bool sorted, bool stable) { auto lower_for_shape_fn = [&](absl::Span operands) -> xla::XlaOp { return xla::Tuple(operands[0].builder(), - CreateTopK(operands[0], k, dim, largest, sorted)); + CreateTopK(operands[0], k, dim, largest, sorted, stable)); }; return InferOutputShape({input.shape()}, lower_for_shape_fn); } @@ -23,28 +23,34 @@ xla::Shape NodeOutputShape(const Value& input, xla::int64 k, xla::int64 dim, } // namespace TopK::TopK(const Value& input, xla::int64 k, xla::int64 dim, bool largest, - bool sorted) + bool sorted, bool stable) : Node(ir::OpKind(at::aten::topk), {input}, - [&]() { return NodeOutputShape(input, k, dim, largest, sorted); }, - /*num_outputs=*/2, xla::util::MHash(k, dim, largest, sorted)), + [&]() { + return NodeOutputShape(input, k, dim, largest, sorted, stable); + }, + /*num_outputs=*/2, + xla::util::MHash(k, dim, largest, sorted, stable)), k_(k), dim_(dim), largest_(largest), - sorted_(sorted) {} + sorted_(sorted), + stable_(stable) {} NodePtr TopK::Clone(OpList operands) const { - return MakeNode(operands.at(0), k_, dim_, largest_, sorted_); + return MakeNode(operands.at(0), k_, dim_, largest_, sorted_, stable_); } XlaOpVector TopK::Lower(LoweringContext* loctx) const { xla::XlaOp input = loctx->GetOutputOp(operand(0)); - return ReturnOps(CreateTopK(input, k_, dim_, largest_, sorted_), loctx); + return ReturnOps(CreateTopK(input, k_, dim_, largest_, sorted_, stable_), + loctx); } std::string TopK::ToString() const { std::stringstream ss; ss << Node::ToString() << ", k=" << k_ << ", dim=" << dim_ - << ", largest=" << largest_ << ", sorted=" << sorted_; + << ", largest=" << largest_ << ", sorted=" << sorted_ + << ", stable=" << stable_; return ss.str(); } diff --git a/torch_xla/csrc/ops/topk.h b/torch_xla/csrc/ops/topk.h index c5592eba6e2..7a81f98d3d5 100644 --- a/torch_xla/csrc/ops/topk.h +++ b/torch_xla/csrc/ops/topk.h @@ -9,7 +9,7 @@ namespace ops { class TopK : public Node { public: TopK(const Value& input, xla::int64 k, xla::int64 dim, bool largest, - bool sorted); + bool sorted, bool stable); std::string ToString() const override; @@ -25,11 +25,14 @@ class TopK : public Node { bool sorted() const { return sorted_; } + bool stable() const { return stable_; } + private: xla::int64 k_; xla::int64 dim_; bool largest_; bool sorted_; + bool stable_; }; } // namespace ops diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 1c22dcb35e5..79c961f66eb 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -1072,7 +1072,8 @@ class XLATensor { static std::tuple topk(const XLATensor& input, xla::int64 k, xla::int64 dim, - bool largest, bool sorted); + bool largest, bool sorted, + bool stable = false); // Returns the sum of the elements of the diagonal of the input 2-D matrix. static XLATensor trace(const XLATensor& input); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 8abeac567ac..7ffb9f37e5e 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -2683,11 +2683,12 @@ XLATensor XLATensor::to(XLATensor& input, c10::optional device, std::tuple XLATensor::topk(const XLATensor& input, xla::int64 k, xla::int64 dim, - bool largest, bool sorted) { + bool largest, bool sorted, + bool stable) { ir::NodePtr node = ir::MakeNode( input.GetIrValue(), k, XlaHelpers::GetCanonicalDimensionIndex(dim, input.shape().get().rank()), - largest, sorted); + largest, sorted, stable); return std::make_tuple( input.CreateFrom(ir::Value(node, 0)), input.CreateFrom(ir::Value(node, 1), at::ScalarType::Long)); diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index e1a90226de9..617b75f7b54 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -348,7 +348,7 @@ std::vector CreateKthValue(xla::XlaOp input, xla::int64 k, std::vector CreateTopK(xla::XlaOp input, xla::int64 k, xla::int64 dim, bool largest, - bool /* sorted */) { + bool /* sorted */, bool stable) { // Here 'k' is 1 based (1...). const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input); XLA_CHECK_LE(k, shape.dimensions(dim)); @@ -362,7 +362,7 @@ std::vector CreateTopK(xla::XlaOp input, xla::int64 k, : xla::CreateScalarLtComputation( {shape.element_type(), xla::PrimitiveType::S32}, input.builder()); - xla::XlaOp sort_result = xla::Sort({input, iota}, comparator, dim); + xla::XlaOp sort_result = xla::Sort({input, iota}, comparator, dim, stable); std::vector start_indices(shape.rank(), 0); std::vector limit_indices(shape.dimensions().begin(), diff --git a/torch_xla/csrc/xla_lower_util.h b/torch_xla/csrc/xla_lower_util.h index e63efe08135..ee586b81239 100644 --- a/torch_xla/csrc/xla_lower_util.h +++ b/torch_xla/csrc/xla_lower_util.h @@ -16,7 +16,8 @@ std::vector CreateKthValue(xla::XlaOp input, xla::int64 k, xla::int64 dim, bool keepdim); std::vector CreateTopK(xla::XlaOp input, xla::int64 k, - xla::int64 dim, bool largest, bool sorted); + xla::int64 dim, bool largest, bool sorted, + bool stable); xla::XlaOp CreateMatMul(xla::XlaOp lhs, xla::XlaOp rhs);