# MPC Surrogate Training Pipeline

This notebook implements the complete training and evaluation pipeline for approximating MPC policies using neural networks.

## Dataset Structure
- **States**: Joint positions and velocities (6D) - [q1, q2, q3, q̇1, q̇2, q̇3]
- **Targets**: End-effector target positions (3D) - [x, y, z]
- **Actions**: MPC torques (3D) - [τ1, τ2, τ3]

The goal is to learn a mapping: (state, target) → MPC torques

In [None]:
# Install dependencies
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install h5py numpy matplotlib scikit-learn tqdm

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
import h5py
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error
import json
import os
from tqdm import tqdm
import time
from torch.nn.utils.rnn import pad_sequence
import random

torch.manual_seed(42)
np.random.seed(42)

DATA_PATH = 'data/robot_mpc_dataset.h5'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

## Data Preparation - Visualization

In [None]:
class MPCDataset(Dataset):
    def __init__(self, filepath, episode_keys, mode="mlp", augment=False):
        """
        mode: 'mlp' (flattens trajectories) or 'rnn' (keeps trajectories intact)
        """
        super().__init__()
        self.augment = augment
        self.mode = mode
        self.data = []  # (inputs, actions) tuples

        with h5py.File(filepath, "r") as f:
            grp_eps = f["episodes"]
            for ep in episode_keys:
                # Load raw data
                s = torch.from_numpy(grp_eps[ep]["states"][:]).float()  # (T, 6)
                t = torch.from_numpy(grp_eps[ep]["targets"][:]).float()  # (T, 3)
                a = torch.from_numpy(grp_eps[ep]["actions"][:]).float()  # (T, 3)

                # states + targets as input -> (T, 9)
                inp = torch.cat([s, t], dim=-1)
                self.data.append((inp, a))

        # if MLP, flatten all steps from these specific episodes into one tensor
        if self.mode == "mlp":
            self.inputs = torch.cat([x[0] for x in self.data], dim=0)
            self.actions = torch.cat([x[1] for x in self.data], dim=0)

    def __len__(self):
        return len(self.inputs) if self.mode == "mlp" else len(self.data)

    def __getitem__(self, idx):
        if self.mode == "mlp":
            x, y = self.inputs[idx], self.actions[idx]  # (9,), (3,)
            if self.augment:
                x[:6] += torch.randn(6) * 0.01  # noise on state
                y += torch.randn(3) * 0.005  # noise on action
            return x, y

        # RNN
        else:
            x, y = self.data[idx]  # (T, 9), (T, 3)
            if self.augment:
                noise_x = torch.randn_like(x)
                noise_x[:, 6:] = 0  # no noise on targets
                x = x + (noise_x * 0.01)
                y = y + (torch.randn_like(y) * 0.005)
            return x, y


def collate_rnn(batch):
    inputs, actions = zip(*batch)
    lengths = torch.tensor([x.size(0) for x in inputs])

    # pad variable lengths (T0, T1...) to max length in batch
    padded_inputs = pad_sequence(inputs, batch_first=True)  # (B, T_max, 9)
    padded_actions = pad_sequence(actions, batch_first=True)  # (B, T_max, 3)

    return padded_inputs, padded_actions, lengths


def create_dataloaders(filepath, train_ratio=0.8, batch_size=32):
    # split by episode idx
    with h5py.File(filepath, "r") as f:
        keys = np.array(sorted(f["episodes"].keys()))

    np.random.shuffle(keys)
    split = int(len(keys) * train_ratio)
    train_keys, val_keys = keys[:split], keys[split:]

    train_ds_mlp = MPCDataset(filepath, train_keys, mode="mlp", augment=True)
    val_ds_mlp = MPCDataset(filepath, val_keys, mode="mlp", augment=False)

    train_ds_rnn = MPCDataset(filepath, train_keys, mode="rnn", augment=True)
    val_ds_rnn = MPCDataset(filepath, val_keys, mode="rnn", augment=False)

    tl_mlp = DataLoader(train_ds_mlp, batch_size=batch_size, shuffle=True)
    vl_mlp = DataLoader(val_ds_mlp, batch_size=batch_size, shuffle=False)

    tl_rnn = DataLoader(train_ds_rnn, batch_size=batch_size, shuffle=True, collate_fn=collate_rnn)
    vl_rnn = DataLoader(val_ds_rnn, batch_size=batch_size, shuffle=False, collate_fn=collate_rnn)

    return tl_mlp, vl_mlp, tl_rnn, vl_rnn

# Data vizualization

In [None]:
episodes = []
with h5py.File("robot_mpc_dataset.h5", "r") as f:
    keys = np.array(sorted(f["episodes"].keys()))
    for ep in keys:
        grp = f["episodes"][ep]
        states = grp["states"][:]
        targets = grp["targets"][:]
        actions = grp["actions"][:]
        episodes.append((states, targets, actions))

# pick a random episode
ep_idx = random.randint(0, len(episodes) - 1)
episode = episodes[ep_idx]
ep_states = episode[0]
ep_targets = episode[1]
ep_actions = episode[2]

print(f"Visualizing episode {ep_idx}")

# joint positions
plt.figure(figsize=(12, 6))
plt.subplot(3, 1, 1)
plt.plot(ep_states[:, 0], label="q1")
plt.plot(ep_states[:, 1], label="q2")
plt.plot(ep_states[:, 2], label="q3")
plt.title("Joint Positions")
plt.legend()

# joint velocities
plt.subplot(3, 1, 2)
plt.plot(ep_states[:, 3], label="dq1")
plt.plot(ep_states[:, 4], label="dq2")
plt.plot(ep_states[:, 5], label="dq3")
plt.title("Joint Velocities")
plt.legend()

# MPC actions
plt.subplot(3, 1, 3)
plt.plot(ep_actions[:, 0], label="tau1")
plt.plot(ep_actions[:, 1], label="tau2")
plt.plot(ep_actions[:, 2], label="tau3")
plt.title("MPC Torques")
plt.legend()

plt.tight_layout()
plt.show()


: 

# Scikit-learn Baseline

In [None]:
import os
import glob
import numpy as np
from tqdm import tqdm

from sklearn.linear_model import LinearRegression, Ridge, Lasso
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.neighbors import KNeighborsRegressor
from sklearn.svm import SVR
from sklearn.multioutput import MultiOutputRegressor

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split


# ============================================================
# 1. LOAD DATASET
# episodes/ep_xxxx/{states, targets, actions}.npy
# ============================================================


def load_dataset(root="episodes"):
    X_list, y_list = [], []

    for ep in tqdm(sorted(glob.glob(os.path.join(root, "ep_*")))):
        states = np.load(os.path.join(ep, "states.npy"))  # (T, 6)
        targets = np.load(os.path.join(ep, "targets.npy"))  # (T, 3)
        actions = np.load(os.path.join(ep, "actions.npy"))  # (T, 3)

        X = np.concatenate([states, targets], axis=1)  # (T, 9)
        y = actions  # (T, 3)

        X_list.append(X)
        y_list.append(y)

    X = np.vstack(X_list)
    y = np.vstack(y_list)

    print("Loaded dataset:")
    print("X:", X.shape, "  y:", y.shape)
    return X, y


X, y = load_dataset("episodes")


# ============================================================
# 2. Split + Normalize
# ============================================================

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=True)

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)


# ============================================================
# 3. Baseline Models
# ============================================================

models = {
    "Linear Regression": LinearRegression(),
    "Ridge": Ridge(alpha=1.0),
    "Lasso": Lasso(alpha=1e-3),
    "Random Forest": RandomForestRegressor(n_estimators=200, n_jobs=-1),
    "Gradient Boosting": GradientBoostingRegressor(),
    "KNN": KNeighborsRegressor(n_neighbors=5),
    "SVM (RBF)": MultiOutputRegressor(SVR(kernel="rbf", C=10, gamma="scale")),
}

results = {}

print("\n===== TRAINING BASELINES =====")
for name, model in models.items():
    print(f"\nTraining: {name}")
    model.fit(X_train, y_train)
    preds = model.predict(X_test)
    mse = mean_squared_error(y_test, preds)
    results[name] = mse
    print(f"{name} MSE: {mse:.6f}")


# ============================================================
# 4. Summary
# ============================================================

print("\n==== FINAL RESULTS ====")
for name, mse in results.items():
    print(f"{name:20s}: MSE = {mse:.6f}")


## Model Architectures

In [None]:
class MLP(nn.Module):
    """Simple Multi-Layer Perceptron"""
    def __init__(self, input_dim=9, hidden_dims=[128, 64], output_dim=3):
        super(MLP, self).__init__()
        layers = []
        prev_dim = input_dim

        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_dim),
                nn.Dropout(0.1)
            ])
            prev_dim = hidden_dim

        layers.append(nn.Linear(prev_dim, output_dim))
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super(ResidualBlock, self).__init__()
        self.linear1 = nn.Linear(dim, dim)
        self.linear2 = nn.Linear(dim, dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        residual = x
        out = self.relu(self.linear1(x))
        out = self.dropout(out)
        out = self.linear2(out)
        out += residual
        out = self.relu(out)
        return out

class ResNet(nn.Module):
    """Residual Network for MPC surrogate"""
    def __init__(self, input_dim=9, hidden_dim=128, num_blocks=3, output_dim=3):
        super(ResNet, self).__init__()
        self.input_layer = nn.Linear(input_dim, hidden_dim)
        self.res_blocks = nn.ModuleList([ResidualBlock(hidden_dim) for _ in range(num_blocks)])
        self.output_layer = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.input_layer(x))
        for block in self.res_blocks:
            x = block(x)
        x = self.output_layer(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads=4, ff_dim=None):
        super(TransformerBlock, self).__init__()
        if ff_dim is None:
            ff_dim = 4 * dim

        self.attention = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        self.feed_forward = nn.Sequential(
            nn.Linear(dim, ff_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(ff_dim, dim)
        )

    def forward(self, x):
        # Self-attention
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_out)

        # Feed-forward
        ff_out = self.feed_forward(x)
        x = self.norm2(x + ff_out)
        return x

class Transformer(nn.Module):
    """Transformer-based MPC surrogate"""
    def __init__(self, input_dim=9, embed_dim=64, num_heads=4, num_blocks=2, output_dim=3):
        super(Transformer, self).__init__()
        self.input_embedding = nn.Linear(input_dim, embed_dim)
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads) for _ in range(num_blocks)
        ])
        self.output_layer = nn.Linear(embed_dim, output_dim)

    def forward(self, x):
        # Add batch and sequence dimensions for transformer
        if x.dim() == 2:
            x = x.unsqueeze(1)  # [batch, 1, features]

        x = self.input_embedding(x)
        for block in self.transformer_blocks:
            x = block(x)
        x = x.squeeze(1)  # Remove sequence dimension
        x = self.output_layer(x)
        return x

# Model configurations
model_configs = {
    'MLP': {'class': MLP, 'params': {'hidden_dims': [128, 64]}},
    'ResNet': {'class': ResNet, 'params': {'hidden_dim': 128, 'num_blocks': 3}},
    'Transformer': {'class': Transformer, 'params': {'embed_dim': 64, 'num_heads': 4, 'num_blocks': 2}}
}

print("Available models:", list(model_configs.keys()))

## Training and Evaluation Functions

In [None]:
def train_epoch(model, train_loader, criterion, optimizer):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    num_batches = 0

    for inputs, targets in train_loader:
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1

    return total_loss / num_batches

def evaluate(model, test_loader, criterion):
    """Evaluate model on test set"""
    model.eval()
    total_loss = 0
    num_batches = 0
    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

            outputs = model(inputs)
            loss = criterion(outputs, targets)

            total_loss += loss.item()
            num_batches += 1

            all_predictions.append(outputs.cpu().numpy())
            all_targets.append(targets.cpu().numpy())

    predictions = np.concatenate(all_predictions)
    targets = np.concatenate(all_targets)

    mse = mean_squared_error(targets, predictions)
    mae = mean_absolute_error(targets, predictions)

    return {
        'loss': total_loss / num_batches,
        'mse': mse,
        'mae': mae,
        'predictions': predictions,
        'targets': targets
    }

def train_model(model_name, model, train_loader, test_loader, num_epochs=100, patience=10):
    """Complete training pipeline for a model"""
    print(f"\n=== Training {model_name} ===")

    # Loss functions
    mse_criterion = nn.MSELoss()
    mae_criterion = nn.L1Loss()

    # Optimizer
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

    # Schedulers
    mse_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    mae_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

    results = {
        'model_name': model_name,
        'hyperparameters': {
            'lr': 1e-3,
            'weight_decay': 1e-4,
            'batch_size': train_loader.batch_size,
            'num_epochs': num_epochs
        },
        'training_history': [],
        'best_mse_results': None,
        'best_mae_results': None
    }

    best_mse_loss = float('inf')
    best_mae_loss = float('inf')
    patience_counter = 0

    for epoch in range(num_epochs):
        start_time = time.time()

        # Train with MSE
        train_loss_mse = train_epoch(model, train_loader, mse_criterion, optimizer)
        mse_results = evaluate(model, test_loader, mse_criterion)

        # Train with MAE
        train_loss_mae = train_epoch(model, train_loader, mae_criterion, optimizer)
        mae_results = evaluate(model, test_loader, mae_criterion)

        # Update schedulers
        mse_scheduler.step(mse_results['loss'])
        mae_scheduler.step(mae_results['loss'])

        epoch_time = time.time() - start_time

        # Log results
        epoch_results = {
            'epoch': epoch + 1,
            'train_loss_mse': train_loss_mse,
            'test_loss_mse': mse_results['loss'],
            'test_mse': mse_results['mse'],
            'test_mae': mse_results['mae'],
            'train_loss_mae': train_loss_mae,
            'test_loss_mae': mae_results['loss'],
            'epoch_time': epoch_time
        }
        results['training_history'].append(epoch_results)

        # Save best models
        if mse_results['loss'] < best_mse_loss:
            best_mse_loss = mse_results['loss']
            results['best_mse_results'] = mse_results.copy()
            torch.save(model.state_dict(), f'{model_name}_best_mse.pth')

        if mae_results['loss'] < best_mae_loss:
            best_mae_loss = mae_results['loss']
            results['best_mae_results'] = mae_results.copy()
            torch.save(model.state_dict(), f'{model_name}_best_mae.pth')

        # Early stopping
        if mse_results['loss'] >= best_mse_loss:
            patience_counter += 1
        else:
            patience_counter = 0

        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, MSE Loss: {mse_results['loss']:.6f}, MAE Loss: {mae_results['loss']:.6f}, Time: {epoch_time:.2f}s")

        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

    return results

# ACT Transformers

In [None]:
class ACT(nn.Module):
    def __init__(
        self,
        state_dim=6,  # qpos + qvel (e.g., 3 pos + 3 vel)
        target_dim=3,  # Cartesian goal (x,y,z)
        action_dim=3,  # torque output (3 DoF)
        latent_dim=512,
        num_layers=6,
        nhead=8,
        dim_feedforward=1024,
        dropout=0.1,
        temporal_chunk_size=32,  # how many timesteps per chunk (paper uses 32–100)
    ):
        super().__init__()

        self.temporal_chunk_size = temporal_chunk_size
        self.latent_dim = latent_dim

        # Input embedding
        self.state_encoder = nn.Linear(state_dim, latent_dim)
        self.target_encoder = nn.Linear(target_dim, latent_dim)

        # Positional encoding (sinusoidal)
        self.pos_encoder = PositionalEncoding(latent_dim, dropout)

        # Transformer Encoder (ACT uses encoder-only)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=latent_dim,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation="gelu",
            batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Action head
        self.action_head = nn.Sequential(
            nn.Linear(latent_dim, latent_dim), nn.GELU(), nn.Linear(latent_dim, action_dim)
        )

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)

    def forward(self, states, targets):
        """
        states : (B, T, state_dim)      -> current joint pos/vel
        targets: (B, T, target_dim)     -> same goal repeated (from your data)
        Returns:
            actions: (B, T, action_dim)
        """
        B, T, _ = states.shape

        # Encode inputs
        state_emb = self.state_encoder(states)  # (B,T,latent)
        target_emb = self.target_encoder(targets)  # (B,T,latent)

        x = state_emb + target_emb  # goal-conditioned embedding
        x = self.pos_encoder(x)

        # Apply transformer
        x = self.transformer(x)  # (B,T,latent)

        # Predict actions
        actions = self.action_head(x)  # (B,T,action_dim)

        return actions

    def get_chunked_loss(self, states, targets, actions_true, chunk_size=None):
        """
        ACT trains on random temporal chunks (key to good performance!)
        """
        if chunk_size is None:
            chunk_size = self.temporal_chunk_size

        B, T, _ = states.shape
        max_start = T - chunk_size

        if max_start <= 0:
            return nn.functional.mse_loss(self(states, targets), actions_true)

        # Random temporal chunk per batch element
        start_idx = torch.randint(0, max_start + 1, (B,))
        idxs = start_idx.unsqueeze(1) + torch.arange(chunk_size).unsqueeze(0).to(states.device)

        s_chunk = torch.gather(states, 1, idxs.unsqueeze(-1).expand(-1, -1, states.size(-1)))
        t_chunk = torch.gather(targets, 1, idxs.unsqueeze(-1).expand(-1, -1, targets.size(-1)))
        a_chunk = torch.gather(actions_true, 1, idxs.unsqueeze(-1).expand(-1, -1, actions_true.size(-1)))

        a_pred = self(s_chunk, t_chunk)
        return nn.functional.mse_loss(a_pred, a_chunk)


# Positional encoding (same as ACT / original Transformer)
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1), :]
        return self.dropout(x)

In [None]:
model = ACT(
    state_dim=6, target_dim=3, action_dim=3, latent_dim=512, num_layers=6, nhead=8, temporal_chunk_size=50
).cuda()

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

# In training loop
for states_batch, targets_batch, actions_batch in dataloader:
    states_batch = states_batch.cuda()  # (B, T, 6)
    targets_batch = targets_batch.cuda()  # (B, T, 3)  ← repeated goal
    actions_batch = actions_batch.cuda()  # (B, T, 3)

    loss = model.get_chunked_loss(states_batch, targets_batch, actions_batch, chunk_size=50)

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()

## Run Training Pipeline

In [None]:
# Train all models
all_results = {}

for model_name, config in model_configs.items():
    # Create model
    model_class = config['class']
    model_params = config['params']
    model = model_class(**model_params).to(DEVICE)

    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"\n{model_name} - Parameters: {total_params:,}")

    # Train model
    results = train_model(model_name, model, train_loader, test_loader, num_epochs=50)
    all_results[model_name] = results

print("\n=== Training Complete ===")

## Export Results

In [None]:
# Create results directory
os.makedirs('results', exist_ok=True)

# Export comprehensive results
for model_name, results in all_results.items():
    # Save detailed results as JSON
    results_file = f'results/{model_name}_results.json'
    with open(results_file, 'w') as f:
        # Convert numpy arrays to lists for JSON serialization
        json_results = results.copy()
        if 'best_mse_results' in json_results and json_results['best_mse_results']:
            json_results['best_mse_results'].pop('predictions', None)
            json_results['best_mse_results'].pop('targets', None)
        if 'best_mae_results' in json_results and json_results['best_mae_results']:
            json_results['best_mae_results'].pop('predictions', None)
            json_results['best_mae_results'].pop('targets', None)
        json.dump(json_results, f, indent=2)

    print(f"Results saved to: {results_file}")

# Export summary table
summary_data = []
for model_name, results in all_results.items():
    row = {
        'model_type': model_name,
        'hyperparameters': results['hyperparameters'],
        'best_mse_loss': results.get('best_mse_results', {}).get('loss', 'N/A'),
        'best_mae_loss': results.get('best_mae_results', {}).get('loss', 'N/A'),
        'best_mse_mse': results.get('best_mse_results', {}).get('mse', 'N/A'),
        'best_mae_mae': results.get('best_mae_results', {}).get('mae', 'N/A'),
        'epochs_trained': len(results['training_history'])
    }
    summary_data.append(row)

with open('results/training_summary.json', 'w') as f:
    json.dump(summary_data, f, indent=2)

print("Summary saved to: results/training_summary.json")

# Export model weights (already saved during training)
print("Model weights saved as .pth files")

## Results Visualization

In [None]:
# Plot training curves
plt.figure(figsize=(15, 10))

for i, (model_name, results) in enumerate(all_results.items()):
    history = results['training_history']

    epochs = [h['epoch'] for h in history]
    mse_losses = [h['test_loss_mse'] for h in history]
    mae_losses = [h['test_loss_mae'] for h in history]

    plt.subplot(2, 2, 1)
    plt.plot(epochs, mse_losses, label=model_name)
    plt.title('MSE Loss vs Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('MSE Loss')
    plt.legend()

    plt.subplot(2, 2, 2)
    plt.plot(epochs, mae_losses, label=model_name)
    plt.title('MAE Loss vs Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('MAE Loss')
    plt.legend()

plt.tight_layout()
plt.savefig('results/training_curves.png', dpi=300, bbox_inches='tight')
plt.show()

# Print final results summary
print("\n=== Final Results Summary ===")
print(f"{'Model':<12} {'MSE Loss':<10} {'MAE Loss':<10} {'MSE':<10} {'MAE':<10}")
print("-" * 60)

for model_name, results in all_results.items():
    mse_results = results.get('best_mse_results', {})
    mae_results = results.get('best_mae_results', {})

    mse_loss = f"{mse_results.get('loss', 'N/A'):.6f}" if mse_results else 'N/A'
    mae_loss = f"{mae_results.get('loss', 'N/A'):.6f}" if mae_results else 'N/A'
    mse_val = f"{mse_results.get('mse', 'N/A'):.6f}" if mse_results else 'N/A'
    mae_val = f"{mae_results.get('mae', 'N/A'):.6f}" if mae_results else 'N/A'

    print(f"{model_name:<12} {mse_loss:<10} {mae_loss:<10} {mse_val:<10} {mae_val:<10}")

print("\nModel weights saved as .pth files")
print("Detailed results saved in results/ directory")

## Download Results

In [None]:
# Create zip file for easy download
!zip -r mpc_surrogate_results.zip results/ *.pth

from google.colab import files
files.download('mpc_surrogate_results.zip')