# Hybrid MPC + SAC

MPC is a great short-term optimization but is not very good at long-term strategies, SAC is gread at long-term strategies but requires a lot of environment interactions.
The solution is to combine both MPC and SAC, MPC will handle short-term planning and ensures good actions early in training, SAC learns a policy that improves overtime and generaluzes to new situations.
The MPC-generated rollouts are used to train SAC more efficiently, reducing sample complexity.
**Early Training**-> MPC helps explore good actions while SAC is still learning.
**Later Training**-> SAC learns a long-term strategy to replace MPC.
**Faster Convergence**-> MPC rollouts speed up sac training, reducing sample complexity.



## Background

MPC solves an optimization problem at every step to find the best action sequence:
$$\min_{a_{t:t+H}}\sum_{k=t}^{t+H}C(s_k,a_k)+\lambda||a_k||$$
subject to:
$$s_{k+1}=f(s_k,a_k)$$
where $H$ is the planning horizon, $C(s_k,a_k)$ is the cost function and $f(s_k,a_k)$ is the dynamics function.
However MPC soes not learn from past experience.

SAC is an off-policy actor-critic algorithm that learns a stochastic policy by maximizing:
$$J(\pi)=\sum_{t}\mathbb(E)_{(s_t,a_t)\sim\pi}[r(s_t,a_t)+\alpha H(\pi(\cdot |s_t))]$$
where $H(\pi)$ is the policy entropy ensuring exploration, Critic Networks $Q_1,Q_2$ are estimated Q-value functions and Actor Network $\pi$ samples continuous actions from a Gaussian Policy.

For the hybrid approach we define a weighting function $w(t)$ that determines whether to use MPC or SAC:
$$a_t = w(t) \cdot \alpha_{MPC} + (1 - W(t)) \cdot \alpha_{SAC}$$
$w(t)$ starts high, favoring MPC early on. $w(t)$ decreases, letting SAC take control over time.

## Theory

The Hybrid approach works as follows:
1. At the beginning of training, the policy learned by SAC is random and unreliable. MPC, using a dynamics model, provides high-quality actions through direct optimization.
2. Over time, SAC learns from both real and MPC-generated experiences, gradually improving and eventually outperforming MPC due to its long-term optimziation capability.
3. A blending coefficient $w(t)$ controls the balance between MPC and SAC actions. Initially, $w(t) \approx 1$, relying on MPC. Gradually, $w(t) \rightarrow 0$, transitioning entirely to the learned SAC policy.

At timestep t, the selected action $a_t$ is given by:
$$a_t = w(t) \cdot a_{MPC,t} + (1-w(t)) \cdot a_{SAC}(s_t)$$
where $a_{MPC,t}$ is the action optimized by MPC, $a_{SAC} = \pi(s_t)$ is the action suggested by SAC policy and $w(t)$ is the time-dependent weight that decreases from 1 to 0 over training.
To choose $w(t)$, there are common strategies:
1. Linear Decay: $w(t) = \max(0, 1-\lambda\cdot t)$, with $\lambda$ small.
2. Exponential Decay: $w(t) = \exp^{-\lambda \cdot t}$, for some positive parameter $\lambda.$


## Mathematical formulation

The MPC solves an optimization problem at each timestep:
$$\min_{a_t^{t+H-1}} \mathbb{E}_{\hat{s}_t^t+H}[\sum_{k=0}^{H-1} C(s_{t+k}, a_{t+k})]$$
subject to:
$$ s_{k+1} = f(s_k, a_k)$$
Here $f(s,a)$ is a learned dynamics model.
We measure uncertainty using an enseble of predictive models
$$Uncertainty(s_t,a_t) = \frac{1}{N} \sum_{i=1}^{N} ||f_{\theta_i}(s_t,a_t) - \bar{f}(s_t,a_t)||^2$$


## Implementation

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym
import matplotlib.pyplot as plt
from collections import deque
import random
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
class ReplayBuffer:
    def __init__(self, capacity=10000000):
        self.buffer = deque(maxlen=capacity)

    def add(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*batch)
        return (torch.tensor(np.array(state), dtype=torch.float),
                torch.tensor(np.array(action), dtype=torch.float),
                torch.tensor(np.array(reward), dtype=torch.float).unsqueeze(1),
                torch.tensor(np.array(next_state), dtype=torch.float),
                torch.tensor(np.array(done, dtype=int)).unsqueeze(1))
    
    def __len__(self):
        return len(self.buffer)

In [5]:
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim),
            nn.Tanh()
        )
        self.max_action = max_action

    def forward(self, state):
        return self.max_action * torch.tanh(self.model(state))

In [6]:
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.q_net = nn.Sequential(
            nn.Linear(state_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, state, action):
        return self.q_net(torch.cat([state, action], dim=-1))

In [7]:
class EnsembleDynamicsModel(nn.Module):
    def __init__(self, state_dim, action_dim, num_models=5):
        super().__init__()
        self.models = nn.ModuleList([
            nn.Sequential(
                nn.Linear(state_dim + action_dim, 256), nn.ReLU(),
                nn.Linear(256, 256), nn.ReLU(),
                nn.Linear(256, state_dim)
            ) for _ in range(num_models)
        ])

    def predict(self, state, action):
        sa = torch.cat([state, action], dim=-1)
        preds = [m(sa) for m in self.models]
        preds = torch.stack(preds)
        mean_pred = preds.mean(0)
        uncertainty = preds.var(0).mean().item()
        return mean_pred, uncertainty

In [None]:
class HybridMPCSACAgent:
    def __init__(self, env):
        self.env = env
        self.state_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.shape[0]
        self.max_action = env.action_space.high[0]

        self.actor = Actor(self.state_dim, action_dim, max_action)
        self.critic = Critic(self.state_dim, self.action_dim)
        self.dynamics_model = EnsembleDynamicsModel(state_dim, action_dim)
        self.buffer = ReplayBuffer()
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=3e-4)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=3e-4)

        self.horizon = 5
        self.mpc_weight = 1.0
        self.mpc_decay = 0.001

    def mpc_action(self, state):
        best_action = torch.zeros(self.action_dim)
        best_reward = -np.inf
        for _ in range(100):
            candidate_action = np.random.uniform(-1,1, self.action_dim)
            next_state_pred, uncertainty = self.dynamics_model.predict(torch.tensor(state, dtype=torch.float), torch.tensor(candidate_action, dtype=torch.float))
            reward_pred = self.critic(torch.tensor(next_state_pred, dtype=torch.float), torch.tensor(candidate_action, dtype=torch.float))
            if reward_pred.item() > best_reward:
                best_reward = reward_pred.item()
                best_action = candidate_action
        return best_action
    
    def select_action(self, state, t):
        sac_action = self.actor(torch.tensor(state, dtype=torch.float)).detach().numpy()
        mpc_action = self.mpc_action(state)
        w = max(0, 1-t/100000)
        return w*mpc_action + (1-w)*sac_action
    
    def train(self):
        states, actions, rewards, next_states, dones = self.buffer.sample(128)