# Data preperation

In [1]:
# --------------------- Imports ---------------------
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import torch.optim as optim
from google.colab import drive

# --------------------- Matplotlib Setup ---------------------
mpl.rcParams.update({
    'font.size': 14,
    'axes.titlesize': 15,
    'axes.labelsize': 12,
    'xtick.labelsize': 11,
    'ytick.labelsize': 11,
    'legend.fontsize': 11,
    'figure.dpi': 300,
    'savefig.dpi': 300,
    'figure.autolayout': True,
})

# --------------------- Load Data ---------------------
print("Mounting Google Drive and loading dataset...")
drive.mount('/content/drive')
total_capture_7k = pd.read_csv('drive/My Drive/correlation_wide.csv')
print(f"Loaded dataset with shape: {total_capture_7k.shape}")

# --------------------- Identify Unique Static Parameter Sets ---------------------
static_cols = [
    'MikeSorghum', 'Quartz', 'Plagioclase', 'Apatite', 'Ilmenite',
    'Diopside_Mn', 'Diopside', 'Olivine', 'Alkali-feldspar',
    'Montmorillonite', 'Glass', 'temp', 'shift', 'year'
]

# Add timestep count per file_id
file_lengths = total_capture_7k.groupby('file_id').size().rename("num_timesteps").reset_index()
static_rows = total_capture_7k.groupby('file_id')[static_cols].first().reset_index()
static_rows = static_rows.merge(file_lengths, on='file_id')

# Filter only unique static parameter sets
unique_static_rows = static_rows.drop_duplicates(subset=static_cols)
unique_file_ids = unique_static_rows['file_id'].tolist()

# --------------------- Extract Time Series Data ---------------------
filtered_df = total_capture_7k[total_capture_7k['file_id'].isin(unique_file_ids)].copy()

# Truncate each group to 101 timesteps
filtered_df = filtered_df.groupby('file_id').head(101).reset_index(drop=True)

# --------------------- Static Feature Table ---------------------
Input_Link_Table = filtered_df.groupby('file_id').agg({col: 'first' for col in static_cols}).reset_index()
print(f"Static feature table created: Input_Link_Table.shape = {Input_Link_Table.shape}")

# --------------------- Time Series Structuring ---------------------
result = filtered_df[['Total_CO2_capture', 'year', 'file_id']]
file_ids = result['file_id'].unique()
num_file_ids = len(file_ids)
max_timesteps = 101
relevant_data = np.zeros((num_file_ids, max_timesteps))
file_id_order = np.zeros(num_file_ids)

for i, file_id in enumerate(file_ids):
    file_data = result[result['file_id'] == file_id]['Total_CO2_capture'].values
    relevant_data[i, :len(file_data)] = file_data
    file_id_order[i] = file_id
print(f"Time series matrix constructed: relevant_data.shape = {relevant_data.shape}")

# --------------------- Clustering ---------------------
scaler = StandardScaler()
normalized_data = scaler.fit_transform(relevant_data)
kmeans = KMeans(n_clusters=8, random_state=42)
clusters = kmeans.fit_predict(normalized_data)
print("Performed KMeans clustering into 8 clusters")

# Compute boundary stats
cluster_boundaries = []
for cluster_id in range(8):
    cluster_data = normalized_data[clusters == cluster_id]
    min_v = scaler.inverse_transform(np.min(cluster_data, axis=0).reshape(1, -1)).flatten()
    median_v = scaler.inverse_transform(np.median(cluster_data, axis=0).reshape(1, -1)).flatten()
    mean_v = scaler.inverse_transform(np.mean(cluster_data, axis=0).reshape(1, -1)).flatten()
    max_v = scaler.inverse_transform(np.max(cluster_data, axis=0).reshape(1, -1)).flatten()
    cluster_boundaries.append((min_v, median_v, mean_v, max_v))
cluster_boundaries = np.array(cluster_boundaries)
print(f"Cluster boundary stats calculated: cluster_boundaries.shape = {cluster_boundaries.shape}")

# --------------------- Merge Static Features with Clusters ---------------------
Clustering_link_table = pd.DataFrame({'file_id': file_id_order.astype(int), 'cluster': clusters})
Clustering_link_table = Clustering_link_table.sort_values(by='file_id').reset_index(drop=True)
merged_df = pd.merge(Input_Link_Table, Clustering_link_table, on='file_id')
print(f"Final input features (static + cluster): merged_df.shape = {merged_df.shape}")

# --------------------- Create Output Time Series DataFrame ---------------------
data = [[file_id_order[i].astype(int), t, relevant_data[i, t]] for i in range(len(file_id_order)) for t in range(max_timesteps)]
df_output = pd.DataFrame(data, columns=['file_id', 'timestep', 'CO2']).sort_values(by=['file_id', 'timestep'])
print(f"Final output time series: df_output.shape = {df_output.shape}")

# --------------------- Summary ---------------------
print("Data Preparation Summary:")
print(f"Static Input Table: merged_df [{merged_df.shape[0]} rows × {merged_df.shape[1]} columns]")
print(f"Time Series Output: df_output [{df_output.shape[0]} rows × 3 columns]")
print(f"Cluster Boundaries: cluster_boundaries [{cluster_boundaries.shape}]")

Mounting Google Drive and loading dataset...
Mounted at /content/drive
Loaded dataset with shape: (1192157, 17)
Static feature table created: Input_Link_Table.shape = (2703, 15)
Time series matrix constructed: relevant_data.shape = (2703, 101)
Performed KMeans clustering into 8 clusters
Cluster boundary stats calculated: cluster_boundaries.shape = (8, 4, 101)
Final input features (static + cluster): merged_df.shape = (2703, 16)
Final output time series: df_output.shape = (273003, 3)
Data Preparation Summary:
Static Input Table: merged_df [2703 rows × 16 columns]
Time Series Output: df_output [273003 rows × 3 columns]
Cluster Boundaries: cluster_boundaries [(8, 4, 101)]


# Model definition

In [2]:
# Model Definition
class AdvancedDSSMDeepState(nn.Module):
    def __init__(self, input_dim, static_dim, hidden_dim, output_dim):
        super(AdvancedDSSMDeepState, self).__init__()

        # Static Data Path (Fully connected layers for static features)
        self.fc_static1 = nn.Linear(static_dim, 512)
        self.fc_static2 = nn.Linear(512, 256)
        self.fc_static3 = nn.Linear(256, 128)
        self.fc_static4 = nn.Linear(128, 64)

        # Time-series Path (Conv1D for feature extraction)
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=hidden_dim, kernel_size=3, padding=1)
        self.relu = nn.ReLU()

        # Deep State Dynamics (LSTM for latent state transitions)
        self.lstm_state = nn.LSTM(hidden_dim + 64, hidden_dim, batch_first=True)

        # Observation Model (Mapping latent states to outputs)
        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, time_series_input, static_input):
        # Static Data Path
        static_out = self.relu(self.fc_static1(static_input))
        static_out = self.relu(self.fc_static2(static_out))
        static_out = self.relu(self.fc_static3(static_out))
        static_out = self.relu(self.fc_static4(static_out))  # Shape: [batch_size, 64]

        # Time-Series Data Path
        if len(time_series_input.shape) == 2:  # [batch_size, seq_len]
            time_series_input = time_series_input.unsqueeze(1)  # Add channel dimension: [batch_size, 1, seq_len]

        conv_out = self.conv1(time_series_input)  # Conv1D layer
        conv_out = self.relu(conv_out)
        conv_out = conv_out.transpose(1, 2)  # Shape: [batch_size, seq_len, hidden_dim]

        # Expand static features to match the sequence length
        static_expanded = static_out.unsqueeze(1).expand(-1, conv_out.size(1), -1)  # Shape: [batch_size, seq_len, 64]

        # Combine Conv1D features and static features
        lstm_input = torch.cat([conv_out, static_expanded], dim=2)  # Shape: [batch_size, seq_len, hidden_dim + 64]

        # Latent State Dynamics (LSTM for state transitions)
        lstm_out, _ = self.lstm_state(lstm_input)  # Shape: [batch_size, seq_len, hidden_dim]

        # Observation Model
        lstm_out_final = lstm_out[:, -1, :]  # Use the last state for prediction
        x = self.fc1(lstm_out_final)
        x = self.relu(x)
        output = self.fc2(x)  # Final prediction

        return output

# Visualization

In [3]:
def plot_boundary_cases_with_input(inputs, Boundary_case_actuals, Boundary_case_predicted, model_name, input_length):
    case_names = ["Best", "Average", "Worst"]
    x_range = input_length
    y_range = Boundary_case_actuals.shape[1]
    total_timesteps = x_range + y_range

    for i in range(3):
        plt.figure(figsize=(7.5, 3.2))

        # Plot input (X)
        plt.plot(range(x_range), inputs[i], color='black', alpha=0.5, label='Input')

        # Plot output actual vs predicted (Y)
        plt.plot(range(x_range, total_timesteps), Boundary_case_actuals[i], color='blue', alpha=0.8, label='Actual')
        plt.plot(range(x_range, total_timesteps), Boundary_case_predicted[i], color='red', alpha=0.8, label='Predicted')

        plt.xlabel("Time Steps")
        plt.ylabel("CO₂ Sequestration")
        plt.title(f"{case_names[i]} Case – {model_name}")
        plt.legend()
        plt.tight_layout(pad=2.5)

        filename = f"drive/My Drive/DSSM-Figures-final/{model_name}_{case_names[i]}.pdf"
        plt.savefig(filename, format='pdf', bbox_inches='tight')
        plt.close()


def calculate_metrics(actuals, predictions,model_name):
    # Mean Absolute Error (MAE)
    mae = np.mean(np.abs(actuals - predictions), axis=1)
    mae_mean = np.mean(mae)
    mae_std = np.std(mae)

    # Mean Squared Error (MSE)
    mse = np.mean((actuals - predictions) ** 2, axis=1)
    mse_mean = np.mean(mse)
    mse_std = np.std(mse)

    # Symmetric Mean Absolute Percentage Error (SMAPE)
    smape = np.mean(2 * np.abs(actuals - predictions) / (np.abs(actuals) + np.abs(predictions) + 1e-8), axis=1) * 100
    smape_mean = np.mean(smape)
    smape_std = np.std(smape)

    # Root Mean Squared Error (RMSE)
    rmse = np.sqrt(mse)
    rmse_mean = np.mean(rmse)
    rmse_std = np.std(rmse)

    # R-squared (R²)
    ss_res = np.sum((actuals - predictions) ** 2, axis=1)
    ss_tot = np.sum((actuals - np.mean(actuals, axis=1, keepdims=True)) ** 2, axis=1)
    r2 = 1 - (ss_res / ss_tot)
    r2_mean = np.mean(r2)
    r2_std = np.std(r2)

    # Finding indices for the lowest, average, and highest RMSE
    min_rmse_index = np.argmin(rmse)
    max_rmse_index = np.argmax(rmse)
    avg_rmse_index = np.argsort(rmse)[len(rmse) // 2]  # median RMSE as the average case

    # Boundary case actuals and predictions
    Boundary_case_actuals = np.vstack([
        actuals[min_rmse_index],
        actuals[avg_rmse_index],
        actuals[max_rmse_index]
    ])

    Boundary_case_predicted = np.vstack([
        predictions[min_rmse_index],
        predictions[avg_rmse_index],
        predictions[max_rmse_index]
    ])

    # Return metrics and their standard deviations
    return {
        f'{model_name} MAE': mae_mean,
        f'{model_name} MAE_std': mae_std,
        f'{model_name} MSE': mse_mean,
        f'{model_name} MSE_std': mse_std,
        f'{model_name} SMAPE': smape_mean,
        f'{model_name} SMAPE_std': smape_std,
        f'{model_name} RMSE': rmse_mean,
        f'{model_name} RMSE_std': rmse_std,
        f'{model_name} R2': r2_mean,
        f'{model_name} R2_std': r2_std,
        f'Boundary_case_actuals': Boundary_case_actuals,
        f'Boundary_case_predicted': Boundary_case_predicted
    }

# Plotting Function (PDF & LaTeX-Ready)
def plot_discretized_validation(inputs, actuals, predictions, clusters, cluster_boundaries, model_name):
    file_id_to_index = {fid: idx for idx, fid in enumerate(Clustering_link_table['file_id'])}
    test_file_ids = list(Clustering_link_table['file_id'][Clustering_link_table['file_id'].isin(test_ids)])
    cluster_ids = np.array(Clustering_link_table['cluster'][Clustering_link_table['file_id'].isin(test_ids)])
    unique_clusters = np.unique(cluster_ids)
    y_length = predictions.shape[1]
    x_length = inputs.shape[1]
    num_timesteps = x_length + y_length

    for cluster_id in unique_clusters:
        cluster_file_ids = [fid for i, fid in enumerate(test_file_ids) if cluster_ids[i] == cluster_id]
        test_pos_in_preds = [i for i, fid in enumerate(test_file_ids) if cluster_ids[i] == cluster_id]

        plt.figure(figsize=(7.5, 3.2))
        min_x = cluster_boundaries[cluster_id, 0, :x_length]
        max_x = cluster_boundaries[cluster_id, 3, :x_length]
        min_y = cluster_boundaries[cluster_id, 0, -y_length:]
        max_y = cluster_boundaries[cluster_id, 3, -y_length:]

        for i, pred_idx in enumerate(test_pos_in_preds):
            plt.plot(range(x_length), inputs[pred_idx], color='black', alpha=0.3, label='Input' if i == 0 else "")
            plt.plot(range(x_length, num_timesteps), actuals[pred_idx], color='blue', alpha=0.5, label='Actual' if i == 0 else "")
            plt.plot(range(x_length, num_timesteps), predictions[pred_idx], color='red', alpha=0.5, label='Predicted' if i == 0 else "")

        plt.fill_between(range(x_length), min_x, max_x, color='orange', alpha=0.2, label='Input Boundary')
        plt.fill_between(range(x_length, num_timesteps), min_y, max_y, color='grey', alpha=0.3, label='Output Boundary')

        plt.xlabel('Time Steps')
        plt.ylabel('CO₂ Sequestration')
        plt.title(f'{model_name} – Cluster {cluster_id}')
        plt.legend()
        plt.tight_layout(pad=2.5)
        filename = f"drive/My Drive/DSSM-Figures-final/{model_name}_cluster_{cluster_id}.pdf"
        plt.savefig(filename, format='pdf', bbox_inches='tight')
        plt.close()

# Experiment

In [4]:
# Training Loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
splits = [(20, 80)]
for train_pct, test_pct in splits:
    split_name = f"{train_pct}_{test_pct}"
    train_ids, test_ids = train_test_split(df_output['file_id'].unique(), test_size=0.2, random_state=42)
    df_train = df_output[df_output['file_id'].isin(train_ids)]
    df_test = df_output[df_output['file_id'].isin(test_ids)]
    train_timestep = int(train_pct / 100 * 101)
    X_train = df_train.pivot(index='file_id', columns='timestep', values='CO2').values[:, :train_timestep]
    Y_train = df_train.pivot(index='file_id', columns='timestep', values='CO2').values[:, train_timestep:]
    X_test = df_test.pivot(index='file_id', columns='timestep', values='CO2').values[:, :train_timestep]
    Y_test = df_test.pivot(index='file_id', columns='timestep', values='CO2').values[:, train_timestep:]
    static_train = merged_df[merged_df['file_id'].isin(train_ids)].drop(columns=['file_id', 'cluster']).values
    static_test = merged_df[merged_df['file_id'].isin(test_ids)].drop(columns=['file_id', 'cluster']).values

    X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
    Y_train_tensor = torch.tensor(Y_train, dtype=torch.float32)
    static_train_tensor = torch.tensor(static_train, dtype=torch.float32)
    X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
    Y_test_tensor = torch.tensor(Y_test, dtype=torch.float32)
    static_test_tensor = torch.tensor(static_test, dtype=torch.float32)

    train_dataset = TensorDataset(X_train_tensor, static_train_tensor, Y_train_tensor)
    test_dataset = TensorDataset(X_test_tensor, static_test_tensor, Y_test_tensor)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=64)

    model = AdvancedDSSMDeepState(input_dim=train_timestep, static_dim=static_train.shape[1], hidden_dim=101, output_dim=Y_train.shape[1])
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.MSELoss()

    model.train()
    print(f"Training for {split_name}")
    for epoch in range(500):
        for X_batch, static_batch, Y_batch in train_loader:
            optimizer.zero_grad()
            preds = model(X_batch, static_batch)
            loss = criterion(preds, Y_batch)
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

    model.eval()
    with torch.no_grad():
        predictions = []
        actuals = []
        inputs = []
        for X_batch, static_batch, Y_batch in test_loader:
            outputs = model(X_batch, static_batch)
            predictions.append(outputs.numpy())
            actuals.append(Y_batch.numpy())
            inputs.append(X_batch.numpy())

    predictions = np.concatenate(predictions, axis=0)
    actuals = np.concatenate(actuals, axis=0)
    inputs = np.concatenate(inputs, axis=0)
    model_name = f"DSSM_{split_name}"

    plot_discretized_validation(inputs, actuals, predictions, clusters, cluster_boundaries, model_name)
    metrics = calculate_metrics(actuals, predictions, model_name)
    Boundary_case_actuals = metrics['Boundary_case_actuals']
    Boundary_case_predicted = metrics['Boundary_case_predicted']


    input_length = inputs.shape[1]
    rmse = np.sqrt(np.mean((actuals - predictions) ** 2, axis=1))
    min_rmse_index = np.argmin(rmse)
    max_rmse_index = np.argmax(rmse)
    avg_rmse_index = np.argsort(rmse)[len(rmse) // 2]

    selected_inputs = np.vstack([
        inputs[min_rmse_index],
        inputs[avg_rmse_index],
        inputs[max_rmse_index]
    ])

    plot_boundary_cases_with_input(selected_inputs, Boundary_case_actuals, Boundary_case_predicted, model_name, input_length)
     # --- New Visualization 1: Timestep-based Error Plot ---
    timestep_mse = np.mean((actuals - predictions) ** 2, axis=0)

    plt.figure(figsize=(7.5, 3.2))
    plt.plot(range(1, len(timestep_mse) + 1), timestep_mse, color='purple', marker='o')
    plt.xlabel("Predicted Timestep (t)")
    plt.ylabel("Average MSE")
    plt.title(f"Timestep-based-error – {model_name}")
    plt.tight_layout(pad=2.5)
    filename = f"drive/My Drive/DSSM-Figures-final/{model_name}_timestep_error.pdf"
    plt.savefig(filename, format='pdf', bbox_inches='tight')
    plt.close()

    # --- New Visualization 2: Input-based Error Plot + CSV Table ---
    test_mse = np.mean((actuals - predictions) ** 2, axis=1)

    # --- Sort the test MSEs in ascending order ---
    sorted_indices = np.argsort(test_mse)
    sorted_mse = np.array(test_mse)[sorted_indices]

    # --- Plot MSE per test case, sorted ---
    plt.figure(figsize=(7.5, 3.2))
    plt.plot(range(len(sorted_mse)), sorted_mse, color='darkgreen', marker='.')
    plt.xlabel("Renumbered Test Case Index (Sorted by Error)")
    plt.ylabel("MSE")
    plt.title(f"Input-based-error – {model_name}")
    plt.tight_layout(pad=2.5)

    # --- Save to Drive ---
    filename = f"drive/My Drive/DSSM-Figures-final/{model_name}_input_error.pdf"
    plt.savefig(filename, format='pdf', bbox_inches='tight')
    plt.close()

    # Save static features and MSE with new indices
    test_static_df = pd.DataFrame(static_test, columns=merged_df.drop(columns=['file_id', 'cluster']).columns)
    test_static_df.insert(0, "Test_Case_Index", np.arange(len(test_static_df)))
    test_static_df.insert(1, "MSE", test_mse)
    test_static_df.to_csv(f"drive/My Drive/DSSM-Figures-final/{model_name}_test_input_features.csv", index=False)

Training for 20_80
Epoch 1, Loss: 0.15828680992126465
Epoch 2, Loss: 0.13744263350963593
Epoch 3, Loss: 0.10406629741191864
Epoch 4, Loss: 0.05085765942931175
Epoch 5, Loss: 0.025389686226844788
Epoch 6, Loss: 0.025762517005205154
Epoch 7, Loss: 0.025303257629275322
Epoch 8, Loss: 0.01531542930752039
Epoch 9, Loss: 0.026277482509613037
Epoch 10, Loss: 0.03640454262495041
Epoch 11, Loss: 0.01880188286304474
Epoch 12, Loss: 0.030176423490047455
Epoch 13, Loss: 0.012344791553914547
Epoch 14, Loss: 0.01922675222158432
Epoch 15, Loss: 0.020823560655117035
Epoch 16, Loss: 0.014164267107844353
Epoch 17, Loss: 0.015640851110219955
Epoch 18, Loss: 0.00956644769757986
Epoch 19, Loss: 0.02409621886909008
Epoch 20, Loss: 0.026111586019396782
Epoch 21, Loss: 0.01685100980103016
Epoch 22, Loss: 0.01571664772927761
Epoch 23, Loss: 0.010084589943289757
Epoch 24, Loss: 0.016782918944954872
Epoch 25, Loss: 0.030276738107204437
Epoch 26, Loss: 0.015180359594523907
Epoch 27, Loss: 0.0226997509598732
Epoch

In [6]:
    # --- Sort the test MSEs in ascending order ---
    sorted_indices = np.argsort(test_mse)
    sorted_mse = np.array(test_mse)[sorted_indices]

    # --- Plot MSE per test case, sorted ---
    plt.figure(figsize=(7.5, 3.2))
    plt.plot(range(len(sorted_mse)), sorted_mse, color='darkgreen', marker='.')
    plt.xlabel("Sorted Test set error (per timeseries)")
    plt.ylabel("MSE")
    plt.title(f"Input-based-error – {model_name}")
    plt.tight_layout(pad=2.5)

    # --- Save to Drive ---
    filename = f"drive/My Drive/DSSM-Figures-final/{model_name}_input_error.pdf"
    plt.savefig(filename, format='pdf', bbox_inches='tight')
    plt.close()