-
Notifications
You must be signed in to change notification settings - Fork 560
Add dynamic support for topk #4644
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| #include "torch_xla/csrc/ops/topk_symint.h" | ||
|
|
||
| #include "torch_xla/csrc/lowering_context.h" | ||
| #include "torch_xla/csrc/ops/infer_output_shape.h" | ||
| #include "torch_xla/csrc/tensor_util.h" | ||
| #include "torch_xla/csrc/xla_lower_util.h" | ||
|
|
||
| namespace torch_xla { | ||
| namespace { | ||
|
|
||
| xla::Shape NodeOutputShapeSymInt(const torch::lazy::Value& input, | ||
| int64_t k_upper_bound, int64_t dim, | ||
| bool largest, bool sorted, bool stable) { | ||
| xla::Shape input_shape = GetXlaShape(input); | ||
| std::vector<int64_t> dimensions(input_shape.dimensions().begin(), | ||
| input_shape.dimensions().end()); | ||
| XLA_CHECK_LT(dim, input_shape.rank()); | ||
| dimensions[dim] = k_upper_bound; | ||
| xla::Shape values_shape = | ||
| xla::ShapeUtil::MakeShape(input_shape.element_type(), dimensions); | ||
| xla::Shape indices_shape = | ||
| xla::ShapeUtil::MakeShape(xla::PrimitiveType::S64, dimensions); | ||
| values_shape.set_dynamic_dimension(dim, true); | ||
| indices_shape.set_dynamic_dimension(dim, true); | ||
| return xla::ShapeUtil::MakeTupleShape({values_shape, indices_shape}); | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| TopKSymInt::TopKSymInt(const torch::lazy::Value& input, const SymIntElements& k, | ||
| int64_t dim, bool largest, bool sorted, bool stable) | ||
| : XlaNode(torch::lazy::OpKind(at::aten::topk), | ||
| {input, torch::lazy::Value(k.GetSizeNodes().front())}, | ||
| [&]() { | ||
| return NodeOutputShapeSymInt(input, k.GetUpperBounds().front(), | ||
| dim, largest, sorted, stable); | ||
| }, | ||
| /*num_outputs=*/2, | ||
| torch::lazy::MHash(k.GetUpperBounds().front(), dim, largest, | ||
| sorted, stable)), | ||
| k_upper_bound_(k.GetUpperBounds().front()), | ||
| dim_(dim), | ||
| largest_(largest), | ||
| sorted_(sorted), | ||
| stable_(stable) {} | ||
|
|
||
| XlaOpVector TopKSymInt::Lower(LoweringContext* loctx) const { | ||
| xla::XlaOp input = loctx->GetOutputOp(operand(0)); | ||
| xla::XlaOp size_op = loctx->GetOutputOp(operand(1)); | ||
| std::vector<xla::XlaOp> results = | ||
| CreateTopK(input, k_upper_bound_, dim_, largest_, stable_); | ||
| std::vector<xla::XlaOp> resized_results; | ||
| std::transform( | ||
| results.begin(), results.end(), std::back_inserter(resized_results), | ||
| [&](xla::XlaOp op) { return xla::SetDimensionSize(op, size_op, dim_); }); | ||
| return ReturnOps(resized_results, loctx); | ||
| } | ||
|
|
||
| std::string TopKSymInt::ToString() const { | ||
| std::stringstream ss; | ||
| ss << XlaNode::ToString() << ", k<=" << k_upper_bound_ << ", dim=" << dim_ | ||
| << ", largest=" << largest_ << ", sorted=" << sorted_ | ||
| << ", stable=" << stable_; | ||
| return ss.str(); | ||
| } | ||
|
|
||
| } // namespace torch_xla |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| #pragma once | ||
|
|
||
| #include "torch_xla/csrc/ir.h" | ||
| #include "torch_xla/csrc/torch_util.h" | ||
|
|
||
| namespace torch_xla { | ||
|
|
||
| class TopKSymInt : public XlaNode { | ||
| public: | ||
| TopKSymInt(const torch::lazy::Value& input, const SymIntElements& k, | ||
| int64_t dim, bool largest, bool sorted, bool stable); | ||
|
|
||
| std::string ToString() const override; | ||
|
|
||
| XlaOpVector Lower(LoweringContext* loctx) const override; | ||
|
|
||
| private: | ||
| int64_t k_upper_bound_; | ||
| int64_t dim_; | ||
| bool largest_; | ||
| bool sorted_; | ||
| bool stable_; | ||
| }; | ||
|
|
||
| } // namespace torch_xla |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -121,6 +121,7 @@ | |
| #include "torch_xla/csrc/ops/threshold.h" | ||
| #include "torch_xla/csrc/ops/threshold_backward.h" | ||
| #include "torch_xla/csrc/ops/topk.h" | ||
| #include "torch_xla/csrc/ops/topk_symint.h" | ||
| #include "torch_xla/csrc/ops/triangular_solve.h" | ||
| #include "torch_xla/csrc/ops/uniform.h" | ||
| #include "torch_xla/csrc/ops/unsqueeze.h" | ||
|
|
@@ -2505,6 +2506,20 @@ std::tuple<XLATensorPtr, XLATensorPtr> topk(const XLATensorPtr& input, | |
| input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long)); | ||
| } | ||
|
|
||
| std::tuple<XLATensorPtr, XLATensorPtr> topk_symint(const XLATensorPtr& input, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. quick question, now that we have a symint version of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I basically followed the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yea, let's keep the static version their. |
||
| c10::SymInt k, int64_t dim, | ||
| bool largest, bool sorted, | ||
| bool stable) { | ||
| SymIntElements k_symint = SymIntElements(k); | ||
| torch::lazy::NodePtr node = torch::lazy::MakeNode<TopKSymInt>( | ||
| input->GetIrValue(), k_symint, | ||
| torch::lazy::GetCanonicalDimensionIndex(dim, input->shape().get().rank()), | ||
| largest, sorted, stable); | ||
| return std::make_tuple( | ||
| input->CreateFrom(torch::lazy::Value(node, 0)), | ||
| input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long)); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some pytorch tests require the static topk returned indices type to be int64, otherwise it throws an error in pytorch here. I saw There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so what happened here is that underlying XLA type will still be The right thing to do is probably do a manual s64->s32->s64 cast instead of playing the trick of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the clarification! Do we want to add something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's handle it inside |
||
| } | ||
|
|
||
| XLATensorPtr trace(const XLATensorPtr& input) { | ||
| auto input_shape_ref = input->shape(); | ||
| XLA_CHECK_EQ((*input_shape_ref).rank(), 2) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if your change can handle the case where t3 is dynamic, eg t3=torch.nonzero(t1).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this feature is already supported by xla through dynamic padder (static topk also works with dynamic t3). I'll add a test case for it later.