Skip to content

Allow compile with bnb #38886

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Allow compile with bnb #38886

wants to merge 4 commits into from

Conversation

SunMarc
Copy link
Member

@SunMarc SunMarc commented Jun 18, 2025

What does this PR do?

This PR enable compilation when generating with bnb models. This is supported with the latest bnb.

@SunMarc SunMarc requested review from gante and matthewdouglas June 18, 2025 14:27
@SunMarc
Copy link
Member Author

SunMarc commented Jun 18, 2025

When generating by default , is fullgraph is set to False ? From the codebase it looks like that but I wanted to be sure cc @gante

@SunMarc
Copy link
Member Author

SunMarc commented Jun 18, 2025

any idea why the test is not passing @matthewdouglas ?

RUN_SLOW=True pytest tests/quantization/bnb/test_4bit.py::Bnb4bitCompile -s -vvvvv. I'm using torch 2.7.1

Getting the following traceback:

FAILED tests/quantization/bnb/test_4bit.py::Bnb4bitCompile::test_generate_compile - torch._dynamo.exc.Unsupported: Unsupported method call
  Explanation: Dynamo does not know how to trace method `t` of class `Params4bit`
  Hint: Avoid calling `Params4bit.t` in your code.
  Hint: Please report an issue to PyTorch.

  Developer debug context: call_method UserDefinedObjectVariable(Params4bit) t [] {}


from user code:
   File "/admin/home/marc/transformers/src/transformers/utils/generic.py", line 943, in wrapper
    output = func(self, *args, **kwargs)
  File "/admin/home/marc/transformers/src/transformers/models/llama/modeling_llama.py", line 555, in forward
    outputs: BaseModelOutputWithPast = self.model(
  File "/admin/home/marc/transformers/src/transformers/utils/generic.py", line 943, in wrapper
    output = func(self, *args, **kwargs)
  File "/admin/home/marc/transformers/src/transformers/models/llama/modeling_llama.py", line 443, in forward
    layer_outputs = decoder_layer(
  File "/admin/home/marc/transformers/src/transformers/modeling_layers.py", line 48, in __call__
    return super().__call__(*args, **kwargs)
  File "/admin/home/marc/miniconda3/envs/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/admin/home/marc/miniconda3/envs/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/admin/home/marc/transformers/src/transformers/models/llama/modeling_llama.py", line 294, in forward
    hidden_states, self_attn_weights = self.self_attn(
  File "/admin/home/marc/transformers/src/transformers/models/llama/modeling_llama.py", line 235, in forward
    query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  File "/admin/home/marc/miniconda3/envs/hf/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 496, in forward
    return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Same for 8bit


FAILED tests/quantization/bnb/test_mixed_int8.py::Bnb8bitCompile::test_generate_compile - torch._dynamo.exc.Unsupported: Dynamic shape operator
  Explanation: Operator `bitsandbytes.int8_vectorwise_quant.default`'s output shape depends on input Tensor data.
  Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`

  Developer debug context: bitsandbytes.int8_vectorwise_quant.default


from user code:
   File "/admin/home/marc/transformers/src/transformers/utils/generic.py", line 943, in wrapper
    output = func(self, *args, **kwargs)
  File "/admin/home/marc/transformers/src/transformers/models/llama/modeling_llama.py", line 555, in forward
    outputs: BaseModelOutputWithPast = self.model(
  File "/admin/home/marc/transformers/src/transformers/utils/generic.py", line 943, in wrapper
    output = func(self, *args, **kwargs)
  File "/admin/home/marc/transformers/src/transformers/models/llama/modeling_llama.py", line 443, in forward
    layer_outputs = decoder_layer(
  File "/admin/home/marc/transformers/src/transformers/modeling_layers.py", line 48, in __call__
    return super().__call__(*args, **kwargs)
  File "/admin/home/marc/miniconda3/envs/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/admin/home/marc/miniconda3/envs/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/admin/home/marc/transformers/src/transformers/models/llama/modeling_llama.py", line 294, in forward
    hidden_states, self_attn_weights = self.self_attn(
  File "/admin/home/marc/transformers/src/transformers/models/llama/modeling_llama.py", line 235, in forward
    query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  File "/admin/home/marc/miniconda3/envs/hf/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 1010, in forward
    out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
  File "/admin/home/marc/miniconda3/envs/hf/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 369, in matmul
    return MatMul8bitLt.apply(A, B, out, bias, state)
  File "/admin/home/marc/miniconda3/envs/hf/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 196, in forward
    CA, SCA, outlier_cols = F.int8_vectorwise_quant(A.to(torch.float16), threshold=state.threshold)
  File "/admin/home/marc/miniconda3/envs/hf/lib/python3.10/site-packages/bitsandbytes/functional.py", line 2245, in int8_vectorwise_quant
    return torch.ops.bitsandbytes.int8_vectorwise_quant.default(A, threshold)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@matthewdouglas
Copy link
Member

Thanks for adding this!

For 4bit this is just a graph break so it does seem to be using fullgraph=True. That test would pass on torch >= 2.8.

I think this default comes from CompileConfig here?

class CompileConfig:
"""
Class that holds arguments relative to `torch.compile` behavior, when using automatic compilation in `generate`.
See [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) for more details on the arguments.
Args:
fullgraph (`bool`, *optional*, defaults to `True`):
If `True`, requires that the whole forward be capturable in a single graph.
dynamic (`bool` or `None`, *optional*):
Whether to try to use dynamic shape graphs.
backend (`str` or `Callable`, *optional*, defaults to `"inductor"`):
Backend to be used.
mode (`str`, *optional*, defaults to `"reduce-overhead"`):
Controls balance between performance and overhead.
options (`dict`, *optional*):
A dictionary of options to pass to the backend.

If that's what we expect by default it might be simplest to guard on for torch >= 2.8 for 4bit. Otherwise we might want to find a way to improve UX, like catching this and providing a better trace/error message, or disabling the fullgraph like here:

elif generation_config.compile_config.fullgraph:
logger.warning_once(
"When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
"FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
)
generation_config.compile_config.fullgraph = False

Regarding 8bit, it's expected that you need torch._dynamo.config.capture_dynamic_output_shape_ops = True unless threshold=0.0. Fortunately the error message here isn't too bad and has the right advice, but I think this is also something we should just check for in is_compileable.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A question about torch versions, but in general LGTM!

def is_compileable(self) -> bool:
# Compatible with PyTorch 2.4+ for fullgraph=False.
# Requires PyTorch 2.8 nightly for fullgraph=True.
return version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.46.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also do torch version checks here? To simplify logic, we should require torch>=2.8.0

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that to keep this simple we can go ahead and just require torch>=2.8.0 here.

Longer term in a separate PR maybe we can do some refactoring to this and let is_compileable consider additional context from a CompileConfig too.

@@ -314,3 +314,7 @@ def _dequantize(self, model):
model, self.modules_to_not_convert, quantization_config=self.quantization_config
)
return model

@property
def is_compileable(self) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this one also have minimum torch requirements, or does it work with all torch versions?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. The requirement here is torch>=2.4.0.

@SunMarc
Copy link
Member Author

SunMarc commented Jul 3, 2025

Thanks for everyone advice ! I will check that later and probably merge it after torch 2.8 it out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants