In [None]:


import os
import math
import random
from typing import Tuple, Dict, Any

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import trange, tqdm

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import statsmodels.api as sm
import shap
from captum.attr import IntegratedGradients

# -----------------------------
# 1) DATA GENERATION
# -----------------------------
def generate_multiseasonal_multivariate(n_samples: int = 3000, random_seed: int = 42) -> pd.DataFrame:
    """
    Produce a multivariate time series with:
    - trend (nonstationarity)
    - two seasonalities (daily-like + weekly-like)
    - multiplicative seasonal effects
    - occasional regime shifts (structural breaks)
    - multiple channels/features with coupling
    Returns a DataFrame with columns: ['t','feat_0','feat_1','feat_2']
    """
    np.random.seed(random_seed)
    t = np.arange(n_samples)
    # base trend (non-stationary)
    trend = 0.001 * (t**1.2)  # slowly accelerating trend

    # seasonality A: period 24
    period_a = 24
    season_a = 1.5 * np.sin(2 * np.pi * t / period_a) + 0.3 * np.sin(2 * np.pi * t / (period_a/2))

    # seasonality B: period 168 (weekly)
    period_b = 168
    season_b = 2.0 * np.sin(2 * np.pi * t / period_b) + 0.5 * np.sin(2 * np.pi * t / (period_b/3))

    # multiplicative interaction and noise
    noise = 0.5 * np.random.randn(n_samples)

    # regime shifts: add jump at two points
    shift = np.zeros(n_samples)
    shift[n_samples//3:] += 1.5
    shift[2*n_samples//3:] -= 0.8

    # Construct three features with coupling + nonlinear transformations
    base = (trend + season_a + season_b + shift)
    feat0 = base + 0.3 * np.sin(0.05 * t) + 0.2 * noise
    feat1 = 0.5 * base + 0.2 * np.cos(0.02 * t) + 0.3 * np.random.randn(n_samples)
    # nonlinear coupling of past values to produce more complexity
    feat2 = 0.3 * np.roll(base, 1) + 0.7 * np.tanh(base) + 0.25 * np.random.randn(n_samples)
    # first element fix after roll
    feat2[0] = feat0[0] * 0.2 + 0.1 * np.random.randn()

    df = pd.DataFrame({
        't': t,
        'feat_0': feat0,
        'feat_1': feat1,
        'feat_2': feat2
    })
    return df

# -----------------------------
# 2) PREPROCESSING: SCALER + WINDOWING
# -----------------------------
class TimeSeriesWindowDataset(Dataset):
    def __init__(self, data: np.ndarray, input_width: int, output_width: int, stride: int = 1):
        """
        data: shape (n_samples, n_features)
        Returns windows for multi-step forecasting:
         - X: shape (num_windows, input_width, n_features)
         - y: shape (num_windows, output_width, target_dim)  (we'll predict first feature by default)
        """
        self.data = data.astype(np.float32)
        self.input_width = input_width
        self.output_width = output_width
        self.stride = stride
        self.n_samples, self.n_features = data.shape
        self.indices = []
        # create indices
        max_start = self.n_samples - (input_width + output_width) + 1
        for s in range(0, max_start, stride):
            self.indices.append(s)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        s = self.indices[idx]
        x = self.data[s: s + self.input_width]  # shape (input_width, n_features)
        y = self.data[s + self.input_width: s + self.input_width + self.output_width, 0:1]  # predict feat_0
        return x, y

# -----------------------------
# 3) MODEL: PyTorch LSTM
# -----------------------------
class LSTMForecast(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, num_layers: int, output_width: int, dropout: float = 0.0):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout)
        # map last hidden state to multi-step outputs
        self.fc = nn.Linear(hidden_dim, output_width)  # output_width for scalar target per step
        # we will output a vector of length output_width; interpret as next output_width steps for feature 0

    def forward(self, x):
        # x: (batch, seq_len, input_dim)
        out, (hn, cn) = self.lstm(x)  # out: (batch, seq_len, hidden)
        # take last timestep hidden state
        last = out[:, -1, :]  # (batch, hidden_dim)
        out = self.fc(last)   # (batch, output_width)
        # reshape to (batch, output_width, 1) for compatibility
        return out.unsqueeze(-1)

# -----------------------------
# 4) TRAIN / EVAL UTIL
# -----------------------------
def train_model(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader,
                n_epochs: int, lr: float, device: torch.device, verbose=True) -> Dict[str, Any]:
    model.to(device)
    optim = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    history = {'train_loss': [], 'val_loss': []}
    best_val = float('inf')
    best_state = None

    for epoch in range(1, n_epochs+1):
        model.train()
        running = 0.0
        for xb, yb in train_loader:
            xb = xb.to(device)
            yb = yb.to(device)
            optim.zero_grad()
            pred = model(xb)
            loss = criterion(pred, yb)
            loss.backward()
            optim.step()
            running += loss.item() * xb.size(0)
        train_loss = running / len(train_loader.dataset)
        # validation
        model.eval()
        running_v = 0.0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb = xb.to(device); yb = yb.to(device)
                pred = model(xb)
                running_v += nn.MSELoss(reduction='sum')(pred, yb).item()
        val_loss = running_v / len(val_loader.dataset)
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        if verbose and epoch % max(1, n_epochs//10) == 0:
            print(f"Epoch {epoch}/{n_epochs}  train_mse={train_loss:.6f}  val_mse={val_loss:.6f}")
        if val_loss < best_val:
            best_val = val_loss
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}
    if best_state is not None:
        model.load_state_dict(best_state)
    return {'model': model, 'history': history, 'best_val_mse': best_val}

def predict_numpy(model: nn.Module, X: np.ndarray, device: torch.device) -> np.ndarray:
    model.to(device)
    model.eval()
    with torch.no_grad():
        tX = torch.from_numpy(X.astype(np.float32)).to(device)
        y = model(tX).cpu().numpy()  # (batch, output_width, 1)
    return y.squeeze(-1)

# -----------------------------
# 5) METRICS
# -----------------------------
def rmse(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    return float(np.sqrt(mean_squared_error(y_true.flatten(), y_pred.flatten())))

def mape(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    denom = np.where(np.abs(y_true) < 1e-6, 1e-6, np.abs(y_true))
    return float(np.mean(np.abs((y_true - y_pred) / denom)) * 100.0)

# -----------------------------
# 6) STATISTICAL BASELINE: SARIMA
# -----------------------------
def sarima_forecast(train_series: pd.Series, test_len: int, order=(1,1,1), seasonal_order=(1,1,1,24)):
    """
    Fit SARIMAX (simple) and forecast test_len steps ahead.
    Returns forecast array of length test_len.
    """
    # statsmodels expects endog as Series
    model = sm.tsa.statespace.SARIMAX(train_series, order=order, seasonal_order=seasonal_order, enforce_stationarity=False, enforce_invertibility=False)
    res = model.fit(disp=False)
    pred = res.get_forecast(steps=test_len)
    mean = pred.predicted_mean.values
    return mean, res

# -----------------------------
# 7) MAIN PIPELINE
# -----------------------------
def main_pipeline():
    seed = 123
    np.random.seed(seed); torch.manual_seed(seed); random.seed(seed)
    # 1) generate
    n = 3000  # >= 1000 as required
    df = generate_multiseasonal_multivariate(n_samples=n, random_seed=seed)
    print("Generated data sample:\n", df.head())

    # Split: train/val/test (60/20/20)
    train_end = int(0.6 * n)
    val_end = int(0.8 * n)
    train_df = df.iloc[:train_end].reset_index(drop=True)
    val_df = df.iloc[train_end:val_end].reset_index(drop=True)
    test_df = df.iloc[val_end:].reset_index(drop=True)

    # Save CSVs
    os.makedirs('output', exist_ok=True)
    df.to_csv('output/full_series.csv', index=False)
    train_df.to_csv('output/train.csv', index=False)
    val_df.to_csv('output/val.csv', index=False)
    test_df.to_csv('output/test.csv', index=False)

    # 2) scale
    scaler = StandardScaler()
    scaler.fit(train_df[['feat_0', 'feat_1', 'feat_2']].values)
    train_scaled = scaler.transform(train_df[['feat_0','feat_1','feat_2']])
    val_scaled = scaler.transform(val_df[['feat_0','feat_1','feat_2']])
    test_scaled = scaler.transform(test_df[['feat_0','feat_1','feat_2']])

    # 3) create window datasets
    input_width = 168  # use one weekly range as input
    output_width = 24  # predict next day (multi-step)
    batch_size = 64

    train_ds = TimeSeriesWindowDataset(train_scaled, input_width, output_width, stride=3)
    val_ds = TimeSeriesWindowDataset(np.vstack([train_scaled[-input_width:], val_scaled]), input_width, output_width, stride=1)  # val windows crossing train/val boundary
    test_ds = TimeSeriesWindowDataset(np.vstack([val_scaled[-input_width:], test_scaled]), input_width, output_width, stride=1)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

    # 4) hyperparameter search (simple grid)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    hyperparams = [
        {'hidden_dim': 64, 'num_layers': 1, 'lr': 1e-3, 'dropout': 0.0, 'n_epochs': 20},
        {'hidden_dim': 128, 'num_layers': 2, 'lr': 5e-4, 'dropout': 0.1, 'n_epochs': 30},
    ]
    best = None
    best_val = float('inf')
    for hp in hyperparams:
        print("Training with hp:", hp)
        model = LSTMForecast(input_dim=3, hidden_dim=hp['hidden_dim'], num_layers=hp['num_layers'],
                             output_width=output_width, dropout=hp['dropout'])
        result = train_model(model, train_loader, val_loader, n_epochs=hp['n_epochs'], lr=hp['lr'], device=device)
        val_mse = result['best_val_mse']
        if val_mse < best_val:
            best_val = val_mse
            best = {'hp': hp, 'model': result['model'], 'history': result['history']}

    print("Best hyperparams:", best['hp'], "val_mse:", best_val)
    # save model
    torch.save(best['model'].state_dict(), 'output/lstm_best.pth')

    # 5) Evaluate on test set (sliding windows -> reassemble forecast to series for proper metrics)
    # We'll create repeated forecasts for each test window and compare to true subsequent block at each step.
    X_test = []
    y_test_true = []
    for x, y in test_ds:
        X_test.append(x)
        y_test_true.append(y)
    X_test = np.stack(X_test)  # (num_windows, input_width, n_features)
    y_test_true = np.stack(y_test_true)  # (num_windows, output_width, 1)

    y_test_pred_scaled = predict_numpy(best['model'], X_test, device=device)  # (num_windows, output_width)
    # inverse scale only the target column scaling
    # scaler.scale_ and scaler.mean_ correspond to 3 features; target is column 0
    target_std = scaler.scale_[0]
    target_mean = scaler.mean_[0]
    y_test_pred = y_test_pred_scaled * target_std + target_mean
    y_test_true_unscaled = y_test_true.squeeze(-1) * target_std + target_mean

    test_rmse = rmse(y_test_true_unscaled, y_test_pred)
    test_mape = mape(y_test_true_unscaled, y_test_pred)
    print(f"LSTM Test RMSE: {test_rmse:.4f}, MAPE: {test_mape:.2f}%")

    # Save sample predictions to CSV
    pd.DataFrame({
        'y_true_flat': y_test_true_unscaled.flatten(),
        'y_pred_flat': y_test_pred.flatten()
    }).to_csv('output/lstm_test_preds.csv', index=False)

    # 6) SARIMA baseline: fit on feat_0 training series and forecast test length
    # For a fair comparison: forecast same number of total points as test set length *assuming one-step rolling forecast*
    # Here we'll compute multi-step forecasts matching the test windows' first output step
    train_series = train_df['feat_0']
    test_len = len(test_df)  # number of time steps in test set
    # Use seasonal_order with daily period (24) AND weekly optional; here we choose 24 seasonality.
    sarima_pred_mean, sarima_res = sarima_forecast(train_series, test_len=test_len, order=(1,1,1), seasonal_order=(1,1,1,24))
    # Compute RMSE / MAPE between sarima_pred_mean and true test series feat_0
    y_true_test_series = test_df['feat_0'].values
    sarima_rmse = rmse(y_true_test_series, sarima_pred_mean)
    sarima_mape = mape(y_true_test_series, sarima_pred_mean)
    print(f"SARIMA Test RMSE: {sarima_rmse:.4f}, MAPE: {sarima_mape:.2f}%")
    # Save baseline
    pd.DataFrame({'sarima_pred': sarima_pred_mean, 'y_true': y_true_test_series}).to_csv('output/sarima_preds.csv', index=False)

    # 7) Explainability
    # 7a) SHAP KernelExplainer (model-agnostic). KernelExplainer is slow; sample background
    # We'll explain the model's predictions for the first 50 windows.
    background_idx = np.random.choice(range(len(X_test)), size=min(50, len(X_test)), replace=False)
    background = X_test[background_idx]  # shape (B, input_width, n_features)
    # kernel explainer needs 2D vectorization; flatten windows
    def model_flattened(x_flat_np):
        # x_flat_np: (m, input_width * n_features)
        x = x_flat_np.reshape((-1, X_test.shape[1], X_test.shape[2])).astype(np.float32)
        with torch.no_grad():
            t = torch.from_numpy(x).to(device)
            pred = best['model'](t).cpu().numpy()
        # return shape (m, output_width) -> we summarize by returning first predicted step only (or return mean)
        return pred[:, 0, 0]  # first step prediction

    print("Preparing SHAP KernelExplainer (this can be slow)...")
    shap_background_flat = background.reshape(background.shape[0], -1)
    explainer = shap.KernelExplainer(model_flattened, shap_background_flat)
    sample_to_explain = X_test[:20].reshape(20, -1)
    shap_values = explainer.shap_values(sample_to_explain, nsamples=100)  # reduces runtime by using nsamples
    # shap_values shape: (20, input_width*n_features)
    # aggregate per feature across lag positions
    shap_vals_arr = np.array(shap_values)
    shap_vals_arr = shap_vals_arr.reshape(20, X_test.shape[1], X_test.shape[2])  # (20, input_width, n_features)
    # average absolute importance across windows
    mean_abs_shap = np.mean(np.abs(shap_vals_arr), axis=0)  # (input_width, n_features)

    # Save shap importance aggregated by lag and feature
    lag = np.arange(X_test.shape[1])
    shap_df = pd.DataFrame(mean_abs_shap, columns=['feat_0','feat_1','feat_2'])
    shap_df['lag'] = lag
    shap_df.to_csv('output/shap_lag_feature_importance.csv', index=False)
    print("Saved SHAP aggregated importances to output/shap_lag_feature_importance.csv")

    # 7b) Integrated Gradients via Captum â€” attribute to input window for a chosen example
    # For captum, we provide a single example and compute attributions wrt input features across time.
    ig = IntegratedGradients(best['model'])
    # pick an example
    example_idx = 5
    x_example = torch.from_numpy(X_test[example_idx:example_idx+1].astype(np.float32)).to(device)
    x_example.requires_grad = True
    # baseline: zeros
    baseline = torch.zeros_like(x_example).to(device)
    attr, delta = ig.attribute(x_example, baselines=baseline, target=None, return_convergence_delta=True, n_steps=50)
    # attr shape (1, input_width, n_features)
    attr_np = attr.detach().cpu().numpy().squeeze(0)
    pd.DataFrame(attr_np, columns=['feat_0','feat_1','feat_2']).to_csv('output/ig_attributions_example5.csv', index=False)
    print("Saved Integrated Gradients attributions for example 5 to output/ig_attributions_example5.csv")

    # 8) Simple plotting for convenience (saved to files)
    plt.figure(figsize=(10,4))
    plt.plot(df['t'][-500:], df['feat_0'][-500:], label='feat_0 (true recent)')
    plt.title('Recent feat_0')
    plt.savefig('output/recent_feat0.png')
    plt.close()

    # 9) Summary metrics CSV
    summary = {
        'model': 'LSTM',
        'lstm_rmse': test_rmse,
        'lstm_mape': test_mape,
        'sarima_rmse': sarima_rmse,
        'sarima_mape': sarima_mape,
        'n_samples': n,
        'input_width': input_width,
        'output_width': output_width
    }
    pd.DataFrame([summary]).to_csv('output/summary_metrics.csv', index=False)
    print("Pipeline completed. Outputs in ./output/")

if __name__ == '__main__':
    main_pipeline()