# v192d: NHiTS with CORAL Domain Adaptation

NHiTS 佔 Beignet Public ensemble 的 30% 權重。
如果 NHiTS 也用 CORAL 訓練，可能進一步改善。

NHiTS 用 1 feature (feature 0)，和 v176 完全相同架構。

In [None]:
from google.colab import drive
drive.mount('/content/drive')

PROJECT_ROOT = '/content/drive/MyDrive/Hackathon_NSF_Neural_Forecasting'
TRAIN_DIR = f'{PROJECT_ROOT}/1_data/raw/train_data_neuro'
TEST_DIR = f'{PROJECT_ROOT}/1_data/raw/test_dev_input'

import os, torch, numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from datetime import datetime

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

In [None]:
def coral_loss(source, target):
    d = source.size(1)
    cs = (source - source.mean(0, keepdim=True)).T @ (source - source.mean(0, keepdim=True)) / (source.size(0) - 1 + 1e-8)
    ct = (target - target.mean(0, keepdim=True)).T @ (target - target.mean(0, keepdim=True)) / (target.size(0) - 1 + 1e-8)
    return ((cs - ct) ** 2).sum() / (4 * d)

def mean_alignment_loss(source, target):
    return ((source.mean(0) - target.mean(0)) ** 2).mean()

In [None]:
# NHiTS Architecture (和 v176 完全相同)

class NHiTSBlock(nn.Module):
    def __init__(self, input_size, hidden_size, output_steps, pool_kernel_size, n_freq_downsample, dropout=0.1):
        super().__init__()
        self.output_steps = output_steps
        self.pool_kernel_size = pool_kernel_size
        self.pooling = nn.MaxPool1d(kernel_size=pool_kernel_size, stride=pool_kernel_size, ceil_mode=True)
        self.pooled_size = (input_size + pool_kernel_size - 1) // pool_kernel_size
        self.mlp = nn.Sequential(nn.Linear(self.pooled_size, hidden_size), nn.ReLU(), nn.Dropout(dropout),
                                  nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Dropout(dropout))
        self.n_coeffs = max(1, output_steps // n_freq_downsample)
        self.backcast_proj = nn.Linear(hidden_size, input_size)
        self.forecast_proj = nn.Linear(hidden_size, self.n_coeffs)

    def forward(self, x, return_hidden=False):
        x_pooled = self.pooling(x.unsqueeze(1)).squeeze(1)
        if x_pooled.shape[1] < self.pooled_size:
            x_pooled = F.pad(x_pooled, (0, self.pooled_size - x_pooled.shape[1]))
        elif x_pooled.shape[1] > self.pooled_size:
            x_pooled = x_pooled[:, :self.pooled_size]
        h = self.mlp(x_pooled)
        backcast = self.backcast_proj(h)
        forecast_coeffs = self.forecast_proj(h)
        if self.n_coeffs < self.output_steps:
            forecast = F.interpolate(forecast_coeffs.unsqueeze(1), size=self.output_steps, mode='linear', align_corners=False).squeeze(1)
        else:
            forecast = forecast_coeffs[:, :self.output_steps]
        if return_hidden:
            return backcast, forecast, h
        return backcast, forecast

class NHiTSStack(nn.Module):
    def __init__(self, input_size, hidden_size, output_steps, n_blocks=2, pool_kernel_sizes=None, n_freq_downsamples=None, dropout=0.1):
        super().__init__()
        pool_kernel_sizes = pool_kernel_sizes or [1, 2][:n_blocks]
        n_freq_downsamples = n_freq_downsamples or [1, 2][:n_blocks]
        self.blocks = nn.ModuleList([NHiTSBlock(input_size, hidden_size, output_steps, pool_kernel_sizes[i % len(pool_kernel_sizes)], n_freq_downsamples[i % len(n_freq_downsamples)], dropout) for i in range(n_blocks)])

    def forward(self, x, return_hidden=False):
        residual, total_forecast = x, 0
        hiddens = []
        for block in self.blocks:
            if return_hidden:
                backcast, forecast, h = block(residual, return_hidden=True)
                hiddens.append(h)
            else:
                backcast, forecast = block(residual)
            residual = residual - backcast
            total_forecast = total_forecast + forecast
        if return_hidden:
            return x - residual, total_forecast, hiddens
        return x - residual, total_forecast

class NHiTSForecaster(nn.Module):
    def __init__(self, n_channels, n_features=1, hidden_size=128, num_layers=2, n_stacks=2, seq_len=10, dropout=0.1, output_steps=10):
        super().__init__()
        self.n_channels, self.output_steps = n_channels, output_steps
        self.hidden_size = hidden_size
        self.stacks = nn.ModuleList([NHiTSStack(seq_len, hidden_size, output_steps, num_layers, [pk*(i+1) for pk in [1,2]], [fd*(i+1) for fd in [1,2]], dropout) for i in range(n_stacks)])

    def forward(self, x, return_features=False):
        B, T, C, F = x.shape
        x_flat = x[:,:,:,0].transpose(1,2).reshape(B*C, T)

        if return_features:
            all_hiddens = []
            total_forecast = 0
            for stack in self.stacks:
                _, forecast, hiddens = stack(x_flat, return_hidden=True)
                total_forecast = total_forecast + forecast
                all_hiddens.extend(hiddens)
            pred = total_forecast.view(B, C, self.output_steps).transpose(1,2)
            # Aggregate hidden features: mean across all blocks, then reshape to (B, C, hidden) and mean across C
            h_cat = torch.stack(all_hiddens, dim=0).mean(dim=0)  # (B*C, hidden)
            h_sample = h_cat.view(B, C, -1).mean(dim=1)  # (B, hidden)
            return pred, h_sample
        else:
            total_forecast = sum(stack(x_flat)[1] for stack in self.stacks)
            return total_forecast.view(B, C, self.output_steps).transpose(1,2)

# Verify
m = NHiTSForecaster(89, 1, 256, 2, 2, 10, 0.1)
print(f'NHiTS parameters: {sum(p.numel() for p in m.parameters()):,}')
x_test = torch.randn(2, 10, 89, 1)
pred, feat = m(x_test, return_features=True)
print(f'Pred: {pred.shape}, Feat: {feat.shape}')
del m, x_test

In [None]:
# Load data (NHiTS uses 1 feature only)
train_data = np.load(f'{TRAIN_DIR}/train_data_beignet.npz')['arr_0']
test_public = np.load(f'{TEST_DIR}/test_data_beignet_masked.npz')['arr_0']

X_train = train_data[:, :10, :, 0:1].astype(np.float32)  # 1 feature
Y_train = train_data[:, 10:, :, 0].astype(np.float32)
X_target = test_public[:, :10, :, 0:1].astype(np.float32)

mean = X_train.mean(axis=(0,1), keepdims=True)
std = X_train.std(axis=(0,1), keepdims=True) + 1e-8

X_train_norm = (X_train - mean) / std
Y_train_norm = (Y_train - mean[...,0]) / std[...,0]
X_target_norm = (X_target - mean) / std

n_val = 100
X_tr, X_val = X_train_norm[:-n_val], X_train_norm[-n_val:]
Y_tr, Y_val = Y_train_norm[:-n_val], Y_train_norm[-n_val:]

batch_size = 32
train_ds = TensorDataset(torch.FloatTensor(X_tr), torch.FloatTensor(Y_tr))
val_ds = TensorDataset(torch.FloatTensor(X_val), torch.FloatTensor(Y_val))
target_ds = TensorDataset(torch.FloatTensor(X_target_norm))
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=batch_size)
target_dl = DataLoader(target_ds, batch_size=batch_size, shuffle=True)

print(f'Train: {len(X_tr)}, Val: {len(X_val)}, Target: {len(X_target_norm)}')

In [None]:
# Training
CORAL_WEIGHT = 1.0
MEAN_WEIGHT = 0.5
EPOCHS = 200
PATIENCE = 30

model = NHiTSForecaster(89, 1, 256, 2, 2, 10, 0.1).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

best_val = float('inf')
best_state = None
no_improve = 0

print(f'Training NHiTS with CORAL={CORAL_WEIGHT}, Mean={MEAN_WEIGHT}')
print('-' * 70)

for epoch in range(EPOCHS):
    model.train()
    train_mse, train_coral, n_batches = 0, 0, 0
    target_iter = iter(target_dl)

    for xb, yb in train_dl:
        xb, yb = xb.to(device), yb.to(device)
        try: (xt,) = next(target_iter)
        except StopIteration:
            target_iter = iter(target_dl)
            (xt,) = next(target_iter)
        xt = xt.to(device)

        optimizer.zero_grad()
        pred, feat_src = model(xb, return_features=True)
        _, feat_tgt = model(xt, return_features=True)

        mse = ((pred - yb)**2).mean()
        coral = coral_loss(feat_src, feat_tgt)
        mean_a = mean_alignment_loss(feat_src, feat_tgt)
        loss = mse + CORAL_WEIGHT * coral + MEAN_WEIGHT * mean_a

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        train_mse += mse.item()
        train_coral += coral.item()
        n_batches += 1

    scheduler.step()
    train_mse /= n_batches
    train_coral /= n_batches

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for xb, yb in val_dl:
            xb, yb = xb.to(device), yb.to(device)
            val_loss += ((model(xb) - yb)**2).sum().item()
    val_mse = (val_loss / len(X_val)) * (std[...,0]**2).mean()

    if val_mse < best_val:
        best_val = val_mse
        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        no_improve = 0
        print(f'Epoch {epoch+1:3d}: MSE={train_mse:.4f}, CORAL={train_coral:.6f}, Val={val_mse:.0f} ***')
    else:
        no_improve += 1
        if epoch % 20 == 0:
            print(f'Epoch {epoch+1:3d}: MSE={train_mse:.4f}, CORAL={train_coral:.6f}, Val={val_mse:.0f}')
    if no_improve >= PATIENCE:
        print(f'Early stopping at epoch {epoch+1}')
        break

print(f'\nBest Val MSE: {best_val:.0f}')

In [None]:
# Save
out_dir = f'{PROJECT_ROOT}/4_models/v192d_nhits_coral'
os.makedirs(out_dir, exist_ok=True)

torch.save({
    'model_state_dict': best_state,
    'val_mse': best_val,
}, f'{out_dir}/model_nhits_coral.pth')

np.savez(f'{out_dir}/normalization_beignet_nhits.npz', mean=mean, std=std)

print(f'Saved to {out_dir}')
print('model_nhits_coral.pth → 替換 v176 的 model_nhits.pth')
print('normalization_beignet_nhits.npz → 替換 v176 的 normalization_beignet_nhits.npz')

from google.colab import files
files.download(f'{out_dir}/model_nhits_coral.pth')
files.download(f'{out_dir}/normalization_beignet_nhits.npz')