Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

indices_or_sections parameter of tensor_split() doesn't work while sections, indices or tensor_indices_or_sections parameter works #127010

Open
hyperkai opened this issue May 23, 2024 · 0 comments
Labels
module: python frontend For issues relating to PyTorch's Python frontend topic: fuzzer triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@hyperkai
Copy link

hyperkai commented May 23, 2024

πŸ“š The doc issue

The doc of tensor_split() says that there is indices_or_sections parameter as shown below but there isn't:

indices_or_sections (Tensor, int or list or tuple of ints) –

Actually, indices_or_sections parameter doesn't work while sections(int), indices(tuple of int or list of int) or tensor_indices_or_sections(tensor of int) parameter works as shown below:

import torch

my_tensor = torch.tensor([[0, 1, 2, 3],
                          [4, 5, 6, 7],
                          [8, 9, 10, 11]])
torch.tensor_split(input=my_tensor, indices_or_sections=1) # Error
torch.tensor_split(input=my_tensor, indices_or_sections=(1, 2)) # Error
torch.tensor_split(input=my_tensor, indices_or_sections=torch.tensor(1)) # Error
torch.tensor_split(input=my_tensor, indices_or_sections=torch.tensor([1, 2])) # Error

torch.tensor_split(input=my_tensor, sections=1)
# (tensor([[0, 1, 2, 3],
#          [4, 5, 6, 7],
#          [8, 9, 10, 11]]),)

torch.tensor_split(input=my_tensor, indices=(1, 2))
# (tensor([[0, 1, 2, 3]]),
#  tensor([[4, 5, 6, 7]]),
#  tensor([[8, 9, 10, 11]]))

torch.tensor_split(input=my_tensor, tensor_indices_or_sections=torch.tensor(1))
# (tensor([[0, 1, 2, 3],
#          [4, 5, 6, 7],
#          [8, 9, 10, 11]]),)

torch.tensor_split(input=my_tensor, tensor_indices_or_sections=torch.tensor([1, 2]))
# (tensor([[0, 1, 2, 3]]),
#  tensor([[4, 5, 6, 7]]),
#  tensor([[8, 9, 10, 11]]))

Suggest a potential alternative/fix

So, the doc of tensor_split() should be as shown below:

indices(tuple of int or list of int), sections(int) or tensor_indices_or_sections(tensor of int) –

cc @albanD

@bdhirsh bdhirsh added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: python frontend For issues relating to PyTorch's Python frontend labels May 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: python frontend For issues relating to PyTorch's Python frontend topic: fuzzer triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants