-
Notifications
You must be signed in to change notification settings - Fork 25.3k
Description
🐛 Describe the bug
I am using Torch 1.10. This is not the code I am working on, but this reproduce the error.
import torch
import os
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import numpy as np
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
def example(rank, world_size):
# create default process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
# create local model
model = nn.Linear(10, 10, dtype=torch.cfloat).to(rank)
# construct DDP model
ddp_model = DDP(model, device_ids=[rank])
# define loss function and optimizer
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
# forward pass
x = torch.rand(20,10)
outputs = ddp_model(torch.randn(20, 10).to(rank))
labels = torch.randn(20, 10).to(rank)
# backward pass
loss_fn(torch.abs(outputs), labels).backward()
# update parameters
optimizer.step()
print(outputs)
def main():
world_size = 2
mp.spawn(example,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__=="__main__":
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = str(np.random.randint(10000, 20000))
main()
Output
RuntimeError: Input tensor data type is not supported for NCCL process group: ComplexFloat
NCCL backend should support the Complex datatype as in the backpropagation algorithm.
Versions
PyTorch version: 1.10.1
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A
OS: NVIDIA DGX Station (x86_64)
GCC version: (GCC) 4.8.5 20150623 (Red Hat 4.8.5-44)
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.17
Python version: 3.8.12 (default, Oct 12 2021, 13:49:34) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-3.10.0-1160.15.2.el7.x86_64-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: Could not collect
Nvidia driver version: 450.102.04
cuDNN version: Probably one of the following:
/usr/lib64/libcudnn.so.7.6.3
/usr/lib64/libcudnn.so.8.0.5
/usr/lib64/libcudnn_adv_infer.so.8.0.5
/usr/lib64/libcudnn_adv_train.so.8.0.5
/usr/lib64/libcudnn_cnn_infer.so.8.0.5
/usr/lib64/libcudnn_cnn_train.so.8.0.5
/usr/lib64/libcudnn_ops_infer.so.8.0.5
/usr/lib64/libcudnn_ops_train.so.8.0.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.21.2
[pip3] torch==1.10.1
[pip3] torchaudio==0.10.1
[pip3] torchvision==0.11.2
[conda] blas 1.0 mkl
[conda] cudatoolkit 10.2.89 hfd86e86_1
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py38h7f8727e_0
[conda] mkl_fft 1.3.1 py38hd3c417c_0
[conda] mkl_random 1.2.2 py38h51133e4_0
[conda] numpy 1.21.2 py38h20f2e39_0
[conda] numpy-base 1.21.2 py38h79a1101_0
[conda] pytorch 1.10.1 py3.8_cuda10.2_cudnn7.6.5_0 pytorch
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torchaudio 0.10.1 py38_cu102 pytorch
[conda] torchvision 0.11.2 py38_cu102 pytorch
cc @ezyang @gchanan @zou3519 @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang @anjali411 @dylanbespalko @mruberry @lezcano @nikitaved