Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: CUDA error: initialization error when calling torch.distributed.init_process_group using torch multiprocessing #68256

Open
ParamsRaman opened this issue Nov 12, 2021 · 4 comments
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue

Comments

@ParamsRaman
Copy link

ParamsRaman commented Nov 12, 2021

❓ Questions and Help

I created a pytest fixture using decorator to create multiple processes (using torch multiprocessing) for running model parallel distributed unit tests using pytorch distributed. I randomly encountered the below CUDA initialization error all of a sudden (when I was trying to fix some unit tests logic). Since then, all my unit tests have been failing and I traced the failure back to my pytest fixture which calls torch.distributed.init_process_group(..).

Error traceback:

$ python3 -m pytest test/test_distributed.py::test_dummy
Process Process-1:
Traceback (most recent call last):
  File "/usr/lib64/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/usr/lib64/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/fsx-dev/FSxLustre20201016T182138Z/prraman/home/workspace/ws_M5_meg/src/M5ModelParallelism/test_script/commons_debug.py", line 34, in dist_init
    torch.distributed.init_process_group(backend, rank=rank, world_size=world_size, init_method=init_method)
  File "/usr/local/lib64/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 480, in init_process_group
    barrier()
  File "/usr/local/lib64/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 2186, in barrier
    work = _default_pg.barrier()
RuntimeError: CUDA error: initialization error

Below is the pytest fixture I created:

# file: test_distributed.py
import os
import time
import torch
import torch.distributed as dist
from torch.multiprocessing import Process, set_start_method
import pytest

# Worker timeout *after* the first worker has completed.
WORKER_TIMEOUT = 120


def distributed_test_debug(world_size=2, backend='nccl'):
    """A decorator for executing a function (e.g., a unit test) in a distributed manner.
        This decorator manages the spawning and joining of processes, initialization of
        torch.distributed, and catching of errors.

        Usage example:
        @distributed_test_debug(worker_size=[2,3])
        def my_test():
            rank = dist.get_rank()
            world_size = dist.get_world_size()
            assert(rank < world_size)

    Arguments:
        world_size (int or list): number of ranks to spawn. Can be a list to spawn
        multiple tests.
    """
    def dist_wrap(run_func):
        """Second-level decorator for dist_test. This actually wraps the function. """
        def dist_init(local_rank,
                      num_procs,
                      *func_args, **func_kwargs):
            """Initialize torch.distributed and execute the user function. """
            os.environ['MASTER_ADDR'] = '127.0.0.1'
            os.environ['MASTER_PORT'] = '29503'
            os.environ['LOCAL_RANK'] = str(local_rank)
            # NOTE: unit tests don't support multi-node so local_rank == global rank
            os.environ['RANK'] = str(local_rank)
            os.environ['WORLD_SIZE'] = str(num_procs)
            master_addr = os.environ['MASTER_ADDR']
            master_port = os.environ['MASTER_PORT']
            rank = local_rank

            # Initializes the default distributed process group, and this will also initialize the distributed package.
            init_method = "tcp://"
            init_method += master_addr + ":" + master_port
            print('inside dist_init, world_size: ', world_size)
            torch.distributed.init_process_group(backend, rank=rank, world_size=world_size, init_method=init_method)
            print("rank={} init complete".format(rank))

            #torch.distributed.destroy_process_group()
            # print("rank={} destroy complete".format(rank))

            if torch.distributed.get_rank() == 0:
                print('> testing initialize_model_parallel with size {} ...'.format(
                    2))

            if torch.cuda.is_available():
                torch.cuda.set_device(local_rank)

            run_func(*func_args, **func_kwargs)

        def dist_launcher(num_procs,
                          *func_args, **func_kwargs):
            """Launch processes and gracefully handle failures. """

            # Spawn all workers on subprocesses.
            #set_start_method('spawn')
            processes = []
            for local_rank in range(num_procs):
                p = Process(target=dist_init,
                            args=(local_rank,
                                  num_procs,
                                  *func_args),
                            kwargs=func_kwargs)
                p.start()
                processes.append(p)

            # Now loop and wait for a test to complete. The spin-wait here isn't a big
            # deal because the number of processes will be O(#GPUs) << O(#CPUs).
            any_done = False
            while not any_done:
                for p in processes:
                    if not p.is_alive():
                        any_done = True
                        break

            # Wait for all other processes to complete
            for p in processes:
                p.join(WORKER_TIMEOUT)

            failed = [(rank, p) for rank, p in enumerate(processes) if p.exitcode != 0]
            for rank, p in failed:
                # If it still hasn't terminated, kill it because it hung.
                if p.exitcode is None:
                    p.terminate()
                    pytest.fail(f'Worker {rank} hung.', pytrace=False)
                if p.exitcode < 0:
                    pytest.fail(f'Worker {rank} killed by signal {-p.exitcode}',
                                pytrace=False)
                if p.exitcode > 0:
                    pytest.fail(f'Worker {rank} exited with code {p.exitcode}',
                                pytrace=False)

        def run_func_decorator(*func_args, **func_kwargs):
            """Entry point for @distributed_test(). """

            if isinstance(world_size, int):
                dist_launcher(world_size, *func_args, **func_kwargs)
            elif isinstance(world_size, list):
                for procs in world_size:
                    dist_launcher(procs, *func_args, **func_kwargs)
                    time.sleep(0.5)
            else:
                raise TypeError(f'world_size must be an integer or a list of integers.')

        return run_func_decorator

    return dist_wrap

Below is how I call the pytest fixture:

@distributed_test_debug(world_size=2)
def test_dummy():
    assert 1 == 1

I have seen some issues raised in the past when torch multiprocessing and CUDA not working well together, not sure if this is related to that. Perhaps a different way I should be creating my multiple processes to avoid this problem? Any help is appreciated.

I am using pytorch version: 1.8.0a0+ae5c2fe

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang

@facebook-github-bot facebook-github-bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Nov 12, 2021
@pritamdamania87
Copy link
Contributor

@ParamsRaman Can you try with PyTorch 1.9? A bunch of logic for init_process_group was changed where we no longer use barrier as part of init_process_group.

@rohan-varma
Copy link
Member

I see set_start_method(spawn) is commented out, does it work if we ensure we are spawning (instead of fork) all subprocs? CUDA has a bunch of issues with forking multiprocessing

@mrwyattii
Copy link

mrwyattii commented Jul 28, 2022

@ParamsRaman if you are still looking for a solution that works with the latest PyTorch, here is what we came up with for DeepSpeed unit testing:

We replace the distributed_test decorator with a DistributedTest class:

import inspect
from abc import ABC

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.multiprocessing import Process

import pytest

DEEPSPEED_UNIT_WORKER_TIMEOUT = 120

class DistributedTest(ABC):
    is_dist_test = True
    world_size = 2
    backend = "nccl"

    def _run_test(self, request):
        self.current_test = self._get_current_test_func(request)
        self.test_kwargs = self._get_test_kwargs(request)
        if isinstance(self.world_size, int):
            self.world_size = [self.world_size]
        for procs in self.world_size:
            self._launch_procs(procs)
            time.sleep(0.5)

    def _get_current_test_func(self, request):
        # DistributedTest subclasses may have multiple test methods
        func_name = request.function.__name__
        return getattr(self, func_name)

    def _get_test_kwargs(self, request):
        # Grab fixture / parametrize kwargs from pytest request object
        test_kwargs = {}
        params = inspect.getfullargspec(self.current_test).args
        params.remove("self")
        for p in params:
            test_kwargs[p] = request.getfixturevalue(p)
        return test_kwargs

    def _launch_procs(self, num_procs):
        mp.set_start_method('forkserver', force=True)
        processes = []
        for local_rank in range(num_procs):
            p = Process(target=self._dist_init, args=(local_rank, num_procs))
            p.start()
            processes.append(p)

        # Now loop and wait for a test to complete. The spin-wait here isn't a big
        # deal because the number of processes will be O(#GPUs) << O(#CPUs).
        any_done = False
        while not any_done:
            for p in processes:
                if not p.is_alive():
                    any_done = True
                    break

        # Wait for all other processes to complete
        for p in processes:
            p.join(DEEPSPEED_UNIT_WORKER_TIMEOUT)

        failed = [(rank, p) for rank, p in enumerate(processes) if p.exitcode != 0]
        for rank, p in failed:
            # If it still hasn't terminated, kill it because it hung.
            if p.exitcode is None:
                p.terminate()
                pytest.fail(f'Worker {rank} hung.', pytrace=False)
            if p.exitcode < 0:
                pytest.fail(f'Worker {rank} killed by signal {-p.exitcode}',
                            pytrace=False)
            if p.exitcode > 0:
                pytest.fail(f'Worker {rank} exited with code {p.exitcode}',
                            pytrace=False)

    def _dist_init(self, local_rank, num_procs):
        """Initialize deepspeed.comm and execute the user function. """
        os.environ['MASTER_ADDR'] = '127.0.0.1'
        os.environ['MASTER_PORT'] = get_master_port()
        os.environ['LOCAL_RANK'] = str(local_rank)
        # NOTE: unit tests don't support multi-node so local_rank == global rank
        os.environ['RANK'] = str(local_rank)
        os.environ['WORLD_SIZE'] = str(num_procs)

        # turn off NCCL logging if set
        os.environ.pop('NCCL_DEBUG', None)

        set_cuda_visibile()

        dist.init_process_group(backend=self.backend, rank=local_rank, world_size=num_procs)
        dist.barrier()

        if torch.cuda.is_available():
            torch.cuda.set_device(local_rank)

        self.current_test(**self.test_kwargs)

        # make sure all ranks finish at the same time
        dist.barrier()
        # tear down after test completes
        dist.destroy_process_group()

You will need to add the following to your conftest.py:

# Override of pytest "runtest" for DistributedTest class
# This hook is run before the default pytest_runtest_call
@pytest.hookimpl(tryfirst=True)
def pytest_runtest_call(item):
    # We want to use our own launching function for distributed tests
    if getattr(item.cls, "is_dist_test", False):
        dist_test_class = item.cls()
        dist_test_class._run_test(item._request)
        item.runtest = lambda: True  # Dummy function so test is not run twice

and then your distributed tests will need to be refactored:

# OLD TEST
@pytest.mark.parametrize('foo,bar', [(1,2),(3,4)])
def test_example(foo, bar):
    @distributed_test(world_size=[1,4])
    def _go():
        assert foo < bar
    _go()

# NEW TEST
@pytest.mark.parametrize('foo,bar', [(1,2),(3,4)])
class TestExample(DistributedTest):
    world_size=[1,4]
    def test(self, foo, bar):
        assert foo < bar

@ParamsRaman
Copy link
Author

@mrwyattii Thanks for the snippet. Will try this out!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

No branches or pull requests

5 participants