@@ -168,31 +168,29 @@ std::vector<xla::XlaOp> CreateKthValue(const xla::XlaOp& input, xla::int64 k,
168
168
std::vector<xla::XlaOp> CreateTopK (const xla::XlaOp& input, xla::int64 k,
169
169
xla::int64 dim, bool largest,
170
170
bool /* sorted */ ) {
171
- auto identity = [](const xla::XlaOp& op) -> xla::XlaOp { return op; };
172
- auto neg = [](const xla::XlaOp& op) -> xla::XlaOp { return xla::Neg (op); };
173
- auto input_transform = largest ? neg : identity;
174
-
175
171
// Here 'k' is 1 based (1...).
176
172
xla::Shape shape = XlaHelpers::ShapeOfXlaOp (input);
177
173
XLA_CHECK_LE (k, shape.dimensions (dim));
178
174
xla::Shape iota_shape =
179
175
xla::ShapeUtil::MakeShape (xla::PrimitiveType::S32, shape.dimensions ());
180
176
xla::XlaOp iota = xla::Iota (input.builder (), iota_shape, dim);
181
- xla::XlaOp sort_result = xla::Sort (
182
- {input_transform (input), iota},
183
- xla::CreateScalarLtComputation (
184
- {shape.element_type (), xla::PrimitiveType::S32}, input.builder ()),
185
- dim);
177
+ xla::XlaComputation comparator =
178
+ largest ? xla::CreateScalarGtComputation (
179
+ {shape.element_type (), xla::PrimitiveType::S32},
180
+ input.builder ())
181
+ : xla::CreateScalarLtComputation (
182
+ {shape.element_type (), xla::PrimitiveType::S32},
183
+ input.builder ());
184
+ xla::XlaOp sort_result = xla::Sort ({input, iota}, comparator, dim);
186
185
187
186
std::vector<xla::int64> start_indices (shape.rank (), 0 );
188
187
std::vector<xla::int64> limit_indices (shape.dimensions ().begin (),
189
188
shape.dimensions ().end ());
190
189
limit_indices[dim] = k;
191
190
std::vector<xla::int64> strides (shape.rank (), 1 );
192
191
193
- xla::XlaOp values =
194
- input_transform (xla::Slice (xla::GetTupleElement (sort_result, 0 ),
195
- start_indices, limit_indices, strides));
192
+ xla::XlaOp values = xla::Slice (xla::GetTupleElement (sort_result, 0 ),
193
+ start_indices, limit_indices, strides);
196
194
xla::XlaOp indices = xla::Slice (xla::GetTupleElement (sort_result, 1 ),
197
195
start_indices, limit_indices, strides);
198
196
// aten::topk() wants Long tensors as indices.
0 commit comments