# 🔍 CLIP Image Search

Search through images using natural language with OpenAI's CLIP model. This notebook runs in Colab's free tier and requires no API key.

## Model Information
CLIP (Contrastive Language-Image Pre-training) excels at zero-shot image search and classification. For even better results, you can:

1. **Fine-tune CLIP on your dataset**:
   - Visit [CLIP Fine-tuning Guide](https://huggingface.co/docs/transformers/main/en/model_doc/clip#training) on Hugging Face
   - Customize the model for your specific domain
   - Improve accuracy for specialized tasks

2. **Use domain-specific CLIP variants**:
   - [OpenCLIP](https://github.com/mlfoundations/open_clip): Open-source CLIP trained on larger datasets
   - [DomainCLIP](https://github.com/alibaba/EasyNLP/tree/master/examples/clip): Specialized for e-commerce
   - [MultiCLIP](https://github.com/OpenGVLab/Multi-Modality-Arena): Enhanced multilingual support

## Features
- Natural language image search
- Support for custom image datasets
- Real-time search results
- Adjustable similarity threshold

## Setup
First, let's install the required packages:

In [None]:
!pip install -q torch torchvision ftfy regex tqdm gradio Pillow open_clip_torch

## Import Dependencies

In [None]:
import torch
import clip
import open_clip
from PIL import Image
import gradio as gr
import os
from torchvision.datasets import CIFAR100
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
import numpy as np

## Load Model

In [None]:
# Available CLIP variants
MODELS = {
    "CLIP (ViT-B/32)": {"type": "clip", "name": "ViT-B/32"},
    "OpenCLIP (ViT-B-32)": {"type": "open_clip", "name": "ViT-B-32", "pretrained": "laion2b_s34b_b79k"},
    "OpenCLIP (ViT-L-14)": {"type": "open_clip", "name": "ViT-L-14", "pretrained": "laion2b_s32b_b82k"}
}

def load_model(model_choice="CLIP (ViT-B/32)"):
    model_config = MODELS[model_choice]
    
    if model_config["type"] == "clip":
        model, preprocess = clip.load(model_config["name"], device=device)
    else:
        model, _, preprocess = open_clip.create_model_and_transforms(
            model_config["name"],
            pretrained=model_config["pretrained"]
        )
        model = model.to(device)
    
    return model, preprocess

# Initialize with default model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = load_model()

print(f"Model loaded on: {device}")

## Load Sample Dataset
We'll use CIFAR100 as a sample dataset, but you can replace this with your own images.

In [None]:
# Load CIFAR100 dataset
dataset = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)

# Preprocess images and compute features
image_features = []
processed_images = []

for image, _ in dataset:
    # Convert PIL image to RGB
    image = image.convert('RGB')
    processed_images.append(image)
    
    # Preprocess and compute features
    image_input = preprocess(image).unsqueeze(0).to(device)
    with torch.no_grad():
        features = model.encode_image(image_input)
    
    image_features.append(features)

# Stack all features
image_features = torch.cat(image_features).cpu().numpy()

print(f"Loaded {len(processed_images)} images")

## Create Search Function

In [None]:
def search_images(query, model_choice="CLIP (ViT-B/32)", num_results=9, similarity_threshold=0.2):
    global model, preprocess
    
    # Load selected model if different from current
    if model_choice != current_model:
        model, preprocess = load_model(model_choice)
    
    # Encode text query
    with torch.no_grad():
        text_features = model.encode_text(clip.tokenize([query]).to(device))
    text_features = text_features.cpu().numpy()
    
    # Calculate similarities
    similarities = (image_features @ text_features.T).squeeze(1)
    
    # Get top matches above threshold
    valid_matches = similarities > similarity_threshold
    if not valid_matches.any():
        return None
    
    best_photo_idx = (-similarities).argsort()[:num_results]
    
    # Return matching images
    return [processed_images[i] for i in best_photo_idx if similarities[i] > similarity_threshold]

## Create Gradio Interface

In [None]:
def gradio_search(query, model_choice, similarity_threshold):
    results = search_images(query, model_choice, similarity_threshold=similarity_threshold)
    if results is None:
        return [None] * 9  # Return empty grid if no results
    
    # Pad results to fill grid
    while len(results) < 9:
        results.append(None)
    
    return results

interface = gr.Interface(
    fn=gradio_search,
    inputs=[
        gr.Textbox(
            label="Search Query",
            placeholder="Describe what you're looking for..."
        ),
        gr.Dropdown(
            choices=list(MODELS.keys()),
            value="CLIP (ViT-B/32)",
            label="Model",
            info="Choose between different CLIP variants"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=0.9,
            value=0.2,
            step=0.1,
            label="Similarity Threshold",
            info="Higher values = stricter matching"
        )
    ],
    outputs=[
        gr.Image(label="Result 1"),
        gr.Image(label="Result 2"),
        gr.Image(label="Result 3"),
        gr.Image(label="Result 4"),
        gr.Image(label="Result 5"),
        gr.Image(label="Result 6"),
        gr.Image(label="Result 7"),
        gr.Image(label="Result 8"),
        gr.Image(label="Result 9")
    ],
    title="CLIP Image Search",
    description="Search through images using natural language descriptions. Choose between different CLIP variants for optimal results.",
    examples=[
        ["a photo of a dog", "CLIP (ViT-B/32)", 0.2],
        ["beautiful landscape with mountains", "OpenCLIP (ViT-L-14)", 0.2],
        ["red flowers in a garden", "OpenCLIP (ViT-B-32)", 0.2]
    ]
)

interface.launch(share=True)