# Innitialize

In [None]:
import os
from google.colab import userdata
os.environ['KAGGLE_USERNAME'] = userdata.get('KAGGLE_USERNAME')
os.environ['KAGGLE_KEY'] = userdata.get('KAGGLE_KEY')
os.environ['GIT_TOKEN'] = userdata.get('GIT_TOKEN')

In [None]:
!git clone https://$GIT_TOKEN@$repo
!kaggle competitions download -c aml-competition
!unzip -o aml-competition.zip -d data
!git clone https://github.com/Mamiglia/challenge.git

In [None]:
from typing import List
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from pathlib import Path
from tqdm.auto import tqdm
import torch.nn.functional as F
import numpy as np
from enum import Enum
import math
import random
import matplotlib.pyplot as plt

In [None]:
from challenge.src.common import load_data, prepare_train_data, generate_submission
from challenge.src.eval import visualize_retrieval, evaluate_retrieval
from challenge.src.eval.metrics import recall_at_k, ndcg,mrr

In [None]:
%pip install optuna

In [None]:
import optuna

In [None]:
!pip install torchdiffeq

In [None]:
!git clone https://github.com/qihao067/CrossFlow.git

In [None]:
def set_seed(seed=42):
    """Ensure deterministic reproducibility."""
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
set_seed()

In [None]:
WORKING_DIR = Path.cwd()

In [None]:
MODELS_DIR = Path("/").absolute()
MODEL_PATH = MODELS_DIR / "default.pth"

In [None]:
DATA_PATH = WORKING_DIR / "data"

In [None]:
EPOCHS = 20
BATCH_SIZE = 256
LR = 0.001
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
sys.path.append(str(WORKING_DIR / "CrossFlow"))
from CrossFlow.diffusion.flow_matching import ClipLoss, SigLipLoss

# Actuall code

## Basic definitions

In [None]:
class Statistics():
  losses = []
  best_loss = float("inf")
  best_epoch_index = -1
  mse_losses = []
  cos_losses = []
  contrastive_losses = []

In [None]:
class DataKeeper():
  train_data = None
  train_loader = None
  val_loader = None
  train_dataset = None
  val_dataset = None
  val_caption_text = None
  val_text_embd = None
  val_img_file = None
  val_img_embd = None
  val_label = None
  def create_loaders(self, batch_size = BATCH_SIZE):
    self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
    self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size)

In [None]:
def get_train_data(data_path=DATA_PATH, split_ratio = 0.8, batch_size = BATCH_SIZE):
  data_keeper = DataKeeper()
  data_keeper.train_data = load_data(data_path/"train/train/train.npz")
  X, y, label = prepare_train_data(data_keeper.train_data)
  DATASET_SIZE = len(X)
  train_size = int(DATASET_SIZE * split_ratio)
  TRAIN_SPLIT = torch.zeros(len(X), dtype=bool)
  TRAIN_SPLIT[:train_size] = True
  X_train, X_val = X[TRAIN_SPLIT], X[~TRAIN_SPLIT]
  y_train, y_val = y[TRAIN_SPLIT], y[~TRAIN_SPLIT]
  data_keeper.train_dataset = TensorDataset(X_train, y_train)
  data_keeper.val_dataset = TensorDataset(X_val, y_val)
  data_keeper.train_loader = DataLoader(data_keeper.train_dataset, batch_size=batch_size, shuffle=True)
  data_keeper.val_loader = DataLoader(data_keeper.val_dataset, batch_size=batch_size)
  img_VAL_SPLIT = label[~TRAIN_SPLIT].sum(dim=0) > 0
  data_keeper.val_caption_text = data_keeper.train_data['captions/text'][~TRAIN_SPLIT]
  data_keeper.val_text_embd = X_val
  data_keeper.img_VAL_SPLIT = label[~TRAIN_SPLIT].sum(dim=0) > 0
  data_keeper.val_img_file = data_keeper.train_data['images/names'][img_VAL_SPLIT]
  data_keeper.val_img_embd = torch.from_numpy(data_keeper.train_data['images/embeddings'][img_VAL_SPLIT])
  data_keeper.val_label = np.nonzero(data_keeper.train_data['captions/label'][~TRAIN_SPLIT][:,img_VAL_SPLIT])[1]
  return data_keeper


In [None]:
def save_model(model, path=MODEL_PATH, verbose=True):
  torch.save(model.state_dict(), path)
  print(f"Model saved to {path}")

In [None]:
def load_model(model, path=MODEL_PATH, verbose=True):
  model.load_state_dict(torch.load(path))
  if verbose:
    print(f"Model loaded from {path}")
  return model

In [None]:
def sample_and_visualize(model, data_keeper, device=DEVICE, number_of_indices=5, dataset_path=WORKING_DIR/"data/train/train"):
  for i in range(0, number_of_indices):
    index = np.random.randint(0, 100)
    caption_embd = data_keeper.val_text_embd[index]
    caption_text = data_keeper.val_caption_text[index]
    gt_index = data_keeper.val_label[index]
    model.eval()
    with torch.no_grad():
      pred_embds = model(caption_embd.to(device)).to(device)
      visualize_retrieval(
        pred_embds,
        gt_index,
        data_keeper.val_img_file,
        caption_text, data_keeper.val_img_embd, k=5, dataset_path=dataset_path)


In [None]:
@torch.inference_mode()
def evaluate_retrieval(translated_embd, image_embd, gt_indices, max_indices = 99, batch_size=100):
    """
    [FIXED VERSION of challenge/src/eval/eval.py]
    This function had a bug in the repo, it's fixed here.
    """
    if isinstance(translated_embd, np.ndarray):
        translated_embd = torch.from_numpy(translated_embd).float()
    if isinstance(image_embd, np.ndarray):
        image_embd = torch.from_numpy(image_embd).float()

    n_queries = translated_embd.shape[0]
    all_sorted_indices = []

    for start_idx in range(0, n_queries, batch_size):
        batch_slice = slice(start_idx, min(start_idx + batch_size, n_queries))
        batch_translated = translated_embd[batch_slice]

        batch_similarity = batch_translated @ image_embd.T

        # --- THIS IS THE FIX ---
        # Added .cpu() before .numpy()
        batch_indices = batch_similarity.topk(k=max_indices, dim=1, sorted=True).indices.cpu().numpy()
        # --- END OF FIX ---

        all_sorted_indices.append(batch_indices)

    sorted_indices = np.concatenate(all_sorted_indices, axis=0)
    metrics = {
        'mrr': mrr,
        'ndcg': ndcg,
        'recall_at_1': lambda preds, gt: recall_at_k(preds, gt, 1),
        'recall_at_3': lambda preds, gt: recall_at_k(preds, gt, 3),
        'recall_at_5': lambda preds, gt: recall_at_k(preds, gt, 5),
        'recall_at_10': lambda preds, gt: recall_at_k(preds, gt, 10),
        'recall_at_50': lambda preds, gt: recall_at_k(preds, gt, 50),
    }

    results = {
        name: func(sorted_indices, gt_indices)
        for name, func in metrics.items()
    }

    return results

## Model deffinitions

In [None]:
class VAEAdapter(nn.Module):
    """
    A simple VAE-based translator from text embeddings to image embeddings.

    Args:
        input_dim (int): Dimension of the input text embeddings.
        output_dim (int): Dimension of the output image embeddings.
        hidden_dim (int): Hidden layer width.
        latent_dim (int): Latent space dimensionality.
    """
    def __init__(self, input_dim, output_dim, hidden_dim, latent_dim):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, latent_dim * 2)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, output_dim)
        )

    def reparameterize(self, mu, logvar):
        """Apply reparameterization trick."""
        return mu + torch.randn_like(mu) * torch.exp(0.5 * logvar)

    def forward(self, x):
        """Forward pass: encode → sample → decode."""
        x = F.normalize(x, dim=-1)
        mu, logvar = self.encoder(x).chunk(2, dim=-1)
        z = self.reparameterize(mu, logvar)
        out = self.decoder(z)
        return F.normalize(out, dim=-1), mu, logvar


def init_weights(m):
    """Kaiming initialization for all Linear layers."""
    if isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

MODEL_REGISTRY = {"vae_adapter": VAEAdapter}


In [None]:
class VAELoss:
    """
    Combines CLIP-style reconstruction with VAE regularization.
    """
    def __init__(self, kld_weight=1e-3, device=DEVICE):
        self.kld_weight = kld_weight
        self.clip = ClipLoss().to(device)

    def __call__(self, pred, mu, logvar, y, logit_scale):
        y = F.normalize(y, dim=-1)
        recons = self.clip(image_features=y, text_features=pred, logit_scale=logit_scale)
        kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1).mean()
        return recons + self.kld_weight * kld



In [None]:
@torch.inference_mode()
def validate_repo_metrics(model, data_keeper, val_gallery, val_labels, device=DEVICE):
    """
    Evaluate using the challenge’s official retrieval metrics (MRR, Recall@k, etc.).
    """
    model.eval()
    preds = [model(X.to(device))[0] for X, _ in tqdm(data_keeper.val_loader, desc="[Val: Metrics]")]
    preds = torch.cat(preds)
    return evaluate_retrieval(preds, data_keeper.val_img_embd, data_keeper.val_label, max_indices=100)


In [None]:
@torch.inference_mode()
def validate_cliploss(model, data_keeper, loss_fn, logit_scale, device=DEVICE):
    """
    Evaluate using internal CLIP-style reconstruction loss.
    """
    model.eval()
    val_loss = 0
    for X, y in tqdm(data_keeper.val_loader, desc="[Val: ClipLoss]"):
        X, y = X.to(device), y.to(device)
        pred, mu, logvar = model(X)
        val_loss += loss_fn(pred, mu, logvar, y, logit_scale).item()
    return {"val_loss": val_loss / len(data_keeper.val_loader)}

In [None]:
def train_model(model, data_keeper, epochs, loss_fn,lr=LR, device=DEVICE):
    """
    Train the model for one experiment run.
    Supports both normal training and Optuna trials.
    """
    logit_scale = nn.Parameter(torch.ones([], device=device) * np.log(1 / 0.07))
    optimizer = torch.optim.Adam(list(model.parameters()) + [logit_scale], lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for X, y in tqdm(data_keeper.train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]"):
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            pred, mu, logvar = model(X)
            loss = loss_fn(pred, mu, logvar, y, logit_scale.exp().clamp(1, 100))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(data_keeper.train_loader)
        scheduler.step()
        tqdm.write(f"Epoch {epoch+1}: Average Train Loss = {avg_loss:.6f}")
    return model, logit_scale


In [None]:
def evaluate_model_VAE(model, data_keeper, logit_scale, loss_function, batch_size=BATCH_SIZE, device=DEVICE):
  val_loss = 0
  for X, y in data_keeper.val_loader:
    X_batch, y_batch = X.to(device), y.to(device)
    model.eval()
    model.to(device)
    with torch.no_grad():
      pred, mu, logvar = model(X_batch)
      loss = loss_fn(pred, mu, logvar, y_batch, logit_scale.exp().clamp(1, 100))
      val_loss += loss.item()
  return val_loss/len(data_keeper.val_loader)

In [None]:
def model_metrics_reflow(model, data_keeper, batch_size=BATCH_SIZE, device=DEVICE):
  all_embds = []
  img_embd = []
  for X, y in data_keeper.val_loader:
    model.eval()
    model.to(device)
    with torch.no_grad():
      pred_embds, _, _ = model(X.to(device))
      all_embds.append(pred_embds.cpu())
  all_embds = torch.cat(all_embds, dim=0)
  return evaluate_retrieval(all_embds, data_keeper.val_img_embd, data_keeper.val_label)

In [None]:
def create_submisio_VAE(model, data_keeper, batch_size=64, data_path=DATA_PATH, device=DEVICE):
  test_data = load_data(data_path/"test/test/test.clean.npz")
  test_embds = torch.from_numpy(test_data['captions/embeddings']).to(device)
  all_preds = torch.empty((0, 1536), device=device)
  model.eval()
  for i in range(0, len(test_embds), batch_size):
    batch = test_embds[i:min(i+batch_size, len(test_embds))]
    with torch.no_grad():
      pred_embds, _, _ = model(batch)
    all_preds = torch.cat((all_preds, pred_embds),dim=0)
  pred_embds = torch.Tensor(all_preds)

  submision = generate_submission(test_data['captions/ids'], pred_embds, 'submission.csv')
  return pred_embds

## Model training

### Best model trainig

In [None]:
data_keeper = get_train_data(split_ratio=1.0)# suing whole dataset for training

In [None]:
hidden_dim = 1024
latent_dim = 2560
kld_weight = 0.0001

model = VAEAdapter(input_dim=1024, output_dim=1536, hidden_dim=hidden_dim, latent_dim=latent_dim).to(DEVICE)
model.apply(init_weights)
batch_size =352
data_keeper.create_loaders(batch_size=batch_size)
epochs = 10
loss_fn = VAELoss(kld_weight=kld_weight, device=DEVICE)

In [None]:
model, logit_scale = train_model(model, data_keeper=data_keeper,epochs=epochs, loss_fn=loss_fn)


In [None]:
create_submisio_VAE(model, data_keeper)

# optuna experiments used to find best params of model

In [None]:
def center_tensor(t):
    return t - t.mean(dim=0, keepdim=True)

def linear_cka(X, Y, eps=1e-12):
    Xc = center_tensor(X)
    Yc = center_tensor(Y)
    num = torch.norm(Xc.T @ Yc, p='fro') ** 2
    den = (torch.norm(Xc.T @ Xc, p='fro') * torch.norm(Yc.T @ Yc, p='fro')).clamp(min=eps)
    return (num / den).clamp(min=0.0, max=1.0)

def rbf_kernel(X, sigma=None):
    pairwise_sq_dists = torch.cdist(X, X, p=2) ** 2
    if sigma is None:
        median = torch.median(pairwise_sq_dists[pairwise_sq_dists > 0])
        sigma = torch.sqrt(median / 2)
    gamma = 1 / (2 * sigma ** 2)
    return torch.exp(-gamma * pairwise_sq_dists)

def center_kernel_nonlienar(K):
    n = K.shape[0]
    H = torch.eye(n, device=K.device) - torch.ones((n, n), device=K.device) / n
    return H @ K @ H

def nonlinear_cka_torch(X, Y, sigma=None, eps=1e-12):
    Kx = center_kernel_nonlienar(rbf_kernel(X, sigma))
    Ky = center_kernel_nonlienar(rbf_kernel(Y, sigma))
    hsic = torch.trace(Kx @ Ky)
    norm_x = torch.sqrt(torch.trace(Kx @ Kx))
    norm_y = torch.sqrt(torch.trace(Ky @ Ky))
    return (hsic / (norm_x * norm_y + eps)).clamp(0, 1)

In [None]:
def evaluate_model_CKA(model, data_keeper, batch_size=BATCH_SIZE, device=DEVICE):
  model.eval()
  lin_cka = 0
  non_lin_cka = 0
  with torch.no_grad():
      for X_batch, y_batch in data_keeper.val_loader:
          X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
          pred_embds, _, _ = model(X_batch) # Unpack the tuple here
          lin_cka += linear_cka(pred_embds, y_batch)
          non_lin_cka += nonlinear_cka_torch(pred_embds, y_batch)
  print(f"non linear_cka: {non_lin_cka / len(data_keeper.val_loader)}")
  print(f"linear_cka: {lin_cka / len(data_keeper.val_loader)}")

In [None]:
data_keeper = get_train_data(split_ratio=0.9) # using only part of the dataset to training, the other part is used during validation

In [None]:
def objective(trial):
    hidden_dim =trial.suggest_int('hidden_dim', 512, 4096, step=256)
    latent_dim =trial.suggest_int('latent_dim', 512, 4096, step=256)
    kld_weight =trial.suggest_float('kld_weight', 1e-4, 1e-1, step=5e-4)

    model = VAEAdapter(input_dim=1024, output_dim=1536, hidden_dim=hidden_dim, latent_dim=latent_dim).to(DEVICE)
    model.apply(init_weights)
    batch_size =trial.suggest_int('batch_size', 256, 1024, step=32)
    data_keeper.create_loaders(batch_size=batch_size)
    epochs = 10
    loss_fn = VAELoss(kld_weight=kld_weight, device=DEVICE)

    best_metric, logit_scale = train_model(model, data_keeper=data_keeper,epochs=epochs, loss_fn=loss_fn)
    evaluate_model_CKA(model, data_keeper)
    metrics = model_metrics_reflow(model, data_keeper)
    return metrics['mrr']

In [None]:
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=50)

In [None]:
best_trial = study.best_trial

In [None]:
best_trial

In [None]:
df = study.trials_dataframe()
df.to_csv(WORKING_DIR/"optuna_results_VAE_10_epochs.csv", index=False)

In [None]:
fig, ax = plt.subplots(ncols=2, nrows=2, figsize=(12, 5))

ax[0][0].scatter(df['params_hidden_dim'], df['value'])
ax[0][0].set_xlabel('hidden_dim')
ax[0][0].set_ylabel('MRR')
ax[0][0].tick_params(axis='x', rotation=45)

ax[0][1].scatter(df['params_latent_dim'], df['value'])
ax[0][1].set_xlabel('latent_dim')
ax[0][1].set_ylabel('MRR')

ax[1][0].scatter(df['params_kld_weight'], df['value'])
ax[1][0].set_xlabel('kld_weight')
ax[1][0].set_ylabel('MRR')
ax[1][0].tick_params(axis='x', rotation=45)

ax[1][1].scatter(df['params_batch_size'], df['value'])
ax[1][1].set_xlabel('batch_size')
ax[1][1].set_ylabel('MRR')
ax[1][1].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()