-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Request for torch.split
to accept a tensor for input split_size_or_sections
#47479
Comments
The issue is that there's some shady conversion going on here: Line 509 in 3549141
What requested can be safely applied to |
We also have a NumPy compatible tensor_split that should throw an error currently when passed a tensor. We could avoid a BC concern and document this function treating a CPU tensor as a list. This raises the more general questions:
I think it'd be OK (from a UX perspective) to consistently interpret CPU tensors as lists when an operand can take a list and as a scalar when an operand can only be a scalar. Device types like XLA will probably want the tensor to be passed as an XLA tensor, however. See #31558 and cc @ailzhang. Unfortunately I don't think we can always match these tensors to the device type that will run the operation because passing a CUDA tensor and converting it to a list would cause cross-device data movement. For the second question, I suppose tracing will be around long enough that adding this support to a few functions would be OK. Especially since if we used tensor_split I don't think we'd have any BC concerns. Do you think tensor_split would work for you, @edqwerty10? |
Hi @mruberry, this is great thank you for the breakdown. Yes |
OK. We would accept a PR updating tensor_split to accept a CPU tensor in place of a list. Note this is consistent with NumPy:
|
The fastest way would be for you to submit a PR implementing the behavior. If that's a pain you can ping me internally and we can prioritize the request. |
Any chance we can allow a cuda tensor for indices? |
No, because the indices tensor is used to define the tensor outputs, and the metadata for tensors lives on the CPU, so the CPU has to be able to access the data in indices. |
…rgument (pytorch#49169) Summary: Pull Request resolved: pytorch#49169 Trying to solve PR request pytorch#47479. This diff tries to overload method `torch.tensor_split` to also accept a tensor for argument `split_size_or_sections` which currently accepts a python list or int. The motivation is to avoid converting a tensor to a list so that when tracing a model/module the tensor operations can be recorded. Implementation is following the diff that originally added the `tensor_split` method D24166164 (pytorch@ef4817f). Test Plan: ``` buck test caffe2/test:torch -- tensor_split ``` https://www.internalfb.com/intern/testinfra/testconsole/testrun/5910974550563805/ ``` buck test caffe2/test:others -- tensor_split ``` https://www.internalfb.com/intern/testinfra/testconsole/testrun/1688849905082678/ Reviewed By: mruberry Differential Revision: D25440885 fbshipit-source-id: ca5d134cfb91fa0efc3dec5257dbc97532eb2d74
…rgument (#49169) Summary: Pull Request resolved: #49169 Trying to solve PR request #47479. This diff tries to overload method `torch.tensor_split` to also accept a tensor for argument `split_size_or_sections` which currently accepts a python list or int. The motivation is to avoid converting a tensor to a list so that when tracing a model/module the tensor operations can be recorded. Implementation is following the diff that originally added the `tensor_split` method D24166164 (ef4817f). Test Plan: ``` buck test caffe2/test:torch -- tensor_split ``` https://www.internalfb.com/intern/testinfra/testconsole/testrun/5910974550563805/ ``` buck test caffe2/test:others -- tensor_split ``` https://www.internalfb.com/intern/testinfra/testconsole/testrun/1688849905082678/ Reviewed By: mruberry Differential Revision: D25440885 fbshipit-source-id: 6705dc551279e3a5eb1e5ec1ede2728eab85ffb1
Closing since we resolved this request by implementing the functionality in torch.tensor_split, and the request for this functionality in torch.split is a dupe of #16703. |
…rgument (pytorch#49169) Summary: Pull Request resolved: pytorch#49169 Trying to solve PR request pytorch#47479. This diff tries to overload method `torch.tensor_split` to also accept a tensor for argument `split_size_or_sections` which currently accepts a python list or int. The motivation is to avoid converting a tensor to a list so that when tracing a model/module the tensor operations can be recorded. Implementation is following the diff that originally added the `tensor_split` method D24166164 (pytorch@ef4817f). Test Plan: ``` buck test caffe2/test:torch -- tensor_split ``` https://www.internalfb.com/intern/testinfra/testconsole/testrun/5910974550563805/ ``` buck test caffe2/test:others -- tensor_split ``` https://www.internalfb.com/intern/testinfra/testconsole/testrun/1688849905082678/ Reviewed By: mruberry Differential Revision: D25440885 fbshipit-source-id: 6705dc551279e3a5eb1e5ec1ede2728eab85ffb1
🚀 Feature
For the pytorch operator
torch.split(tensor, split_size_or_sections, dim)
we would like the inputsplit_size_or_sections
to also handle tensors. Right now it handles list of ints and an int.Motivation
When tracing a model, we currently do
split_size_or_sections=tensor.tolist()
when callingtorch.split
but tracing can't record this conversion of tensor to list of ints so the traced model fails on other inputs. This is currently blocking a diff that tries to improve runtime performance by using tensor operations, D24595761.Pitch
Have
torch.split(tensor, tensor, int)
work so that tracing can properly record the operations for a traced model.Alternatives
We currently don't have alternatives (I think). We are migrating to a scripted model (which works in this case) but support for tracing is still requested.
Additional context
Here is a diff D24595761 that tries to improve performance by using torch operators only but fails for a traced model
The text was updated successfully, but these errors were encountered: