In [1]:
!pip install datasets

[0m

In [2]:
!pip install git+https://github.com/openai/CLIP.git

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-p0_n0bhg
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-p0_n0bhg
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25ldone
[0m

In [3]:
import matplotlib.pyplot as plt
import pandas as pd
import clip
import torch
import torch.nn.functional as F
from PIL import Image

In [4]:
from datasets import load_dataset

In [5]:
dataset = load_dataset("UCSC-Admire/idiom-dataset-100-2024-11-11_14-37-58")

dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'language', 'compound', 'sentence_type', 'sentence', 'style', 'image_1_prompt', 'image_1', 'image_2_prompt', 'image_2', 'image_3_prompt', 'image_3', 'image_4_prompt', 'image_4', 'image_5_prompt', 'image_5'],
        num_rows: 600
    })
})

In [6]:
dataset['train']

Dataset({
    features: ['id', 'language', 'compound', 'sentence_type', 'sentence', 'style', 'image_1_prompt', 'image_1', 'image_2_prompt', 'image_2', 'image_3_prompt', 'image_3', 'image_4_prompt', 'image_4', 'image_5_prompt', 'image_5'],
    num_rows: 600
})

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

In [12]:
preprocess

Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=warn)
    CenterCrop(size=(224, 224))
    <function _convert_image_to_rgb at 0x7f10f20ad080>
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)

In [14]:
def get_text_features(sentence):
    text = clip.tokenize([sentence]).to(device)  # Tokenize the text
    with torch.no_grad():
        text_features = model.encode_text(text)  # Extract text features
    return text_features

def get_image_features(image_data):
    # image_data is expected to be a PIL image directly from the dataset
    image = preprocess(image_data).unsqueeze(0).to(device)  # Preprocess the image
    with torch.no_grad():
        image_features = model.encode_image(image)  # Extract image features
    return image_features


def cosine_similarity(text_features, image_features):
    similarity = F.cosine_similarity(text_features, image_features)
    return similarity.item()

In [15]:
ranked_images = []

for idx in range(len(dataset['train'])):
    row = dataset['train'][idx]
    sentence = row['sentence']
    image_columns = ['image_1', 'image_2', 'image_3', 'image_4', 'image_5']
    
    text_features = get_text_features(sentence)
    
    similarities = []
    for img_col in image_columns:
        image_data = row[img_col]  # Get the image directly from the dataset (assumed to be in PIL format)
        image = Image.open(io.BytesIO(image_data)) if isinstance(image_data, bytes) else image_data
        image_features = get_image_features(image)
        similarity = cosine_similarity(text_features, image_features)
        similarities.append((img_col, similarity))
    
    # Rank images by similarity score in descending order
    ranked_images.append(sorted(similarities, key=lambda x: x[1], reverse=True))

In [16]:
ranked_images

[[('image_1', 0.30419921875),
  ('image_2', 0.233642578125),
  ('image_4', 0.169921875),
  ('image_5', 0.1478271484375),
  ('image_3', 0.146240234375)],
 [('image_4', 0.279052734375),
  ('image_3', 0.265380859375),
  ('image_2', 0.2490234375),
  ('image_5', 0.2325439453125),
  ('image_1', 0.23193359375)],
 [('image_1', 0.314208984375),
  ('image_2', 0.252685546875),
  ('image_4', 0.173828125),
  ('image_5', 0.1717529296875),
  ('image_3', 0.14013671875)],
 [('image_4', 0.289306640625),
  ('image_3', 0.269775390625),
  ('image_2', 0.24609375),
  ('image_5', 0.2415771484375),
  ('image_1', 0.23681640625)],
 [('image_1', 0.297607421875),
  ('image_2', 0.25634765625),
  ('image_4', 0.1763916015625),
  ('image_5', 0.1539306640625),
  ('image_3', 0.1336669921875)],
 [('image_4', 0.2646484375),
  ('image_3', 0.25927734375),
  ('image_2', 0.243896484375),
  ('image_1', 0.2415771484375),
  ('image_5', 0.231201171875)],
 [('image_1', 0.261474609375),
  ('image_5', 0.2440185546875),
  ('image_2',