diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 11d81c69c253..6eeb1849a25d 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -1101,6 +1101,20 @@ TEST_F(AtenXlaTensorTest, TestTopK) { } } +TEST_F(AtenXlaTensorTest, TestTopKSymIntStatic) { + torch::Tensor a = torch::rand({10, 10}, torch::TensorOptions(torch::kFloat)); + auto results = torch::topk(a, 5); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_a = CopyToDevice(a, device); + auto xla_results = torch::topk_symint(xla_a, c10::SymInt(5)); + AllClose(std::get<0>(results), std::get<0>(xla_results)); + AllClose(std::get<1>(results), std::get<1>(xla_results)); + }); + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::topk_symint", cpp_test::GetIgnoredCounters()); +} + TEST_F(AtenXlaTensorTest, TestSort) { torch::Tensor a = torch::rand({4, 5, 3}, torch::TensorOptions(torch::kFloat)); for (int k = 1; k <= 3; ++k) { diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 068954900934..3bf242ec860c 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -178,6 +178,50 @@ def test_expand_symint_correctness(self): self.assertEqual(t3.shape[0], 2) self.assertEqual(expand_out_aten.cpu(), expand_out_xla.cpu()) + def test_topk_symint_ir_1(self): + t1 = torch.zeros([5, 2], device=dev) + t1[3][0] = 1 + t1[3][1] = 1 + t2 = torch.nonzero(t1) + t3 = torch.zeros([10, 2], device=dev) + values, indices = torch.topk(t3, t2.shape[0], dim=0) + self.assertIsInstance(t2.shape[0], torch.SymInt) + self.assertIsInstance(values.shape[0], torch.SymInt) + self.assertIsInstance(indices.shape[0], torch.SymInt) + self.assertEqual(str(t2.shape[0]), '<=10') + self.assertEqual(str(values.shape[0]), '<=10') + self.assertEqual(str(indices.shape[0]), '<=10') + self.assertEqual(t3.shape[1], 2) + self.assertEqual(values.shape[1], 2) + self.assertEqual(indices.shape[1], 2) + + def test_topk_symint_ir_2(self): + t1 = torch.ones(20, device=dev) + t1[::3] = 0 + t2 = torch.ones(10, device=dev) + t2[::2] = 0 + k = torch.nonzero(t2).shape[0] + values, indices = torch.topk(t1, k, dim=0) + self.assertIsInstance(values.shape[0], torch.SymInt) + self.assertIsInstance(indices.shape[0], torch.SymInt) + self.assertEqual(str(values.shape[0]), '<=10') + self.assertEqual(str(indices.shape[0]), '<=10') + + def test_topk_symint_correctness(self): + + def test_fn(*tensors): + torch.manual_seed(0) + x = torch.rand(10, 10).to(tensors[0].device) + results = [] + for tensor in tensors: + k = torch.nonzero(tensor).shape[0] + for dim in range(2): + results += list(torch.topk(x, k=k, dim=dim)) + return results + + self.runAtenTest([torch.randint(0, 2, size=(10,)) for _ in range(5)], + test_fn) + if __name__ == '__main__': assert os.environ['XLA_EXPERIMENTAL'] != '' diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 479297348368..94ab60e011a6 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2842,13 +2842,22 @@ at::Tensor XLANativeFunctions::threshold_backward(const at::Tensor& grad_output, threshold.to())); } -std::tuple XLANativeFunctions::topk( - const at::Tensor& self, int64_t k, int64_t dim, bool largest, bool sorted) { - TORCH_LAZY_FN_COUNTER("xla::"); - auto results = tensor_methods::topk(bridge::GetXlaTensor(self), k, dim, - largest, sorted, /*stable=*/false); - return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), - bridge::AtenFromXlaTensor(std::get<1>(results))); +std::tuple XLANativeFunctions::topk_symint( + const at::Tensor& self, c10::SymInt k, int64_t dim, bool largest, + bool sorted) { + TORCH_LAZY_FN_COUNTER("xla::"); + if (!k.is_symbolic()) { + auto results = + tensor_methods::topk(bridge::GetXlaTensor(self), k.expect_int(), dim, + largest, sorted, /*stable=*/false); + return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), + bridge::AtenFromXlaTensor(std::get<1>(results))); + } else { + auto results = tensor_methods::topk_symint( + bridge::GetXlaTensor(self), k, dim, largest, sorted, /*stable=*/false); + return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), + bridge::AtenFromXlaTensor(std::get<1>(results))); + } } at::Tensor XLANativeFunctions::trace(const at::Tensor& self) { diff --git a/torch_xla/csrc/ops/topk_symint.cpp b/torch_xla/csrc/ops/topk_symint.cpp new file mode 100644 index 000000000000..3156aa1b8934 --- /dev/null +++ b/torch_xla/csrc/ops/topk_symint.cpp @@ -0,0 +1,67 @@ +#include "torch_xla/csrc/ops/topk_symint.h" + +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/infer_output_shape.h" +#include "torch_xla/csrc/tensor_util.h" +#include "torch_xla/csrc/xla_lower_util.h" + +namespace torch_xla { +namespace { + +xla::Shape NodeOutputShapeSymInt(const torch::lazy::Value& input, + int64_t k_upper_bound, int64_t dim, + bool largest, bool sorted, bool stable) { + xla::Shape input_shape = GetXlaShape(input); + std::vector dimensions(input_shape.dimensions().begin(), + input_shape.dimensions().end()); + XLA_CHECK_LT(dim, input_shape.rank()); + dimensions[dim] = k_upper_bound; + xla::Shape values_shape = + xla::ShapeUtil::MakeShape(input_shape.element_type(), dimensions); + xla::Shape indices_shape = + xla::ShapeUtil::MakeShape(xla::PrimitiveType::S64, dimensions); + values_shape.set_dynamic_dimension(dim, true); + indices_shape.set_dynamic_dimension(dim, true); + return xla::ShapeUtil::MakeTupleShape({values_shape, indices_shape}); +} + +} // namespace + +TopKSymInt::TopKSymInt(const torch::lazy::Value& input, const SymIntElements& k, + int64_t dim, bool largest, bool sorted, bool stable) + : XlaNode(torch::lazy::OpKind(at::aten::topk), + {input, torch::lazy::Value(k.GetSizeNodes().front())}, + [&]() { + return NodeOutputShapeSymInt(input, k.GetUpperBounds().front(), + dim, largest, sorted, stable); + }, + /*num_outputs=*/2, + torch::lazy::MHash(k.GetUpperBounds().front(), dim, largest, + sorted, stable)), + k_upper_bound_(k.GetUpperBounds().front()), + dim_(dim), + largest_(largest), + sorted_(sorted), + stable_(stable) {} + +XlaOpVector TopKSymInt::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + xla::XlaOp size_op = loctx->GetOutputOp(operand(1)); + std::vector results = + CreateTopK(input, k_upper_bound_, dim_, largest_, stable_); + std::vector resized_results; + std::transform( + results.begin(), results.end(), std::back_inserter(resized_results), + [&](xla::XlaOp op) { return xla::SetDimensionSize(op, size_op, dim_); }); + return ReturnOps(resized_results, loctx); +} + +std::string TopKSymInt::ToString() const { + std::stringstream ss; + ss << XlaNode::ToString() << ", k<=" << k_upper_bound_ << ", dim=" << dim_ + << ", largest=" << largest_ << ", sorted=" << sorted_ + << ", stable=" << stable_; + return ss.str(); +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/topk_symint.h b/torch_xla/csrc/ops/topk_symint.h new file mode 100644 index 000000000000..e0f7854c111b --- /dev/null +++ b/torch_xla/csrc/ops/topk_symint.h @@ -0,0 +1,25 @@ +#pragma once + +#include "torch_xla/csrc/ir.h" +#include "torch_xla/csrc/torch_util.h" + +namespace torch_xla { + +class TopKSymInt : public XlaNode { + public: + TopKSymInt(const torch::lazy::Value& input, const SymIntElements& k, + int64_t dim, bool largest, bool sorted, bool stable); + + std::string ToString() const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + private: + int64_t k_upper_bound_; + int64_t dim_; + bool largest_; + bool sorted_; + bool stable_; +}; + +} // namespace torch_xla diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 76fcedc8158b..1bb661ba8903 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -121,6 +121,7 @@ #include "torch_xla/csrc/ops/threshold.h" #include "torch_xla/csrc/ops/threshold_backward.h" #include "torch_xla/csrc/ops/topk.h" +#include "torch_xla/csrc/ops/topk_symint.h" #include "torch_xla/csrc/ops/triangular_solve.h" #include "torch_xla/csrc/ops/uniform.h" #include "torch_xla/csrc/ops/unsqueeze.h" @@ -2505,6 +2506,20 @@ std::tuple topk(const XLATensorPtr& input, input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long)); } +std::tuple topk_symint(const XLATensorPtr& input, + c10::SymInt k, int64_t dim, + bool largest, bool sorted, + bool stable) { + SymIntElements k_symint = SymIntElements(k); + torch::lazy::NodePtr node = torch::lazy::MakeNode( + input->GetIrValue(), k_symint, + torch::lazy::GetCanonicalDimensionIndex(dim, input->shape().get().rank()), + largest, sorted, stable); + return std::make_tuple( + input->CreateFrom(torch::lazy::Value(node, 0)), + input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long)); +} + XLATensorPtr trace(const XLATensorPtr& input) { auto input_shape_ref = input->shape(); XLA_CHECK_EQ((*input_shape_ref).rank(), 2) diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 2ce51c3e4d2e..bbc6b3ef58fc 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -828,6 +828,11 @@ std::tuple topk(const XLATensorPtr& input, bool largest, bool sorted, bool stable); +std::tuple topk_symint(const XLATensorPtr& input, + c10::SymInt k, int64_t dim, + bool largest, bool sorted, + bool stable); + // Returns the sum of the elements of the diagonal of the input 2-D matrix. XLATensorPtr trace(const XLATensorPtr& input); diff --git a/xla_native_functions.yaml b/xla_native_functions.yaml index dd583280f392..21eb690878d3 100644 --- a/xla_native_functions.yaml +++ b/xla_native_functions.yaml @@ -364,6 +364,7 @@ symint: - diagonal_backward - narrow_copy - select_backward + - topk autograd: - einsum - max_pool2d