Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why don't the ViT-L/14 models in (blip2 pretrain_vitL) and (blip2_t5 pretrain_flant5xl_vitL) have the same number of layers as when instantiation a BLIP2 model with vit_model = 'clip_L'? #609

Open
ReinforcedKnowledge opened this issue Dec 10, 2023 · 0 comments

Comments

@ReinforcedKnowledge
Copy link

Hi!

I have checked the vision part of both blip2 pretrain_vitL and blip2_t5 pretrain_flant5xl_vitL and noticed that there are $21$ residual attention blocks.

Meanwhile, when instantiating a base BLIP2 model, we can specify the vision encoder to be a "clip_L" model, which is done by the init_vision_encoder. As you can see, it creates the ViT-L model by calling the function create_clip_vit_L . You can see in the function's code that there are $23$ layers, which I could confirm by downloading the model's pth file and loading it.

I ran the following code to check if the weights in the vision encoder of blip2 pretrain_vitL are the same as their counterparts in the clip vit_L:

from lavis.models import load_model
model = load_model("blip2", "pretrain_vitL")

blip2_pretrain_vitl_state_dict = model.state_dict()
clip_vitL_state_dict = torch.load(clip_vit_l_path, map_location="cpu")

# Verify that the parameters' names are the same
ve_blip2_pretrain_vitl_keys = [k.replace('visual_encoder.', '') for k in blip2_pretrain_vitl_state_dict.keys() if 'visual_encoder' in k]
all_keys_in = sum([1 for k in ve_blip2_pretrain_vitl_keys if k in clip_vitL_state_dict.keys()])
print(len(ve_blip2_pretrain_vitl_keys) == all_keys_in)

# Verify that the weights are the same
for k in ve_blip2_pretrain_vitl_keys:
  if not bool((blip2_pretrain_vitl_state_dict[f"visual_encoder.{k}"] == clip_vitL_state_dict[k]).all()):
    print(k)

I was wondering why is there such a difference.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant