<a href="https://colab.research.google.com/github/Mamiglia/challenge/blob/master/baseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

❕ Add `kaggle.json` API token to files before starting

In [None]:
import os, subprocess

def run(cmd):
    """Helper to execute shell commands with logging."""
    print(f"▶ {cmd}")
    subprocess.run(cmd, shell=True, check=True)

# --- Kaggle setup ---
if not os.path.exists(os.path.expanduser("~/.kaggle/kaggle.json")):
    run("mkdir -p ~/.kaggle && cp kaggle.json ~/.kaggle/ && chmod 600 ~/.kaggle/kaggle.json")

# --- Kaggle data download (only if missing) ---
if not os.path.exists("data/train"):
    run("kaggle competitions download -c aml-competition -p data")
    run("unzip -qo data/aml-competition.zip -d data")

# --- Clone repositories ---
if not os.path.exists("challenge"):
    run("git clone https://github.com/Mamiglia/challenge.git")

# --- Install dependencies ---
!pip install -q torch torchvision torchaudio
!pip install -q openai-clip scikit-learn opencv-python torchdiffeq \
    beautifulsoup4 open_clip_torch scikit-image cython matplotlib accelerate \
    absl-py ml_collections einops wandb ftfy transformers timm tensorboard pycocotools

In [None]:
import torch, torch.nn as nn, torch.nn.functional as F, numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from pathlib import Path
from tqdm import tqdm
import sys, random, logging, math
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from IPython.display import clear_output

# Extend path to local repositories
sys.path.extend(["challenge/src"])

# Project imports
from challenge.src.common import load_data, prepare_train_data, generate_submission
from challenge.src.eval import visualize_retrieval

# Configure logging
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)
logging.basicConfig(
    level=logging.INFO,
    format="%(message)s",
    handlers=[logging.StreamHandler(sys.stdout)]
)

In [None]:
# Configuration dictionary
CFG = {
    "MODEL_PATH": "models/crossflow2.pth",
    "SEED": 42,
    "DEVICE": torch.device("cuda" if torch.cuda.is_available() else "cpu"),

    # --- Data Dimensions ---
    "TEXT_DIM": 1024,
    "LATENT_DIM": 1536,

    # --- Model Architecture ---
    "TIME_EMB_DIM": 256,
    "N_LAYERS_VE": 3,   # Variational Encoder layers
    "N_HEADS_VE": 8,
    "N_LAYERS_FLOW": 6, # Flow Matching layers
    "N_HEADS_FLOW": 8,

    # --- Training ---
    "EPOCHS": 2,
    "BATCH_SIZE": 16,
    "LR": 1e-4,
    "WEIGHT_DECAY": 0.03,
    "KL_WEIGHT": 1e-4,
    "UNCOND_RATE": 0.1,

    # --- Inference ---
    "INFERENCE_STEPS": 10,
    "GUIDANCE_SCALE": 1.7,
}

In [None]:
def set_seed(seed=42):
    """Ensure deterministic reproducibility."""
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

set_seed(CFG["SEED"])
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Model Definitions

In [None]:
# ============================================================
# Utility Functions
# ============================================================

def reparametrize(mu, logvar):
    std = (0.5 * logvar).exp()
    eps = torch.randn_like(std)
    return mu + eps * std

def compute_z_t_and_vhat(z0, z1, t, sigma_min=1e-5):
    """
    Flow interpolation & target velocity from CrossFlow paper (Eq. 1)
    """
    B, _, D = z0.shape
    t = t.view(B, 1, 1)  # ensure proper broadcast shape
    z_t = t * z1 + (1.0 - (1.0 - sigma_min) * t) * z0
    v_hat = z1 - (1.0 - sigma_min) * z0
    return z_t, v_hat

def sinusoidal_time_embedding(t, dim):
    """
    t: [B,1] with values in [0,1]
    returns: [B, dim] time embedding (sin/cos)
    """
    # follow common diffusion embeddings: scale by 2pi
    half = dim // 2
    freqs = torch.exp(-math.log(1e4) * torch.arange(0, half, dtype=torch.float32, device=t.device) / (half - 1))
    args = t * 2.0 * math.pi * freqs.view(1, -1)  # [B, half]
    emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
    if dim % 2 == 1:
        # pad one zero if odd
        emb = F.pad(emb, (0, 1))
    return emb  # [B, dim]

In [None]:
# ============================================================
# Components: Variational Encoder & Flow Transformer
# ============================================================

class VariationalEncoder(nn.Module):
    """
    Lightweight Transformer-based Variational Encoder.
    Input:  text embedding  [B, N, text_dim]
    Output: mu, logvar each [B, latent_dim]
    """
    def __init__(self, text_dim, latent_dim, n_layers=3, n_heads=8):
        super().__init__()
        self.proj_in = nn.Linear(text_dim, latent_dim)
        enc_layer = nn.TransformerEncoderLayer(
            d_model=latent_dim,
            nhead=n_heads,
            dim_feedforward=latent_dim * 4,
            activation="gelu",
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
        self.to_mu = nn.Linear(latent_dim, latent_dim)
        self.to_logvar = nn.Linear(latent_dim, latent_dim)

    def forward(self, x):  # x:[B,N,text_dim]
        h = self.proj_in(x)        # [B,N,latent_dim]
        h = h.permute(1, 0, 2)     # [N,B,E]
        h = self.encoder(h)
        h = h.permute(1, 0, 2)     # [B,N,E]
        pooled = h.mean(dim=1)     # [B,E]
        mu, logvar = self.to_mu(pooled), self.to_logvar(pooled)
        return mu, logvar

class FlowTransformer(nn.Module):
    """
    Predict velocity per sample. Supports per-sample indicator mask (boolean tensor [B]),
    and richer sinusoidal time embedding injection.
    """
    def __init__(self, latent_dim, n_layers=6, n_heads=8, time_emb_dim=256):
        super().__init__()
        self.latent_dim = latent_dim
        self.input_proj = nn.Linear(latent_dim, latent_dim)

        layer = nn.TransformerEncoderLayer(
            d_model=latent_dim,
            nhead=n_heads,
            dim_feedforward=latent_dim * 4,
            activation="gelu",
            batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(layer, num_layers=n_layers)
        self.time_mlp = nn.Sequential(
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, latent_dim)
        )

        # CFG tokens: two learnable tokens (conditional & unconditional)
        self.g_c = nn.Parameter(torch.randn(1, latent_dim))
        self.g_uc = nn.Parameter(torch.randn(1, latent_dim))

        # final output projection to velocity (per-token)
        self.out_proj = nn.Linear(latent_dim, latent_dim)
        self.time_emb_dim = time_emb_dim

    def forward(self, zt, t, indicator_mask):
        B, S, D = zt.shape
        h = self.input_proj(zt)  # [B,S,D]

        # Append per-sample indicator token
        token_cond = self.g_c.expand(B, -1, -1)
        token_uncond = self.g_uc.expand(B, -1, -1)
        token = torch.where(indicator_mask.view(B, 1, 1), token_cond, token_uncond)  # [B,1,D]
        h = torch.cat([h, token], dim=1)  # [B, S+1, D]

        # add time embedding (broadcast to tokens)
        te = sinusoidal_time_embedding(t, self.time_emb_dim)  # [B, time_emb_dim]
        te = self.time_mlp(te)  # [B, D]
        h = h + te.unsqueeze(1)  # broadcast to [B, S+1, D]

        h = self.transformer(h)  # [B, S+1, D]

        # drop the token output and map to velocity for tokens only
        h_tokens = h[:, :S, :]  # [B, S, D]
        v = self.out_proj(h_tokens)  # [B, S, D]
        return v

In [None]:
# ============================================================
# Composite CrossFlow model
# ============================================================

class CrossFlowModel(nn.Module):
    def __init__(self, text_dim, latent_dim):
        super().__init__()
        self.ve = VariationalEncoder(text_dim, latent_dim,
                                     n_layers=CFG["N_LAYERS_VE"], n_heads=CFG["N_HEADS_VE"])
        self.flow = FlowTransformer(latent_dim,
                                    n_layers=CFG["N_LAYERS_FLOW"], n_heads=CFG["N_HEADS_FLOW"],
                                    time_emb_dim=CFG["TIME_EMB_DIM"])

    def encoding_mu_logvar(self, text_emb):
        mu, logvar = self.ve(text_emb)
        return mu, logvar

    def sample_z0(self, mu, logvar):
        return reparametrize(mu, logvar)

    def predict_velocity(self, zt, t, indicator_mask):
        return self.flow(zt, t, indicator_mask)

    def forward_flow_training(self, text_emb, target_latent, t, indicator_mask):
        """
        Training forward: returns v_pred (per-sample), v_hat, mu, logvar
        """
        mu, logvar = self.ve(text_emb)
        z0 = reparametrize(mu, logvar)  # [B, latent_dim] flattened
        z1 = target_latent

        # reshape to token shape: [B, seq_len, D] where seq_len=1
        z0_toks = z0.unsqueeze(1)
        z1_toks = z1.unsqueeze(1)

        # compute z_t and v_hat
        z_t, v_hat = compute_z_t_and_vhat(z0_toks, z1_toks, t)

        z_t = z_t.view(z_t.size(0), 1, -1)
        v_hat = v_hat.view(v_hat.size(0), 1, -1)

        # predict
        v_pred = self.predict_velocity(z_t, t, indicator_mask)
        return v_pred, v_hat, mu, logvar

    def predict_z1_from_z0(self, z0, n_steps=10, guidance_scale=1.0, indicator_mask=None):
        """
        Integrate the flow from z0 -> z1_pred using Euler steps.
        If guidance_scale != 1.0, we compute both cond/uncond predictions inside integrator.
        """
        B = z0.shape[0]
        z = z0.unsqueeze(1)  # [B,1,D]
        device = z.device
        for i in range(n_steps):
            t = torch.full((B, 1), float(i + 1) / n_steps, device=device)
            if guidance_scale == 1.0:
                # single pass cond (assume conditioned)
                indicator_mask = torch.ones(B, dtype=torch.bool, device=device)
                v = self.predict_velocity(z, t, indicator_mask)
            else:
                # compute both conditional and unconditional predictions in batch
                cond_mask = torch.ones(B, dtype=torch.bool, device=device)
                uncond_mask = torch.zeros(B, dtype=torch.bool, device=device)
                v_cond = self.predict_velocity(z, t, cond_mask)
                v_uncond = self.predict_velocity(z, t, uncond_mask)
                v = guidance_scale * v_cond + (1.0 - guidance_scale) * v_uncond

            z = z + v / float(n_steps)  # Euler step
        z1_pred = z.squeeze(1)  # [B,D]
        return z1_pred

# Visualization Helpers

In [None]:
# Color palette
source_color = [0/255, 114/255, 178/255]
target_color = [213/255, 94/255, 0/255]
plt.style.use("default")

def plot_distributions(dist1, dist2, title1="Source", title2="Target"):
    """Project two embedding sets into 2D with PCA and plot side-by-side."""
    X1_2d = PCA(n_components=2).fit_transform(dist1.cpu())
    X2_2d = PCA(n_components=2).fit_transform(dist2.cpu())
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    ax1.scatter(X1_2d[:, 0], X1_2d[:, 1], color=source_color, alpha=0.6, s=8)
    ax2.scatter(X2_2d[:, 0], X2_2d[:, 1], color=target_color, alpha=0.6, s=8)
    ax1.set_title(title1); ax2.set_title(title2)
    for ax in (ax1, ax2):
        ax.set_aspect("equal"); ax.axis("off")
    plt.tight_layout(); plt.show(); plt.close()

# Data

In [None]:
"""
Prepare training and validation sets
"""
train_data = load_data("data/train/train/train.npz")
X, y, label = prepare_train_data(train_data)

# Split train/val
n_train = int(0.9 * len(X))
TRAIN_SPLIT = torch.zeros(len(X), dtype=torch.bool)
TRAIN_SPLIT[:n_train] = 1
X_train, X_val = X[TRAIN_SPLIT], X[~TRAIN_SPLIT]
y_train, y_val = y[TRAIN_SPLIT], y[~TRAIN_SPLIT]

# Data Loaders
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
train_loader = DataLoader(train_dataset, batch_size=CFG["BATCH_SIZE"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=CFG["BATCH_SIZE"])

# Pre-compute masks for validation/viz
img_VAL_SPLIT = label[~TRAIN_SPLIT].sum(dim=0) > 0
val_img_file = train_data['images/names'][img_VAL_SPLIT]
val_img_embd = torch.from_numpy(train_data['images/embeddings'][img_VAL_SPLIT])
val_label = np.nonzero(train_data['captions/label'][~TRAIN_SPLIT][:,img_VAL_SPLIT])[1]
val_caption_text = train_data['captions/text'][~TRAIN_SPLIT]

print(f"Train data: {len(X_train)} samples. Val data: {len(X_val)} samples.")

In [None]:
# Plot the data distributions
print("Plotting data distributions:")
plot_distributions(X_train[:2000], y_train[:2000])

# Training

In [None]:
def train_crossflow_model(model, train_loader, val_loader, n_epochs, lr):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=CFG["WEIGHT_DECAY"])
    model.to(CFG["DEVICE"])
    
    best_val = float("inf")
    Path(CFG["MODEL_PATH"]).parent.mkdir(parents=True, exist_ok=True)

    for epoch in range(n_epochs):
        model.train()
        train_loss = 0.0
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{n_epochs}", leave=False)
        for Xb, Yb in pbar:
            Xb, Yb = Xb.to(CFG["DEVICE"]), Yb.to(CFG["DEVICE"])
            t = torch.rand(Xb.size(0), 1, device=CFG["DEVICE"])

            # create boolean mask: True = conditioned, False = unconditioned (dropped)
            B = Xb.size(0)
            indicator_mask = torch.rand(B, device=CFG["DEVICE"]) > CFG["UNCOND_RATE"]
            
            v_pred, v_hat, mu, logvar = model.forward_flow_training(Xb.unsqueeze(1), Yb, t, indicator_mask)

            # Losses
            L_FM = F.mse_loss(v_pred, v_hat)
            L_KL = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
            # Simple cosine contrastive term
            L_enc = 1 - F.cosine_similarity(mu, Yb).mean()

            loss = L_FM + L_enc + CFG["KL_WEIGHT"] * L_KL
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)

        # Validation
        model.eval()
        retrieval_loss = 0.0
        with torch.no_grad():
            for Xb, Yb in val_loader:
                Xb, Yb = Xb.to(CFG["DEVICE"]), Yb.to(CFG["DEVICE"])
                mu, logvar = model.encoding_mu_logvar(Xb.unsqueeze(1))
                z0 = mu  # use mean for deterministic inference
                z1_pred = model.predict_z1_from_z0(z0, n_steps=10, guidance_scale=1.0)
                # retrieval metric: cosine
                sim = F.cosine_similarity(z1_pred, Yb, dim=-1)  # [B]
                batch_loss = 1.0 - sim.mean()
                retrieval_loss += batch_loss.item()
        retrieval_loss /= len(val_loader)

        print(f"> Epoch {epoch+1}: Train Loss {train_loss:.6f} | Val Loss {retrieval_loss:.6f}")

        if retrieval_loss < best_val:
            best_val = retrieval_loss
            torch.save(model.state_dict(), CFG["MODEL_PATH"])
            print(f"  ✓ Saved best model (val={retrieval_loss:.6f})")
    
    return model

In [None]:
# Initialize and Train
print("Initializing CrossFlow Model...")
model = CrossFlowModel(CFG["TEXT_DIM"], CFG["LATENT_DIM"])
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

print("\nStarting training...")
model = train_crossflow_model(model, train_loader, val_loader, CFG["EPOCHS"], CFG["LR"])

# Load best model for inference
model.load_state_dict(torch.load(CFG["MODEL_PATH"]))
model.to(CFG["DEVICE"])
print("Training complete.")

## Evaluation

In [None]:
def predict_embeddings(model, text_embs, n_steps=10, guidance_scale=1.0, deterministic=True, device=CFG["DEVICE"]):
    """
    Helper to generate image embeddings from text embeddings using the trained CrossFlow model.
    text_embs: torch.Tensor shape [B, text_dim] OR [B, 1, text_dim]
    Returns: pred_embds shape [B, latent_dim] (cpu tensor)
    """
    model.eval()
    with torch.no_grad():
        # ensure shape is [B, 1, text_dim] for the VE
        if text_embs.ndim == 1:
            text_embs = text_embs.unsqueeze(0)        # [1, text_dim]
        if text_embs.ndim == 2:
            text_in = text_embs.unsqueeze(1).to(device)  # [B,1,text_dim]
        else:
            text_in = text_embs.to(device)  # already [B,1,text_dim] or [B,N,text_dim]

        mu, logvar = model.encoding_mu_logvar(text_in)   # [B, latent_dim]
        if deterministic:
            z0 = mu
        else:
            z0 = model.sample_z0(mu, logvar)

        z1_pred = model.predict_z1_from_z0(z0, n_steps=n_steps, guidance_scale=guidance_scale)
        return z1_pred.cpu()

In [None]:
# Sample and visualize retrieval results
print("Visualizing retrieval examples...")
model.eval()
for i in range(5):
    idx = np.random.randint(0, 100)
    caption_embd = val_dataset[idx][0] # Get text embedding
    caption_text = val_caption_text[idx]
    gt_index = val_label[idx]

    # Predict
    pred_embds = predict_embeddings(
        model, 
        caption_embd.to(CFG["DEVICE"]), 
        n_steps=CFG["INFERENCE_STEPS"], 
        guidance_scale=CFG["GUIDANCE_SCALE"]
    ).squeeze(0)

    visualize_retrieval(
            pred_embds,
            gt_index,
            val_img_file,
            caption_text, val_img_embd, k=5,
            dataset_path="data/train/train")

## Submission

In [None]:
print("Generating submission file...")
test_data = load_data("data/test/test/test.clean.npz")
test_embds = torch.from_numpy(test_data['captions/embeddings']).float()

# Generate predicted embeddings in batches
all_preds = []
BS = 64
for i in range(0, len(test_embds), BS):
    batch = test_embds[i:i+BS]
    preds = predict_embeddings(
        model, 
        batch, 
        n_steps=CFG["INFERENCE_STEPS"], 
        guidance_scale=CFG["GUIDANCE_SCALE"], 
        device=CFG["DEVICE"]
    )
    all_preds.append(preds)

pred_embds = torch.cat(all_preds, dim=0)
print(f"Predicted test embeddings shape: {pred_embds.shape}")

generate_submission(test_data['captions/ids'], pred_embds, 'submission.csv')
print("submission.csv generated successfully.")