Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[gemma 3] multimodal checkpoints + AutoModelForCausalLM #36741

Merged
merged 4 commits into from
Mar 19, 2025

Conversation

gante
Copy link
Member

@gante gante commented Mar 15, 2025

What does this PR do?

Fixes #36683

Enables AutoModelForCausalLM with multimodal checkpoints, to load a decoder-only text model.

E.g. the following now works:

from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-4b-it")
model = AutoModelForCausalLM.from_pretrained("google/gemma-3-4b-it", device_map="auto")

inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device)
gen_out = model.generate(**inputs, do_sample=False)
print(tokenizer.batch_decode(gen_out, skip_special_tokens=True))

@github-actions github-actions bot marked this pull request as draft March 15, 2025 13:17
Copy link

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the Ready for review button (at the bottom of the PR page).

@gante gante marked this pull request as ready for review March 15, 2025 13:17
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante gante requested review from LysandreJik, zucchini-nlp and ArthurZucker and removed request for Rocketknight1 and ArthurZucker March 15, 2025 14:12
Under the hood, multimodal models mapped by AutoModelForCausalLM assume the text decoder receives its own
config, rather than the config for the whole model.
"""
return config.get_text_config(decoder=True)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(this is a no-op on models without nested configs, and catches all known nested names for text decoders. See get_text_config:

def get_text_config(self, decoder=False) -> "PretrainedConfig":
)

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! I just got to my laptop and tried to reproduce the issue, seems like it is already working on main branch. At least for me it maps the config to the TextConfig since we have both config types in auto-map

Also checked on release branch, same things, works smoothly. Am I missing smth?

@gante
Copy link
Member Author

gante commented Mar 17, 2025

@zucchini-nlp that's curious, this means the issue is possibly setup or commit-dependent 🤔

On my machine, the test script above fails both on main (c53d53da89c0617f7dd5a69a2a08e6b1232b35fd) and on gemma3 release commit (46350f5eae87ac1d168ddfdc57a0b39b64b9a029), both on CPU and GPU. Same error as reported by users.

What was the commit you've tried it on?

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, the test here uses a 4b checkpoint and I used a 1b checkpoint. The config for 1b was saved through CausalLM so it has the correct "model_type", while 4b has a multimodal config type

My bad! Now I see the reason, and I don't agree we should be letting users to load a multimodal model with AutoCausaLM handle. Even if model can work with text only, which is true for all VLMs, the user is expected to load AutoImageTextToText and feed the type of inputs they want (any combination of text or image). We will be getting more modalities per model soon, with model like MiniCPM. and imo it doesn't scale to let those be loadable with auto classes for any subset of supported modalities

For ex an any-to-text model with "audio+image+text" will not be expected to map with "image+text" auto class, if that makes sense. Though currently we have only image+text mapping, I believe we'll expand that list soon, staring from video which I have planned for this year

So imo, that is user error and it has been so for all VLMs up to now. If for some reason Gemma3ForConditionalGeneration doesn't accept text-only input, we can fix that. Lmk if you think otherwise 🤗

@gante
Copy link
Member Author

gante commented Mar 17, 2025

There are two angles to this question:

  1. What we communicate to our users;
  2. Whether we should enable it or not.

Regarding 1:

In the release notes we write "For text-only inputs use Gemma3ForCausalLM for generation to avoid loading the vision tower.", so we expect the model to be used as text-only. Both gemma3 and gemma3_text are in the auto map for AutoModelForCausalLM. As such, from our side, we indicate to the user that any gemma3 model is compatible with AutoModelForCausalLM, implicitly (can be used as text-only model) and explicitly (it's part of the auto map).

It is not compatible at the moment, so we must choose whether to enable it or update other parts accordingly (including perhaps exceptions) 🤗


Regarding 2:

If gemma3 supports text-only, I'm deeply in favor of enabling it with AutoModelForCausalLM:

  • Gemma3ForCausalLM is supported, and it is the same (minus the ease of use)
  • Many downstream libraries use auto classes and don't want to bother with model-specific imports. It makes their development cycle much easier.
  • Without AutoModelForCausalLM, the alternative auto class loads the vision tower (= waste of memory)
  • The changes are minimal, as we can see in this PR

I don't think we should do it for all combinations of modalities, not unless there is a demand for it. However, text-only generation is the most used task in our library and on the hub, and it is expected that new models will often be text + other modalities. As such, if we don't enable this, downstream libraries/users will have to pick between model-specific imports and additional memory usage.

For me, the advantages are very clear :)

@zucchini-nlp
Copy link
Member

Regarding the communication, I see how there might have been some confusion. The idea was that Gemma3 has a 1B checkpoint trained on text-only, so users are expected to load it with CausalLM

If Gemma3 supports text-only, I'm deeply in favor of enabling it with AutoModelForCausalLM

I understand the reasoning behind this, but I think we should also consider consistency with other VLMs. Gemma3 and Mllama are only models that have a separate CausalLM class. Even though any other VLMs can work in text-only setting with decent quality, we didn't add a CausalLM class for them since the models released didn't have a text-only checkpoint. If we decide to introduce CausalLM for all models that support text, it would make sense to extend the same approach to other VLMs as well.

As a reference point, LLaVa users have been running the model in text-only mode by loading ConditionalModel, and we’ve enabled it by ensuring the processor doesn’t raise errors when images=None which is standard by now for all processors afair

Without AutoModelForCausalLM, the alternative auto class loads the vision tower (= waste of memory)

Valid concern, though in practice, the additional memory overhead is relatively small (200-500M params with SigLIP). While not ideal, this may not significantly impact users, especially since they are intentionally choosing a multimodal model for text-only tasks

Many downstream libraries use auto classes and don't want to bother with model-specific imports. It makes their development cycle much easier

I think this is the strongest argument for adding CausalLM. However, I’d like to clarify whether this is an actual limitation of VLLM—where it can't generate from text-only inputs with Gemma3ForConditionalGeneration—or if this is more about user expectations shaped by our documentation. Given that previous multimodal releases didn’t raise similar concerns, I suspect this is more of a communication issue than technical limitation on VLLM.

If we decide to add add CausalLM for gemma3 multimodal models, we would need to ensure consistency across other model. This shouldn’t be too difficult since part of the standardization process for VLLM already involves separating BaseModel and lm-head for all multimodals - though it may add some redundancy imho

Would love to hear thoughts on this! If you agree with the above, I would be pro of shaping better our communication next time, and be more explicit on docs. I feel like in general the docs for multimodality lacks a lot of information, but I wanted first to make sure we have a consistent API before adding a section for Multimodal LLMs

@@ -539,6 +544,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if kwargs_orig.get("quantization_config", None) is not None:
kwargs["quantization_config"] = kwargs_orig["quantization_config"]

# AutoClass-specific config manipulation
config = copy.deepcopy(config)
config = cls._prepare_config_for_auto_class(config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could also just get the text config here no?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

( I mean no one liner!)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only want to pull the text config in AutoModelForCausalLM, and this function is shared across all AutoModel classes. This means that we either:

  • do inheritance (commit you reviewed)
  • add an if self.__class__.__name__ == "AutoModelForCausalLM" with the logic

I think inheritance is cleaner, but I'm happy to change it to the alternative version :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no worries this is fine! I just want watever is simpler!

@@ -539,6 +544,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if kwargs_orig.get("quantization_config", None) is not None:
kwargs["quantization_config"] = kwargs_orig["quantization_config"]

# AutoClass-specific config manipulation
config = copy.deepcopy(config)
config = cls._prepare_config_for_auto_class(config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no worries this is fine! I just want watever is simpler!

@gante gante added the for patch Tag issues / labels that should be included in the next patch label Mar 19, 2025
@gante gante merged commit 7c23398 into huggingface:main Mar 19, 2025
24 checks passed
@graftim
Copy link

graftim commented Mar 20, 2025

Thanks a lot for this PR 🙏 . Any chance this can be added to the v4.49.0-Gemma-3 tag? Or is 4.50 around the corner anyway?

@gante
Copy link
Member Author

gante commented Mar 20, 2025

Or is 4.50 around the corner anyway?

yes, it is 😉

@gante gante deleted the gemma3_automodelforcausallm branch March 20, 2025 10:55
ArthurZucker pushed a commit that referenced this pull request Mar 25, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
* deprecate the prev fix

* reword warning and update docs

* reword warning

* tests

* dont bloat `get_text_config()`
ArthurZucker pushed a commit that referenced this pull request Mar 25, 2025
* deprecate the prev fix

* reword warning and update docs

* reword warning

* tests

* dont bloat `get_text_config()`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
for patch Tag issues / labels that should be included in the next patch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

AttributeError: 'Gemma3Config' object has no attribute 'vocab_size'
5 participants