-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Drop to float16 if bfloat16 is not supported #1901
Conversation
I personally prefer to ask users to explicitly specify dtype in this case, since otherwise it can affect the accuracy of the model silently. WDYT @zhuohan123 @simon-mo? |
I think the fallback is fine if we explicitly print a warning? |
+1. In this case warning is fine. The accuracy difference between bfloat and float should not be too crazy. |
@acebot712 You can simply do this: llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.1", dtype="half") |
I agree with @WoosukKwon, I think in general we should avoid doing any "magic" that can change the outputs, even if slightly. I would suggest to instead modify the exception message to suggest the user to set the dtype to float16 themselves. |
We can check device capability outside the vLLM and choose the dtype depending on device. The codes could be: llm = LLM(
model="mistralai/Mistral-7B-Instruct-v0.1",
dtype="float16" if torch.cuda.get_device_capability()[0] < 8 else "bfloat16",
) |
Hi @acebot712 Thanks for bringing up this issue and submitting the PR! We've decided to keep the current behavior; To avoid silent accuracy changes, vLLM will ask users to set |
As a beginner, I just experienced this little setback. Finally, of course, adding the parameter solved it: --dtype="half" |
@chuanzhubin A better error message is indeed a great idea? Would you be interested in submitting a PR? |
Issue:- #1157
Instead of throwing an error if the GPU compute capability is not supported for bfloat16, vLLM should throw a warning but use float16 instead. This helps in Colab notebooks where the compute of the T4 instance is set to 7.5 and not 8 so vLLM does not work trivially.