Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

distributed data parallel, gloo backend works, but nccl deadlock #17745

Open
youkaichao opened this issue Mar 7, 2019 · 5 comments
Open

distributed data parallel, gloo backend works, but nccl deadlock #17745

youkaichao opened this issue Mar 7, 2019 · 5 comments
Labels
module: deadlock Problems related to deadlocks (hang without exiting) module: nccl Problems related to nccl support oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@youkaichao
Copy link
Collaborator

youkaichao commented Mar 7, 2019

I have run into the same problem as #14870 . Because I cannot reopen the issue, I opened a new issue. But the GPUs are all empty, except that 11MB memory is used (no process running).

The code is easy to reproduce:

import torch
import sys

dist = '--local_rank' in ''.join(sys.argv)

if dist:
    torch.distributed.init_process_group(backend='nccl')

def get_available_GPUs(N, max_utilization=.5, max_memory_usage=.5):
    '''
    get `N` available GPU ids with *utilization* less than `max_utilization` and *memory usage* less than max_memory_usage
    Arguments:
        N (int): How many GPUs you want to select
        max_utilization (float): GPU with utilization higher than `max_utilization` is considered as not available.
        max_memory_usage (float): GPU with memory usage higher than `max_memory_usage` is considered as not available.

    Returns:
        list containing IDs of available GPUs
    '''
    from subprocess import Popen, PIPE
    cmd = ["nvidia-smi",
           "--query-gpu=index,utilization.gpu,memory.total,memory.used",
           "--format=csv,noheader,nounits"]
    p = Popen(cmd, stdout=PIPE)
    output = p.stdout.read().decode('UTF-8')
    gpus = [[int(x) for x in line.split(',')] for line in output.splitlines()]
    gpu_ids = []
    for (index, utilization, total, used) in gpus:
        if utilization / 100.0 < max_utilization:
            if used * 1.0 / total < max_memory_usage:
                gpu_ids.append(index)
    if len(gpu_ids) < N:
        raise Exception("Only %s GPU(s) available but %s GPU(s) are required!" % (len(gpu_ids), N))
    available = gpu_ids[:N]
    return list(available)

def select_GPUs(N_per_process, max_utilization=.5, max_memory_usage=.5):
    '''
    select `N_per_process` GPUs.
    If distributed training is enabled, GPUs will be assigned properly among different processes.
    Arguments:
        N_per_process (int): How many GPUs you want to select for each process
        max_utilization (float): GPU with utilization higher than `max_utilization` is considered as not available.
        max_memory_usage (float): GPU with memory usage higher than `max_memory_usage` is considered as not available.

    Returns:
        list containing IDs of selected GPUs
    '''
    if not dist:
        return get_available_GPUs(N_per_process, max_utilization, max_memory_usage)
    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()
    tensor = torch.zeros(world_size * N_per_process, dtype=torch.int).cuda()
    if rank == 0:
        device_ids = get_available_GPUs(world_size * N_per_process)
        tensor = torch.tensor(device_ids, dtype=torch.int).cuda()
    torch.distributed.broadcast(tensor, 0)
    ids = list(tensor.cpu().numpy().tolist())
    return ids[N_per_process * rank : N_per_process * rank + N_per_process]

import torch.nn as nn
device_ids = select_GPUs(2)
output_device = torch.device(device_ids[0])

x = torch.tensor([5.0, -1.0], dtype=torch.float).cuda(output_device).view(-1, 1)

model = nn.Linear(in_features=1, out_features=1, bias=False).cuda(output_device)

rank = torch.distributed.get_rank()
model.weight.data.zero_()
model.weight.data.add_(rank)

print(f'at rank {rank}, before init, the weight is {model.weight.data.item()}')

model = torch.nn.parallel.DistributedDataParallel(model,device_ids=device_ids, output_device=output_device)

print(f'at rank {rank}, after init, the weight is {model.module.weight.data.item()}')

y = model(x)

label = torch.zeros(2, 1, dtype=torch.float).cuda(output_device)

loss = torch.sum((y - label)**2)

loss.backward()

# print(model.weight.grad)

After printing

at rank 0, before init, the weight is 0.0
at rank 1, before init, the weight is 1.0
at rank 0, after init, the weight is 0.0

it hangs.

But it works fine with gloo.

@pytorchbot pytorchbot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Mar 7, 2019
@chenhuiji
Copy link

I have encoutered this problem. My problem is caused by different CUDA version.

@ezyang ezyang added the module: nccl Problems related to nccl support label Apr 3, 2019
@ezyang
Copy link
Contributor

ezyang commented Apr 3, 2019

Can you please use https://github.com/pytorch/pytorch/blob/master/torch/utils/collect_env.py to report information about your system

@ezyang ezyang added module: deadlock Problems related to deadlocks (hang without exiting) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 3, 2019
@jerinphilip
Copy link

Had a similar issue. Mine was due to not issuing torch.cuda.set_device(args.local_rank) at the start of each process. I found each proc occupying some memory on the first gpu (gpustat with procids), which went away once I set the device. NCCL requires that everything in one proc be stored on respective GPUs.

@xbcReal
Copy link

xbcReal commented Feb 1, 2023

I got similar problem when I run demo ddp code.

import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel as DDP

# On Windows platform, the torch.distributed package only
# supports Gloo backend, FileStore and TcpStore.
# For FileStore, set init_method parameter in init_process_group
# to a local file. Example as follow:
# init_method="file:///f:/libtmp/some_file"
# dist.init_process_group(
#    "gloo",
#    rank=rank,
#    init_method=init_method,
#    world_size=world_size)
# For TcpStore, same way as on Linux.

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # create model and move it to GPU with id rank
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(rank)
    loss_fn(outputs, labels).backward()
    optimizer.step()

    cleanup()

    

if __name__ == '__main__':
    world_size = 2
    mp.spawn(demo_basic,
             args=(world_size,),
             nprocs=world_size,
             join=True)

it works well when the backend is gloo, but failed when the backend is nccl.
my environment information collected by collect_env.py is:

1

2

when I set NCCL_DEBUG=INFO,the log is as below:

3

4

So does anybody knows why this happens?

@xbcReal
Copy link

xbcReal commented Feb 7, 2023

I got similar problem when I run demo ddp code.

import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel as DDP

# On Windows platform, the torch.distributed package only
# supports Gloo backend, FileStore and TcpStore.
# For FileStore, set init_method parameter in init_process_group
# to a local file. Example as follow:
# init_method="file:///f:/libtmp/some_file"
# dist.init_process_group(
#    "gloo",
#    rank=rank,
#    init_method=init_method,
#    world_size=world_size)
# For TcpStore, same way as on Linux.

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # create model and move it to GPU with id rank
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(rank)
    loss_fn(outputs, labels).backward()
    optimizer.step()

    cleanup()

    

if __name__ == '__main__':
    world_size = 2
    mp.spawn(demo_basic,
             args=(world_size,),
             nprocs=world_size,
             join=True)

it works well when the backend is gloo, but failed when the backend is nccl. my environment information collected by collect_env.py is: 1 2

when I set NCCL_DEBUG=INFO,the log is as below: 3 4

So does anybody knows why this happens?

I fixed this by change another machine. So I guess maybe something wrong with the previous machine, and it has nothing to do with the software.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: deadlock Problems related to deadlocks (hang without exiting) module: nccl Problems related to nccl support oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants