In [2]:
import pandas as pd
import numpy as np
import torch
from torch import nn 
from torch.utils.data import DataLoader     # 데이터로더는 데이터셋을 iterable하게 감싸는 역할
from torchvision import datasets            # 데이터셋은 샘플과 정답을 저장함
from torchvision.transforms import ToTensor
import clip
from PIL import Image

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

image_title = 'A horse in the space.png'

image = preprocess(Image.open(image_title)).unsqueeze(0).to(device)
text = clip.tokenize("A A horse in the space").to(device)

In [13]:
model.visual.proj

Parameter containing:
tensor([[-2.6264e-03,  5.0962e-05,  2.7496e-02,  ..., -1.0025e-02,
         -1.2222e-02,  5.8403e-03],
        [-1.9852e-02,  7.1182e-03,  8.9788e-04,  ...,  1.1528e-02,
         -1.9485e-02, -8.0185e-03],
        [-8.6288e-03,  1.9226e-03, -2.1725e-03,  ...,  3.9330e-03,
         -1.1269e-02,  1.5345e-03],
        ...,
        [-1.1993e-02,  1.2955e-02,  2.5848e-02,  ..., -9.8038e-03,
         -4.2076e-03,  1.5211e-04],
        [-1.2871e-02, -9.5673e-03, -1.0826e-02,  ..., -7.0610e-03,
         -4.3182e-03, -4.9353e-04],
        [-4.4098e-03,  3.3588e-03, -1.2054e-02,  ...,  6.1073e-03,
          3.9940e-03, -3.0861e-03]], device='cuda:0', dtype=torch.float16,
       requires_grad=True)

In [5]:
img_embeddings = [torch.zeros([1, 512]) for i in range(model.visual.transformer.layers)]

#%% model.encode_image

# image -> tokens
x = model.visual.conv1(image.type(model.visual.conv1.weight.dtype))
x = x.reshape(x.shape[0], x.shape[1], -1)   # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1)                      # shape = [*, grid ** 2, width]
x = torch.cat([model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
x = x + model.visual.positional_embedding.to(x.dtype)
x = model.visual.ln_pre(x)

# tokens -> transformer -> feature_embeddings
x = x.permute(1, 0, 2)  # NLD -> LND

for i in range(model.visual.transformer.layers):
    x = model.visual.transformer.resblocks[i](x)
    tmp = x.permute(1, 0, 2)
    tmp = model.visual.ln_post(tmp[:, 0, :])
    if model.visual.proj is not None:
        tmp = tmp @ model.visual.proj
    img_embeddings[i].copy_(tmp)

x = x.permute(1, 0, 2)  # LND -> NLD
print(x.shape)  

x = model.visual.ln_post(x[:, 0, :])    # [CLS] token의 임베딩을 사용
print(x.shape)  

if model.visual.proj is not None:
    x = x @ model.visual.proj
print(x.shape)

torch.Size([1, 50, 768])
torch.Size([1, 768])
torch.Size([1, 512])


In [6]:
txt_embeddings = [torch.zeros([1, 512]) for i in range(model.transformer.layers)]

#%% model.encode_image

x = model.token_embedding(text).type(model.dtype)  # [batch_size, n_ctx, d_model]

x = x + model.positional_embedding.type(model.dtype)
x = x.permute(1, 0, 2)  # NLD -> LND

for i in range(model.transformer.layers):
    x = model.transformer.resblocks[i](x)
    tmp = x.permute(1, 0, 2)
    tmp = model.ln_final(tmp).type(model.dtype)
    tmp = tmp[torch.arange(tmp.shape[0]), text.argmax(dim=-1)] @ model.text_projection
    txt_embeddings[i].copy_(tmp)


x = x.permute(1, 0, 2)  # LND -> NLD
x = model.ln_final(x).type(model.dtype)
print(x.shape)

x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ model.text_projection    
    # x.shape[0] : 들어온 단어 토큰의 개수
    # x[[(which_word_token)], (which_end_token)] : [CLS] token이 아니라 end token의 임베딩을 사용
    # text.argmax(dim=-1) : end token의 위치
print(x.shape)


torch.Size([1, 77, 512])
torch.Size([1, 512])


In [7]:
len(img_embeddings), len(txt_embeddings)

(12, 12)

In [8]:
img_similarity_by_layer = [[(img_embeddings[i] @ txt_embeddings[j].T).item() for j in range(12)] for i in range(12)]
img_similarity_by_layer

[[7.615795612335205,
  9.797560691833496,
  11.156327247619629,
  9.027064323425293,
  6.2413458824157715,
  6.277929782867432,
  11.906933784484863,
  15.012986183166504,
  14.751118659973145,
  16.606121063232422,
  15.018819808959961,
  24.93141746520996],
 [6.330994129180908,
  8.5187406539917,
  9.775116920471191,
  7.487305641174316,
  4.8748860359191895,
  4.953120231628418,
  10.725042343139648,
  13.944134712219238,
  13.571572303771973,
  15.294862747192383,
  13.928719520568848,
  23.036231994628906],
 [6.784273147583008,
  8.728285789489746,
  10.201212882995605,
  7.814774036407471,
  5.212159633636475,
  5.04819917678833,
  10.780756950378418,
  14.003941535949707,
  13.499654769897461,
  15.335693359375,
  13.966981887817383,
  23.18608856201172],
 [7.7387003898620605,
  9.74960708618164,
  11.19079303741455,
  8.873252868652344,
  6.261541366577148,
  6.105512619018555,
  11.513084411621094,
  14.71274471282959,
  14.375870704650879,
  16.073898315429688,
  14.695353507

In [9]:
txt_similarity_by_layer = [[(txt_embeddings[i] @ img_embeddings[j].T).item() for j in range(12)] for i in range(12)]
txt_similarity_by_layer

[[7.615795612335205,
  6.330994129180908,
  6.784273147583008,
  7.7387003898620605,
  8.214287757873535,
  8.431873321533203,
  8.650493621826172,
  8.574219703674316,
  9.553253173828125,
  9.85700511932373,
  10.145522117614746,
  13.722640037536621],
 [9.797560691833496,
  8.5187406539917,
  8.728285789489746,
  9.74960708618164,
  9.829022407531738,
  9.94196605682373,
  10.241815567016602,
  10.119183540344238,
  10.961342811584473,
  11.301011085510254,
  11.80378532409668,
  15.046772956848145],
 [11.156327247619629,
  9.775116920471191,
  10.201212882995605,
  11.19079303741455,
  11.209044456481934,
  10.736769676208496,
  11.223325729370117,
  10.903237342834473,
  11.61048698425293,
  12.231568336486816,
  12.988690376281738,
  14.040470123291016],
 [9.027064323425293,
  7.487305641174316,
  7.814774036407471,
  8.873252868652344,
  9.299650192260742,
  8.945768356323242,
  9.289751052856445,
  8.73231315612793,
  9.479312896728516,
  10.547597885131836,
  11.53284263610839

In [10]:
img_sum_similarity_by_layer = [sum(x) for x in img_similarity_by_layer]
txt_sum_similarity_by_layer = [sum(x) for x in txt_similarity_by_layer]
print([round(x/sum(img_sum_similarity_by_layer), 2) for x in img_sum_similarity_by_layer])
print([round(x/sum(txt_sum_similarity_by_layer), 2) for x in txt_sum_similarity_by_layer])
# 이 결과값을 각 레이어에 대한 가중치로 준다면 어떨까?

[0.08, 0.07, 0.07, 0.08, 0.08, 0.08, 0.08, 0.08, 0.08, 0.08, 0.1, 0.12]
[0.06, 0.07, 0.07, 0.06, 0.04, 0.04, 0.08, 0.1, 0.1, 0.11, 0.11, 0.16]
