-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
DataParallel
(broadcast_coalesced
) with complex tensors yield real views
#55375
Comments
Any progress on this issue (and the related #57534)? This is currently a big problem for us and it is increasingly hard to work around it... |
cc @vitaly-fedyunin @ngimel for |
It is recommended to use Distributed Data Parallel instead of DataParallel https://pytorch.org/docs/stable/notes/cuda.html#cuda-nn-ddp-instead. DataParallel is in maintenance mode, and bugfixes for it are limited. |
fwiw I do want to fix this particular bug, it doesn't feel like it should be hard to fix |
related issue: Lines 509 to 514 in 80797d0
|
@anjali411 @mrshenli to discuss |
import torch
import torch.nn as nn
import torch.nn.parallel
class Model(nn.Module):
def __init__(self):
super().__init__()
param = torch.rand((3, 2))
param = torch.view_as_complex(param)
self.param = nn.Parameter(param) # complex, shape (3,)
def forward(self, x):
print(self.param.dtype, self.param.shape) # correctly prints "torch.complex64 torch.Size([3])"
return torch.abs(x - self.param)
model = nn.parallel.DataParallel(Model())
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
x = torch.rand(100, 3)
y = model(x)
loss = torch.mean(y ** 2)
loss.backward()
optimizer.step() # works ok @Kegnarok looks like this is fixed on the latest master! |
@Kegnarok Please feel free to reopen the issue or open a new issue if there are other bugs, but this seems to be fixed on the latest master |
This code no longer works for me. I'm on PyTorch 1.11.0+cu113 The error continues to be:
|
cc. @osalpekar |
Hi, I know you mentioned earlier DataParallel will have limited support; But in case this is helpful. I did some experiments with torch 2.0.0 with a slightly modified code compared to the above one (mainly directly setting the dtype of the tensors to
The issue which seems to me related to the above is that for some reasons the tensor get casted as real view tensors of the complex dtypes. The code is given below import torch
import torch.nn as nn
import torch.nn.parallel
class Model(nn.Module):
def __init__(self):
super().__init__()
param = torch.rand((3, ), dtype=torch.complex64)
self.param = nn.Parameter(param)
print(f"Constructor: param ({self.param.dtype})")
def forward(self, x):
print(f"Forward pass: param({self.param.dtype}) with shape {self.param.shape}")
print(f"Forward pass: input({x.dtype} with shape {x.shape}")
return torch.abs(x - self.param)
model = Model()
model = nn.parallel.DataParallel(model) # <---- For the experiments, this line is either present of commented
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
device = torch.device('cuda:0')
model.to(device)
x = torch.rand(100, 3, dtype=torch.complex64).to(device)
y = model(x)
loss = torch.mean(y ** 2)
loss.backward()
optimizer.step() |
馃悰 Bug
Using
DataParallel
on complex tensors (either parameters or inputs/outputs) yield real views. The expected behavior would be to obtain complex tensors on each replicate. Casting the views back to complex leads to an exception duringBroadcastBackward
. See the original issue (#47330) and PR (#48686), I believe the problem is thatbroadcast_coalesced
callstorch.view_as_real
but the correspondingtorch.view_as_complex
is never called.Example (here only the model parameters are complex, but there are additional problems if the inputs/outputs are as well):
Environment
Workaround
A temporary workaround is to ensure all parameters and model inputs/outputs are real with an additional last dimension of size 2, and calling
torch.view_as_complex
on them at each forward. In the above example, one simply removesparam = torch.view_as_complex(param)
inModel.__init__
.cc @ezyang @gchanan @zou3519 @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang @kwen2501 @anjali411 @dylanbespalko @mruberry @lezcano @nikitaved @bdhirsh @jbschlosser @agolynski @mrzzd @cbalioglu
The text was updated successfully, but these errors were encountered: