Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DDP multi host with single GPU each. #78047

Open
spyroot opened this issue May 21, 2022 · 3 comments
Open

DDP multi host with single GPU each. #78047

spyroot opened this issue May 21, 2022 · 3 comments
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@spyroot
Copy link

spyroot commented May 21, 2022

🐛 Describe the bug

Folks,

I have two hosts, and each host has a single GPU, I'm using an example.
https://github.com/sudomaze/ttorch/blob/main/examples/ddp/run.py

if I use the master node
rank 0 , world_size 2
worker
rank 1, world_size 2

if I use (master start training loop but worker never connected)
rank 0 , world_size 1
rank 1 , world_size 1

Stack trace for case one. Note that master goes and waits only if world_size 2.

gpu10:17709:17709 [0] NCCL INFO NET/Socket : Using [0]eth0:172.16.80.231<0>
gpu10:17709:17709 [0] NCCL INFO Using network Socket
gpu10:17709:17729 [0] NCCL INFO Trees [0] -1/-1/-1->1->0 [1] 0/-1/-1->1->-1
gpu10:17709:17729 [0] NCCL INFO Channel 00 : 0[b000] -> 1[6000] [receive] via NET/Socket/0
gpu10:17709:17729 [0] NCCL INFO Channel 01 : 0[b000] -> 1[6000] [receive] via NET/Socket/0
gpu10:17709:17729 [0] NCCL INFO Channel 00 : 1[6000] -> 0[b000] [send] via NET/Socket/0
gpu10:17709:17729 [0] NCCL INFO Channel 01 : 1[6000] -> 0[b000] [send] via NET/Socket/0
gpu10:17709:17729 [0] NCCL INFO Connected all rings
gpu10:17709:17729 [0] NCCL INFO Connected all trees
gpu10:17709:17729 [0] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 8/8/512
gpu10:17709:17729 [0] NCCL INFO 2 coll channels, 2 p2p channels, 1 p2p channels per peer
gpu10:17709:17729 [0] NCCL INFO comm 0x7f7f2c002fc0 rank 1 nranks 2 cudaDev 0 busId 6000 - Init COMPLETE
barrier released
Traceback (most recent call last):
  File "/root/git/dtc_latest/dtc/ddp_test/ddp_sample.worker.py", line 235, in <module>
    init_process(1, world_size, run)
  File "/root/git/dtc_latest/dtc/ddp_test/ddp_sample.worker.py", line 227, in init_process
    fn(rank, world_size)
  File "/root/git/dtc_latest/dtc/ddp_test/ddp_sample.worker.py", line 158, in run
    torch.cuda.set_device(rank)
  File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 313, in set_device
    torch._C._cuda_setDevice(device)
RuntimeError: CUDA error: invalid device ordinal
CUDA kernel errors might be asynchronously reported at some other API call, so the stack trace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

Master node.

win11:15538:15538 [0] NCCL INFO Bootstrap : Using eth0:192.168.254.205<0>
win11:15538:15538 [0] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation

win11:15538:15538 [0] misc/ibvwrap.cc:63 NCCL WARN Failed to open libibverbs.so[.1]
win11:15538:15538 [0] NCCL INFO NET/Socket : Using [0]eth0:192.168.254.205<0>
win11:15538:15538 [0] NCCL INFO Using network Socket
NCCL version 2.10.3+cuda11.5
win11:15538:15568 [0] NCCL INFO Channel 00/02 :    0   1
win11:15538:15568 [0] NCCL INFO Channel 01/02 :    0   1
win11:15538:15568 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1 [1] -1/-1/-1->0->1
win11:15538:15568 [0] NCCL INFO Channel 00 : 1[6000] -> 0[b000] [receive] via NET/Socket/0
win11:15538:15568 [0] NCCL INFO Channel 01 : 1[6000] -> 0[b000] [receive] via NET/Socket/0
win11:15538:15568 [0] NCCL INFO Channel 00 : 0[b000] -> 1[6000] [send] via NET/Socket/0
win11:15538:15568 [0] NCCL INFO Channel 01 : 0[b000] -> 1[6000] [send] via NET/Socket/0
win11:15538:15568 [0] NCCL INFO Connected all rings
win11:15538:15568 [0] NCCL INFO Connected all trees
win11:15538:15568 [0] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 8/8/512
win11:15538:15568 [0] NCCL INFO 2 coll channels, 2 p2p channels, 1 p2p channels per peer
win11:15538:15568 [0] NCCL INFO comm 0x7fa13c002fb0 rank 0 nranks 2 cudaDev 0 busId b000 - Init COMPLETE
win11:15538:15538 [0] NCCL INFO Launch mode Parallel
releasing
2 64

win11:15538:15570 [0] include/socket.h:423 NCCL WARN Net : Connection closed by remote peer 172.16.80.231<36406>
win11:15538:15570 [0] NCCL INFO transport/net_socket.cc:414 -> 2
win11:15538:15570 [0] NCCL INFO include/net.h:28 -> 2
win11:15538:15570 [0] NCCL INFO transport/net.cc:459 -> 2
win11:15538:15570 [0] NCCL INFO proxy.cc:351 -> 2
win11:15538:15570 [0] NCCL INFO proxy.cc:452 -> 2 [Proxy Thread]

I manage a bit narrow it down.

This a case

master  
rank 0 local rank set 0 world size 2 
device = "cuda:0"
model.to(device)
DDP(model, device_ids=[0], output_device=[0)

worker
device = "cuda:0"
model.to(device)
DDP(model, device_ids=[0], output_device=[0)
RuntimeError: CUDA error: invalid device ordinal
CUDA kernel errors might be asynchronously reported at some other API call, so the stack trace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

Versions

Collecting environment information...
PyTorch version: 1.11.0+cu115
Is debug build: False
CUDA used to build PyTorch: 11.5
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04 LTS (x86_64)
GCC version: (Ubuntu 11.2.0-19ubuntu1) 11.2.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.4 (main, Apr 2 2022, 09:04:19) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.10.102.1-microsoft-standard-WSL2-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version: 512.15
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.21.6
[pip3] torch==1.11.0+cu115
[pip3] torchaudio==0.11.0+cu115
[pip3] torchvision==0.12.0+cu115
[conda] Could not collect

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang @kwen2501

@spyroot
Copy link
Author

spyroot commented May 21, 2022

I think somewhere id lookup are really messed up.

Traceback (most recent call last):
File "main.py", line 324, in
main(args)
File "main.py", line 273, in main
train(spec=trainer_spec, cmd_args=cmd_args, device=_device)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/distributed.py", line 962, in forward
inputs, kwargs = self.to_kwargs(inputs, kwargs, self.device_ids[0])
File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/distributed.py", line 1086, in to_kwargs
inputs = self._recursive_to(inputs, device_id) if inputs else []
File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/distributed.py", line 1080, in _recursive_to
res = to_map(inputs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/distributed.py", line 1059, in to_map
return list(zip(*map(to_map, obj)))
File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/distributed.py", line 1059, in to_map
return list(zip(*map(to_map, obj)))
File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/distributed.py", line 1044, in to_map
stream = _get_stream(target_gpu)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/_functions.py", line 122, in _get_stream
if _streams[device] is None:
IndexError: list index out of range

@mrshenli mrshenli added the oncall: distributed Add this issue/PR to distributed oncall triage queue label May 22, 2022
@kwen2501
Copy link
Contributor

Hi, as the first error log indicates, this call in user code ddp_sample.worker.py:
torch.cuda.set_device(rank)
probably sets an invalid device ordinal.

This could happen when rank = 1 and you only have 1 GPU per node (so valid device ordinal is only '0' on both nodes). You can consider modifying rank to local_rank in your above call (if the latter exists). The local_rank value is expected to be 0 on both nodes in your case.

@kumpera kumpera added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 31, 2022
@spyroot
Copy link
Author

spyroot commented May 31, 2022

Folks,

So if you have two nodes, each node has only 1 GPU. Based on the logic.

Node a - rank 0 ( master)
Node b - rank 1 ( worker) (local rank 0) --> this value read somewhere in code from os env.

I managed to make it work but what is strange is I had to set all env and all values.
So on top of all mentions. basically re-populate to os. environment. (I think down the line somewhere,
the value read from a different source)

2). If you check all examples, sometimes rank passed as to.device() some time rank passed in devices_ids.

os.environ["NCCL_DEBUG"] = "INFO"
os.environ["NCCL_IB_DISABLE"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["LOCAL_RANK"] = "0" ----> this value read somewhere either in torch or cuda/nccl code
os.environ["RANK"] = "0". ----> this value as well.
os.environ["WORLD_SIZE"] = "2"
os.environ["TUNE_DISABLE_AUTO_CALLBACK_SYNCER"] = "1"

I think it makes more sense to be concrete in typing.
i.e device representation in Torch i.e id and CUDA device, local rank vs rank.

i.e 0 rank , 0 local rank , 0 cuda device id....

For example, the torch already has an abstract device. So it makes sense always to use
the same abstraction. If you check all examples on torch people do very strange think,
send models to rank. to(rank) or to a local rank. (local rank and number of GPU
per node don't make much a sense from an abstraction point of view)

if you think about it in a logical sense it is very strange.
In essence to_deivce(some_device -> that should be abstract . i.e local
GPU or remote GPU)

In essence, it makes much more sense to have some sort of priority instead
of local rank.

Something like this. How pool was created, and what device was abstracted.
DDP(GPU_Pool(devices=list[torch.devices()]))

or maybe.

DDP(RoundRobbinGpuSampler(devices=list[torch.devices()]))

Imagine you have 4 servers and each 1 GPU so you have all local rank 0, GPU device id 0. ;-))
or imagine you have one node master no GPU at all, it just split batch sends to all remotes.
aggregate on CPU and never compute anything, only combine gradients.

Right now device_ids , cuda device id, local_rank IMHO very ambiguous and can be passed
in a different form.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants