Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NCCL error using DDP and PyTorch 1.7 #4420

Closed
ohmeow opened this issue Oct 29, 2020 · 56 comments
Closed

NCCL error using DDP and PyTorch 1.7 #4420

ohmeow opened this issue Oct 29, 2020 · 56 comments
Assignees
Labels
3rd party Related to a 3rd-party bug Something isn't working distributed Generic distributed-related topic help wanted Open to be worked on priority: 0 High priority task
Milestone

Comments

@ohmeow
Copy link

ohmeow commented Oct 29, 2020

🐛 Bug

Getting this error when attempting to use ddp with the "getting started" autoencoder example:

Stack Trace:

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2]
initializing ddp: GLOBAL_RANK: 1, MEMBER: 2/2
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/2
Traceback (most recent call last):
  File "01_getting_started_autoencoder.py", line 66, in <module>
    modle, trainer = cli_main()
  File "01_getting_started_autoencoder.py", line 60, in cli_main
    trainer.fit(model, train_dl)
  File "/home/user/anaconda3/envs/playground-pl/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 440, in fit
Traceback (most recent call last):
  File "/home/user/development/_training/ml/pl-playground/01_getting_started_autoencoder.py", line 66, in <module>
    results = self.accelerator_backend.train()
  File "/home/user/anaconda3/envs/playground-pl/lib/python3.7/site-packages/pytorch_lightning/accelerators/ddp_accelerator.py", line 138, in train
    results = self.ddp_train(process_idx=self.task_idx, model=model)
  File "/home/user/anaconda3/envs/playground-pl/lib/python3.7/site-packages/pytorch_lightning/accelerators/ddp_accelerator.py", line 231, in ddp_train
    self.trainer.is_slurm_managing_tasks
  File "/home/user/anaconda3/envs/playground-pl/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 213, in init_ddp_connection
    torch_backend, rank=global_rank, world_size=world_size
  File "/home/user/anaconda3/envs/playground-pl/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 442, in init_process_group
    barrier()
  File "/home/user/anaconda3/envs/playground-pl/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 1947, in barrier
    modle, trainer = cli_main()
  File "/home/user/development/_training/ml/pl-playground/01_getting_started_autoencoder.py", line 60, in cli_main
    trainer.fit(model, train_dl)
  File "/home/user/anaconda3/envs/playground-pl/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 440, in fit
    results = self.accelerator_backend.train()
  File "/home/user/anaconda3/envs/playground-pl/lib/python3.7/site-packages/pytorch_lightning/accelerators/ddp_accelerator.py", line 138, in train
    work = _default_pg.barrier()
RuntimeError: NCCL error in: /pytorch/torch/lib/c10d/ProcessGroupNCCL.cpp:784, invalid usage, NCCL version 2.7.8
    results = self.ddp_train(process_idx=self.task_idx, model=model)
  File "/home/user/anaconda3/envs/playground-pl/lib/python3.7/site-packages/pytorch_lightning/accelerators/ddp_accelerator.py", line 231, in ddp_train
    self.trainer.is_slurm_managing_tasks
  File "/home/user/anaconda3/envs/playground-pl/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 213, in init_ddp_connection
    torch_backend, rank=global_rank, world_size=world_size
  File "/home/user/anaconda3/envs/playground-pl/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 442, in init_process_group
    barrier()
  File "/home/user/anaconda3/envs/playground-pl/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 1947, in barrier
    work = _default_pg.barrier()
RuntimeError: NCCL error in: /pytorch/torch/lib/c10d/ProcessGroupNCCL.cpp:784, invalid usage, NCCL version 2.7.8

To Reproduce

Follow the code in the getting started question with these parameters to Trainer:

model = LitAutoEncoder()
trainer = pl.Trainer(gpus='1,2', distributed_backend='ddp')
trainer.fit(model, train_dl)

Expected behavior

For it to train on multiple GPUs :)

Environment

  • PyTorch Version 1.7:
  • OS (e.g., Linux): Ubuntu 18.04
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source): n/a
  • Python version: 3.7
  • CUDA/cuDNN version: 10.2/7.6.5
  • GPU models and configuration: 2 1080Tis
  • Any other relevant information: n/a
@ohmeow ohmeow added bug Something isn't working help wanted Open to be worked on labels Oct 29, 2020
@awaelchli
Copy link
Member

Hi, thanks for reporting.
The autoencoder example runs fine for me.
Could you please let me know the Lightning version you are using?
We recently fixed a bug, please use 1.0.4 or newer.

@awaelchli awaelchli added information needed distributed Generic distributed-related topic labels Oct 29, 2020
@ohmeow
Copy link
Author

ohmeow commented Oct 29, 2020

Yah I'm using 1.0.4

Here's the full source for my .py file:

import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from torch.utils.data import random_split


# define pl module
class LitAutoEncoder(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 64),
            nn.ReLU(),
            nn.Linear(64, 3)
        )
        self.decoder = nn.Sequential(
            nn.Linear(3, 64),
            nn.ReLU(),
            nn.Linear(64, 28*28)
        )

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)

        # Logging to TensorBoard by default
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# define datasets/dataloaders
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train_dl = DataLoader(dataset)


# train
model = LitAutoEncoder()
trainer = pl.Trainer(gpus='0,1', distributed_backend='ddp')
trainer.fit(model, train_dl)

@awaelchli
Copy link
Member

ok, I can confirm this is only happening on pytorch 1.7

@maxjeblick
Copy link
Contributor

I have the same issue on 1080ti, with V100 GPUs everything works fine.

@edenlightning edenlightning added the priority: 0 High priority task label Oct 29, 2020
@edenlightning edenlightning added this to the 1.0.x milestone Oct 29, 2020
@s-rog
Copy link
Contributor

s-rog commented Oct 30, 2020

@maxjeblick sounds like a driver issue?

Edit:
Certainly very odd that NCCL is bugging out only with 1080ti GPUs...

@awaelchli
Copy link
Member

I tested the following with our examples:
ddp 1080ti pytorch 1.7: error
ddp 1080ti pytorch 1.6: good
ddp 2080ti pytorch 1.7: good
ddp 2080ti pytorch 1.6: good

so far was not able reproduce with pytorch examples :( need to dig deep

@jgbos
Copy link
Contributor

jgbos commented Oct 31, 2020

I can confirm the same error using the latest Lightning and PyTorch using Tesla V100s. Does not happen on a single node with 2 GPUs, but once I go to multiple nodes the error happens.

@edenlightning edenlightning changed the title NCCL error ... invalid usage, NCCL version 2.7.8 when attempting to use ddp and 2 1080ti GPUs NCCL error using DDP and PyTorch 1.7 Nov 3, 2020
@ildoonet
Copy link

ildoonet commented Nov 4, 2020

same error with A100 gpus.

RuntimeError: NCCL error in: /pytorch/torch/lib/c10d/ProcessGroupNCCL.cpp:784, invalid usage, NCCL version 2.7.8

@stillwalker1234
Copy link

Have the same issue with 2x2080ti on ubuntu 20.04 using pytorch 1.7 and cuda 11.
downgrading to pytorch 1.6 and cuda 10.2 fixes the issue

@awaelchli
Copy link
Member

Could it be this fix in pytorch?
pytorch/pytorch#47257
Exact same error (line number and stack messages).

@Borda Borda added 3rd party Related to a 3rd-party and removed information needed labels Nov 10, 2020
@tchaton tchaton closed this as completed Nov 10, 2020
@julian3xl
Copy link

pytorch closed their issue because this issue exists and you close this issue because their issue exists...

@awaelchli
Copy link
Member

@julian3xl are you referring to the one I posted? I was under the impression that the fix was merged into pytorch master.
I will check if it's fixed

@ksopyla
Copy link

ksopyla commented Nov 30, 2020

Have the same issue with single node 2x rtx 3090 on ubuntu 18.04 using pytorch 1.7, Driver Version: 455.45.01 CUDA Version: 11.1 , pytorch-lightning 1.0.8

@min-xu-ai
Copy link

I am also hitting this and I am not even using lightning. :-(

@awaelchli
Copy link
Member

@min-xu-ai And can you confirm that in pytorch 1.8 nightly it is fixed?

@min-xu-ai
Copy link

@min-xu-ai And can you confirm that in pytorch 1.8 nightly it is fixed?

Great suggestion! I installed latest 1.8.0dev version and it still fails. But the error msg seems to be more helpful than before. @awaelchli Do you think the underlying error should have been fixed in 1.8.0dev?

>       raise ProcessRaisedException(msg, error_index, failed_process.pid)
E       torch.multiprocessing.spawn.ProcessRaisedException:
E
E       -- Process 0 terminated with the following error:
E       Traceback (most recent call last):
E         File "/home/owen/e/py38_fs/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
E           fn(i, *args)
E         File "/home/owen/git/fairscale/tests/optim/test_oss.py", line 180, in run_test_add_param_group
E           dist_init(rank, world_size, tempfile_name)
E         File "/home/owen/git/fairscale/tests/optim/test_oss.py", line 33, in dist_init
E           dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
E         File "/home/owen/e/py38_fs/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 486, in init_process_group
E           barrier()
E         File "/home/owen/e/py38_fs/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 2197, in barrier
E           work = default_pg.barrier()
E       RuntimeError: NCCL error in: /pytorch/torch/lib/c10d/ProcessGroupNCCL.cpp:882, invalid usage, NCCL version 2.7.8
E       ncclInvalidUsage: This usually reflects invalid usage of NCCL library (such as too many async ops, too many collectives at once, mixing streams in a group, etc).

../../e/py38_fs/lib/python3.8/site-packages/torch/multiprocessing/spawn.py:166: ProcessRaisedException

My versions:

torch             1.8.0.dev20201204+cu110
torchtext         0.6.0
torchvision       0.9.0.dev20201204+cu110

@min-xu-ai
Copy link

Actually, I found out the reason. It seems that my unit test is trying to start a world_size=3 on 2 GPUs. The error msg is definitely hard to parse. It would be nice that dist.init_process_group just check the world_size.

FWIW, gloo backend works fine in this case.

@mpaepper
Copy link

mpaepper commented Dec 9, 2020

For others who might run into this:

In previous PyTorch Lightning versions, the Trainer received an argument distributed_backend. You now need to rename it to accelerator.

@mathpluscode
Copy link

Hi, I'm also getting this error when using multi GPUs with mixed precision, the package versions are

python=3.7

pytorch==1.8.0

torchvision==0.9.0

torchaudio==0.8.0

cudatoolkit=11.1

The GPUs are A100 with NVIDIA driver Version: 450.51.06, CUDA Version: 11.0.

The error message is

RuntimeError: NCCL error in: /opt/conda/conda-bld/pytorch_1614378098133/work/torch/lib/c10d/ProcessGroupNCCL.cpp:825, unhandled cuda error, NCCL version 2.7.8

ncclUnhandledCudaError: Call to CUDA function failed.

I tried the following env vars but it didn't work

export NCCL_SOCKET_IFNAME=eth0

export NCCL_IB_DISABLE=1

My trainer call is

trainer = pl.Trainer(

    gpus=args.gpus,
    amp_level="O2",
    precision=16,
    accumulate_grad_batches=args.acc_batch_size // args.batch_size,
    accelerator="ddp",
    plugins=DDPPlugin(find_unused_parameters=True),
    max_epochs=args.epochs,
)

Meanwhile, it works with native PyTorch using torch.cuda.amp.autocast() for mixed-precision and nn.DataParallel for multi-GPU support.

Is there any suggestions for fixing this error?

I have exactly the same issue on same env with you, How have you fix this, I get this error when I am working with apex

Hi @NoahDrisort, I fixed it by using raw Pytorch. Then I do not have this error using DataParallel, but the error occurred again using DistributedDataParallel. Then I fixed it by using docker image 'nvcr.io/nvidia/pytorch:21.05-py3'. But I haven't check this image together with PyTorch Lightning.

@awaelchli
Copy link
Member

If you land here on this thread because you got an NCCL error and it looks like this (not exactly what OP posted):

RuntimeError: NCCL error in: /pytorch/torch/lib/c10d/ProcessGroupNCCL.cpp:825, unhandled system error, NCCL version 2.7.8 E ncclSystemError: System call (socket, malloc, munmap, etc) failed

It may be because you have too little shared memory. The solution is to increase the shared memory (google it for your operating system) or if you use docker set --shm-size="1G" or some acceptable number.

General advice for NCCL errors: Run your command with the environment variable NCCL_DEBUG=INFO and collect all the messages it prints.

@chzhan
Copy link

chzhan commented Jul 7, 2021

The package versions:
pytorch: 1.8.1
/usr/local/lib/python3.8/dist-packages/torch/distributed/launch.py +270

        if not args.use_env:
            cmd.append("--local_rank={}".format(local_rank))
        cmd.extend(args.training_script_args)

-->

        cmd.extend(args.training_script_args)
        if not args.use_env:
            cmd.append("--local_rank={}".format(local_rank))

In my case, this error is caused by the local_rank parameter not being passed in. It's a bug.

@zheyuanWang
Copy link

what does that do?

My understanding is, NCCL_IB_DISABLE=1 disables nccl establishing the connection over infiniband and forces it to use the network socket.

https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-ib-disable

would it slow down the training process?

@srijandas07
Copy link

For those with A100s,
export NCCL_P2P_DISABLE=1
export NCCL_IB_DISABLE=1
works like a charm.

@universome
Copy link

For those with A100s,
export NCCL_P2P_DISABLE=1
export NCCL_IB_DISABLE=1
works like a charm.

This disables some important nccl features and you could simply use gloo backend instead (which works fine by default). Disabling these features can potentially decrease the performance (especially if you use nvlink). However, I just tried your suggestion and my StyleGAN2-ADA training speed on 4x A6000s not only decreased, but even slightly improved (by 2%). But note, that I do not have nvlink

@sjeaugey
Copy link

sjeaugey commented Sep 24, 2021

Realizing this bug is very misleading as it seems to be the landing point of every NCCL error for PyTorch and DDP. NCCL errors are varied as NCCL encompasses CUDA, NVLink, networking (Sockets and Infiniband/RoCE), and other mechanisms like shared memory, as well as performing topology detection to optimize communication between GPUs. So different users will have very different problems which need to be solved in different ways.

The first thing to do whenever a NCCL error happens, as suggested by the NCCL troubleshooting page is to run again with NCCL_DEBUG=WARN. That will give a precise error message of why NCCL failed and hopefully help fix the problem. If that message isn't clear enough, feel free to report the issue to the NCCL github project: https://github.com/nvidia/nccl.

Now, rewinding the bug to try to categorize the different issues...

In the first part of the issue (@min-xu-ai and @ohmeow) the error reported by NCCL is ncclInvalidUsage. With NCCL_DEBUG=WARN, there would be a message like the one below, which would hopefully have helped.

NCCL WARN Duplicate GPU detected : rank 0 and rank 3 both on CUDA device 0

Then, @mhpfuchs probably got a ncclSystemError in the Infiniband code (not ncclInvalidUsage) and "fixed" it by disabling IB. It would have been good to understand what that error was and perhaps fix the IB setup to make it functional, as sockets have much lower performance than IB, and a much higher CPU usage. In many cases, that's due to ulimit -l not being high enough to allow proper IB operation. Now sometimes it happens that there is an IB interface which is active yet not functional due to the network fabric not being properly setup, in which case NCCL_IB_DISABLE=1 is the proper workaround.

After that, @brando90 got a ncclUnhandledCudaError (not ncclSystemError, nor ncclInvalidUsage), like @MInner, @v-nhandt21 and @universome as well I guess.
At least in one case, the error was:

NCCL WARN failed to open CUDA IPC handle : 711 peer mapping resources exhausted

Setting NCCL_P2P_DISABLE=1 is a proper workaround in that case, but not all CUDA errors are this one and the solution could be very different depending on what CUDA issue we encountered.

Finally, @awaelchli got a ncclSystemError due to too little shared memory being available; setting NCCL_DEBUG=WARN would have probably printed something like:

NCCL WARN Error while creating shared memory segment /dev/shm/nccl-... (size ...)

and helped fix the problem as well.

@TiankaiHang
Copy link

Same bug on V100 32GB
torch1.8
cuda10.1

@haofanwang
Copy link

As suggested above,

export NCCL_SOCKET_IFNAME=eth0
export NCCL_IB_DISABLE=1

This works for me on 2 V100 servers.

@yiyele
Copy link

yiyele commented Mar 18, 2022

export NCCL_SHM_DISABLE=1
This works for me on v100

@superzrx
Copy link

NCCL_IB_DISABLE=1

fixed for me with pytorch1.9 cuda111 nccl2.7.8 v100

@sjeaugey
Copy link

sjeaugey commented Aug 22, 2022

@haofanwang @superzrx

export NCCL_IB_DISABLE=1

If you're fine leaving performance on the table, it's ok, but performance using RDMA is much higher than using TCP/IP, plus it has a much lesser load on the CPU. The most common issue when using RDMA is the memlock limit.

@yiyele

export NCCL_SHM_DISABLE=1

This is usually due to the container not setting enough space for /dev/shm. Can be fixed by launching the container with --shm-size=1G

Also, as a reminder, make sure you run with NCCL_DEBUG=WARN and look for NCCL WARN lines explaining what went wrong.

@kvenkman
Copy link
Contributor

Could you point me to the location of the NCCL logs?

FWIW, I'm working on a single node running CentOS with 2 GPUs and export NCCL_SOCKET_IFNAME=lo is what worked for me.

@sjeaugey
Copy link

Could you point me to the location of the NCCL logs?

NCCL logs will be printed to the standard output if NCCL_DEBUG is set.

@Hanpx20
Copy link

Hanpx20 commented Aug 18, 2023

@sjeaugey I'm experience the same issue. With 2 nodes, using NCCL_IB_DISABLE makes training extremely slow, but without the flag, NCCL reports error. I'm not using docker, the error message looks like:
babel-0-31:1552539:1553866 [7] misc/ibvwrap.cc:262 NCCL WARN Call to ibv_reg_mr failed with error Cannot allocate memory
babel-0-31:1552539:1553866 [7] proxy.cc:1119 NCCL WARN [Proxy Service 7] Failed to execute operation Connect from rank 7, retcode 2
babel-0-31:1552539:1553866 [7] misc/ibvwrap.cc:299 NCCL WARN Call to ibv_create_cq failed with error Cannot allocate memory
babel-0-31:1552539:1553866 [7] proxy.cc:1119 NCCL WARN [Proxy Service 7] Failed to execute operation Connect from rank 7, retcode 2
babel-0-31:1552539:1553850 [7] misc/socket.cc:538 NCCL WARN Net : Connection closed by remote peer babel-0-31.eth<38555>
babel-0-31:1552539:1553850 [7] proxy.cc:884 NCCL WARN Proxy Call to rank 7 failed (Connect)
babel-0-31:1552539:1552539 [7] bootstrap.cc:439 NCCL WARN Unexpected connections are not empty
babel-0-31:1552539:1552539 [7] init.cc:1428 NCCL WARN commReclaim: cleanup comm 0x565407de18d0 rank 7 failed in destroy/abort, error 3

Do you have any idea about the reason?

@sjeaugey
Copy link

@Hanpx20 this is a very different problem.

See how to increase your memory limits for RDMA operations here:

https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/troubleshooting.html#infiniband

Another reason for that failure can be a bad mismatch between the container IB stack and the host.

@JerryDaHeLian
Copy link

Fixed for me with NCCL_IB_DISABLE=1 lightning run model ...
Thank you to all of you above!

@danielajisafe
Copy link

Issue solved after adding export NCCL_IB_DISABLE=1. I will add the logs here for 2GPUs (v100), take note of the 2nd and 3rd line.

cdr2584:98188:98188 [0] NCCL INFO Bootstrap : Using ib0:172.19.146.21<0>
cdr2584:98188:98188 [0] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
cdr2584:98188:98188 [0] NCCL INFO NCCL_IB_DISABLE set by environment to 1.
cdr2584:98188:98188 [0] NCCL INFO NET/Socket : Using [0]ib0:172.19.146.21<0>
cdr2584:98188:98188 [0] NCCL INFO Using network Socket
NCCL version 2.12.12+cuda11.7
cdr2584:98188:98759 [0] NCCL INFO Setting affinity for GPU 0 to 5555
cdr2584:98188:98760 [1] NCCL INFO Setting affinity for GPU 1 to aaaa00
cdr2584:98188:98759 [0] NCCL INFO Channel 00/04 :    0   1
cdr2584:98188:98759 [0] NCCL INFO Channel 01/04 :    0   1
cdr2584:98188:98759 [0] NCCL INFO Channel 02/04 :    0   1
cdr2584:98188:98759 [0] NCCL INFO Channel 03/04 :    0   1
cdr2584:98188:98759 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1 [1] 1/-1/-1->0->-1 [2] 1/-1/-1->0->-1 [3] 1/-1/-1->0->-1
cdr2584:98188:98759 [0] NCCL INFO Channel 00 : 0[18000] -> 1[af000] via P2P/direct pointer
cdr2584:98188:98759 [0] NCCL INFO Channel 01 : 0[18000] -> 1[af000] via P2P/direct pointer
cdr2584:98188:98759 [0] NCCL INFO Channel 02 : 0[18000] -> 1[af000] via P2P/direct pointer
cdr2584:98188:98759 [0] NCCL INFO Channel 03 : 0[18000] -> 1[af000] via P2P/direct pointer
cdr2584:98188:98760 [1] NCCL INFO Trees [0] -1/-1/-1->1->0 [1] -1/-1/-1->1->0 [2] -1/-1/-1->1->0 [3] -1/-1/-1->1->0
cdr2584:98188:98760 [1] NCCL INFO Channel 00 : 1[af000] -> 0[18000] via P2P/direct pointer
cdr2584:98188:98760 [1] NCCL INFO Channel 01 : 1[af000] -> 0[18000] via P2P/direct pointer
cdr2584:98188:98760 [1] NCCL INFO Channel 02 : 1[af000] -> 0[18000] via P2P/direct pointer
cdr2584:98188:98760 [1] NCCL INFO Channel 03 : 1[af000] -> 0[18000] via P2P/direct pointer
cdr2584:98188:98760 [1] NCCL INFO Connected all rings
cdr2584:98188:98760 [1] NCCL INFO Connected all trees
cdr2584:98188:98759 [0] NCCL INFO Connected all rings
cdr2584:98188:98760 [1] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 8/8/512
cdr2584:98188:98760 [1] NCCL INFO 4 coll channels, 4 p2p channels, 4 p2p channels per peer
cdr2584:98188:98759 [0] NCCL INFO Connected all trees
cdr2584:98188:98759 [0] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 8/8/512
cdr2584:98188:98759 [0] NCCL INFO 4 coll channels, 4 p2p channels, 4 p2p channels per peer
cdr2584:98188:98760 [1] NCCL INFO comm 0x2af5bb402180 rank 1 nranks 2 cudaDev 1 busId af000 - Init COMPLETE
cdr2584:98188:98759 [0] NCCL INFO comm 0x2af5ac00a870 rank 0 nranks 2 cudaDev 0 busId 18000 - Init COMPLETE
cdr2584:98188:98188 [0] NCCL INFO Launch mode Parallel

@sjeaugey
Copy link

sjeaugey commented Feb 6, 2024

If you have an IB fabric, you probably don't want to disable IB, as it would use TCP instead and that may affect performance by an order of magnitude. So, disabling IB is not a solution for everyone, only for those who actually don't need IB but happen to have a misconfigured active IB NIC on their system (or a NIC they don't want to use for various reasons). The real solution is usually to fix the IB configuration and use it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3rd party Related to a 3rd-party bug Something isn't working distributed Generic distributed-related topic help wanted Open to be worked on priority: 0 High priority task
Projects
None yet
Development

No branches or pull requests