Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Avoid initializing CUDA too early #3487

Merged
merged 1 commit into from
Mar 19, 2024

Conversation

njhill
Copy link
Collaborator

@njhill njhill commented Mar 19, 2024

Care is taken in the code to avoid initializing CUDA prior to CUDA_VISIBLE_DEVICES being set in the worker, but an instance of this was inadvertently introduced in #2569.

Care is taken in the code to avoid initializing CUDA prior to CUDA_VISIBLE_DEVICES being set in the worker, but an instance of this was inadvertently introduced in vllm-project#2569.
@Yard1
Copy link
Collaborator

Yard1 commented Mar 19, 2024

Actually is it possible to somehow validate DeviceConfig inside the worker, after we have set CUDA_VISIBLE_DEVICES?

@youkaichao
Copy link
Member

Can we use something like in the setup.py?

def _is_cuda() -> bool:
    return torch.version.cuda is not None

This should not initialize cuda context, either.

It is not safe to assume cuda if it is not neuron.

@njhill
Copy link
Collaborator Author

njhill commented Mar 19, 2024

Can we use something like in the setup.py?

@youkaichao I assume this just indicates whether a cuda version of pytorch is in use and so would always return true.

It is not safe to assume cuda if it is not neuron.

I'm not sure what "safe" means here. If cuda/gpu isn't found then it the server will fail to start either way. Just that if it can be checked here it would fail slightly earlier with a nicer message.

Actually is it possible to somehow validate DeviceConfig inside the worker, after we have set CUDA_VISIBLE_DEVICES?

@Yard1 I'm not sure what an easy way to do that would be without nontrivial restructuring. Here all I'm doing is reverting something introduced when the neuron changes were added. We could contemplate that further as a separate improvement?

@Yard1
Copy link
Collaborator

Yard1 commented Mar 19, 2024

Ok sounds good, no blockers from my side

@youkaichao
Copy link
Member

@youkaichao I assume this just indicates whether a cuda version of pytorch is in use and so would always return true.

I'm not familiar with neuron. When we use neuron, is torch.version.cuda also set?

@rkooo567
Copy link
Collaborator

Also @njhill do you happen to know how this DeviceConfig has initialized before forking happens in the CI? I wonder if there's a way to restructure CI to avoid the same problem

Copy link
Collaborator

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I think this is a good temporary fix. Should be changed if we would like to use CPU in the future.

@zhuohan123
Copy link
Collaborator

Let me know when this is ready to be merged!

@njhill
Copy link
Collaborator Author

njhill commented Mar 19, 2024

Also @njhill do you happen to know how this DeviceConfig has initialized before forking happens in the CI? I wonder if there's a way to restructure CI to avoid the same problem

@rkooo567 I was just speculating that this bug might be the cause of that, but if you're referring to the 19 failing tests in Model Tests then it doesn't look like that's the case since they still fail in the CI for this branch.

@zhuohan123 from my pov it's ready to be merged, thanks!

@zhuohan123 zhuohan123 merged commit 7341c77 into vllm-project:main Mar 19, 2024
31 checks passed
@njhill njhill deleted the fix-cuda-init branch March 19, 2024 19:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants