<a href="https://colab.research.google.com/github/xinyuezhang-shirley/cs229FinalProject/blob/main/CS229_ProjectionLayer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Training hyperparameters
BATCH_SIZE = 128                
EPOCHS = 500              # maximum epochs (early stop may cut short)
LR = 2e-3                 # base learning rate (after warmup)
TEMP = 0.7                # InfoNCE temperature
SAMPLES_PER_EPOCH = 2500   # samples per epoch  

# Learning rate schedule / early stopping
WARMUP_EPOCHS = 5         # linear warmup epochs
MIN_LR = 1e-3             # final minimum LR for cosine schedule
USE_COSINE_SCHEDULE = True
EARLY_STOP_PATIENCE = 25  # epochs (post-warmup) with no sufficient improvement
EARLY_STOP_DELTA = 0.002  # required loss decrease to reset patience
MOVING_AVG_WINDOW = 10    # for smoothed loss

# Modality weights (kept same)
ALPHA = 0.6       # MPNet branch
BETA_EMO = 0.1    # emotion semantics
BETA_THEME = 0.15 # theme semantics
BETA_OTHER = 0.1  # other semantics (sentiment, subjectivity, concreteness, energy, narrative, imagery)
GAMMA = 0.15      # structural/lexical branch

# Unsupervised pair construction hyperparameters
POS_TOPK = 5        # positives per poem from similarity
HARD_TOPK = 0       # hard negatives per poem (near misses)
EASY_THRESHOLD = 0.25  # cosine threshold for easy negatives

In [74]:
# MPNet embeddings (raw, not yet filtered)
poem_vecs = np.load("data/processed/mpnet_embeddings_poems.npy")
song_vecs = np.load("data/processed/mpnet_embeddings_songs.npy")

# Load all features from full_features.npz
full = np.load("data/processed/full_features.npz", allow_pickle=True)

# Structural + lexical features (concatenated)
poem_struct = full["poem_struct"]  # (3413, 3)
poem_lexical = full["poem_lexical"]  # (3413, 3)
poem_feats = np.concatenate([poem_struct, poem_lexical], axis=1)  # (3413, 6)

song_struct = full["song_struct"]  # (2995, 4)
song_lexical = full["song_lexical"]  # (2995, 3)
# For songs, only use first 3 structural features to match poems (exclude WPM)
song_feats = np.concatenate([song_struct[:, :3], song_lexical], axis=1)  # (2995, 6)

# Semantic features
poem_sem_all = full["poem_semantic"]  # (3413, 36)
song_sem_all = full["song_semantic"]  # (2995, 36)

# Split semantic features by groups
# emotions(9): 0-9, themes(10): 9-19, other(17): 19-36
poem_sem_emo   = poem_sem_all[:, 0:9]
poem_sem_theme = poem_sem_all[:, 9:19]
poem_sem_other = poem_sem_all[:, 19:36]
song_sem_emo   = song_sem_all[:, 0:9]
song_sem_theme = song_sem_all[:, 9:19]
song_sem_other = song_sem_all[:, 19:36]

# Align song embeddings to match cleaned features
idx_map = full["song_source_indexes"]  # (2995,) maps cleaned songs -> raw embedding indices
song_vecs = song_vecs[idx_map]  # reorder raw embeddings to match cleaned data

print(f"Poems: {poem_vecs.shape[0]} items")
print(f"Songs: {song_vecs.shape[0]} items")
print(f"poem_vecs: {poem_vecs.shape}, song_vecs: {song_vecs.shape}")
print(f"poem_feats: {poem_feats.shape}, song_feats: {song_feats.shape}")
print(f"poem_sem (emo/theme/other): {poem_sem_emo.shape}, {poem_sem_theme.shape}, {poem_sem_other.shape}")
print(f"song_sem (emo/theme/other): {song_sem_emo.shape}, {song_sem_theme.shape}, {song_sem_other.shape}")


Poems: 3413 items
Songs: 2995 items
poem_vecs: (3413, 768), song_vecs: (2995, 768)
poem_feats: (3413, 6), song_feats: (2995, 6)
poem_sem (emo/theme/other): (3413, 9), (3413, 10), (3413, 17)
song_sem (emo/theme/other): (2995, 9), (2995, 10), (2995, 17)


In [None]:
# Normalize MPNet embeddings per row to balance scales
poem_vecs = poem_vecs / (np.linalg.norm(poem_vecs, axis=1, keepdims=True) + 1e-8)
song_vecs = song_vecs / (np.linalg.norm(song_vecs, axis=1, keepdims=True) + 1e-8)

# Build branch inputs
poem_in = {
    "mpnet": poem_vecs.astype(np.float32),
    "sem_emo":   poem_sem_emo.astype(np.float32),
    "sem_theme": poem_sem_theme.astype(np.float32),
    "sem_other": poem_sem_other.astype(np.float32),
    "feat":  poem_feats.astype(np.float32),
}
song_in = {
    "mpnet": song_vecs.astype(np.float32),
    "sem_emo":   song_sem_emo.astype(np.float32),
    "sem_theme": song_sem_theme.astype(np.float32),
    "sem_other": song_sem_other.astype(np.float32),
    "feat":  song_feats.astype(np.float32),
}

def _row_norm(x):
    return x / (np.linalg.norm(x, axis=1, keepdims=True) + 1e-8)

print("poem branches:", poem_in["mpnet"].shape, poem_in["sem_emo"].shape, poem_in["sem_theme"].shape, poem_in["sem_other"].shape, poem_in["feat"].shape)
print("song  branches:", song_in["mpnet"].shape, song_in["sem_emo"].shape, song_in["sem_theme"].shape, song_in["sem_other"].shape, song_in["feat"].shape)

# Compute or load precomputed pairwise cosine matrices per modality and combine with weights
import torch
from pathlib import Path

cosine_cache_dir = Path("data/processed/cosine_mats")
cosine_cache_dir.mkdir(parents=True, exist_ok=True)

def get_cosine_matrix(p_feat, s_feat, name):
    cache_path = cosine_cache_dir / f"cosine_{name}.npy"
    if cache_path.exists():
        print(f"Loading {name} cosine matrix from {cache_path}...")
        return torch.from_numpy(np.load(cache_path)).to(DEVICE)
    print(f"Computing {name} cosine matrix on GPU...")
    p_norm = torch.from_numpy(_row_norm(p_feat)).to(torch.float32).to(DEVICE)
    s_norm = torch.from_numpy(_row_norm(s_feat)).to(torch.float32).to(DEVICE)
    mat = torch.matmul(p_norm, s_norm.T)
    np.save(cache_path, mat.cpu().numpy())
    print(f"Saved {name} cosine matrix to {cache_path}")
    return mat

cos_mpnet = get_cosine_matrix(poem_in["mpnet"], song_in["mpnet"], "mpnet")
cos_emo   = get_cosine_matrix(poem_in["sem_emo"], song_in["sem_emo"], "sem_emo")
cos_theme = get_cosine_matrix(poem_in["sem_theme"], song_in["sem_theme"], "sem_theme")
cos_other = get_cosine_matrix(poem_in["sem_other"], song_in["sem_other"], "sem_other")
cos_feat  = get_cosine_matrix(poem_in["feat"], song_in["feat"], "feat")

cos_matrix_t = (
    ALPHA * cos_mpnet +
    BETA_EMO * cos_emo +
    BETA_THEME * cos_theme +
    BETA_OTHER * cos_other +
    GAMMA * cos_feat
)

print("Combined cosine matrix shape:", cos_matrix_t.shape)

# Build pos/hard/neg pairs from current hyperparameters (always recompute based on thresholds)
print(f"Building pairs with POS_TOPK={POS_TOPK}, HARD_TOPK={HARD_TOPK}, EASY_THRESHOLD={EASY_THRESHOLD}...")
P, S = cos_matrix_t.shape
pos_pairs = []
hard_pairs = []
neg_pairs = []

with torch.no_grad():
    # For each poem, get top (POS_TOPK + HARD_TOPK) indices
    topk_vals, topk_idxs = torch.topk(cos_matrix_t, k=min(S, POS_TOPK + HARD_TOPK), dim=1, largest=True, sorted=True)

    # Build pos and hard lists
    for i in range(P):
        for j in topk_idxs[i, :POS_TOPK].tolist():
            pos_pairs.append((int(i), int(j)))
        for j in topk_idxs[i, POS_TOPK:POS_TOPK+HARD_TOPK].tolist():
            hard_pairs.append((int(i), int(j)))

    easy_mask = cos_matrix_t <= EASY_THRESHOLD
    for i in range(P):
        low_idxs = torch.nonzero(easy_mask[i], as_tuple=False).squeeze(-1).cpu().numpy()
        if low_idxs.size > 0:
            sample_ct = min(5, low_idxs.size)
            choice = np.random.choice(low_idxs, size=sample_ct, replace=False)
            for j in choice:
                neg_pairs.append((int(i), int(j)))

print(f"Built pairs -> pos: {len(pos_pairs)} hard: {len(hard_pairs)} easy: {len(neg_pairs)}")

poem branches: (3413, 768) (3413, 9) (3413, 10) (3413, 17) (3413, 6)
song  branches: (2995, 768) (2995, 9) (2995, 10) (2995, 17) (2995, 6)
Loading precomputed cosine matrix from data/processed/mpnet_pairwise_cosine_matrix.npy...
Loaded matrix shape: torch.Size([3413, 2995])
Building pairs with POS_TOPK=5, HARD_TOPK=10, EASY_THRESHOLD=0.25...
Built pairs -> pos: 17065 hard: 34130 easy: 17065


In [76]:
# Pre-build GPU tensors so batching is instant
poem_gpu = {}
song_gpu = {}

for k in poem_in:
    poem_gpu[k] = torch.from_numpy(poem_in[k]).to(torch.float32).to(DEVICE)
    song_gpu[k] = torch.from_numpy(song_in[k]).to(torch.float32).to(DEVICE)

# Build dataset and loader
dataset = PairDataset(pos_pairs, neg_pairs, hard_pairs, size=SAMPLES_PER_EPOCH)
loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)

print(f"Dataset size: {len(dataset)} samples/epoch, loader batches: {len(loader)}")

Dataset size: 5000 samples/epoch, loader batches: 40


In [77]:
def clip_loss(poem_emb, song_emb, temperature=TEMP):
    """
    InfoNCE / contrastive loss for a batch of poem/song embedding pairs.
    
    Args:
        poem_emb: [B, D] poem embeddings
        song_emb: [B, D] song embeddings
        temperature: scaling for logits
    
    Returns:
        Scalar loss
    """
    # Normalize both
    poem_emb = F.normalize(poem_emb, dim=1)
    song_emb = F.normalize(song_emb, dim=1)
    
    B = poem_emb.shape[0]
    
    # Compute all-pairs similarity: [B, B]
    logits = torch.matmul(poem_emb, song_emb.T) / temperature
    
    # Positives on diagonal, negatives off-diagonal
    labels = torch.arange(B, device=poem_emb.device)
    
    # Symmetric loss: poem->song + song->poem
    loss_p2s = F.cross_entropy(logits, labels)
    loss_s2p = F.cross_entropy(logits.T, labels)
    
    return (loss_p2s + loss_s2p) / 2.0

In [80]:
class PairDataset(Dataset):
    def __init__(self, pos_pairs, neg_pairs, hard_pairs, size):
        """
        Returns poem/song indices for each sample.
        size = number of samples per epoch
        """
        self.pos_pairs  = pos_pairs
        self.neg_pairs  = neg_pairs
        self.hard_pairs = hard_pairs
        self.size = size

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        # Sample random positive pair
        i_poem, j_song = self.pos_pairs[np.random.randint(len(self.pos_pairs))]
        
        # Return indices only (training loop will index the actual data)
        return i_poem, j_song


In [None]:
class ProjectionModel(nn.Module):
    def __init__(self, p_dims, s_dims, proj_dim):
        super().__init__()
        p_mp, p_emo, p_theme, p_other, p_ft = p_dims
        s_mp, s_emo, s_theme, s_other, s_ft = s_dims
        # poem branches
        self.poem_mp = nn.Sequential(nn.Linear(p_mp, 256), nn.ReLU(), nn.Linear(256, 128))
        self.poem_emo = nn.Sequential(nn.Linear(max(p_emo,1), 64), nn.ReLU(), nn.Linear(64, 64))
        self.poem_theme = nn.Sequential(nn.Linear(max(p_theme,1), 64), nn.ReLU(), nn.Linear(64, 64))
        self.poem_other = nn.Sequential(nn.Linear(max(p_other,1), 64), nn.ReLU(), nn.Linear(64, 64))
        self.poem_ft = nn.Sequential(nn.Linear(p_ft, 64), nn.ReLU(), nn.Linear(64, 64))
        self.poem_proj = nn.Sequential(nn.LayerNorm(128+64+64+64+64), nn.Linear(128+64+64+64+64, proj_dim))
        # song branches
        self.song_mp = nn.Sequential(nn.Linear(s_mp, 256), nn.ReLU(), nn.Linear(256, 128))
        self.song_emo = nn.Sequential(nn.Linear(max(s_emo,1), 64), nn.ReLU(), nn.Linear(64, 64))
        self.song_theme = nn.Sequential(nn.Linear(max(s_theme,1), 64), nn.ReLU(), nn.Linear(64, 64))
        self.song_other = nn.Sequential(nn.Linear(max(s_other,1), 64), nn.ReLU(), nn.Linear(64, 64))
        self.song_ft = nn.Sequential(nn.Linear(s_ft, 64), nn.ReLU(), nn.Linear(64, 64))
        self.song_proj = nn.Sequential(nn.LayerNorm(128+64+64+64+64), nn.Linear(128+64+64+64+64, proj_dim))
    def forward_poem(self, p):
        mp = self.poem_mp(p["mpnet"])
        emo_in = p["sem_emo"] if p_dim_emo>0 else torch.zeros(p["mpnet"].shape[0], 1, device=p["mpnet"].device)
        theme_in = p["sem_theme"] if p_dim_theme>0 else torch.zeros(p["mpnet"].shape[0], 1, device=p["mpnet"].device)
        other_in = p["sem_other"] if p_dim_other>0 else torch.zeros(p["mpnet"].shape[0], 1, device=p["mpnet"].device)
        emo = self.poem_emo(emo_in)
        theme = self.poem_theme(theme_in)
        other = self.poem_other(other_in)
        ft  = self.poem_ft(p["feat"])
        comb = torch.cat([mp, emo, theme, other, ft], dim=1)
        z = self.poem_proj(comb)
        return F.normalize(z, dim=1)
    def forward_song(self, s):
        mp = self.song_mp(s["mpnet"])
        emo_in = s["sem_emo"] if s_dim_emo>0 else torch.zeros(s["mpnet"].shape[0], 1, device=s["mpnet"].device)
        theme_in = s["sem_theme"] if s_dim_theme>0 else torch.zeros(s["mpnet"].shape[0], 1, device=s["mpnet"].device)
        other_in = s["sem_other"] if s_dim_other>0 else torch.zeros(s["mpnet"].shape[0], 1, device=s["mpnet"].device)
        emo = self.song_emo(emo_in)
        theme = self.song_theme(theme_in)
        other = self.song_other(other_in)
        ft  = self.song_ft(s["feat"])
        comb = torch.cat([mp, emo, theme, other, ft], dim=1)
        z = self.song_proj(comb)
        return F.normalize(z, dim=1)

In [None]:
# Helper utilities for grid search experiments (pair building, training, evaluation)
from itertools import product
import time
import math

P_DIMS_GRID = (
    poem_gpu["mpnet"].shape[1],
    poem_gpu["sem_emo"].shape[1] if poem_gpu["sem_emo"].ndim > 1 else 0,
    poem_gpu["sem_theme"].shape[1] if poem_gpu["sem_theme"].ndim > 1 else 0,
    poem_gpu["sem_other"].shape[1] if poem_gpu["sem_other"].ndim > 1 else 0,
    poem_gpu["feat"].shape[1],
)
S_DIMS_GRID = (
    song_gpu["mpnet"].shape[1],
    song_gpu["sem_emo"].shape[1] if song_gpu["sem_emo"].ndim > 1 else 0,
    song_gpu["sem_theme"].shape[1] if song_gpu["sem_theme"].ndim > 1 else 0,
    song_gpu["sem_other"].shape[1] if song_gpu["sem_other"].ndim > 1 else 0,
    song_gpu["feat"].shape[1],
)
PROJ_DIM_GRID = int(
    globals().get(
        "proj_dim",
        getattr(model.poem_proj[-1], "out_features", 128) if "model" in globals() else 128,
    )
)

def combine_cosine_matrix(weights):
    """Create weighted cosine matrix using cached per-modality similarities."""
    return (
        weights["alpha"] * cos_mpnet
        + weights["beta_emo"] * cos_emo
        + weights["beta_theme"] * cos_theme
        + weights["beta_other"] * cos_other
        + weights["gamma"] * cos_feat
    )

def build_pairs_from_matrix(
    cos_matrix,
    pos_topk=POS_TOPK,
    hard_topk=HARD_TOPK,
    easy_threshold=EASY_THRESHOLD,
    easy_samples=5,
    rng=None,
):
    """Return positive / hard / easy pairs based on similarity thresholds."""
    rng = rng or np.random
    P, S = cos_matrix.shape
    pos_pairs, hard_pairs, neg_pairs = [], [], []
    with torch.no_grad():
        k = min(S, pos_topk + hard_topk)
        if k > 0:
            _, topk_idxs = torch.topk(cos_matrix, k=k, dim=1, largest=True, sorted=True)
        else:
            topk_idxs = None
        for i in range(P):
            if pos_topk > 0 and topk_idxs is not None:
                pos_pairs.extend([(int(i), int(j)) for j in topk_idxs[i, :pos_topk].tolist()])
            if hard_topk > 0 and topk_idxs is not None:
                hard_pairs.extend([(int(i), int(j)) for j in topk_idxs[i, pos_topk : pos_topk + hard_topk].tolist()])
        if easy_threshold is not None:
            easy_mask = cos_matrix <= easy_threshold
            for i in range(P):
                low_idxs = torch.nonzero(easy_mask[i], as_tuple=False).squeeze(-1).cpu().numpy()
                if low_idxs.size == 0:
                    continue
                sample_ct = min(easy_samples, low_idxs.size)
                choice = rng.choice(low_idxs, size=sample_ct, replace=False)
                for j in choice:
                    neg_pairs.append((int(i), int(j)))
    return pos_pairs, hard_pairs, neg_pairs

def make_loader_for_config(pos_pairs, neg_pairs, hard_pairs, batch_size, samples_per_epoch):
    dataset = PairDataset(pos_pairs, neg_pairs, hard_pairs, size=samples_per_epoch)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)
    return dataset, loader

def train_single_config(
    loader,
    temperature,
    lr,
    epochs,
    patience,
    delta,
    warmup_epochs,
    min_lr,
    use_cosine_schedule,
    moving_avg_window,
    seed=None,
):
    """Train a ProjectionModel with the provided loader and hyperparameters."""
    if seed is not None:
        torch.manual_seed(seed)
        np.random.seed(seed)
    model_local = ProjectionModel(P_DIMS_GRID, S_DIMS_GRID, PROJ_DIM_GRID).to(DEVICE)
    optimizer = torch.optim.Adam(model_local.parameters(), lr=lr)
    warm_epochs = min(warmup_epochs, epochs)
    warmup_scheduler = None
    scheduler_main = None
    if use_cosine_schedule and epochs > 0:
        warm_epochs = max(1, warm_epochs)
        def lr_lambda(epoch):
            return min(1.0, (epoch + 1) / warm_epochs)
        warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
        t_max = max(1, epochs - warm_epochs)
        scheduler_main = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=t_max, eta_min=min_lr)
    loss_history, lr_history = [], []
    best_loss = math.inf
    best_state = None
    patience_counter = 0
    start_time = time.time()
    for epoch in range(epochs):
        epoch_losses = []
        for pidxs, sidxs in loader:
            optimizer.zero_grad()
            batch_poem = {k: poem_gpu[k][pidxs] for k in poem_gpu}
            batch_song = {k: song_gpu[k][sidxs] for k in song_gpu}
            poem_out = model_local.forward_poem(batch_poem)
            song_out = model_local.forward_song(batch_song)
            loss = clip_loss(poem_out, song_out, temperature=temperature)
            loss.backward()
            optimizer.step()
            epoch_losses.append(loss.item())
        if not epoch_losses:
            break
        avg_loss = float(np.mean(epoch_losses))
        loss_history.append(avg_loss)
        window = moving_avg_window if moving_avg_window > 0 else len(loss_history)
        smooth_loss = float(np.mean(loss_history[-window:]))
        if smooth_loss < best_loss - delta:
            best_loss = smooth_loss
            patience_counter = 0
            best_state = {k: v.detach().cpu().clone() for k, v in model_local.state_dict().items()}
        else:
            patience_counter += 1
        if warmup_scheduler and epoch < warm_epochs:
            warmup_scheduler.step()
        elif scheduler_main:
            scheduler_main.step()
        lr_history.append(optimizer.param_groups[0]["lr"])
        if patience_counter >= patience:
            break
    train_time = time.time() - start_time
    if best_state:
        model_local.load_state_dict(best_state)
    return {
        "model": model_local,
        "loss_history": loss_history,
        "lr_history": lr_history,
        "best_loss": best_loss,
        "epochs_trained": len(loss_history),
        "train_time": train_time,
    }

def evaluate_on_triplets(model_local):
    if "human_triplets" not in globals():
        raise RuntimeError("human_triplets is not defined in the workspace.")
    model_local.eval()
    correct = 0
    total = len(human_triplets)
    with torch.no_grad():
        for p_idx, s1_idx, s2_idx, label in human_triplets:
            p_batch = {k: poem_gpu[k][p_idx : p_idx + 1] for k in poem_gpu}
            s1_batch = {k: song_gpu[k][s1_idx : s1_idx + 1] for k in song_gpu}
            s2_batch = {k: song_gpu[k][s2_idx : s2_idx + 1] for k in song_gpu}
            p_z = model_local.forward_poem(p_batch)
            s1_z = model_local.forward_song(s1_batch)
            s2_z = model_local.forward_song(s2_batch)
            sim1 = float((p_z * s1_z).sum().item())
            sim2 = float((p_z * s2_z).sum().item())
            pred = 1 if sim1 > sim2 else 2
            if pred == label:
                correct += 1
    return correct / total if total > 0 else 0.0

In [None]:
# Grid search runner (toggle RUN_GRID_SEARCH to True to execute all configurations)
import json

RUN_GRID_SEARCH = True  # Notebook default: execute grid search when this cell runs
SAVE_ALL_MODELS = True  # When True, store every trained checkpoint under results/
GRID_BATCH_SIZES = [96, 128]
GRID_TEMPERATURES = [0.5, 0.7, 0.9]
GRID_WEIGHT_CONFIGS = [
    {"name": "mpnet_heavy", "alpha": 0.65, "beta_emo": 0.08, "beta_theme": 0.12, "beta_other": 0.05, "gamma": 0.10},
    {"name": "balanced", "alpha": 0.55, "beta_emo": 0.12, "beta_theme": 0.15, "beta_other": 0.08, "gamma": 0.10},
    {"name": "semantic_push", "alpha": 0.45, "beta_emo": 0.18, "beta_theme": 0.20, "beta_other": 0.07, "gamma": 0.10},
    {"name": "structure_boost", "alpha": 0.50, "beta_emo": 0.10, "beta_theme": 0.12, "beta_other": 0.08, "gamma": 0.20},
    {"name": "other_focus", "alpha": 0.50, "beta_emo": 0.10, "beta_theme": 0.12, "beta_other": 0.18, "gamma": 0.10},
]
GRID_POS_TOPK = [POS_TOPK]
GRID_EASY_THRESH = [EASY_THRESHOLD]
GRID_SAMPLES_PER_EPOCH = [4000, 6000]
MAX_CONFIGS = 60

grid_configs = []
for bs, temp, weights, pos_k, easy_thr, spe in product(
    GRID_BATCH_SIZES,
    GRID_TEMPERATURES,
    GRID_WEIGHT_CONFIGS,
    GRID_POS_TOPK,
    GRID_EASY_THRESH,
    GRID_SAMPLES_PER_EPOCH,
):
    cfg = {
        "label": f"{weights['name']}_bs{bs}_t{temp:.2f}_spe{spe}",
        "batch_size": bs,
        "temperature": temp,
        "weights": weights,
        "pos_topk": pos_k,
        "hard_topk": HARD_TOPK,
        "easy_threshold": easy_thr,
        "samples_per_epoch": spe,
        "lr": LR,
        "epochs": min(EPOCHS, 150),
        "patience": min(EARLY_STOP_PATIENCE, 20),
        "delta": EARLY_STOP_DELTA,
        "warmup_epochs": min(WARMUP_EPOCHS, 5),
        "min_lr": MIN_LR,
        "use_cosine_schedule": USE_COSINE_SCHEDULE,
        "moving_avg_window": MOVING_AVG_WINDOW,
        "seed": 1337,
    }
    grid_configs.append(cfg)
    if len(grid_configs) >= MAX_CONFIGS:
        break

if not RUN_GRID_SEARCH:
    print(f"Grid search configured for {len(grid_configs)} runs. Set RUN_GRID_SEARCH = True to execute.")
else:
    results = []
    best_result = None
    result_dir = Path("results")
    result_dir.mkdir(parents=True, exist_ok=True)
    for idx, cfg in enumerate(grid_configs, start=1):
        print("\n" + "=" * 80)
        print(f"Config {idx}/{len(grid_configs)} :: {cfg['label']}")
        weights = cfg["weights"]
        cos_matrix = combine_cosine_matrix(weights)
        rng = np.random.default_rng(cfg.get("seed", 0) + idx)
        pos_pairs, hard_pairs, neg_pairs = build_pairs_from_matrix(
            cos_matrix,
            pos_topk=cfg["pos_topk"],
            hard_topk=cfg["hard_topk"],
            easy_threshold=cfg["easy_threshold"],
            easy_samples=5,
            rng=rng,
        )
        _, loader_local = make_loader_for_config(
            pos_pairs,
            neg_pairs,
            hard_pairs,
            batch_size=cfg["batch_size"],
            samples_per_epoch=cfg["samples_per_epoch"],
        )
        train_out = train_single_config(
            loader_local,
            temperature=cfg["temperature"],
            lr=cfg["lr"],
            epochs=cfg["epochs"],
            patience=cfg["patience"],
            delta=cfg["delta"],
            warmup_epochs=cfg["warmup_epochs"],
            min_lr=cfg["min_lr"],
            use_cosine_schedule=cfg["use_cosine_schedule"],
            moving_avg_window=cfg["moving_avg_window"],
            seed=cfg.get("seed"),
        )
        triplet_acc = evaluate_on_triplets(train_out["model"])
        result_row = {
            "label": cfg["label"],
            "batch_size": cfg["batch_size"],
            "temperature": cfg["temperature"],
            "weights": weights,
            "triplet_acc": triplet_acc,
            "best_loss": train_out["best_loss"],
            "epochs_trained": train_out["epochs_trained"],
            "train_time_sec": train_out["train_time"],
            "pos_pairs": len(pos_pairs),
        }
        results.append(result_row)
        checkpoint_name = f"grid_{cfg['label']}.pt"
        checkpoint_path = result_dir / checkpoint_name
        if SAVE_ALL_MODELS:
            torch.save(train_out["model"].state_dict(), checkpoint_path)
        print(
            f"{cfg['label']} → acc {triplet_acc*100:.2f}% | best loss {train_out['best_loss']:.4f} | checkpoint {checkpoint_name if SAVE_ALL_MODELS else 'n/a'}"
        )
        if best_result is None or result_row["triplet_acc"] > best_result["triplet_acc"]:
            best_state = {k: v.cpu() for k, v in train_out["model"].state_dict().items()}
            best_result = {**result_row, "state_dict": best_state}
    results_sorted = sorted(results, key=lambda r: r["triplet_acc"], reverse=True)
    summary_path = result_dir / "grid_search_results.json"
    with summary_path.open("w", encoding="utf-8") as f:
        json.dump(results_sorted, f, indent=2)
    print(f"\nSaved summary to {summary_path}")
    if best_result:
        best_path = result_dir / f"grid_best_{best_result['label']}.pt"
        torch.save(best_result["state_dict"], best_path)
        print(
            f"Best config: {best_result['label']} | acc {best_result['triplet_acc']*100:.2f}% | checkpoint → {best_path}",
        )
    print("Top-5 configs:")
    for row in results_sorted[:5]:
        print(
            f"  {row['label']}: acc {row['triplet_acc']*100:.2f}% | loss {row['best_loss']:.4f} | epochs {row['epochs_trained']}"
        )