-
Notifications
You must be signed in to change notification settings - Fork 25.1k
Closed
Labels
better-engineeringRelatively self-contained tasks for better engineering contributorsRelatively self-contained tasks for better engineering contributorshigh prioritymodule: c10dIssues/PRs related to collective communications and process groupsIssues/PRs related to collective communications and process groupsoncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queue
Description
🐛 Describe the bug
I'M testING TORCH_DISTRIBUTED_DEBUG according to this documents: https://pytorch.org/docs/master/distributed.html#debugging-torch-distributed-applications
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
class TwoLinLayerNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.nn.Linear(10, 10, bias=False)
self.b = torch.nn.Linear(10, 1, bias=False)
def forward(self, x):
a = self.a(x)
b = self.b(x)
return (a, b)
def worker(rank):
dist.init_process_group("nccl", rank=rank, world_size=2)
torch.cuda.set_device(rank)
print("init model")
model = TwoLinLayerNet().cuda()
print("init ddp")
ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
inp = torch.randn(10, 10).cuda()
print("train")
for _ in range(20):
output = ddp_model(inp)
loss = output[0] + output[1]
loss.sum().backward()
if __name__ == "__main__":
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
os.environ[
"TORCH_DISTRIBUTED_DEBUG"
] = "DETAIL" # set to DETAIL for runtime logging.
mp.spawn(worker, nprocs=2, args=())
The code above is from this document and is identical
but the output has no ddp debug information
init model
init model
init ddp
init ddp
i39a12200:2427:2427 [0] NCCL INFO Bootstrap : Using bond0:11.164.100.100<0>
i39a12200:2427:2427 [0] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
i39a12200:2427:2427 [0] misc/ibvwrap.cc:63 NCCL WARN Failed to open libibverbs.so[.1]
i39a12200:2427:2427 [0] NCCL INFO NET/Socket : Using [0]bond0:11.164.100.100<0>
i39a12200:2427:2427 [0] NCCL INFO Using network Socket
NCCL version 2.10.3+cuda11.3
i39a12200:2428:2428 [1] NCCL INFO Bootstrap : Using bond0:11.164.100.100<0>
i39a12200:2428:2428 [1] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
i39a12200:2428:2428 [1] misc/ibvwrap.cc:63 NCCL WARN Failed to open libibverbs.so[.1]
i39a12200:2428:2428 [1] NCCL INFO NET/Socket : Using [0]bond0:11.164.100.100<0>
i39a12200:2428:2428 [1] NCCL INFO Using network Socket
i39a12200:2427:2577 [0] NCCL INFO NCCL_MAX_NCHANNELS set by environment to 2.
i39a12200:2428:2578 [1] NCCL INFO NCCL_MAX_NCHANNELS set by environment to 2.
i39a12200:2427:2577 [0] NCCL INFO NCCL_MIN_NCHANNELS set by environment to 2.
i39a12200:2427:2577 [0] NCCL INFO Channel 00/02 : 0 1
i39a12200:2428:2578 [1] NCCL INFO NCCL_MIN_NCHANNELS set by environment to 2.
i39a12200:2428:2578 [1] NCCL INFO Trees [0] -1/-1/-1->1->0 [1] -1/-1/-1->1->0
i39a12200:2427:2577 [0] NCCL INFO Channel 01/02 : 0 1
i39a12200:2428:2578 [1] NCCL INFO Setting affinity for GPU 1 to ff,ffffffff,ffffffff,ffffffff
i39a12200:2427:2577 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1 [1] 1/-1/-1->0->-1
i39a12200:2427:2577 [0] NCCL INFO Setting affinity for GPU 0 to ff,ffffffff,ffffffff,ffffffff
i39a12200:2428:2578 [1] NCCL INFO Channel 00 : 1[57000] -> 0[52000] via P2P/IPC/read
i39a12200:2427:2577 [0] NCCL INFO Channel 00 : 0[52000] -> 1[57000] via P2P/IPC/read
i39a12200:2428:2578 [1] NCCL INFO Channel 01 : 1[57000] -> 0[52000] via P2P/IPC/read
i39a12200:2427:2577 [0] NCCL INFO Channel 01 : 0[52000] -> 1[57000] via P2P/IPC/read
i39a12200:2428:2578 [1] NCCL INFO Connected all rings
i39a12200:2428:2578 [1] NCCL INFO Connected all trees
i39a12200:2428:2578 [1] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 8/8/512
i39a12200:2428:2578 [1] NCCL INFO 2 coll channels, 2 p2p channels, 2 p2p channels per peer
i39a12200:2427:2577 [0] NCCL INFO Connected all rings
i39a12200:2427:2577 [0] NCCL INFO Connected all trees
i39a12200:2427:2577 [0] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 8/8/512
i39a12200:2427:2577 [0] NCCL INFO 2 coll channels, 2 p2p channels, 2 p2p channels per peer
i39a12200:2428:2578 [1] NCCL INFO comm 0x7fdb5c002f70 rank 1 nranks 2 cudaDev 1 busId 57000 - Init COMPLETE
i39a12200:2427:2577 [0] NCCL INFO comm 0x7fbfdc002f70 rank 0 nranks 2 cudaDev 0 busId 52000 - Init COMPLETE
i39a12200:2427:2427 [0] NCCL INFO Launch mode Parallel
train
train
Do I need to successfully configure USE_GLOG at compile time or specify a CAFFE2_LOG_THRESHOLD of no less than INFO in order to print these messages? If so, I suggest refine the documentation here.
The PyTorch build used:
pip3 install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
Versions
PyTorch version: 1.10.1+cu113
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.19.1
Libc version: glibc-2.9
Python version: 3.6.12 |Anaconda, Inc.| (default, Sep 8 2020, 23:10:56) [GCC 7.3.0] (64-bit runtime)
Python platform: Linux-3.10.0-327.ali2019.alios7.x86_64-x86_64-with-debian-buster-sid
Is CUDA available: True
CUDA runtime version: 11.3.109
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB
GPU 2: NVIDIA A100-SXM4-40GB
GPU 3: NVIDIA A100-SXM4-40GB
GPU 4: NVIDIA A100-SXM4-40GB
GPU 5: NVIDIA A100-SXM4-40GB
GPU 6: NVIDIA A100-SXM4-40GB
GPU 7: NVIDIA A100-SXM4-40GB
Nvidia driver version: 470.82.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.2.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.16.6
[pip3] torch==1.10.1+cu113
[pip3] torch-cluster==1.5.9
[pip3] torch-geometric==2.0.3
[pip3] torch-scatter==2.0.9
[pip3] torch-sparse==0.6.12
[pip3] torch-spline-conv==1.2.1
[pip3] torch-xla==1.10
[pip3] torchaudio==0.10.1+cu113
[pip3] torchtext==0.11.1
[pip3] torchvision==0.11.2+cu113
[conda] magma-cuda113 2.5.2 1 pytorch
[conda] mkl 2021.4.0 h06a4308_640 defaults
[conda] mkl-include 2022.0.1 h8d4b97c_803 conda-forge
[conda] numpy 1.16.6 pypi_0 pypi
[conda] torch 1.10.1+cu113 pypi_0 pypi
[conda] torch-cluster 1.5.9 pypi_0 pypi
[conda] torch-geometric 2.0.3 pypi_0 pypi
[conda] torch-scatter 2.0.9 pypi_0 pypi
[conda] torch-sparse 0.6.12 pypi_0 pypi
[conda] torch-spline-conv 1.2.1 pypi_0 pypi
[conda] torch-xla 1.10 pypi_0 pypi
[conda] torchaudio 0.10.1+cu113 pypi_0 pypi
[conda] torchtext 0.11.1 pypi_0 pypi
[conda] torchvision 0.11.2+cu113 pypi_0 pypi
cc @ezyang @gchanan @zou3519 @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang
Metadata
Metadata
Assignees
Labels
better-engineeringRelatively self-contained tasks for better engineering contributorsRelatively self-contained tasks for better engineering contributorshigh prioritymodule: c10dIssues/PRs related to collective communications and process groupsIssues/PRs related to collective communications and process groupsoncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queue