Skip to content

Commit e0eff17

Browse files
committed
Fix descending sort when minimum values are present
Applying an unary minus transformation on the input overflows when the minimum values are present. Use a greater than comparator instead. Fixes #865.
1 parent f7fc05a commit e0eff17

File tree

2 files changed

+23
-12
lines changed

2 files changed

+23
-12
lines changed

test/cpp/test_aten_xla_tensor.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,19 @@ TEST_F(AtenXlaTensorTest, TestSort) {
851851
}
852852
}
853853

854+
TEST_F(AtenXlaTensorTest, TestSortDescWithMinValue) {
855+
std::vector<int8_t> values{-128, 100};
856+
torch::Tensor input =
857+
torch::tensor(values, torch::TensorOptions(torch::kChar));
858+
auto output = torch::sort(input, /*dim=*/0, /*descending=*/true);
859+
ForEachDevice([&](const torch::Device& device) {
860+
torch::Tensor xla_input = CopyToDevice(input, device);
861+
auto xla_output = torch::sort(xla_input, /*dim=*/0, /*descending=*/true);
862+
AllEqual(std::get<0>(output), std::get<0>(xla_output));
863+
AllEqual(std::get<1>(output), std::get<1>(xla_output));
864+
});
865+
}
866+
854867
TEST_F(AtenXlaTensorTest, TestArgSort) {
855868
torch::Tensor a = torch::rand({4, 5, 3}, torch::TensorOptions(torch::kFloat));
856869
for (int k = 1; k <= 3; ++k) {

torch_xla/csrc/xla_lower_util.cpp

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -168,31 +168,29 @@ std::vector<xla::XlaOp> CreateKthValue(const xla::XlaOp& input, xla::int64 k,
168168
std::vector<xla::XlaOp> CreateTopK(const xla::XlaOp& input, xla::int64 k,
169169
xla::int64 dim, bool largest,
170170
bool /* sorted */) {
171-
auto identity = [](const xla::XlaOp& op) -> xla::XlaOp { return op; };
172-
auto neg = [](const xla::XlaOp& op) -> xla::XlaOp { return xla::Neg(op); };
173-
auto input_transform = largest ? neg : identity;
174-
175171
// Here 'k' is 1 based (1...).
176172
xla::Shape shape = XlaHelpers::ShapeOfXlaOp(input);
177173
XLA_CHECK_LE(k, shape.dimensions(dim));
178174
xla::Shape iota_shape =
179175
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, shape.dimensions());
180176
xla::XlaOp iota = xla::Iota(input.builder(), iota_shape, dim);
181-
xla::XlaOp sort_result = xla::Sort(
182-
{input_transform(input), iota},
183-
xla::CreateScalarLtComputation(
184-
{shape.element_type(), xla::PrimitiveType::S32}, input.builder()),
185-
dim);
177+
xla::XlaComputation comparator =
178+
largest ? xla::CreateScalarGtComputation(
179+
{shape.element_type(), xla::PrimitiveType::S32},
180+
input.builder())
181+
: xla::CreateScalarLtComputation(
182+
{shape.element_type(), xla::PrimitiveType::S32},
183+
input.builder());
184+
xla::XlaOp sort_result = xla::Sort({input, iota}, comparator, dim);
186185

187186
std::vector<xla::int64> start_indices(shape.rank(), 0);
188187
std::vector<xla::int64> limit_indices(shape.dimensions().begin(),
189188
shape.dimensions().end());
190189
limit_indices[dim] = k;
191190
std::vector<xla::int64> strides(shape.rank(), 1);
192191

193-
xla::XlaOp values =
194-
input_transform(xla::Slice(xla::GetTupleElement(sort_result, 0),
195-
start_indices, limit_indices, strides));
192+
xla::XlaOp values = xla::Slice(xla::GetTupleElement(sort_result, 0),
193+
start_indices, limit_indices, strides);
196194
xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1),
197195
start_indices, limit_indices, strides);
198196
// aten::topk() wants Long tensors as indices.

0 commit comments

Comments
 (0)