# 0. Env

In [None]:
import numpy as np

import torch

from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset

In [None]:
# Gradient False
# Pytorch에서 동작을 확안하기 위해서 Gradient 계산을 하지 않도록 설정
torch.set_grad_enabled(False)

# 1. CLIP

In [None]:
# 데이터 로딩
data = load_dataset(
    "jamescalam/image-text-demo",
    split="train"
)
data

In [None]:
# 사전 학습된 모델 로딩
model_id = "openai/clip-vit-base-patch32"

processor = CLIPProcessor.from_pretrained(model_id)
model = CLIPModel.from_pretrained(model_id)

In [None]:
# GPU 사용 가능 여부 확인
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

In [None]:
# text를 token으로 변경
text = data['text'] 

tokens = processor(
    text=text,
    padding=True,
    images=None,
    return_tensors='pt'
).to(device)

tokens

In [None]:
# 텍스트 특징 벡터
text_emb = model.get_text_features(**tokens)

In [None]:
print(text_emb.shape)
print(text_emb.min(), text_emb.max())

In [None]:
# torch tensor를 numpy로 변환
text_emb = text_emb.detach().cpu().numpy()

# L2 Norm
norm_factor = np.linalg.norm(text_emb, axis=1)
norm_factor.shape

In [None]:
# L2 norm이 1인 단위 벡터
text_emb = text_emb.T / norm_factor
# (521, 21) -> (21, 512)
text_emb = text_emb.T
print(text_emb.shape)
print(text_emb.min(), text_emb.max())

In [None]:
# 이미지 전처리
image_batch = data['image']

images = processor(
    text=None,
    images=image_batch,
    return_tensors='pt'
)['pixel_values'].to(device)

images.shape

In [None]:
# 이미지 특징 벡터
img_emb = model.get_image_features(images)
print(img_emb.shape)
print(img_emb.min(), img_emb.max())

In [None]:
# torch tensor를 numpy로 변환
img_emb = img_emb.detach().cpu().numpy()

# L2 Norm
norm_factor = np.linalg.norm(img_emb, axis=1)
norm_factor.shape

In [None]:
# L2 norm이 1인 단위 벡터
img_emb = img_emb.T / norm_factor
# (521, 21) -> (21, 512)
img_emb = img_emb.T
print(img_emb.shape)
print(img_emb.min(), img_emb.max())

In [None]:
cos_sim = np.dot(text_emb, img_emb.T) / (
    np.linalg.norm(text_emb, axis=1) * np.linalg.norm(img_emb, axis=1)
)
cos_sim.shape

In [None]:
import matplotlib.pyplot as plt

plt.imshow(cos_sim)
plt.show()

In [None]:
dot_sim = np.dot(text_emb, img_emb.T)

plt.imshow(dot_sim)
plt.show()

In [None]:
diff = cos_sim - dot_sim
diff.min(), diff.max()