# Innitialize

In [None]:
import os
from google.colab import userdata
os.environ['KAGGLE_USERNAME'] = userdata.get('KAGGLE_USERNAME')
os.environ['KAGGLE_KEY'] = userdata.get('KAGGLE_KEY')
os.environ['GIT_TOKEN'] = userdata.get('GIT_TOKEN')

In [None]:
import copy

In [None]:
!git clone https://$GIT_TOKEN@$repo
!kaggle competitions download -c aml-competition
!unzip -o aml-competition.zip -d data
!git clone https://github.com/Mamiglia/challenge.git

In [None]:
from typing import List
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from pathlib import Path
from tqdm.auto import tqdm
import torch.nn.functional as F
import numpy as np
from enum import Enum
import math
import random
import matplotlib.pyplot as plt

In [None]:
from challenge.src.common import load_data, prepare_train_data, generate_submission
from challenge.src.eval import visualize_retrieval, evaluate_retrieval
from challenge.src.eval.metrics import recall_at_k, ndcg,mrr

In [None]:
%pip install optuna

In [None]:
import optuna

In [None]:
!pip install torchdiffeq

In [None]:
def set_seed(seed=42):
    """Ensure deterministic reproducibility."""
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
set_seed()

In [None]:
WORKING_DIR = Path.cwd()

In [None]:
MODELS_DIR = Path("/").absolute()
MODEL_PATH = MODELS_DIR / "default.pth"

In [None]:
DATA_PATH = WORKING_DIR / "data"

In [None]:
EPOCHS = 20
BATCH_SIZE = 256
LR = 0.001
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
!git clone https://github.com/qihao067/CrossFlow.git

In [None]:
sys.path.append(str(WORKING_DIR / "CrossFlow"))
from CrossFlow.diffusion.flow_matching import ClipLoss, SigLipLoss

# Actuall code

## Basic definitions

In [None]:
class Statistics():
  losses = []
  best_loss = float("inf")
  best_epoch_index = -1
  mse_losses = []
  cos_losses = []
  contrastive_losses = []

In [None]:
class DataKeeper():
  train_data = None
  train_loader = None
  val_loader = None
  train_dataset = None
  val_dataset = None
  val_caption_text = None
  val_text_embd = None
  val_img_file = None
  val_img_embd = None
  val_label = None
  def create_loaders(self, batch_size = BATCH_SIZE):
    self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
    self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size)

In [None]:
def get_train_data(data_path=DATA_PATH, split_ratio = 0.8, batch_size = BATCH_SIZE):
  data_keeper = DataKeeper()
  data_keeper.train_data = load_data(data_path/"train/train/train.npz")
  X, y, label = prepare_train_data(data_keeper.train_data)
  DATASET_SIZE = len(X)
  train_size = int(DATASET_SIZE * split_ratio)
  TRAIN_SPLIT = torch.zeros(len(X), dtype=bool)
  TRAIN_SPLIT[:train_size] = True
  X_train, X_val = X[TRAIN_SPLIT], X[~TRAIN_SPLIT]
  y_train, y_val = y[TRAIN_SPLIT], y[~TRAIN_SPLIT]
  data_keeper.train_dataset = TensorDataset(X_train, y_train)
  data_keeper.val_dataset = TensorDataset(X_val, y_val)
  data_keeper.train_loader = DataLoader(data_keeper.train_dataset, batch_size=batch_size, shuffle=True)
  data_keeper.val_loader = DataLoader(data_keeper.val_dataset, batch_size=batch_size)
  img_VAL_SPLIT = label[~TRAIN_SPLIT].sum(dim=0) > 0
  data_keeper.val_caption_text = data_keeper.train_data['captions/text'][~TRAIN_SPLIT]
  data_keeper.val_text_embd = X_val
  data_keeper.img_VAL_SPLIT = label[~TRAIN_SPLIT].sum(dim=0) > 0
  data_keeper.val_img_file = data_keeper.train_data['images/names'][img_VAL_SPLIT]
  data_keeper.val_img_embd = torch.from_numpy(data_keeper.train_data['images/embeddings'][img_VAL_SPLIT])
  data_keeper.val_label = np.nonzero(data_keeper.train_data['captions/label'][~TRAIN_SPLIT][:,img_VAL_SPLIT])[1]
  return data_keeper


In [None]:
def save_model(model, path=MODEL_PATH, verbose=True):
  torch.save(model.state_dict(), path)
  print(f"Model saved to {path}")

In [None]:
def load_model(model, path=MODEL_PATH, verbose=True):
  model.load_state_dict(torch.load(path))
  if verbose:
    print(f"Model loaded from {path}")
  return model

In [None]:
def sample_and_visualize(model, data_keeper, device=DEVICE, number_of_indices=5, dataset_path=WORKING_DIR/"data/train/train"):
  for i in range(0, number_of_indices):
    index = np.random.randint(0, 100)
    caption_embd = data_keeper.val_text_embd[index]
    caption_text = data_keeper.val_caption_text[index]
    gt_index = data_keeper.val_label[index]
    model.eval()
    with torch.no_grad():
      pred_embds = model(caption_embd.to(device)).to(device)
      visualize_retrieval(
        pred_embds,
        gt_index,
        data_keeper.val_img_file,
        caption_text, data_keeper.val_img_embd, k=5, dataset_path=dataset_path)


In [None]:
@torch.inference_mode()
def evaluate_retrieval(translated_embd, image_embd, gt_indices, max_indices = 99, batch_size=100):
    """
    [FIXED VERSION of challenge/src/eval/eval.py]
    This function had a bug in the repo, it's fixed here.
    """
    if isinstance(translated_embd, np.ndarray):
        translated_embd = torch.from_numpy(translated_embd).float()
    if isinstance(image_embd, np.ndarray):
        image_embd = torch.from_numpy(image_embd).float()

    n_queries = translated_embd.shape[0]
    all_sorted_indices = []

    for start_idx in range(0, n_queries, batch_size):
        batch_slice = slice(start_idx, min(start_idx + batch_size, n_queries))
        batch_translated = translated_embd[batch_slice]

        batch_similarity = batch_translated @ image_embd.T

        # --- THIS IS THE FIX ---
        # Added .cpu() before .numpy()
        batch_indices = batch_similarity.topk(k=max_indices, dim=1, sorted=True).indices.cpu().numpy()
        # --- END OF FIX ---

        all_sorted_indices.append(batch_indices)

    sorted_indices = np.concatenate(all_sorted_indices, axis=0)
    metrics = {
        'mrr': mrr,
        'ndcg': ndcg,
        'recall_at_1': lambda preds, gt: recall_at_k(preds, gt, 1),
        'recall_at_3': lambda preds, gt: recall_at_k(preds, gt, 3),
        'recall_at_5': lambda preds, gt: recall_at_k(preds, gt, 5),
        'recall_at_10': lambda preds, gt: recall_at_k(preds, gt, 10),
        'recall_at_50': lambda preds, gt: recall_at_k(preds, gt, 50),
    }

    results = {
        name: func(sorted_indices, gt_indices)
        for name, func in metrics.items()
    }

    return results

## Model deffinitions

In [None]:
class Normalization(Enum):
  NONE = 0
  STANDARD = 1
  L2 = 2

In [None]:
class AnchorSelectionStrategy(Enum):
  UNIFORM=1
  RANDOM=2

In [None]:
class RefinementType(Enum):
  NONE = 0,
  RESIDUAL=1

In [None]:
N_ANCHORS = 1536
ANCHOR_SELECTION =  AnchorSelectionStrategy.UNIFORM

NORMALIZATIONS = [Normalization.STANDARD, Normalization.L2]

USE_REFINEMENT = True
REFINEMENT_EPOCHS = 30
REFINEMENT_LR = 5e-5
REFINEMENT_BATCH_SIZE = 256
REFINEMENT_PATIENCE = 8

PROCRUSTES_METHOD = RefinementType.RESIDUAL
USE_ENSEMBLE = False

In [None]:
def normalize_embeddings(X, method=Normalization.STANDARD, stats=None):
    if method == Normalization.NONE:
        return X, {'method': 'none', 'dim': X.shape[1]}
    elif method == Normalization.L2:
        norms = torch.linalg.norm(X, axis=1, keepdims=True)
        norms = torch.where(norms == 0, 1.0, norms)
        X_norm = X / norms

        if stats is None:
            stats = {
                'method': 'l2',
                'dim': X.shape[1]
            }
        return X_norm, stats

    elif method == Normalization.STANDARD:
        if stats is None:
            mean = X.mean(axis=0)
            std = X.std(axis=0)
            std = torch.where(std == 0, 1.0, std)

            stats = {
                'method': 'standard',
                'mean': mean,
                'std': std,
                'dim': X.shape[1]
            }
        else:
            mean = stats['mean']
            std = stats['std']
        X_norm = (X - mean) / std
        return X_norm, stats


In [None]:
def select_anchors_diverse(data_keeper, n_anchors, method=AnchorSelectionStrategy.UNIFORM):
    caption_embeddings = data_keeper.train_data['captions/embeddings']
    label_matrix = data_keeper.train_data['captions/label']

    n_captions = len(caption_embeddings)
    gt_indices = np.argmax(label_matrix, axis=1)

    if method == AnchorSelectionStrategy.UNIFORM:
        caption_indices = np.linspace(0, n_captions - 1, n_anchors, dtype=int)

    elif method == AnchorSelectionStrategy.RANDOM:
        caption_indices = np.random.choice(n_captions, n_anchors, replace=True)

    else:
        raise ValueError(f"Unknown method: {method}")

    image_indices = gt_indices[caption_indices]

    print(f"   Selected {len(caption_indices)} anchor pairs")
    print(f"   Caption indices range: {caption_indices.min()} - {caption_indices.max()}")
    print(f"   Image indices range: {image_indices.min()} - {image_indices.max()}")

    return caption_indices, image_indices

In [None]:
def denormalize_embeddings(X_norm, method=Normalization.STANDARD, stats=None):
    if method == Normalization.NONE:
        return X_norm

    elif method == Normalization.L2:
        return X_norm

    elif method == Normalization.STANDARD:
        if stats is None:
            raise ValueError("Need stats for denormalization")
        mean = stats['mean']
        std = stats['std']
        if X_norm.shape[1] > len(mean):
            X_norm = X_norm[:, :len(mean)]
        X = X_norm * std + mean

        return X

In [None]:
def compute_R(R, allow_reflection=True):
    U, S, Vt = torch.linalg.svd(R, full_matrices=False)
    new_R = U @ Vt
    print(f"Xpad.shape {R.shape}")
    print(f"U.shape {U.shape}")
    print(f"Vt.shape {Vt.shape}")
    print(f"new_R.shape {new_R.shape}")
    if not allow_reflection and np.linalg.det(new_R) < 0:
        print("  ⚠ Reflection detected, correcting to pure rotation")
        U[:, -1] *= -1
        new_R = U @ Vt

    return new_R

In [None]:
def compute_procrustes_with_padding(X, Y, allow_reflection=True):
    d1, d2 = X.shape[1], Y.shape[1]
    d_max = max(d1, d2)
    print(f"X.shape {X.shape}")
    print(f"Y.shape {Y.shape}")
    if d1 < d_max:
        X = F.pad(X, (0, d_max - d1, 0, 0), mode='constant', value=0)
    if d2 < d_max:
        Y = F.pad(Y, (0, d_max - d2, 0, 0), mode='constant', value=0)
    H = X.T @ Y
    U, S, Vt = torch.linalg.svd(H, full_matrices=False)
    R = U @ Vt
    if not allow_reflection and np.linalg.det(R) < 0:
        print("  ⚠ Reflection detected, correcting to pure rotation")
        U[:, -1] *= -1
        R = U @ Vt

    return R

In [None]:
class ProcrustesTranslator(nn.Module):
  def __init__(self, normalization=Normalization.STANDARD, allow_reflection=True):
    super().__init__()
    self.normalization = normalization
    self.R = None
    self.source_stats = None
    self.target_stats = None
    self.allow_reflection = allow_reflection
  def fit(self, X_source, Y_target):
      print(f"  Input dimensions: {X_source.shape[1]} -> {Y_target.shape[1]}")

      X_norm, self.source_stats = normalize_embeddings(
          X_source, self.normalization
      )
      Y_norm, self.target_stats = normalize_embeddings(
          Y_target, self.normalization
      )

      self.R = compute_procrustes_with_padding(
          X_norm, Y_norm, allow_reflection=self.allow_reflection
      )

      print(f"  Transformation matrix shape: {self.R.shape}")

  def forward(self, X_source):
      if self.R is None:
          raise ValueError("Must call fit() before transform()")

      X_norm, _ = normalize_embeddings(
          X_source, self.normalization, self.source_stats
      )
      d_in = X_norm.shape[1]
      d_R = self.R.shape[0]

      if d_in < d_R:
          X_norm = F.pad(X_norm, (0, d_R - d_in, 0, 0), mode='constant')

      Y_norm = X_norm @ self.R

      # Step 4: Remove padding from output if needed
      # target_stats will tell us the original target dimension
      if self.normalization == Normalization.STANDARD:
          d_target = len(self.target_stats['mean'])
      elif self.normalization == Normalization.L2:
          d_target = self.target_stats['dim']
      else:
          d_target = d_R

      if Y_norm.shape[1] > d_target:
          Y_norm = Y_norm[:, :d_target]

      Y = denormalize_embeddings(
          Y_norm, self.normalization, self.target_stats
      )

      return Y

In [None]:
class AffineTranslator(nn.Module):
  def __init__(self, input_dim=1024, output_dim=1536, normalization=Normalization.STANDARD):
    super().__init__()
    self.normalization=Normalization.STANDARD
    self.net = nn.Linear(input_dim, output_dim)

  def forward(self, X):
    X_norm, self.source_stats = normalize_embeddings(
    X, self.normalization
    )
    Y = self.net(X_norm)
    return Y

In [None]:
class LinearTranslator(nn.Module):
  def __init__(self, input_dim=1024, output_dim=1536, normalization=Normalization.STANDARD):
    super().__init__()
    self.R = None
    self.normalization = normalization
  def fit(self, X,Y):
    X_norm, self.source_stats = normalize_embeddings(
        X, self.normalization
    )
    Y_norm, self.target_stats = normalize_embeddings( ###NOTE: MAybe it is nice idea ot to noramlzie Y, with L2?
        Y, self.normalization
    )
    R_T, residuals, rank, singular_values = torch.linalg.lstsq(X_norm, Y_norm)
    self.R = R_T
    return self.R
  def forward(self, X):
    X_norm, _ = normalize_embeddings(
          X, self.normalization, self.source_stats
      )
    Y_norm =   X_norm @ self.R
    Y = denormalize_embeddings(
      Y_norm, self.normalization, self.target_stats
      )
    return Y

In [None]:
class LOrthoTranslator(nn.Module):
  def __init__(self, normalization=Normalization.STANDARD, target_stats=None, allow_reflection=True):
    super().__init__()
    self.normalization = normalization
    self.R = None
    self.source_stats = None
    self.target_stats = target_stats
    self.allow_reflection = allow_reflection
  def fit(self, R):
      self.R = compute_R(
         R, allow_reflection=self.allow_reflection
      )

      print(f"  Transformation matrix shape: {self.R.shape}")

  def forward(self, X_source):
      if self.R is None:
          raise ValueError("Must call fit() before transform()")

      X_norm, _ = normalize_embeddings(
          X_source, self.normalization, self.source_stats
      )
      d_in = X_norm.shape[1]
      d_R = self.R.shape[0]

      if d_in < d_R:
          X_norm = F.pad(X_norm, (0, d_R - d_in, 0, 0), mode='constant')


      Y_norm = X_norm @ self.R

      if self.normalization == Normalization.STANDARD:
          d_target = len(self.target_stats['mean'])
      elif self.normalization == Normalization.L2:
          d_target = self.target_stats['dim']
      else:
          d_target = d_R

      if Y_norm.shape[1] > d_target:
          Y_norm = Y_norm[:, :d_target]

      Y = denormalize_embeddings(
          Y_norm, self.normalization, self.target_stats
      )

      return Y

In [None]:
class RefinedProcrustesTranslator(nn.Module):
    def __init__(self, d_in=1024, d_out=1536, procrustes_R=None,
                 hidden_dim=None, refinement_type=RefinementType.RESIDUAL):
        super().__init__()

        if hidden_dim is None:
            hidden_dim = d_out

        self.d_in = d_in
        self.d_out = d_out
        self.refinement_type = refinement_type

        if procrustes_R is not None:
            R_sub = procrustes_R[:d_out, :d_in]

            self.linear = nn.Linear(d_in, d_out, bias=False)
            with torch.no_grad():
                self.linear.weight.data = R_sub

            print(f"  Initialized with Procrustes solution (extracted {d_out}×{d_in} from {procrustes_R.shape})")

        if refinement_type == RefinementType.RESIDUAL:
            self.refinement = nn.Sequential(
                nn.Linear(d_out, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Dropout(0.1),
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.LayerNorm(hidden_dim // 2),
                nn.GELU(),
                nn.Dropout(0.1),
                nn.Linear(hidden_dim // 2, d_out)
            )
            # Initialize last layer near zero
            nn.init.zeros_(self.refinement[-1].weight)
            nn.init.zeros_(self.refinement[-1].bias)
            self.residual_weight = nn.Parameter(torch.tensor(0.1))

    def forward(self, x):
        x_proj = self.linear(x)

        if self.refinement_type == RefinementType.RESIDUAL:
            residual = self.refinement(x_proj)
            alpha = torch.sigmoid(self.residual_weight) * 0.3 ##NOTE is this ).3 needed?
            return x_proj + alpha * residual
        else:
            return x_proj

In [None]:
def train_translator(translator, data_keeper, n_anchors=N_ANCHORS, method=ANCHOR_SELECTION):
  caption_anchor_idx, image_anchor_idx = select_anchors_diverse(data_keeper, n_anchors, method=method)
  X_anchors = torch.from_numpy(data_keeper.train_data['captions/embeddings'][caption_anchor_idx]).float().to(DEVICE)
  Y_anchors = torch.from_numpy(data_keeper.train_data['images/embeddings'][image_anchor_idx]).float().to(DEVICE)
  translator.fit(X_anchors, Y_anchors)


In [None]:
def model_metrics(model, data_keeper, batch_size=BATCH_SIZE, device=DEVICE):
  all_embds = []
  for X, y in data_keeper.val_loader:
    model.eval()
    model.to(device)
    with torch.no_grad():
      pred_embds = model(X.to(device))
      all_embds.append(pred_embds)
  all_embds = torch.cat(all_embds, dim=0)
  val_img_embd_on_device = data_keeper.val_img_embd.to(device)
  return evaluate_retrieval(all_embds, val_img_embd_on_device, data_keeper.val_label)

In [None]:
def create_submision(model, data_keeper, batch_size=64, data_path=DATA_PATH, device=DEVICE):
  test_data = load_data(data_path/"test/test/test.clean.npz")
  test_embds = torch.from_numpy(test_data['captions/embeddings']).to(device)
  all_preds = torch.empty((0, 1536), device=device)
  model.eval()
  for i in range(0, len(test_embds), batch_size):
    batch = test_embds[i:min(i+batch_size, len(test_embds))]
    with torch.no_grad():
      pred_embds = model(batch)
    all_preds = torch.cat((all_preds, pred_embds),dim=0)
  pred_embds = torch.Tensor(all_preds)

  submision = generate_submission(test_data['captions/ids'], pred_embds, 'submission.csv')
  return pred_embds

In [None]:

def train_model_affine(model, train_loader, epochs=EPOCHS, device=DEVICE, lr=LR, loss_function=F.mse_loss, verbose = True, collect_statistiscs = True):
  optimizer  = optim.Adam(model.parameters(), lr=lr)
  best_val_loss = float("inf")
  training_statistic = None
  if collect_statistiscs:
    training_statistic = Statistics()
  for epoch in tqdm(range(epochs), desc="Trainnig"):
    model.train()
    train_loss = 0
    for X_batch, y_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}") if verbose else train_loader:
      X_batch, y_batch = X_batch.to(device), y_batch.to(device)
      optimizer.zero_grad()
      outputs = model(X_batch)
      loss = loss_function(outputs, y_batch)
      loss.backward()
      optimizer.step()
      train_loss += loss.item()
    train_loss /= len(train_loader)
    if verbose:
      print(f"Epoch {epoch+1}/{epochs} - Train loss: {train_loss}")
    if collect_statistiscs:
      training_statistic.losses.append(train_loss)
  if collect_statistiscs:
    training_statistic.best_loss = min(training_statistic.losses)
    training_statistic.best_epoch_index = training_statistic.losses.index(training_statistic.best_loss)
  return model, training_statistic

## Model training

In [None]:
data_keeper = get_train_data(split_ratio=0.9)

In [None]:
ptranslator = ProcrustesTranslator(
    normalization=NORMALIZATIONS[1]
)
train_translator(ptranslator, data_keeper, n_anchors=120000)

In [None]:
metrics = model_metrics(ptranslator, data_keeper)
print(metrics)

In [None]:
ltranslator = LinearTranslator(
    normalization=NORMALIZATIONS[1]
)
train_translator(ltranslator, data_keeper, n_anchors=120000)

In [None]:
metrics = model_metrics(ltranslator, data_keeper)
print(metrics)

In [None]:
lotranslator = LOrthoTranslator(
    normalization=NORMALIZATIONS[1],
    target_stats=ltranslator.target_stats
)
lotranslator.fit(ltranslator.R)

In [None]:
metrics = model_metrics(lotranslator, data_keeper)
print(metrics)

In [None]:
atranslator = AffineTranslator(normalization=NORMALIZATIONS[0]).to(DEVICE)

In [None]:
model, stats = train_model_affine(atranslator, data_keeper.train_loader, verbose=False)

In [None]:
metrics = model_metrics(atranslator, data_keeper)
print(metrics)

### Best model trainig

In [None]:
data_keeper = get_train_data(split_ratio=1.0)# suing whole dataset for training

In [None]:
lotranslator = LOrthoTranslator(
    normalization=NORMALIZATIONS[1],
    target_stats=ltranslator.target_stats
)
lotranslator.fit(ltranslator.R)

In [None]:
create_submision(lotranslator, data_keeper=data_keeper)

# optuna experiments used to find best params of model

In [None]:
data_keeper = get_train_data(split_ratio=0.9)

In [None]:
def objective(trial):
    n_anchors =trial.suggest_int('n_anchors', 6000, 200000, step=1536)
    normalization = trial.suggest_categorical('normalization',[Normalization.STANDARD, Normalization.L2])
    # allow_reflection = trial.suggest_categorical('allow_reflection',[True, False])
    allow_reflection = True
    anchor_selection = trial.suggest_categorical('anchor_selection',[AnchorSelectionStrategy.RANDOM, AnchorSelectionStrategy.UNIFORM])
    ltranslator = LinearTranslator(
        normalization=normalization

    )

    train_translator(ltranslator, data_keeper, n_anchors=n_anchors, method=anchor_selection)
    lotranslator = LOrthoTranslator(
        normalization=normalization,
        target_stats=ltranslator.target_stats,
        allow_reflection=allow_reflection
    )
    lotranslator.fit(ltranslator.R)
    metrics = model_metrics(lotranslator, data_keeper)
    print(metrics)
    return metrics['mrr']

In [None]:
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=60)

In [None]:
df = study.trials_dataframe()
df.to_csv(WORKING_DIR/"optuna_results_lortho.csv", index=False)

In [None]:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(ncols=2, nrows=2, figsize=(12, 5))
ax[0][0].scatter(df['params_normalization'].apply(lambda x: x.name), df['value'])
ax[0][0].set_xlabel('Normalization Type')
ax[0][0].set_ylabel('MRR')
ax[0][0].tick_params(axis='x', rotation=45)

ax[0][1].scatter(df['params_n_anchors'], df['value'])
ax[0][1].set_xlabel('Number of Anchors')
ax[0][1].set_ylabel('MRR')

ax[1][0].scatter(df['params_anchor_selection'].apply(lambda x: x.name), df['value'])
ax[1][0].set_xlabel('anchor_selection Type')
ax[1][0].set_ylabel('MRR')
ax[1][0].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

# Ensembling models with VAE

### VAE

In [None]:
class VAEAdapter(nn.Module):
    """
    A simple VAE-based translator from text embeddings to image embeddings.

    Args:
        input_dim (int): Dimension of the input text embeddings.
        output_dim (int): Dimension of the output image embeddings.
        hidden_dim (int): Hidden layer width.
        latent_dim (int): Latent space dimensionality.
    """
    def __init__(self, input_dim, output_dim, hidden_dim, latent_dim):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, latent_dim * 2)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, output_dim)
        )

    def reparameterize(self, mu, logvar):
        """Apply reparameterization trick."""
        return mu + torch.randn_like(mu) * torch.exp(0.5 * logvar)

    def forward(self, x):
        """Forward pass: encode → sample → decode."""
        x = F.normalize(x, dim=-1)
        mu, logvar = self.encoder(x).chunk(2, dim=-1)
        z = self.reparameterize(mu, logvar)
        out = self.decoder(z)
        return F.normalize(out, dim=-1), mu, logvar


def init_weights(m):
    """Kaiming initialization for all Linear layers."""
    if isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

MODEL_REGISTRY = {"vae_adapter": VAEAdapter}

class VAELoss:
    """
    Combines CLIP-style reconstruction with VAE regularization.
    """
    def __init__(self, kld_weight=1e-3, device=DEVICE):
        self.kld_weight = kld_weight
        self.clip = ClipLoss().to(device)

    def __call__(self, pred, mu, logvar, y, logit_scale):
        y = F.normalize(y, dim=-1)
        recons = self.clip(image_features=y, text_features=pred, logit_scale=logit_scale)
        kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1).mean()
        return recons + self.kld_weight * kld

@torch.inference_mode()
def validate_repo_metrics(model, data_keeper, val_gallery, val_labels, device=DEVICE):
    """
    Evaluate using the challenge’s official retrieval metrics (MRR, Recall@k, etc.).
    """
    model.eval()
    preds = [model(X.to(device))[0] for X, _ in tqdm(data_keeper.val_loader, desc="[Val: Metrics]")]
    preds = torch.cat(preds)
    return evaluate_retrieval(preds, data_keeper.val_img_embd, data_keeper.val_label, max_indices=100)

@torch.inference_mode()
def validate_cliploss(model, data_keeper, loss_fn, logit_scale, device=DEVICE):
    """
    Evaluate using internal CLIP-style reconstruction loss.
    """
    model.eval()
    val_loss = 0
    for X, y in tqdm(data_keeper.val_loader, desc="[Val: ClipLoss]"):
        X, y = X.to(device), y.to(device)
        pred, mu, logvar = model(X)
        val_loss += loss_fn(pred, mu, logvar, y, logit_scale).item()
    return {"val_loss": val_loss / len(data_keeper.val_loader)}

def train_model(model, data_keeper, epochs, loss_fn,lr=LR, device=DEVICE):
    """
    Train the model for one experiment run.
    Supports both normal training and Optuna trials.
    """
    logit_scale = nn.Parameter(torch.ones([], device=device) * np.log(1 / 0.07))
    optimizer = torch.optim.Adam(list(model.parameters()) + [logit_scale], lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for X, y in tqdm(data_keeper.train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]"):
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            pred, mu, logvar = model(X)
            loss = loss_fn(pred, mu, logvar, y, logit_scale.exp().clamp(1, 100))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(data_keeper.train_loader)
        scheduler.step()
        tqdm.write(f"Epoch {epoch+1}: Average Train Loss = {avg_loss:.6f}")
    return model, logit_scale


In [None]:
data_keeper = get_train_data(split_ratio=0.9)

# {'hidden_dim': 1024, 'latent_dim': 2560, 'kld_weight': 0.0001, 'batch_size': 352}

data_keeper.create_loaders(352)

model = VAEAdapter(input_dim=1024, output_dim=1536, hidden_dim=1024, latent_dim=2560).to(DEVICE)
model.apply(init_weights)
loss_fn = VAELoss(kld_weight=0.0001, device=DEVICE)
model, logit_scale = train_model(model, data_keeper=data_keeper,epochs=10, loss_fn=loss_fn)

### ensembling

In [None]:
class DynamicEnsemble(nn.Module):
    def __init__(self, model1, model2, input_dim, hidden_dim=32):
        super().__init__()
        self.model1 = copy.deepcopy(model1)
        self.model2 = copy.deepcopy(model2)
        self.gate = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2)
        )
        for param in self.model1.parameters():
          param.requires_grad = False
        for param in self.model2.parameters():
          param.requires_grad = False

    def forward(self, x):
        y1_output = self.model1(x)
        if isinstance(y1_output, tuple):
            y1 = y1_output[0]
        else:
            y1 = y1_output

        y2 = self.model2(x)

        w = F.softmax(self.gate(x), dim=1)
        y = w[:, 0:1] * y1 + w[:, 1:] * y2
        return y

In [None]:
class DynamicEnsembleV2(nn.Module):
    def __init__(self, model1, model2, input_dim=1024, output_dim=1536, hidden_dim=2048):
        super().__init__()
        self.model1 = copy.deepcopy(model1)
        self.model2 = copy.deepcopy(model2)
        self.gate_1 = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

        self.gate_2 = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        for param in self.model1.parameters():
          param.requires_grad = False
        for param in self.model2.parameters():
          param.requires_grad = False

    def forward(self, x):
        y1_output = self.model1(x)
        if isinstance(y1_output, tuple):
            y1 = y1_output[0]
        else:
            y1 = y1_output

        y2 = self.model2(x)
        alpha = self.gate_1(x)
        beta = self.gate_2(x)
        combined = torch.stack((alpha, beta), dim=1)

        w = F.softmax(combined, dim=1)

        y = w[:, 0, :] * y1 + w[:, 1, :] * y2
        return y

In [None]:
class DynamicEnsembleSimple(nn.Module):
    def __init__(self, model1, model2):
        super().__init__()
        self.model1 = copy.deepcopy(model1)
        self.model2 = copy.deepcopy(model2)
        self.weights = nn.Parameter(torch.tensor([0.5, 0.5]))
        for param in self.model1.parameters():
          param.requires_grad = False
        for param in self.model2.parameters():
          param.requires_grad = False

    def forward(self, x):
        y1_output = self.model1(x)
        if isinstance(y1_output, tuple):
            y1 = y1_output[0]
        else:
            y1 = y1_output

        y2 = self.model2(x)
        w = F.softmax(self.weights, dim=0)
        y = w[0] * y1 + w[1] * y2
        return y

In [None]:
class DynamicEnsembleSimpleV2(nn.Module):
    def __init__(self, model1, model2, output_dim=1536):
        super().__init__()
        self.model1 = copy.deepcopy(model1)
        self.model2 = copy.deepcopy(model2)
        self.weights = nn.Parameter(torch.ones(2, output_dim) * 0.5)
        for param in self.model1.parameters():
          param.requires_grad = False
        for param in self.model2.parameters():
          param.requires_grad = False

    def forward(self, x):
        y1_output = self.model1(x)
        if isinstance(y1_output, tuple):
            y1 = y1_output[0]
        else:
            y1 = y1_output

        y2 = self.model2(x)

        w = F.softmax(self.weights, dim=0)
        y = w[0] * y1 + w[1] * y2
        return y

In [None]:
model_vae = model

In [None]:
ensenmble_model = DynamicEnsemble(model_vae, lotranslator, input_dim=1024).to(DEVICE)


In [None]:
ensenmble_model, metrics = train_model_affine(ensenmble_model, data_keeper.train_loader, epochs=10, verbose=True)

In [None]:
metrics = model_metrics(ensenmble_model, data_keeper)

In [None]:
print(metrics)

In [None]:
ensenmble_model = DynamicEnsembleV2(model_vae, lotranslator, input_dim=1024, hidden_dim=320).to(DEVICE)


In [None]:
metrics = model_metrics(ensenmble_model, data_keeper)

In [None]:
print(metrics)

In [None]:
ensenmble_model = DynamicEnsembleSimpleV2(model_vae, lotranslator).to(DEVICE)


In [None]:
metrics = model_metrics(ensenmble_model, data_keeper)

In [None]:
ensenmble_model = DynamicEnsembleSimple(model_vae, lotranslator).to(DEVICE)


In [None]:
ensenmble_model, metrics = train_model_affine(ensenmble_model, data_keeper.train_loader, epochs=10, verbose=True)

In [None]:
metrics = model_metrics(ensenmble_model, data_keeper)

In [None]:
print(metrics)

#### Coment

ensembling these models doesn't seems to bring any positive result