In [3]:
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
import numpy as np

# Load model
model_id = "patrickjohncyh/fashion-clip"
model = CLIPModel.from_pretrained(model_id)
processor = CLIPProcessor.from_pretrained(model_id)

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") # mac lol
model = model.to(device)

In [4]:
# Load image and text
image = Image.open("../data/picture.jpeg")
with open("../data/post.txt", "r") as f:
  text = f.read()

# text cannot be more than 77 tokens long per embedding for CLIP

In [5]:
tokens = processor.tokenizer(text, return_tensors="pt")
tokens['input_ids'].shape[1]  # should be ≤ 77


43

In [6]:
# Prepare inputs (image + optional text)
inputs = processor(
    text=[text],
    images=image,
    return_tensors="pt",
    padding=True
)
inputs = {k: v.to(device) for k, v in inputs.items()} 

# Forward pass
with torch.no_grad():
    outputs = model(**inputs)

image_embeds = outputs.image_embeds
text_embeds = outputs.text_embeds

image_embeds = torch.nn.functional.normalize(image_embeds)
text_embeds = torch.nn.functional.normalize(text_embeds)

In [7]:
image_embeds.shape, text_embeds.shape

(torch.Size([1, 512]), torch.Size([1, 512]))

In [14]:
# Prepare inputs (image + optional text)
compare_text = "blue shirt and brown long pants"
compare_inputs = processor(
    text=[compare_text],
    return_tensors="pt",
    padding=True
)
compare_inputs = {k: v.to(device) for k, v in compare_inputs.items()} 

# Forward pass
with torch.no_grad():
    compare_text_embeds = model.get_text_features(**compare_inputs)

compare_text_embeds = torch.nn.functional.normalize(compare_text_embeds)

In [17]:
similarity = torch.nn.functional.cosine_similarity(image_embeds, text_embeds)
print(f"Image to post.txt similarity: {similarity[0]}")


similarity = torch.nn.functional.cosine_similarity(image_embeds, compare_text_embeds)
print(f"Image to manual description similarity: {similarity[0]}")

similarity = torch.nn.functional.cosine_similarity(text_embeds, compare_text_embeds)
print(f"Manual description to post.txt similarity: {similarity[0]}")

Image to post.txt similarity: 0.2717254161834717
Image to manual description similarity: 0.3059011697769165
Manual description to post.txt similarity: 0.4487333595752716
