In [None]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader, TensorDataset, Dataset
from pathlib import Path
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

from src.common import TRAIN_DATA, TEST_DATA, DEVICE
from src.common import load_data, prepare_train_data, generate_submission
from CrossFlow.diffusion.flow_matching import ClipLoss, SigLipLoss

In [2]:
def compute_procrustes_with_padding(X, Y, allow_reflection=False):
    d1, d2 = X.shape[1], Y.shape[1]
    d_max = max(d1, d2)

    if d1 < d_max:
        X = np.pad(X, ((0, 0), (0, d_max - d1)), mode='constant')
    if d2 < d_max:
        Y = np.pad(Y, ((0, 0), (0, d_max - d2)), mode='constant')
    
    H = X.T @ Y
    U, S, Vt = np.linalg.svd(H, full_matrices=False)
    R = U @ Vt
    
    if not allow_reflection and np.linalg.det(R) < 0:
        U[:, -1] *= -1
        R = U @ Vt
    
    return R

In [3]:
def normalize_embeddings(X, method='standard', stats=None):
    if method == 'none':
        return X, {'method': 'none', 'dim': X.shape[1]}
    
    elif method == 'l2':

        norms = np.linalg.norm(X, axis=1, keepdims=True)
        norms = np.where(norms == 0, 1, norms)
        X_norm = X / norms
        
        if stats is None:
            stats = {
                'method': 'l2',
                'dim': X.shape[1]
            }
        
        return X_norm, stats
    
    elif method == 'standard':
        if stats is None:
            mean = X.mean(axis=0)
            std = X.std(axis=0)
            std = np.where(std == 0, 1, std)
            
            stats = {
                'method': 'standard',
                'mean': mean,
                'std': std,
                'dim': X.shape[1]
            }
        else:
            mean = stats['mean']
            std = stats['std']
        
        X_norm = (X - mean) / std
        
        return X_norm, stats
    
def denormalize_embeddings(X_norm, method='standard', stats=None):
    if method == 'none':
        return X_norm
    
    elif method == 'l2':
        return X_norm
    
    elif method == 'standard':
        if stats is None:
            raise ValueError("Need stats for denormalization")
        
        mean = stats['mean']
        std = stats['std']
        
        if X_norm.shape[1] > len(mean):
            X_norm = X_norm[:, :len(mean)]
        
        X = X_norm * std + mean
        
        return X

In [4]:
def select_anchors_diverse(train_data, n_anchors, method='uniform'):
    caption_embeddings = train_data['captions/embeddings']
    label_matrix = train_data['captions/label']
    
    n_captions = len(caption_embeddings)
    
    gt_indices = np.argmax(label_matrix, axis=1)
    
    if method == 'uniform':
        caption_indices = np.linspace(0, n_captions - 1, n_anchors, dtype=int)
        
    elif method == 'random':
        caption_indices = np.random.choice(n_captions, n_anchors, replace=False)
        
    else:
        raise ValueError(f"Unknown method: {method}")
    
    image_indices = gt_indices[caption_indices]
    
    print(f"   Selected {len(caption_indices)} anchor pairs")
    print(f"   Caption indices range: {caption_indices.min()} - {caption_indices.max()}")
    print(f"   Image indices range: {image_indices.min()} - {image_indices.max()}")
    
    return caption_indices, image_indices

In [13]:
class ProcrustesTranslator:
    """
    Zero-shot translator using Procrustes analysis.
    """
    def __init__(self, normalization='standard', method='ortho'):
        self.normalization = normalization
        self.method = method
        self.R = None
        self.scale = None
        self.bias = None
        self.source_stats = None
        self.target_stats = None
        self.orthogonality_error = None
    
    def fit(self, X_source, Y_target):
        if torch.is_tensor(X_source):
            X_source = X_source.cpu().numpy()
        if torch.is_tensor(Y_target):
            Y_target = Y_target.cpu().numpy()
        
        print(f"  Input dimensions: {X_source.shape[1]} -> {Y_target.shape[1]}")
        print(f"  Method: {self.method}")

        X_norm, self.source_stats = normalize_embeddings(
            X_source, self.normalization
        )
        Y_norm, self.target_stats = normalize_embeddings(
            Y_target, self.normalization
        )

        if self.method == 'ortho':
            self.R, self.orthogonality_error = compute_procrustes_with_padding(
                X_norm, Y_norm, allow_reflection=False
            )
            self.scale = 1.0
            self.bias = None
            
        elif self.method == 'lortho':
            self._fit_linear(X_norm, Y_norm)
        
        else:
            raise ValueError(f"Unknown method: {self.method}")
        
        print(f"  Transformation matrix shape: {self.R.shape}")
        if self.orthogonality_error is not None:
            print(f"  Orthogonality check: {self.orthogonality_error:.6e}")
    
    def _fit_linear(self, X_norm, Y_norm):
        """Fit linear transformation (least squares)."""
        d_max = max(X_norm.shape[1], Y_norm.shape[1])

        X_padded = np.pad(X_norm, ((0, 0), (0, d_max - X_norm.shape[1])))
        Y_padded = np.pad(Y_norm, ((0, 0), (0, d_max - Y_norm.shape[1])))

        T = np.linalg.lstsq(X_padded, Y_padded, rcond=None)[0]

        U, S, Vt = np.linalg.svd(T, full_matrices=False)
        self.R = U @ Vt
        self.scale = np.mean(S)
        
        self.orthogonality_error = np.linalg.norm(self.R.T @ self.R - np.eye(self.R.shape[0]))
    
    def transform(self, X_source):

        was_tensor = torch.is_tensor(X_source)
        if was_tensor:
            device = X_source.device
            X_source = X_source.cpu().numpy()

        X_norm, _ = normalize_embeddings(
            X_source, self.normalization, self.source_stats
        )

        d_in = X_norm.shape[1]
        d_R = self.R.shape[0]
        
        if d_in < d_R:
            X_norm = np.pad(X_norm, ((0, 0), (0, d_R - d_in)), mode='constant')

        if self.method == 'ortho':
            Y_norm = X_norm @ self.R
        elif self.method == 'lortho':
            Y_norm = (X_norm @ self.R) * self.scale
        elif self.method == 'affine':
            Y_norm = (X_norm @ self.R) * self.scale + self.bias
        
        if self.normalization == 'standard':
            d_target = len(self.target_stats['mean'])
        elif self.normalization == 'l2':
            d_target = self.target_stats['dim']
        else:
            d_target = d_R
        
        if Y_norm.shape[1] > d_target:
            Y_norm = Y_norm[:, :d_target]
        
        Y = denormalize_embeddings(
            Y_norm, self.normalization, self.target_stats
        )

        if was_tensor:
            Y = torch.from_numpy(Y).float().to(device)
        
        return Y

In [6]:
import torch.nn.functional as F

class RefinedProcrustesTranslator(nn.Module):
    def __init__(self, d_in=1024, d_out=1536, procrustes_R=None, 
                 hidden_dim=None, refinement_type='residual'):
        super().__init__()
        
        if hidden_dim is None:
            hidden_dim = d_out
        
        self.d_in = d_in
        self.d_out = d_out
        self.refinement_type = refinement_type
        
        if procrustes_R is not None:
            R_sub = procrustes_R[:d_out, :d_in]
            
            self.linear = nn.Linear(d_in, d_out, bias=False)
            with torch.no_grad():
                self.linear.weight.data = torch.from_numpy(R_sub).float()
        
        # Refinement architecture
        if refinement_type == 'residual':
            self.refinement = nn.Sequential(
                nn.Linear(d_out, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Dropout(0.1),
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.LayerNorm(hidden_dim // 2),
                nn.GELU(),
                nn.Dropout(0.1),
                nn.Linear(hidden_dim // 2, d_out)
            )
            nn.init.zeros_(self.refinement[-1].weight)
            nn.init.zeros_(self.refinement[-1].bias)
            
            self.residual_weight = nn.Parameter(torch.tensor(0.1))
        
        else:
            raise ValueError(f"Unknown refinement_type: {refinement_type}")
    
    def forward(self, x):

        x_proj = self.linear(x)

        if self.refinement_type == 'residual':
            residual = self.refinement(x_proj)
            alpha = torch.sigmoid(self.residual_weight) * 0.3
            output = x_proj + alpha * residual
            
        return output

In [7]:
def train_refined_model_improved(
    procrustes_translator,
    train_data,
    val_data=None,
    epochs=20,
    lr=1e-4,
    batch_size=256,
    hidden_dim=None,
    patience=10,
    device='cuda'
):
    text_embeddings = torch.from_numpy(train_data['captions/embeddings']).float()
    image_embeddings = torch.from_numpy(train_data['images/embeddings']).float().to(device)
    label_matrix = torch.from_numpy(train_data['captions/label']).bool()
    gt_indices = torch.argmax(label_matrix.long(), dim=1)
    
    print(f"Training data: {len(text_embeddings)} captions, {len(image_embeddings)} images")
    
    if val_data is not None:
        val_text = torch.from_numpy(val_data['captions/embeddings']).float()
        val_image = torch.from_numpy(val_data['images/embeddings']).float().to(device)
        val_label_matrix = torch.from_numpy(val_data['captions/label']).bool()
        val_gt_indices = torch.argmax(val_label_matrix.long(), dim=1)
        print(f"Validation data: {len(val_text)} captions")
    
    model = RefinedProcrustesTranslator(
        d_in=text_embeddings.shape[1],
        d_out=image_embeddings.shape[1],
        procrustes_R=procrustes_translator.R,
        hidden_dim=hidden_dim
    ).to(device)
    
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    dataset = TensorDataset(text_embeddings, gt_indices)
    loader = DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=True,
        num_workers=2, 
        pin_memory=True
    )
    
    siglip_loss_fn = SigLipLoss().to(device)
    
    logit_scale = torch.tensor(np.exp(np.log(1.0 / 0.07)), device=device, dtype=torch.float32)
    logit_bias = torch.tensor(0.0, device=device, dtype=torch.float32)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)

    warmup_epochs = max(1, epochs // 10)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=epochs - warmup_epochs, eta_min=lr/10
    )
    
    best_loss = float('inf')
    best_model_state = None
    patience_counter = 0
    
    history = {
        'train_loss': [],
        'val_loss': [] if val_data else None,
        'lr': []
    }
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        num_batches = 0
        
        pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{epochs}")
        
        for batch_text, batch_gt_idx in pbar:
            batch_text = batch_text.to(device)
            batch_gt_idx = batch_gt_idx.to(device)

            pred_embeddings = model(batch_text)
            target_embeddings = image_embeddings[batch_gt_idx]

            loss = siglip_loss_fn(target_embeddings, pred_embeddings, logit_scale, logit_bias)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item()
            num_batches += 1

            pbar.set_postfix({'loss': f'{loss.item():.4f}'})

        if epoch >= warmup_epochs:
            scheduler.step()

        avg_loss = total_loss / num_batches
        history['train_loss'].append(avg_loss)
        history['lr'].append(optimizer.param_groups[0]['lr'])
        
        print(f"\nEpoch {epoch+1} Summary:")
        print(f"  Train Loss: {avg_loss:.4f}")

        if val_data is not None:
            model.eval()
            with torch.no_grad():
                val_pred = model(val_text.to(device))
                val_target = val_image[val_gt_indices]
                val_loss = siglip_loss_fn(val_target, val_pred, logit_scale, logit_bias).item()
                
                history['val_loss'].append(val_loss)
                print(f"  Val Loss: {val_loss:.4f}")
                current_loss = val_loss
        else:
            current_loss = avg_loss
        
        print(f"  LR: {optimizer.param_groups[0]['lr']:.6f}")
        
        # Save best model
        if current_loss < best_loss:
            best_loss = current_loss
            best_model_state = model.state_dict().copy()
            print(f"  New best model! (Loss: {best_loss:.4f})")
            patience_counter = 0
        else:
            patience_counter += 1
            print(f"  No improvement ({patience_counter}/{patience})")
            
            if patience_counter >= patience:
                print(f"\n  Early stopping at epoch {epoch+1}")
                break
    
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    
    print(f"Training Complete! Best Loss: {best_loss:.4f}")
    
    return model, history

In [21]:
def evaluate_retrieval(model, query_embeddings, gallery_embeddings, gt_indices, device, k=10):
    # Apply model if provided
    if model is not None:
        model.eval()
        with torch.no_grad():
            if not torch.is_tensor(query_embeddings):
                query_embeddings = torch.from_numpy(query_embeddings).float()
            query_embeddings = model(query_embeddings.to(device))
    
    query_embeddings = query_embeddings.to(device)
    gallery_embeddings = gallery_embeddings.to(device)
    gt_indices = gt_indices.to(device)
    
    n_queries = query_embeddings.shape[0]
    
    # Normalize for cosine similarity
    query_norm = torch.nn.functional.normalize(query_embeddings, p=2, dim=1)
    gallery_norm = torch.nn.functional.normalize(gallery_embeddings, p=2, dim=1)
    
    # Compute similarities in batches
    batch_size = 1000
    all_top_k_indices = []
    all_distances = []
    
    for i in range(0, n_queries, batch_size):
        batch_queries = query_norm[i:i+batch_size]
        
        # Cosine similarity
        similarities = batch_queries @ gallery_norm.T  # [batch_size, N_gallery]
        
        # Get top-k predictions
        top_k_indices = torch.topk(similarities, k=min(k, similarities.shape[1]), dim=1).indices
        all_top_k_indices.append(top_k_indices)
        
        # Compute L2 distance to ground truth
        batch_gt = gt_indices[i:i+batch_size]
        gt_embeddings = gallery_embeddings[batch_gt]
        distances = torch.norm(query_embeddings[i:i+batch_size] - gt_embeddings, dim=1)
        all_distances.append(distances)
    
    all_top_k_indices = torch.cat(all_top_k_indices, dim=0)
    all_distances = torch.cat(all_distances, dim=0)
    
    # Calculate Recall@k
    gt_indices_expanded = gt_indices.unsqueeze(1).expand(-1, all_top_k_indices.shape[1])
    correct = (all_top_k_indices == gt_indices_expanded).any(dim=1)
    recall_at_k = correct.float().mean().item()
    
    # Calculate MRR (Mean Reciprocal Rank)
    reciprocal_ranks = []
    for i in range(n_queries):
        pred_indices = all_top_k_indices[i]
        gt_idx = gt_indices[i]
        
        if gt_idx in pred_indices:
            rank = (pred_indices == gt_idx).nonzero(as_tuple=True)[0].item() + 1
            reciprocal_ranks.append(1.0 / rank)
        else:
            reciprocal_ranks.append(0.0)
    
    mrr = np.mean(reciprocal_ranks)
    mean_l2_dist = all_distances.mean().item()
    
    return recall_at_k, mrr, mean_l2_dist

In [26]:
N_ANCHORS = 120000
ANCHOR_SELECTION = 'uniform'

NORMALIZATIONS = ['standard', 'l2']

USE_REFINEMENT = True
REFINEMENT_EPOCHS = 40
REFINEMENT_LR = 5e-5
REFINEMENT_BATCH_SIZE = 256
REFINEMENT_PATIENCE = 8

PROCRUSTES_METHOD = 'lortho'
USE_ENSEMBLE = False

print("\n1. Loading training data")
train_data = load_data(TRAIN_DATA)

text_embeddings, image_embeddings, label_matrix = prepare_train_data(train_data)

text_embeddings = torch.from_numpy(train_data['captions/embeddings']).float()
image_embeddings = torch.from_numpy(train_data['images/embeddings']).float()
label_matrix = torch.from_numpy(train_data['captions/label']).bool()
gt_indices = torch.argmax(label_matrix.long(), dim=1)


n_captions = len(text_embeddings)
n_val = min(5000, n_captions // 10)

caption_indices = torch.randperm(n_captions)
val_caption_indices = caption_indices[:n_val]
train_caption_indices = caption_indices[n_val:]

print(f"   Training captions: {len(train_caption_indices)}")
print(f"   Validation captions: {len(val_caption_indices)}")

val_data = {
    'captions/embeddings': train_data['captions/embeddings'][val_caption_indices.numpy()],
    'images/embeddings': train_data['images/embeddings'],
    'captions/label': train_data['captions/label'][val_caption_indices.numpy()]
}

train_data_split = {
    'captions/embeddings': train_data['captions/embeddings'][train_caption_indices.numpy()],
    'images/embeddings': train_data['images/embeddings'],
    'captions/label': train_data['captions/label'][train_caption_indices.numpy()]
}

np.savez_compressed(
    'data/train/validation_split.npz',
    captions_embeddings=val_data['captions/embeddings'],
    images_embeddings=val_data['images/embeddings'],
    captions_label=val_data['captions/label'],
    val_caption_indices=val_caption_indices.numpy(),
    train_caption_indices=train_caption_indices.numpy()
)

print(f"   ✓ Saved validation split to data/splits/validation_split.npz")

print(f"\n2. Selecting {N_ANCHORS} anchor pairs ({ANCHOR_SELECTION})...")
caption_anchor_idx, image_anchor_idx = select_anchors_diverse(train_data_split, N_ANCHORS, method=ANCHOR_SELECTION)

X_anchors_full, Y_anchors_full, _ = prepare_train_data(train_data_split)

X_anchors = X_anchors_full[caption_anchor_idx]
Y_anchors = Y_anchors_full[caption_anchor_idx]

print(f"   Anchor text embeddings: {X_anchors.shape}")
print(f"   Anchor image embeddings: {Y_anchors.shape}")

print(f"\n3. Training Procrustes models")
print(f"   Method: {PROCRUSTES_METHOD}")

translators = {}
results = {}
best_translator = None
best_mrr = 0
best_config = None

norm = 'l2'

print(f"\nTesting normalization: {norm}")

translator = ProcrustesTranslator(
    normalization=norm,
    method=PROCRUSTES_METHOD
)
translator.fit(X_anchors, Y_anchors)

val_text = text_embeddings[val_caption_indices]
val_gt = gt_indices[val_caption_indices]
val_image = image_embeddings

translated = translator.transform(val_text)

recall_1, mrr, l2_dist = evaluate_retrieval(
    None,
    translated,
    val_image,
    val_gt,
    DEVICE,
    k=1
)

recall_5, _, _ = evaluate_retrieval(
    None,
    translated,
    val_image,
    val_gt,
    DEVICE,
    k=5
)

recall_10, _, _ = evaluate_retrieval(
    None,
    translated,
    val_image,
    val_gt,
    DEVICE,
    k=10
)

print(f"     Validation Results:")
print(f"       Recall@1:  {recall_1:.4f}")
print(f"       Recall@5:  {recall_5:.4f}")
print(f"       Recall@10: {recall_10:.4f}")
print(f"       MRR:       {mrr:.4f}")
print(f"       L2 Distance: {l2_dist:.4f}")

translators[norm] = translator
results[norm] = {
    'recall_1': recall_1,
    'recall_5': recall_5,
    'recall_10': recall_10,
    'mrr': mrr,
    'l2_dist': l2_dist
}

if mrr > best_mrr:
    best_mrr = mrr
    best_translator = translator
    best_config = norm

print(f"\n   → Best Procrustes: {best_config} with MRR = {best_mrr:.4f}")

print(f"Procrustes normalization method: {best_translator.normalization}")
print(f"Procrustes method: {best_translator.method}")


1. Loading training data
(125000,)
Train data: 125000 captions, 125000 images
   Training captions: 120000
   Validation captions: 5000
   ✓ Saved validation split to data/splits/validation_split.npz

2. Selecting 120000 anchor pairs (uniform)...
   Selected 120000 anchor pairs
   Caption indices range: 0 - 119999
   Image indices range: 0 - 24999
(120000,)
Train data: 120000 captions, 120000 images
   Anchor text embeddings: torch.Size([120000, 1024])
   Anchor image embeddings: torch.Size([120000, 1536])

3. Training Procrustes models
   Method: lortho

Testing normalization: l2
  Input dimensions: 1024 -> 1536
  Method: lortho
  Transformation matrix shape: (1536, 1536)
  Orthogonality check: 1.295021e-05
     Validation Results:
       Recall@1:  0.1190
       Recall@5:  0.2860
       Recall@10: 0.3778
       MRR:       0.1190
       L2 Distance: 25.9078

   → Best Procrustes: l2 with MRR = 0.1190
Procrustes normalization method: l2
Procrustes method: lortho


In [24]:
import matplotlib.pyplot as plt

def plot_training_history(history, save_path='training_curves.png'):
    has_val = history.get('val_loss') is not None
    has_components = any(k.startswith('train_') and k != 'train_loss' for k in history.keys())
    
    n_plots = 2 if has_components else 1
    fig, axes = plt.subplots(1, n_plots, figsize=(6*n_plots, 5))
    if n_plots == 1:
        axes = [axes]
    
    ax = axes[0]
    epochs = range(1, len(history['train_loss']) + 1)
    
    ax.plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
    if has_val:
        ax.plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
    
    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel('Loss', fontsize=12)
    ax.set_title('Training Progress', fontsize=14, fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3)

    if has_components:
        ax = axes[1]
        
        for key in ['train_clip', 'train_mse', 'train_cosine', 'train_contrastive', 
                    'train_similarity', 'train_group_consistency']:
            if key in history and history[key]:
                label = key.replace('train_', '').title()
                ax.plot(epochs, history[key], label=label, linewidth=2)
        
        ax.set_xlabel('Epoch', fontsize=12)
        ax.set_ylabel('Component Loss', fontsize=12)
        ax.set_title('Loss Components', fontsize=14, fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()

    Path(save_path).parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()

In [None]:
final_model = None

if USE_REFINEMENT:
    print(f"\n4. Neural refinement of best Procrustes model")
    print(f"   Initial validation MRR: {best_mrr:.4f}")

    refined_model, history = train_refined_model_improved(
        procrustes_translator=best_translator,
        train_data=train_data_split,
        val_data=val_data,
        epochs=REFINEMENT_EPOCHS,
        lr=REFINEMENT_LR,
        batch_size=REFINEMENT_BATCH_SIZE,
        hidden_dim=2048,
        patience=REFINEMENT_PATIENCE,
        loss_weights=None,
        device=DEVICE
    )

    try:
        plot_training_history(history, save_path='training_curves.png')
        print("Training curves saved to training_curves.png")
    except Exception as e:
        print(f" Could not plot training curves: {e}")

    print("\n   Evaluating refined model on validation set")
    
    recall_1, mrr_refined, l2_dist = evaluate_retrieval(
        refined_model,
        val_text,
        image_embeddings,
        val_gt,
        DEVICE,
        1
    )
    
    recall_5, _, _ = evaluate_retrieval(
        refined_model,
        val_text,
        image_embeddings,
        val_gt,
        DEVICE,
        5
    )
    
    recall_10, _, _ = evaluate_retrieval(
        refined_model,
        val_text,
        image_embeddings,
        val_gt,
        DEVICE,
        10
    )
    
    print(f"   Refined Model Validation Results:")
    print(f"     Recall@1:  {recall_1:.4f}")
    print(f"     Recall@5:  {recall_5:.4f}")
    print(f"     Recall@10: {recall_10:.4f}")
    print(f"     MRR:       {mrr_refined:.4f}")
    print(f"     L2 Distance: {l2_dist:.4f}")
    print(f"   Improvement: {mrr_refined - best_mrr:+.4f}")
    
    if mrr_refined > best_mrr:
        print("Refinement improved performance!")
        final_model = refined_model
        best_mrr = mrr_refined
    else:
        print("Refinement did not improve, using Procrustes only")
        final_model = best_translator
else:
    final_model = best_translator

print(f"\n6. Final evaluation on full training set")
    
if isinstance(final_model, nn.Module):
    final_model.eval()
    with torch.no_grad():
        full_translated = final_model(text_embeddings.to(DEVICE)).cpu()
else:
    full_translated = final_model.transform(text_embeddings)
    if not torch.is_tensor(full_translated):
        full_translated = torch.from_numpy(full_translated).float()

recall_1, mrr_train, _ = evaluate_retrieval(
    None,
    full_translated,
    image_embeddings,
    gt_indices,
    DEVICE,
    k=1
)

recall_5, _, _ = evaluate_retrieval(
    None,
    full_translated,
    image_embeddings,
    gt_indices,
    DEVICE,
    k=5
)

recall_10, _, _ = evaluate_retrieval(
    None,
    full_translated,
    image_embeddings,
    gt_indices,
    DEVICE,
    k=10
)

print(f"   Full Training Set Results:")
print(f"     Recall@1:  {recall_1:.4f}")
print(f"     Recall@5:  {recall_5:.4f}")
print(f"     Recall@10: {recall_10:.4f}")
print(f"     MRR:       {mrr_train:.4f}")


4. Neural refinement of best Procrustes model...
   Initial validation MRR: 0.1036

NEURAL REFINEMENT TRAINING (SigLipLoss)
Training data: 120000 captions, 25000 images
Validation data: 5000 captions
  Initialized with Procrustes solution (extracted 1536×1024 from (1536, 1536))
Model parameters: 8,399,361


Epoch 1/40: 100%|██████████| 469/469 [00:10<00:00, 44.34it/s, loss=118.4070] 



Epoch 1 Summary:
  Train Loss: 413.7385
  Val Loss: 813.1583
  LR: 0.000050
  ✓ New best model! (Loss: 813.1583)


Epoch 2/40: 100%|██████████| 469/469 [00:07<00:00, 61.58it/s, loss=75.1144] 



Epoch 2 Summary:
  Train Loss: 94.4805
  Val Loss: 667.2427
  LR: 0.000050
  ✓ New best model! (Loss: 667.2427)


Epoch 3/40: 100%|██████████| 469/469 [00:07<00:00, 58.70it/s, loss=56.1781]



Epoch 3 Summary:
  Train Loss: 71.9515
  Val Loss: 621.2647
  LR: 0.000050
  ✓ New best model! (Loss: 621.2647)


Epoch 4/40: 100%|██████████| 469/469 [00:07<00:00, 60.38it/s, loss=39.6368] 



Epoch 4 Summary:
  Train Loss: 59.8015
  Val Loss: 439.6286
  LR: 0.000050
  ✓ New best model! (Loss: 439.6286)


Epoch 5/40: 100%|██████████| 469/469 [00:07<00:00, 61.00it/s, loss=49.9999] 



Epoch 5 Summary:
  Train Loss: 51.3397
  Val Loss: 390.6216
  LR: 0.000050
  ✓ New best model! (Loss: 390.6216)


Epoch 6/40: 100%|██████████| 469/469 [00:07<00:00, 60.10it/s, loss=36.0675] 



Epoch 6 Summary:
  Train Loss: 44.9923
  Val Loss: 456.5824
  LR: 0.000050
  No improvement (1/8)


Epoch 7/40: 100%|██████████| 469/469 [00:07<00:00, 61.05it/s, loss=40.6797] 



Epoch 7 Summary:
  Train Loss: 40.2761
  Val Loss: 372.0241
  LR: 0.000049
  ✓ New best model! (Loss: 372.0241)


Epoch 8/40: 100%|██████████| 469/469 [00:07<00:00, 62.51it/s, loss=30.2124] 



Epoch 8 Summary:
  Train Loss: 36.3128
  Val Loss: 423.8181
  LR: 0.000049
  No improvement (1/8)


Epoch 9/40: 100%|██████████| 469/469 [00:07<00:00, 59.81it/s, loss=28.4344] 



Epoch 9 Summary:
  Train Loss: 32.7927
  Val Loss: 316.8164
  LR: 0.000048
  ✓ New best model! (Loss: 316.8164)


Epoch 10/40: 100%|██████████| 469/469 [00:07<00:00, 60.10it/s, loss=22.2474] 



Epoch 10 Summary:
  Train Loss: 29.6091
  Val Loss: 253.2075
  LR: 0.000047
  ✓ New best model! (Loss: 253.2075)


Epoch 11/40: 100%|██████████| 469/469 [00:07<00:00, 59.31it/s, loss=20.9568] 



Epoch 11 Summary:
  Train Loss: 27.2778
  Val Loss: 252.5251
  LR: 0.000046
  ✓ New best model! (Loss: 252.5251)


Epoch 12/40: 100%|██████████| 469/469 [00:08<00:00, 55.47it/s, loss=26.2147]



Epoch 12 Summary:
  Train Loss: 25.0491
  Val Loss: 297.9079
  LR: 0.000045
  No improvement (1/8)


Epoch 13/40: 100%|██████████| 469/469 [00:08<00:00, 56.90it/s, loss=18.8137] 



Epoch 13 Summary:
  Train Loss: 23.0415
  Val Loss: 275.7688
  LR: 0.000043
  No improvement (2/8)


Epoch 14/40: 100%|██████████| 469/469 [00:08<00:00, 53.59it/s, loss=17.8078]



Epoch 14 Summary:
  Train Loss: 20.9818
  Val Loss: 216.5830
  LR: 0.000042
  ✓ New best model! (Loss: 216.5830)


Epoch 15/40: 100%|██████████| 469/469 [00:08<00:00, 55.54it/s, loss=16.1023] 



Epoch 15 Summary:
  Train Loss: 19.5392
  Val Loss: 247.8447
  LR: 0.000040
  No improvement (1/8)


Epoch 16/40: 100%|██████████| 469/469 [00:08<00:00, 58.03it/s, loss=15.4613] 



Epoch 16 Summary:
  Train Loss: 18.0184
  Val Loss: 241.5046
  LR: 0.000039
  No improvement (2/8)


Epoch 17/40: 100%|██████████| 469/469 [00:08<00:00, 55.79it/s, loss=16.3782]



Epoch 17 Summary:
  Train Loss: 16.7204
  Val Loss: 196.0521
  LR: 0.000037
  ✓ New best model! (Loss: 196.0521)


Epoch 18/40: 100%|██████████| 469/469 [00:08<00:00, 56.21it/s, loss=16.6711] 



Epoch 18 Summary:
  Train Loss: 15.3249
  Val Loss: 191.5651
  LR: 0.000035
  ✓ New best model! (Loss: 191.5651)


Epoch 19/40: 100%|██████████| 469/469 [00:08<00:00, 57.88it/s, loss=12.7421] 



Epoch 19 Summary:
  Train Loss: 14.3898
  Val Loss: 179.6608
  LR: 0.000033
  ✓ New best model! (Loss: 179.6608)


Epoch 20/40: 100%|██████████| 469/469 [00:08<00:00, 53.65it/s, loss=12.8027] 



Epoch 20 Summary:
  Train Loss: 13.4520
  Val Loss: 170.9258
  LR: 0.000031
  ✓ New best model! (Loss: 170.9258)


Epoch 21/40: 100%|██████████| 469/469 [00:08<00:00, 52.48it/s, loss=11.0181] 



Epoch 21 Summary:
  Train Loss: 12.5144
  Val Loss: 188.3420
  LR: 0.000029
  No improvement (1/8)


Epoch 22/40: 100%|██████████| 469/469 [00:07<00:00, 58.69it/s, loss=9.7003]  



Epoch 22 Summary:
  Train Loss: 11.5666
  Val Loss: 208.8699
  LR: 0.000028
  No improvement (2/8)


Epoch 23/40: 100%|██████████| 469/469 [00:08<00:00, 54.27it/s, loss=11.3013]



Epoch 23 Summary:
  Train Loss: 10.6282
  Val Loss: 142.0755
  LR: 0.000026
  ✓ New best model! (Loss: 142.0755)


Epoch 24/40: 100%|██████████| 469/469 [00:08<00:00, 53.61it/s, loss=9.8441]  



Epoch 24 Summary:
  Train Loss: 9.9709
  Val Loss: 158.5658
  LR: 0.000024
  No improvement (1/8)


Epoch 25/40: 100%|██████████| 469/469 [00:09<00:00, 51.86it/s, loss=7.4736] 



Epoch 25 Summary:
  Train Loss: 9.3681
  Val Loss: 138.3880
  LR: 0.000022
  ✓ New best model! (Loss: 138.3880)


Epoch 26/40: 100%|██████████| 469/469 [00:08<00:00, 56.62it/s, loss=8.4007]  



Epoch 26 Summary:
  Train Loss: 8.7406
  Val Loss: 152.0197
  LR: 0.000020
  No improvement (1/8)


Epoch 27/40: 100%|██████████| 469/469 [00:07<00:00, 62.07it/s, loss=7.7045]  



Epoch 27 Summary:
  Train Loss: 8.1955
  Val Loss: 145.7334
  LR: 0.000018
  No improvement (2/8)


Epoch 28/40: 100%|██████████| 469/469 [00:08<00:00, 55.18it/s, loss=5.1429] 



Epoch 28 Summary:
  Train Loss: 7.7516
  Val Loss: 125.5144
  LR: 0.000016
  ✓ New best model! (Loss: 125.5144)


Epoch 29/40: 100%|██████████| 469/469 [00:08<00:00, 55.71it/s, loss=5.5451] 



Epoch 29 Summary:
  Train Loss: 7.2293
  Val Loss: 123.8620
  LR: 0.000015
  ✓ New best model! (Loss: 123.8620)


Epoch 30/40: 100%|██████████| 469/469 [00:08<00:00, 57.48it/s, loss=5.9224] 



Epoch 30 Summary:
  Train Loss: 6.8397
  Val Loss: 130.9074
  LR: 0.000013
  No improvement (1/8)


Epoch 31/40: 100%|██████████| 469/469 [00:08<00:00, 53.16it/s, loss=4.6979] 



Epoch 31 Summary:
  Train Loss: 6.4856
  Val Loss: 123.6391
  LR: 0.000012
  ✓ New best model! (Loss: 123.6391)


Epoch 32/40: 100%|██████████| 469/469 [00:08<00:00, 52.88it/s, loss=6.2285] 



Epoch 32 Summary:
  Train Loss: 6.1216
  Val Loss: 111.4521
  LR: 0.000010
  ✓ New best model! (Loss: 111.4521)


Epoch 33/40: 100%|██████████| 469/469 [00:09<00:00, 48.94it/s, loss=5.0770]



Epoch 33 Summary:
  Train Loss: 5.8964
  Val Loss: 110.4191
  LR: 0.000009
  ✓ New best model! (Loss: 110.4191)


Epoch 34/40: 100%|██████████| 469/469 [00:09<00:00, 48.48it/s, loss=4.9640]



Epoch 34 Summary:
  Train Loss: 5.5723
  Val Loss: 116.8246
  LR: 0.000008
  No improvement (1/8)


Epoch 35/40: 100%|██████████| 469/469 [00:09<00:00, 48.37it/s, loss=4.2857]



Epoch 35 Summary:
  Train Loss: 5.3662
  Val Loss: 109.5287
  LR: 0.000007
  ✓ New best model! (Loss: 109.5287)


Epoch 36/40: 100%|██████████| 469/469 [00:09<00:00, 50.89it/s, loss=3.7749]



Epoch 36 Summary:
  Train Loss: 5.1741
  Val Loss: 108.4450
  LR: 0.000006
  ✓ New best model! (Loss: 108.4450)


Epoch 37/40: 100%|██████████| 469/469 [00:08<00:00, 54.79it/s, loss=3.3972] 



Epoch 37 Summary:
  Train Loss: 4.9984
  Val Loss: 105.1929
  LR: 0.000006
  ✓ New best model! (Loss: 105.1929)


Epoch 38/40: 100%|██████████| 469/469 [00:08<00:00, 53.04it/s, loss=5.7518]



Epoch 38 Summary:
  Train Loss: 4.9203
  Val Loss: 107.5419
  LR: 0.000005
  No improvement (1/8)


Epoch 39/40: 100%|██████████| 469/469 [00:08<00:00, 55.46it/s, loss=3.2061] 



Epoch 39 Summary:
  Train Loss: 4.8475
  Val Loss: 104.9796
  LR: 0.000005
  ✓ New best model! (Loss: 104.9796)


Epoch 40/40: 100%|██████████| 469/469 [00:07<00:00, 62.05it/s, loss=2.9647] 



Epoch 40 Summary:
  Train Loss: 4.7352
  Val Loss: 102.4799
  LR: 0.000005
  ✓ New best model! (Loss: 102.4799)

Training Complete! Best Loss: 102.4799
   ✓ Training curves saved to training_curves.png

   Evaluating refined model on validation set...
   Refined Model Validation Results:
     Recall@1:  0.1378
     Recall@5:  0.3528
     Recall@10: 0.4750
     MRR:       0.1378
     L2 Distance: 31.9781
   Improvement: +0.0342
   ✓ Refinement improved performance!
After neural refinement norms: min=13.2261, max=23.4800, mean=18.6310
Target (DINOv2) norms: min=22.3201, max=35.4587, mean=25.9392


RuntimeError: The size of tensor a (5000) must match the size of tensor b (125000) at non-singleton dimension 0

In [None]:
print("\n7. Generating submission for test set")
test_data = load_data(TEST_DATA)
test_text_embeddings = torch.from_numpy(test_data['captions/embeddings']).float()

if isinstance(final_model, nn.Module):
    final_model.eval()
    with torch.no_grad():
        test_translated = final_model(test_text_embeddings.to(DEVICE))
        test_translated = test_translated.cpu()
elif hasattr(final_model, 'transform'):
    test_translated = final_model.transform(test_text_embeddings)
    if not torch.is_tensor(test_translated):
        test_translated = torch.from_numpy(test_translated).float()
else:
    raise ValueError(f"Unknown model type: {type(final_model)}")

print(f"   Test translated embeddings: {test_translated.shape}")

sample_ids = test_data['captions/ids']
submission_file = f"submission_{best_config.replace(' ', '_').lower()}.csv"
submission = generate_submission(
    sample_ids,
    test_translated.cpu().numpy(),
    output_file=submission_file
)

# Save checkpoint
checkpoint = {
    'config': {
        'n_anchors': N_ANCHORS,
        'anchor_selection': ANCHOR_SELECTION,
        'normalization': best_config,
        'procrustes_method': PROCRUSTES_METHOD,
        'use_refinement': USE_REFINEMENT,
        'use_ensemble': USE_ENSEMBLE,
    },
    'results': {
        'validation_mrr': best_mrr,
        'training_mrr': mrr_train,
        'all_results': results
    }
}

if isinstance(final_model, nn.Module):
    checkpoint['model_state_dict'] = final_model.state_dict()
    torch.save(checkpoint, 'final_model_checkpoint.pt')
    print(f"Model checkpoint saved to final_model_checkpoint.pt")

print("COMPLETE!")
print(f"Configuration: {best_config}")
print(f"Validation MRR: {best_mrr:.4f}")
print(f"Training MRR: {mrr_train:.4f}")
print(f"Method: {'Refined ' if USE_REFINEMENT and isinstance(final_model, nn.Module) else ''}"
        f"{'Ensemble ' if hasattr(final_model, 'translators') else ''}"
        f"Procrustes ({PROCRUSTES_METHOD})")
print(f"Anchors: {N_ANCHORS} ({ANCHOR_SELECTION})")
print(f"Submission: {submission_file}")