# v223: Affi Pre-LN Transformer with CORAL

- Architecture: h=384, L=4, heads=8, Pre-LN (norm_first=True)
- Training: CORAL + Mean alignment, mixed precision
- Goal: Add TF diversity to Affi (currently TCN-only)
- Beignet saw -1,411 improvement from adding TF to ensemble
- Train both CORAL and non-CORAL versions (for 4-model blend)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

PROJECT_ROOT = '/content/drive/MyDrive/Hackathon_NSF_Neural_Forecasting'
TRAIN_DIR = f'{PROJECT_ROOT}/1_data/raw/train_data_neuro'
TEST_DIR = f'{PROJECT_ROOT}/1_data/raw/test_dev_input'

import os, time, torch, numpy as np, math
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.cuda.amp import autocast, GradScaler

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if device.type == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

In [None]:
# ============================================================================
# Cell 2: Losses + Architecture + Augmentation
# ============================================================================

def coral_loss(source, target):
    d = source.size(1)
    cs = (source - source.mean(0, keepdim=True)).T @ (source - source.mean(0, keepdim=True)) / (source.size(0) - 1 + 1e-8)
    ct = (target - target.mean(0, keepdim=True)).T @ (target - target.mean(0, keepdim=True)) / (target.size(0) - 1 + 1e-8)
    return ((cs - ct) ** 2).sum() / (4 * d)

def mean_alignment_loss(source, target):
    return ((source.mean(0) - target.mean(0)) ** 2).mean()

class TransformerForecasterLarge(nn.Module):
    """Pre-LN Transformer for Affi (239 channels)."""
    def __init__(self, n_ch, n_feat=1, h=384, n_layers=4, n_heads=8, dropout=0.2):
        super().__init__()
        self.channel_embed = nn.Embedding(n_ch, h // 4)
        self.input_proj = nn.Linear(n_feat + h // 4, h)
        self.pos_embed = nn.Parameter(torch.randn(1, 10, h) * 0.02)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=h, nhead=n_heads, dim_feedforward=h * 4,
            dropout=dropout, batch_first=True, activation='gelu',
            norm_first=True  # Pre-LN
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.cross_attn = nn.MultiheadAttention(h, n_heads, dropout=dropout, batch_first=True)
        self.attn_norm = nn.LayerNorm(h)
        self.pred_head = nn.Sequential(
            nn.Linear(h, h), nn.GELU(), nn.Dropout(dropout), nn.Linear(h, 10)
        )

    def forward(self, x, return_features=False):
        B, T, C, F = x.shape
        ch_emb = self.channel_embed(torch.arange(C, device=x.device)).unsqueeze(0).unsqueeze(0).expand(B, T, -1, -1)
        x = torch.cat([x, ch_emb], -1).permute(0, 2, 1, 3).reshape(B * C, T, -1)
        x = self.input_proj(x) + self.pos_embed[:, :T, :]
        x = self.transformer(x)
        x = x[:, -1, :].view(B, C, -1)
        x = self.attn_norm(x + self.cross_attn(x, x, x)[0])
        pred = self.pred_head(x).transpose(1, 2)
        if return_features:
            return pred, x.mean(dim=1)
        return pred

def augment_batch(x, y):
    if torch.rand(1).item() < 0.5:
        shift = 0.15 * torch.randn(1, 1, x.shape[2], 1, device=x.device)
        x = x.clone()
        x[..., 0:1] = x[..., 0:1] + shift
        y = y + shift[..., 0].squeeze(0)
    if torch.rand(1).item() < 0.5:
        scale = 1.0 + 0.08 * torch.randn(1, 1, x.shape[2], 1, device=x.device)
        x = x * scale
        y = y * scale[..., 0].squeeze(0)
    if torch.rand(1).item() < 0.3:
        x = x + 0.03 * torch.randn_like(x)
    return x, y

# Print param count
m = TransformerForecasterLarge(239, 3, 384, 4, 8, 0.2)
n_params = sum(p.numel() for p in m.parameters())
print(f'TransformerForecasterLarge(239ch, 3feat, h=384, L=4, heads=8): {n_params:,} params ({n_params/1e6:.2f}M)')
del m
print('Architecture defined.')

In [None]:
# ============================================================================
# Cell 3: Load Data
# ============================================================================

# Load Affi data - BOTH 3-feat and 1-feat versions
train_data = np.load(f'{TRAIN_DIR}/train_data_affi.npz')['arr_0']
test_public = np.load(f'{TEST_DIR}/test_data_affi_masked.npz')['arr_0']

# 3-feat normalization
N_FEAT_3 = 3
X_train_3 = train_data[:, :10, :, :N_FEAT_3].astype(np.float32)
Y_train = train_data[:, 10:, :, 0].astype(np.float32)
X_target_3 = test_public[:, :10, :, :N_FEAT_3].astype(np.float32)

mean_3 = X_train_3.mean(axis=(0,1), keepdims=True)
std_3 = X_train_3.std(axis=(0,1), keepdims=True) + 1e-8

X_train_3n = (X_train_3 - mean_3) / std_3
Y_train_n = (Y_train - mean_3[...,0]) / std_3[...,0]
X_target_3n = (X_target_3 - mean_3) / std_3

n_val = 100
X_tr_3, X_val_3 = X_train_3n[:-n_val], X_train_3n[-n_val:]
Y_tr, Y_val = Y_train_n[:-n_val], Y_train_n[-n_val:]

print(f'Train: {len(X_tr_3)}, Val: {len(X_val_3)}, Target: {len(X_target_3n)}')
print(f'Channels: {X_train_3.shape[2]}, Features: {N_FEAT_3}')

In [None]:
# ============================================================================
# Cell 4: Training Function
# ============================================================================

def train_affi_transformer(n_feat, X_tr, X_val, Y_tr, Y_val, X_target,
                           mean, std, h=384, n_layers=4, n_heads=8, dropout=0.2,
                           coral_w=3.0, mean_w=1.0, seed=42,
                           epochs=200, patience=30, batch_size=16,
                           lr=5e-4, warmup_epochs=10):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed(seed)

    train_ds = TensorDataset(torch.FloatTensor(X_tr), torch.FloatTensor(Y_tr))
    val_ds = TensorDataset(torch.FloatTensor(X_val), torch.FloatTensor(Y_val))
    target_ds = TensorDataset(torch.FloatTensor(X_target))
    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
    val_dl = DataLoader(val_ds, batch_size=batch_size)
    target_dl = DataLoader(target_ds, batch_size=batch_size, shuffle=True)

    model = TransformerForecasterLarge(239, n_feat, h, n_layers, n_heads, dropout).to(device)
    n_params = sum(p.numel() for p in model.parameters())
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scaler = GradScaler()

    # Warmup + cosine schedule
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return (epoch + 1) / warmup_epochs
        progress = (epoch - warmup_epochs) / (epochs - warmup_epochs)
        return 0.5 * (1 + math.cos(math.pi * progress))
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    best_val, best_state, no_improve = float('inf'), None, 0
    t0 = time.time()

    for epoch in range(epochs):
        model.train()
        target_iter = iter(target_dl)
        for xb, yb in train_dl:
            xb, yb = xb.to(device), yb.to(device)
            try: (xt,) = next(target_iter)
            except StopIteration:
                target_iter = iter(target_dl)
                (xt,) = next(target_iter)
            xt = xt.to(device)
            xb, yb = augment_batch(xb, yb)

            optimizer.zero_grad()
            with autocast():
                pred, feat_src = model(xb, return_features=True)
                _, feat_tgt = model(xt, return_features=True)
                mse_loss = ((pred - yb)**2).mean()
                if coral_w > 0:
                    c_loss = coral_loss(feat_src.float(), feat_tgt.float())
                    m_loss = mean_alignment_loss(feat_src.float(), feat_tgt.float())
                    loss = mse_loss + coral_w * c_loss + mean_w * m_loss
                else:
                    loss = mse_loss
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
        scheduler.step()

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for xb, yb in val_dl:
                xb, yb = xb.to(device), yb.to(device)
                with autocast():
                    val_loss += ((model(xb) - yb)**2).sum().item()
        val_mse = (val_loss / len(X_val)) * (std[...,0]**2).mean()

        if val_mse < best_val:
            best_val = val_mse
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            no_improve = 0
        else:
            no_improve += 1
        if no_improve >= patience: break

    elapsed = time.time() - t0
    tag = 'CORAL' if coral_w > 0 else 'orig'
    print(f'  [{tag}] seed={seed} feat={n_feat} h={h} L={n_layers} heads={n_heads} '
          f'-> Val MSE: {best_val:,.0f} ({n_params:,} params, ep {epoch+1}, {elapsed:.0f}s)')
    del model; torch.cuda.empty_cache()
    return best_val, best_state

print('Training function defined')

In [None]:
# ============================================================================
# Cell 5: Train 3-feat CORAL + orig (5 seeds each)
# ============================================================================

SEEDS = [42, 123, 456, 789, 2024]

print('=== Training 3-feat CORAL (5 seeds) ===')
coral_3f_results = []
for s in SEEDS:
    val_mse, state = train_affi_transformer(
        3, X_tr_3, X_val_3, Y_tr, Y_val, X_target_3n,
        mean_3, std_3, h=384, n_layers=4, n_heads=8, dropout=0.2,
        coral_w=3.0, mean_w=1.0, seed=s, epochs=200, patience=30)
    coral_3f_results.append((s, val_mse, state))

print('\n=== Training 3-feat orig/no-CORAL (5 seeds) ===')
orig_3f_results = []
for s in SEEDS:
    val_mse, state = train_affi_transformer(
        3, X_tr_3, X_val_3, Y_tr, Y_val, X_target_3n,
        mean_3, std_3, h=384, n_layers=4, n_heads=8, dropout=0.2,
        coral_w=0.0, mean_w=0.0, seed=s, epochs=200, patience=30)
    orig_3f_results.append((s, val_mse, state))

print('\n=== 3-feat Results ===')
for s, v, _ in coral_3f_results:
    print(f'  CORAL seed={s}: Val={v:,.0f}')
for s, v, _ in orig_3f_results:
    print(f'  orig  seed={s}: Val={v:,.0f}')

In [None]:
# ============================================================================
# Cell 6: Save models
# ============================================================================

out_dir = f'{PROJECT_ROOT}/4_models/v223_affi_preln_transformer'
os.makedirs(out_dir, exist_ok=True)

config = {
    'n_ch': 239, 'n_feat': 3, 'h': 384, 'n_layers': 4, 'n_heads': 8,
    'dropout': 0.2, 'pre_ln': True, 'coral_weight': 3.0, 'mean_weight': 1.0,
}

for s, val_mse, state in coral_3f_results:
    path = f'{out_dir}/model_affi_tf_3feat_coral_seed{s}.pth'
    torch.save({'model_state_dict': state, 'val_mse': val_mse, 'config': config}, path)
    fsize = os.path.getsize(path) / (1024*1024)
    print(f'CORAL seed {s}: Val={val_mse:,.0f}, size={fsize:.1f}MB')

config_orig = dict(config); config_orig['coral_weight'] = 0.0; config_orig['mean_weight'] = 0.0
for s, val_mse, state in orig_3f_results:
    path = f'{out_dir}/model_affi_tf_3feat_orig_seed{s}.pth'
    torch.save({'model_state_dict': state, 'val_mse': val_mse, 'config': config_orig}, path)
    fsize = os.path.getsize(path) / (1024*1024)
    print(f'orig  seed {s}: Val={val_mse:,.0f}, size={fsize:.1f}MB')

# Save normalization
np.savez(f'{out_dir}/normalization_affi_3feat.npz', mean=mean_3, std=std_3)

# Save results summary
with open(f'{out_dir}/results.txt', 'w') as f:
    f.write('v223 Affi Pre-LN Transformer Results\n')
    f.write(f'Config: {config}\n\n')
    f.write('3-feat CORAL:\n')
    for s, v, _ in coral_3f_results:
        f.write(f'  seed={s}: Val={v:.0f}\n')
    f.write('\n3-feat orig:\n')
    for s, v, _ in orig_3f_results:
        f.write(f'  seed={s}: Val={v:.0f}\n')

print(f'\nSaved to {out_dir}')

In [None]:
# ============================================================================
# Cell 7: Download best models
# ============================================================================

# Sort by val MSE, pick best 3 for each
coral_sorted = sorted(coral_3f_results, key=lambda x: x[1])
orig_sorted = sorted(orig_3f_results, key=lambda x: x[1])

print('=== Best 3 CORAL ===')
for s, v, _ in coral_sorted[:3]:
    print(f'  seed={s}: Val={v:,.0f}')

print('\n=== Best 3 orig ===')
for s, v, _ in orig_sorted[:3]:
    print(f'  seed={s}: Val={v:,.0f}')

print(f'\nAll models saved to: {out_dir}')
print(f'\nFor model.py integration:')
print(f'  - Add TransformerForecasterLarge class (Pre-LN version)')
print(f'  - Load TF models alongside existing TCN models')
print(f'  - Ensemble: alpha * TCN_4model_blend + (1-alpha) * TF_ensemble')
print(f'  - Optimize alpha on val data')