In [3]:
import pandas as pd
import numpy as np
import torch
from torch import nn 
from torch.utils.data import DataLoader     # 데이터로더는 데이터셋을 iterable하게 감싸는 역할
from torchvision import datasets            # 데이터셋은 샘플과 정답을 저장함
from torchvision.transforms import ToTensor
import clip
from PIL import Image

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

image_title = 'A dog is sitting on a couch with a stuffed animal.png'

image = preprocess(Image.open(image_title)).unsqueeze(0).to(device)
text = clip.tokenize("A dog is sitting on a couch with a stuffed animal.png").to(device)

In [5]:
img_embeddings = [torch.zeros([1, 512]) for i in range(model.visual.transformer.layers)]

#%% model.encode_image

# image -> tokens
x = model.visual.conv1(image.type(model.visual.conv1.weight.dtype))
x = x.reshape(x.shape[0], x.shape[1], -1)   # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1)                      # shape = [*, grid ** 2, width]
x = torch.cat([model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
x = x + model.visual.positional_embedding.to(x.dtype)
x = model.visual.ln_pre(x)

# tokens -> transformer -> feature_embeddings
x = x.permute(1, 0, 2)  # NLD -> LND

for i in range(model.visual.transformer.layers):
    x = model.visual.transformer.resblocks[i](x)
    tmp = x.permute(1, 0, 2)
    tmp = model.visual.ln_post(tmp[:, 0, :])
    if model.visual.proj is not None:
        tmp = tmp @ model.visual.proj
    img_embeddings[i].copy_(tmp)

x = x.permute(1, 0, 2)  # LND -> NLD
print(x.shape)  

x = model.visual.ln_post(x[:, 0, :])    # [CLS] token의 임베딩을 사용
print(x.shape)  

if model.visual.proj is not None:
    x = x @ model.visual.proj
print(x.shape)

torch.Size([1, 50, 768])
torch.Size([1, 768])
torch.Size([1, 512])


In [6]:
txt_embeddings = [torch.zeros([1, 512]) for i in range(model.transformer.layers)]

#%% model.encode_image

x = model.token_embedding(text).type(model.dtype)  # [batch_size, n_ctx, d_model]

x = x + model.positional_embedding.type(model.dtype)
x = x.permute(1, 0, 2)  # NLD -> LND

for i in range(model.transformer.layers):
    x = model.transformer.resblocks[i](x)
    tmp = x.permute(1, 0, 2)
    tmp = model.ln_final(tmp).type(model.dtype)
    tmp = tmp[torch.arange(tmp.shape[0]), text.argmax(dim=-1)] @ model.text_projection
    txt_embeddings[i].copy_(tmp)


x = x.permute(1, 0, 2)  # LND -> NLD
x = model.ln_final(x).type(model.dtype)
print(x.shape)

x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ model.text_projection    
    # x.shape[0] : 들어온 단어 토큰의 개수
    # x[[(which_word_token)], (which_end_token)] : [CLS] token이 아니라 end token의 임베딩을 사용
    # text.argmax(dim=-1) : end token의 위치
print(x.shape)


torch.Size([1, 77, 512])
torch.Size([1, 512])


In [8]:
len(img_embeddings), len(txt_embeddings)

(12, 12)

In [20]:
img_similarity_by_layer = [[(img_embeddings[i] @ txt_embeddings[j].T).item() for j in range(12)] for i in range(12)]
img_similarity_by_layer

[[2.501817464828491,
  4.849160671234131,
  4.716002941131592,
  4.787253379821777,
  1.205403208732605,
  0.48790448904037476,
  5.24664306640625,
  6.5477986335754395,
  4.552478790283203,
  7.057314395904541,
  4.469366073608398,
  9.94657039642334],
 [1.3229795694351196,
  3.3995361328125,
  3.2204341888427734,
  3.1518073081970215,
  -0.6400675773620605,
  -1.2784724235534668,
  3.543281078338623,
  5.218833923339844,
  3.367579460144043,
  6.245415210723877,
  4.461362361907959,
  10.032344818115234],
 [1.6270900964736938,
  3.4631729125976562,
  3.4350616931915283,
  3.2851505279541016,
  -0.35106775164604187,
  -1.0539891719818115,
  3.8040571212768555,
  5.455502033233643,
  3.350917100906372,
  6.090952396392822,
  4.210789203643799,
  9.992056846618652],
 [2.112712860107422,
  3.9352617263793945,
  3.8565807342529297,
  3.5443649291992188,
  -0.12104439735412598,
  -0.9199535846710205,
  3.926861047744751,
  5.421072959899902,
  3.50542950630188,
  6.1039605140686035,
  3.99

In [21]:
txt_similarity_by_layer = [[(txt_embeddings[i] @ img_embeddings[j].T).item() for j in range(12)] for i in range(12)]
txt_similarity_by_layer

[[2.501817464828491,
  1.3229795694351196,
  1.6270900964736938,
  2.112712860107422,
  2.214921712875366,
  1.3777868747711182,
  1.9260956048965454,
  2.565786123275757,
  3.8193743228912354,
  4.937297344207764,
  5.343286037445068,
  10.72659683227539],
 [4.849160671234131,
  3.3995361328125,
  3.4631729125976562,
  3.9352617263793945,
  3.8491246700286865,
  3.3112974166870117,
  4.2125139236450195,
  5.002923488616943,
  5.804977893829346,
  6.562867641448975,
  7.014270782470703,
  11.778487205505371],
 [4.716002941131592,
  3.2204341888427734,
  3.4350616931915283,
  3.8565807342529297,
  3.9205925464630127,
  3.8110437393188477,
  4.888388633728027,
  5.617875099182129,
  6.628039360046387,
  7.444637775421143,
  7.986630439758301,
  14.422587394714355],
 [4.787253379821777,
  3.1518073081970215,
  3.2851505279541016,
  3.5443649291992188,
  3.755575656890869,
  3.643217086791992,
  5.0851216316223145,
  6.1066365242004395,
  6.924951553344727,
  7.939623832702637,
  9.2752084

In [27]:
img_sum_similarity_by_layer = [sum(x) for x in img_similarity_by_layer]
txt_sum_similarity_by_layer = [sum(x) for x in txt_similarity_by_layer]
print([round(x/sum(img_sum_similarity_by_layer), 2) for x in img_sum_similarity_by_layer])
print([round(x/sum(txt_sum_similarity_by_layer), 2) for x in txt_sum_similarity_by_layer])
# 이 결과값을 각 레이어에 대한 가중치로 준다면 어떨까?

[0.06, 0.05, 0.05, 0.05, 0.05, 0.05, 0.07, 0.08, 0.09, 0.1, 0.13, 0.21]
[0.05, 0.07, 0.08, 0.08, 0.03, 0.02, 0.09, 0.11, 0.09, 0.12, 0.1, 0.18]
