From d180f98512089508acdcb7d904464678d6a66f3a Mon Sep 17 00:00:00 2001 From: Yanming Wang Date: Thu, 16 Feb 2023 06:18:38 +0000 Subject: [PATCH 1/4] Initial implementation of topk_symint --- test/cpp/test_aten_xla_tensor.cpp | 14 ++++++++ test/test_dynamic_shapes.py | 31 ++++++++++++++++++ torch_xla/csrc/aten_xla_type.cpp | 23 +++++++++---- torch_xla/csrc/ops/topk.cpp | 54 +++++++++++++++++++++++++++++++ torch_xla/csrc/ops/topk.h | 18 +++++++++++ torch_xla/csrc/tensor_methods.cpp | 14 ++++++++ torch_xla/csrc/tensor_methods.h | 5 +++ xla_native_functions.yaml | 1 + 8 files changed, 153 insertions(+), 7 deletions(-) diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 11d81c69c253..6eeb1849a25d 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -1101,6 +1101,20 @@ TEST_F(AtenXlaTensorTest, TestTopK) { } } +TEST_F(AtenXlaTensorTest, TestTopKSymIntStatic) { + torch::Tensor a = torch::rand({10, 10}, torch::TensorOptions(torch::kFloat)); + auto results = torch::topk(a, 5); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_a = CopyToDevice(a, device); + auto xla_results = torch::topk_symint(xla_a, c10::SymInt(5)); + AllClose(std::get<0>(results), std::get<0>(xla_results)); + AllClose(std::get<1>(results), std::get<1>(xla_results)); + }); + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::topk_symint", cpp_test::GetIgnoredCounters()); +} + TEST_F(AtenXlaTensorTest, TestSort) { torch::Tensor a = torch::rand({4, 5, 3}, torch::TensorOptions(torch::kFloat)); for (int k = 1; k <= 3; ++k) { diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 068954900934..d53df1bdf183 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -178,6 +178,37 @@ def test_expand_symint_correctness(self): self.assertEqual(t3.shape[0], 2) self.assertEqual(expand_out_aten.cpu(), expand_out_xla.cpu()) + def test_topk_symint(self): + dev = xm.xla_device() + t1 = torch.zeros([5, 2], device=dev) + t1[3][0] = 1 + t1[3][1] = 1 + t2 = torch.nonzero(t1) + t3 = torch.zeros([10, 2], device=dev) + values, indices = torch.topk(t3, t2.shape[0], dim=0) + self.assertIsInstance(t2.shape[0], torch.SymInt) + self.assertIsInstance(values.shape[0], torch.SymInt) + self.assertIsInstance(indices.shape[0], torch.SymInt) + self.assertEqual(str(t2.shape[0]), '<=10') + self.assertEqual(str(values.shape[0]), '<=10') + self.assertEqual(str(indices.shape[0]), '<=10') + self.assertEqual(t3.shape[1], 2) + self.assertEqual(values.shape[1], 2) + self.assertEqual(indices.shape[1], 2) + + def test_fn(*tensors): + torch.manual_seed(0) + x = torch.rand(10, 10).to(tensors[0].device) + results = [] + for tensor in tensors: + k = torch.nonzero(tensor).shape[0] + for dim in range(2): + results += list(torch.topk(x, k=k, dim=dim)) + return results + + self.runAtenTest([torch.randint(0, 2, size=(10,)) for _ in range(5)], + test_fn) + if __name__ == '__main__': assert os.environ['XLA_EXPERIMENTAL'] != '' diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 479297348368..94ab60e011a6 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2842,13 +2842,22 @@ at::Tensor XLANativeFunctions::threshold_backward(const at::Tensor& grad_output, threshold.to())); } -std::tuple XLANativeFunctions::topk( - const at::Tensor& self, int64_t k, int64_t dim, bool largest, bool sorted) { - TORCH_LAZY_FN_COUNTER("xla::"); - auto results = tensor_methods::topk(bridge::GetXlaTensor(self), k, dim, - largest, sorted, /*stable=*/false); - return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), - bridge::AtenFromXlaTensor(std::get<1>(results))); +std::tuple XLANativeFunctions::topk_symint( + const at::Tensor& self, c10::SymInt k, int64_t dim, bool largest, + bool sorted) { + TORCH_LAZY_FN_COUNTER("xla::"); + if (!k.is_symbolic()) { + auto results = + tensor_methods::topk(bridge::GetXlaTensor(self), k.expect_int(), dim, + largest, sorted, /*stable=*/false); + return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), + bridge::AtenFromXlaTensor(std::get<1>(results))); + } else { + auto results = tensor_methods::topk_symint( + bridge::GetXlaTensor(self), k, dim, largest, sorted, /*stable=*/false); + return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), + bridge::AtenFromXlaTensor(std::get<1>(results))); + } } at::Tensor XLANativeFunctions::trace(const at::Tensor& self) { diff --git a/torch_xla/csrc/ops/topk.cpp b/torch_xla/csrc/ops/topk.cpp index e47b6a33ef82..ddc6ab6ee909 100644 --- a/torch_xla/csrc/ops/topk.cpp +++ b/torch_xla/csrc/ops/topk.cpp @@ -18,6 +18,23 @@ xla::Shape NodeOutputShape(const torch::lazy::Value& input, int64_t k, return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn); } +xla::Shape NodeOutputShapeSymInt(const torch::lazy::Value& input, + int64_t k_upper_bound, int64_t dim, + bool largest, bool sorted, bool stable) { + xla::Shape input_shape = GetXlaShape(input); + std::vector dimensions(input_shape.dimensions().begin(), + input_shape.dimensions().end()); + XLA_CHECK_LT(dim, input_shape.rank()); + dimensions[dim] = k_upper_bound; + xla::Shape values_shape = + xla::ShapeUtil::MakeShape(input_shape.element_type(), dimensions); + xla::Shape indices_shape = + xla::ShapeUtil::MakeShape(xla::PrimitiveType::S64, dimensions); + values_shape.set_dynamic_dimension(dim, true); + indices_shape.set_dynamic_dimension(dim, true); + return xla::ShapeUtil::MakeTupleShape({values_shape, indices_shape}); +} + } // namespace TopK::TopK(const torch::lazy::Value& input, int64_t k, int64_t dim, @@ -52,4 +69,41 @@ std::string TopK::ToString() const { return ss.str(); } +TopKSymInt::TopKSymInt(const torch::lazy::Value& input, const SymIntElements& k, + int64_t dim, bool largest, bool sorted, bool stable) + : XlaNode(torch::lazy::OpKind(at::aten::topk), + {input, torch::lazy::Value(k.GetSizeNodes().front())}, + [&]() { + return NodeOutputShapeSymInt(input, k.GetUpperBounds().front(), + dim, largest, sorted, stable); + }, + /*num_outputs=*/2, + torch::lazy::MHash(k.GetUpperBounds().front(), dim, largest, + sorted, stable)), + k_upper_bound_(k.GetUpperBounds().front()), + dim_(dim), + largest_(largest), + sorted_(sorted), + stable_(stable) {} + +XlaOpVector TopKSymInt::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + xla::XlaOp size_op = loctx->GetOutputOp(operand(1)); + std::vector results = + CreateTopK(input, k_upper_bound_, dim_, largest_, stable_); + std::vector resized_results; + std::transform( + results.begin(), results.end(), std::back_inserter(resized_results), + [&](xla::XlaOp op) { return xla::SetDimensionSize(op, size_op, dim_); }); + return ReturnOps(resized_results, loctx); +} + +std::string TopKSymInt::ToString() const { + std::stringstream ss; + ss << XlaNode::ToString() << ", k<=" << k_upper_bound_ << ", dim=" << dim_ + << ", largest=" << largest_ << ", sorted=" << sorted_ + << ", stable=" << stable_; + return ss.str(); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/ops/topk.h b/torch_xla/csrc/ops/topk.h index 39598b048a7e..f1b899f97c96 100644 --- a/torch_xla/csrc/ops/topk.h +++ b/torch_xla/csrc/ops/topk.h @@ -1,6 +1,7 @@ #pragma once #include "torch_xla/csrc/ir.h" +#include "torch_xla/csrc/torch_util.h" namespace torch_xla { @@ -33,4 +34,21 @@ class TopK : public XlaNode { bool stable_; }; +class TopKSymInt : public XlaNode { + public: + TopKSymInt(const torch::lazy::Value& input, const SymIntElements& k, + int64_t dim, bool largest, bool sorted, bool stable); + + std::string ToString() const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + private: + int64_t k_upper_bound_; + int64_t dim_; + bool largest_; + bool sorted_; + bool stable_; +}; + } // namespace torch_xla diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 76fcedc8158b..79348dcb7428 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -2505,6 +2505,20 @@ std::tuple topk(const XLATensorPtr& input, input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long)); } +std::tuple topk_symint(const XLATensorPtr& input, + c10::SymInt k, int64_t dim, + bool largest, bool sorted, + bool stable) { + SymIntElements k_symint = SymIntElements(k); + torch::lazy::NodePtr node = torch::lazy::MakeNode( + input->GetIrValue(), k_symint, + torch::lazy::GetCanonicalDimensionIndex(dim, input->shape().get().rank()), + largest, sorted, stable); + return std::make_tuple( + input->CreateFrom(torch::lazy::Value(node, 0)), + input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long)); +} + XLATensorPtr trace(const XLATensorPtr& input) { auto input_shape_ref = input->shape(); XLA_CHECK_EQ((*input_shape_ref).rank(), 2) diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 2ce51c3e4d2e..bbc6b3ef58fc 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -828,6 +828,11 @@ std::tuple topk(const XLATensorPtr& input, bool largest, bool sorted, bool stable); +std::tuple topk_symint(const XLATensorPtr& input, + c10::SymInt k, int64_t dim, + bool largest, bool sorted, + bool stable); + // Returns the sum of the elements of the diagonal of the input 2-D matrix. XLATensorPtr trace(const XLATensorPtr& input); diff --git a/xla_native_functions.yaml b/xla_native_functions.yaml index dd583280f392..21eb690878d3 100644 --- a/xla_native_functions.yaml +++ b/xla_native_functions.yaml @@ -364,6 +364,7 @@ symint: - diagonal_backward - narrow_copy - select_backward + - topk autograd: - einsum - max_pool2d From 2cb242cd2b94f957df0fa4794f065ff8d2645304 Mon Sep 17 00:00:00 2001 From: Yanming Wang Date: Thu, 16 Feb 2023 23:55:31 +0000 Subject: [PATCH 2/4] Add .torch_pin --- torch_patches/.torch_pin | 1 + 1 file changed, 1 insertion(+) create mode 100644 torch_patches/.torch_pin diff --git a/torch_patches/.torch_pin b/torch_patches/.torch_pin new file mode 100644 index 000000000000..813d6b42c541 --- /dev/null +++ b/torch_patches/.torch_pin @@ -0,0 +1 @@ +#95015 From f4156d44856c3ceb9bf46567241b06f376ad4738 Mon Sep 17 00:00:00 2001 From: Yanming Wang Date: Mon, 20 Feb 2023 20:37:09 +0000 Subject: [PATCH 3/4] Address CR comments --- test/test_dynamic_shapes.py | 17 +++++++- torch_xla/csrc/ops/topk.cpp | 54 ------------------------ torch_xla/csrc/ops/topk.h | 18 -------- torch_xla/csrc/ops/topk_symint.cpp | 67 ++++++++++++++++++++++++++++++ torch_xla/csrc/ops/topk_symint.h | 25 +++++++++++ torch_xla/csrc/tensor_methods.cpp | 1 + 6 files changed, 108 insertions(+), 74 deletions(-) create mode 100644 torch_xla/csrc/ops/topk_symint.cpp create mode 100644 torch_xla/csrc/ops/topk_symint.h diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index d53df1bdf183..3bf242ec860c 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -178,8 +178,7 @@ def test_expand_symint_correctness(self): self.assertEqual(t3.shape[0], 2) self.assertEqual(expand_out_aten.cpu(), expand_out_xla.cpu()) - def test_topk_symint(self): - dev = xm.xla_device() + def test_topk_symint_ir_1(self): t1 = torch.zeros([5, 2], device=dev) t1[3][0] = 1 t1[3][1] = 1 @@ -196,6 +195,20 @@ def test_topk_symint(self): self.assertEqual(values.shape[1], 2) self.assertEqual(indices.shape[1], 2) + def test_topk_symint_ir_2(self): + t1 = torch.ones(20, device=dev) + t1[::3] = 0 + t2 = torch.ones(10, device=dev) + t2[::2] = 0 + k = torch.nonzero(t2).shape[0] + values, indices = torch.topk(t1, k, dim=0) + self.assertIsInstance(values.shape[0], torch.SymInt) + self.assertIsInstance(indices.shape[0], torch.SymInt) + self.assertEqual(str(values.shape[0]), '<=10') + self.assertEqual(str(indices.shape[0]), '<=10') + + def test_topk_symint_correctness(self): + def test_fn(*tensors): torch.manual_seed(0) x = torch.rand(10, 10).to(tensors[0].device) diff --git a/torch_xla/csrc/ops/topk.cpp b/torch_xla/csrc/ops/topk.cpp index ddc6ab6ee909..e47b6a33ef82 100644 --- a/torch_xla/csrc/ops/topk.cpp +++ b/torch_xla/csrc/ops/topk.cpp @@ -18,23 +18,6 @@ xla::Shape NodeOutputShape(const torch::lazy::Value& input, int64_t k, return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn); } -xla::Shape NodeOutputShapeSymInt(const torch::lazy::Value& input, - int64_t k_upper_bound, int64_t dim, - bool largest, bool sorted, bool stable) { - xla::Shape input_shape = GetXlaShape(input); - std::vector dimensions(input_shape.dimensions().begin(), - input_shape.dimensions().end()); - XLA_CHECK_LT(dim, input_shape.rank()); - dimensions[dim] = k_upper_bound; - xla::Shape values_shape = - xla::ShapeUtil::MakeShape(input_shape.element_type(), dimensions); - xla::Shape indices_shape = - xla::ShapeUtil::MakeShape(xla::PrimitiveType::S64, dimensions); - values_shape.set_dynamic_dimension(dim, true); - indices_shape.set_dynamic_dimension(dim, true); - return xla::ShapeUtil::MakeTupleShape({values_shape, indices_shape}); -} - } // namespace TopK::TopK(const torch::lazy::Value& input, int64_t k, int64_t dim, @@ -69,41 +52,4 @@ std::string TopK::ToString() const { return ss.str(); } -TopKSymInt::TopKSymInt(const torch::lazy::Value& input, const SymIntElements& k, - int64_t dim, bool largest, bool sorted, bool stable) - : XlaNode(torch::lazy::OpKind(at::aten::topk), - {input, torch::lazy::Value(k.GetSizeNodes().front())}, - [&]() { - return NodeOutputShapeSymInt(input, k.GetUpperBounds().front(), - dim, largest, sorted, stable); - }, - /*num_outputs=*/2, - torch::lazy::MHash(k.GetUpperBounds().front(), dim, largest, - sorted, stable)), - k_upper_bound_(k.GetUpperBounds().front()), - dim_(dim), - largest_(largest), - sorted_(sorted), - stable_(stable) {} - -XlaOpVector TopKSymInt::Lower(LoweringContext* loctx) const { - xla::XlaOp input = loctx->GetOutputOp(operand(0)); - xla::XlaOp size_op = loctx->GetOutputOp(operand(1)); - std::vector results = - CreateTopK(input, k_upper_bound_, dim_, largest_, stable_); - std::vector resized_results; - std::transform( - results.begin(), results.end(), std::back_inserter(resized_results), - [&](xla::XlaOp op) { return xla::SetDimensionSize(op, size_op, dim_); }); - return ReturnOps(resized_results, loctx); -} - -std::string TopKSymInt::ToString() const { - std::stringstream ss; - ss << XlaNode::ToString() << ", k<=" << k_upper_bound_ << ", dim=" << dim_ - << ", largest=" << largest_ << ", sorted=" << sorted_ - << ", stable=" << stable_; - return ss.str(); -} - } // namespace torch_xla diff --git a/torch_xla/csrc/ops/topk.h b/torch_xla/csrc/ops/topk.h index f1b899f97c96..39598b048a7e 100644 --- a/torch_xla/csrc/ops/topk.h +++ b/torch_xla/csrc/ops/topk.h @@ -1,7 +1,6 @@ #pragma once #include "torch_xla/csrc/ir.h" -#include "torch_xla/csrc/torch_util.h" namespace torch_xla { @@ -34,21 +33,4 @@ class TopK : public XlaNode { bool stable_; }; -class TopKSymInt : public XlaNode { - public: - TopKSymInt(const torch::lazy::Value& input, const SymIntElements& k, - int64_t dim, bool largest, bool sorted, bool stable); - - std::string ToString() const override; - - XlaOpVector Lower(LoweringContext* loctx) const override; - - private: - int64_t k_upper_bound_; - int64_t dim_; - bool largest_; - bool sorted_; - bool stable_; -}; - } // namespace torch_xla diff --git a/torch_xla/csrc/ops/topk_symint.cpp b/torch_xla/csrc/ops/topk_symint.cpp new file mode 100644 index 000000000000..3156aa1b8934 --- /dev/null +++ b/torch_xla/csrc/ops/topk_symint.cpp @@ -0,0 +1,67 @@ +#include "torch_xla/csrc/ops/topk_symint.h" + +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/infer_output_shape.h" +#include "torch_xla/csrc/tensor_util.h" +#include "torch_xla/csrc/xla_lower_util.h" + +namespace torch_xla { +namespace { + +xla::Shape NodeOutputShapeSymInt(const torch::lazy::Value& input, + int64_t k_upper_bound, int64_t dim, + bool largest, bool sorted, bool stable) { + xla::Shape input_shape = GetXlaShape(input); + std::vector dimensions(input_shape.dimensions().begin(), + input_shape.dimensions().end()); + XLA_CHECK_LT(dim, input_shape.rank()); + dimensions[dim] = k_upper_bound; + xla::Shape values_shape = + xla::ShapeUtil::MakeShape(input_shape.element_type(), dimensions); + xla::Shape indices_shape = + xla::ShapeUtil::MakeShape(xla::PrimitiveType::S64, dimensions); + values_shape.set_dynamic_dimension(dim, true); + indices_shape.set_dynamic_dimension(dim, true); + return xla::ShapeUtil::MakeTupleShape({values_shape, indices_shape}); +} + +} // namespace + +TopKSymInt::TopKSymInt(const torch::lazy::Value& input, const SymIntElements& k, + int64_t dim, bool largest, bool sorted, bool stable) + : XlaNode(torch::lazy::OpKind(at::aten::topk), + {input, torch::lazy::Value(k.GetSizeNodes().front())}, + [&]() { + return NodeOutputShapeSymInt(input, k.GetUpperBounds().front(), + dim, largest, sorted, stable); + }, + /*num_outputs=*/2, + torch::lazy::MHash(k.GetUpperBounds().front(), dim, largest, + sorted, stable)), + k_upper_bound_(k.GetUpperBounds().front()), + dim_(dim), + largest_(largest), + sorted_(sorted), + stable_(stable) {} + +XlaOpVector TopKSymInt::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + xla::XlaOp size_op = loctx->GetOutputOp(operand(1)); + std::vector results = + CreateTopK(input, k_upper_bound_, dim_, largest_, stable_); + std::vector resized_results; + std::transform( + results.begin(), results.end(), std::back_inserter(resized_results), + [&](xla::XlaOp op) { return xla::SetDimensionSize(op, size_op, dim_); }); + return ReturnOps(resized_results, loctx); +} + +std::string TopKSymInt::ToString() const { + std::stringstream ss; + ss << XlaNode::ToString() << ", k<=" << k_upper_bound_ << ", dim=" << dim_ + << ", largest=" << largest_ << ", sorted=" << sorted_ + << ", stable=" << stable_; + return ss.str(); +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/topk_symint.h b/torch_xla/csrc/ops/topk_symint.h new file mode 100644 index 000000000000..e0f7854c111b --- /dev/null +++ b/torch_xla/csrc/ops/topk_symint.h @@ -0,0 +1,25 @@ +#pragma once + +#include "torch_xla/csrc/ir.h" +#include "torch_xla/csrc/torch_util.h" + +namespace torch_xla { + +class TopKSymInt : public XlaNode { + public: + TopKSymInt(const torch::lazy::Value& input, const SymIntElements& k, + int64_t dim, bool largest, bool sorted, bool stable); + + std::string ToString() const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + private: + int64_t k_upper_bound_; + int64_t dim_; + bool largest_; + bool sorted_; + bool stable_; +}; + +} // namespace torch_xla diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 79348dcb7428..1bb661ba8903 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -121,6 +121,7 @@ #include "torch_xla/csrc/ops/threshold.h" #include "torch_xla/csrc/ops/threshold_backward.h" #include "torch_xla/csrc/ops/topk.h" +#include "torch_xla/csrc/ops/topk_symint.h" #include "torch_xla/csrc/ops/triangular_solve.h" #include "torch_xla/csrc/ops/uniform.h" #include "torch_xla/csrc/ops/unsqueeze.h" From 1f39f884f2d604d1e054992baf29848234c04d9b Mon Sep 17 00:00:00 2001 From: Yanming Wang Date: Tue, 28 Feb 2023 04:49:11 +0000 Subject: [PATCH 4/4] Remove torch_pin --- torch_patches/.torch_pin | 1 - 1 file changed, 1 deletion(-) delete mode 100644 torch_patches/.torch_pin diff --git a/torch_patches/.torch_pin b/torch_patches/.torch_pin deleted file mode 100644 index 813d6b42c541..000000000000 --- a/torch_patches/.torch_pin +++ /dev/null @@ -1 +0,0 @@ -#95015