# Adaptive MPC

Model Predictive Control is a powerful model-based control strategy used in robotics, autonomous driving and RL. Unlike traditional RL methods, which rely on trial and error learning, MPC explicitly optimizes future actions at every step by solving an optimizaiton problem.
MPC works by:
1. Predicting future states using a learned dynamics model.
2. Optimizing a sequence of actions over a plianning horizon.
3. Executing onlu the first action, then repeating the process at the next timestep.

The problem with standard MPC is the fixed horizon(H) throughout training which is not optimal:
-Short horizons lead to myopic (short-sighted) behavior.
-Long horizons increase computation time and amplify model inaccuracies
-Uncertainty changes over tims, a fixed horizon does not adapt to model accuracy.

Adaptive MPC dynamically adjusts the planning horizon H based on model uncertainty.
- Low uncertainty -> Long horizon (uses full model capacity)
- High uncertainty -> Short horizon (reduces error propagation)

## Background

MPC is an optimal control algorithm that repeatedly:
1. Solves an optimization problem over a finite horizon.
2. Finds the best control sequence minimizing a cost function.
3. Executes the first control action, then repeats the process.

At each timestep t, MPC solves:
$$ \min_{a_{t:t+H}}\sum_{k=t}^{t+H} C(s_k,a_k)+\lambda||a_k||$$
subject to:
$$s_{k+1} = f(s_k, a_k)$$

where:
- $H$ is the planning horizon.
- $C(s_k,a_k)$ is the cost function.
- $\lambda$ is the regularization parameter.
- $f(s_k, a_k)$ is the dynamics model.

**Uncertainty-aware planning**:
To make MPC adaptive, we define a confidence score:
$$Uncertainty(s_k) = Var(\hat{s}_{k+1})$$
where $Var(\hat{s}_{k+1})$ is an ensemble of dynamics models predicting the next state.
The adaptive planning horizon is then defined as:
$$H_t = \max(H_{min}, \min(H_{max}, H_{base}\cdot \exp^{-\beta \cdot Uncertainty(s_t)}))$$
where:
- $H_{min}$ is the minimum horizon.
- $H_{max}$ is the maximum horizon.
- $\beta$ is the scaling factor, sensitivity parameter.
This ensures that MPC makes longer plans in familiar situations and shorter plans when uncertainty is high.

## Theory

**Adaptive Horizon Computation**:
At each timestep:
1. Compute the model uncertainty using an ensemble of models.
2. Adjust the horizon $H_t$ accordingly:
- If uncertainty is low, increase the horizon.
- If uncertainty is high, decrease the horizon.
**Ensemble-Based Uncertainty Estimation**:
We use N learned dynamics models, each predicting next states:
$$\hat{s}_{t+1}^{(i)} = f_{\theta_i}(s_t, a_t), i\in{1,2,...,N}$$
The variance across models fives uncertainty:
$$Uncertainty(s_t) = \frac{1}{N} \sum_{i=1}^{N} (\hat{s}_{t+1}^{(i)} - \bar{s}_{t+1})^2$$
where $\bar{s}_{t+1}$ is the mean prediction across models.

## Implementation

In [1]:
import numpy as np
import torch
from collections import deque
import random

class ReplayBuffer:
    def __init__(self, max_size=int(1e6)):
        self.buffer = deque(maxlen=max_size)

    def add(self, state, action, reward, next_state, done):
        state = np.array(state, dtype=np.float32).flatten()
        action = np.array(action, dtype=np.float32).flatten()
        next_state = np.array(next_state, dtype=np.float32).flatten()
        reward = np.float32(reward)
        done = np.float32(done)
        
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        return (torch.tensor(np.stack(states), dtype=torch.float32),
                torch.tensor(np.stack(actions), dtype=torch.float32),
                torch.tensor(rewards, dtype=torch.float32).unsqueeze(1),
                torch.tensor(np.stack(next_states), dtype=torch.float32),
                torch.tensor(dones, dtype=torch.float32).unsqueeze(1))

    def __len__(self):
        return len(self.buffer)

In [None]:
import torch.nn as nn

class AdaptiveMPC:
    def __init__(self, state_dim, action_dim, num_models=5, beta=0.1):
        self.models = [self._build_model(state_dim, action_dim) for _ in range(num_models)]
        self.num_models = num_models
        self.beta = beta
        self.rollout_horizon = 5 # initial value
        self.action_dim = action_dim

    def _build_model(self, state_dim, action_dim):
        return nn.Sequential(
            nn.Linear(state_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, state_dim+1)
        )
    
    def predict(self, state, action):
        inputs = torch.cat([state, action], dim=-1)
        predictions = [model(inputs) for model in self.models]
        predictions = torch.stack(predictions)

        mean_prediction = predictions.mean(dim=0)
        uncertainty = predictions.var(dim=0).mean().item()
        
        self.update_horizon(uncertainty)
        return mean_prediction[..., :-1], mean_prediction[..., -1], uncertainty
    
    def plan(self, state, num_samples=500):
        mean = np.zeros((self.rollout_horizon, self.action_dim))
        std = np.ones((self.rollout_horizon, self.action_dim))

        for _ in range(5):
            actions = np.random.normal(mean, std, (num_samples, self.rollout_horizon, self.action_dim))
            returns = np.zeros(num_samples)

            for i in range(num_samples):
                returns[i] = self.evaluate_sequence(state, actions[i])

            elite_idxs = returns.argsort()[-int(0.1*num_samples):]
            elite_actions = actions[elite_idxs]

            mean = elite_actions.mean(axis=0)
            std = elite_actions.std(axis=0)
        
        return mean[0]

    def evaluate_sequence(self, state, actions):
        total_pred_reward = 0.0
        uncertainty_sum = 0.0

        for action in actions:
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            action_tensor = torch.tensor(action, dtype=torch.float32).unsqueeze(0)
            next_state_pred, reward_pred, uncertainty = self.predict(state_tensor, action_tensor)

            total_pred_reward += reward_pred.item()
            state = next_state_pred.detach().numpy().squeeze(0)
            uncertainty_sum += uncertainty
        
        avg_uncertainty = uncertainty_sum / len(actions)
        self.update_horizon(avg_uncertainty)
        return total_pred_reward
    
    def update_horizon(self, uncertainty):
        self.rollout_horizon = int(max(1, min(20, 5*np.exp(-self.beta*uncertainty))))

In [None]:
import gymnasium as gym
import matplotlib.pyplot as plt

# Initialize environment
env = gym.make("HalfCheetah-v5")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])

# Initialize components
adaptive_mpc = AdaptiveMPC(state_dim, action_dim)
real_buffer = ReplayBuffer(max_size=int(1e6))

num_episodes = 500
adaptive_mpc_rewards = []

# Training loop
for episode in range(num_episodes):
    state, _ = env.reset()
    total_reward = 0

    for step in range(1000):  # Max steps per episode
        action = adaptive_mpc.plan(state)
        next_state, reward, done, _, _ = env.step(action)
        real_buffer.add(state, action, reward, next_state, done)
        total_reward += reward
        state = next_state if not done else env.reset()

        if len(real_buffer) > 10000:
            # Train model using real buffer
            states, actions, rewards, next_states, _ = real_buffer.sample(256)
            predicted_next_states, predicted_rewards, _ = adaptive_mpc.predict(states, actions)
            model_loss = nn.MSELoss()(predicted_next_states, next_states) + nn.MSELoss()(predicted_rewards, rewards)
            
            optimizer = torch.optim.Adam(
                [param for model in adaptive_mpc.models for param in model.parameters()], lr=1e-3
            )
            optimizer.zero_grad()
            model_loss.backward()
            optimizer.step()

    adaptive_mpc_rewards.append(total_reward)
    print(f"Episode {episode}: Total Reward = {total_reward}")

# Plot results
plt.plot(adaptive_mpc_rewards, label="Adaptive MPC")
plt.xlabel("Episodes")
plt.ylabel("Total Reward")
plt.title("Adaptive MPC on HalfCheetah-v4")
plt.legend()
plt.show()
