In [None]:
# Notebook to be executed entirely in Google Colab

from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install torch transformers pillow numpy scikit-learn matplotlib

In [None]:
# Colab Cell 3: Authenticate Kaggle API and Download Dataset

# Create a directory for Kaggle configuration
!mkdir -p ~/.kaggle

# Copy the uploaded kaggle.json to the correct directory
# Make sure 'kaggle.json' is in /content/ (uploaded via the file browser)
!cp /content/kaggle.json ~/.kaggle/

# Set appropriate permissions for the Kaggle API key (important for security)
!chmod 600 ~/.kaggle/kaggle.json

# --- Download the Sketchy Dataset from Kaggle ---
# Define a directory in Colab's temporary storage for downloads
# This is *not* on your Google Drive initially, but faster for downloads
DOWNLOAD_PATH = "/content/data" # Ensure this variable is clean

# Create the download directory if it doesn't exist
!mkdir -p "$DOWNLOAD_PATH"

# Change current directory to where we want to download
%cd "$DOWNLOAD_PATH"

# Replace with the exact Kaggle API command for YOUR chosen Sketchy dataset if different
# You can find this command on the dataset's Kaggle page under '...' -> 'Copy API command'
KAGGLE_DATASET_SLUG = "dhananjayapaliwal/fulldataset" # This is a common Sketchy dataset ID
!kaggle datasets download -d {KAGGLE_DATASET_SLUG} --unzip

print("\n--- Listing extracted contents to verify path ---")
# List contents of the download directory to confirm extraction
!ls -F "$DOWNLOAD_PATH"

In [None]:
# Colab Cell 3: Authenticate Kaggle API and Download Dataset

# Create a directory for Kaggle configuration
!mkdir -p ~/.kaggle

# Copy the uploaded kaggle.json to the correct directory
# Make sure 'kaggle.json' is in /content/ (uploaded via the file browser)
!cp /content/kaggle.json ~/.kaggle/

# Set appropriate permissions for the Kaggle API key (important for security)
!chmod 600 ~/.kaggle/kaggle.json

# --- Download the Sketchy Dataset from Kaggle ---
# Define a directory in Colab's temporary storage for downloads
# This is *not* on your Google Drive initially, but faster for downloads
DOWNLOAD_PATH = "/content/data" # Ensure this variable is clean

# Change current directory to where we want to download
%cd "$DOWNLOAD_PATH"

# Replace with the exact Kaggle API command for YOUR chosen Sketchy dataset if different
# You can find this command on the dataset's Kaggle page under '...' -> 'Copy API command'
KAGGLE_DATASET_SLUG = "rishikashili/tuberlin" # This is a common Sketchy dataset ID
!kaggle datasets download -d {KAGGLE_DATASET_SLUG} --unzip

print("\n--- Listing extracted contents to verify path ---")
# List contents of the download directory to confirm extraction
!ls -F "$DOWNLOAD_PATH"

In [None]:
import os
import random
import json

# --- Set seed for reproducibility ---
random.seed(42)

# --- Paths ---
KAGGLE_EXTRACTED_FOLDER_NAME = "temp_extraction"
BASE_DATA_DIR = os.path.join("/content/data", KAGGLE_EXTRACTED_FOLDER_NAME)
TRAIN_SKETCH_PATH = os.path.join(BASE_DATA_DIR, "256x256", "splitted_sketches", "train", "tx_000100000000")
TRAIN_PHOTO_PATH = os.path.join(BASE_DATA_DIR, "256x256", "photo", "tx_000100000000")
GOOGLE_DRIVE_ROOT = "/content/drive/My Drive"
EMBEDDINGS_DIR_ON_DRIVE = os.path.join(GOOGLE_DRIVE_ROOT, "ML_Embeddings_SBIR_PyTorch")
os.makedirs(EMBEDDINGS_DIR_ON_DRIVE, exist_ok=True)
PROJECTION_NET_SAVE_PATH = os.path.join(EMBEDDINGS_DIR_ON_DRIVE, "sketch_projection_net_category_infonce.pth")
ALL_CATEGORIES = sorted([cat for cat in os.listdir(TRAIN_SKETCH_PATH) if os.path.isdir(os.path.join(TRAIN_SKETCH_PATH, cat))])

print(f"Total categories found: {len(ALL_CATEGORIES)}")

# --- Split into 100 train and 25 test categories ---
train_categories = random.sample(ALL_CATEGORIES, 100)
test_categories = [cat for cat in ALL_CATEGORIES if cat not in train_categories]

print(f"Train categories: {len(train_categories)}")
print(f"Test categories: {len(test_categories)}")

# --- Save to JSON for later use ---
SPLIT_SAVE_PATH = os.path.join(EMBEDDINGS_DIR_ON_DRIVE, "category_split.json")
with open(SPLIT_SAVE_PATH, "w") as f:
    json.dump({
        "train_categories": train_categories,
        "test_categories": test_categories
    }, f, indent=2)

print(f"Category split saved at {SPLIT_SAVE_PATH}")


In [None]:
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import os
from tqdm import tqdm
import glob
import torch.nn.functional as F
from transformers import CLIPProcessor, CLIPModel
import json

# --- Device Configuration ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# --- Load Pre-trained CLIP Model (Frozen) ---
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = model.to(device)
model.eval()
for param in model.parameters():
    param.requires_grad = False

CLIP_EMBEDDING_DIM = 512

# --- Dataset Class for Category-level Training ---
class CategorySketchPhotoDataset(Dataset):
    def __init__(self, sketch_dir, photo_dir, category_list):
        self.sketch_files = []
        self.sketch_labels = []
        self.photo_files = []
        self.photo_labels = []
        self.category_to_idx = {cat: idx for idx, cat in enumerate(category_list)}

        # Load sketches
        for category in category_list:
            cat_path = os.path.join(sketch_dir, category)
            if not os.path.isdir(cat_path):
                continue
            files = glob.glob(os.path.join(cat_path, '*'))
            self.sketch_files.extend(files)
            self.sketch_labels.extend([self.category_to_idx[category]] * len(files))

        # Load photos
        for category in category_list:
            cat_path = os.path.join(photo_dir, category)
            if not os.path.isdir(cat_path):
                continue
            files = glob.glob(os.path.join(cat_path, '*'))
            self.photo_files.extend(files)
            self.photo_labels.extend([self.category_to_idx[category]] * len(files))

        # Build category index for faster sampling
        self.photo_by_category = {}
        for file, label in zip(self.photo_files, self.photo_labels):
            if label not in self.photo_by_category:
                self.photo_by_category[label] = []
            self.photo_by_category[label].append(file)

    def __len__(self):
        return len(self.sketch_files)

    def __getitem__(self, idx):
        sketch_path = self.sketch_files[idx]
        label = self.sketch_labels[idx]
        sketch_img = Image.open(sketch_path).convert("RGB")

        # Sample a random photo from the same category
        photo_path = np.random.choice(self.photo_by_category[label])
        photo_img = Image.open(photo_path).convert("RGB")

        return sketch_img, photo_img, label

# --- Faster Collate Function ---
def fast_collate_fn(batch):
    sketches, photos, labels = zip(*batch)
    sketch_inputs = processor(images=list(sketches), return_tensors="pt", padding=True)
    photo_inputs = processor(images=list(photos), return_tensors="pt", padding=True)
    return sketch_inputs['pixel_values'], photo_inputs['pixel_values'], torch.tensor(labels)

# --- Setup Dataset and Dataloader ---
# Load category split
SPLIT_SAVE_PATH = os.path.join(EMBEDDINGS_DIR_ON_DRIVE, "category_split.json")
with open(SPLIT_SAVE_PATH, "r") as f:
    split = json.load(f)

train_categories = split["train_categories"]

# Initialize dataset with train categories only
train_dataset = CategorySketchPhotoDataset(TRAIN_SKETCH_PATH, TRAIN_PHOTO_PATH, split["train_categories"])
print(f"Total sketch-photo category pairs loaded: {len(train_dataset)}")
dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2, collate_fn=fast_collate_fn)

# --- Define the Sketch Projection Network ---
class SketchProjectionNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SketchProjectionNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.ReLU(),
            nn.Linear(input_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)

projection_net = SketchProjectionNet(CLIP_EMBEDDING_DIM, CLIP_EMBEDDING_DIM).to(device)
optimizer = optim.Adam(projection_net.parameters(), lr=1e-4)

# --- InfoNCE Loss ---
def info_nce_loss(sketch_embeds, photo_embeds, temperature=0.07):
    sketch_embeds = F.normalize(sketch_embeds, p=2, dim=-1)
    photo_embeds = F.normalize(photo_embeds, p=2, dim=-1)

    logits = torch.matmul(sketch_embeds, photo_embeds.T) / temperature
    labels = torch.arange(logits.size(0)).to(device)
    loss = F.cross_entropy(logits, labels)
    return loss

# --- Training Loop ---
NUM_EPOCHS = 5
for epoch in range(NUM_EPOCHS):
    projection_net.train()
    total_loss = 0
    for sketches, photos, labels in tqdm(dataloader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}"):
        sketches, photos = sketches.to(device), photos.to(device)

        optimizer.zero_grad()

        with torch.no_grad():
            sketch_embeds = model.get_image_features(sketches)
            photo_embeds = model.get_image_features(photos)

        projected = projection_net(sketch_embeds)

        loss = info_nce_loss(projected, photo_embeds)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} completed. Avg Loss: {total_loss / len(dataloader):.4f}")

# --- Save Projection Network ---
torch.save(projection_net.state_dict(), PROJECTION_NET_SAVE_PATH)
print(f"Model saved at {PROJECTION_NET_SAVE_PATH}")


In [None]:
# Colab Cell 5: Evaluate with the Trained Projection Network

import torch
# --- CORRECTED IMPORTS HERE ---
import torch.nn as nn # <-- Added
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from transformers import CLIPProcessor, CLIPModel
from PIL import Image # <-- Added
import os
from tqdm import tqdm # <-- Added (though not used in eval, good practice if modified)
import random # <-- Added
import glob # <-- Added (for finding random sketch)
import matplotlib.pyplot as plt

# --- Device Configuration ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# --- Load Pre-trained CLIP Model (Frozen for inference) ---
print("\nLoading CLIP model for inference...")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model.to(device)
for param in model.parameters(): # Ensure it's frozen again if running independently
    param.requires_grad = False
print("CLIP model loaded.")

# --- Define the Sketch Projection Network (Must be defined again in this cell) ---
class SketchProjectionNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SketchProjectionNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.ReLU(),
            nn.Linear(input_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)

CLIP_EMBEDDING_DIM = 512
projection_net = SketchProjectionNet(CLIP_EMBEDDING_DIM, CLIP_EMBEDDING_DIM).to(device)

# --- Define Dataset and Embedding Paths (GLOBAL TO THIS CELL) ---
KAGGLE_EXTRACTED_FOLDER_NAME = "temp_extraction" # Must be consistent with Cell 3 & 4
BASE_DATA_DIR = os.path.join("/content/data", KAGGLE_EXTRACTED_FOLDER_NAME)

# Paths for saving/loading embeddings on Google Drive for persistence
GOOGLE_DRIVE_ROOT = "/content/drive/My Drive"
EMBEDDINGS_DIR_ON_DRIVE = os.path.join(GOOGLE_DRIVE_ROOT, "ML_Embeddings_SBIR_PyTorch")
os.makedirs(EMBEDDINGS_DIR_ON_DRIVE, exist_ok=True) # Ensure this directory exists
PROJECTION_NET_SAVE_PATH = os.path.join(EMBEDDINGS_DIR_ON_DRIVE, "sketch_projection_net.pth")
GALLERY_EMBEDDINGS_FILE = os.path.join(EMBEDDINGS_DIR_ON_DRIVE, "gallery_embeddings.npy")
GALLERY_IMAGE_PATHS_FILE = os.path.join(EMBEDDINGS_DIR_ON_DRIVE, "gallery_image_filepaths.npy")

if os.path.exists(PROJECTION_NET_SAVE_PATH):
    projection_net.load_state_dict(torch.load(PROJECTION_NET_SAVE_PATH, map_location=device))
    projection_net.eval() # Set to evaluation mode (important for dropout/batchnorm layers if present)
    print(f"Sketch Projection Network loaded from: {PROJECTION_NET_SAVE_PATH}")
else:
    print(f"Error: Trained Sketch Projection Network not found at {PROJECTION_NET_SAVE_PATH}")
    print("Please run the training cell (Cell 4) first and ensure it completes successfully.")
    exit()

# --- Load Your Image Gallery Embeddings (using original CLIP Image Encoder) ---
# This part is the same as before, using the frozen CLIP model.
gallery_embeddings = None
gallery_image_filepaths = None


if os.path.exists(GALLERY_EMBEDDINGS_FILE) and os.path.exists(GALLERY_IMAGE_PATHS_FILE):
    print(f"\nLoading pre-saved gallery embeddings from {EMBEDDINGS_DIR_ON_DRIVE}...")
    gallery_embeddings = np.load(GALLERY_EMBEDDINGS_FILE)
    gallery_image_filepaths = np.load(GALLERY_IMAGE_PATHS_FILE, allow_pickle=True)
    print(f"Loaded gallery embeddings shape: {gallery_embeddings.shape}")
else:
    print("\nGallery embeddings not found. Attempting to build them now (this may take a while)...")
    gallery_embeddings_list = []
    gallery_image_filepaths_list = []

    IMAGE_GALLERY_PATH = os.path.join(BASE_DATA_DIR, "256x256", "photo", "tx_000100000000") # Ensure this is defined
    if not os.path.exists(IMAGE_GALLERY_PATH):
        print(f"Error: Image gallery path does not exist for building: {IMAGE_GALLERY_PATH}")
        print("Please ensure your dataset is downloaded/unzipped correctly and paths are set correctly.")
        exit()

    with torch.no_grad():
        for category_name in tqdm(os.listdir(IMAGE_GALLERY_PATH), desc="Processing gallery images"):
            category_path = os.path.join(IMAGE_GALLERY_PATH, category_name)
            if not os.path.isdir(category_path):
                continue

            for image_filename in os.listdir(category_path):
                if image_filename.lower().endswith(('.png', '.jpg', '.jpeg')):
                    image_filepath = os.path.join(category_path, image_filename)
                    try:
                        image = Image.open(image_filepath).convert("RGB")
                        inputs = processor(images=image, return_tensors="pt")
                        inputs = {k: v.to(device) for k, v in inputs.items()}
                        image_feature = model.get_image_features(**inputs).detach().cpu().numpy()
                        gallery_embeddings_list.append(image_feature)
                        gallery_image_filepaths_list.append(image_filepath)
                    except Exception as e:
                        print(f"Could not process image {image_filepath}: {e}")

    if gallery_embeddings_list:
        gallery_embeddings = np.vstack(gallery_embeddings_list)

        gallery_image_filepaths = np.array(gallery_image_filepaths_list)
        print(f"Encoded {len(gallery_embeddings_list)} gallery images. Embeddings shape: {gallery_embeddings.shape}")

        np.save(GALLERY_EMBEDDINGS_FILE, gallery_embeddings)
        np.save(GALLERY_IMAGE_PATHS_FILE, gallery_image_filepaths)
        print(f"Gallery embeddings saved to {EMBEDDINGS_DIR_ON_DRIVE}")
    else:
        print("No images processed in the gallery. Check your IMAGE_GALLERY_PATH and dataset contents.")
        exit()


In [None]:
import os
import glob
import numpy as np
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import average_precision_score
from sklearn.metrics.pairwise import cosine_similarity
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
from transformers import CLIPProcessor, CLIPModel

# --- Device ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# --- Paths (update these for your environment) ---
TEST_SKETCH_PATH = os.path.join(BASE_DATA_DIR, "256x256", "splitted_sketches", "test", "tx_000100000000")
TEST_PHOTO_PATH = os.path.join(BASE_DATA_DIR, "256x256", "photo", "tx_000100000000")
SPLIT_SAVE_PATH = os.path.join(EMBEDDINGS_DIR_ON_DRIVE, "category_split.json")
PROJECTION_NET_SAVE_PATH = os.path.join(EMBEDDINGS_DIR_ON_DRIVE, "sketch_projection_net_category_infonce.pth")

# --- Load category split ---
with open(SPLIT_SAVE_PATH, "r") as f:
    split = json.load(f)
test_categories = split["test_categories"]

# --- Load CLIP model ---
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = model.to(device)
model.eval()
for param in model.parameters():
    param.requires_grad = False

CLIP_EMBEDDING_DIM = 512

# --- Define SketchProjectionNet ---
class SketchProjectionNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SketchProjectionNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.ReLU(),
            nn.Linear(input_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)

# --- Load Projection Network ---
projection_net = SketchProjectionNet(CLIP_EMBEDDING_DIM, CLIP_EMBEDDING_DIM).to(device)
projection_net.load_state_dict(torch.load(PROJECTION_NET_SAVE_PATH, map_location=device))
projection_net.eval()

# --- Build gallery embeddings ---
print("Precomputing gallery embeddings for test categories...")

gallery_image_filepaths = []
gallery_embeddings = []

for category in tqdm(test_categories, desc="Encoding gallery"):
    cat_path = os.path.join(TEST_PHOTO_PATH, category)
    if not os.path.isdir(cat_path):
        continue

    image_files = glob.glob(os.path.join(cat_path, '*'))
    for img_file in image_files:
        try:
            img = Image.open(img_file).convert("RGB")
            inputs = processor(images=img, return_tensors="pt").to(device)

            with torch.no_grad():
                embed = model.get_image_features(**inputs)
                embed = F.normalize(embed, p=2, dim=-1).cpu().numpy()

            gallery_embeddings.append(embed)
            gallery_image_filepaths.append(img_file)

        except Exception as e:
            print(f"Error processing gallery image {img_file}: {e}")

gallery_embeddings = np.concatenate(gallery_embeddings, axis=0)

# --- Evaluation ---
print("\nRunning full benchmark on test split...")

correct_top1 = 0
all_aps = []
total_sketches = 0

for category in tqdm(test_categories, desc="Evaluating sketches"):
    cat_path = os.path.join(TEST_SKETCH_PATH, category)
    if not os.path.isdir(cat_path):
        continue

    sketch_files = glob.glob(os.path.join(cat_path, "*.png")) + \
                   glob.glob(os.path.join(cat_path, "*.jpg")) + \
                   glob.glob(os.path.join(cat_path, "*.jpeg"))

    for sketch_file in sketch_files:
        try:
            sketch_image = Image.open(sketch_file).convert("RGB")
            inputs = processor(images=sketch_image, return_tensors="pt").to(device)

            with torch.no_grad():
                sketch_embed = model.get_image_features(**inputs)
                projected_embed = projection_net(sketch_embed)
                projected_embed = F.normalize(projected_embed, p=2, dim=-1).cpu().numpy()

            # Cosine similarity
            sims = cosine_similarity(projected_embed, gallery_embeddings)[0]
            sorted_indices = np.argsort(sims)[::-1]

            # Top-1 prediction
            top1_image_path = gallery_image_filepaths[sorted_indices[0]]
            top1_category = os.path.basename(os.path.dirname(top1_image_path))

            if top1_category == category:
                correct_top1 += 1

            # mAP calculation
            true_labels = np.array([
                1 if os.path.basename(os.path.dirname(img_path)) == category else 0
                for img_path in gallery_image_filepaths
            ])
            ap = average_precision_score(true_labels, sims)
            all_aps.append(ap)

            total_sketches += 1

        except Exception as e:
            print(f"Error processing sketch {sketch_file}: {e}")

# --- Final Metrics ---
rank1_acc = correct_top1 / total_sketches * 100 if total_sketches > 0 else 0
mean_ap = np.mean(all_aps) * 100 if all_aps else 0

print("\nBenchmark Results:")
print(f"Total Test Sketches Evaluated: {total_sketches}")
print(f"Rank-1 Accuracy: {rank1_acc:.2f}%")
print(f"Mean Average Precision (mAP): {mean_ap:.2f}%")


In [None]:
!pip install faiss-cpu

In [None]:
import os
import glob
import numpy as np
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import average_precision_score
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPProcessor, CLIPModel
import faiss

GALLERY_EMBEDDINGS_SAVE_PATH = "/content/drive/My Drive/gallery_embeddings.npz"

# --- Device Configuration ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# --- Define SketchProjectionNet ---
class SketchProjectionNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SketchProjectionNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.ReLU(),
            nn.Linear(input_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)

# --- Load Pre-trained CLIP Model (Frozen) ---
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32",)
model = model.to(device)
model.eval()
for param in model.parameters():
    param.requires_grad = False

CLIP_EMBEDDING_DIM = 512

# --- Paths ---
TU_BERLIN_SKETCH_PATH = os.path.join("/content/data/TUBerlin", "png_ready")
TU_BERLIN_IMAGE_PATH = os.path.join("/content/data/TUBerlin", "ImageResized_ready")
PROJECTION_NET_PATH = "/content/drive/MyDrive/ML_Embeddings_SBIR_PyTorch/sketch_projection_net_category_infonce.pth"

# --- Load TU-Berlin categories ---
tu_berlin_categories = sorted(os.listdir(TU_BERLIN_SKETCH_PATH))

# --- Load Projection Network ---
projection_net = SketchProjectionNet(CLIP_EMBEDDING_DIM, CLIP_EMBEDDING_DIM).to(device)
projection_net.load_state_dict(torch.load(PROJECTION_NET_PATH, map_location=device))
projection_net.eval()

# --- Build or Load gallery embeddings ---
if os.path.exists(GALLERY_EMBEDDINGS_SAVE_PATH):
    print("Loading precomputed gallery embeddings...")
    saved = np.load(GALLERY_EMBEDDINGS_SAVE_PATH, allow_pickle=True)
    gallery_embeddings = saved['embeddings']
    gallery_image_filepaths = saved['filepaths'].tolist()

    # Ensure shape is (N,512)
    if len(gallery_embeddings.shape) == 1:
        gallery_embeddings = np.vstack(gallery_embeddings)

    print("Loaded gallery embeddings with shape:", gallery_embeddings.shape)

else:
    print("Precomputing gallery embeddings for TU-Berlin...")

    gallery_image_filepaths = []
    gallery_embeddings = []

    for category in tqdm(tu_berlin_categories, desc="Encoding gallery"):
        cat_path = os.path.join(TU_BERLIN_IMAGE_PATH, category)
        if not os.path.isdir(cat_path):
            continue

        image_files = glob.glob(os.path.join(cat_path, '*'))
        for img_file in image_files:
            try:
                img = Image.open(img_file).convert("RGB")
                inputs = processor(images=img, return_tensors="pt").to(device)

                with torch.no_grad():
                    embed = model.get_image_features(**inputs)
                    embed = F.normalize(embed, p=2, dim=-1).cpu().numpy()

                gallery_embeddings.append(embed)
                gallery_image_filepaths.append(img_file)

            except Exception as e:
                print(f"Error processing gallery image {img_file}: {e}")

    gallery_embeddings = np.concatenate(gallery_embeddings, axis=0)
    print("Final gallery embeddings shape:", gallery_embeddings.shape)

    np.savez(GALLERY_EMBEDDINGS_SAVE_PATH, embeddings=gallery_embeddings, filepaths=np.array(gallery_image_filepaths))
    print(f"Saved gallery embeddings to {GALLERY_EMBEDDINGS_SAVE_PATH}")

# --- Build FAISS Index for Exact Search ---
index = faiss.IndexFlatIP(CLIP_EMBEDDING_DIM)
index.add(gallery_embeddings.astype(np.float32))

# --- Evaluation with batching ---
print("\nRunning zero-shot benchmark on TU-Berlin with batching and FAISS...")

correct_top1 = 0
all_aps = []
total_sketches = 0
BATCH_SIZE = 32  # Reduced for stability
MAX_K = min(1000, len(gallery_embeddings))  # Limit k to avoid RAM explosion

for category in tqdm(tu_berlin_categories, desc="Evaluating sketches"):
    cat_path = os.path.join(TU_BERLIN_SKETCH_PATH, category)
    if not os.path.isdir(cat_path):
        continue

    sketch_files = glob.glob(os.path.join(cat_path, "*.png")) + \
                   glob.glob(os.path.join(cat_path, "*.jpg")) + \
                   glob.glob(os.path.join(cat_path, "*.jpeg"))

    # Process in batches
    for i in range(0, len(sketch_files), BATCH_SIZE):
        batch_files = sketch_files[i:i+BATCH_SIZE]
        images = [Image.open(f).convert("RGB") for f in batch_files]

        inputs = processor(images=images, return_tensors="pt", padding=True).to(device)

        with torch.no_grad():
            sketch_embeds = model.get_image_features(**inputs)
            projected_embeds = projection_net(sketch_embeds)
            projected_embeds = F.normalize(projected_embeds, p=2, dim=-1).cpu().numpy()

        # FAISS search with limited k
        sims, indices = index.search(projected_embeds.astype(np.float32), k=MAX_K)

        for j, sims_j in enumerate(sims):
            sorted_indices = indices[j]
            top1_image_path = gallery_image_filepaths[sorted_indices[0]]
            top1_category = os.path.basename(os.path.dirname(top1_image_path))

            if top1_category == category:
                correct_top1 += 1

            # mAP calculation (approximate over top-k for efficiency)
            true_labels = np.array([
              1 if os.path.basename(os.path.dirname(gallery_image_filepaths[idx])) == category else 0
              for idx in sorted_indices
            ])
            if np.sum(true_labels) == 0:
              # No positive sample retrieved; skip AP calculation for this sketch
              continue

            ap = average_precision_score(true_labels, sims_j)
            all_aps.append(ap)

            total_sketches += 1

# --- Final Metrics ---
rank1_acc = correct_top1 / total_sketches * 100 if total_sketches > 0 else 0
mean_ap = np.mean(all_aps) * 100 if all_aps else 0

print("\nTU-Berlin Zero-Shot Benchmark Results (Batched + FAISS):")
print(f"Total Test Sketches Evaluated: {total_sketches}")
print(f"Rank-1 Accuracy: {rank1_acc:.2f}%")
print(f"Mean Average Precision (mAP): {mean_ap:.2f}%")
