Skip to content

Conversation

ymwangg
Copy link
Contributor

@ymwangg ymwangg commented Feb 16, 2023

This PR adds dynamic shape support for topk.

@JackCaoG JackCaoG added the dynamism Dynamic Shape Features label Feb 17, 2023
@miladm miladm requested review from miladm and vanbasten23 February 17, 2023 01:40
xla::Shape values_shape =
xla::ShapeUtil::MakeShape(input_shape.element_type(), dimensions);
xla::Shape indices_shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S64, dimensions);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index shape is actully s32, can you use https://github.com/pytorch/xla/blob/master/torch_xla/csrc/tensor_util.cpp#L1254 ? context is SetDimensionShape and getDimensionShpe currently only handle s32, this is where we have slight difference with upstream.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the static version also returns s64 type, because of a s32 to s64 cast here:

// aten::topk() wants Long tensors as indices.
return {values, xla::ConvertElementType(
indices, GetDevicePrimitiveType(xla::PrimitiveType::S64,
/*device=*/nullptr))};

Do I also need to disable this cast?

@vanbasten23
Copy link
Collaborator

Thanks for the contribution!

input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long));
}

std::tuple<XLATensorPtr, XLATensorPtr> topk_symint(const XLATensorPtr& input,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quick question, now that we have a symint version of topk, do we still need the original topk?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I basically followed the expand_symint design, which keeps the static expand implementation. I guess the original topk implementation is still needed because the codegen for handling SymInt is slightly different (e.g. SetDimensionSize).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, let's keep the static version their.

t1[3][1] = 1
t2 = torch.nonzero(t1)
t3 = torch.zeros([10, 2], device=dev)
values, indices = torch.topk(t3, t2.shape[0], dim=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if your change can handle the case where t3 is dynamic, eg t3=torch.nonzero(t1).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this feature is already supported by xla through dynamic padder (static topk also works with dynamic t3). I'll add a test case for it later.

largest, sorted, stable);
return std::make_tuple(
input->CreateFrom(torch::lazy::Value(node, 0)),
input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some pytorch tests require the static topk returned indices type to be int64, otherwise it throws an error in pytorch here. I saw nonzero returns int32 indices. Why not just do a int64 (pytorch) -> int32 (xla) -> int64 (pytorch) conversion for handling indices?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so what happened here is that underlying XLA type will still be xla::s32, however we mark the logical element type will be at::ScalarType::Long. This will create a problem since underlying s32 can not represent all numbers for long. We run into issues when we try to cast this kind of tensor(long logical type, real s32 type) to other dtypes, because Cast op doesn't know how to handle this kind of cases.

The right thing to do is probably do a manual s64->s32->s64 cast instead of playing the trick of logical_element_type here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification! Do we want to add something like res = torch::lazy::MakeNode<Cast>(res, at::ScalarType::Long) here or use
xla::ConvertElementType(indices, xla::PrimitiveType::S64, /*device=*/nullptr)) in CreateTopK like before?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's handle it inside CreateTopK so it is cleaner.

@ymwangg
Copy link
Contributor Author

ymwangg commented Feb 23, 2023

This PR requires removing torch_pin. Let me ping someone for review.

@JackCaoG
Copy link
Collaborator

@ymwangg can upstream pr merge without this one? If so we can merge the upstream one first, otherwise we would have to do the two way pining.

@ymwangg
Copy link
Contributor Author

ymwangg commented Feb 23, 2023

@JackCaoG Yes, the upstream PR does not depend on anything on xla.

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Feb 24, 2023
Companion PR for pytorch/xla#4644.

Pull Request resolved: #95015
Approved by: https://github.com/ezyang
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 25, 2023
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 25, 2023
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 5, 2023
@JackCaoG
Copy link
Collaborator

Sorry I have overlooked this pr. @vanbasten23 can you take over this pr and rebase? I think we should merge this pr since it is already here.

jhavukainen pushed a commit to kulinseth/pytorch that referenced this pull request Mar 15, 2024
Companion PR for pytorch/xla#4644.

Pull Request resolved: pytorch#95015
Approved by: https://github.com/ezyang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

dynamism Dynamic Shape Features

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants