1. Imports and Configuration Setup

In [None]:
# === IMPORTS ===
import os
import random
import numpy as np
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms, datasets, models
from torch.optim import AdamW
from torch.cuda.amp import autocast

# === CONFIGURATION ===
DATA_DIR      = "your_dataset_path"
TRAIN_DIR     = os.path.join(DATA_DIR, "train")
QUERY_DIR     = os.path.join(DATA_DIR, "test", "query")
GALLERY_DIR   = os.path.join(DATA_DIR, "test", "gallery")

BATCH_SIZE    = 32
NUM_EPOCHS    = 5
LR            = 1e-4
WEIGHT_DECAY  = 5e-4
TOP_K         = 10
EMB_SIZE      = 768  # for DINOv2-base embeddings
VAL_SPLIT     = 0.2
NUM_WORKERS   = 2

# Automatically select GPU if available
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === RANDOM SEED FOR REPRODUCIBILITY ===
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)


  from .autonotebook import tqdm as notebook_tqdm


2. Data Preparation: Transforms, Splits, Triplet Dataset

In [2]:
# === IMAGE TRANSFORMATIONS ===
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(518),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Optional: More aggressive augmentation (disabled for now)
# train_transform = transforms.Compose([
#     transforms.RandomResizedCrop(518),
#     transforms.RandomHorizontalFlip(),
#     transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
#     transforms.RandomGrayscale(p=0.1),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
# ])

val_transform = transforms.Compose([
    transforms.Resize(540),
    transforms.CenterCrop(518),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# === TRAIN/VALIDATION SPLIT (STRATIFIED) ===
full_dataset = datasets.ImageFolder(TRAIN_DIR)

# Group indices by class
class_indices = [[] for _ in range(len(full_dataset.classes))]
for idx, (_, label) in enumerate(full_dataset.samples):
    class_indices[label].append(idx)

# Stratified split
train_indices, val_indices = [], []
for indices in class_indices:
    n_total = len(indices)
    n_val = int(n_total * VAL_SPLIT)
    random.shuffle(indices)
    val_indices.extend(indices[:n_val])
    train_indices.extend(indices[n_val:])

# Create subsets for DataLoader
train_subset = torch.utils.data.Subset(full_dataset, train_indices)
val_subset   = torch.utils.data.Subset(full_dataset, val_indices)

# Assign transforms to subsets
train_subset.dataset.transform = train_transform
val_subset.dataset.transform   = val_transform

# === TRIPLET DATASET DEFINITION ===
class TripletDataset(Dataset):
    def __init__(self, subset):
        self.subset = subset
        self.targets = [self.subset.dataset.samples[i][1] for i in self.subset.indices]
        
        # Build dictionary: label → list of indices
        self.label_to_indices = {}
        for idx, label in zip(self.subset.indices, self.targets):
            self.label_to_indices.setdefault(label, []).append(idx)

        self.all_indices = self.subset.indices

    def __getitem__(self, idx):
        anchor_img, anchor_label = self.subset[idx]
        
        # Select a positive sample
        pos_idx = idx
        while pos_idx == idx:
            pos_idx = random.choice(self.label_to_indices[anchor_label])
        positive_img, _ = self.subset[self.subset.indices.index(pos_idx)]
        
        # Select a negative sample
        neg_label = anchor_label
        while neg_label == anchor_label:
            neg_label = random.choice(list(self.label_to_indices.keys()))
        neg_idx = random.choice(self.label_to_indices[neg_label])
        negative_img, _ = self.subset[self.subset.indices.index(neg_idx)]

        return anchor_img, positive_img, negative_img

    def __len__(self):
        return len(self.subset)

# === DATALOADERS ===
train_loader = DataLoader(TripletDataset(train_subset), batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS)
val_loader   = DataLoader(TripletDataset(val_subset),   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)


3. Model Definition: DINOv2 Encoder

In [3]:
from transformers import AutoModel, AutoImageProcessor

class DinoV2Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        # Load pretrained model and image processor
        self.model = AutoModel.from_pretrained("facebook/dinov2-base")
        
        # Freeze all parameters
        for param in self.model.parameters():
            param.requires_grad = False
        
        # Unfreeze only the last layers and final normalization layer
        for name, param in self.model.named_parameters():
            if any(x in name for x in ["encoder.layer.10", "encoder.layer.11", "layernorm"]):
                param.requires_grad = True

    def forward(self, x):
        # x is expected to be pre-normalized (handled in the transform pipeline)
        outputs = self.model(pixel_values=x)
        cls_token = outputs.last_hidden_state[:, 0, :]  # Extract [CLS] token
        return F.normalize(cls_token, dim=-1)            # Return L2-normalized embedding

# Instantiate model
model = DinoV2Encoder().to(DEVICE)


4. Loss Function, Training Loop, and Validation

In [4]:
# === LOSS FUNCTION AND OPTIMIZER ===
margin = 0.3
criterion = nn.TripletMarginLoss(margin=margin, p=2)

optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                  lr=LR, weight_decay=WEIGHT_DECAY)

best_val_map10 = 0.0

# === VALIDATION METRIC: MEAN AVERAGE PRECISION @10 ===
from torch.nn.functional import cosine_similarity

def evaluate_map10_on_val(val_subset, model):
    model.eval()
    loader = DataLoader(val_subset, batch_size=32, shuffle=False)

    features = []
    labels = []

    with torch.no_grad():
        for imgs, lbls in tqdm(loader, desc="Val mAP@10 - extracting"):
            imgs = imgs.to(DEVICE)
            emb = model(imgs)
            features.append(F.normalize(emb, dim=1).cpu())
            labels.append(lbls)

    features = torch.cat(features, dim=0)  # (N, 768)
    labels = torch.cat(labels, dim=0)      # (N,)
    N = features.size(0)

    sims = cosine_similarity(features.unsqueeze(1), features.unsqueeze(0), dim=-1)  # (N, N)
    sims.masked_fill_(torch.eye(N, dtype=torch.bool), -float('inf'))

    ap_total = 0
    for i in range(N):
        target_label = labels[i]
        scores = sims[i]
        topk = scores.topk(k=10).indices
        hits = (labels[topk] == target_label).float()
        precision_at_k = hits.cumsum(0) / torch.arange(1, 11)
        ap = (precision_at_k * hits).sum() / hits.sum().clamp(min=1)
        ap_total += ap.item()

    mean_ap10 = ap_total / N
    return mean_ap10

# === TRAINING LOOP WITH AMP (AUTOMATIC MIXED PRECISION) ===
for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    for anchor, positive, negative in tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} - Train"):
        anchor   = anchor.to(DEVICE)
        positive = positive.to(DEVICE)
        negative = negative.to(DEVICE)

        optimizer.zero_grad()

        with autocast():  # Enable AMP
            emb_a = model(anchor)
            emb_p = model(positive)
            emb_n = model(negative)
            loss = criterion(emb_a, emb_p, emb_n)

        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    avg_train_loss = running_loss / len(train_loader)

    # === VALIDATION PHASE ===
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for anchor, positive, negative in tqdm(val_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} - Val"):
            anchor   = anchor.to(DEVICE)
            positive = positive.to(DEVICE)
            negative = negative.to(DEVICE)

            with autocast():
                emb_a = model(anchor)
                emb_p = model(positive)
                emb_n = model(negative)
                loss = criterion(emb_a, emb_p, emb_n)

            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)

    # === VALIDATION RETRIEVAL METRIC ===
    val_map10 = evaluate_map10_on_val(val_subset, model)

    print(f"[Epoch {epoch+1}] Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val mAP@10: {val_map10:.4f}")

    # === SAVE BEST MODEL ===
    if val_map10 > best_val_map10:
        best_val_map10 = val_map10
        torch.save(model.state_dict(), "best_model.pt")
        print(">> Best model saved (improved mAP@10)!")


Epoch 1/5 - Train: 100%|██████████| 43/43 [00:51<00:00,  1.19s/it]
Epoch 1/5 - Val: 100%|██████████| 11/11 [00:09<00:00,  1.12it/s]
Val mAP@10 - extracting: 100%|██████████| 11/11 [00:11<00:00,  1.09s/it]


[Epoch 1] Train Loss: 0.0109 | Val Loss: 0.0114 | Val mAP@10: 0.9271
>> Best model saved (improved mAP@10)!


Epoch 2/5 - Train: 100%|██████████| 43/43 [00:47<00:00,  1.11s/it]
Epoch 2/5 - Val: 100%|██████████| 11/11 [00:09<00:00,  1.13it/s]
Val mAP@10 - extracting: 100%|██████████| 11/11 [00:11<00:00,  1.08s/it]


[Epoch 2] Train Loss: 0.0081 | Val Loss: 0.0037 | Val mAP@10: 0.9456
>> Best model saved (improved mAP@10)!


Epoch 3/5 - Train: 100%|██████████| 43/43 [00:47<00:00,  1.11s/it]
Epoch 3/5 - Val: 100%|██████████| 11/11 [00:09<00:00,  1.13it/s]
Val mAP@10 - extracting: 100%|██████████| 11/11 [00:11<00:00,  1.08s/it]


[Epoch 3] Train Loss: 0.0071 | Val Loss: 0.0077 | Val mAP@10: 0.9429


Epoch 4/5 - Train: 100%|██████████| 43/43 [00:47<00:00,  1.11s/it]
Epoch 4/5 - Val: 100%|██████████| 11/11 [00:09<00:00,  1.12it/s]
Val mAP@10 - extracting: 100%|██████████| 11/11 [00:11<00:00,  1.09s/it]


[Epoch 4] Train Loss: 0.0095 | Val Loss: 0.0073 | Val mAP@10: 0.9435


Epoch 5/5 - Train: 100%|██████████| 43/43 [00:47<00:00,  1.11s/it]
Epoch 5/5 - Val: 100%|██████████| 11/11 [00:09<00:00,  1.13it/s]
Val mAP@10 - extracting: 100%|██████████| 11/11 [00:11<00:00,  1.09s/it]

[Epoch 5] Train Loss: 0.0050 | Val Loss: 0.0050 | Val mAP@10: 0.9358





5. Image Retrieval

In [5]:
# === RETRIEVAL DATASETS ===
class SimpleImageDataset(Dataset):
    def __init__(self, root, transform):
        self.root = root
        self.paths = [os.path.join(root, fname) for fname in os.listdir(root)
                      if fname.lower().endswith((".png", ".jpg", ".jpeg"))]
        self.transform = transform

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        return self.transform(img), self.paths[idx]

    def __len__(self):
        return len(self.paths)

# Load query and gallery datasets
query_ds   = SimpleImageDataset(QUERY_DIR, val_transform)
gallery_ds = SimpleImageDataset(GALLERY_DIR, val_transform)

query_loader   = DataLoader(query_ds, batch_size=1, shuffle=False)
gallery_loader = DataLoader(gallery_ds, batch_size=32, shuffle=False)

# Reload the best model
model = DinoV2Encoder().to(DEVICE)
model.load_state_dict(torch.load("best_model.pt"))
model.eval()

# === FEATURE EXTRACTION FUNCTIONS ===
def extract_features(dataloader):
    features = []
    paths = []
    with torch.no_grad():
        for imgs, img_paths in tqdm(dataloader, desc="Extracting features"):
            imgs = imgs.to(DEVICE)
            emb = model(imgs)
            features.append(emb.cpu())
            paths.extend(img_paths)
    return torch.cat(features, dim=0), paths

def extract_features_with_tta(dataloader):
    """Applies Test Time Augmentation (TTA) via horizontal flip"""
    features = []
    paths = []
    with torch.no_grad():
        for imgs, img_paths in tqdm(dataloader, desc="Extracting features (TTA)"):
            imgs = imgs.to(DEVICE)
            emb_normal = model(imgs)
            imgs_flipped = torch.flip(imgs, dims=[3])  # Horizontal flip
            emb_flipped = model(imgs_flipped)
            emb_avg = (emb_normal + emb_flipped) / 2
            features.append(emb_avg.cpu())
            paths.extend(img_paths)
    return torch.cat(features, dim=0), paths

# Extract features from both query and gallery sets using TTA
query_feats, query_paths     = extract_features_with_tta(query_loader)
gallery_feats, gallery_paths = extract_features_with_tta(gallery_loader)

# === SIMILARITY CALCULATION AND OUTPUT DICTIONARY ===
from torch.nn.functional import cosine_similarity

results = {}

for q_feat, q_path in zip(query_feats, query_paths):
    similarities = cosine_similarity(q_feat.unsqueeze(0), gallery_feats)
    topk_indices = torch.topk(similarities, k=TOP_K).indices
    topk_paths = [gallery_paths[i] for i in topk_indices]
    results[q_path] = topk_paths


Extracting features (TTA): 100%|██████████| 240/240 [00:19<00:00, 12.62it/s]
Extracting features (TTA): 100%|██████████| 15/15 [00:31<00:00,  2.12s/it]


6. Save Retrieval Results to JSON

In [6]:
import json
import os

# === FORMAT RESULTS FOR JSON EXPORT ===
formatted_results = []

for q_path, topk_paths in results.items():
    formatted_results.append({
        "filename": os.path.basename(q_path),
        "samples": [os.path.basename(p) for p in topk_paths]
    })

# === SAVE TO JSON FILE ===
with open("retrieval_results.json", "w") as f:
    json.dump(formatted_results, f, indent=2)

print("✅ Retrieval completed. Output saved to 'retrieval_results.json'")


✅ Retrieval completed. Output saved to 'retrieval_results.json'
