# **AML Challenge: Model Stitching**

In [None]:
# @title
from PIL import Image
import os
from urllib.request import urlopen
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision import transforms

from sentence_transformers import SentenceTransformer, util
from diffusers import AutoencoderKL

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## **Downloading the Text Encoder**

In [None]:
print("🫁 Downloading the model...")
text_encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

### **Demo for the Text Encoder:**

In [None]:
text = "A cat is hiding under the table"

# We obtain the "embedding vector" using the encode() function:
emb = text_encoder.encode(text, convert_to_tensor=True, show_progress_bar=False)

# This is only for clarity:
preview = emb.tolist()[:3] + ["..."] + emb.tolist()[-3:]
print(f"The embedding looks like this: {preview}")

# This is the shape of our embedding:
print(f"The shape of the embedding is: {emb.shape}")

## **Downloading the VAE**

In [None]:
print("🫁 Downloading VAE from Hugging Face...")
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)

### **Demo for the VAE:**

In [None]:
IMG_URL = "https://images.unsplash.com/photo-1574144611937-0df059b5ef3e?ixlib=rb-4.1.0&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&q=80&w=764"
img = Image.open(urlopen(IMG_URL)).convert("RGB")

preprocess = transforms.Compose([
    transforms.Resize((256, 256)), # --> keep this size fixed
    # The VAE works also for 512x512, but it will require more compute
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
img_tensor = preprocess(img).unsqueeze(0).to(device)


# Encode --> latent
with torch.no_grad():
    latents = vae.encode(img_tensor).latent_dist.sample() * 0.18215
print("Latent shape:", latents.shape)

# Decode --> reconstruct
with torch.no_grad():
    recon = vae.decode(latents / 0.18215).sample

recon = (recon.clamp(-1, 1) + 1) / 2
recon_img = transforms.ToPILImage()(recon.squeeze().cpu())

In [None]:
# Visualize input vs output:
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
axes[0].imshow(img.resize((256, 256)))
axes[0].set_title("Original")
axes[0].axis("off")

axes[1].imshow(recon_img)
axes[1].set_title("Reconstructed")
axes[1].axis("off")

plt.tight_layout()
plt.show()

#**Frankenstein Model:**

In [None]:
class Translator(nn.Module):
    """
    This will be the Translator Model you have to design for the challenge.
    You have (almost) complete freedom on this. Your creativity will be rewarded.

    Some ideas might be:
    - Zero shot stitching (see https://arxiv.org/pdf/2209.15430)
    - Linear, Affine, Orthognal solutions (see https://arxiv.org/pdf/2311.00664)
    - Diffusion Priors (see https://arxiv.org/pdf/2204.06125)
    - Flow Matching (see https://arxiv.org/pdf/2412.15213)
    - CKA / Procrustes Analysis
    - Adversarial Trainings
    - AutoEncoding Solutions
    """

    def __init__(self):
        super().__init__()
        # Here is where *you* come into play:
        self.fc = nn.Linear(384, 4 * 32 * 32)
        # This is the most trivial thing you can do (spoiler: it doesn't work)

    def forward(self, x):
        x = self.fc(x)
        return x.view(1, 4, 32, 32)

translator = Translator().to(device)
translator.eval()

In [None]:
# Part 1: encoding the text prompt
text = "Frankestein's Monster writing code on Google Colab"
print(f"Prompt: {text}")
emb = text_encoder.encode(text, convert_to_tensor=True).to(device)
print("Text embedding shape:", emb.shape)

# Part 2: translating the embedding
with torch.no_grad():
    latent = translator(emb)
print(f"Translated latent shape: {latent.shape}\n\n")

# Part 3: feed the VAE with the translation
with torch.no_grad():
    recon = vae.decode(latent / 0.18215).sample

# Part 4: visualizing the output
recon = (recon.clamp(-1, 1) + 1) / 2
recon_img = transforms.ToPILImage()(recon.squeeze().cpu())

plt.imshow(recon_img)
plt.axis("off")
plt.show()

## Baseline

In [1]:
#!mkdir data
#!gdown 1CVAQDuPOiwm8h9LJ8a_oOs6zOWS6EgkB
#!gdown 1ykZ9fjTxUwdiEwqagoYZiMcD5aG-7rHe
#!unzip -o test.zip -d data
#!unzip -o train.zip -d data
from google.colab import drive
drive.mount('/content/drive')
!git clone https://github.com/Mamiglia/challenge.git

Mounted at /content/drive
Cloning into 'challenge'...
remote: Enumerating objects: 98, done.[K
remote: Counting objects: 100% (98/98), done.[K
remote: Compressing objects: 100% (69/69), done.[K
remote: Total 98 (delta 39), reused 72 (delta 26), pack-reused 0 (from 0)[K
Receiving objects: 100% (98/98), 21.03 MiB | 17.81 MiB/s, done.
Resolving deltas: 100% (39/39), done.


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from pathlib import Path
from tqdm import tqdm

from challenge.src.common import load_data, prepare_train_data, generate_submission

In [3]:
class TransformerTranslator(nn.Module):
    """
    Transformer-style translator from text embedding -> image embedding
    """
    def __init__(self, text_dim=1024, img_dim=1536, n_heads=8, n_layers=2, dim_feedforward=2048, dropout=0.2):
        super().__init__()
        self.input_ln = nn.LayerNorm(text_dim)
        self.proj_in = nn.Linear(text_dim, img_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=img_dim,
            nhead=n_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation='gelu',
            batch_first=True  # for (B, Seq, Dim)
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.output_ln = nn.LayerNorm(img_dim)

    def forward(self, x):
        if x.dim() == 2:
            x = x.unsqueeze(1)  # (B, 1, text_dim)
        x = self.input_ln(x)
        x = self.proj_in(x)  # project to model dim
        out = self.encoder(x)  # Transformer encoder
        out = out.squeeze(1)   # remove sequence dim
        return self.output_ln(out)

In [4]:
class ResidualMLPTranslator(nn.Module):
    def __init__(self, text_dim=1024, img_dim=1536, hidden_dim=2048, num_layers=3, dropout=0.2):
        super().__init__()
        assert num_layers >= 2
        self.input_ln = nn.LayerNorm(text_dim)
        self.dropout = nn.Dropout(dropout)

        # first layer: text_dim -> hidden_dim (no residual yet)
        self.first_layer = nn.Sequential(
            nn.Linear(text_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout)
        )

        # hidden residual blocks (hidden_dim -> hidden_dim)
        self.blocks = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.GELU(),
                nn.LayerNorm(hidden_dim),
                nn.Dropout(dropout)
            )
            for _ in range(num_layers - 2)
        ])

        # final projection to image space
        self.final_proj = nn.Linear(hidden_dim, img_dim)
        self.output_ln = nn.LayerNorm(img_dim)

        # input residual to output
        if text_dim != img_dim:
            self.res_proj = nn.Linear(text_dim, img_dim)
        else:
            self.res_proj = nn.Identity()

    def forward(self, x):
        x_in = self.input_ln(x)
        out = self.first_layer(x_in)
        for block in self.blocks:
            out = out + block(out)  # residual only between same-dim layers
        out = self.final_proj(out)
        out = out + self.res_proj(x_in)
        return self.output_ln(out)


In [5]:
class LatentSpaceTranslator(nn.Module):
    """
    MLP translator from text embedding -> image embedding
    Input: text_emb (batch, text_dim) or (batch, 1, text_dim)
    Output: (batch, img_dim)
    Regularization: dropout, LayerNorm, GELU, residual (optional projector)
    """
    def __init__(self,
                 text_dim=1024,
                 img_dim=1536,
                 hidden_dim=2048,
                 num_layers=3,
                 dropout=0.2,
                 use_residual=True):
        super().__init__()
        assert num_layers >= 2, "num_layers should be >= 2 (including final proj)"
        self.use_residual = use_residual
        self.input_ln = nn.LayerNorm(text_dim)
        layers = []
        in_dim = text_dim
        for i in range(num_layers - 1):
            layers.append(nn.Linear(in_dim, hidden_dim))
            layers.append(nn.GELU())
            layers.append(nn.LayerNorm(hidden_dim))
            layers.append(nn.Dropout(dropout))
            in_dim = hidden_dim
        # final projection to image space
        layers.append(nn.Linear(in_dim, img_dim))
        self.net = nn.Sequential(*layers)

        # if using residual, project input to img_dim to add it at the end
        if self.use_residual:
            if text_dim != img_dim:
                self.res_proj = nn.Linear(text_dim, img_dim)
            else:
                self.res_proj = nn.Identity()

        # final layer norm in image space
        self.output_ln = nn.LayerNorm(img_dim)

    def forward(self, text_emb):
        if text_emb.dim() == 3:
            x = text_emb.squeeze(1)
        else:
            x = text_emb
        x = self.input_ln(x)
        out = self.net(x)  # (B, img_dim)
        if self.use_residual:
            res = self.res_proj(x)
            out = out + res
        return self.output_ln(out)


In [6]:
# ---------- Training loop with Procrustes + InfoNCE ----------
def training(model, train_loader, val_loader, device, epochs, lr, MODEL_PATH,
             use_procrustes_init=True, procrustes_subset=10000, temperature=0.07,
             w_nce = 0.2, w_mse = 0.2, w_cos = 0.2):
    """Train LatentSpaceTranslator with optional Procrustes init + InfoNCE loss."""
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=5e-3)
    best_val_loss = float('inf')

    # --- Optional: Procrustes initialization ---
    if use_procrustes_init:
        print("Computing Procrustes initialization...")
        text_list, img_list = [], []
        for i, (X, y) in enumerate(train_loader):
            text_list.append(X.cpu())
            img_list.append(y.cpu())
            if sum(t.shape[0] for t in text_list) >= procrustes_subset:
                break
        text_sample = torch.cat(text_list, dim=0)[:procrustes_subset]
        img_sample = torch.cat(img_list, dim=0)[:procrustes_subset]
        model = apply_procrustes_init_to_final(model, text_sample, img_sample)

    # --- Training ---
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        for X_batch, y_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()

            pred = model(X_batch)
            # Contrastive InfoNCE loss (can mix with cosine or MSE)
            loss = w_nce * info_nce_loss(pred, y_batch, temperature=temperature)
            # Optional mixed objective:
            loss += w_cos * (1 - F.cosine_similarity(pred, y_batch).mean())
            loss += w_mse * F.mse_loss(pred, y_batch)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)

        # --- Validation ---
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)
                pred = model(X_batch)
                loss = w_nce * info_nce_loss(pred, y_batch, temperature=temperature)
                # Optional mixed objective:
                loss += w_cos * (1 - F.cosine_similarity(pred, y_batch).mean())
                loss += w_mse * F.mse_loss(pred, y_batch)
                val_loss += loss.item()

        val_loss /= len(val_loader)
        print(f"Epoch {epoch+1}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}")

        # --- Save best model ---
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            Path(MODEL_PATH).parent.mkdir(parents=True, exist_ok=True)
            torch.save(model.state_dict(), MODEL_PATH)
            print(f"  ✓ Saved best model (val_loss={val_loss:.6f})")

    return model

In [19]:
# ====== Procrustes initialization ======
def procrustes_init(text_embs, img_embs):
    """
    text_embs: (N, d_text)
    img_embs:  (N, d_img)
    returns: weight matrix (d_img, d_text)
    """
    # Center both
    X = text_embs - text_embs.mean(0, keepdim=True)
    Y = img_embs - img_embs.mean(0, keepdim=True)

    # Compute SVD of cross-covariance
    U, _, Vt = torch.linalg.svd(X.T @ Y, full_matrices=False)
    W = U @ Vt  # orthogonal map d_text→d_img
    return W.T   # shape (d_img, d_text) for nn.Linear weight

# ====== InfoNCE (CLIP-style) loss ======
def info_nce_loss(pred_img_emb, true_img_emb, temperature=0.07):
    """
    pred_img_emb: (B, D)
    true_img_emb: (B, D)
    """
    zt = F.normalize(pred_img_emb, dim=1)
    zi = F.normalize(true_img_emb, dim=1)
    logits = zt @ zi.t() / temperature
    labels = torch.arange(len(zt), device=zt.device)
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.t(), labels)
    return 0.5 * (loss_i2t + loss_t2i)

import torch
import torch.nn as nn

def apply_procrustes_init_to_final(model, text_sample, img_sample):
    """
    Apply Procrustes initialization to a model.
    - For MLP / ResidualMLP: apply to final Linear layer (hidden -> img_dim)
    - For TransformerTranslator: apply to first projection (text_dim -> img_dim)
    """
    with torch.no_grad():
        # Compute Procrustes matrix
        W = procrustes_init(text_embs=text_sample, img_embs=img_sample)

        # Apply to the appropriate layer
        applied = False
        for name, m in model.named_modules():
            if isinstance(m, nn.Linear):
                # Transformer: apply to first projection (proj_in)
                if isinstance(model, TransformerTranslator) and name.endswith("proj_in"):
                    print(m.weight.shape, W.shape)
                    if m.weight.shape == W.shape:
                        m.weight.copy_(W)
                        applied = True
                        break
                # MLP / ResidualMLP: apply to final_proj
                elif isinstance(model, LatentSpaceTranslator) and name.endswith("res_proj"):
                    print(m.weight.shape, W.shape)
                    if m.weight.shape == W.shape:
                        m.weight.copy_(W)
                        applied = True
                        break

                elif isinstance(model, ResidualMLPTranslator) and name.endswith("res_proj"):
                    print(m.weight.shape, W.shape)
                    if m.weight.shape == W.shape:
                        m.weight.copy_(W)
                        applied = True
                        break

        if not applied:
            print("⚠️ Warning: Could not find matching layer for Procrustes init")
    return model



In [8]:
# 1. Learning with a basic simple transformer
# 2. Learning with a basic deep neural network
# 3. TODO: Creat a hyperparameter optimization for the transformer aswell as
# the lost function nce loss etc. and perform it on the first layer of the transformer
# Configuration
EPOCHS = 60
BATCH_SIZE = 256
LR = 0.001
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load data
train_data = load_data("drive/MyDrive/data/train/train.npz")
X, y, label = prepare_train_data(train_data)
DATASET_SIZE = len(X)
# Split train/val
# This is done only to measure generalization capabilities, you don't have to
# use a validation set (though we encourage this)
n_train = int(0.9 * len(X))
TRAIN_SPLIT = torch.zeros(len(X), dtype=torch.bool)
TRAIN_SPLIT[:n_train] = 1
X_train, X_val = X[TRAIN_SPLIT], X[~TRAIN_SPLIT]
y_train, y_val = y[TRAIN_SPLIT], y[~TRAIN_SPLIT]


train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
y_train.shape, X_train.shape, train_loader.batch_size, val_loader.batch_size

(125000,)
Train data: 125000 captions, 125000 images


(torch.Size([112500, 1536]), torch.Size([112500, 1024]), 256, 256)

## Hyperparameter Optimization

In [9]:
!pip install optuna

Collecting optuna
  Downloading optuna-4.5.0-py3-none-any.whl.metadata (17 kB)
Collecting colorlog (from optuna)
  Downloading colorlog-6.10.1-py3-none-any.whl.metadata (11 kB)
Downloading optuna-4.5.0-py3-none-any.whl (400 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m400.9/400.9 kB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading colorlog-6.10.1-py3-none-any.whl (11 kB)
Installing collected packages: colorlog, optuna
Successfully installed colorlog-6.10.1 optuna-4.5.0


In [42]:
import optuna

def objective_extended(arch, trial, train_dataloader, val_dataloader, device, MODEL_PATH_BASE):

    # --- Common hyperparameters ---
    dropout = trial.suggest_float("dropout", 0.1, 0.5)
    lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
    batch_size = trial.suggest_categorical("batch_size", [64, 128, 256])
    weight_decay = trial.suggest_loguniform("weight_decay", 1e-5, 1e-3)

    # --- New hyperparameters ---
    temperature = trial.suggest_float("temperature", 0.01, 0.2)
    w_infonce = trial.suggest_float("w_infonce", 0.6, 0.8)
    w_cos = trial.suggest_float("w_cos", 0.4, 1.0)
    w_mse = trial.suggest_float("w_mse", 1.0 - w_cos, 1.0)
    procrustes_subset = 10000

    # --- Architecture-specific hyperparameters ---
    if arch in ["MLP", "ResidualMLP"]:
        hidden_dim = trial.suggest_categorical("hidden_dim", [1024, 2048, 4096])
        num_layers = trial.suggest_int("num_layers", 2, 6)
        if arch == "MLP":
            model = LatentSpaceTranslator(
                text_dim=1024, img_dim=1536, hidden_dim=hidden_dim,
                num_layers=num_layers, dropout=dropout
            ).to(device)
        else:
            model = ResidualMLPTranslator(
                text_dim=1024, img_dim=1536, hidden_dim=hidden_dim,
                num_layers=num_layers, dropout=dropout
            ).to(device)
    elif arch == "Transformer":
        n_layers = trial.suggest_int("n_layers", 2, 6)
        n_heads = trial.suggest_categorical("n_heads", [4, 8, 12])
        dim_feedforward = trial.suggest_categorical("dim_feedforward", [1024, 2048, 4096])
        model = TransformerTranslator(
            text_dim=1024, img_dim=1536,
            n_heads=n_heads, n_layers=n_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout
        ).to(device)


    # --- Apply Procrustes initialization ---
    if procrustes_subset > 0:
        # Get subset from train_loader
        text_list, img_list = [], []
        for i, (X, y) in enumerate(train_loader):
            text_list.append(X.cpu())
            img_list.append(y.cpu())
            if sum(t.shape[0] for t in text_list) >= procrustes_subset:
                break
        text_sample = torch.cat(text_list, dim=0)[:procrustes_subset]
        img_sample = torch.cat(img_list, dim=0)[:procrustes_subset]
        model = apply_procrustes_init_to_final(model, text_sample, img_sample)

    # --- Training loop (short run) ---
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    model.train()
    for epoch in range(5):  # short training
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            pred = model(X_batch)
            pred = F.normalize(pred, dim=-1)
            y_batch = F.normalize(y_batch, dim=-1)

            # Weighted combination of losses
            loss = 0
            if w_infonce > 0:
                loss += w_infonce * info_nce_loss(pred, y_batch, temperature=temperature)
            if w_mse > 0:
                loss += w_mse * F.mse_loss(pred, y_batch)
            if w_cos > 0:
                loss += w_cos * (1 - F.cosine_similarity(pred, y_batch).mean())

            loss.backward()
            optimizer.step()

    # --- Evaluate on validation ---
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            pred = model(X_batch)
            pred = F.normalize(pred, dim=-1)
            y_batch = F.normalize(y_batch, dim=-1)
            # Use combined loss for evaluation
            loss = 0
            if w_infonce > 0:
                loss += w_infonce * info_nce_loss(pred, y_batch, temperature=temperature)
            if w_mse > 0:
                loss += w_mse * F.mse_loss(pred, y_batch)
            if w_cos > 0:
                loss += w_cos * (1 - F.cosine_similarity(pred, y_batch).mean())
            val_loss += loss.item()
    val_loss /= len(val_loader)

    return val_loss


def run_optuna_extended(arch, train_dataloader, val_dataloader, device, MODEL_PATH_BASE, n_trials=30):
    study = optuna.create_study(direction="minimize")
    study.optimize(lambda trial: objective_extended(arch, trial, train_dataloader, val_dataloader, device, MODEL_PATH_BASE),
                   n_trials=n_trials)

    print("Best trial:")
    trial = study.best_trial
    print(f"Val loss: {trial.value}")
    print("Best hyperparameters:")
    for key, value in trial.params.items():
        print(f"  {key}: {value}")

    return trial.params

In [None]:
archs = ['MLP', 'ResidualMLP', 'Transformer']
choosen_arch = archs[2]
best_params = run_optuna_extended(
    arch = choosen_arch,
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    device=DEVICE,
    MODEL_PATH_BASE="models/translator_optuna"
)

if choosen_arch == 'Transformer':
    model = TransformerTranslator(
        text_dim=1024,
        img_dim=1536,
        n_heads = best_params['n_heads'],
        n_layers=best_params['n_layers'],
        dim_feedforward=best_params['dim_feedforward'],
        dropout=best_params['dropout']
    ).to(DEVICE)
    MODEL_PATH = "drive/MyDrive/data//models/transformer.pth"

elif choosen_arch == 'MLP':
    model = LatentSpaceTranslator(
    text_dim=1024,
    img_dim=1536,
    hidden_dim=best_params["hidden_dim"],
    num_layers=best_params["num_layers"],
    dropout=best_params["dropout"]).to(DEVICE)
    MODEL_PATH = "drive/MyDrive/data/models/latent_space.pth"

else:
    model = ResidualMLPTranslator(
    text_dim=1024,
    img_dim=1536,
    hidden_dim=best_params["hidden_dim"],
    num_layers=best_params["num_layers"],
    dropout=best_params["dropout"]).to(DEVICE)
    MODEL_PATH = "drive/MyDrive/data/models/residual.pth"

[I 2025-11-01 11:50:00,120] A new study created in memory with name: no-name-14f9d160-561b-403f-9b18-97eedbf7146b
  lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
  weight_decay = trial.suggest_loguniform("weight_decay", 1e-5, 1e-3)


torch.Size([1536, 1024]) torch.Size([1536, 1024])


[I 2025-11-01 11:51:08,482] Trial 0 finished with value: 1.9839396160476062 and parameters: {'dropout': 0.44983997621252014, 'lr': 4.2286853291969677e-05, 'batch_size': 256, 'weight_decay': 2.701630494266463e-05, 'temperature': 0.03324723057875822, 'w_infonce': 0.7209312836976648, 'w_cos': 0.6142175661338521, 'w_mse': 0.45920387201388335, 'n_layers': 2, 'n_heads': 12, 'dim_feedforward': 1024}. Best is trial 0 with value: 1.9839396160476062.


torch.Size([1536, 1024]) torch.Size([1536, 1024])


In [38]:
print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")

# Train
print("\n3. Training...")
model = training(model,
                 train_loader,
                 val_loader,
                 DEVICE,
                 EPOCHS,
                 best_params["lr"],
                 MODEL_PATH,
                 True,
                 10000,
                 best_params["temperature"],
                 best_params["w_infonce"],
                 best_params["w_mse"],
                 best_params["w_cos"])

   Parameters: 14,177,280

3. Training...
Computing Procrustes initialization...
torch.Size([1536, 1024]) torch.Size([1536, 1024])


Epoch 1/45: 100%|██████████| 440/440 [00:08<00:00, 54.43it/s]


Epoch 1: Train Loss = 2.667712, Val Loss = 2.665481
  ✓ Saved best model (val_loss=2.665481)


Epoch 2/45: 100%|██████████| 440/440 [00:08<00:00, 54.27it/s]


Epoch 2: Train Loss = 2.306717, Val Loss = 2.514931
  ✓ Saved best model (val_loss=2.514931)


Epoch 3/45: 100%|██████████| 440/440 [00:08<00:00, 52.28it/s]


Epoch 3: Train Loss = 2.169662, Val Loss = 2.443939
  ✓ Saved best model (val_loss=2.443939)


Epoch 4/45: 100%|██████████| 440/440 [00:08<00:00, 54.37it/s]


Epoch 4: Train Loss = 2.084842, Val Loss = 2.399220
  ✓ Saved best model (val_loss=2.399220)


Epoch 5/45: 100%|██████████| 440/440 [00:08<00:00, 54.73it/s]


Epoch 5: Train Loss = 2.021432, Val Loss = 2.363772
  ✓ Saved best model (val_loss=2.363772)


Epoch 6/45: 100%|██████████| 440/440 [00:07<00:00, 55.05it/s]


Epoch 6: Train Loss = 1.971583, Val Loss = 2.337160
  ✓ Saved best model (val_loss=2.337160)


Epoch 7/45: 100%|██████████| 440/440 [00:08<00:00, 53.36it/s]


Epoch 7: Train Loss = 1.930606, Val Loss = 2.311470
  ✓ Saved best model (val_loss=2.311470)


Epoch 8/45: 100%|██████████| 440/440 [00:08<00:00, 54.46it/s]


Epoch 8: Train Loss = 1.894185, Val Loss = 2.291556
  ✓ Saved best model (val_loss=2.291556)


Epoch 9/45: 100%|██████████| 440/440 [00:07<00:00, 55.76it/s]


Epoch 9: Train Loss = 1.864162, Val Loss = 2.272109
  ✓ Saved best model (val_loss=2.272109)


Epoch 10/45: 100%|██████████| 440/440 [00:08<00:00, 54.71it/s]


Epoch 10: Train Loss = 1.837341, Val Loss = 2.253808
  ✓ Saved best model (val_loss=2.253808)


Epoch 11/45: 100%|██████████| 440/440 [00:08<00:00, 54.34it/s]


Epoch 11: Train Loss = 1.812634, Val Loss = 2.239229
  ✓ Saved best model (val_loss=2.239229)


Epoch 12/45: 100%|██████████| 440/440 [00:08<00:00, 52.33it/s]


Epoch 12: Train Loss = 1.789839, Val Loss = 2.225551
  ✓ Saved best model (val_loss=2.225551)


Epoch 13/45: 100%|██████████| 440/440 [00:07<00:00, 56.41it/s]


Epoch 13: Train Loss = 1.768866, Val Loss = 2.211253
  ✓ Saved best model (val_loss=2.211253)


Epoch 14/45: 100%|██████████| 440/440 [00:08<00:00, 53.77it/s]


Epoch 14: Train Loss = 1.748797, Val Loss = 2.199894
  ✓ Saved best model (val_loss=2.199894)


Epoch 15/45: 100%|██████████| 440/440 [00:08<00:00, 54.25it/s]


Epoch 15: Train Loss = 1.731735, Val Loss = 2.188339
  ✓ Saved best model (val_loss=2.188339)


Epoch 16/45: 100%|██████████| 440/440 [00:08<00:00, 54.76it/s]


Epoch 16: Train Loss = 1.713526, Val Loss = 2.178325
  ✓ Saved best model (val_loss=2.178325)


Epoch 17/45: 100%|██████████| 440/440 [00:08<00:00, 54.15it/s]


Epoch 17: Train Loss = 1.698392, Val Loss = 2.167685
  ✓ Saved best model (val_loss=2.167685)


Epoch 18/45: 100%|██████████| 440/440 [00:07<00:00, 55.19it/s]


Epoch 18: Train Loss = 1.682995, Val Loss = 2.157980
  ✓ Saved best model (val_loss=2.157980)


Epoch 19/45: 100%|██████████| 440/440 [00:07<00:00, 55.43it/s]


Epoch 19: Train Loss = 1.668335, Val Loss = 2.149243
  ✓ Saved best model (val_loss=2.149243)


Epoch 20/45: 100%|██████████| 440/440 [00:08<00:00, 52.92it/s]


Epoch 20: Train Loss = 1.653344, Val Loss = 2.140630
  ✓ Saved best model (val_loss=2.140630)


Epoch 21/45: 100%|██████████| 440/440 [00:08<00:00, 54.70it/s]


Epoch 21: Train Loss = 1.639856, Val Loss = 2.133352
  ✓ Saved best model (val_loss=2.133352)


Epoch 22/45: 100%|██████████| 440/440 [00:07<00:00, 55.42it/s]


Epoch 22: Train Loss = 1.626496, Val Loss = 2.124883
  ✓ Saved best model (val_loss=2.124883)


Epoch 23/45: 100%|██████████| 440/440 [00:08<00:00, 53.96it/s]


Epoch 23: Train Loss = 1.615575, Val Loss = 2.117537
  ✓ Saved best model (val_loss=2.117537)


Epoch 24/45: 100%|██████████| 440/440 [00:08<00:00, 53.99it/s]


Epoch 24: Train Loss = 1.602767, Val Loss = 2.109036
  ✓ Saved best model (val_loss=2.109036)


Epoch 25/45: 100%|██████████| 440/440 [00:08<00:00, 54.94it/s]


Epoch 25: Train Loss = 1.589925, Val Loss = 2.102382
  ✓ Saved best model (val_loss=2.102382)


Epoch 26/45: 100%|██████████| 440/440 [00:08<00:00, 54.05it/s]


Epoch 26: Train Loss = 1.579613, Val Loss = 2.096099
  ✓ Saved best model (val_loss=2.096099)


Epoch 27/45: 100%|██████████| 440/440 [00:07<00:00, 55.33it/s]


Epoch 27: Train Loss = 1.568842, Val Loss = 2.089372
  ✓ Saved best model (val_loss=2.089372)


Epoch 28/45: 100%|██████████| 440/440 [00:07<00:00, 56.39it/s]


Epoch 28: Train Loss = 1.557857, Val Loss = 2.082269
  ✓ Saved best model (val_loss=2.082269)


Epoch 29/45: 100%|██████████| 440/440 [00:08<00:00, 52.52it/s]


Epoch 29: Train Loss = 1.547650, Val Loss = 2.077723
  ✓ Saved best model (val_loss=2.077723)


Epoch 30/45: 100%|██████████| 440/440 [00:07<00:00, 55.35it/s]


Epoch 30: Train Loss = 1.537787, Val Loss = 2.071424
  ✓ Saved best model (val_loss=2.071424)


Epoch 31/45: 100%|██████████| 440/440 [00:07<00:00, 55.93it/s]


Epoch 31: Train Loss = 1.527890, Val Loss = 2.064726
  ✓ Saved best model (val_loss=2.064726)


Epoch 32/45: 100%|██████████| 440/440 [00:08<00:00, 54.98it/s]


Epoch 32: Train Loss = 1.517575, Val Loss = 2.059332
  ✓ Saved best model (val_loss=2.059332)


Epoch 33/45: 100%|██████████| 440/440 [00:09<00:00, 48.84it/s]


Epoch 33: Train Loss = 1.509626, Val Loss = 2.054208
  ✓ Saved best model (val_loss=2.054208)


Epoch 34/45: 100%|██████████| 440/440 [00:07<00:00, 56.44it/s]


Epoch 34: Train Loss = 1.499840, Val Loss = 2.049011
  ✓ Saved best model (val_loss=2.049011)


Epoch 35/45: 100%|██████████| 440/440 [00:08<00:00, 54.85it/s]


Epoch 35: Train Loss = 1.491968, Val Loss = 2.044377
  ✓ Saved best model (val_loss=2.044377)


Epoch 36/45: 100%|██████████| 440/440 [00:08<00:00, 54.13it/s]


Epoch 36: Train Loss = 1.482622, Val Loss = 2.038578
  ✓ Saved best model (val_loss=2.038578)


Epoch 37/45: 100%|██████████| 440/440 [00:08<00:00, 54.35it/s]


Epoch 37: Train Loss = 1.474448, Val Loss = 2.034065
  ✓ Saved best model (val_loss=2.034065)


Epoch 38/45: 100%|██████████| 440/440 [00:08<00:00, 52.48it/s]


Epoch 38: Train Loss = 1.465391, Val Loss = 2.029970
  ✓ Saved best model (val_loss=2.029970)


Epoch 39/45: 100%|██████████| 440/440 [00:08<00:00, 54.54it/s]


Epoch 39: Train Loss = 1.458482, Val Loss = 2.024600
  ✓ Saved best model (val_loss=2.024600)


Epoch 40/45: 100%|██████████| 440/440 [00:07<00:00, 55.94it/s]


Epoch 40: Train Loss = 1.450942, Val Loss = 2.020618
  ✓ Saved best model (val_loss=2.020618)


Epoch 41/45: 100%|██████████| 440/440 [00:08<00:00, 53.91it/s]


Epoch 41: Train Loss = 1.443161, Val Loss = 2.016528
  ✓ Saved best model (val_loss=2.016528)


Epoch 42/45: 100%|██████████| 440/440 [00:08<00:00, 52.63it/s]


Epoch 42: Train Loss = 1.435650, Val Loss = 2.011917
  ✓ Saved best model (val_loss=2.011917)


Epoch 43/45: 100%|██████████| 440/440 [00:07<00:00, 56.17it/s]


Epoch 43: Train Loss = 1.428820, Val Loss = 2.008884
  ✓ Saved best model (val_loss=2.008884)


Epoch 44/45: 100%|██████████| 440/440 [00:07<00:00, 55.30it/s]


Epoch 44: Train Loss = 1.421862, Val Loss = 2.005336
  ✓ Saved best model (val_loss=2.005336)


Epoch 45/45: 100%|██████████| 440/440 [00:08<00:00, 54.47it/s]


Epoch 45: Train Loss = 1.413865, Val Loss = 2.001262
  ✓ Saved best model (val_loss=2.001262)


In [39]:
model.load_state_dict(torch.load(MODEL_PATH))

<All keys matched successfully>

In [None]:
from challenge.src.eval import visualize_retrieval
import numpy as np
import torch

val_caption_text = train_data['captions/text'][~TRAIN_SPLIT]
val_text_embd = X_val
img_VAL_SPLIT = label[~TRAIN_SPLIT].sum(dim=0) > 0
val_img_file = train_data['images/names'][img_VAL_SPLIT]
val_img_embd = torch.from_numpy(train_data['images/embeddings'][img_VAL_SPLIT])
val_label = np.nonzero(train_data['captions/label'][~TRAIN_SPLIT][:,img_VAL_SPLIT])[1]

# Sample and visualize
for i in range(5):
    idx = np.random.randint(0, 100)
    caption_embd = val_text_embd[idx]
    caption_text = val_caption_text[idx]
    gt_index = val_label[idx]

    with torch.no_grad():
        pred_embds = model(caption_embd.to(DEVICE)).cpu()

        visualize_retrieval(
            pred_embds,
            gt_index,
            val_img_file,
            caption_text, val_img_embd, k=5,
            dataset_path = 'drive/MyDrive/data/train')

KeyboardInterrupt: 

In [40]:
test_data = load_data("drive/MyDrive/data/test/test.clean.npz")

test_embds = test_data['captions/embeddings']
test_embds = torch.from_numpy(test_embds).float()

with torch.no_grad():
    pred_embds = model(test_embds.to(DEVICE)).cpu()

submission = generate_submission(test_data['captions/ids'], pred_embds, f'{choosen_arch}_submission.csv')
print(f"Model saved to: {MODEL_PATH}")

Generating submission file...
✓ Saved submission to Transformer_submission.csv
Model saved to: drive/MyDrive/data//models/transformer.pth
