Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix cuda device not found error when LLM is initialized in ray actor #3198

Closed
wants to merge 1 commit into from

Conversation

wuxibin89
Copy link

After #2221, when tensor_parallel_size>1, the driver process's CUDA_VISIBLE_DEVICES is manually set after RayWorkerVllm has been up. When LLM engine is initialized in a ray actor with num_gpus=0, ray set its CUDA_VISIBLE_DEVICES to '', then torch.cuda.is_available() will return False and causes any subsequent CUDA_VISIBLE_DEVICES updates invalid. The root cause is that pytorch initializes device number only once:
https://github.com/pytorch/pytorch/blob/main/c10/cuda/CUDAFunctions.cpp#L96-L113

We should be very careful not call any torch.cuda.* function before CUDA_VISIBLE_DEVICES is set.

import ray
from vllm import LLM

@ray.remote
class LLMDeployment:
    def __init__(self, *args, **kwargs) -> None:
        self.llm = LLM(*args, **kwargs) 


actor = LLMDeployment.remote("facebook/opt-13b", tensor_parallel_size=4)

@njhill
Copy link
Member

njhill commented Mar 5, 2024

@wuxibin89 I think this was actually introduced by #2569. I included a similar fix here.

I think we can keep auto but just not call torch.cuda.is_available() there.

@zhuohan123
Copy link
Collaborator

@wuxibin89 I think this was actually introduced by #2569. I included a similar fix here.

I think we can keep auto but just not call torch.cuda.is_available() there.

@njhill can you separate the fix to another small pr or @wuxibin89 can you modify this pr accordingly?

@wuxibin89
Copy link
Author

I think @njhill 's fix is better :)

@wuxibin89 wuxibin89 closed this Mar 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants