Skip to content

torch.cuda.init() unstacks existing CUDA contexts #75025

@QuiteAFoxtrot

Description

@QuiteAFoxtrot

🐛 Describe the bug

I am using the Nvidia's 'cuda-python' bindings, and I have discovered that torch expects to be the only CUDA resident in a given program. In my code I was doing something as follows:

import torch
from cuda import cuda #nvidia package
device_num = 5
err, = cuda.cuInit(0)
err, device = cuda.cuGetDevice(device_num)
err, cuda_context = cuda.cuCtxCreate(device)
...
#some stuff with the current active context which works fine
err, = some_func_with_my(cuda_context)
err, cuda_context = cuda.cuCtxPopCurrent() #remove my context from the stack so its clean for PyTorch's turn
...
#Then I need to do some pytorch stuff
torch_dev = torch.device("cuda:"+str(device_num))
torch.cuda.init() #Or any torch.function which initializes CUDA

...
#Then later I go to use my cuda context
err, = cuda.cuCtxPushCurrent(cuda_context) #push my context back on the stack, since other programs need not be aware of my program
err, = some_other_func_which_uses_my(cuda_context) #Crashes with cuda error Context Destroyed

As far as I can tell, there is only one programming paradigm which Nvidia/CUDA makes safe for having multiple residents within a single process both able to use CUDA, as described here:
https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#context

Under the hood, torch.cuda.init() must be calling something like cuDeviceReset() or doing something otherwise funky, but the point is after calling torch.cuda.init(), the context I created was destroyed. The workaround to this is to call torch.cuda.init() BEFORE I create my context - but it would be nice if torch could adhere to the Nvidia provided paradigm and "play nice in the sandbox" with others, instead of destroying things.

Update: As per example below, context is not destroyed, just removed from the CUDA stack.

At minimum, this behavior should be explicitly documented in torch.cuda.init() (though that wouldn't be playing nicely in the sandbox). The current documentation notes "Does nothing if the CUDA state is already initialized." which is quite false.

Versions

python 3.8
pytorch 1.10.1
cuda_toolkit 11.2
cuda-python 11.6.1

cc @ngimel

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudaRelated to torch.cuda, and CUDA support in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions