New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.nn.parallel.scatter_gather.gather cannot handle namedtuple as output #50510
Comments
cc @rohan-varma Since this seems like an enhancement on top of #44220. |
Hello, I reproduced the error, and I can confirm the author's fix works. However, there is a deprecation warning when putting the scatter output to the CPU.
I did just as the warning instructed (to use I'm writing a unittest as per the example in Reproduced Code using random tensors import torch
import collections
fields = ['a', 'b', 'c']
MyTuple = collections.namedtuple('MyTuple', fields)
a = torch.rand(2 * 3, device=0)
b = torch.rand(2 * 3, device=0)
a_tensors_for_gpu = [a[2 * i : 2 * i + 2].to(0) for i in range(3)]
out3 = MyTuple(a_tensors_for_gpu[0], a_tensors_for_gpu[1], a_tensors_for_gpu[2])
print(out3)
b_tensors_for_gpu = [b[2 * i : 2 * i + 2].to(0) for i in range(3)]
out4 = MyTuple(b_tensors_for_gpu[0], b_tensors_for_gpu[1], b_tensors_for_gpu[2])
print(out4)
outputs = [out3, out4]
from torch.nn.parallel.scatter_gather import gather
out = gather(outputs, 'cpu')
for i, x in enumerate(out):
print (i, x)
|
@ajsanjoaquin Awesome, thanks for working on this! Feel free to put up a PR with the proposed fix. As far as testing on multi-gpu scenarios, when you create the PR, the relevant CI tests will be triggered and we can also test locally on a multi gpu machine if need be. Thanks! |
@rohan-varma I made a PR in case you haven't seen it for you to review. Thanks! |
@ajsanjoaquin Thanks! I'll take a look at the PR |
… handle moving output to CPU (#51104) Summary: Fixes #{[50510](#50510)} Allows ```torch.nn.parallel.scatter_gather.gather``` to accept a list of NamedTuples as input and returns a NamedTuple whose elements are tensors. I added the author's fix using the ```is_namedtuple``` function. While testing this fix, I encountered a deprecation warning instructing me to use ```'cpu'``` instead of ```-1``` to move the outputs to the CPU. However, doing this causes an assertion error in the ```_get_device_index``` function. I solved this by handling the CPU case in the affected ```forward``` function. rohan-varma Pull Request resolved: #51104 Reviewed By: albanD Differential Revision: D26395578 Pulled By: rohan-varma fbshipit-source-id: 6e98c9ce1d9f1725973c18d24a6554c1bceae465
Done by #51104 |
… handle moving output to CPU (pytorch#51104) Summary: Fixes #{[50510](pytorch#50510)} Allows ```torch.nn.parallel.scatter_gather.gather``` to accept a list of NamedTuples as input and returns a NamedTuple whose elements are tensors. I added the author's fix using the ```is_namedtuple``` function. While testing this fix, I encountered a deprecation warning instructing me to use ```'cpu'``` instead of ```-1``` to move the outputs to the CPU. However, doing this causes an assertion error in the ```_get_device_index``` function. I solved this by handling the CPU case in the affected ```forward``` function. rohan-varma Pull Request resolved: pytorch#51104 Reviewed By: albanD Differential Revision: D26395578 Pulled By: rohan-varma fbshipit-source-id: 6e98c9ce1d9f1725973c18d24a6554c1bceae465
@rohan-varma |
🐛 Bug
torch.nn.parallel.scatter_gather.gather
can't gather outputs of typenamedtuple
.To Reproduce
Steps to reproduce the behavior:
Error message:
Copied from execution on Google Colab.
Expected behavior
A successful gather should return an object of type
MyTuple
with the following values:Environment
I tested on my server as well as google colab.
Google Colab:
My server:
Additional context
I think it's related to #41327, but probably more generai as
namedtuple
is buit-in python class? I personally feel that sincescatter_gather.py
already has theis_namedtuple
function, this should be relatively easy to fix. For example, just add something like below (not thoroughly tested though):cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @agolynski @SciPioneer @H-Huang @mrzzd
The text was updated successfully, but these errors were encountered: