# v200: Beignet Larger TCN + Multi-Seed + Augmentation

**策略:**
1. Scale up Beignet TCN: h=64 → h=128, 3 layers → 4 layers (104K → ~400K params)
2. Training-time augmentation: channel shift, amplitude scale, noise (zero inference cost)
3. Multi-seed ensemble: train 5 seeds, average predictions at inference

**Baseline:** v193d/v199b Beignet Public = 55,367 (CORAL λ=5.0, h=64, 3 layers)
**Target:** < 52,000 Beignet Public → Total MSE < 40,000

In [None]:
from google.colab import drive
drive.mount('/content/drive')

PROJECT_ROOT = '/content/drive/MyDrive/Hackathon_NSF_Neural_Forecasting'
TRAIN_DIR = f'{PROJECT_ROOT}/1_data/raw/train_data_neuro'
TEST_DIR = f'{PROJECT_ROOT}/1_data/raw/test_dev_input'

import os, time, torch, numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if device.type == 'cuda': print(f'GPU: {torch.cuda.get_device_name(0)}')

In [None]:
# ============================================================================
# CORAL Losses
# ============================================================================

def coral_loss(source, target):
    d = source.size(1)
    cs = (source - source.mean(0, keepdim=True)).T @ (source - source.mean(0, keepdim=True)) / (source.size(0) - 1 + 1e-8)
    ct = (target - target.mean(0, keepdim=True)).T @ (target - target.mean(0, keepdim=True)) / (target.size(0) - 1 + 1e-8)
    return ((cs - ct) ** 2).sum() / (4 * d)

def mean_alignment_loss(source, target):
    return ((source.mean(0) - target.mean(0)) ** 2).mean()

# ============================================================================
# TCN Architecture (same as v193d, parameterized h and n_layers)
# ============================================================================

class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1):
        super().__init__()
        self.padding = (kernel_size - 1) * dilation
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=self.padding, dilation=dilation)
    def forward(self, x):
        out = self.conv(x)
        return out[:, :, :-self.padding] if self.padding > 0 else out

class TCNBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, dilation, dropout=0.2):
        super().__init__()
        self.conv1 = CausalConv1d(in_ch, out_ch, kernel_size, dilation)
        self.conv2 = CausalConv1d(out_ch, out_ch, kernel_size, dilation)
        self.norm1, self.norm2 = nn.BatchNorm1d(out_ch), nn.BatchNorm1d(out_ch)
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(dropout)
        self.residual = nn.Conv1d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
    def forward(self, x):
        r = self.residual(x)
        x = self.dropout(self.activation(self.norm1(self.conv1(x))))
        x = self.dropout(self.activation(self.norm2(self.conv2(x))))
        return x + r

class TCNEncoder(nn.Module):
    def __init__(self, in_size, h_size, n_layers=4, k_size=3, dropout=0.2):
        super().__init__()
        self.input_proj = nn.Conv1d(in_size, h_size, 1)
        self.layers = nn.ModuleList([TCNBlock(h_size, h_size, k_size, 2**i, dropout) for i in range(n_layers)])
    def forward(self, x):
        x = self.input_proj(x.transpose(1,2))
        for l in self.layers: x = l(x)
        return x.transpose(1,2)

class TCNForecaster(nn.Module):
    def __init__(self, n_ch, n_feat=1, h=64, n_layers=3, dropout=0.3):
        super().__init__()
        self.channel_embed = nn.Embedding(n_ch, h//4)
        self.input_proj = nn.Linear(n_feat + h//4, h)
        self.tcn = TCNEncoder(h, h, n_layers, 3, dropout)
        self.cross_attn = nn.MultiheadAttention(h, 4, dropout=dropout, batch_first=True)
        self.attn_norm = nn.LayerNorm(h)
        self.pred_head = nn.Sequential(nn.Linear(h,h), nn.GELU(), nn.Dropout(dropout), nn.Linear(h,10))
    def forward(self, x, return_features=False):
        B,T,C,F = x.shape
        ch_emb = self.channel_embed(torch.arange(C, device=x.device)).unsqueeze(0).unsqueeze(0).expand(B,T,-1,-1)
        x = torch.cat([x, ch_emb], -1).permute(0,2,1,3).reshape(B*C,T,-1)
        x = self.tcn(self.input_proj(x))
        x = x[:,-1,:].view(B,C,-1)
        x = self.attn_norm(x + self.cross_attn(x,x,x)[0])
        pred = self.pred_head(x).transpose(1,2)
        if return_features:
            return pred, x.mean(dim=1)
        return pred

# Count params
for h, nl in [(64, 3), (128, 4), (192, 4)]:
    m = TCNForecaster(89, 9, h, nl, 0.25)
    n = sum(p.numel() for p in m.parameters())
    print(f'h={h}, layers={nl}: {n:,} params ({n/1e6:.2f}M)')

In [None]:
# ============================================================================
# Load Data
# ============================================================================

train_data = np.load(f'{TRAIN_DIR}/train_data_beignet.npz')['arr_0']
test_public = np.load(f'{TEST_DIR}/test_data_beignet_masked.npz')['arr_0']

n_features = 9
X_train = train_data[:, :10, :, :n_features].astype(np.float32)
Y_train = train_data[:, 10:, :, 0].astype(np.float32)
X_target = test_public[:, :10, :, :n_features].astype(np.float32)

mean = X_train.mean(axis=(0,1), keepdims=True)
std = X_train.std(axis=(0,1), keepdims=True) + 1e-8

X_train_n = (X_train - mean) / std
Y_train_n = (Y_train - mean[...,0]) / std[...,0]
X_target_n = (X_target - mean) / std

n_val = 100
X_tr, X_val = X_train_n[:-n_val], X_train_n[-n_val:]
Y_tr, Y_val = Y_train_n[:-n_val], Y_train_n[-n_val:]

print(f'Train: {len(X_tr)}, Val: {len(X_val)}, Target: {len(X_target_n)}')
print(f'X shape: {X_tr.shape}, Y shape: {Y_tr.shape}')

In [None]:
# ============================================================================
# Training-Time Augmentation (zero inference cost)
# ============================================================================

def augment_batch(x, y):
    """Apply augmentations to source batch during training.
    x: (B, T, C, F) normalized
    y: (B, T_out, C) normalized
    """
    # 1. Channel-wise mean shift (simulates baseline drift between domains)
    if torch.rand(1).item() < 0.5:
        shift = 0.15 * torch.randn(1, 1, x.shape[2], 1, device=x.device)
        x = x.clone()
        x[..., 0:1] = x[..., 0:1] + shift
        y = y + shift[..., 0].squeeze(0)  # shift target too

    # 2. Amplitude scaling per channel (simulates gain drift)
    if torch.rand(1).item() < 0.5:
        scale = 1.0 + 0.08 * torch.randn(1, 1, x.shape[2], 1, device=x.device)
        x = x * scale
        y = y * scale[..., 0].squeeze(0)

    # 3. Gaussian noise (regularization)
    if torch.rand(1).item() < 0.3:
        x = x + 0.03 * torch.randn_like(x)

    return x, y

print('Augmentation defined')

In [None]:
# ============================================================================
# Training Function with Augmentation
# ============================================================================

def train_coral_model(h_size, n_layers, dropout, coral_w, mean_w, seed,
                      epochs=250, patience=30, batch_size=32, use_aug=True):
    """Train a single TCN CORAL model with given config and seed."""
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    train_ds = TensorDataset(torch.FloatTensor(X_tr), torch.FloatTensor(Y_tr))
    val_ds = TensorDataset(torch.FloatTensor(X_val), torch.FloatTensor(Y_val))
    target_ds = TensorDataset(torch.FloatTensor(X_target_n))
    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=False)
    val_dl = DataLoader(val_ds, batch_size=batch_size)
    target_dl = DataLoader(target_ds, batch_size=batch_size, shuffle=True)

    model = TCNForecaster(89, 9, h_size, n_layers, dropout).to(device)
    n_params = sum(p.numel() for p in model.parameters())
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    best_val, best_state, no_improve = float('inf'), None, 0
    t0 = time.time()

    for epoch in range(epochs):
        model.train()
        target_iter = iter(target_dl)
        for xb, yb in train_dl:
            xb, yb = xb.to(device), yb.to(device)
            try: (xt,) = next(target_iter)
            except StopIteration:
                target_iter = iter(target_dl)
                (xt,) = next(target_iter)
            xt = xt.to(device)

            # Apply augmentation
            if use_aug:
                xb, yb = augment_batch(xb, yb)

            optimizer.zero_grad()
            pred, feat_src = model(xb, return_features=True)
            _, feat_tgt = model(xt, return_features=True)
            loss = ((pred - yb)**2).mean() + coral_w * coral_loss(feat_src, feat_tgt) + mean_w * mean_alignment_loss(feat_src, feat_tgt)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        scheduler.step()

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for xb, yb in val_dl:
                xb, yb = xb.to(device), yb.to(device)
                val_loss += ((model(xb) - yb)**2).sum().item()
        val_mse = (val_loss / len(X_val)) * (std[...,0]**2).mean()

        if val_mse < best_val:
            best_val = val_mse
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            no_improve = 0
        else:
            no_improve += 1
        if no_improve >= patience: break

    elapsed = time.time() - t0
    print(f'  seed={seed} h={h_size} L={n_layers} CORAL={coral_w} Mean={mean_w} '
          f'-> Val MSE: {best_val:.0f} ({n_params:,} params, ep {epoch+1}, {elapsed:.0f}s)')
    return best_val, best_state

print('Training function defined')

In [None]:
# ============================================================================
# Step 1: Architecture Sweep (find best h/layers with single seed)
# ============================================================================

arch_configs = [
    # (h_size, n_layers, dropout, coral_w, mean_w)
    (64,  3, 0.30, 5.0, 2.0),   # baseline (v193d config)
    (128, 3, 0.25, 5.0, 2.0),   # wider
    (128, 4, 0.25, 5.0, 2.0),   # wider + deeper
    (192, 4, 0.25, 5.0, 2.0),   # even wider + deeper
]

print('=== Architecture Sweep (seed=42) ===')
arch_results = []
for h, nl, do, cw, mw in arch_configs:
    val_mse, state = train_coral_model(h, nl, do, cw, mw, seed=42)
    arch_results.append((h, nl, do, cw, mw, val_mse, state))

arch_results.sort(key=lambda x: x[5])
print(f'\n=== Architecture Results ===')
for h, nl, do, cw, mw, val_mse, _ in arch_results:
    tag = ' *** BEST' if val_mse == arch_results[0][5] else ''
    print(f'h={h} L={nl} do={do} CORAL={cw} Mean={mw} -> Val={val_mse:.0f}{tag}')

best_h, best_nl, best_do = arch_results[0][0], arch_results[0][1], arch_results[0][2]
print(f'\nBest architecture: h={best_h}, layers={best_nl}, dropout={best_do}')

In [None]:
# ============================================================================
# Step 2: CORAL Lambda Sweep with Best Architecture
# ============================================================================

coral_configs = [
    (3.0, 1.0),
    (5.0, 2.0),   # v193d config
    (5.0, 3.0),
    (8.0, 2.0),
    (8.0, 3.0),
]

print(f'=== CORAL Sweep with h={best_h}, L={best_nl} (seed=42) ===')
coral_results = []
for cw, mw in coral_configs:
    val_mse, state = train_coral_model(best_h, best_nl, best_do, cw, mw, seed=42)
    coral_results.append((cw, mw, val_mse, state))

coral_results.sort(key=lambda x: x[2])
print(f'\n=== CORAL Results ===')
for cw, mw, val_mse, _ in coral_results:
    tag = ' *** BEST' if val_mse == coral_results[0][2] else ''
    print(f'CORAL={cw} Mean={mw} -> Val={val_mse:.0f}{tag}')

best_cw, best_mw = coral_results[0][0], coral_results[0][1]
print(f'\nBest CORAL config: λ={best_cw}, Mean={best_mw}')

In [None]:
# ============================================================================
# Step 3: Multi-Seed Training with Best Config
# ============================================================================

seeds = [42, 123, 456, 789, 2024]

print(f'=== Multi-Seed Training: h={best_h}, L={best_nl}, CORAL={best_cw}, Mean={best_mw} ===')
seed_results = []
for s in seeds:
    val_mse, state = train_coral_model(best_h, best_nl, best_do, best_cw, best_mw, seed=s)
    seed_results.append((s, val_mse, state))

print(f'\n=== Seed Results ===')
val_mses = [r[1] for r in seed_results]
for s, val_mse, _ in seed_results:
    print(f'seed={s} -> Val MSE: {val_mse:.0f}')
print(f'\nMean: {np.mean(val_mses):.0f}, Std: {np.std(val_mses):.0f}')
print(f'Variance across seeds shows {np.std(val_mses)/np.mean(val_mses)*100:.1f}% variation')

In [None]:
# ============================================================================
# Step 4: Save All Seed Checkpoints
# ============================================================================

out_dir = f'{PROJECT_ROOT}/4_models/v200_beignet_multiseed'
os.makedirs(out_dir, exist_ok=True)

# Save config
config = {
    'h_size': int(best_h),
    'n_layers': int(best_nl),
    'dropout': float(best_do),
    'coral_weight': float(best_cw),
    'mean_weight': float(best_mw),
    'n_features': 9,
    'seeds': seeds,
}
print(f'Config: {config}')

# Save each seed's checkpoint
for s, val_mse, state in seed_results:
    path = f'{out_dir}/model_tcn_seed{s}.pth'
    torch.save({
        'model_state_dict': state,
        'val_mse': val_mse,
        'config': config,
    }, path)
    fsize = os.path.getsize(path) / 1024
    print(f'Saved seed {s}: Val={val_mse:.0f}, size={fsize:.0f}KB -> {path}')

# Save normalization
np.savez(f'{out_dir}/normalization_beignet_tcn.npz', mean=mean, std=std)
print(f'\nSaved normalization to {out_dir}')

# Also save best single model for comparison
best_seed_result = min(seed_results, key=lambda x: x[1])
torch.save({'model_state_dict': best_seed_result[2]}, f'{out_dir}/model_tcn_best_single.pth')
print(f'Best single seed: {best_seed_result[0]} (Val={best_seed_result[1]:.0f})')

In [None]:
# ============================================================================
# Step 5: Local Validation - Compare Single vs Multi-Seed Ensemble
# ============================================================================

def evaluate_on_val(models, X_val_t, Y_val_t):
    """Evaluate ensemble of models on validation set."""
    preds = []
    with torch.no_grad():
        for m in models:
            m.eval()
            p = m(X_val_t).cpu().numpy()
            preds.append(p)
    avg_pred = np.mean(preds, axis=0)
    # Denormalize
    avg_pred_raw = avg_pred * std[...,0] + mean[...,0]
    y_raw = Y_val_t.cpu().numpy() * std[...,0] + mean[...,0]
    return ((avg_pred_raw - y_raw)**2).mean()

X_val_t = torch.FloatTensor(X_val).to(device)
Y_val_t = torch.FloatTensor(Y_val).to(device)

# Load all seed models
all_models = []
for s, _, state in seed_results:
    m = TCNForecaster(89, 9, best_h, best_nl, best_do).to(device)
    m.load_state_dict(state)
    m.eval()
    all_models.append(m)

# Single model MSEs
print('=== Single Model Validation ===')
for i, (s, _, _) in enumerate(seed_results):
    mse = evaluate_on_val([all_models[i]], X_val_t, Y_val_t)
    print(f'seed={s}: Val MSE = {mse:.0f}')

# Ensemble MSEs
print('\n=== Ensemble Validation ===')
for k in [2, 3, 4, 5]:
    if k <= len(all_models):
        mse = evaluate_on_val(all_models[:k], X_val_t, Y_val_t)
        print(f'{k}-seed ensemble: Val MSE = {mse:.0f}')

print('\nDone! Check which ensemble size gives best result.')

In [None]:
# ============================================================================
# Step 6: Download Models
# ============================================================================

from google.colab import files

# Download all seed checkpoints
for s in seeds:
    files.download(f'{out_dir}/model_tcn_seed{s}.pth')

# Download normalization
files.download(f'{out_dir}/normalization_beignet_tcn.npz')

print('\nAll files downloaded!')
print(f'\nNext steps:')
print(f'1. Put model_tcn_seed*.pth in submission folder')
print(f'2. Update model.py to load {len(seeds)} TCN seeds with h={best_h}, L={best_nl}')
print(f'3. Average predictions in model.py')
print(f'4. Submit to Codabench')

In [None]:
# ============================================================================
# Step 7 (Optional): Augmentation Ablation
# Compare with vs without augmentation to see its contribution
# ============================================================================

print('=== Augmentation Ablation (seed=42) ===')
val_aug, _ = train_coral_model(best_h, best_nl, best_do, best_cw, best_mw, seed=42, use_aug=True)
val_noaug, _ = train_coral_model(best_h, best_nl, best_do, best_cw, best_mw, seed=42, use_aug=False)
print(f'\nWith augmentation: {val_aug:.0f}')
print(f'Without augmentation: {val_noaug:.0f}')
print(f'Augmentation effect: {val_aug - val_noaug:+.0f}')