In [None]:
# forecast_uq.py
import math
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from typing import Tuple, List
from copy import deepcopy
import time

# ---------------------------
# Utils / Data generation
# ---------------------------
def generate_sine_dataset(n_series=1000, length=300, seed=0):
    rng = np.random.RandomState(seed)
    data = []
    for i in range(n_series):
        freq = rng.uniform(0.01, 0.1)
        phase = rng.uniform(0, 2*np.pi)
        trend = rng.uniform(-0.01, 0.01) * np.arange(length)
        noise = rng.normal(scale=0.2, size=length)
        series = np.sin(np.arange(length) * freq + phase) + trend + noise
        data.append(series.astype(np.float32))
    return np.stack(data)  # shape (n_series, length)

class TimeSeriesDataset(Dataset):
    """
    Sliding-window dataset.
    Given series shape (n_series, series_len) produce pairs (encoder_seq, decoder_target)
    encoder_len = input window length
    horizon = how many steps to forecast
    """
    def __init__(self, series: np.ndarray, encoder_len: int=48, horizon: int=12):
        # series: (n_series, series_len)
        self.series = series
        self.enc_len = encoder_len
        self.horizon = horizon
        self.items = []
        n_series, series_len = series.shape
        for i in range(n_series):
            for t in range(series_len - encoder_len - horizon + 1):
                self.items.append((i, t))
    def __len__(self):
        return len(self.items)
    def __getitem__(self, idx):
        i, t = self.items[idx]
        seq = self.series[i, t:t+self.enc_len]           # encoder input
        target = self.series[i, t+self.enc_len:t+self.enc_len+self.horizon]  # future horizon
        return torch.from_numpy(seq).unsqueeze(-1), torch.from_numpy(target).unsqueeze(-1)
        # shapes: (enc_len, 1), (horizon, 1)

# ---------------------------
# Model: Transformer Encoder -> MLP head
# Simple temporal Transformer that outputs multi-step forecast.
# Includes dropout layers we can use for MC dropout.
# ---------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))  # (1, max_len, d_model)
    def forward(self, x):
        # x shape: (batch, seq_len, d_model)
        x = x + self.pe[:, :x.size(1), :].to(x.device)
        return x

class TransformerForecaster(nn.Module):
    def __init__(self, input_dim=1, d_model=64, nhead=4, num_layers=3,
                 dim_feedforward=128, dropout=0.1, horizon=12, use_quantiles: List[float]=None):
        """
        If use_quantiles is None, model predicts mean value for each horizon step.
        If use_quantiles is a list like [0.1, 0.5, 0.9], the head predicts len(use_quantiles) values per horizon.
        """
        super().__init__()
        self.d_model = d_model
        self.input_proj = nn.Linear(input_dim, d_model)
        self.pos_enc = PositionalEncoding(d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
                                                   dim_feedforward=dim_feedforward,
                                                   dropout=dropout, activation='gelu')
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.dropout = nn.Dropout(dropout)  # this dropout will be used for MC-dropout
        self.horizon = horizon
        self.use_quantiles = use_quantiles
        out_feats = horizon * (len(use_quantiles) if use_quantiles else 1)
        self.head = nn.Sequential(
            nn.Linear(d_model, d_model//2),
            nn.GELU(),
            nn.Linear(d_model//2, out_feats)
        )

    def forward(self, x):
        # x: (batch, seq_len, 1)
        x = self.input_proj(x) * math.sqrt(self.d_model)
        x = self.pos_enc(x)
        # transformer expects (seq_len, batch, d_model)
        xf = self.transformer(x.transpose(0,1)).transpose(0,1)  # (batch, seq_len, d_model)
        # Pool across time (use last token or mean pool)
        pooled = xf.mean(dim=1)  # (batch, d_model)
        pooled = self.dropout(pooled)
        out = self.head(pooled)  # (batch, out_feats)
        out = out.view(out.size(0), self.horizon, -1)  # (batch, horizon, feats_per_step)
        # If not quantiles, feats_per_step==1
        return out  # (batch, horizon, q) where q is 1 or num_quantiles

# ---------------------------
# Losses: MSE and Pinball (quantile) loss
# ---------------------------
def mse_loss(pred, target):
    return ((pred - target)**2).mean()

def pinball_loss(pred, target, q):
    # pred, target: (...), q scalar or tensor in [0,1]
    diff = target - pred
    return torch.max(q*diff, (q-1.0)*diff).mean()

def multi_quantile_loss(preds, target, quantiles: List[float]):
    # preds: (batch, horizon, len(quantiles))
    loss = 0.0
    for i, q in enumerate(quantiles):
        loss = loss + pinball_loss(preds[:,:,i], target.squeeze(-1), q)
    return loss / len(quantiles)

# ---------------------------
# Training / Eval helpers
# ---------------------------
def train_one_epoch(model, loader, optimizer, device, use_quantiles=None):
    model.train()
    total_loss = 0.0
    n = 0
    for x, y in loader:
        x = x.to(device)     # (B, enc_len, 1)
        y = y.to(device)     # (B, horizon, 1)
        optimizer.zero_grad()
        out = model(x)       # (B, horizon, q)
        if use_quantiles is None:
            pred = out.squeeze(-1)  # (B, horizon)
            loss = mse_loss(pred, y.squeeze(-1))
        else:
            loss = multi_quantile_loss(out, y, use_quantiles)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
        n += x.size(0)
    return total_loss / n

def evaluate(model, loader, device, use_quantiles=None):
    model.eval()
    total_mse = 0.0
    total_mae = 0.0
    n = 0
    preds = []
    trues = []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            out = model(x)
            if use_quantiles is None:
                pred = out.squeeze(-1)
                total_mse += ((pred - y.squeeze(-1))**2).sum().item()
                total_mae += (pred - y.squeeze(-1)).abs().sum().item()
                preds.append(pred.cpu().numpy())
            else:
                # pick median quantile index
                q_idx = use_quantiles.index(0.5) if 0.5 in use_quantiles else len(use_quantiles)//2
                pred = out[:,:,q_idx]
                total_mse += ((pred - y.squeeze(-1))**2).sum().item()
                total_mae += (pred - y.squeeze(-1)).abs().sum().item()
                preds.append(pred.cpu().numpy())
            trues.append(y.squeeze(-1).cpu().numpy())
            n += y.numel()
    rmse = math.sqrt(total_mse / n)
    mae = total_mae / n
    preds = np.concatenate(preds, axis=0)
    trues = np.concatenate(trues, axis=0)
    return rmse, mae, preds, trues

# ---------------------------
# MC Dropout prediction
# ---------------------------
def mc_dropout_predict(model, x, device, mc_samples=50):
    """
    Run model forward mc_samples times with dropout active.
    Returns: mean_pred (batch, horizon), std_pred (batch, horizon)
    """
    model.train()  # IMPORTANT: leave dropout active, but there are no BN layers here
    preds = []
    with torch.no_grad():
        for _ in range(mc_samples):
            out = model(x.to(device))  # (B, horizon, q)
            out = out.squeeze(-1).cpu().numpy()
            preds.append(out)
    preds = np.stack(preds, axis=0)  # (mc, batch, horizon)
    mean = preds.mean(axis=0)
    std = preds.std(axis=0)
    model.eval()
    return mean, std

# ---------------------------
# Example: train & demonstrate
# ---------------------------
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # generate data
    series = generate_sine_dataset(n_series=500, length=300, seed=42)
    # train / val / test split along series
    n = series.shape[0]
    train_series = series[:380]
    val_series = series[380:440]
    test_series = series[440:]
    encoder_len = 48
    horizon = 12

    train_ds = TimeSeriesDataset(train_series, encoder_len, horizon)
    val_ds = TimeSeriesDataset(val_series, encoder_len, horizon)
    test_ds = TimeSeriesDataset(test_series, encoder_len, horizon)

    batch_size = 64
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size)
    test_loader = DataLoader(test_ds, batch_size=batch_size)

    # ---------------------------
    # 1) Baseline: point forecast (MSE) with dropout (for MC)
    # ---------------------------
    model = TransformerForecaster(input_dim=1, d_model=64, nhead=4, num_layers=3,
                                  dim_feedforward=128, dropout=0.2, horizon=horizon,
                                  use_quantiles=None)
    model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
    epochs = 12
    print("Training point-forecast model (MSE)...")
    best_val = 1e9
    for ep in range(epochs):
        t0 = time.time()
        tr_loss = train_one_epoch(model, train_loader, opt, device, use_quantiles=None)
        rmse_val, mae_val, _, _ = evaluate(model, val_loader, device, use_quantiles=None)
        if rmse_val < best_val:
            best_val = rmse_val
            best_model_state = deepcopy(model.state_dict())
        print(f"Epoch {ep+1}/{epochs}  train_loss={tr_loss:.4f} val_rmse={rmse_val:.4f} val_mae={mae_val:.4f} time={time.time()-t0:.1f}s")
    model.load_state_dict(best_model_state)

    # Evaluate on test set
    rmse_test, mae_test, preds_pt, trues_pt = evaluate(model, test_loader, device, use_quantiles=None)
    print(f"Point model test RMSE={rmse_test:.4f}, MAE={mae_test:.4f}")

    # ---------------------------
    # 2) MC Dropout for uncertainty
    # ---------------------------
    # Take a batch from test_loader and run MC dropout
    x_batch, y_batch = next(iter(test_loader))
    mean_mc, std_mc = mc_dropout_predict(model, x_batch, device, mc_samples=100)
    # Example: compute 95% interval using Gaussian approx
    lower_95 = mean_mc - 1.96 * std_mc
    upper_95 = mean_mc + 1.96 * std_mc
    # compute empirical coverage on this batch (how many true values fall into the interval)
    y_np = y_batch.squeeze(-1).numpy()
    coverage_95 = ((y_np >= lower_95) & (y_np <= upper_95)).mean()
    print(f"MC Dropout batch coverage ~95% interval: {coverage_95*100:.2f}% (on single batch)")

    # ---------------------------
    # 3) Quantile regression model (predict multiple quantiles directly)
    # ---------------------------
    quantiles = [0.1, 0.5, 0.9]
    q_model = TransformerForecaster(input_dim=1, d_model=64, nhead=4, num_layers=3,
                                    dim_feedforward=128, dropout=0.2, horizon=horizon,
                                    use_quantiles=quantiles)
    q_model.to(device)
    q_opt = torch.optim.Adam(q_model.parameters(), lr=1e-3)
    epochs_q = 12
    print("Training quantile regression model...")
    best_val = 1e9
    for ep in range(epochs_q):
        tr_loss = train_one_epoch(q_model, train_loader, q_opt, device, use_quantiles=quantiles)
        rmse_val, mae_val, _, _ = evaluate(q_model, val_loader, device, use_quantiles=quantiles)
        if rmse_val < best_val:
            best_val = rmse_val
            best_model_state = deepcopy(q_model.state_dict())
        print(f"Epoch {ep+1}/{epochs_q}  q_train_loss={tr_loss:.4f} val_rmse={rmse_val:.4f}")
    q_model.load_state_dict(best_model_state)

    # Evaluate quantile coverage on a batch
    q_model.eval()
    xb, yb = next(iter(test_loader))
    with torch.no_grad():
        out = q_model(xb.to(device)).cpu().numpy()  # (B, horizon, q)
    # pick quantiles index
    q_lower = 0  # 0.1
    q_median = 1 # 0.5
    q_upper = 2  # 0.9
    lower_q = out[:,:,q_lower]
    med_q = out[:,:,q_median]
    upper_q = out[:,:,q_upper]
    y_np = yb.squeeze(-1).numpy()
    cov_q = ((y_np >= lower_q) & (y_np <= upper_q)).mean()
    print(f"Quantile model empirical coverage 10-90 interval: {cov_q*100:.2f}% (on single batch)")

    # ---------------------------
    # 4) Deep Ensembles: train multiple point-forecast models and aggregate
    # ---------------------------
    n_ensembles = 3
    ensemble_models = []
    print("Training deep ensemble of point models...")
    for m in range(n_ensembles):
        m_model = TransformerForecaster(input_dim=1, d_model=64, nhead=4, num_layers=3,
                                       dim_feedforward=128, dropout=0.2, horizon=horizon,
                                       use_quantiles=None)
        m_model.to(device)
        m_opt = torch.optim.Adam(m_model.parameters(), lr=1e-3)
        # quick train (few epochs) for demo
        for ep in range(6):
            train_one_epoch(m_model, train_loader, m_opt, device, use_quantiles=None)
        ensemble_models.append(deepcopy(m_model).cpu())
    # Aggregate predictions on batch
    xb, yb = next(iter(test_loader))
    preds_ens = []
    for m in ensemble_models:
        m.eval()
        with torch.no_grad():
            p = m(xb.to(device)).squeeze(-1).cpu().numpy()  # (B, horizon)
            preds_ens.append(p)
    preds_ens = np.stack(preds_ens, axis=0)  # (n_ens, B, horizon)
    ens_mean = preds_ens.mean(axis=0)
    ens_std = preds_ens.std(axis=0)
    # Example coverage using mean +/- 1.96*std
    cov_ens = ((yb.squeeze(-1).numpy() >= ens_mean - 1.96*ens_std) & (yb.squeeze(-1).numpy() <= ens_mean + 1.96*ens_std)).mean()
    print(f"Ensemble empirical coverage (approx 95%): {cov_ens*100:.2f}% (on single batch)")

    # ---------------------------
    # Final test RMSE for quantile & point models
    # ---------------------------
    rmse_point, mae_point, _, _ = evaluate(model, test_loader, device, use_quantiles=None)
    rmse_q, mae_q, _, _ = evaluate(q_model, test_loader, device, use_quantiles=quantiles)
    print(f"Final test RMSE: point_model={rmse_point:.4f}, quantile_model(median)={rmse_q:.4f}")

if __name__ == "__main__":
    main()



Training point-forecast model (MSE)...
Epoch 1/12  train_loss=0.1489 val_rmse=0.2759 val_mae=0.2168 time=275.8s
Epoch 2/12  train_loss=0.0878 val_rmse=0.2574 val_mae=0.2034 time=266.4s
Epoch 3/12  train_loss=0.0834 val_rmse=0.2636 val_mae=0.2076 time=268.7s
Epoch 4/12  train_loss=0.0808 val_rmse=0.2626 val_mae=0.2071 time=277.2s
