-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[gemma 3] multimodal checkpoints + AutoModelForCausalLM #36741
Conversation
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Under the hood, multimodal models mapped by AutoModelForCausalLM assume the text decoder receives its own | ||
config, rather than the config for the whole model. | ||
""" | ||
return config.get_text_config(decoder=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(this is a no-op on models without nested configs, and catches all known nested names for text decoders. See get_text_config
:
def get_text_config(self, decoder=False) -> "PretrainedConfig": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! I just got to my laptop and tried to reproduce the issue, seems like it is already working on main
branch. At least for me it maps the config to the TextConfig
since we have both config types in auto-map
Also checked on release branch, same things, works smoothly. Am I missing smth?
@zucchini-nlp that's curious, this means the issue is possibly setup or commit-dependent 🤔 On my machine, the test script above fails both on What was the commit you've tried it on? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see, the test here uses a 4b checkpoint and I used a 1b checkpoint. The config for 1b was saved through CausalLM
so it has the correct "model_type", while 4b has a multimodal config type
My bad! Now I see the reason, and I don't agree we should be letting users to load a multimodal model with AutoCausaLM
handle. Even if model can work with text only, which is true for all VLMs, the user is expected to load AutoImageTextToText
and feed the type of inputs they want (any combination of text or image). We will be getting more modalities per model soon, with model like MiniCPM. and imo it doesn't scale to let those be loadable with auto classes for any subset of supported modalities
For ex an any-to-text model with "audio+image+text" will not be expected to map with "image+text" auto class, if that makes sense. Though currently we have only image+text
mapping, I believe we'll expand that list soon, staring from video which I have planned for this year
So imo, that is user error and it has been so for all VLMs up to now. If for some reason Gemma3ForConditionalGeneration
doesn't accept text-only input, we can fix that. Lmk if you think otherwise 🤗
There are two angles to this question:
Regarding 1: In the release notes we write "For text-only inputs use Gemma3ForCausalLM for generation to avoid loading the vision tower.", so we expect the model to be used as text-only. Both It is not compatible at the moment, so we must choose whether to enable it or update other parts accordingly (including perhaps exceptions) 🤗 Regarding 2: If
I don't think we should do it for all combinations of modalities, not unless there is a demand for it. However, text-only generation is the most used task in our library and on the hub, and it is expected that new models will often be text + other modalities. As such, if we don't enable this, downstream libraries/users will have to pick between model-specific imports and additional memory usage. For me, the advantages are very clear :) |
Regarding the communication, I see how there might have been some confusion. The idea was that Gemma3 has a 1B checkpoint trained on text-only, so users are expected to load it with
I understand the reasoning behind this, but I think we should also consider consistency with other VLMs. Gemma3 and Mllama are only models that have a separate As a reference point, LLaVa users have been running the model in text-only mode by loading
Valid concern, though in practice, the additional memory overhead is relatively small (200-500M params with SigLIP). While not ideal, this may not significantly impact users, especially since they are intentionally choosing a multimodal model for text-only tasks
I think this is the strongest argument for adding If we decide to add add Would love to hear thoughts on this! If you agree with the above, I would be pro of shaping better our communication next time, and be more explicit on docs. I feel like in general the docs for multimodality lacks a lot of information, but I wanted first to make sure we have a consistent API before adding a section for |
@@ -539,6 +544,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |||
if kwargs_orig.get("quantization_config", None) is not None: | |||
kwargs["quantization_config"] = kwargs_orig["quantization_config"] | |||
|
|||
# AutoClass-specific config manipulation | |||
config = copy.deepcopy(config) | |||
config = cls._prepare_config_for_auto_class(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could also just get the text config here no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
( I mean no one liner!)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only want to pull the text config in AutoModelForCausalLM
, and this function is shared across all AutoModel
classes. This means that we either:
- do inheritance (commit you reviewed)
- add an
if self.__class__.__name__ == "AutoModelForCausalLM"
with the logic
I think inheritance is cleaner, but I'm happy to change it to the alternative version :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no worries this is fine! I just want watever is simpler!
@@ -539,6 +544,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |||
if kwargs_orig.get("quantization_config", None) is not None: | |||
kwargs["quantization_config"] = kwargs_orig["quantization_config"] | |||
|
|||
# AutoClass-specific config manipulation | |||
config = copy.deepcopy(config) | |||
config = cls._prepare_config_for_auto_class(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no worries this is fine! I just want watever is simpler!
Thanks a lot for this PR 🙏 . Any chance this can be added to the v4.49.0-Gemma-3 tag? Or is 4.50 around the corner anyway? |
yes, it is 😉 |
What does this PR do?
Fixes #36683
Enables
AutoModelForCausalLM
with multimodal checkpoints, to load a decoder-only text model.E.g. the following now works: