Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.nn.parallel.scatter_gather.gather cannot handle namedtuple as output #50510

Closed
zhmd opened this issue Jan 13, 2021 · 7 comments
Closed
Assignees
Labels
better-engineering Relatively self-contained tasks for better engineering contributors module: ddp Issues/PRs related distributed data parallel training oncall: distributed Add this issue/PR to distributed oncall triage queue

Comments

@zhmd
Copy link

zhmd commented Jan 13, 2021

🐛 Bug

torch.nn.parallel.scatter_gather.gather can't gather outputs of type namedtuple.

To Reproduce

Steps to reproduce the behavior:

import torch
import collections

MyTuple = collections.namedtuple('MyTuple', ['a', 'b', 'c'])
out1 = MyTuple(torch.tensor(1), torch.tensor(2), torch.tensor(3))
out2 = MyTuple(torch.tensor(4), torch.tensor(5), torch.tensor(6))
outputs = [out1, out2]

from torch.nn.parallel.scatter_gather import gather
gather(outputs, 0)

Error message:
Copied from execution on Google Colab.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-139-9cc42055abf8> in <module>()
      8 
      9 from torch.nn.parallel.scatter_gather import gather
---> 10 gather(outputs, 0)

1 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/parallel/scatter_gather.py in gather_map(outputs)
     61             return type(out)(((k, gather_map([d[k] for d in outputs]))
     62                               for k in out))
---> 63         return type(out)(map(gather_map, zip(*outputs)))
     64 
     65     # Recursive function calls like this create reference cycles.

TypeError: __new__() missing 2 required positional arguments: 'b' and 'c'

Expected behavior

A successful gather should return an object of type MyTuple with the following values:

MyTuple(a=tensor([1, 4]), b=tensor([2, 5]), c=tensor([3, 6]))

Environment

I tested on my server as well as google colab.

Google Colab:

Collecting environment information...
PyTorch version: 1.7.0+cu101
Is debug build: True
CUDA used to build PyTorch: 10.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final)
CMake version: version 3.12.0

Python version: 3.6 (64-bit runtime)
Is CUDA available: False
CUDA runtime version: 10.1.243
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.5
[pip3] torch==1.7.0+cu101
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.3.1
[pip3] torchvision==0.8.1+cu101
[conda] Could not collect

My server:

Collecting environment information...
PyTorch version: 1.8.0.dev20210113+cu101
Is debug build: False
CUDA used to build PyTorch: 10.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.4 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.10.2

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.0.130
GPU models and configuration:
GPU 0: Quadro RTX 5000
GPU 1: Quadro RTX 5000
GPU 2: Quadro RTX 5000
GPU 3: Quadro RTX 5000
GPU 4: Quadro RTX 5000
GPU 5: Quadro RTX 5000
GPU 6: Quadro RTX 5000
GPU 7: Quadro RTX 5000

Nvidia driver version: 440.44
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.18.5
[pip3] torch==1.8.0.dev20210113+cu101
[pip3] torchaudio==0.8.0.dev20210113
[pip3] torchvision==0.9.0.dev20210113+cu101
[conda] blas                      1.0                         mkl
[conda] cudatoolkit               10.1.243             h6bb024c_0
[conda] mkl                       2020.2                      256
[conda] mkl-service               2.3.0            py38he904b0f_0
[conda] mkl_fft                   1.2.0            py38h23d657b_0
[conda] mkl_random                1.1.1            py38h0573a6f_0
[conda] numpy                     1.18.5                   pypi_0    pypi
[conda] torch                     1.8.0.dev20210113+cu101          pypi_0    pypi
[conda] torchaudio                0.8.0.dev20210113          pypi_0    pypi
[conda] torchvision               0.9.0.dev20210113+cu101          pypi_0    pypi

Additional context

I think it's related to #41327, but probably more generai as namedtuple is buit-in python class? I personally feel that since scatter_gather.py already has the is_namedtuple function, this should be relatively easy to fix. For example, just add something like below (not thoroughly tested though):

def gather_map(outputs):
    out = outputs[0]
    if isinstance(out, torch.Tensor):
        return Gather.apply(target_device, dim, *outputs)
    if out is None:
        return None
    if isinstance(out, dict):
        if not all((len(out) == len(d) for d in outputs)):
            raise ValueError('All dicts must have the same number of keys')
        return type(out)(((k, gather_map([d[k] for d in outputs]))
                          for k in out))
    if is_namedtuple(out):
      return type(out)._make(map(gather_map, zip(*outputs)))
    return type(out)(map(gather_map, zip(*outputs)))

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @agolynski @SciPioneer @H-Huang @mrzzd

@mrshenli mrshenli added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jan 15, 2021
@pritamdamania87
Copy link
Contributor

cc @rohan-varma Since this seems like an enhancement on top of #44220.

@rohan-varma rohan-varma self-assigned this Jan 22, 2021
@rohan-varma rohan-varma added better-engineering Relatively self-contained tasks for better engineering contributors module: ddp Issues/PRs related distributed data parallel training labels Jan 22, 2021
@ajsanjoaquin
Copy link
Contributor

Hello, I reproduced the error, and I can confirm the author's fix works. However, there is a deprecation warning when putting the scatter output to the CPU.

pytorch\torch\nn\parallel\comm.py:231: UserWarning: Using -1 to represent CPU tensor is deprecated. 
Please use a device object or string instead, e.g., "cpu".

I did just as the warning instructed (to use out = gather(outputs, 'cpu') instead of out = gather(outputs, -1), but this throwed an error. I fixed the affected function by handling the cpu case.

I'm writing a unittest as per the example in test_cuda.py, but I'm not sure if I can test it myself given that my local machine only has 1 GPU... perhaps I can make a draft PR with the changes and someone can do the unittest on their end?

Reproduced Code using random tensors

import torch
import collections

fields = ['a', 'b', 'c']
MyTuple = collections.namedtuple('MyTuple', fields)

a = torch.rand(2 * 3, device=0)
b = torch.rand(2 * 3, device=0)

a_tensors_for_gpu = [a[2 * i : 2 * i + 2].to(0) for i in range(3)]
out3 = MyTuple(a_tensors_for_gpu[0], a_tensors_for_gpu[1], a_tensors_for_gpu[2])
print(out3)
b_tensors_for_gpu = [b[2 * i : 2 * i + 2].to(0) for i in range(3)]
out4 = MyTuple(b_tensors_for_gpu[0], b_tensors_for_gpu[1], b_tensors_for_gpu[2])
print(out4)
outputs = [out3, out4]

from torch.nn.parallel.scatter_gather import gather
out = gather(outputs, 'cpu')
for i, x in enumerate(out):
    print (i, x)
Output:
MyTuple(a=tensor([0.8345, 0.3308], device='cuda:0'), b=tensor([0.3901, 0.1534], device='cuda:0'), c=tensor([0.2477, 0.7246], device='cuda:0'))
MyTuple(a=tensor([0.2016, 0.4089], device='cuda:0'), b=tensor([0.6584, 0.5463], device='cuda:0'), c=tensor([0.0309, 0.0097], device='cuda:0'))
0 tensor([0.8345, 0.3308, 0.2016, 0.4089])
1 tensor([0.3901, 0.1534, 0.6584, 0.5463])
2 tensor([0.2477, 0.7246, 0.0309, 0.0097])

@rohan-varma
Copy link
Member

@ajsanjoaquin Awesome, thanks for working on this! Feel free to put up a PR with the proposed fix. As far as testing on multi-gpu scenarios, when you create the PR, the relevant CI tests will be triggered and we can also test locally on a multi gpu machine if need be. Thanks!

@ajsanjoaquin
Copy link
Contributor

@rohan-varma I made a PR in case you haven't seen it for you to review. Thanks!

@rohan-varma
Copy link
Member

@ajsanjoaquin Thanks! I'll take a look at the PR

facebook-github-bot pushed a commit that referenced this issue Feb 11, 2021
… handle moving output to CPU (#51104)

Summary:
Fixes #{[50510](#50510)}

Allows ```torch.nn.parallel.scatter_gather.gather``` to accept a list of NamedTuples as input and returns a NamedTuple whose elements are tensors. I added the author's fix using the ```is_namedtuple``` function.

While testing this fix, I encountered a deprecation warning instructing me to use ```'cpu'``` instead of ```-1``` to move the outputs to the CPU. However, doing this causes an assertion error in the ```_get_device_index``` function. I solved this by handling the CPU case in the affected ```forward``` function.
rohan-varma

Pull Request resolved: #51104

Reviewed By: albanD

Differential Revision: D26395578

Pulled By: rohan-varma

fbshipit-source-id: 6e98c9ce1d9f1725973c18d24a6554c1bceae465
@rohan-varma
Copy link
Member

Done by #51104

xsacha pushed a commit to xsacha/pytorch that referenced this issue Mar 31, 2021
… handle moving output to CPU (pytorch#51104)

Summary:
Fixes #{[50510](pytorch#50510)}

Allows ```torch.nn.parallel.scatter_gather.gather``` to accept a list of NamedTuples as input and returns a NamedTuple whose elements are tensors. I added the author's fix using the ```is_namedtuple``` function.

While testing this fix, I encountered a deprecation warning instructing me to use ```'cpu'``` instead of ```-1``` to move the outputs to the CPU. However, doing this causes an assertion error in the ```_get_device_index``` function. I solved this by handling the CPU case in the affected ```forward``` function.
rohan-varma

Pull Request resolved: pytorch#51104

Reviewed By: albanD

Differential Revision: D26395578

Pulled By: rohan-varma

fbshipit-source-id: 6e98c9ce1d9f1725973c18d24a6554c1bceae465
@adizhol
Copy link

adizhol commented Dec 22, 2021

@rohan-varma
Hi, just a quick question - why is there no handling of regular tuples in gather_map?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
better-engineering Relatively self-contained tasks for better engineering contributors module: ddp Issues/PRs related distributed data parallel training oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

No branches or pull requests

6 participants