-
Notifications
You must be signed in to change notification settings - Fork 561
Added aten::kthvalue and aten::topk operations #342
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
torch_xla/csrc/xla_lower_util.cpp
Outdated
| xla::Shape iota_shape = | ||
| xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, shape.dimensions()); | ||
| xla::XlaOp iota = xla::Iota(input.builder(), iota_shape, dim); | ||
| // TODO: Remember to add is_stable=true as last Sort() argument when fetching |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would we need stability here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To return indices in a predictable fashion.
Though, I am not sure what the PT semantics are on ties, so it might not matter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The documentation doesn't say anything about it, I'd rather not do it. Stable sort is slower.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure.
torch_xla/csrc/xla_lower_util.cpp
Outdated
|
|
||
| std::vector<xla::XlaOp> CreateTopK(const xla::XlaOp& input, xla::int64 k, | ||
| xla::int64 dim, bool largest, | ||
| bool /* sorted */) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is sorted supposed to do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assumed it was an optimization to return in order, or any order.
But now that I think about it, could be in-order (the default), or in the native order (sorted=false).
Native order sucks because we have to either throw away the sort result and do a gather on the sorted indices, or shuffle the sort result back into native order.
Let me pun to CPU in case of not-sorted until I work my way on the best approach.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yikes, yes, let's not handle not-sorted until we figure it out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, seems a bit "whatever".
>>> x = torch.randn(10)
>>> print(x)
tensor([-1.3195, -0.3020, 0.2261, -2.0767, -1.5631, 0.3084, -1.4165, -0.2338,
0.8787, 0.5441])
>>> y = torch.topk(x, 2, 0, largest=True, sorted=True)
>>> print(y[0])
tensor([0.8787, 0.5441])
>>> y = torch.topk(x, 2, 0, largest=True, sorted=False)
>>> print(y[0])
tensor([0.5441, 0.8787])
With sorted=False I would have expected to see [0.8787, 0.5441] (the native order).
I think sorted=False it is really whatever, and we might be able to remove the pun-to-CPU.
No description provided.