# MPC Surrogate Training Pipeline

This notebook implements the complete training and evaluation pipeline for approximating MPC policies using neural networks.

## Dataset Structure
- **States**: Joint positions and velocities (6D) - [q1, q2, q3, q̇1, q̇2, q̇3]
- **Targets**: End-effector target positions (3D) - [x, y, z]
- **Actions**: MPC torques (3D) - [τ1, τ2, τ3]

The goal is to learn a mapping: (state, target) → MPC torques

In [None]:
# Install dependencies
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install h5py numpy matplotlib scikit-learn tqdm

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import h5py
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error
import json
import os
from tqdm import tqdm
import time

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## Data Preparation - Visualization

In [None]:
class MPCDataset(Dataset):
    def __init__(self, states, targets, actions, augment=False):
        self.states = torch.FloatTensor(states)
        self.targets = torch.FloatTensor(targets)
        self.actions = torch.FloatTensor(actions)
        self.augment = augment

    def __len__(self):
        return len(self.states)

    def __getitem__(self, idx):
        state = self.states[idx]
        target = self.targets[idx]
        action = self.actions[idx]

        if self.augment:
            # Simple data augmentation: add small noise to states and targets
            state_noise = torch.randn_like(state) * 0.01  # 1% noise
            target_noise = torch.randn_like(target) * 0.005  # 0.5% noise
            state = state + state_noise
            target = target + target_noise

        # Concatenate state and target as input
        x = torch.cat([state, target], dim=0)
        return x, action

def load_data(filename='robot_mpc_dataset.h5'):
    """Load data from HDF5 file"""
    with h5py.File(filename, 'r') as f:
        states = f['states'][:]
        targets = f['targets'][:]
        actions = f['actions'][:]
    return states, targets, actions

def create_dataloaders(states, targets, actions, batch_size=32, train_ratio=0.8, augment=True):
    """Create train and test dataloaders"""
    dataset = MPCDataset(states, targets, actions, augment=augment)

    train_size = int(train_ratio * len(dataset))
    test_size = len(dataset) - train_size

    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    # Disable augmentation for test set
    test_dataset.dataset.augment = False

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

# Load and prepare data
states, targets, actions = load_data('data/robot_mpc_dataset.h5')
print(f"Dataset shapes:")
print(f"States: {states.shape}")
print(f"Targets: {targets.shape}")
print(f"Actions: {actions.shape}")

train_loader, test_loader = create_dataloaders(states, targets, actions, batch_size=32)

In [None]:
import h5py
import numpy as np
import matplotlib.pyplot as plt


def load_episode_indices(targets, tol=1e-6):
    """
    Splits the dataset into episodes by detecting when target changes.
    Assumes each episode has a constant target_xyz.
    """
    episode_starts = [0]
    for i in range(1, len(targets)):
        if np.linalg.norm(targets[i] - targets[i - 1]) > tol:
            episode_starts.append(i)
    episode_starts.append(len(targets))
    return episode_starts


def visualize_random_episode(h5_path="data/robot_mpc_dataset.h5"):
    with h5py.File(h5_path, "r") as f:
        states = f["states"][:]  # (N, 6)
        targets = f["targets"][:]  # (N, 3)
        actions = f["actions"][:]  # (N, 3)

    # Find where episodes split
    episode_boundaries = load_episode_indices(targets)
    num_episodes = len(episode_boundaries) - 1

    print(f"Found {num_episodes} episodes in dataset.")

    # Sample a random episode
    ep_idx = np.random.randint(num_episodes)
    start = episode_boundaries[ep_idx]
    end = episode_boundaries[ep_idx + 1]

    ep_states = states[start:end]
    ep_targets = targets[start:end]
    ep_actions = actions[start:end]

    print(f"Visualizing episode {ep_idx} ({end - start} timesteps).")
    print(f"Target = {ep_targets[0]}")

    # Plot joint positions (first 3 states)
    plt.figure(figsize=(12, 6))
    plt.subplot(3, 1, 1)
    plt.plot(ep_states[:, 0], label="q1")
    plt.plot(ep_states[:, 1], label="q2")
    plt.plot(ep_states[:, 2], label="q3")
    plt.title("Joint Positions")
    plt.legend()

    # Plot joint velocities (last 3 states)
    plt.subplot(3, 1, 2)
    plt.plot(ep_states[:, 3], label="dq1")
    plt.plot(ep_states[:, 4], label="dq2")
    plt.plot(ep_states[:, 5], label="dq3")
    plt.title("Joint Velocities")
    plt.legend()

    # Plot MPC actions
    plt.subplot(3, 1, 3)
    plt.plot(ep_actions[:, 0], label="tau1")
    plt.plot(ep_actions[:, 1], label="tau2")
    plt.plot(ep_actions[:, 2], label="tau3")
    plt.title("MPC Torques")
    plt.legend()

    plt.tight_layout()
    plt.show()

## Model Architectures

In [None]:
class MLP(nn.Module):
    """Simple Multi-Layer Perceptron"""
    def __init__(self, input_dim=9, hidden_dims=[128, 64], output_dim=3):
        super(MLP, self).__init__()
        layers = []
        prev_dim = input_dim

        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_dim),
                nn.Dropout(0.1)
            ])
            prev_dim = hidden_dim

        layers.append(nn.Linear(prev_dim, output_dim))
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super(ResidualBlock, self).__init__()
        self.linear1 = nn.Linear(dim, dim)
        self.linear2 = nn.Linear(dim, dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        residual = x
        out = self.relu(self.linear1(x))
        out = self.dropout(out)
        out = self.linear2(out)
        out += residual
        out = self.relu(out)
        return out

class ResNet(nn.Module):
    """Residual Network for MPC surrogate"""
    def __init__(self, input_dim=9, hidden_dim=128, num_blocks=3, output_dim=3):
        super(ResNet, self).__init__()
        self.input_layer = nn.Linear(input_dim, hidden_dim)
        self.res_blocks = nn.ModuleList([ResidualBlock(hidden_dim) for _ in range(num_blocks)])
        self.output_layer = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.input_layer(x))
        for block in self.res_blocks:
            x = block(x)
        x = self.output_layer(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads=4, ff_dim=None):
        super(TransformerBlock, self).__init__()
        if ff_dim is None:
            ff_dim = 4 * dim

        self.attention = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        self.feed_forward = nn.Sequential(
            nn.Linear(dim, ff_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(ff_dim, dim)
        )

    def forward(self, x):
        # Self-attention
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_out)

        # Feed-forward
        ff_out = self.feed_forward(x)
        x = self.norm2(x + ff_out)
        return x

class Transformer(nn.Module):
    """Transformer-based MPC surrogate"""
    def __init__(self, input_dim=9, embed_dim=64, num_heads=4, num_blocks=2, output_dim=3):
        super(Transformer, self).__init__()
        self.input_embedding = nn.Linear(input_dim, embed_dim)
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads) for _ in range(num_blocks)
        ])
        self.output_layer = nn.Linear(embed_dim, output_dim)

    def forward(self, x):
        # Add batch and sequence dimensions for transformer
        if x.dim() == 2:
            x = x.unsqueeze(1)  # [batch, 1, features]

        x = self.input_embedding(x)
        for block in self.transformer_blocks:
            x = block(x)
        x = x.squeeze(1)  # Remove sequence dimension
        x = self.output_layer(x)
        return x

# Model configurations
model_configs = {
    'MLP': {'class': MLP, 'params': {'hidden_dims': [128, 64]}},
    'ResNet': {'class': ResNet, 'params': {'hidden_dim': 128, 'num_blocks': 3}},
    'Transformer': {'class': Transformer, 'params': {'embed_dim': 64, 'num_heads': 4, 'num_blocks': 2}}
}

print("Available models:", list(model_configs.keys()))

## Training and Evaluation Functions

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    num_batches = 0

    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1

    return total_loss / num_batches

def evaluate(model, test_loader, criterion, device):
    """Evaluate model on test set"""
    model.eval()
    total_loss = 0
    num_batches = 0
    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets)

            total_loss += loss.item()
            num_batches += 1

            all_predictions.append(outputs.cpu().numpy())
            all_targets.append(targets.cpu().numpy())

    predictions = np.concatenate(all_predictions)
    targets = np.concatenate(all_targets)

    mse = mean_squared_error(targets, predictions)
    mae = mean_absolute_error(targets, predictions)

    return {
        'loss': total_loss / num_batches,
        'mse': mse,
        'mae': mae,
        'predictions': predictions,
        'targets': targets
    }

def train_model(model_name, model, train_loader, test_loader, num_epochs=100, patience=10):
    """Complete training pipeline for a model"""
    print(f"\n=== Training {model_name} ===")

    # Loss functions
    mse_criterion = nn.MSELoss()
    mae_criterion = nn.L1Loss()

    # Optimizer
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

    # Schedulers
    mse_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    mae_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

    results = {
        'model_name': model_name,
        'hyperparameters': {
            'lr': 1e-3,
            'weight_decay': 1e-4,
            'batch_size': train_loader.batch_size,
            'num_epochs': num_epochs
        },
        'training_history': [],
        'best_mse_results': None,
        'best_mae_results': None
    }

    best_mse_loss = float('inf')
    best_mae_loss = float('inf')
    patience_counter = 0

    for epoch in range(num_epochs):
        start_time = time.time()

        # Train with MSE
        train_loss_mse = train_epoch(model, train_loader, mse_criterion, optimizer, device)
        mse_results = evaluate(model, test_loader, mse_criterion, device)

        # Train with MAE
        train_loss_mae = train_epoch(model, train_loader, mae_criterion, optimizer, device)
        mae_results = evaluate(model, test_loader, mae_criterion, device)

        # Update schedulers
        mse_scheduler.step(mse_results['loss'])
        mae_scheduler.step(mae_results['loss'])

        epoch_time = time.time() - start_time

        # Log results
        epoch_results = {
            'epoch': epoch + 1,
            'train_loss_mse': train_loss_mse,
            'test_loss_mse': mse_results['loss'],
            'test_mse': mse_results['mse'],
            'test_mae': mse_results['mae'],
            'train_loss_mae': train_loss_mae,
            'test_loss_mae': mae_results['loss'],
            'epoch_time': epoch_time
        }
        results['training_history'].append(epoch_results)

        # Save best models
        if mse_results['loss'] < best_mse_loss:
            best_mse_loss = mse_results['loss']
            results['best_mse_results'] = mse_results.copy()
            torch.save(model.state_dict(), f'{model_name}_best_mse.pth')

        if mae_results['loss'] < best_mae_loss:
            best_mae_loss = mae_results['loss']
            results['best_mae_results'] = mae_results.copy()
            torch.save(model.state_dict(), f'{model_name}_best_mae.pth')

        # Early stopping
        if mse_results['loss'] >= best_mse_loss:
            patience_counter += 1
        else:
            patience_counter = 0

        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, MSE Loss: {mse_results['loss']:.6f}, MAE Loss: {mae_results['loss']:.6f}, Time: {epoch_time:.2f}s")

        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

    return results

## Run Training Pipeline

In [None]:
# Train all models
all_results = {}

for model_name, config in model_configs.items():
    # Create model
    model_class = config['class']
    model_params = config['params']
    model = model_class(**model_params).to(device)

    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"\n{model_name} - Parameters: {total_params:,}")

    # Train model
    results = train_model(model_name, model, train_loader, test_loader, num_epochs=50)
    all_results[model_name] = results

print("\n=== Training Complete ===")

## Export Results

In [None]:
# Create results directory
os.makedirs('results', exist_ok=True)

# Export comprehensive results
for model_name, results in all_results.items():
    # Save detailed results as JSON
    results_file = f'results/{model_name}_results.json'
    with open(results_file, 'w') as f:
        # Convert numpy arrays to lists for JSON serialization
        json_results = results.copy()
        if 'best_mse_results' in json_results and json_results['best_mse_results']:
            json_results['best_mse_results'].pop('predictions', None)
            json_results['best_mse_results'].pop('targets', None)
        if 'best_mae_results' in json_results and json_results['best_mae_results']:
            json_results['best_mae_results'].pop('predictions', None)
            json_results['best_mae_results'].pop('targets', None)
        json.dump(json_results, f, indent=2)

    print(f"Results saved to: {results_file}")

# Export summary table
summary_data = []
for model_name, results in all_results.items():
    row = {
        'model_type': model_name,
        'hyperparameters': results['hyperparameters'],
        'best_mse_loss': results.get('best_mse_results', {}).get('loss', 'N/A'),
        'best_mae_loss': results.get('best_mae_results', {}).get('loss', 'N/A'),
        'best_mse_mse': results.get('best_mse_results', {}).get('mse', 'N/A'),
        'best_mae_mae': results.get('best_mae_results', {}).get('mae', 'N/A'),
        'epochs_trained': len(results['training_history'])
    }
    summary_data.append(row)

with open('results/training_summary.json', 'w') as f:
    json.dump(summary_data, f, indent=2)

print("Summary saved to: results/training_summary.json")

# Export model weights (already saved during training)
print("Model weights saved as .pth files")

## Results Visualization

In [None]:
# Plot training curves
plt.figure(figsize=(15, 10))

for i, (model_name, results) in enumerate(all_results.items()):
    history = results['training_history']

    epochs = [h['epoch'] for h in history]
    mse_losses = [h['test_loss_mse'] for h in history]
    mae_losses = [h['test_loss_mae'] for h in history]

    plt.subplot(2, 2, 1)
    plt.plot(epochs, mse_losses, label=model_name)
    plt.title('MSE Loss vs Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('MSE Loss')
    plt.legend()

    plt.subplot(2, 2, 2)
    plt.plot(epochs, mae_losses, label=model_name)
    plt.title('MAE Loss vs Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('MAE Loss')
    plt.legend()

plt.tight_layout()
plt.savefig('results/training_curves.png', dpi=300, bbox_inches='tight')
plt.show()

# Print final results summary
print("\n=== Final Results Summary ===")
print(f"{'Model':<12} {'MSE Loss':<10} {'MAE Loss':<10} {'MSE':<10} {'MAE':<10}")
print("-" * 60)

for model_name, results in all_results.items():
    mse_results = results.get('best_mse_results', {})
    mae_results = results.get('best_mae_results', {})

    mse_loss = f"{mse_results.get('loss', 'N/A'):.6f}" if mse_results else 'N/A'
    mae_loss = f"{mae_results.get('loss', 'N/A'):.6f}" if mae_results else 'N/A'
    mse_val = f"{mse_results.get('mse', 'N/A'):.6f}" if mse_results else 'N/A'
    mae_val = f"{mae_results.get('mae', 'N/A'):.6f}" if mae_results else 'N/A'

    print(f"{model_name:<12} {mse_loss:<10} {mae_loss:<10} {mse_val:<10} {mae_val:<10}")

print("\nModel weights saved as .pth files")
print("Detailed results saved in results/ directory")

## Download Results

In [None]:
# Create zip file for easy download
!zip -r mpc_surrogate_results.zip results/ *.pth

from google.colab import files
files.download('mpc_surrogate_results.zip')