In [None]:
# --------------------- Imports ---------------------
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import torch.optim as optim
from google.colab import drive

# --------------------- Matplotlib Setup ---------------------
mpl.rcParams.update({
    'font.size': 14,
    'axes.titlesize': 15,
    'axes.labelsize': 12,
    'xtick.labelsize': 11,
    'ytick.labelsize': 11,
    'legend.fontsize': 11,
    'figure.dpi': 300,
    'savefig.dpi': 300,
    'figure.autolayout': True,
})

# --------------------- Load Data ---------------------
print("Mounting Google Drive and loading dataset...")
drive.mount('/content/drive')
total_capture_7k = pd.read_csv('drive/My Drive/C02 project/correlation_wide.csv')
print(f"Loaded dataset with shape: {total_capture_7k.shape}")

# --------------------- Identify Unique Static Parameter Sets ---------------------
static_cols = [
    'MikeSorghum', 'Quartz', 'Plagioclase', 'Apatite', 'Ilmenite',
    'Diopside_Mn', 'Diopside', 'Olivine', 'Alkali-feldspar',
    'Montmorillonite', 'Glass', 'temp', 'shift', 'year'
]

# Add timestep count per file_id
file_lengths = total_capture_7k.groupby('file_id').size().rename("num_timesteps").reset_index()
static_rows = total_capture_7k.groupby('file_id')[static_cols].first().reset_index()
static_rows = static_rows.merge(file_lengths, on='file_id')

# Filter only unique static parameter sets
unique_static_rows = static_rows.drop_duplicates(subset=static_cols)
unique_file_ids = unique_static_rows['file_id'].tolist()

# --------------------- Extract Time Series Data ---------------------
filtered_df = total_capture_7k[total_capture_7k['file_id'].isin(unique_file_ids)].copy()

# Truncate each group to 101 timesteps
filtered_df = filtered_df.groupby('file_id').head(101).reset_index(drop=True)

# --------------------- Static Feature Table ---------------------
Input_Link_Table = filtered_df.groupby('file_id').agg({col: 'first' for col in static_cols}).reset_index()
print(f"Static feature table created: Input_Link_Table.shape = {Input_Link_Table.shape}")

# --------------------- Time Series Structuring ---------------------
result = filtered_df[['Total_CO2_capture', 'year', 'file_id']]
file_ids = result['file_id'].unique()
num_file_ids = len(file_ids)
max_timesteps = 101
relevant_data = np.zeros((num_file_ids, max_timesteps))
file_id_order = np.zeros(num_file_ids)

for i, file_id in enumerate(file_ids):
    file_data = result[result['file_id'] == file_id]['Total_CO2_capture'].values
    relevant_data[i, :len(file_data)] = file_data
    file_id_order[i] = file_id
print(f"Time series matrix constructed: relevant_data.shape = {relevant_data.shape}")

# --------------------- Clustering ---------------------
scaler = StandardScaler()
normalized_data = scaler.fit_transform(relevant_data)
kmeans = KMeans(n_clusters=8, random_state=42)
clusters = kmeans.fit_predict(normalized_data)
print("Performed KMeans clustering into 8 clusters")

# Compute boundary stats
cluster_boundaries = []
for cluster_id in range(8):
    cluster_data = normalized_data[clusters == cluster_id]
    min_v = scaler.inverse_transform(np.min(cluster_data, axis=0).reshape(1, -1)).flatten()
    median_v = scaler.inverse_transform(np.median(cluster_data, axis=0).reshape(1, -1)).flatten()
    mean_v = scaler.inverse_transform(np.mean(cluster_data, axis=0).reshape(1, -1)).flatten()
    max_v = scaler.inverse_transform(np.max(cluster_data, axis=0).reshape(1, -1)).flatten()
    cluster_boundaries.append((min_v, median_v, mean_v, max_v))
cluster_boundaries = np.array(cluster_boundaries)
print(f"Cluster boundary stats calculated: cluster_boundaries.shape = {cluster_boundaries.shape}")

# --------------------- Merge Static Features with Clusters ---------------------
Clustering_link_table = pd.DataFrame({'file_id': file_id_order.astype(int), 'cluster': clusters})
Clustering_link_table = Clustering_link_table.sort_values(by='file_id').reset_index(drop=True)
merged_df = pd.merge(Input_Link_Table, Clustering_link_table, on='file_id')
print(f"Final input features (static + cluster): merged_df.shape = {merged_df.shape}")

# --------------------- Create Output Time Series DataFrame ---------------------
data = [[file_id_order[i].astype(int), t, relevant_data[i, t]] for i in range(len(file_id_order)) for t in range(max_timesteps)]
df_output = pd.DataFrame(data, columns=['file_id', 'timestep', 'CO2']).sort_values(by=['file_id', 'timestep'])
print(f"Final output time series: df_output.shape = {df_output.shape}")

# --------------------- Summary ---------------------
print("Data Preparation Summary:")
print(f"Static Input Table: merged_df [{merged_df.shape[0]} rows × {merged_df.shape[1]} columns]")
print(f"Time Series Output: df_output [{df_output.shape[0]} rows × 3 columns]")
print(f"Cluster Boundaries: cluster_boundaries [{cluster_boundaries.shape}]")

Mounting Google Drive and loading dataset...
Mounted at /content/drive
Loaded dataset with shape: (1192157, 17)
Static feature table created: Input_Link_Table.shape = (2703, 15)
Time series matrix constructed: relevant_data.shape = (2703, 101)
Performed KMeans clustering into 8 clusters
Cluster boundary stats calculated: cluster_boundaries.shape = (8, 4, 101)
Final input features (static + cluster): merged_df.shape = (2703, 16)
Final output time series: df_output.shape = (273003, 3)
Data Preparation Summary:
Static Input Table: merged_df [2703 rows × 16 columns]
Time Series Output: df_output [273003 rows × 3 columns]
Cluster Boundaries: cluster_boundaries [(8, 4, 101)]


In [None]:
# model_tvlssm.py
# Time-Varying Linear SSM (A_t, B_t, C_t, D_t) predicted by a Conv1D+LSTM backbone.
# For a horizon H, the model predicts H distinct (A_t, B_t, C_t, D_t) and rolls the linear SSM forward.

import math
import torch
import torch.nn as nn


class TimeVaryingLSSMHead(nn.Module):
    """
    Predicts time-varying linear state-space parameters for each forecast step t = 1..H:
        z_t = A_t z_{t-1} + B_t u
        y_t = C_t z_t + D_t u
    where u is the static embedding (same for all steps).

    Parameterization:
      - A_t is stabilized as: diag(tanh(a_t)) + uv_scale * (U_t @ V_t^T),
        with U_t, V_t in R^{Z x r}, a_t in R^Z.
      - B_t in R^{Z x u_dim}, C_t in R^{1 x Z}, D_t in R^{1 x u_dim}

    Inputs:
      - context_vec: [B, C]   (e.g., concat of backbone summary h_T and static embedding s)
      - step_embeddings: [H, E] learnable or provided per-horizon embeddings (H = output_dim)
      - u: [B, u_dim] static control vector (same at every t)
    """
    def __init__(self, latent_dim, u_dim, horizon, context_dim, step_emb_dim=64, low_rank=4, uv_scale=0.05):
        super().__init__()
        self.Z = latent_dim
        self.u_dim = u_dim
        self.H = horizon
        self.Cdim = context_dim
        self.E = step_emb_dim
        self.r = low_rank
        self.uv_scale = uv_scale

        # Per-step learnable embeddings (one per horizon step)
        self.step_embed = nn.Parameter(torch.randn(self.H, self.E) * 0.02)

        # Small MLP to produce a compact "param context" per step
        in_dim = self.Cdim + self.E
        hidden = max(128, self.Z)  # modest width
        self.param_ctx = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden), nn.ReLU(),
        )
        self._init_mlp(self.param_ctx)

        # Heads to predict raw parameter tensors from param-context
        # A_t parts
        self.head_a = nn.Linear(hidden, self.Z)                     # diag part
        self.head_U = nn.Linear(hidden, self.Z * self.r)            # U_t flattened
        self.head_V = nn.Linear(hidden, self.Z * self.r)            # V_t flattened

        # B_t, C_t, D_t
        self.head_B = nn.Linear(hidden, self.Z * self.u_dim)        # B_t flattened
        self.head_C = nn.Linear(hidden, self.Z)                     # C_t row (1 x Z)
        self.head_D = nn.Linear(hidden, self.u_dim)                 # D_t row (1 x u_dim)

        # z0 from the same context (no step embedding)
        self.init_net = nn.Sequential(
            nn.Linear(self.Cdim, 2 * self.Z), nn.ReLU(),
            nn.Linear(2 * self.Z, self.Z)
        )
        self._init_mlp(self.init_net)

    @staticmethod
    def _init_mlp(m):
        for mod in m:
            if isinstance(mod, nn.Linear):
                nn.init.kaiming_uniform_(mod.weight, a=math.sqrt(5))
                if mod.bias is not None:
                    nn.init.zeros_(mod.bias)

    def _build_A(self, a_raw, U_raw, V_raw):
        """
        a_raw: [B, Z]
        U_raw: [B, Z, r]
        V_raw: [B, Z, r]
        Returns A: [B, Z, Z]
        """
        B, Z, r = U_raw.shape
        diag = torch.tanh(a_raw)                        # [B, Z], |diag|<1
        A = torch.zeros(B, Z, Z, device=a_raw.device, dtype=a_raw.dtype)
        A = A + torch.diag_embed(diag)                  # put diag entries

        # Low-rank residual: U V^T
        # (B, Z, r) @ (B, r, Z) -> (B, Z, Z)
        A = A + self.uv_scale * torch.bmm(U_raw, V_raw.transpose(1, 2))
        return A

    def forward(self, context_vec, u):
        """
        context_vec: [B, C]
        u:           [B, u_dim]
        Returns:
          y_hat: [B, H]
        Also returns a dict of parameters if needed (disabled in main forward for speed).
        """
        B = context_vec.size(0)
        device = context_vec.device
        Z, H, r, u_dim = self.Z, self.H, self.r, self.u_dim

        # Initial latent state z0
        z_t = self.init_net(context_vec)                            # [B, Z]

        # Repeat static 'u' as needed
        # We'll compute per-step params and roll forward
        y_preds = []

        for t in range(H):
            e_t = self.step_embed[t].unsqueeze(0).expand(B, -1)     # [B, E]
            pc_in = torch.cat([context_vec, e_t], dim=1)            # [B, C+E]
            pc = self.param_ctx(pc_in)                              # [B, hidden]

            # Predict raw params
            a_t = self.head_a(pc)                                   # [B, Z]
            U_t = self.head_U(pc).view(B, Z, r)                     # [B, Z, r]
            V_t = self.head_V(pc).view(B, Z, r)                     # [B, Z, r]
            B_t = self.head_B(pc).view(B, Z, u_dim)                 # [B, Z, u_dim]
            C_t = self.head_C(pc).view(B, 1, Z)                     # [B, 1, Z]
            D_t = self.head_D(pc).view(B, 1, u_dim)                 # [B, 1, u_dim]

            # Build A_t
            A_t = self._build_A(a_t, U_t, V_t)                      # [B, Z, Z]

            # Roll one step: z_t = A_t z_{t-1} + B_t u
            Bu = torch.bmm(B_t, u.unsqueeze(-1)).squeeze(-1)        # [B, Z]
            z_t = torch.bmm(A_t, z_t.unsqueeze(-1)).squeeze(-1) + Bu

            # y_t = C_t z_t + D_t u  -> scalar per step
            Cz = torch.bmm(C_t, z_t.unsqueeze(-1)).squeeze(-1)      # [B, 1]
            Du = torch.bmm(D_t, u.unsqueeze(-1)).squeeze(-1)        # [B, 1]
            y_t = Cz + Du                                           # [B, 1]
            y_preds.append(y_t)

        y_hat = torch.cat(y_preds, dim=1)                           # [B, H]
        return y_hat


class TVLSSMForecastNet(nn.Module):
    """
    Backbone: Static MLP + Conv1D + LSTM encoder over observed window.
    Head:     TimeVaryingLSSMHead that predicts (A_t, B_t, C_t, D_t) per step and rolls the linear SSM.

    time_series_input: [B, L] or [B, 1, L] observed window
    static_input:      [B, S] static features
    returns:           [B, H] forecast
    """
    def __init__(
        self,
        input_len,                 # L (observed window length)
        static_dim,                # number of static features
        hidden_dim,                # conv channels and LSTM hidden size (also latent_dim default)
        horizon,                   # H (forecast length)
        latent_dim=None,           # latent state z size (default: hidden_dim)
        step_emb_dim=64,
        low_rank=4,
        uv_scale=0.05,
        dropout=0.0,
    ):
        super().__init__()
        if latent_dim is None:
            latent_dim = hidden_dim
        self.input_len = input_len
        self.static_dim = static_dim
        self.hidden_dim = hidden_dim
        self.horizon = horizon
        self.latent_dim = latent_dim
        self.u_dim = 64  # static embedding size
        self.dropout_p = dropout

        # ----- Static encoder -----
        self.fc_s1 = nn.Linear(static_dim, 512)
        self.fc_s2 = nn.Linear(512, 256)
        self.fc_s3 = nn.Linear(256, 128)
        self.fc_s4 = nn.Linear(128, self.u_dim)
        self.relu = nn.ReLU()
        self._init_linear(self.fc_s1)
        self._init_linear(self.fc_s2)
        self._init_linear(self.fc_s3)
        self._init_linear(self.fc_s4)
        self.drop = nn.Dropout(self.dropout_p) if self.dropout_p > 0 else nn.Identity()

        # ----- Temporal encoder: Conv1D → LSTM -----
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=hidden_dim, kernel_size=3, padding=1)
        nn.init.kaiming_uniform_(self.conv1.weight, a=math.sqrt(5))
        if self.conv1.bias is not None:
            nn.init.zeros_(self.conv1.bias)

        self.lstm = nn.LSTM(hidden_dim + self.u_dim, hidden_dim, batch_first=True)

        # ----- Time-Varying LSSM Head -----
        context_dim = hidden_dim + self.u_dim  # using [h_T ⊕ s] as context
        self.tvlssm = TimeVaryingLSSMHead(
            latent_dim=self.latent_dim,
            u_dim=self.u_dim,
            horizon=horizon,
            context_dim=context_dim,
            step_emb_dim=step_emb_dim,
            low_rank=low_rank,
            uv_scale=uv_scale,
        )

    @staticmethod
    def _init_linear(layer):
        nn.init.kaiming_uniform_(layer.weight, a=math.sqrt(5))
        if layer.bias is not None:
            nn.init.zeros_(layer.bias)

    def forward(self, time_series_input, static_input):
        # Static embedding u
        s = self.relu(self.fc_s1(static_input))
        s = self.relu(self.fc_s2(s))
        s = self.relu(self.fc_s3(s))
        s = self.fc_s4(s)                         # [B, u_dim]
        s = self.drop(s)

        # Conv over observed window
        x = time_series_input
        if x.dim() == 2:                          # [B, L] -> [B, 1, L]
            x = x.unsqueeze(1)
        conv_out = self.relu(self.conv1(x))       # [B, Hc, L]
        conv_out = conv_out.transpose(1, 2)       # [B, L, Hc]

        # LSTM with static conditioning at each step of the observed window
        s_exp = s.unsqueeze(1).expand(-1, conv_out.size(1), -1)  # [B, L, u_dim]
        lstm_in = torch.cat([conv_out, s_exp], dim=2)            # [B, L, Hc+u]
        lstm_out, _ = self.lstm(lstm_in)                         # [B, L, Hc]
        h_T = lstm_out[:, -1, :]                                 # summary [B, Hc]

        # Context for parameter generation
        context_vec = torch.cat([h_T, s], dim=1)                 # [B, Hc+u]

        # Forecast by rolling the time-varying linear SSM
        y_hat = self.tvlssm(context_vec, s)                      # [B, horizon]
        return y_hat

    @torch.no_grad()
    def debug_inspect_single_step_params(self, time_series_input, static_input, step_idx=0):
        """
        Optional: return A_t, B_t, C_t, D_t for a given step t for inspection (first element of batch).
        """
        self.eval()
        x = time_series_input
        s_in = static_input
        if x.dim() == 2:
            x = x.unsqueeze(1)
        s = self.relu(self.fc_s1(s_in))
        s = self.relu(self.fc_s2(s))
        s = self.relu(self.fc_s3(s))
        s = self.fc_s4(s)

        conv_out = self.relu(self.conv1(x)).transpose(1, 2)
        s_exp = s.unsqueeze(1).expand(-1, conv_out.size(1), -1)
        lstm_in = torch.cat([conv_out, s_exp], dim=2)
        lstm_out, _ = self.lstm(lstm_in)
        h_T = lstm_out[:, -1, :]
        context_vec = torch.cat([h_T, s], dim=1)

        # Build per-step params (single step)
        B = context_vec.size(0)
        e_t = self.tvlssm.step_embed[step_idx].unsqueeze(0).expand(B, -1)
        pc_in = torch.cat([context_vec, e_t], dim=1)
        pc = self.tvlssm.param_ctx(pc_in)

        Z, r, u_dim = self.latent_dim, self.tvlssm.r, self.u_dim
        a_t = self.tvlssm.head_a(pc)
        U_t = self.tvlssm.head_U(pc).view(B, Z, r)
        V_t = self.tvlssm.head_V(pc).view(B, Z, r)
        B_t = self.tvlssm.head_B(pc).view(B, Z, u_dim)
        C_t = self.tvlssm.head_C(pc).view(B, 1, Z)
        D_t = self.tvlssm.head_D(pc).view(B, 1, u_dim)
        A_t = self.tvlssm._build_A(a_t, U_t, V_t)
        return A_t[0].cpu(), B_t[0].cpu(), C_t[0].cpu(), D_t[0].cpu()

In [None]:
# experiment_tvlssm.py
# Train/evaluate the TV-LSSM model over multiple X–Y splits using your existing df_output & merged_df.

import os, math, numpy as np, pandas as pd, torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


# --------------------- Config ---------------------
cfg = {
    "hidden_dim": 101,
    "epochs": 500,
    "batch_size": 124,
    "learning_rate": 1e-3,
    "grad_clip": 1.0,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "low_rank": 4,
    "uv_scale": 0.05,
    "dropout": 0.0,
    "seed": 42,
    "checkpoint_dir": "/content/drive/MyDrive/DSSM-Figures",
    "log_every": 10,
}
os.makedirs(cfg["checkpoint_dir"], exist_ok=True)
torch.manual_seed(cfg["seed"]); np.random.seed(cfg["seed"])

# --------------------- Data helpers (unchanged logic) ---------------------
def pivot_matrix(df_output, ids):
    sub = df_output[df_output["file_id"].isin(ids)]
    piv = sub.pivot(index="file_id", columns="timestep", values="CO2").sort_index()
    return piv.values, piv.index.values

def extract_X_Y(df_output, ids, pct):
    mat, ordered_ids = pivot_matrix(df_output, ids)
    finite_mask = np.all(np.isfinite(mat), axis=1)
    mat = mat[finite_mask]
    kept_ids = ordered_ids[finite_mask]
    split_idx = int(pct / 100 * 101)  # total steps=101
    X = mat[:, :split_idx]
    Y = mat[:, split_idx:]
    return X, Y, kept_ids

def align_static(merged_df, kept_ids):
    stat = merged_df.set_index("file_id").loc[kept_ids]
    stat = stat.drop(columns=["cluster"], errors="ignore")
    S = stat.values
    S = np.nan_to_num(S, nan=0.0, posinf=0.0, neginf=0.0)
    return S

def mse_np(a, b):
    return float(np.mean((a - b) ** 2))

# --------------------- Training pipeline ---------------------
def train_all_splits(df_output, merged_df, splits=None):
    if splits is None:
        splits = [(20,80),(10,90),(5,95),(3,97),(1,99),(80,20),(60,40),(50,50),(40,60)]

    device = cfg["device"]
    results = pd.DataFrame(columns=["Split", "Test_MSE"])

    # Fixed series ID split across all X–Y splits
    file_ids = df_output["file_id"].unique()
    trainval_ids, test_ids = train_test_split(file_ids, test_size=0.2, random_state=cfg["seed"])
    train_ids, val_ids = train_test_split(trainval_ids, test_size=0.2, random_state=cfg["seed"])

    for train_pct, test_pct in splits:
        split_name = f"{train_pct}_{test_pct}"
        print(f"\n==== Running Split: {split_name} ====")

        # Build matrices (these calls drop rows with any NaN/Inf in series)
        X_train, Y_train, kept_train = extract_X_Y(df_output, train_ids, train_pct)
        X_val,   Y_val,   kept_val   = extract_X_Y(df_output, val_ids,   train_pct)
        X_test,  Y_test,  kept_test  = extract_X_Y(df_output, test_ids,  train_pct)

        S_train = align_static(merged_df, kept_train)
        S_val   = align_static(merged_df, kept_val)
        S_test  = align_static(merged_df, kept_test)

        assert X_train.shape[0] == Y_train.shape[0] == S_train.shape[0]
        assert X_val.shape[0]   == Y_val.shape[0]   == S_val.shape[0]
        assert X_test.shape[0]  == S_test.shape[0]

        print(f"Split {split_name} — X_len={X_train.shape[1]} | Y_len={Y_train.shape[1]}")
        print(f"Train: {X_train.shape}, Val: {X_val.shape}, Test inputs: {X_test.shape}")

        # Standardize with train-only stats
        x_scaler = StandardScaler()
        y_scaler = StandardScaler()
        s_scaler = StandardScaler()

        X_train_z = x_scaler.fit_transform(X_train)
        Y_train_z = y_scaler.fit_transform(Y_train)
        S_train_z = s_scaler.fit_transform(S_train)

        X_val_z = x_scaler.transform(X_val)
        Y_val_z = y_scaler.transform(Y_val)
        S_val_z = s_scaler.transform(S_val)

        X_test_z = x_scaler.transform(X_test)
        S_test_z = s_scaler.transform(S_test)  # keep Y_test on original scale

        # Tensors
        X_train_t = torch.tensor(X_train_z, dtype=torch.float32, device=device)
        Y_train_t = torch.tensor(Y_train_z, dtype=torch.float32, device=device)
        S_train_t = torch.tensor(S_train_z, dtype=torch.float32, device=device)

        X_val_t = torch.tensor(X_val_z, dtype=torch.float32, device=device)
        Y_val_t = torch.tensor(Y_val_z, dtype=torch.float32, device=device)
        S_val_t = torch.tensor(S_val_z, dtype=torch.float32, device=device)

        X_test_t = torch.tensor(X_test_z, dtype=torch.float32, device=device)
        S_test_t = torch.tensor(S_test_z, dtype=torch.float32, device=device)

        train_loader = DataLoader(TensorDataset(X_train_t, S_train_t, Y_train_t),
                                  batch_size=cfg["batch_size"], shuffle=True, drop_last=False)
        val_loader = DataLoader(TensorDataset(X_val_t, S_val_t, Y_val_t),
                                batch_size=cfg["batch_size"], shuffle=False, drop_last=False)

        # ----- Model -----
        model = TVLSSMForecastNet(
            input_len=X_train.shape[1],
            static_dim=S_train.shape[1],
            hidden_dim=cfg["hidden_dim"],
            horizon=Y_train.shape[1],
            latent_dim=cfg["hidden_dim"],
            step_emb_dim=64,
            low_rank=cfg["low_rank"],
            uv_scale=cfg["uv_scale"],
            dropout=cfg["dropout"],
        ).to(device)

        optimizer = optim.Adam(model.parameters(), lr=cfg["learning_rate"])
        criterion = nn.MSELoss()
        best_val = math.inf
        ckpt_path = os.path.join(cfg["checkpoint_dir"], f"tvlssm_best_{split_name}.pt")

        print("Training started...")
        for epoch in range(cfg["epochs"]):
            model.train()
            running = 0.0
            for Xb, Sb, Yb in train_loader:
                optimizer.zero_grad()
                preds = model(Xb, Sb)                  # [B, H]
                loss = criterion(preds, Yb)
                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"[Epoch {epoch+1}] NaN/Inf loss. Skipping step.")
                    continue
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), cfg["grad_clip"])
                optimizer.step()
                running += loss.item()

            # Validation
            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for Xb, Sb, Yb in val_loader:
                    preds = model(Xb, Sb)
                    val_loss += criterion(preds, Yb).item()
            val_loss /= max(1, len(val_loader))

            if (epoch + 1) % cfg["log_every"] == 0:
                trn_loss = running / max(1, len(train_loader))
                print(f"Epoch {epoch+1:3d} | Train {trn_loss:.6f} | Val {val_loss:.6f}")

            if val_loss < best_val and not np.isnan(val_loss):
                best_val = val_loss
                torch.save(model.state_dict(), ckpt_path)

        # ----- Test on original scale -----
        model.load_state_dict(torch.load(ckpt_path, map_location=device))
        model.eval()
        with torch.no_grad():
            test_pred_z = model(X_test_t, S_test_t).cpu().numpy()
        test_pred = y_scaler.inverse_transform(test_pred_z)

        # Build true Y_test (on original scale from earlier split)
        _, Y_test, _ = extract_X_Y(df_output, kept_test, train_pct)
        test_mse = mse_np(test_pred, Y_test)
        print(f"Split {split_name} — Final Test MSE (orig scale): {test_mse:.6e}")
        results.loc[len(results)] = [split_name, test_mse]
    return results

# --------------------- Run training ---------------------
# Expect df_output (file_id, timestep, CO2) and merged_df to be defined already by your unchanged data prep module.
# Example:
# from data_prep_module import df_output, merged_df
DSSM_TV_LSSM_mse = train_all_splits(df_output, merged_df)
print("\nAll splits complete:")
print(DSSM_TV_LSSM_mse)


==== Running Split: 20_80 ====
Split 20_80 — X_len=20 | Y_len=81
Train: (1729, 20), Val: (433, 20), Test inputs: (541, 20)
Training started...
Epoch  10 | Train 0.060541 | Val 0.076558
Epoch  20 | Train 0.036317 | Val 0.049876
Epoch  30 | Train 0.017108 | Val 0.031021
Epoch  40 | Train 0.022286 | Val 0.036981
Epoch  50 | Train 0.008802 | Val 0.025435
Epoch  60 | Train 0.006566 | Val 0.020934
Epoch  70 | Train 0.006582 | Val 0.022517
Epoch  80 | Train 0.005030 | Val 0.018471
Epoch  90 | Train 0.009743 | Val 0.026042
Epoch 100 | Train 0.004754 | Val 0.019574
Epoch 110 | Train 0.004677 | Val 0.020491
Epoch 120 | Train 0.003083 | Val 0.019450
Epoch 130 | Train 0.002830 | Val 0.015038
Epoch 140 | Train 0.003164 | Val 0.016135
Epoch 150 | Train 0.003408 | Val 0.015367
Epoch 160 | Train 0.002346 | Val 0.015412
Epoch 170 | Train 0.002015 | Val 0.017589
Epoch 180 | Train 0.001960 | Val 0.017558
Epoch 190 | Train 0.002627 | Val 0.018281
Epoch 200 | Train 0.007710 | Val 0.026046
Epoch 210 | Trai