In [None]:
import os
import json
import torch
import clip
from PIL import Image
import numpy as np
from tqdm import tqdm

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the CLIP model
model, preprocess = clip.load("ViT-B/32", device=device)

# Function to load and preprocess images from a folder
def load_images_from_folder(folder_path):
    image_paths = []
    images = []
    for filename in os.listdir(folder_path):
        if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
            path = os.path.join(folder_path, filename)
            try:
                image = preprocess(Image.open(path).convert("RGB"))
                images.append(image)
                image_paths.append(path)
            except Exception as e:
                print(f"Error loading image {path}: {e}")
    return image_paths, images

# Load gallery images
gallery_folder = "test/gallery"
gallery_paths, gallery_images = load_images_from_folder(gallery_folder)

# Load query images
query_folder = "test/query"
query_paths, query_images = load_images_from_folder(query_folder)

# Compute embeddings for gallery images
with torch.no_grad():
    gallery_embeddings = []
    for image in tqdm(gallery_images, desc="Processing gallery images"):
        image = image.unsqueeze(0).to(device)
        embedding = model.encode_image(image)
        embedding /= embedding.norm(dim=-1, keepdim=True)
        gallery_embeddings.append(embedding.cpu().numpy())
    gallery_embeddings = np.vstack(gallery_embeddings)

# Compute embeddings for query images
with torch.no_grad():
    query_embeddings = []
    for image in tqdm(query_images, desc="Processing query images"):
        image = image.unsqueeze(0).to(device)
        embedding = model.encode_image(image)
        embedding /= embedding.norm(dim=-1, keepdim=True)
        query_embeddings.append(embedding.cpu().numpy())
    query_embeddings = np.vstack(query_embeddings)

# Compute cosine similarity and retrieve top-k matches
top_k = 10  # Adjust as needed
res = {}
for idx, query_embedding in enumerate(query_embeddings):
    similarities = gallery_embeddings @ query_embedding.T
    top_k_indices = similarities.flatten().argsort()[-top_k:][::-1]

    # Extract just the filenames (no folder path)
    query_filename = os.path.basename(query_paths[idx])
    top_k_filenames = [os.path.basename(gallery_paths[i]) for i in top_k_indices]

    res[query_filename] = top_k_filenames