## Notebook for the Crossflow Paper

### Import Libaries

In [None]:
#!mkdir data
#!gdown 1CVAQDuPOiwm8h9LJ8a_oOs6zOWS6EgkB
#!gdown 1ykZ9fjTxUwdiEwqagoYZiMcD5aG-7rHe
#!unzip -o test.zip -d data
#!unzip -o train.zip -d data
from google.colab import drive
drive.mount('/content/drive')
!git clone https://github.com/Mamiglia/challenge.git

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from pathlib import Path
from tqdm import tqdm

from challenge.src.common import load_data, prepare_train_data, generate_submission

### Create Neural Network Architectures

In [None]:
import math
class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        # t: (B,) in [0,1]
        half_dim = self.dim // 2
        emb = torch.exp(-torch.arange(half_dim, device=t.device) * (math.log(10000) / half_dim))
        emb = t.unsqueeze(1) * emb.unsqueeze(0)  # (B, half_dim)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)  # (B, dim)
        return emb

In [None]:
class LatentFlowMLP(nn.Module):
    def __init__(self, text_dim, img_dim, hidden_dim=1024, num_layers=3, time_emb_dim=32):
        super().__init__()
        self.text_dim = text_dim
        self.img_dim = img_dim
        self.hidden_dim = hidden_dim
        self.time_emb_dim = time_emb_dim

        # Procrustes linear projection (initialized later)
        self.proj = nn.Linear(text_dim, img_dim, bias=False)

        # Time embedding
        self.time_emb = TimeEmbedding(time_emb_dim)
        self.time_proj = nn.Linear(time_emb_dim, hidden_dim)

        # MLP layers
        layers = []
        in_dim = img_dim  # after projection
        for _ in range(num_layers - 1):
            layers.append(nn.Linear(in_dim, hidden_dim))
            layers.append(nn.GELU())
            in_dim = hidden_dim
        layers.append(nn.Linear(hidden_dim, img_dim))  # final output in image space
        self.mlp = nn.Sequential(*layers)

    def forward(self, t, z_text):
        # 1. Project text embeddings into image space
        z_proj = self.proj(z_text)

        # 2. Compute time embedding
        t_emb = self.time_emb(t)
        t_emb = self.time_proj(t_emb)

        # 3. Add residual time embedding
        h = z_proj + t_emb  # simple residual addition

        # 4. Pass through MLP
        v = self.mlp(h)  # (B, img_dim)
        return v, z_proj  # return velocity and projected text embedding


In [None]:
class ResidualFlowTransformer(nn.Module):
    def __init__(self, text_dim, img_dim, hidden_dim=512, num_layers=4, nhead=8, time_emb_dim=128, temperature=1.0):
        super().__init__()
        self.text_dim = text_dim
        self.img_dim = img_dim
        self.hidden_dim = hidden_dim
        self.temperature = temperature

        # Procrustes linear projection (initialized later)
        self.proj = nn.Linear(text_dim, img_dim, bias=False)

        # Time embedding
        self.time_emb = TimeEmbedding(time_emb_dim)
        self.time_proj = nn.Linear(time_emb_dim, hidden_dim)

        # Input projection
        self.input_proj = nn.Linear(img_dim, hidden_dim)  # now in image space

        # Transformer encoder layers
        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nhead,
                                                   dim_feedforward=hidden_dim*4, activation='gelu')
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Output projection
        self.output_proj = nn.Linear(hidden_dim, img_dim)

    def forward(self, t, z_text):
        
        # Project text into image space
        z_proj = self.proj(z_text)

        # Time embedding
        t_emb = self.time_emb(t)
        t_emb = self.time_proj(t_emb)

        # Input projection
        h = self.input_proj(z_proj)

        # Residual connection
        h = h + t_emb

        # Transformer
        h = h.unsqueeze(0)
        h = self.transformer(h)
        h = h.squeeze(0)

        # Output
        v = self.output_proj(h) * self.temperature
        return v, z_proj  # also return projected start


### Flow Matching loss and inference euler

In [None]:
def flow_matching_loss(v_net, z_text, z_img):
    B = z_text.size(0)
    t = torch.rand(B, device=z_text.device)

    # Forward pass
    v_pred, z_proj = v_net(t, z_text)

    # Linear path in image space
    gamma = (1 - t).unsqueeze(1) * z_proj + t.unsqueeze(1) * z_img
    gamma_dot = z_img - z_proj  # now shapes match

    # Predicted velocity at gamma
    v_pred_at_gamma = v_net(t, z_text)[0]  # same as v_pred
    loss = nn.MSELoss()(v_pred_at_gamma, gamma_dot)
    return loss

In [None]:
def integrate_euler(z_text, v_net, steps=50):
    dt = 1.0 / steps
    z = v_net.proj(z_text)  # start in image space
    for i in range(steps):
        t = torch.full((z.size(0),), i*dt, device=z.device)
        dz, _ = v_net(t, z_text)
        z = z + dz * dt
    return z

### Procrustes Init

In [None]:
# ====== Procrustes initialization ======
def procrustes_init(text_embs, img_embs):
    """
    text_embs: (N, d_text)
    img_embs:  (N, d_img)
    returns: weight matrix (d_img, d_text)
    """
    # Center both
    X = text_embs - text_embs.mean(0, keepdim=True)
    Y = img_embs - img_embs.mean(0, keepdim=True)

    # Compute SVD of cross-covariance
    U, _, Vt = torch.linalg.svd(X.T @ Y, full_matrices=False)
    W = U @ Vt  # orthogonal map d_text→d_img
    return W.T   # shape (d_img, d_text) for nn.Linear weight

def apply_procrustes_init_to_final(model, text_sample, img_sample):
    """Apply Procrustes initialization to the appropriate layer of the model."""
    with torch.no_grad():
        # Compute Procrustes matrix
        W = procrustes_init(text_embs=text_sample, img_embs=img_sample)

        # Apply to the appropriate layer
        applied = False
        for name, m in model.named_modules():
            if isinstance(model, ResidualFlowTransformer) and name.endswith("proj"):
                print(m.weight.shape, W.shape)
                if m.weight.shape == W.shape:
                    m.weight.copy_(W)
                    applied = True
                    break
        if not applied:
            print("⚠️ Warning: Could not find matching layer for Procrustes init")
    return model


### Load Data

In [None]:
# 3. Crossflow
# 4. Data Augmentation
# 5. Zero Shot Stitching
# 6. Diffusion Priors
# Configuration
EPOCHS = 60
BATCH_SIZE = 64
LR = 0.001
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load data
train_data = load_data("drive/MyDrive/data/train/train.npz")
X, y, label = prepare_train_data(train_data)
DATASET_SIZE = len(X)
# Split train/val
# This is done only to measure generalization capabilities, you don't have to
# use a validation set (though we encourage this)
n_train = int(0.9 * len(X))
TRAIN_SPLIT = torch.zeros(len(X), dtype=torch.bool)
TRAIN_SPLIT[:n_train] = 1
X_train, X_val = X[TRAIN_SPLIT], X[~TRAIN_SPLIT]
y_train, y_val = y[TRAIN_SPLIT], y[~TRAIN_SPLIT]


train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
y_train.shape, X_train.shape, train_loader.batch_size, val_loader.batch_size

### Training and Hyperparameter Optimization

In [None]:
def training_crossflow(model, train_loader, val_loader, device, epochs, lr, MODEL_PATH,
             use_procrustes_init=True, procrustes_subset=10000, temperature=0.07):
    """Train LatentSpaceTranslator with optional Procrustes init + InfoNCE loss."""
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=5e-3)
    best_val_loss = float('inf')

    # --- Optional: Procrustes initialization ---
    if use_procrustes_init:
        print("Computing Procrustes initialization...")
        text_list, img_list = [], []
        for i, (X, y) in enumerate(train_loader):
            text_list.append(X.cpu())
            img_list.append(y.cpu())
            if sum(t.shape[0] for t in text_list) >= procrustes_subset:
                break
        text_sample = torch.cat(text_list, dim=0)[:procrustes_subset]
        img_sample = torch.cat(img_list, dim=0)[:procrustes_subset]
        model = apply_procrustes_init_to_final(model, text_sample, img_sample)

    # --- Training ---
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        for X_batch, y_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()

            loss = flow_matching_loss(model, X_batch, y_batch)
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)

        # --- Validation ---
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)

                loss = flow_matching_loss(model, X_batch, y_batch)

                val_loss += loss.item()

        val_loss /= len(val_loader)
        print(f"Epoch {epoch+1}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}")

        # --- Save best model ---
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            Path(MODEL_PATH).parent.mkdir(parents=True, exist_ok=True)
            torch.save(model.state_dict(), MODEL_PATH)
            print(f"  ✓ Saved best model (val_loss={val_loss:.6f})")

    return model


In [None]:

#model = ResidualFlowTransformer(text_dim=1024, img_dim=1536, time_emb_dim=BATCH_SIZE, temperature=0.2).to(DEVICE)
model = LatentFlowMLP(text_dim=1024, img_dim=1536, time_emb_dim=BATCH_SIZE).to(DEVICE)
MODEL_PATH = "drive/MyDrive/data/models/crossflow.pth"
model = training_crossflow(
    model,
    train_loader,
    val_loader,
    DEVICE,
    EPOCHS,
    LR,
    MODEL_PATH,
)

### Inference

In [None]:

test_data = load_data("drive/MyDrive/data/test/test.clean.npz")

test_embds = test_data['captions/embeddings']
test_embds = torch.from_numpy(test_embds).float()

with torch.no_grad():
    img_pred = integrate_euler(test_embds, model, steps=50)
    pred_embds = model(test_embds.to(DEVICE)).cpu()

submission = generate_submission(test_data['captions/ids'], pred_embds, 'crossflow_submission.csv')
MODEL_PATH = "drive/MyDrive/data//models/crossflow.pth"
print(f"Model saved to: {MODEL_PATH}")