# Image Similarity Search using CLIP

This notebook demonstrates how to use the Open-CLIP (Contrastive Language-Image Pre-training) model to perform image similarity searches. We'll cover two types of searches:
1. Text-to-Image: Find images that match a given text description
2. Image-to-Image: Find images similar to a given image

Let's start by importing the necessary libraries and setting up our environment.

In [None]:
import os
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Set up environment variables (you can modify these as needed)
os.environ['IMAGE_SOURCE_FOLDER'] = './data/images'
os.environ['EMBEDDINGS_FILE'] = './embeddings/CLIPembeddings.pt'

print("Libraries imported and environment variables set.")

## Setting up the CLIP model

Now, let's set up the CLIP model. This function will load the model and determine the best available device (GPU or CPU) for processing.

In [None]:
def setup_model(model_name):
    model = CLIPModel.from_pretrained(model_name)
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    model.to(device)
    processor = CLIPProcessor.from_pretrained(model_name)
    return model, processor, device

model_name = "laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K"
model, processor, device = setup_model(model_name)
print(f"Model loaded and set to use device: {device}")

## Preparing the Image Dataset

Next, we'll load the paths of all images in our dataset and their pre-computed embeddings.

In [None]:
def get_image_paths(source_folder):
    image_paths = []
    for root, _, files in os.walk(source_folder):
        for file in files:
            if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                image_paths.append(os.path.join(root, file))
    return sorted(image_paths)

def load_embeddings(file_path):
    return torch.load(file_path)

source_folder = os.environ.get('IMAGE_SOURCE_FOLDER', './data/images')
embeddings_file = os.environ.get('EMBEDDINGS_FILE', './embeddings/CLIPembeddings.pt')

image_paths = get_image_paths(source_folder)
image_embeddings = load_embeddings(embeddings_file)

print(f"Loaded {len(image_paths)} image paths and their embeddings.")

## Utility Functions

Now, let's define some utility functions that we'll use for our similarity searches.

In [None]:
def get_text_embedding(text, processor, model, device):
    inputs = processor(text=[text], return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        text_embedding = model.get_text_features(**inputs)
    return text_embedding.cpu()

def get_image_embedding(image_path, processor, model, device):
    try:
        image = Image.open(image_path).convert("RGB")
        inputs = processor(images=image, return_tensors="pt").to(device)
        with torch.no_grad():
            image_embedding = model.get_image_features(**inputs)
        return image_embedding.cpu()
    except Exception as e:
        print(f"Error processing image {image_path}: {e}")
        return None

def compute_similarity(embedding1, embedding2):
    embedding1 = embedding1 / embedding1.norm(dim=-1, keepdim=True)
    embedding2 = embedding2 / embedding2.norm(dim=-1, keepdim=True)
    return torch.matmul(embedding1, embedding2.T).squeeze()

def retrieve_top_images(query_embedding, image_paths, image_embeddings, top_k=5):
    similarities = compute_similarity(query_embedding, image_embeddings)
    top_k_indices = torch.topk(similarities, min(top_k, len(similarities))).indices
    top_k_image_paths = [image_paths[i] for i in top_k_indices]
    top_k_similarities = [similarities[i].item() for i in top_k_indices]
    return top_k_image_paths, top_k_similarities

def display_images(image_paths, similarities, image_size=(5, 5), font_size=10, show_search_image=None):
    def display_single_image(img_path, title):
        try:
            image = Image.open(img_path).convert("RGB")
            plt.figure(figsize=image_size)
            plt.imshow(np.array(image))
            plt.axis('off')
            plt.title(f"{title}\nPath: {img_path}", fontsize=font_size, wrap=True)
            plt.tight_layout()
            plt.show()
        except Exception as e:
            print(f"Error displaying image {img_path}: {e}")
    
    if show_search_image:
        display_single_image(show_search_image, "Search Image")
    
    for img_path, sim in zip(image_paths, similarities):
        display_single_image(img_path, f"Similarity: {sim:.4f}")

print("Utility functions defined.")

## Text-to-Image Search

Now, let's perform a text-to-image search. This process involves:
1. Converting the text query into an embedding using the CLIP model
2. Comparing this text embedding with all the pre-computed image embeddings
3. Retrieving the top matching images based on similarity scores

You can modify the `text_query` to search for different concepts. You can change the number of top_k to retrieve more (or less) similar images.

In [None]:
text_query = "a telegraph office"
top_k = 5
text_embedding = get_text_embedding(text_query, processor, model, device)
top_images, top_similarities = retrieve_top_images(text_embedding, image_paths, image_embeddings, top_k=top_k)

print(f"Top images matching the query '{text_query}':")
display_images(top_images, top_similarities, image_size=(8, 8), font_size=10)

## Image-to-Image Search

Now, let's perform an image-to-image search. This process involves:
1. Loading a search image and converting it into an embedding using the CLIP model
2. Comparing this image embedding with all the pre-computed image embeddings
3. Retrieving the top matching images based on similarity scores

You can modify the `search_image_path` to search for similar images to a different reference image.

In [None]:
search_image_path = './testimages/parliament.jpg'
top_k = 5
search_image_embedding = get_image_embedding(search_image_path, processor, model, device)

if search_image_embedding is not None:
    similar_images, similarities = retrieve_top_images(search_image_embedding, image_paths, image_embeddings, top_k=top_k)
    
    print(f"Top images similar to {search_image_path}:")
    display_images(similar_images, similarities, image_size=(8, 8), font_size=10, show_search_image=search_image_path)
else:
    print("Failed to process the search image.")

## Conclusion

This notebook demonstrates how to use the CLIP model for both text-to-image and image-to-image similarity searches. You can modify the text queries or search images to explore different results within your image dataset.

Remember that the quality of results depends on the diversity and relevance of your image dataset, as well as the pre-computed embeddings. If you want to use this notebook with your own image collection, you'll need to update the `IMAGE_SOURCE_FOLDER` and `EMBEDDINGS_FILE` environment variables, and ensure you have pre-computed CLIP embeddings for your images.