-
Notifications
You must be signed in to change notification settings - Fork 560
Add dynamic support for topk #4644
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
torch_xla/csrc/ops/topk.cpp
Outdated
xla::Shape values_shape = | ||
xla::ShapeUtil::MakeShape(input_shape.element_type(), dimensions); | ||
xla::Shape indices_shape = | ||
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S64, dimensions); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
index shape is actully s32, can you use https://github.com/pytorch/xla/blob/master/torch_xla/csrc/tensor_util.cpp#L1254 ? context is SetDimensionShape
and getDimensionShpe
currently only handle s32, this is where we have slight difference with upstream.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually the static version also returns s64 type, because of a s32 to s64 cast here:
xla/torch_xla/csrc/xla_lower_util.cpp
Lines 380 to 383 in 1bbe4da
// aten::topk() wants Long tensors as indices. | |
return {values, xla::ConvertElementType( | |
indices, GetDevicePrimitiveType(xla::PrimitiveType::S64, | |
/*device=*/nullptr))}; |
Do I also need to disable this cast?
Thanks for the contribution! |
input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long)); | ||
} | ||
|
||
std::tuple<XLATensorPtr, XLATensorPtr> topk_symint(const XLATensorPtr& input, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
quick question, now that we have a symint version of topk
, do we still need the original topk
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I basically followed the expand_symint
design, which keeps the static expand implementation. I guess the original topk
implementation is still needed because the codegen for handling SymInt is slightly different (e.g. SetDimensionSize).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, let's keep the static version their.
t1[3][1] = 1 | ||
t2 = torch.nonzero(t1) | ||
t3 = torch.zeros([10, 2], device=dev) | ||
values, indices = torch.topk(t3, t2.shape[0], dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if your change can handle the case where t3 is dynamic, eg t3=torch.nonzero(t1).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this feature is already supported by xla through dynamic padder (static topk also works with dynamic t3). I'll add a test case for it later.
largest, sorted, stable); | ||
return std::make_tuple( | ||
input->CreateFrom(torch::lazy::Value(node, 0)), | ||
input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some pytorch tests require the static topk returned indices type to be int64, otherwise it throws an error in pytorch here. I saw nonzero
returns int32 indices. Why not just do a int64 (pytorch) -> int32 (xla) -> int64 (pytorch)
conversion for handling indices?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so what happened here is that underlying XLA type will still be xla::s32
, however we mark the logical element type will be at::ScalarType::Long
. This will create a problem since underlying s32 can not represent all numbers for long. We run into issues when we try to cast this kind of tensor(long logical type, real s32 type) to other dtypes, because Cast
op doesn't know how to handle this kind of cases.
The right thing to do is probably do a manual s64->s32->s64 cast instead of playing the trick of logical_element_type
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clarification! Do we want to add something like res = torch::lazy::MakeNode<Cast>(res, at::ScalarType::Long)
here or use
xla::ConvertElementType(indices, xla::PrimitiveType::S64, /*device=*/nullptr))
in CreateTopK
like before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's handle it inside CreateTopK
so it is cleaner.
This PR requires removing torch_pin. Let me ping someone for review. |
@ymwangg can upstream pr merge without this one? If so we can merge the upstream one first, otherwise we would have to do the two way pining. |
@JackCaoG Yes, the upstream PR does not depend on anything on xla. |
Companion PR for pytorch/xla#4644. Pull Request resolved: #95015 Approved by: https://github.com/ezyang
Companion PR for pytorch/xla#4644. Pull Request resolved: pytorch/pytorch#95015 Approved by: https://github.com/ezyang
Companion PR for pytorch/xla#4644. Pull Request resolved: pytorch/pytorch#95015 Approved by: https://github.com/ezyang
Companion PR for pytorch/xla#4644. Pull Request resolved: pytorch/pytorch#95015 Approved by: https://github.com/ezyang
Sorry I have overlooked this pr. @vanbasten23 can you take over this pr and rebase? I think we should merge this pr since it is already here. |
Companion PR for pytorch/xla#4644. Pull Request resolved: pytorch#95015 Approved by: https://github.com/ezyang
This PR adds dynamic shape support for topk.