# v213: Stronger Transformer for Beignet Public

**Current TF:** h=256, L=4, heads=4, do=0.2, val ~53-56K (vs TCN ~47-50K)

**Goal:** Close the gap to TCN quality. Better TF = better ensemble.

**Improvements to try:**
1. Pre-LayerNorm (more stable training)
2. Longer training (500 epochs, patience=50)
3. Learning rate warmup (10 epochs)
4. Deeper model (L=6)
5. More heads (8)

**Plan:** Sweep configs, then train 5 seeds with best config

In [None]:
from google.colab import drive
drive.mount('/content/drive')

PROJECT_ROOT = '/content/drive/MyDrive/Hackathon_NSF_Neural_Forecasting'
TRAIN_DIR = f'{PROJECT_ROOT}/1_data/raw/train_data_neuro'
TEST_DIR = f'{PROJECT_ROOT}/1_data/raw/test_dev_input'

import os, time, torch, numpy as np, math
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if device.type == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name(0)}')

In [None]:
# ============================================================================
# Cell 2: Architecture + CORAL + Augmentation
# ============================================================================

def coral_loss(source, target):
    d = source.size(1)
    cs = (source - source.mean(0, keepdim=True)).T @ (source - source.mean(0, keepdim=True)) / (source.size(0) - 1 + 1e-8)
    ct = (target - target.mean(0, keepdim=True)).T @ (target - target.mean(0, keepdim=True)) / (target.size(0) - 1 + 1e-8)
    return ((cs - ct) ** 2).sum() / (4 * d)

def mean_alignment_loss(source, target):
    return ((source.mean(0) - target.mean(0)) ** 2).mean()

class TransformerForecaster(nn.Module):
    """Improved Transformer with Pre-LN for more stable training."""
    def __init__(self, n_ch, n_feat=1, h=256, n_layers=4, n_heads=4, dropout=0.2):
        super().__init__()
        self.channel_embed = nn.Embedding(n_ch, h // 4)
        self.input_proj = nn.Linear(n_feat + h // 4, h)
        self.pos_embed = nn.Parameter(torch.randn(1, 10, h) * 0.02)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=h, nhead=n_heads, dim_feedforward=h * 4,
            dropout=dropout, batch_first=True, activation='gelu',
            norm_first=True  # Pre-LN: more stable training
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.cross_attn = nn.MultiheadAttention(h, n_heads, dropout=dropout, batch_first=True)
        self.attn_norm = nn.LayerNorm(h)
        self.pred_head = nn.Sequential(
            nn.Linear(h, h), nn.GELU(), nn.Dropout(dropout), nn.Linear(h, 10)
        )

    def forward(self, x, return_features=False):
        B, T, C, F = x.shape
        ch_emb = self.channel_embed(torch.arange(C, device=x.device)).unsqueeze(0).unsqueeze(0).expand(B, T, -1, -1)
        x = torch.cat([x, ch_emb], -1).permute(0, 2, 1, 3).reshape(B * C, T, -1)
        x = self.input_proj(x) + self.pos_embed[:, :T, :]
        x = self.transformer(x)
        x = x[:, -1, :].view(B, C, -1)
        x = self.attn_norm(x + self.cross_attn(x, x, x)[0])
        pred = self.pred_head(x).transpose(1, 2)
        if return_features:
            return pred, x.mean(dim=1)
        return pred

# Original (Post-LN) version for comparison
class TransformerForecasterPostLN(nn.Module):
    def __init__(self, n_ch, n_feat=1, h=256, n_layers=4, n_heads=4, dropout=0.2):
        super().__init__()
        self.channel_embed = nn.Embedding(n_ch, h // 4)
        self.input_proj = nn.Linear(n_feat + h // 4, h)
        self.pos_embed = nn.Parameter(torch.randn(1, 10, h) * 0.02)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=h, nhead=n_heads, dim_feedforward=h * 4,
            dropout=dropout, batch_first=True, activation='gelu'
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.cross_attn = nn.MultiheadAttention(h, n_heads, dropout=dropout, batch_first=True)
        self.attn_norm = nn.LayerNorm(h)
        self.pred_head = nn.Sequential(
            nn.Linear(h, h), nn.GELU(), nn.Dropout(dropout), nn.Linear(h, 10)
        )

    def forward(self, x, return_features=False):
        B, T, C, F = x.shape
        ch_emb = self.channel_embed(torch.arange(C, device=x.device)).unsqueeze(0).unsqueeze(0).expand(B, T, -1, -1)
        x = torch.cat([x, ch_emb], -1).permute(0, 2, 1, 3).reshape(B * C, T, -1)
        x = self.input_proj(x) + self.pos_embed[:, :T, :]
        x = self.transformer(x)
        x = x[:, -1, :].view(B, C, -1)
        x = self.attn_norm(x + self.cross_attn(x, x, x)[0])
        pred = self.pred_head(x).transpose(1, 2)
        if return_features:
            return pred, x.mean(dim=1)
        return pred

def augment_batch(x, y):
    if torch.rand(1).item() < 0.5:
        shift = 0.15 * torch.randn(1, 1, x.shape[2], 1, device=x.device)
        x = x.clone()
        x[..., 0:1] = x[..., 0:1] + shift
        y = y + shift[..., 0].squeeze(0)
    if torch.rand(1).item() < 0.5:
        scale = 1.0 + 0.08 * torch.randn(1, 1, x.shape[2], 1, device=x.device)
        x = x * scale
        y = y * scale[..., 0].squeeze(0)
    if torch.rand(1).item() < 0.3:
        x = x + 0.03 * torch.randn_like(x)
    return x, y

for name, cls in [('Pre-LN', TransformerForecaster), ('Post-LN', TransformerForecasterPostLN)]:
    for h, nl, nh in [(256,4,4), (256,6,4), (256,4,8), (384,4,4)]:
        m = cls(89, 9, h, nl, nh, 0.2)
        n = sum(p.numel() for p in m.parameters())
        print(f'{name} h={h} L={nl} heads={nh}: {n:,} params ({n/1e6:.2f}M)')
    print()
print('Architecture defined.')

In [None]:
# ============================================================================
# Cell 3: Load Data
# ============================================================================

train_data = np.load(f'{TRAIN_DIR}/train_data_beignet.npz')['arr_0']
test_public = np.load(f'{TEST_DIR}/test_data_beignet_masked.npz')['arr_0']

n_features = 9
X_train = train_data[:, :10, :, :n_features].astype(np.float32)
Y_train = train_data[:, 10:, :, 0].astype(np.float32)
X_target = test_public[:, :10, :, :n_features].astype(np.float32)

mean = X_train.mean(axis=(0,1), keepdims=True)
std = X_train.std(axis=(0,1), keepdims=True) + 1e-8

X_train_n = (X_train - mean) / std
Y_train_n = (Y_train - mean[...,0]) / std[...,0]
X_target_n = (X_target - mean) / std

n_val = 100
X_tr, X_val = X_train_n[:-n_val], X_train_n[-n_val:]
Y_tr, Y_val = Y_train_n[:-n_val], Y_train_n[-n_val:]

print(f'Train: {len(X_tr)}, Val: {len(X_val)}, Target: {len(X_target_n)}')

In [None]:
# ============================================================================
# Cell 4: Training Function (with warmup + longer training)
# ============================================================================

def train_transformer(model_cls, h, n_layers, n_heads, dropout,
                      coral_w, mean_w, seed,
                      epochs=400, patience=50, batch_size=32,
                      lr=5e-4, warmup_epochs=10):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    train_ds = TensorDataset(torch.FloatTensor(X_tr), torch.FloatTensor(Y_tr))
    val_ds = TensorDataset(torch.FloatTensor(X_val), torch.FloatTensor(Y_val))
    target_ds = TensorDataset(torch.FloatTensor(X_target_n))
    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=False)
    val_dl = DataLoader(val_ds, batch_size=batch_size)
    target_dl = DataLoader(target_ds, batch_size=batch_size, shuffle=True)

    model = model_cls(89, 9, h, n_layers, n_heads, dropout).to(device)
    n_params = sum(p.numel() for p in model.parameters())
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    
    # Warmup + cosine schedule
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return (epoch + 1) / warmup_epochs
        progress = (epoch - warmup_epochs) / (epochs - warmup_epochs)
        return 0.5 * (1 + math.cos(math.pi * progress))
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    best_val, best_state, no_improve = float('inf'), None, 0
    t0 = time.time()

    for epoch in range(epochs):
        model.train()
        target_iter = iter(target_dl)
        for xb, yb in train_dl:
            xb, yb = xb.to(device), yb.to(device)
            try:
                (xt,) = next(target_iter)
            except StopIteration:
                target_iter = iter(target_dl)
                (xt,) = next(target_iter)
            xt = xt.to(device)
            xb, yb = augment_batch(xb, yb)

            optimizer.zero_grad()
            pred, feat_src = model(xb, return_features=True)
            _, feat_tgt = model(xt, return_features=True)
            loss = ((pred - yb)**2).mean() + coral_w * coral_loss(feat_src, feat_tgt) + mean_w * mean_alignment_loss(feat_src, feat_tgt)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        scheduler.step()

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for xb, yb in val_dl:
                xb, yb = xb.to(device), yb.to(device)
                val_loss += ((model(xb) - yb)**2).sum().item()
        val_mse = (val_loss / len(X_val)) * (std[...,0]**2).mean()

        if val_mse < best_val:
            best_val = val_mse
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            no_improve = 0
        else:
            no_improve += 1
        if no_improve >= patience:
            break

    elapsed = time.time() - t0
    print(f'  seed={seed} h={h} L={n_layers} heads={n_heads} do={dropout} '
          f'CORAL={coral_w} -> Val MSE: {best_val:.0f} ({n_params:,} params, ep {epoch+1}, {elapsed:.0f}s)')
    return best_val, best_state

print('Training function defined (warmup + cosine + augmentation)')

In [None]:
# ============================================================================
# Cell 5: Config Sweep
# ============================================================================

configs = [
    # (cls, h, L, heads, dropout, coral_w, mean_w, lr, name)
    # Baseline (v206 config, but with longer training + warmup + augmentation)
    (TransformerForecasterPostLN, 256, 4, 4, 0.20, 3.0, 1.0, 5e-4, 'postLN_h256_L4_do20'),
    # Pre-LN variants
    (TransformerForecaster, 256, 4, 4, 0.20, 3.0, 1.0, 5e-4, 'preLN_h256_L4_do20'),
    (TransformerForecaster, 256, 6, 4, 0.20, 3.0, 1.0, 5e-4, 'preLN_h256_L6_do20'),
    (TransformerForecaster, 256, 4, 8, 0.20, 3.0, 1.0, 5e-4, 'preLN_h256_L4_h8_do20'),
    # More dropout
    (TransformerForecaster, 256, 4, 4, 0.30, 3.0, 1.0, 5e-4, 'preLN_h256_L4_do30'),
    # Lower LR
    (TransformerForecaster, 256, 4, 4, 0.20, 3.0, 1.0, 3e-4, 'preLN_h256_L4_lr3e4'),
    # Stronger CORAL
    (TransformerForecaster, 256, 4, 4, 0.20, 5.0, 2.0, 5e-4, 'preLN_h256_L4_c5'),
]

results = []
print(f'=== Transformer Config Sweep ({len(configs)} configs) ===')
for cls, h, nl, nh, do, cw, mw, lr, name in configs:
    print(f'\n--- {name} ---')
    val_mse, state = train_transformer(cls, h, nl, nh, do, cw, mw, seed=42,
                                       epochs=400, patience=50, lr=lr)
    results.append((name, cls, h, nl, nh, do, cw, mw, lr, val_mse, state))

results.sort(key=lambda x: x[9])
print('\n=== Sweep Results (sorted) ===')
for i, (name, *_, val_mse, _) in enumerate(results):
    tag = ' *** BEST' if i == 0 else ''
    print(f'  {name}: Val={val_mse:,.0f}{tag}')

BEST = results[0]
print(f'\nBest: {BEST[0]} (Val={BEST[9]:,.0f})')
print(f'\nFor reference, v206 TF had val ~65M (shorter training, no aug, no warmup)')

In [None]:
# ============================================================================
# Cell 6: Multi-seed Training with Best Config
# ============================================================================

BEST_CLS = BEST[1]
BEST_H, BEST_NL, BEST_NH = BEST[2], BEST[3], BEST[4]
BEST_DO, BEST_CW, BEST_MW, BEST_LR = BEST[5], BEST[6], BEST[7], BEST[8]

SEEDS = [42, 123, 456, 789, 2024]
print(f'=== Training {len(SEEDS)} seeds with best config: {BEST[0]} ===')

seed_results = []
for s in SEEDS:
    val_mse, state = train_transformer(BEST_CLS, BEST_H, BEST_NL, BEST_NH, BEST_DO,
                                       BEST_CW, BEST_MW, seed=s,
                                       epochs=400, patience=50, lr=BEST_LR)
    seed_results.append((s, val_mse, state))

print(f'\n=== Multi-seed Results ===')
for s, val_mse, _ in seed_results:
    print(f'  seed={s}: Val MSE = {val_mse:,.0f}')
vals = [r[1] for r in seed_results]
print(f'Mean: {np.mean(vals):,.0f}, Std: {np.std(vals):,.0f}')

In [None]:
# ============================================================================
# Cell 7: Save
# ============================================================================

out_dir = f'{PROJECT_ROOT}/4_models/v213_stronger_transformer'
os.makedirs(out_dir, exist_ok=True)

# Check if best config uses Pre-LN or Post-LN
is_preln = BEST_CLS == TransformerForecaster
config = {
    'h': BEST_H, 'n_layers': BEST_NL, 'n_heads': BEST_NH,
    'dropout': BEST_DO, 'coral_weight': BEST_CW, 'mean_weight': BEST_MW,
    'lr': BEST_LR, 'n_features': 9, 'pre_ln': is_preln,
    'epochs': 400, 'patience': 50, 'warmup': 10,
}
print(f'Config: {config}')

for s, val_mse, state in seed_results:
    path = f'{out_dir}/model_transformer_seed{s}.pth'
    torch.save({'model_state_dict': state, 'val_mse': val_mse, 'config': config}, path)
    fsize = os.path.getsize(path) / (1024*1024)
    print(f'Saved seed {s}: Val={val_mse:,.0f}, size={fsize:.1f}MB')

# Save sweep results
with open(f'{out_dir}/sweep_results.txt', 'w') as f:
    for name, *_, val_mse, _ in results:
        f.write(f'{name}: Val={val_mse:.0f}\n')
    f.write(f'\nBest: {BEST[0]}\n')
    f.write(f'Pre-LN: {is_preln}\n')
    f.write(f'\nMulti-seed:\n')
    for s, val_mse, _ in seed_results:
        f.write(f'  seed={s}: Val={val_mse:.0f}\n')

print(f'\nSaved to {out_dir}')
print(f'\nIMPORTANT: If Pre-LN wins, model.py TransformerForecaster needs norm_first=True')
print(f'If Post-LN wins, model.py stays the same (just swap .pth files)')