Skip to content

Conversation

@dlibenzi
Copy link
Collaborator

No description provided.

@dlibenzi dlibenzi requested a review from asuhan February 25, 2019 00:42
xla::Shape iota_shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, shape.dimensions());
xla::XlaOp iota = xla::Iota(input.builder(), iota_shape, dim);
// TODO: Remember to add is_stable=true as last Sort() argument when fetching
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would we need stability here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To return indices in a predictable fashion.
Though, I am not sure what the PT semantics are on ties, so it might not matter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation doesn't say anything about it, I'd rather not do it. Stable sort is slower.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

@dlibenzi dlibenzi changed the title Added aten::kthvalue() operation. Added aten::kthvalue() and aten::topk operations Feb 25, 2019
@dlibenzi dlibenzi changed the title Added aten::kthvalue() and aten::topk operations Added aten::kthvalue and aten::topk operations Feb 25, 2019

std::vector<xla::XlaOp> CreateTopK(const xla::XlaOp& input, xla::int64 k,
xla::int64 dim, bool largest,
bool /* sorted */) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is sorted supposed to do?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assumed it was an optimization to return in order, or any order.
But now that I think about it, could be in-order (the default), or in the native order (sorted=false).
Native order sucks because we have to either throw away the sort result and do a gather on the sorted indices, or shuffle the sort result back into native order.
Let me pun to CPU in case of not-sorted until I work my way on the best approach.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yikes, yes, let's not handle not-sorted until we figure it out.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, seems a bit "whatever".

>>> x = torch.randn(10)
>>> print(x)
tensor([-1.3195, -0.3020,  0.2261, -2.0767, -1.5631,  0.3084, -1.4165, -0.2338,
         0.8787,  0.5441])
>>> y = torch.topk(x, 2, 0, largest=True, sorted=True)
>>> print(y[0])
tensor([0.8787, 0.5441])
>>> y = torch.topk(x, 2, 0, largest=True, sorted=False)
>>> print(y[0])
tensor([0.5441, 0.8787])

With sorted=False I would have expected to see [0.8787, 0.5441] (the native order).
I think sorted=False it is really whatever, and we might be able to remove the pun-to-CPU.

@dlibenzi dlibenzi merged commit a258155 into master Feb 25, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants