-
Notifications
You must be signed in to change notification settings - Fork 29.4k
Add kwargs for timm.create_model in TimmWrapper #38860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
run-slow: timm_wrapper |
This comment contains run-slow, running the specified jobs: models: ['models/timm_wrapper'] |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super fast, wow. Thanks a lot!
I don't fully understand how we load timm models, so to make sure. If we do smth like AutoModel.from_config(config)
where the config has many attributes, do we catch up with all the attributes?
I see that the tiim config internally saves it in model_init_kwargs
but I don't know where does it come from when doing auto mapping
Not sure I got the question, but in case the config has its e.g. for vlm we would have TimmWrapperConfig as sub config which has to define |
I mean if we have a config as |
Yes, the config has to be TimmWrapperConfig(
...
model_init_kwargs={"hidden_dim": 128, activation:"gelu"}
) because kwargs are different for different timm models, e.g. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oke, maybe that makes sense. I actually had an idea that TimmWrapperConfig
is an abstraction and each model will inherit from it and add their own attributes
Thanks!
@qubvel a few thoughts, q Does this also work with AutoModel*.from_pretrained? Does this save back to a config, I feel that's something which would make sense but maybe it doesn't in some contexts?
timm's config.json has a 'model_args' key which can store dict of kwarg overrides that's applied at model create time (for users to create custom / altered architectures). I feel this should align with that somehow? |
@rwightman, thanks for the review, great questions
yes
yes, that was an idea of adding this to the config rather than passing as (also added a test to check save-restore with extra args works as expected)
Perfect, I didn't know that, thanks for sharing 👍 changed import timm
import tempfile
from transformers import TimmWrapperConfig, TimmWrapperModel
config = TimmWrapperConfig.from_pretrained(
"timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k",
model_args={"depth": 3},
)
model = TimmWrapperModel(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
restored_hf_model = TimmWrapperModel.from_pretrained(tmpdirname)
assert len(restored_hf_model.timm_model.blocks) == 3
restored_timm_model = timm.create_model(f"local-dir:{tmpdirname}")
assert len(restored_timm_model.blocks) == 3 |
@qubvel great... looks good w/ the -> model_args change |
What does this PR do?
Add kwargs for timm.create_model in TimmWrapper.
requested in:
related to:
Kwargs are added to the config to save them.
timm
will read them, but for models without kwargs it doesn't break anything.To init from config:
To init model from ptratrained
cc @zucchini-nlp @rwightman