In [None]:
!pip -q install ftfy regex tqdm
!pip -q install git+https://github.com/openai/CLIP.git


In [None]:
# Download the private dataset zip
# Unzip the dataset into a folder
!gdown 1ttmGZdAZJ-4pA9Kz5SMfvff-G_-Xn0uM
!unzip -q ./ENTRep_Private_Dataset_Update.zip -d ./ENTRep_Private_Dataset_update/

# Download the CSV and related private image data
!gdown 1d66ZMIef0HN8kTfsLzLKlgoAA5NXsI2I
!unzip ./ENTRep_Track2_Private_Data.zip

In [None]:
# download the pretrained **Vector Field** model
!gdown 1KgzoCoaDoFsLYReWHtX-2MdvflrxSGTQ

In [None]:
# use a trained **Rerank Model** to rerank results
!gdown 1d-JhNGHCKGEIc_9vJYJtHwUPGeShJoOC
!unzip -q Rerank_model.zip -d convnextbase-ensemble-metalearner

In [1]:
import os
import torch
import clip
from PIL import Image
import torch.nn.functional as F
import random
import numpy as np

import torch.nn as nn

from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from PIL import Image
import json
import csv
from pathlib import Path
from tqdm import tqdm
import timm
import pickle
import warnings
import torch.nn.init as init


In [2]:
embed_dim = 512
class GaussianFourierProjection(nn.Module):
    def __init__(self, embed_dim, scale=10.0):
        super().__init__()
        # Fixed random weights for projecting scalar t to higher frequency space
        self.W = nn.Parameter(torch.randn(1, embed_dim // 2) * scale, requires_grad=False)

    def forward(self, t):
        # Ensure t has shape [B, 1]
        if t.ndim == 1:
            t = t.unsqueeze(-1)
        proj = t * self.W  # Shape: [B, D/2]
        # Return sinusoidal and cosinusoidal projection: [sin(tW), cos(tW)] → Shape: [B, D]
        return torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)

class VectorField(nn.Module):
    def __init__(self, dim, t_dim=32, hidden_dim=256, n_heads=4, dropout_prob=0.1):
        super().__init__()
        self.x_norm = nn.LayerNorm(dim)  # Normalize input embeddings
        self.time_encoder = GaussianFourierProjection(t_dim)  # Time embedding module
        self.dropout = nn.Dropout(dropout_prob)

        # Create multiple independent heads (like a lightweight transformer block)
        self.heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(dim + t_dim, hidden_dim),     # Project input + time
                nn.LayerNorm(hidden_dim),
                nn.SiLU(),                               # Activation: SiLU 
                nn.Dropout(dropout_prob),
                nn.Linear(hidden_dim, dim)              # Back to original embedding dimension
            ) for _ in range(n_heads)
        ])

        self.res_weight = nn.Parameter(torch.tensor(1.0))  # Learnable residual scaling
        self.out_norm = nn.LayerNorm(dim)  # Final normalization (not applied directly here)
        self.initialize_weights()

    
    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                # Use Kaiming initialization (good for ReLU/SiLU)
                init.kaiming_uniform_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    init.zeros_(m.bias)


    def forward(self, x, t):
        # Handle scalar or 1D tensor time input → ensure shape [B, 1]
        if not isinstance(t, torch.Tensor):
            t = torch.full((x.shape[0], 1), t, device=x.device)
        elif t.ndim == 0:
            t = t.expand(x.shape[0], 1)
        elif t.ndim == 1:
            t = t.unsqueeze(-1)

        x_normed = self.x_norm(x)                     # Normalize input
        t_encoded = self.time_encoder(t.to(x.device)) # Encode time t
        inp = torch.cat([x_normed, t_encoded], dim=-1)  # Concatenate along feature dim

        # Pass through each head and average their outputs
        head_outs = [head(inp) for head in self.heads]
        out = torch.mean(torch.stack(head_outs), dim=0)

        # Add residual connection scaled by learnable weight
        return out + self.res_weight * x




In [3]:
warnings.filterwarnings('ignore')

class RerankModel:
    def __init__(self, model_dir, device='cuda'):
        self.model_dir = Path(model_dir)
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')

        self.class_names = [
            "nose-right", "nose-left", "ear-right",
            "ear-left", "vc-open", "vc-closed", "throat"
        ]
        self.num_classes = len(self.class_names)

        print(f"Using device: {self.device}")

    def load_ensemble_models(self):
        """Load all models from the ensemble directory."""
        print("Loading ensemble models...")

        with open(self.model_dir / 'ensemble_info.pkl', 'rb') as f:
            ensemble_info = pickle.load(f)

        models = []
        model_names = ensemble_info['models']
        weights = ensemble_info['weights']

        for i, model_name in enumerate(model_names):
            print(f"Loading model {i+1}/{len(model_names)}: {model_name}")

            if 'convnext' in model_name.lower():
                base_name = "convnext_base.fb_in22k_ft_in1k"
            elif 'efficientnet' in model_name.lower():
                base_name = "efficientnet_b4"
            else:
                base_name = "convnext_base.fb_in22k_ft_in1k"

            model = timm.create_model(base_name, pretrained=False, num_classes=self.num_classes)
            state_dict = torch.load(self.model_dir / f"ensemble_model_{i}.pt", map_location=self.device)
            model.load_state_dict(state_dict)
            model.to(self.device)
            model.eval()

            models.append({'model': model, 'weight': weights[i], 'name': model_name})

        print(f"Loaded {len(models)} models successfully.")
        return models

    def load_test_data(self, csv_path, img_dir):
        """Load test image paths from CSV file."""
        test_files = []
        with open(csv_path, 'r') as f:
            reader = csv.reader(f)
            for row in reader:
                if row:
                    img_name = row[0].strip()
                    img_path = Path(img_dir) / img_name
                    if img_path.exists():
                        test_files.append(str(img_path))
                    else:
                        print(f"Warning: Image not found: {img_path}")

        print(f"Loaded {len(test_files)} test images.")
        return test_files

    def get_tta_transforms(self, img_size=224, n_aug=5):
        """Generate a list of transforms for Test Time Augmentation (TTA)."""
        class ResizeOrPad:
            def __init__(self, min_size):
                self.min_size = min_size

            def __call__(self, img):
                w, h = img.size
                if w < self.min_size or h < self.min_size:
                    scale = self.min_size / min(w, h)
                    new_w = int(w * scale)
                    new_h = int(h * scale)
                    return T.functional.resize(img, (new_h, new_w))
                return img

        transforms = [
            T.Compose([
                ResizeOrPad(img_size),
                T.Resize((img_size, img_size)),
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
            ])
        ]

        import random
        for i in range(n_aug - 1):
            resize_delta = random.choice([-20, -10, 0, 10, 20])
            target_size = max(img_size, img_size + resize_delta)

            transforms.append(
                T.Compose([
                    ResizeOrPad(target_size + 20),
                    T.Resize((target_size + 10, target_size + 10)),
                    T.CenterCrop(img_size) if i % 2 == 0 else T.RandomCrop(img_size),
                    T.RandomApply([
                        T.ColorJitter(
                            brightness=random.uniform(0.1, 0.2),
                            contrast=random.uniform(0.1, 0.2),
                            saturation=random.uniform(0.05, 0.15),
                            hue=random.uniform(0.02, 0.05)
                        )
                    ], p=0.8),
                    T.RandomChoice([
                        T.GaussianBlur(3, sigma=(0.1, 0.5)),
                        nn.Identity()
                    ]),
                    T.ToTensor(),
                    T.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
                ])
            )
        return transforms

    def predict_single_image_tta(self, model, img_path, n_aug=5):
        """Predict a single image using Test Time Augmentation (TTA)."""
        image = Image.open(img_path).convert('RGB')
        tta_transforms = self.get_tta_transforms(n_aug=n_aug)

        aug_probs = []
        with torch.no_grad():
            for transform in tta_transforms:
                img_tensor = transform(image).unsqueeze(0).to(self.device)
                output = model(img_tensor)
                probs = F.softmax(output, dim=1)
                aug_probs.append(probs)

        weights = torch.tensor([1.5] + [1.0] * (n_aug - 1)).to(self.device)
        weights = weights / weights.sum()

        final_probs = torch.zeros_like(aug_probs[0])
        for i, probs in enumerate(aug_probs):
            final_probs += probs * weights[i]

        return final_probs.cpu().numpy().squeeze()

    def ensemble_predict_batch(self, models, test_files, use_tta=True, batch_size=32):
        """Predict a batch of images by ensembling multiple models."""
        predictions = {}
        for img_path in test_files:
            img_name = Path(img_path).name
            all_model_probs = []
            model_weights = []

            for model_info in models:
                model = model_info['model']
                weight = model_info['weight']

                if use_tta:
                    probs = self.predict_single_image_tta(model, img_path, n_aug=3)
                else:
                    transform = T.Compose([
                        T.Resize((224, 224)),
                        T.ToTensor(),
                        T.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
                    ])
                    image = Image.open(img_path).convert('RGB')
                    img_tensor = transform(image).unsqueeze(0).to(self.device)
                    with torch.no_grad():
                        output = model(img_tensor)
                        probs = F.softmax(output, dim=1).cpu().numpy().squeeze()

                all_model_probs.append(probs)
                model_weights.append(weight ** 2)

            model_weights = np.array(model_weights)
            model_weights /= model_weights.sum()

            final_probs = np.zeros_like(all_model_probs[0])
            for i, probs in enumerate(all_model_probs):
                final_probs += probs * model_weights[i]

            pred_class = np.argmax(final_probs)
            predictions[img_name] = int(pred_class)

        return predictions

    def predict_and_save(self, csv_path, img_dir, output_path, use_tta=True):
        """Run predictions and save the results to a JSON file."""
        models = self.load_ensemble_models()
        test_files = self.load_test_data(csv_path, img_dir)

        print("\nStarting prediction...")
        predictions = self.ensemble_predict_batch(models, test_files, use_tta=use_tta)

        print(f"\nSaving results to {output_path}")
        with open(output_path, 'w') as f:
            json.dump(predictions, f, indent=2)

        print(f"Saved {len(predictions)} predictions")

        print("\nSample predictions:")
        for i, (img_name, pred) in enumerate(list(predictions.items())[:5]):
            print(f"  {img_name}: {pred} ({self.class_names[pred]})")

        return predictions


In [5]:

#Function to preprocess and embed an image 
def preprocess_and_embed(image_path):
    try:
        # Load and convert image to RGB
        image = Image.open(image_path).convert("RGB")
    except Exception as e:
        raise RuntimeError(f"Image open failed: {e}")

    # Apply CLIP preprocessing and move to device
    image = preprocess(image).unsqueeze(0).to(device)

    with torch.no_grad():
        # Generate image embedding using CLIP
        emb = model.encode_image(image)

        # Check that the embedding has expected shape [1, 512]
        if emb.ndim != 2:
            raise RuntimeError(f"Unexpected embedding shape: {emb.shape}")

        # Normalize the embedding vector
        emb = emb / emb.norm(dim=-1, keepdim=True)

        # Ensure correct dtype before passing to vector field
        emb = emb.to(next(vf.parameters()).dtype)

        # Apply learned vector field at time t = 0.0
        emb_vf = vf(emb, t=torch.tensor([[0.0]], device=emb.device)).squeeze(0).cpu()

        # Ensure the output is a 1D vector
        if emb_vf.ndim != 1:
            raise RuntimeError(f"VectorField output is not 1D: {emb_vf.shape}")

    return emb_vf  # Final transformed embedding
    
def rerank_topk_by_class(candidate_names, candidate_scores, predictor, models, img_dir, bonus=0.01):
    """
    Re-rank a list of top-k image candidates by promoting those that share the dominant class.

    Args:
        candidate_names (List[str]): Top-k image file names to be re-ranked.
        candidate_scores (List[float]): Corresponding similarity scores.
        predictor (RerankModel): Classifier wrapper to load and use ensemble models.
        models (List): List of loaded classification models.
        img_dir (str): Path to image directory.
        bonus (float): Bonus score to add for images in the dominant class.

    Returns:
        reranked_names (List[str]): Candidate names sorted by adjusted score.
    """
    name_to_class = {}   # Map image name → predicted class
    class_votes = {}     # Count class frequency

    # Predict class for each candidate
    for img_name in candidate_names:
        img_path = os.path.join(img_dir, img_name)
        pred_class = predictor.ensemble_predict_batch(models, [img_path], use_tta=False)[img_name]
        name_to_class[img_name] = pred_class
        class_votes[pred_class] = class_votes.get(pred_class, 0) + 1

    # Identify majority class
    main_class = max(class_votes, key=class_votes.get)

    # Apply bonus score for images in majority class
    reranked = []
    for name, score in zip(candidate_names, candidate_scores):
        bonus_score = bonus if name_to_class[name] == main_class else 0.0
        reranked.append((name, score + bonus_score))

    # Sort descending by score
    reranked = sorted(reranked, key=lambda x: -x[1])
    reranked_names = [name for name, _ in reranked]
    return reranked_names


## **Retrieval Pipeline Description**

Image-to-image retrieval is performed in 4 main steps:

1. **Image Loading**: Read all image file names from a CSV file.

2. **Embedding Extraction**: Use CLIP + VectorField to compute embeddings for all images.

3. **Similarity Search**: For each query image, retrieve top-5 similar images using cosine similarity.

4. **Class-based Reranking**: Use an ensemble classifier to predict classes of the top-5 candidates. Images sharing the dominant class receive a score bonus (+0.01) before reranking.

The top-1 result after reranking is stored for each query image.


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load pre-trained CLIP model (ViT-B/32) for image embedding
model, preprocess = clip.load("ViT-B/32", device=device)

# Initialize the learned time-dependent vector field module
vf = VectorField(embed_dim).to(device).float()

# Load the pre-trained weights for the vector field
vf.load_state_dict(torch.load("./vf_model.pth", map_location=device))
vf.eval()  # Set to evaluation mode

# Initialize the reranking model used for predicting anatomical regions
predictor = RerankModel(model_dir="./convnextbase-ensemble-metalearner")

# Load ensemble of classification models from the specified directory
models = predictor.load_ensemble_models()

Using device: cuda
Loading ensemble models...
Loading model 1/6: convnext_base.fb_in22k_ft_in1k_full
Loading model 2/6: convnext_base.fb_in22k_ft_in1k_fold3
Loading model 3/6: convnext_base.fb_in22k_ft_in1k_fold1
Loading model 4/6: convnext_base.fb_in22k_ft_in1k_fold4
Loading model 5/6: convnext_base.fb_in22k_ft_in1k_fold5
Loading model 6/6: convnext_base.fb_in22k_ft_in1k_fold2
Loaded 6 models successfully.


In [7]:
# Step 1: Read image names from CSV
image_dir = "./ENTRep_Private_Dataset_update/imgs"  # Path to directory containing test images
csv_path = "i2i.csv"  # CSV file with image names (1 name per line)

with open(csv_path, "r") as f:
    reader = csv.reader(f)
    image_list = [row[0].strip() for row in reader if row]  # List of image file names


#Step 2: Compute embeddings for all images
# Output: all_embeddings[img_name] = torch.Tensor of shape [D]

all_embeddings = {}

for img_name in image_list:
    img_path = os.path.join(image_dir, img_name)  # Full path to the image
    try:
        emb = preprocess_and_embed(img_path)  # Generate CLIP+VectorField embedding
        all_embeddings[img_name] = emb
       
    except Exception as e:
        print(f"❌ Error with {img_name}: {e}")
print(" Embedded all images")

#Step 3: Filter valid embeddings and stack into a tensor
# Output: embeddings: torch.Tensor [N, D], img_names: List[str]
valid_img_names = []
valid_embeddings = []

for name in all_embeddings:
    emb = all_embeddings[name]
    # Ensure embedding is a 1D torch.Tensor (e.g., shape = [512])
    if isinstance(emb, torch.Tensor) and emb.ndim == 1:
        valid_img_names.append(name)
        valid_embeddings.append(emb)
    else:
        print(f"⚠️ Invalid embedding: {name} → {type(emb)}, shape = {getattr(emb, 'shape', None)}")

if not valid_embeddings:
    raise ValueError("❌ No valid embeddings found!")

# Stack embeddings into a tensor and normalize
img_names = valid_img_names
embeddings = torch.stack(valid_embeddings)  # Shape: [N, D]
embeddings = F.normalize(embeddings, dim=-1)  # Cosine-normalized embeddings


# Step 4: Image Retrieval with Class-based Reranking
# For each image, find top-5 most similar, then re-rank by majority class
retrieval_results = {}

for i, query_name in enumerate(img_names):
    query_emb = embeddings[i].unsqueeze(0)  # Shape: [1, D]
    # Exclude current image to avoid self-matching
    others = torch.cat([embeddings[:i], embeddings[i+1:]], dim=0)  # Shape: [N-1, D]

    # Compute cosine similarities
    sims = (others @ query_emb.T).squeeze()  # Shape: [N-1]

    # Retrieve top-5 most similar images
    topk = torch.topk(sims, k=5)
    topk_indices = topk.indices.tolist()
    topk_scores = topk.values.tolist()

    # Map indices to actual image names (adjust index if skipped self)
    candidate_names = []
    candidate_scores = []
    for j, idx in enumerate(topk_indices):
        idx_adjusted = idx if idx < i else idx + 1
        candidate_names.append(img_names[idx_adjusted])
        candidate_scores.append(topk_scores[j])

    # Re-rank candidates using class prediction
    reranked = rerank_topk_by_class(
        candidate_names, candidate_scores,
        predictor, models, image_dir, bonus=0.01
    )

    # Save only the top-1 match
    retrieval_results[query_name] = reranked[0]
print("Retrieval completed for all images")


 Embedded all images
Retrieval completed for all images


In [9]:
output_json = "rerank003.json"
# Save results
with open(output_json, "w") as f:
    json.dump(retrieval_results, f, indent=2)
print(f"Saved top-1 retrieval results with rerank to: {output_json}")

Saved top-1 retrieval results with rerank to: rerank003.json
