In [1]:
# Prep & dependencies

import requests, os
import torch, time
from transformers import CLIPModel, CLIPProcessor

from IPython.display import display, Image, HTML
import torch.nn.functional as F  # Import softmax from PyTorch

In [None]:
# Run to load base model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Load base beto
beto = torch.load('search/beto.pt')
beto_idx = torch.load('search/beto_idx.pt')
beto_normalized = torch.load('search/beto_normalized.pt') 
print(beto.shape)

In [3]:
# DEPRECATED
# Run to load fine-tuned model
def load_fine_tuned_model(model_path):
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", state_dict=None)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    return model

# Load fine-tuned model
model_path = "fine-tuning/last_fine_tuned_clip.pt"  # Adjust this to your model's file path
model = load_fine_tuned_model(model_path)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Load fine-tuned beto
beto = torch.load('beto_ft.pt')
beto_idx = torch.load('beto_idx_ft.pt')
beto_normalized = torch.load('beto_normalized_ft.pt') 
print(beto.shape)

torch.Size([408896, 512])


# Text search

In [None]:
# Make sure 'processor', 'model', 'beto', and 'beto_idx' are predefined!
results = int(input("Enter the number of results to display: "))
query = input("Enter the query to search for: ")
if "map" not in query:
    query += " map"
print(query)
start = time.time()

# Preprocess and get text embeddings in batch if possible
text_preprocess = processor(text=query, return_tensors="pt", padding=True)
text_embeds = model.get_text_features(**text_preprocess)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
logit_scale = model.logit_scale.exp()
# Compute similarities (logits) for all pairs, assuming 'beto' is compatible in size
logits = torch.matmul(beto_normalized, text_embeds.t()) * logit_scale

# First, find the top k logits
top_logits, top_indices = torch.topk(logits, k=results, dim=0)
    
# Display the results
for i in range(results):
    idx = top_indices[i].item()
    # Use top_logits for displaying similarity
    print(f"Result {i+1}: {beto_idx[idx]} with similarity {top_logits[i].item():.3f}")

image_urls = [beto_idx[idx] for idx in top_indices]
images_html = "".join(f"<img style='width: 400px; margin: 0px; float: left; border: 1px solid black;' src='{url}' />" for url in image_urls)
display(HTML(images_html))

print("Time: {:.3f} seconds".format(time.time()-start))

# Image search

In [None]:
from PIL import Image

# Assuming 'processor', 'model', 'beto', and 'beto_idx' are predefined
results = int(input("Enter the number of results: "))
query = input("Enter the URL of the image to search for: ")

start = time.time()

# Preprocess and get text embeddings in batch if possible (qimg stands for query image)
image = Image.open(requests.get(query, stream=True).raw)
qimg_preprocess = processor(images = image, return_tensors="pt", padding=True) 
qimg_embeds = model.get_image_features(**qimg_preprocess) 
qimg_embeds = qimg_embeds / qimg_embeds.norm(p=2, dim=-1, keepdim=True)
logit_scale = model.logit_scale.exp()
print(logit_scale)
# Compute similarities (logits) for all pairs, assuming 'beto' is compatible in size
logits = torch.matmul(beto_normalized, qimg_embeds.t()) * logit_scale

# Directly find top-k values and indices for each query
top_values, top_indices = torch.topk(logits, k=results, dim=0)

# Display the results
for i in range(results):
    idx = top_indices[i].item()
    print(f"Result {i+1}: {beto_idx[idx]} with similarity {top_values[i].item():.3f}")

image_urls = [beto_idx[idx] for idx in top_indices]
images_html = "".join(f"<img style='width: 400px; margin: 0px; float: left; border: 1px solid black;' src='{url}' />" for url in image_urls)
display(HTML(images_html))

print("Time: {:.3f} seconds".format(time.time()-start))

# Text + image search

In [None]:
from PIL import Image
import numpy as np

# Input for both image and text
results = int(input("Enter the number of results: "))
img_query = input("Enter the URL of the image to search for: ")
text_query = input("Enter the text query to search for: ")
if "map" not in text_query: 
    text_query = "a " + text_query + " map"
scaling = float(input("Enter the scaling factor for the combined query: "))
start = time.time()

# Preprocess and get image embeddings (qimg stands for query image)
qimg = Image.open(requests.get(img_query, stream=True).raw)
qimg_preprocess = processor(images = qimg, return_tensors="pt", padding=True)
qimg_embeds = model.get_image_features(**qimg_preprocess)
qimg_embeds = qimg_embeds / qimg_embeds.norm(p=2, dim=-1, keepdim=True)

# Preprocess and get text embeddings (qtext stands for query text)
text_preprocess = processor(text=text_query, return_tensors="pt", padding=True)
qtext_embeds = model.get_text_features(**text_preprocess)
qtext_embeds = qtext_embeds / qtext_embeds.norm(p=2, dim=-1, keepdim=True)

qimg_embeds_input = ((1 - scaling) * qimg_embeds)/2
qtext_embeds_input = ((1 + scaling) * qtext_embeds)/2
# Create a tensor for the combined query
combined_embeds = torch.add(qimg_embeds_input, qtext_embeds_input) / 2
logit_scale = model.logit_scale.exp()

logits = torch.matmul(beto_normalized, combined_embeds.t()) * logit_scale

# Directly find top-k values and indices for each query
top_values, top_indices = torch.topk(logits, k=results, dim=0)

# Display the results
for i in range(results):
    idx = top_indices[i].item()
    print(f"Result {i+1}: {beto_idx[idx]} with similarity {top_values[i].item():.3f}")

image_urls = [beto_idx[idx] for idx in top_indices]
images_html = "".join(f"<img style='width: 400px; margin: 0px; float: left; border: 1px solid black;' src='{url}' />" for url in image_urls)
display(HTML(images_html))

print("Time: {:.3f} seconds".format(time.time()-start))