In [86]:
import os

import torch
from transformers import (
    CLIPTokenizerFast,
    CLIPTextModel,
    CLIPImageProcessor,
    CLIPModel,
    CLIPProcessor,
)
from PIL import Image
import torch.nn.functional as F
from IPython.display import display
import psycopg
import numpy as np
from pydantic import BaseModel
import requests
from io import BytesIO

MODEL = "openai/clip-vit-base-patch32"
DATABASE_URL = "postgresql://postgres:postgres@0.0.0.0:5433/postgres"
IMAGE_DIR = "./data"

# Download all the images

In [None]:
import pandas as pd

df = pd.read_csv("data/amazon_product.csv")
df.head(3)[["asin", "product_photo"]]

In [91]:
for i, row in df.iterrows():
    url = row["product_photo"]
    asin = row["asin"]
    response = requests.get(url)
    img = Image.open(BytesIO(response.content))
    if img.mode == 'RGBA':
        img = img.convert('RGB')

    img.save(f"./data/{asin}.jpg")

# Instantiate CLIP models

We'll use APIs from the `transformers` library to generate embeddings from the images

In [None]:
tokenizer = CLIPTokenizerFast.from_pretrained(MODEL)
text_model = CLIPTextModel.from_pretrained(MODEL)

image_processor = CLIPImageProcessor.from_pretrained(MODEL)
image_model = CLIPModel.from_pretrained(MODEL)

processor = CLIPProcessor.from_pretrained(MODEL)

clip_model = CLIPModel.from_pretrained(MODEL)

In [93]:

class ImageEmbedding(BaseModel):
    image_id: str
    embeddings: list[float]


def get_image_embeddings(
    image_paths: list[str], normalize=True
) -> list[ImageEmbedding]:
    # Process image and generate embeddings
    images = []
    for path in image_paths:
        images.append(Image.open(path))
    inputs = image_processor(images=images, return_tensors="pt")
    with torch.no_grad():
        outputs = image_model.get_image_features(**inputs)

    image_embeddings: list[ImageEmbedding] = []
    for image_p, embedding in zip(image_paths, outputs):
        if normalize:
            embeds = F.normalize(embedding, p=2, dim=-1)
        else:
            embeds = embedding
        image_embeddings.append(
            ImageEmbedding(
                image_id=image_p.split("/")[-1].split(".jpg")[0],
                embeddings=embeds.tolist(),
            )
        )

    return image_embeddings

def list_jpg_files(directory):
    # List to hold the full paths of files
    full_paths = []
    # Loop through the directory
    for filename in os.listdir(directory):
        # Check if the file ends with .jpg
        if filename.endswith(".jpg"):
            # Construct full path and add it to the list
            full_paths.append(os.path.join(directory, filename))
    return full_paths


def pg_insert_embeddings(images: list[ImageEmbedding]):
    init_pg_vectorize = "CREATE EXTENSION IF NOT EXISTS vectorize CASCADE;"
    init_table = """
        CREATE TABLE IF NOT EXISTS image_embeddings (image_id TEXT PRIMARY KEY, embeddings VECTOR(512));
    """
    insert_query = """
        INSERT INTO image_embeddings (image_id, embeddings)
        VALUES (%s, %s)
        ON CONFLICT (image_id)
        DO UPDATE SET embeddings = EXCLUDED.embeddings
        ;
    """
    with psycopg.connect(DATABASE_URL) as conn:
        with conn.cursor() as cur:
            cur.execute(init_pg_vectorize)
            cur.execute(init_table)

            for image in images:
                cur.execute(insert_query, (image.image_id, image.embeddings))




In [94]:
images = list_jpg_files(IMAGE_DIR)
image_embeddings = get_image_embeddings(images)
pg_insert_embeddings(image_embeddings)


# Transform Text to Embeddings

In [95]:
def get_text_embeddings(text, normalize=True) -> list[float]:
    inputs = processor(text=[text], return_tensors="pt", padding=True)
    text_features = clip_model.get_text_features(**inputs)
    text_embedding = text_features[0].detach().numpy()

    if normalize:
        embeds = text_embedding / np.linalg.norm(text_embedding)
    else:
        embeds = text_embedding
    return embeds.tolist()

In [96]:
def similarity_search(txt_embedding: list[float]) -> list[tuple[str, float]]:
    with psycopg.connect(DATABASE_URL) as conn:
        with conn.cursor() as cur:
            cur.execute(
                """
                    SELECT
                        image_id,
                        1 - (embeddings <=> %s::vector) AS similarity_score
                    FROM image_embeddings
                    ORDER BY similarity_score DESC
                    LIMIT 5;
                """,
                (txt_embedding,),
            )
            rows = cur.fetchall()

            return [(row[0], row[1]) for row in rows]


# Search Images with a Text Query

In [None]:
text_embeddings = get_text_embeddings("arts and crafts")
results = similarity_search(text_embeddings)

for r, score in results[:3]:
    print("Image ID:", r, "Score:", score)
    image_path = IMAGE_DIR + "/" + r + ".jpg"

    image = Image.open(image_path)
    img_resized = image.resize((300, 300))
    display(img_resized)

# Image Similarity Search

Download a Photo of Cher from Wikipedia

In [None]:
# get an image of Cher from wikipedia
url = "https://upload.wikimedia.org/wikipedia/commons/thumb/1/1d/Cher_in_2019_cropped_1.jpg/752px-Cher_in_2019_cropped_1.jpg"
response = requests.get(url)
img = Image.open(BytesIO(response.content))
img.save('./cher_wikipedia.jpg')
original_width, original_height = img.size
new_width = int(original_width * 0.25)
new_height = int(original_height * 0.25)

# Rescale the image
resized_img = img.resize((new_width, new_height))
display(resized_img)


# Find images similar to Cher's wikipedia image

In [None]:
image_embeddings = get_image_embeddings(["./cher_wikipedia.jpg"])[0].embeddings
results = similarity_search(image_embeddings)

for r, score in results[:3]:
    print("Image ID:", r, "Score:", score)
    image_path = IMAGE_DIR + "/" + r + ".jpg"

    image = Image.open(image_path)
    img_resized = image.resize((250, 250))
    display(img_resized)