Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -936,16 +936,20 @@ TEST_F(AtenXlaTensorTest, TestSort) {
for (int k = 1; k <= 3; ++k) {
for (int dim = 0; dim < 3; ++dim) {
for (bool descending : {false, true}) {
auto b = torch::sort(a, dim, descending);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_a = CopyToDevice(a, device);
auto xla_b = torch::sort(xla_a, dim, descending);
AllClose(std::get<0>(b), std::get<0>(xla_b));
AllEqual(std::get<1>(b), std::get<1>(xla_b));
});
for (bool stable : {false, true}) {
auto b = torch::sort(a, dim, descending, stable);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_a = CopyToDevice(a, device);
auto xla_b = torch::sort(xla_a, dim, descending, stable);
AllClose(std::get<0>(b), std::get<0>(xla_b));
AllEqual(std::get<1>(b), std::get<1>(xla_b));
});
}
}
}
}
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::sort", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestSortDescWithMinValue) {
Expand Down
8 changes: 5 additions & 3 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2982,10 +2982,12 @@ at::Tensor AtenXlaType::softshrink_backward(const at::Tensor& grad_out,

std::tuple<at::Tensor, at::Tensor> AtenXlaType::sort(const at::Tensor& self,
int64_t dim,
bool descending) {
bool descending,
bool stable) {
XLA_FN_COUNTER("xla::");
auto results = XLATensor::topk(bridge::GetXlaTensor(self), self.size(dim),
dim, descending, true);
auto results =
XLATensor::topk(bridge::GetXlaTensor(self), self.size(dim), dim,
descending, /*sorted=*/false, /*stable=*/stable);
return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)),
bridge::AtenFromXlaTensor(std::get<1>(results)));
}
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/aten_xla_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,8 @@ class AtenXlaType {
at::Scalar lambda);

static std::tuple<at::Tensor, at::Tensor> sort(const at::Tensor& self,
int64_t dim, bool descending);
int64_t dim, bool descending,
bool stable);

static std::vector<at::Tensor> split(const at::Tensor& self,
int64_t split_size, int64_t dim);
Expand Down
24 changes: 15 additions & 9 deletions torch_xla/csrc/ops/topk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,40 +11,46 @@ namespace ops {
namespace {

xla::Shape NodeOutputShape(const Value& input, xla::int64 k, xla::int64 dim,
bool largest, bool sorted) {
bool largest, bool sorted, bool stable) {
auto lower_for_shape_fn =
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return xla::Tuple(operands[0].builder(),
CreateTopK(operands[0], k, dim, largest, sorted));
CreateTopK(operands[0], k, dim, largest, sorted, stable));
};
return InferOutputShape({input.shape()}, lower_for_shape_fn);
}

} // namespace

TopK::TopK(const Value& input, xla::int64 k, xla::int64 dim, bool largest,
bool sorted)
bool sorted, bool stable)
: Node(ir::OpKind(at::aten::topk), {input},
[&]() { return NodeOutputShape(input, k, dim, largest, sorted); },
/*num_outputs=*/2, xla::util::MHash(k, dim, largest, sorted)),
[&]() {
return NodeOutputShape(input, k, dim, largest, sorted, stable);
},
/*num_outputs=*/2,
xla::util::MHash(k, dim, largest, sorted, stable)),
k_(k),
dim_(dim),
largest_(largest),
sorted_(sorted) {}
sorted_(sorted),
stable_(stable) {}

NodePtr TopK::Clone(OpList operands) const {
return MakeNode<TopK>(operands.at(0), k_, dim_, largest_, sorted_);
return MakeNode<TopK>(operands.at(0), k_, dim_, largest_, sorted_, stable_);
}

XlaOpVector TopK::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
return ReturnOps(CreateTopK(input, k_, dim_, largest_, sorted_), loctx);
return ReturnOps(CreateTopK(input, k_, dim_, largest_, sorted_, stable_),
loctx);
}

std::string TopK::ToString() const {
std::stringstream ss;
ss << Node::ToString() << ", k=" << k_ << ", dim=" << dim_
<< ", largest=" << largest_ << ", sorted=" << sorted_;
<< ", largest=" << largest_ << ", sorted=" << sorted_
<< ", stable=" << stable_;
return ss.str();
}

Expand Down
5 changes: 4 additions & 1 deletion torch_xla/csrc/ops/topk.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace ops {
class TopK : public Node {
public:
TopK(const Value& input, xla::int64 k, xla::int64 dim, bool largest,
bool sorted);
bool sorted, bool stable);

std::string ToString() const override;

Expand All @@ -25,11 +25,14 @@ class TopK : public Node {

bool sorted() const { return sorted_; }

bool stable() const { return stable_; }

private:
xla::int64 k_;
xla::int64 dim_;
bool largest_;
bool sorted_;
bool stable_;
};

} // namespace ops
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,8 @@ class XLATensor {

static std::tuple<XLATensor, XLATensor> topk(const XLATensor& input,
xla::int64 k, xla::int64 dim,
bool largest, bool sorted);
bool largest, bool sorted,
bool stable = false);

// Returns the sum of the elements of the diagonal of the input 2-D matrix.
static XLATensor trace(const XLATensor& input);
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2683,11 +2683,12 @@ XLATensor XLATensor::to(XLATensor& input, c10::optional<Device> device,

std::tuple<XLATensor, XLATensor> XLATensor::topk(const XLATensor& input,
xla::int64 k, xla::int64 dim,
bool largest, bool sorted) {
bool largest, bool sorted,
bool stable) {
ir::NodePtr node = ir::MakeNode<ir::ops::TopK>(
input.GetIrValue(), k,
XlaHelpers::GetCanonicalDimensionIndex(dim, input.shape().get().rank()),
largest, sorted);
largest, sorted, stable);
return std::make_tuple(
input.CreateFrom(ir::Value(node, 0)),
input.CreateFrom(ir::Value(node, 1), at::ScalarType::Long));
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/xla_lower_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ std::vector<xla::XlaOp> CreateKthValue(xla::XlaOp input, xla::int64 k,

std::vector<xla::XlaOp> CreateTopK(xla::XlaOp input, xla::int64 k,
xla::int64 dim, bool largest,
bool /* sorted */) {
bool /* sorted */, bool stable) {
// Here 'k' is 1 based (1...).
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input);
XLA_CHECK_LE(k, shape.dimensions(dim));
Expand All @@ -362,7 +362,7 @@ std::vector<xla::XlaOp> CreateTopK(xla::XlaOp input, xla::int64 k,
: xla::CreateScalarLtComputation(
{shape.element_type(), xla::PrimitiveType::S32},
input.builder());
xla::XlaOp sort_result = xla::Sort({input, iota}, comparator, dim);
xla::XlaOp sort_result = xla::Sort({input, iota}, comparator, dim, stable);

std::vector<xla::int64> start_indices(shape.rank(), 0);
std::vector<xla::int64> limit_indices(shape.dimensions().begin(),
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/xla_lower_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ std::vector<xla::XlaOp> CreateKthValue(xla::XlaOp input, xla::int64 k,
xla::int64 dim, bool keepdim);

std::vector<xla::XlaOp> CreateTopK(xla::XlaOp input, xla::int64 k,
xla::int64 dim, bool largest, bool sorted);
xla::int64 dim, bool largest, bool sorted,
bool stable);

xla::XlaOp CreateMatMul(xla::XlaOp lhs, xla::XlaOp rhs);

Expand Down