# Image Search Engine using OpenAI CLIP

This notebook demonstrates how to create an image search engine using the CLIP model from OpenAI and the BLIP model from Salesforce to generate images' captions.


In [None]:
!pip3 install torch torchvision torchaudio matplotlib transformers

In [None]:
import os

import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from transformers import CLIPProcessor, CLIPModel, BlipProcessor, BlipForConditionalGeneration

In [None]:
# Load pre-trained models

clip_model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32')
clip_processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')

blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

In [None]:
# Load images from the specified directory and generate captions using the BLIP model. 
# Then, prepare the images and their generated captions for processing with the CLIP model.

# TODO: Replace the image_dir with the path to your images
# You can find several images in week-1/images
image_dir = '/'
allowed_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff'}

images = [
    Image.open(os.path.join(image_dir, filename))
    for filename in sorted(os.listdir(image_dir))
    if os.path.splitext(filename)[-1].lower() in allowed_extensions
]

generated_captions = []
for image in images:
    inputs = blip_processor(images=image, return_tensors="pt")
    out = blip_model.generate(**inputs)
    caption = blip_processor.decode(out[0], skip_special_tokens=True)
    generated_captions.append(caption)

inputs = clip_processor(
    text=generated_captions,
    images=images,
    return_tensors='pt',
    padding=True,
    truncation=True,
    max_length=50
)

outputs = clip_model(**inputs)
probs = outputs.logits_per_image.argmax(dim=1)

In [None]:
# Generate image embeddings using the CLIP model

import torch

image_embeddings = []

for image in images:
    inputs = clip_processor(images=image, return_tensors="pt")

    with torch.no_grad():
        image_embed = clip_model.get_image_features(**inputs)

    image_embeddings.append(image_embed)

image_embeddings = torch.stack(image_embeddings).squeeze(1)

In [None]:
# Display each image along with its generated caption

for i, image in enumerate(images):
    argmax = probs[i].item()
    print(f"Caption: {generated_captions[argmax]}")
    plt.imshow(np.asarray(image))
    plt.axis('off')
    plt.show()

In [None]:
# Function to search for an image based on a text description

def search_image(query_text):
    text_inputs = clip_processor(text=[query_text], return_tensors='pt', padding=True, truncation=True, max_length=77)

    with torch.no_grad():
        text_embed = clip_model.get_text_features(**text_inputs)

    similarity_scores = torch.nn.functional.cosine_similarity(text_embed, image_embeddings)

    best_match_index = similarity_scores.argmax().item()

    return images[best_match_index]

# Get user input for the image search
input_text = input("Enter a description to search for an image: ")

if input_text:
    matched_image = search_image(input_text)

    plt.imshow(np.asarray(matched_image))
    plt.axis('off')
    plt.show()
else:
    print("No input provided.")
