## Notebook for the Crossflow Paper

### Import Libaries

In [1]:
#!mkdir data
#!gdown 1CVAQDuPOiwm8h9LJ8a_oOs6zOWS6EgkB
#!gdown 1ykZ9fjTxUwdiEwqagoYZiMcD5aG-7rHe
#!unzip -o test.zip -d data
#!unzip -o train.zip -d data
from google.colab import drive
drive.mount('/content/drive')
!git clone https://github.com/Mamiglia/challenge.git

Mounted at /content/drive
Cloning into 'challenge'...
remote: Enumerating objects: 98, done.[K
remote: Counting objects: 100% (98/98), done.[K
remote: Compressing objects: 100% (69/69), done.[K
remote: Total 98 (delta 39), reused 72 (delta 26), pack-reused 0 (from 0)[K
Receiving objects: 100% (98/98), 21.03 MiB | 15.85 MiB/s, done.
Resolving deltas: 100% (39/39), done.


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from pathlib import Path
from tqdm import tqdm

from challenge.src.common import load_data, prepare_train_data, generate_submission

### Create Neural Network Architectures

- VAE-ENCODER (1024) -> LATENT SPACE (1536) -> VAE-DECODER (1024) train a VAE in parallel with the crossflow network
- CROSSFLOW GETS THE LATENT SPACE FROM VAE AS INPUT
-> INPUT (1024) -> VAE-ENCODER (1536) -> INPUT FOR CROSSFLOW -> CROSSFLOW TRANSFORMER -> OUTPUT FOR CROSSFLOW (1536)
- Use Text Embeddings as input for vae and image embeddings for crossflow
![image.png](attachment:image.png)

In [3]:
class MLPBlock(nn.Module):
    """Simple MLP block with LayerNorm, activation, and dropout"""
    def __init__(self, in_dim, out_dim, dropout=0.0, activation=nn.GELU):
        super().__init__()
        self.block = nn.Sequential(
            nn.LayerNorm(in_dim),
            nn.Linear(in_dim, out_dim),
            activation(),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.block(x)


class ContextMapperVAE(nn.Module):
    """
    Configurable VAE:
    - input_dim: dimension of the context embeddings
    - latent_dim: dimension of the target/image latent
    - num_layers: number of hidden layers for encoder/decoder
    - hidden_dim: width of hidden layers
    - dropout: dropout probability
    """
    def __init__(self, input_dim, latent_dim, num_layers=2, hidden_dim=512, dropout=0.1):
        super().__init__()
        # ---------------- Encoder ----------------
        enc_layers = []
        dim_in = input_dim
        for _ in range(num_layers):
            enc_layers.append(MLPBlock(dim_in, hidden_dim, dropout))
            dim_in = hidden_dim
        self.encoder_backbone = nn.Sequential(*enc_layers)
        self.encoder_head = nn.Linear(hidden_dim, latent_dim * 2)  # μ and logσ

        # ---------------- Decoder ----------------
        dec_layers = []
        dim_in = latent_dim
        for _ in range(num_layers):
            dec_layers.append(MLPBlock(dim_in, hidden_dim, dropout))
            dim_in = hidden_dim
        self.decoder_backbone = nn.Sequential(*dec_layers)
        self.decoder_head = nn.Linear(hidden_dim, input_dim)

    def forward(self, x):
        # --- Encoder path ---
        h = self.encoder_backbone(x)
        stats = self.encoder_head(h)
        mu, log_sigma = stats.chunk(2, dim=-1)
        sigma = torch.exp(log_sigma)
        eps = torch.randn_like(mu)
        z0 = mu + sigma * eps

        # --- Decoder path ---
        h_dec = self.decoder_backbone(z0)
        x_recon = self.decoder_head(h_dec)
        return z0, mu, log_sigma, x_recon

    def kl_loss(self, mu, log_sigma):
        # KL(q(z|x) || N(0,1))
        return -0.5 * torch.sum(1 + 2 * log_sigma - mu.pow(2) - torch.exp(2 * log_sigma), dim=-1).mean()


In [4]:
class TransformerFlow(nn.Module):
    """
    Configurable Transformer Flow model:
    - latent_dim: dimension of z_t
    - num_layers: number of transformer encoder layers
    - num_heads: number of attention heads
    - ff_dim: feed-forward hidden dimension
    - dropout: dropout in transformer layers
    """
    def __init__(self, latent_dim, num_layers=4, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=latent_dim,
            nhead=num_heads,
            dim_feedforward=ff_dim,
            dropout=dropout,
            activation='gelu',
            batch_first=True,
            norm_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.output = nn.Linear(latent_dim, latent_dim)

    def forward(self, z_t, t):
        # Add sinusoidal time embedding
        t_embed = self.time_embedding(t, z_t.size(-1))
        x = z_t + t_embed
        x = x.unsqueeze(1)  # transformer expects sequence
        x = self.transformer(x)
        return self.output(x.squeeze(1))

    def time_embedding(self, t, dim):
        half_dim = dim // 2
        freqs = torch.exp(
            torch.arange(half_dim, device=t.device) * (-torch.log(torch.tensor(10000.0)) / (half_dim - 1))
        )
        angles = t.unsqueeze(1) * freqs.unsqueeze(0)
        emb = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
        return emb


In [5]:
@torch.no_grad()
def integrate_flow(flow, z0, n_steps=20):
    z = z0.clone()
    t_values = torch.linspace(0, 1, n_steps, device=z0.device)
    dt = 1.0 / n_steps
    for t in t_values:
        v = flow(z, t.repeat(z.size(0)))
        z = z + dt * v
    return z


In [6]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np

def train_epoch(train_loader, vae, flow, optimizer,
                lambda_kl=1e-2,
                temperature=0.07, queue_size=4098, device="cuda", epoch = 0):
    vae.train()
    flow.train()
    total_loss, total_fm, total_enc, total_kl = 0, 0, 0, 0
    criterion = QueueInfoNCELoss(dim=1536, temperature=temperature, queue_size=queue_size).to(device)

    for X_batch, y_batch in tqdm(train_loader, desc=f"Epoch: {epoch}"):
        context = X_batch.to(device)
        image = y_batch.to(device)

        z0, mu, log_sigma, recon = vae(context)
        z1 = image

        t = torch.rand(z0.size(0), 1, device=device)
        z_t = (1 - t) * z0 + t * z1
        v_hat = z1 - z0
        v_pred = flow(z_t, t.squeeze())

        L_Enc = criterion(z0, z1)
        L_FM = F.mse_loss(v_pred, v_hat)
        L_KL = vae.kl_loss(mu, log_sigma)

        loss = L_FM + L_Enc + lambda_kl * L_KL

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_fm += L_FM.item()
        total_enc += L_Enc.item()
        total_kl += L_KL.item()

        with torch.no_grad():
          img_keys = F.normalize(image, dim=1).detach()
          # put them into the queue
          criterion._enqueue(keys=img_keys)

    n = len(train_loader)
    return {
        "loss": total_loss / n,
        "L_FM": total_fm / n,
        "L_Enc": total_enc / n,
        "L_KL": total_kl / n
    }


@torch.no_grad()
def validate_epoch(val_loader, vae, flow, device="cuda", n_steps=20):
    vae.eval()
    flow.eval()
    cos_sims = []

    for X_batch, y_batch in tqdm(val_loader, desc="Validation"):
        context = X_batch.to(device)
        image = y_batch.to(device)

        # Encode to z0
        z0, _, _, _ = vae(context)

        # Integrate flow to predict target embedding
        z1_pred = integrate_flow(flow, z0, n_steps=n_steps)
        z1_true = image

        # Compute cosine similarity between predicted and true image embeddings
        cos_sim = F.cosine_similarity(z1_pred, z1_true, dim=-1)
        cos_sims.append(cos_sim.cpu().numpy())

    cos_sims = np.concatenate(cos_sims)
    mean_cosine = np.mean(cos_sims)
    acc_80 = np.mean(cos_sims > 0.8)  # how often similarity > 0.8
    acc_90 = np.mean(cos_sims > 0.9)

    return {
        "mean_cosine": mean_cosine,
        "acc@0.8": acc_80,
        "acc@0.9": acc_90
    }


In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class QueueInfoNCELoss(nn.Module):
    """
    One-directional (text → image) InfoNCE with optional sinusoidal time embedding.
    """
    def __init__(self, dim, temperature=0.07, queue_size=4096, use_time_embedding=True):
        super().__init__()
        self.temperature = temperature
        self.queue_size = queue_size

        # Single queue for image embeddings
        self.register_buffer("queue", F.normalize(torch.randn(queue_size, dim), dim=1))
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))



    # ------------------------------
    # Queue Management
    # ------------------------------
    @torch.no_grad()
    def _enqueue(self, keys):
        """Add image keys (B, dim) to the circular queue after backward()."""
        bsz = keys.shape[0]
        keys = F.normalize(keys, dim=1)

        ptr = int(self.queue_ptr.item())
        end_ptr = (ptr + bsz) % self.queue_size

        if end_ptr > ptr:
            self.queue[ptr:end_ptr] = keys
        else:
            first_len = self.queue_size - ptr
            self.queue[ptr:] = keys[:first_len]
            self.queue[:end_ptr] = keys[first_len:]

        self.queue_ptr[0] = end_ptr

    # ------------------------------
    # Forward Pass
    # ------------------------------
    def forward(self, z_text, z_img):
        """
        z_text: (B, dim) predicted text→image latent
        z_img:  (B, dim) target image latent
        """

        # Normalize embeddings
        z_text = F.normalize(z_text, dim=1)
        z_img = F.normalize(z_img, dim=1)

        # Positive logits: (B, 1)
        l_pos = torch.sum(z_text * z_img, dim=-1, keepdim=True)

        # Negatives: (B, queue_size)
        l_neg = torch.matmul(z_text, self.queue.T)

        # Combine and scale by temperature
        logits = torch.cat([l_pos, l_neg], dim=1) / self.temperature
        labels = torch.zeros(logits.size(0), dtype=torch.long, device=z_text.device)

        loss = F.cross_entropy(logits, labels)
        return loss


In [8]:
# ====== Procrustes initialization ======
def procrustes_init(text_embs, img_embs):
    """
    text_embs: (N, d_text)
    img_embs:  (N, d_img)
    returns: weight matrix (d_img, d_text)
    """
    # Center both
    X = text_embs - text_embs.mean(0, keepdim=True)
    Y = img_embs - img_embs.mean(0, keepdim=True)

    # Compute SVD of cross-covariance
    U, _, Vt = torch.linalg.svd(X.T @ Y, full_matrices=False)
    W = U @ Vt  # orthogonal map d_text→d_img
    return W.T   # shape (d_img, d_text) for nn.Linear weight


def apply_procrustes_init_to_final(model, text_sample, img_sample):
    """
    Apply Procrustes initialization to a model
    """
    with torch.no_grad():
        # Compute Procrustes matrix
        W = procrustes_init(text_embs=text_sample, img_embs=img_sample)

        # Apply to the appropriate layer
        applied = False
        for name, m in model.named_modules():
            if isinstance(m, nn.Linear):
                # Transformer: apply to first projection (proj_in)
                if isinstance(model, TransformerFlow) and name.endswith("output"):
                    print(m.weight.shape, W.shape)
                    if m.weight.shape == W.shape:
                        m.weight.copy_(W)
                        applied = True
                        break
        if not applied:
            print("⚠️ Warning: Could not find matching layer for Procrustes init")
    return model


### Load Data

In [13]:
# 3. Crossflow
# 4. Data Augmentation
# 5. Zero Shot Stitching
# 6. Diffusion Priors
# Configuration
EPOCHS = 60
BATCH_SIZE = 128
LR = 0.001
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load data
train_data = load_data("drive/MyDrive/data/train/train.npz")
X, y, label = prepare_train_data(train_data)
DATASET_SIZE = len(X)
# Split train/val
# This is done only to measure generalization capabilities, you don't have to
# use a validation set (though we encourage this)
n_train = int(0.9 * len(X))
TRAIN_SPLIT = torch.zeros(len(X), dtype=torch.bool)
TRAIN_SPLIT[:n_train] = 1
X_train, X_val = X[TRAIN_SPLIT], X[~TRAIN_SPLIT]
y_train, y_val = y[TRAIN_SPLIT], y[~TRAIN_SPLIT]


train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
y_train.shape, X_train.shape, train_loader.batch_size, val_loader.batch_size

(125000,)
Train data: 125000 captions, 125000 images


(torch.Size([112500, 1536]), torch.Size([112500, 1024]), 128, 128)

In [14]:

vae = ContextMapperVAE(
    input_dim=1024, latent_dim=1536,
    num_layers=3, hidden_dim=1024, dropout=0.1
).to(DEVICE)

flow = TransformerFlow(
    latent_dim=1536, num_layers=6,
    num_heads=8, ff_dim=1024, dropout=0.1
).to(DEVICE)

optimizer = torch.optim.AdamW(
    list(vae.parameters()) + list(flow.parameters()),
    lr=1e-4, weight_decay=0.01
)

procrustes_init = False
if procrustes_init:
  print("Computing Procrustes initialization...")
  text_list, img_list = [], []
  for i, (X, y) in enumerate(train_loader):
      text_list.append(X.cpu())
      img_list.append(y.cpu())
      if sum(t.shape[0] for t in text_list) >= 20000:
            break
  text_sample = torch.cat(text_list, dim=0)[:20000]
  img_sample = torch.cat(img_list, dim=0)[:20000]
  flow = apply_procrustes_init_to_final(flow, text_sample, img_sample)

for epoch in range(EPOCHS):
    train_metrics = train_epoch(train_loader, vae, flow, optimizer, device=DEVICE, epoch=epoch)
    val_metrics = validate_epoch(val_loader, vae, flow, device=DEVICE)
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    print(f"Train: {train_metrics}")
    print(f"Val: {val_metrics}")

Epoch: 0: 100%|██████████| 879/879 [00:44<00:00, 19.84it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.15it/s]



Epoch 1/60
Train: {'loss': 9.143248772865268, 'L_FM': 1.0541024583055976, 'L_Enc': 7.68200744893635, 'L_KL': 40.71389145791463}
Val: {'mean_cosine': np.float32(0.48039296), 'acc@0.8': np.float64(0.0), 'acc@0.9': np.float64(0.0)}


Epoch: 1: 100%|██████████| 879/879 [00:44<00:00, 19.86it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.11it/s]



Epoch 2/60
Train: {'loss': 8.60262702837739, 'L_FM': 0.7771246161341532, 'L_Enc': 7.287264568406974, 'L_KL': 53.82378841395807}
Val: {'mean_cosine': np.float32(0.53193843), 'acc@0.8': np.float64(0.0), 'acc@0.9': np.float64(0.0)}


Epoch: 2: 100%|██████████| 879/879 [00:44<00:00, 19.78it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.09it/s]



Epoch 3/60
Train: {'loss': 8.440742568622541, 'L_FM': 0.6993255115477569, 'L_Enc': 7.158354883443509, 'L_KL': 58.30622033996929}
Val: {'mean_cosine': np.float32(0.5685683), 'acc@0.8': np.float64(0.0), 'acc@0.9': np.float64(0.0)}


Epoch: 3: 100%|██████████| 879/879 [00:44<00:00, 19.79it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.06it/s]



Epoch 4/60
Train: {'loss': 8.335648655484562, 'L_FM': 0.6490890732939876, 'L_Enc': 7.073069471005558, 'L_KL': 61.3490133882247}
Val: {'mean_cosine': np.float32(0.5966239), 'acc@0.8': np.float64(0.0), 'acc@0.9': np.float64(0.0)}


Epoch: 4: 100%|██████████| 879/879 [00:44<00:00, 19.75it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.04it/s]



Epoch 5/60
Train: {'loss': 8.259847978954294, 'L_FM': 0.6164084380668449, 'L_Enc': 7.003061036336679, 'L_KL': 64.03785351437513}
Val: {'mean_cosine': np.float32(0.61390096), 'acc@0.8': np.float64(0.0), 'acc@0.9': np.float64(0.0)}


Epoch: 5: 100%|██████████| 879/879 [00:44<00:00, 19.76it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.05it/s]



Epoch 6/60
Train: {'loss': 8.196227139309133, 'L_FM': 0.5907615651720891, 'L_Enc': 6.941787271640678, 'L_KL': 66.36783254838234}
Val: {'mean_cosine': np.float32(0.63032216), 'acc@0.8': np.float64(0.0), 'acc@0.9': np.float64(0.0)}


Epoch: 6: 100%|██████████| 879/879 [00:44<00:00, 19.74it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.05it/s]



Epoch 7/60
Train: {'loss': 8.143728085345376, 'L_FM': 0.5698591362223012, 'L_Enc': 6.888723041959726, 'L_KL': 68.51459301806419}
Val: {'mean_cosine': np.float32(0.64150774), 'acc@0.8': np.float64(0.00032), 'acc@0.9': np.float64(0.0)}


Epoch: 7: 100%|██████████| 879/879 [00:44<00:00, 19.76it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.05it/s]



Epoch 8/60
Train: {'loss': 8.0970810509379, 'L_FM': 0.5535012581920732, 'L_Enc': 6.839859079853531, 'L_KL': 70.37207347188522}
Val: {'mean_cosine': np.float32(0.6497806), 'acc@0.8': np.float64(0.00112), 'acc@0.9': np.float64(0.0)}


Epoch: 8: 100%|██████████| 879/879 [00:44<00:00, 19.77it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.02it/s]



Epoch 9/60
Train: {'loss': 8.056312880554026, 'L_FM': 0.5402128165153377, 'L_Enc': 6.794677210342355, 'L_KL': 72.1422865415189}
Val: {'mean_cosine': np.float32(0.65958875), 'acc@0.8': np.float64(0.00256), 'acc@0.9': np.float64(0.0)}


Epoch: 9: 100%|██████████| 879/879 [00:44<00:00, 19.77it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.05it/s]



Epoch 10/60
Train: {'loss': 8.016995757648697, 'L_FM': 0.5271031620559432, 'L_Enc': 6.752683394326827, 'L_KL': 73.72092118333767}
Val: {'mean_cosine': np.float32(0.66612226), 'acc@0.8': np.float64(0.00448), 'acc@0.9': np.float64(0.0)}


Epoch: 10: 100%|██████████| 879/879 [00:44<00:00, 19.80it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.02it/s]



Epoch 11/60
Train: {'loss': 7.983970317688683, 'L_FM': 0.5161192126580608, 'L_Enc': 6.716485377193447, 'L_KL': 75.13657474626318}
Val: {'mean_cosine': np.float32(0.6712961), 'acc@0.8': np.float64(0.00672), 'acc@0.9': np.float64(0.0)}


Epoch: 11: 100%|██████████| 879/879 [00:44<00:00, 19.77it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.04it/s]



Epoch 12/60
Train: {'loss': 7.956691023161523, 'L_FM': 0.5083538368429069, 'L_Enc': 6.6831863737486055, 'L_KL': 76.51508363413458}
Val: {'mean_cosine': np.float32(0.6760038), 'acc@0.8': np.float64(0.0084), 'acc@0.9': np.float64(0.0)}


Epoch: 12: 100%|██████████| 879/879 [00:44<00:00, 19.82it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.04it/s]



Epoch 13/60
Train: {'loss': 7.928457873673162, 'L_FM': 0.5003014675450678, 'L_Enc': 6.650388831571507, 'L_KL': 77.77675948072483}
Val: {'mean_cosine': np.float32(0.6811767), 'acc@0.8': np.float64(0.012), 'acc@0.9': np.float64(0.0)}


Epoch: 13: 100%|██████████| 879/879 [00:44<00:00, 19.80it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.04it/s]



Epoch 14/60
Train: {'loss': 7.904451550341575, 'L_FM': 0.4936102330006566, 'L_Enc': 6.621774338753693, 'L_KL': 78.90670003880142}
Val: {'mean_cosine': np.float32(0.6838866), 'acc@0.8': np.float64(0.01352), 'acc@0.9': np.float64(0.0)}


Epoch: 14: 100%|██████████| 879/879 [00:44<00:00, 19.77it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.03it/s]



Epoch 15/60
Train: {'loss': 7.8818564078642375, 'L_FM': 0.486890540018559, 'L_Enc': 6.594749437121672, 'L_KL': 80.02164633689073}
Val: {'mean_cosine': np.float32(0.6878283), 'acc@0.8': np.float64(0.01552), 'acc@0.9': np.float64(0.0)}


Epoch: 15: 100%|██████████| 879/879 [00:44<00:00, 19.80it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.04it/s]



Epoch 16/60
Train: {'loss': 7.859283630622803, 'L_FM': 0.4802339755091922, 'L_Enc': 6.568176665539356, 'L_KL': 81.0873015298507}
Val: {'mean_cosine': np.float32(0.69098634), 'acc@0.8': np.float64(0.02096), 'acc@0.9': np.float64(0.0)}


Epoch: 16: 100%|██████████| 879/879 [00:44<00:00, 19.83it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.06it/s]



Epoch 17/60
Train: {'loss': 7.8404501908468305, 'L_FM': 0.4753021180155063, 'L_Enc': 6.544250637462386, 'L_KL': 82.08974659022482}
Val: {'mean_cosine': np.float32(0.6905968), 'acc@0.8': np.float64(0.02096), 'acc@0.9': np.float64(0.0)}


Epoch: 17: 100%|██████████| 879/879 [00:44<00:00, 19.82it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.04it/s]



Epoch 18/60
Train: {'loss': 7.822560866945026, 'L_FM': 0.47201829375667376, 'L_Enc': 6.519710877107136, 'L_KL': 83.08317208317224}
Val: {'mean_cosine': np.float32(0.6939065), 'acc@0.8': np.float64(0.02368), 'acc@0.9': np.float64(0.0)}


Epoch: 18: 100%|██████████| 879/879 [00:44<00:00, 19.83it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.04it/s]



Epoch 19/60
Train: {'loss': 7.804838669856118, 'L_FM': 0.465625451320403, 'L_Enc': 6.499716423477329, 'L_KL': 83.94968317962751}
Val: {'mean_cosine': np.float32(0.69627064), 'acc@0.8': np.float64(0.028), 'acc@0.9': np.float64(0.0)}


Epoch: 19: 100%|██████████| 879/879 [00:44<00:00, 19.80it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.06it/s]



Epoch 20/60
Train: {'loss': 7.791570451886174, 'L_FM': 0.4633723386861087, 'L_Enc': 6.479377518741751, 'L_KL': 84.8820620939322}
Val: {'mean_cosine': np.float32(0.6983289), 'acc@0.8': np.float64(0.03072), 'acc@0.9': np.float64(0.0)}


Epoch: 20: 100%|██████████| 879/879 [00:44<00:00, 19.85it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.03it/s]



Epoch 21/60
Train: {'loss': 7.775984715818681, 'L_FM': 0.45842292952320546, 'L_Enc': 6.460955616014673, 'L_KL': 85.6606200633738}
Val: {'mean_cosine': np.float32(0.70159143), 'acc@0.8': np.float64(0.03328), 'acc@0.9': np.float64(0.0)}


Epoch: 21: 100%|██████████| 879/879 [00:44<00:00, 19.82it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.02it/s]



Epoch 22/60
Train: {'loss': 7.762532782093521, 'L_FM': 0.45591033227207717, 'L_Enc': 6.441729748479607, 'L_KL': 86.48927245503535}
Val: {'mean_cosine': np.float32(0.7027368), 'acc@0.8': np.float64(0.03848), 'acc@0.9': np.float64(0.0)}


Epoch: 22: 100%|██████████| 879/879 [00:44<00:00, 19.83it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.04it/s]



Epoch 23/60
Train: {'loss': 7.749809590622833, 'L_FM': 0.45204003502755713, 'L_Enc': 6.42624861062997, 'L_KL': 87.15209829007128}
Val: {'mean_cosine': np.float32(0.7049603), 'acc@0.8': np.float64(0.03896), 'acc@0.9': np.float64(0.0)}


Epoch: 23: 100%|██████████| 879/879 [00:44<00:00, 19.80it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.04it/s]



Epoch 24/60
Train: {'loss': 7.734932790979726, 'L_FM': 0.44759833551646633, 'L_Enc': 6.408891253639542, 'L_KL': 87.84432260385282}
Val: {'mean_cosine': np.float32(0.705894), 'acc@0.8': np.float64(0.04296), 'acc@0.9': np.float64(0.0)}


Epoch: 24: 100%|██████████| 879/879 [00:44<00:00, 19.82it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.04it/s]



Epoch 25/60
Train: {'loss': 7.726180695565216, 'L_FM': 0.4456165247470717, 'L_Enc': 6.39523260097048, 'L_KL': 88.53315972685135}
Val: {'mean_cosine': np.float32(0.7078662), 'acc@0.8': np.float64(0.04704), 'acc@0.9': np.float64(0.0)}


Epoch: 25: 100%|██████████| 879/879 [00:44<00:00, 19.86it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.06it/s]



Epoch 26/60
Train: {'loss': 7.715458120773541, 'L_FM': 0.4425601926001808, 'L_Enc': 6.381912112642881, 'L_KL': 89.09858591646058}
Val: {'mean_cosine': np.float32(0.7103044), 'acc@0.8': np.float64(0.052), 'acc@0.9': np.float64(0.0)}


Epoch: 26: 100%|██████████| 879/879 [00:44<00:00, 19.80it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.02it/s]



Epoch 27/60
Train: {'loss': 7.703561726419321, 'L_FM': 0.4389608974342867, 'L_Enc': 6.367936195638264, 'L_KL': 89.66646619325883}
Val: {'mean_cosine': np.float32(0.71117526), 'acc@0.8': np.float64(0.05464), 'acc@0.9': np.float64(0.0)}


Epoch: 27: 100%|██████████| 879/879 [00:44<00:00, 19.79it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.04it/s]



Epoch 28/60
Train: {'loss': 7.693980132353591, 'L_FM': 0.4377976039231705, 'L_Enc': 6.354378825568502, 'L_KL': 90.18037283488373}
Val: {'mean_cosine': np.float32(0.71266633), 'acc@0.8': np.float64(0.05648), 'acc@0.9': np.float64(0.0)}


Epoch: 28: 100%|██████████| 879/879 [00:44<00:00, 19.82it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.06it/s]



Epoch 29/60
Train: {'loss': 7.683264529474495, 'L_FM': 0.4341377517604719, 'L_Enc': 6.341152547026928, 'L_KL': 90.79742552287479}
Val: {'mean_cosine': np.float32(0.71296096), 'acc@0.8': np.float64(0.06112), 'acc@0.9': np.float64(0.0)}


Epoch: 29: 100%|██████████| 879/879 [00:44<00:00, 19.85it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.04it/s]



Epoch 30/60
Train: {'loss': 7.674263653087942, 'L_FM': 0.4315088958870429, 'L_Enc': 6.330144842601338, 'L_KL': 91.2609937144899}
Val: {'mean_cosine': np.float32(0.7143393), 'acc@0.8': np.float64(0.06168), 'acc@0.9': np.float64(0.0)}


Epoch: 30: 100%|██████████| 879/879 [00:44<00:00, 19.79it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.06it/s]



Epoch 31/60
Train: {'loss': 7.666393678207311, 'L_FM': 0.4282534469313182, 'L_Enc': 6.320570390243444, 'L_KL': 91.75698611245356}
Val: {'mean_cosine': np.float32(0.7142351), 'acc@0.8': np.float64(0.06248), 'acc@0.9': np.float64(0.0)}


Epoch: 31: 100%|██████████| 879/879 [00:44<00:00, 19.77it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.03it/s]



Epoch 32/60
Train: {'loss': 7.656889075712132, 'L_FM': 0.42595090522972257, 'L_Enc': 6.308683554873938, 'L_KL': 92.22546498685973}
Val: {'mean_cosine': np.float32(0.7146803), 'acc@0.8': np.float64(0.06392), 'acc@0.9': np.float64(0.0)}


Epoch: 32: 100%|██████████| 879/879 [00:44<00:00, 19.80it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.03it/s]



Epoch 33/60
Train: {'loss': 7.649169050657166, 'L_FM': 0.4237673271434706, 'L_Enc': 6.298999958342117, 'L_KL': 92.64017973499493}
Val: {'mean_cosine': np.float32(0.71725744), 'acc@0.8': np.float64(0.07064), 'acc@0.9': np.float64(0.0)}


Epoch: 33: 100%|██████████| 879/879 [00:44<00:00, 19.81it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.06it/s]



Epoch 34/60
Train: {'loss': 7.6401299160902, 'L_FM': 0.4206893199268597, 'L_Enc': 6.288448492685954, 'L_KL': 93.09921466015847}
Val: {'mean_cosine': np.float32(0.718435), 'acc@0.8': np.float64(0.07648), 'acc@0.9': np.float64(0.0)}


Epoch: 34: 100%|██████████| 879/879 [00:44<00:00, 19.78it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.03it/s]



Epoch 35/60
Train: {'loss': 7.63313534533747, 'L_FM': 0.4189027388144679, 'L_Enc': 6.2798113963981, 'L_KL': 93.44212428972854}
Val: {'mean_cosine': np.float32(0.7180379), 'acc@0.8': np.float64(0.0744), 'acc@0.9': np.float64(0.0)}


Epoch: 35: 100%|██████████| 879/879 [00:44<00:00, 19.78it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.04it/s]



Epoch 36/60
Train: {'loss': 7.626776054586296, 'L_FM': 0.4173604235714206, 'L_Enc': 6.27058565901408, 'L_KL': 93.8830015162966}
Val: {'mean_cosine': np.float32(0.71888655), 'acc@0.8': np.float64(0.07832), 'acc@0.9': np.float64(0.0)}


Epoch: 36: 100%|██████████| 879/879 [00:44<00:00, 19.77it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.04it/s]



Epoch 37/60
Train: {'loss': 7.620087976748626, 'L_FM': 0.4152092989597711, 'L_Enc': 6.2630341489702905, 'L_KL': 94.18445813696539}
Val: {'mean_cosine': np.float32(0.71950984), 'acc@0.8': np.float64(0.0808), 'acc@0.9': np.float64(0.0)}


Epoch: 37: 100%|██████████| 879/879 [00:44<00:00, 19.77it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.05it/s]



Epoch 38/60
Train: {'loss': 7.611228298951064, 'L_FM': 0.4119479765731998, 'L_Enc': 6.254315348615418, 'L_KL': 94.49650174467502}
Val: {'mean_cosine': np.float32(0.7209269), 'acc@0.8': np.float64(0.08336), 'acc@0.9': np.float64(8e-05)}


Epoch: 38: 100%|██████████| 879/879 [00:44<00:00, 19.82it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.04it/s]



Epoch 39/60
Train: {'loss': 7.606903798208573, 'L_FM': 0.4113230551333422, 'L_Enc': 6.247157337180041, 'L_KL': 94.84234352675992}
Val: {'mean_cosine': np.float32(0.7219612), 'acc@0.8': np.float64(0.09184), 'acc@0.9': np.float64(0.00016)}


Epoch: 39: 100%|██████████| 879/879 [00:44<00:00, 19.81it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.04it/s]



Epoch 40/60
Train: {'loss': 7.601530265482619, 'L_FM': 0.40890661573654147, 'L_Enc': 6.240888339531978, 'L_KL': 95.17353372931888}
Val: {'mean_cosine': np.float32(0.7235259), 'acc@0.8': np.float64(0.09456), 'acc@0.9': np.float64(8e-05)}


Epoch: 40: 100%|██████████| 879/879 [00:44<00:00, 19.80it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.04it/s]



Epoch 41/60
Train: {'loss': 7.595108848111757, 'L_FM': 0.4071969252261964, 'L_Enc': 6.232819449779524, 'L_KL': 95.50925014179043}
Val: {'mean_cosine': np.float32(0.72330946), 'acc@0.8': np.float64(0.09344), 'acc@0.9': np.float64(8e-05)}


Epoch: 41: 100%|██████████| 879/879 [00:44<00:00, 19.79it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.07it/s]



Epoch 42/60
Train: {'loss': 7.5873125860715485, 'L_FM': 0.4048506285279006, 'L_Enc': 6.224539677030804, 'L_KL': 95.79223162376569}
Val: {'mean_cosine': np.float32(0.7244998), 'acc@0.8': np.float64(0.10088), 'acc@0.9': np.float64(0.00016)}


Epoch: 42: 100%|██████████| 879/879 [00:44<00:00, 19.78it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.03it/s]



Epoch 43/60
Train: {'loss': 7.582003927068092, 'L_FM': 0.40365271071932013, 'L_Enc': 6.216731500571362, 'L_KL': 96.16197558718737}
Val: {'mean_cosine': np.float32(0.72389233), 'acc@0.8': np.float64(0.09424), 'acc@0.9': np.float64(8e-05)}


Epoch: 43: 100%|██████████| 879/879 [00:44<00:00, 19.80it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.05it/s]



Epoch 44/60
Train: {'loss': 7.5769696935451885, 'L_FM': 0.4012463430495799, 'L_Enc': 6.212297222584452, 'L_KL': 96.34261620600748}
Val: {'mean_cosine': np.float32(0.7235843), 'acc@0.8': np.float64(0.0948), 'acc@0.9': np.float64(0.00016)}


Epoch: 44: 100%|██████████| 879/879 [00:44<00:00, 19.83it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.03it/s]



Epoch 45/60
Train: {'loss': 7.569884255748831, 'L_FM': 0.39846958391224296, 'L_Enc': 6.204688200229128, 'L_KL': 96.67265067246994}
Val: {'mean_cosine': np.float32(0.7243731), 'acc@0.8': np.float64(0.0976), 'acc@0.9': np.float64(0.0)}


Epoch: 45: 100%|██████████| 879/879 [00:44<00:00, 19.82it/s]
Validation: 100%|██████████| 98/98 [00:09<00:00, 10.03it/s]



Epoch 46/60
Train: {'loss': 7.566512099711968, 'L_FM': 0.39779314849699454, 'L_Enc': 6.199773472730617, 'L_KL': 96.89455099615763}
Val: {'mean_cosine': np.float32(0.7244087), 'acc@0.8': np.float64(0.09752), 'acc@0.9': np.float64(0.00016)}


Epoch: 46:  53%|█████▎    | 463/879 [00:23<00:21, 19.79it/s]


KeyboardInterrupt: 

### Training and Hyperparameter Optimization

### Inference

In [17]:

test_data = load_data("drive/MyDrive/data/test/test.clean.npz")

test_embds = test_data['captions/embeddings']
test_embds = torch.from_numpy(test_embds).float().to(DEVICE)

with torch.no_grad():
    # Encode to z0
    z0, _, _, _ = vae(test_embds)
    # Integrate flow to predict target embedding
    z1_pred = integrate_flow(flow, z0, n_steps=50)

submission = generate_submission(test_data['captions/ids'], z1_pred, 'drive/MyDrive/data/crossflow_submission.csv')
MODEL_PATH = "drive/MyDrive/data//models/crossflow.pth"
print(f"Model saved to: {MODEL_PATH}")

Generating submission file...
✓ Saved submission to drive/MyDrive/data/crossflow_submission.csv
Model saved to: drive/MyDrive/data//models/crossflow.pth
