Skip to content

Add kwargs for timm.create_model in TimmWrapper #38860

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 20, 2025

Conversation

qubvel
Copy link
Member

@qubvel qubvel commented Jun 17, 2025

What does this PR do?

Add kwargs for timm.create_model in TimmWrapper.

requested in:

related to:

Kwargs are added to the config to save them. timm will read them, but for models without kwargs it doesn't break anything.

To init from config:

config = TimmWrapperConfig.from_pretrained(
  "timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k",
  model_args={"depth": 3},
)
model = TimmWrapperModel(config)

To init model from ptratrained

model = TimmWrapperModel.from_pretrained(
  "timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k",
  model_args={"depth": 3},
)

cc @zucchini-nlp @rwightman

@qubvel qubvel marked this pull request as ready for review June 17, 2025 12:08
@qubvel
Copy link
Member Author

qubvel commented Jun 17, 2025

run-slow: timm_wrapper

Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/timm_wrapper']
quantizations: [] ...

@qubvel qubvel requested a review from zucchini-nlp June 17, 2025 12:10
@qubvel qubvel added the Vision label Jun 17, 2025
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super fast, wow. Thanks a lot!

I don't fully understand how we load timm models, so to make sure. If we do smth like AutoModel.from_config(config) where the config has many attributes, do we catch up with all the attributes?

I see that the tiim config internally saves it in model_init_kwargs but I don't know where does it come from when doing auto mapping

@qubvel
Copy link
Member Author

qubvel commented Jun 17, 2025

Not sure I got the question, but in case the config has its model_init_kwargs defined, yes. However, original timm models does'n have this keyword defined by default, so model_init_kwargs has to be passed explicitly or saved in config in advance.

e.g. for vlm we would have TimmWrapperConfig as sub config which has to define model_init_kwargs (such as depth for example)

@zucchini-nlp
Copy link
Member

I mean if we have a config as config(hidden_dim=128, activation="gelu") and want to use it to initialize a timm model, will that work? Does the config have to be as config(model_init_kwargs={"hidden_dim": 128, activation:"gelu"})?

@qubvel
Copy link
Member Author

qubvel commented Jun 17, 2025

Yes, the config has to be

TimmWrapperConfig(
   ...
   model_init_kwargs={"hidden_dim": 128, activation:"gelu"}
)

because kwargs are different for different timm models, e.g. depth argument is supported by vit but not supported by mobilenet

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oke, maybe that makes sense. I actually had an idea that TimmWrapperConfig is an abstraction and each model will inherit from it and add their own attributes

Thanks!

@qubvel qubvel requested a review from rwightman June 17, 2025 17:35
@rwightman
Copy link
Contributor

@qubvel a few thoughts, q

Does this also work with AutoModel*.from_pretrained?

Does this save back to a config, I feel that's something which would make sense but maybe it doesn't in some contexts?

model = TimmWrapperModel.from_pretrained(
  "timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k",
  model_init_kwargs={"depth": 3},
)

model.save_pretrained()  # does this save with args so the result can be loaded without the args? 

timm's config.json has a 'model_args' key which can store dict of kwarg overrides that's applied at model create time (for users to create custom / altered architectures). I feel this should align with that somehow?

@qubvel
Copy link
Member Author

qubvel commented Jun 18, 2025

@rwightman, thanks for the review, great questions

Does this also work with AutoModel*.from_pretrained?

yes

Does this save back to a config, I feel that's something which would make sense but maybe it doesn't in some contexts?

yes, that was an idea of adding this to the config rather than passing as **kwargs to the TimmWrapper.__init__

(also added a test to check save-restore with extra args works as expected)

timm's config.json has a 'model_args' key which can store dict of kwarg overrides that's applied at model create time (for users to create custom / altered architectures). I feel this should align with that somehow?

Perfect, I didn't know that, thanks for sharing 👍 changed model_init_kwargs -> model_args and now model can be loaded back to timm with extra args applied

import timm
import tempfile
from transformers import TimmWrapperConfig, TimmWrapperModel

config = TimmWrapperConfig.from_pretrained(
    "timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k",
    model_args={"depth": 3},
)
model = TimmWrapperModel(config)

with tempfile.TemporaryDirectory() as tmpdirname:
    model.save_pretrained(tmpdirname)
    restored_hf_model = TimmWrapperModel.from_pretrained(tmpdirname)
    assert len(restored_hf_model.timm_model.blocks) == 3

    restored_timm_model = timm.create_model(f"local-dir:{tmpdirname}")
    assert len(restored_timm_model.blocks) == 3

@rwightman
Copy link
Contributor

@qubvel great... looks good w/ the -> model_args change

@qubvel qubvel enabled auto-merge (squash) June 20, 2025 11:48
@qubvel qubvel merged commit 9120567 into huggingface:main Jun 20, 2025
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants