-
Notifications
You must be signed in to change notification settings - Fork 29.6k
Allow compile with bnb #38886
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Allow compile with bnb #38886
Conversation
When generating by default , is |
any idea why the test is not passing @matthewdouglas ?
Getting the following traceback:
Same for 8bit
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Thanks for adding this! For 4bit this is just a graph break so it does seem to be using I think this default comes from transformers/src/transformers/generation/configuration_utils.py Lines 1564 to 1579 in c27f628
If that's what we expect by default it might be simplest to guard on for torch >= 2.8 for 4bit. Otherwise we might want to find a way to improve UX, like catching this and providing a better trace/error message, or disabling the fullgraph like here: transformers/src/transformers/generation/utils.py Lines 3581 to 3586 in 9cd7570
Regarding 8bit, it's expected that you need |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A question about torch versions, but in general LGTM!
def is_compileable(self) -> bool: | ||
# Compatible with PyTorch 2.4+ for fullgraph=False. | ||
# Requires PyTorch 2.8 nightly for fullgraph=True. | ||
return version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.46.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also do torch
version checks here? To simplify logic, we should require torch>=2.8.0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that to keep this simple we can go ahead and just require torch>=2.8.0
here.
Longer term in a separate PR maybe we can do some refactoring to this and let is_compileable
consider additional context from a CompileConfig
too.
@@ -314,3 +314,7 @@ def _dequantize(self, model): | |||
model, self.modules_to_not_convert, quantization_config=self.quantization_config | |||
) | |||
return model | |||
|
|||
@property | |||
def is_compileable(self) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this one also have minimum torch requirements, or does it work with all torch versions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. The requirement here is torch>=2.4.0
.
Thanks for everyone advice ! I will check that later and probably merge it after torch 2.8 it out |
What does this PR do?
This PR enable compilation when generating with bnb models. This is supported with the latest bnb.