Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataParallel (broadcast_coalesced) with complex tensors yield real views #55375

Open
FlorentinGuth opened this issue Apr 6, 2021 · 11 comments
Open
Labels
high priority module: complex Related to complex number support in PyTorch module: data parallel oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@FlorentinGuth
Copy link

FlorentinGuth commented Apr 6, 2021

馃悰 Bug

Using DataParallel on complex tensors (either parameters or inputs/outputs) yield real views. The expected behavior would be to obtain complex tensors on each replicate. Casting the views back to complex leads to an exception during BroadcastBackward. See the original issue (#47330) and PR (#48686), I believe the problem is that broadcast_coalesced calls torch.view_as_real but the corresponding torch.view_as_complex is never called.

Example (here only the model parameters are complex, but there are additional problems if the inputs/outputs are as well):

import torch
import torch.nn as nn
import torch.nn.parallel

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        param = torch.rand((3, 2))
        param = torch.view_as_complex(param)
        self.param = nn.Parameter(param)  # complex, shape (3,)

    def forward(self, x):
        print(self.param.dtype, self.param.shape)   # prints "torch.float32 torch.Size([3, 2])" instead of "torch.complex64 torch.Size([3])"
        return torch.abs(x - torch.view_as_complex(self.param))  # view_as_complex necessary to have a complex tensor

model = nn.parallel.DataParallel(Model()).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

x = torch.rand(100, 3).cuda()
y = model(x)

loss = torch.mean(y ** 2)
loss.backward()
optimizer.step()  # Function BroadcastBackward returned an invalid gradient at index 0 - got [3, 2] but expected shape compatible with [3]

Environment

PyTorch version: 1.8.1+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: CentOS Linux release 7.9.2009 (Core) (x86_64)
GCC version: (GCC) 4.8.5 20150623 (Red Hat 4.8.5-44)
Clang version: Could not collect
CMake version: version 2.8.12.2

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: 
GPU 0: Tesla V100-SXM2-32GB
GPU 1: Tesla V100-SXM2-32GB

Nvidia driver version: 460.32.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.2
[pip3] numpydoc==1.1.0
[pip3] torch==1.8.1+cu111
[pip3] torchaudio==0.8.1
[pip3] torchvision==0.9.1+cu111
[conda] blas                      1.0                         mkl  
[conda] mkl                       2020.2                      256  
[conda] mkl-service               2.3.0            py38he904b0f_0  
[conda] mkl_fft                   1.2.0            py38h23d657b_0  
[conda] mkl_random                1.1.1            py38h0573a6f_0  
[conda] numpy                     1.19.2           py38h54aff64_0  
[conda] numpy-base                1.19.2           py38hfa32c7d_0  
[conda] numpydoc                  1.1.0                      py_0  
[conda] torch                     1.8.1+cu111              pypi_0    pypi
[conda] torchaudio                0.8.1                    pypi_0    pypi
[conda] torchvision               0.9.1+cu111              pypi_0    pypi

Workaround

A temporary workaround is to ensure all parameters and model inputs/outputs are real with an additional last dimension of size 2, and calling torch.view_as_complex on them at each forward. In the above example, one simply removes param = torch.view_as_complex(param) in Model.__init__.

cc @ezyang @gchanan @zou3519 @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang @kwen2501 @anjali411 @dylanbespalko @mruberry @lezcano @nikitaved @bdhirsh @jbschlosser @agolynski @mrzzd @cbalioglu

@mrshenli mrshenli added module: complex Related to complex number support in PyTorch module: data parallel triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module oncall: distributed Add this issue/PR to distributed oncall triage queue labels Apr 7, 2021
@FlorentinGuth
Copy link
Author

Any progress on this issue (and the related #57534)? This is currently a big problem for us and it is increasingly hard to work around it...

@mrshenli mrshenli removed the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jun 30, 2021
@mrshenli
Copy link
Contributor

cc @vitaly-fedyunin @ngimel for DataParallel

@ngimel
Copy link
Collaborator

ngimel commented Jun 30, 2021

It is recommended to use Distributed Data Parallel instead of DataParallel https://pytorch.org/docs/stable/notes/cuda.html#cuda-nn-ddp-instead. DataParallel is in maintenance mode, and bugfixes for it are limited.

@ezyang
Copy link
Contributor

ezyang commented Jul 6, 2021

fwiw I do want to fix this particular bug, it doesn't feel like it should be hard to fix

@anjali411
Copy link
Contributor

related issue:

pytorch/torch/_utils.py

Lines 509 to 514 in 80797d0

def _handle_complex(tensor):
"""
Returns a real view of a tensor if complex dtype else just the tensor
need to check if a UninitializedParameter because otherwise checking is_complex is an error for a LazyModule
"""
return torch.view_as_real(tensor) if not isinstance(tensor,
this doesn't work with conjugate view tensors

@soulitzer
Copy link
Contributor

@anjali411 @mrshenli to discuss

@anjali411
Copy link
Contributor

anjali411 commented Oct 19, 2021

import torch
import torch.nn as nn
import torch.nn.parallel

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        param = torch.rand((3, 2))
        param = torch.view_as_complex(param)
        self.param = nn.Parameter(param)  # complex, shape (3,)

    def forward(self, x):
        print(self.param.dtype, self.param.shape)   # correctly prints "torch.complex64 torch.Size([3])" 
        return torch.abs(x - self.param)  

model = nn.parallel.DataParallel(Model())
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

x = torch.rand(100, 3)
y = model(x)

loss = torch.mean(y ** 2)
loss.backward()
optimizer.step()  # works ok

@Kegnarok looks like this is fixed on the latest master!

@anjali411
Copy link
Contributor

@Kegnarok Please feel free to reopen the issue or open a new issue if there are other bugs, but this seems to be fixed on the latest master

@fschiffers
Copy link

import torch
import torch.nn as nn
import torch.nn.parallel

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        param = torch.rand((3, 2))
        param = torch.view_as_complex(param)
        self.param = nn.Parameter(param)  # complex, shape (3,)

    def forward(self, x):
        print(self.param.dtype, self.param.shape)   # correctly prints "torch.complex64 torch.Size([3])" 
        return torch.abs(x - self.param)  

model = nn.parallel.DataParallel(Model())
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

x = torch.rand(100, 3)
y = model(x)

loss = torch.mean(y ** 2)
loss.backward()
optimizer.step()  # works ok

@Kegnarok looks like this is fixed on the latest master!

This code no longer works for me. I'm on PyTorch 1.11.0+cu113

The error continues to be:

RuntimeError: Function BroadcastBackward returned an invalid gradient at index 0 - got [3, 2] but expected shape compatible with [3]

@anjali411
Copy link
Contributor

cc. @osalpekar

@osalpekar osalpekar reopened this Jun 20, 2022
@osalpekar osalpekar added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jun 20, 2022
@jeremyfix
Copy link

Hi, I know you mentioned earlier DataParallel will have limited support; But in case this is helpful.

I did some experiments with torch 2.0.0 with a slightly modified code compared to the above one (mainly directly setting the dtype of the tensors to torch.complex64)

  • If I omit the use DataParallel (just commenting the line model = nn.parallel.DataParallel(model)), I get as an output :
Constructor: param (torch.complex64)
Forward pass: param(torch.complex64) with shape torch.Size([3])
Forward pass: input(torch.complex64 with shape torch.Size([100, 3])
  • If I use DataParallel, keeping the code below with the line model = nn.parallel.DataParallel(model)
Constructor: param (torch.complex64)
Forward pass: param(torch.float32) with shape torch.Size([3, 2])
Forward pass: input(torch.float32 with shape torch.Size([50, 3, 2])
Forward pass: param(torch.float32) with shape torch.Size([3, 2])
Forward pass: input(torch.float32 with shape torch.Size([50, 3, 2])
[...]
RuntimeError: Function BroadcastBackward returned an invalid gradient at index 0 - got [3, 2] but expected shape compatible with [3]

The issue which seems to me related to the above is that for some reasons the tensor get casted as real view tensors of the complex dtypes.

The code is given below

import torch
import torch.nn as nn
import torch.nn.parallel

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        param = torch.rand((3, ), dtype=torch.complex64)
        self.param = nn.Parameter(param)
        print(f"Constructor: param ({self.param.dtype})")

    def forward(self, x):
        print(f"Forward pass: param({self.param.dtype}) with shape {self.param.shape}")
        print(f"Forward pass: input({x.dtype} with shape {x.shape}")
        return torch.abs(x - self.param)  

model = Model()
model = nn.parallel.DataParallel(model)  # <---- For the experiments, this line is either present of commented
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

device = torch.device('cuda:0')
model.to(device)

x = torch.rand(100, 3, dtype=torch.complex64).to(device)
y = model(x)

loss = torch.mean(y ** 2)
loss.backward()
optimizer.step()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: complex Related to complex number support in PyTorch module: data parallel oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

9 participants