In [5]:
import torch
import torch.nn as nn

from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
from open_clip import create_model_from_pretrained, get_tokenizer 

from PIL import Image
import requests

In [2]:
model = "openai/clip-vit-large-patch14-336"

image_processor = CLIPImageProcessor.from_pretrained(model)
vision_tower = CLIPVisionModel.from_pretrained(model)

In [15]:
select_layer = -2
select_feature = 'patch'
unfreeze_mm_vision_tower = False

device = vision_tower.device
dtype = vision_tower.dtype

def feature_select(image_forward_outs):
    image_features = image_forward_outs.hidden_states[select_layer]
    if select_feature == 'patch':
        image_features = image_features[:, 1:]
    elif select_feature == 'cls_patch':
        image_features = image_features
    else:
        raise ValueError(f'Unexpected select feature: {select_feature}')
    return image_features

def _forward(images):
    with torch.set_grad_enabled(unfreeze_mm_vision_tower):
        image_forward_outs = vision_tower(images.to(device=device, dtype=dtype), output_hidden_states=True)
        image_features = feature_select(image_forward_outs).to(images.dtype)
        return image_features


In [16]:
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")

In [17]:
inputs = image_processor(raw_image, return_tensors="pt")
inputs

{'pixel_values': tensor([[[[-0.5806, -0.5806, -0.5806,  ..., -0.7850, -0.7850, -0.7704],
          [ 0.5581,  0.5581,  0.5435,  ..., -0.4054, -0.4054, -0.4200],
          [ 0.8355,  0.8501,  0.8355,  ..., -0.2010, -0.2156, -0.2156],
          ...,
          [ 0.8792,  0.9084,  0.9230,  ...,  1.0106,  0.9960,  0.9814],
          [ 0.9084,  0.8938,  0.9522,  ...,  1.0252,  1.0106,  1.0398],
          [ 0.7187,  0.6895,  0.7479,  ...,  0.6457,  0.6603,  0.6749]],

         [[-0.6715, -0.6715, -0.6715,  ..., -0.8516, -0.8516, -0.8516],
          [-0.1613, -0.1313, -0.1463,  ..., -1.0317, -1.0317, -1.0317],
          [ 0.0488,  0.0789,  0.0789,  ..., -0.8967, -0.8967, -0.8967],
          ...,
          [ 0.9343,  0.9643,  0.9793,  ...,  1.0994,  1.0844,  1.0694],
          [ 0.9643,  0.9493,  1.0093,  ...,  1.1144,  1.0994,  1.1294],
          [ 0.7842,  0.7392,  0.7992,  ...,  0.7242,  0.7392,  0.7542]],

         [[-0.4422, -0.4422, -0.4422,  ..., -0.5986, -0.5986, -0.5986],
          [ 0

In [11]:
out = _forward(inputs.pixel_values)
out

tensor([[[ 0.4228, -0.9447,  0.2154,  ..., -0.0067,  1.1019,  0.6007],
         [-0.0308,  0.3509,  0.7939,  ...,  0.6656,  0.9412, -0.7258],
         [-0.1157,  0.1328,  1.3532,  ...,  0.9491, -0.3748,  0.7546],
         ...,
         [ 0.5602,  0.2668,  0.6287,  ...,  0.1729,  0.4725, -0.0097],
         [ 0.6587, -1.0553, -0.8035,  ..., -0.1626,  0.2400,  1.1361],
         [ 0.4978,  0.1702, -0.3934,  ...,  0.4429,  0.3947, -0.2092]]])

In [12]:
out.shape

torch.Size([1, 576, 1024])

In [20]:
images = inputs.pixel_values

image_forward_outs = vision_tower(images.to(device=device, dtype=dtype), output_hidden_states=True)

In [21]:
type(image_forward_outs)

transformers.modeling_outputs.BaseModelOutputWithPooling

In [22]:
image_forward_outs.keys()

odict_keys(['last_hidden_state', 'pooler_output', 'hidden_states'])

In [23]:
image_forward_outs.last_hidden_state.shape

torch.Size([1, 577, 1024])

In [30]:
for i in range(len(image_forward_outs.hidden_states)):
    print(i, image_forward_outs.hidden_states[i].shape)

0 torch.Size([1, 577, 1024])
1 torch.Size([1, 577, 1024])
2 torch.Size([1, 577, 1024])
3 torch.Size([1, 577, 1024])
4 torch.Size([1, 577, 1024])
5 torch.Size([1, 577, 1024])
6 torch.Size([1, 577, 1024])
7 torch.Size([1, 577, 1024])
8 torch.Size([1, 577, 1024])
9 torch.Size([1, 577, 1024])
10 torch.Size([1, 577, 1024])
11 torch.Size([1, 577, 1024])
12 torch.Size([1, 577, 1024])
13 torch.Size([1, 577, 1024])
14 torch.Size([1, 577, 1024])
15 torch.Size([1, 577, 1024])
16 torch.Size([1, 577, 1024])
17 torch.Size([1, 577, 1024])
18 torch.Size([1, 577, 1024])
19 torch.Size([1, 577, 1024])
20 torch.Size([1, 577, 1024])
21 torch.Size([1, 577, 1024])
22 torch.Size([1, 577, 1024])
23 torch.Size([1, 577, 1024])
24 torch.Size([1, 577, 1024])


In [34]:
(image_forward_outs.hidden_states[-2] == image_forward_outs.last_hidden_state)

tensor([[[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]]])