Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[In Gunicorn & multiprocessing environment] Cannot re-initialize CUDA in forked subprocess #68861

Open
yakirs57 opened this issue Nov 24, 2021 · 4 comments
Labels
module: multiprocessing Related to torch.multiprocessing Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@yakirs57
Copy link

yakirs57 commented Nov 24, 2021

I am getting "RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method"

I have developed a REST API (Gunicorn; Gevent; Flask; Python) which runs a model loaded by PyTorch, and has multiple workers and threads to support parallel executions.
I receive the above-mentioned error when I call the model, e.g.: model_name(imgs)

Versions:
Python=3.8.8
Gunicorn=20.1.0
Gevent=21.1.2
Flask=2.0.0
torch=1.8.1

I have tried to do the following:

  1. Add pytorch.multiprocessing.set_start_method('spawn') //NOTE: I added prints before calling the model to see if the start method is still of a value 'spawn' and it is.
  2. Add torch.set_num_threads(1)
  3. Add .share_memory() & .eval() to the models when they are loaded

Any help will be appreciated.

cc @VitalyFedyunin

@ejguan
Copy link
Contributor

ejguan commented Nov 24, 2021

Not familiar with Gunicorn. You may want to set spawn as default for Gunicorn.

@ejguan ejguan added module: multiprocessing Related to torch.multiprocessing triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Nov 24, 2021
@github-actions github-actions bot added the Stale label Jan 23, 2022
@icannotnamemyself
Copy link

icannotnamemyself commented Sep 11, 2022

I looked through the source code of gunicorn and torch, the reason why this error happened is that gunicorn will fork the worker process , and torch will set a forked child process in_bad_fork.

Explaination

  1. in gunicorn process will be forked using os.fork()
    def spawn_worker(self):
        self.worker_age += 1
        worker = self.worker_class(self.worker_age, self.pid, self.LISTENERS,
                                   self.app, self.timeout / 2.0,
                                   self.cfg, self.log)
        self.cfg.pre_fork(self, worker)
        pid = os.fork()
        if pid != 0:
            worker.pid = pid
            self.WORKERS[pid] = worker
            return pid

        # Do not inherit the temporary files of other workers
        for sibling in self.WORKERS.values():
            sibling.tmp.close()

        # Process Child
        worker.pid = os.getpid()
        try:
            util._setproctitle("worker [%s]" % self.proc_name)
            self.log.info("Booting worker with pid: %s", worker.pid)
            self.cfg.post_fork(self, worker)
            worker.init_process()
            sys.exit(0)
        except SystemExit:
            raise
        except AppImportError as e:
            self.log.debug("Exception while loading the application",
                           exc_info=True)
            print("%s" % e, file=sys.stderr)
            sys.stderr.flush()
            sys.exit(self.APP_LOAD_ERROR)
        except Exception:
            self.log.exception("Exception in worker process")
            if not worker.booted:
                sys.exit(self.WORKER_BOOT_ERROR)
            sys.exit(-1)
        finally:
            self.log.info("Worker exiting (pid: %s)", worker.pid)
            try:
                worker.tmp.close()
                self.cfg.worker_exit(self, worker)
            except Exception:
                self.log.warning("Exception during worker exit:\n%s",
                                 traceback.format_exc())

  1. in torch , a forked process will be set in bad fork status
#ifndef WIN32
// Called in the forked child if cuda has already been initialized
static void forked_child() {
  in_bad_fork = true;
  torch::utils::set_requires_cuda_init(true);
}
#endif

// Should be called before the first cuda call.
// Note: This is distinct from initExtension because a stub cuda implementation
// has some working functions (e.g. device_count) but cannot fully initialize.
static void poison_fork() {
#ifndef WIN32
  static c10::once_flag flag;
  c10::call_once(flag, [] { pthread_atfork(nullptr, nullptr, forked_child); });
#endif
}

I made some modifications to gunicorn and try to dig in to it , and indeed , after fork ,the child process's torch.cuda._is_in_bad_fork is True, and the parent process's torch.cuda._is_in_bad_fork is False.

Solution

what we have to do is to clear this torch.cuda._is_in_bad_fork variable . Since it is actually a global variable in
pytorch underlying C language, we can use python multiprocess in spawn start method to create a totally new process without most current status of parent process.

change your code in your flask app from

task()

to

    ctx = multiprocessing.get_context('spawn')
    p = ctx.Process(task, args=(..args)))
    p.start()
    p.join()

@Davidnet
Copy link

Hi @wayneleif would you mind to expand the above discussion I have no idea what could task be, is it the initialization of the model?

@Sanster
Copy link

Sanster commented Jan 11, 2023

For anyone struggle works with gunicorn + pytorch, this might be useful.

The key is not run any cuda relate code in master process(including torch.cuda.device_count()), executing any of the following line in master process will result in Cannot re-initialize CUDA in forked subprocess error

torch.cuda.device_count()
torch.cuda.is_available()
torch.tensor(1).cuda()
torch.autocast("cuda")

If you still got error after check you code, this issue might be helpful: #17199

Full code:

import torch
from gunicorn.app.base import BaseApplication


class StandaloneApplication(BaseApplication):
    def __init__(self, app, options=None):
        self.options = options or {}
        self.application = app
        super().__init__()

    def load_config(self):
        config = {key: value for key, value in self.options.items()
                  if key in self.cfg.settings and value is not None}
        for key, value in config.items():
            self.cfg.set(key.lower(), value)

    def load(self):
        return self.application

# Executing any of the following lines will result in Cannot re-initialize CUDA in forked subprocess error
# torch.cuda.device_count()
# torch.cuda.is_available()
# torch.tensor(1).cuda()

def post_worker_init(worker):
    conv = torch.nn.Conv2d(330, 330, 3, 3, 3, 3)
    conv.cuda()
    print("finish post_worker_init")


def handler_app(environ, start_response):
    response_body = b'Works fine'
    status = '200 OK'

    response_headers = [
        ('Content-Type', 'text/plain'),
    ]

    start_response(status, response_headers)

    return [response_body]


def main():
    host = "0.0.0.0"
    port = 7860

    options = {
        'bind': f"{host}:{port}",
        'workers': 1,
        'worker_class': 'uvicorn.workers.UvicornWorker',
        'post_worker_init': post_worker_init,
        'timeout': 120,
    }
    StandaloneApplication(handler_app, options).run()


if __name__ == "__main__":
    main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: multiprocessing Related to torch.multiprocessing Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants