# Model Predictive Control (MPC) with RL

Model Predictive Control is a planning-based approach where at each time step an agent optimizes a sequence of future actions (over a finite horizon) to maximize expected reward, then executes only the first action of that sequence before re-planning at the next step. This "plan-at-every-step" strategy allows the agent to use new state feedback at each step, mitigating errors from imperfect models. MPC requires a dynamic model of the environment to predict how actions will affect future states and rewards.
Instead of learning a policy directly, MPC optimizes actions online using a learned dynamics model. It is used in self-driving cars, robots, and drones. It solves the long-term planning problems of RL by actively searching for optimal actions at each timestep.
Instead of training an actor network, MPC searches for the best action at each step by solving:
$$a_t^*,a_{t+1}^*,...,a_{t+H}^* = \arg\max_{a_t,a_{t+1},...,a_{t+H}} \sum_{i=0}^{H}\gamma^i R(s_{t+i},a_{t+i})$$
where:
 - H is the planning horizon
 - $R(s,a)$ is the reward function
 - $\gamma$ is the discount factor

MPC differs from MBPO in that it optimizes at each step instead of using a policy network, predicts future states online instead of generic synthetic rollouts, runs online at each step instead of training offline, and works best in structured environments instead of high-dimensional action spaces.

## Random Shooting
Simplest MPC planning method, samples many random action sequences, simulates each sequence with the model to compute total reward, and chooses the best sequence's first action. This is easy to implement and parallelize, but can be inefficient—wastes samples on poor sequences and may require a very large number of samples to find a good plan.

### Implementation

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np

class MPCRandomShooting:
    def __init__(self, model, action_dim, horizon=5, num_samples=500, discount=0.99):
        self.model = model
        self.action_dim = action_dim
        self.horizon = horizon
        self.num_samples = num_samples
        self.discount = discount

    def predict_trajectory(self, state, actions):
        state = state.clone()
        total_rewards = np.zeros(self.num_samples)

        for t in range(self.horizon):
            action = actions[:, t, :]
            next_state, reward = self.model.predict(state, action)
            total_rewards += (self.discount ** t) * reward.squeeze().detach().cpu().numpy()
            state = next_state

        return total_rewards

    def select_action(self, state):
        state = state.unsqueeze(0).repeat(self.num_samples,1)
        actions = torch.rand(self.num_samples, self.horizon, self.action_dim) * 2 - 1
        rewards = self.predict_trajectory(state, actions)
        best_idx = np.argmax(rewards)
        return actions[best_idx, 0, :].unsqueeze(0)

In [None]:
import torch.nn as nn
class EnsembleDynamicsModel(nn.Module):
    def __init__(self, state_dim, action_dim, num_models=5, hidden_dim=256):
        super(EnsembleDynamicsModel, self).__init__();
        self.models = nn.ModuleList([nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, state_dim + 1)
        ) for _ in range(num_models)]);
        self.num_models = num_models;

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1);
        outputs = [model(x) for model in self.models];
        next_states, rewards = zip(*[(out[..., :-1], out[..., -1:]) for out in outputs]);
        return next_states, rewards;

    def predict(self, state, action):
        model_idx = np.random.randint(self.num_models);
        x = torch.cat([state, action], dim=-1);
        output = self.models[model_idx](x);
        next_state, reward = output[..., :-1], output[..., -1:];
        return next_state, reward;

In [None]:
import gymnasium as gym
env = gym.make('HalfCheetah-v5')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]

dynamics_model = EnsembleDynamicsModel(state_dim, action_dim)
mpc_controller = MPCRandomShooting(dynamics_model, action_dim)

num_episodes = 100
for episode in range(num_episodes):
    state, _ = env.reset()
    episode_reward = 0

    for step in range(1000):
        action = mpc_controller.select_action(torch.tensor(state, dtype=torch.float32)).squeeze(0)
        next_state, reward, terminated, truncated, _ = env.step(action)
        state = next_state
        episode_reward += reward

        done = terminated or truncated
        if done:
            break

    print(f"Episode {episode+1}, Reward: {episode_reward}")
env.close()



## Cross-Entropy Method (CEM)
Improvement over random shooting, uses a form of iterative refinement. Instead of sampling action sequences from a fixed distribution, CEM updates the sampling distribution to focus on high-performing sequences. In practice, CEM often assumes a parameterized distribution over action sequences. The process works as follows:
1. Sample a population of $M$ random action sequences from a proposal distribution $p(A)$
2. Evaluate each action sequence by simulating it on the model from the current state and summing the rewards (compute $J(A_i)$ for each sequence $A_i$)
3. Select elites, the top $K$ sequences with highest returns.
4. Update distribution $p(A)$ by fitting it to these elite samples
5. Repeat steps 1-4 until convergence or a fixed number of iterations.
6. Output the best sequence found (or the mean of the final elite distribution) as the plan, and execute the first action of this plan.

This algorithm concentrates search around promising regions of the action space, making planning more efficient and effective than one-shot random shooting. The advantage of CEM is that it requires far fewer samples to find good action sequences because it iteratively zooms in on high-reward areas rather than repeatedly sampling uniformly at random. CEM is still simple and highly parallelizable, and it has been used successfully in MPC-based RL.

### Implementation

In [None]:
class CEMPlanner:
    def __init__(self, act_dim, horizon, pop_size=500, elite_frac=0.1, cem_iterations=4, action_bounds=None, init_mean=None, init_std=None):
        self.act_dim = act_dim
        self.horizon = horizon
        self.pop_size = pop_size
        self.elite_num = int(pop_size * elite_frac)
        self.cem_iterations = cem_iterations
        self.action_bounds = action_bounds

        self.init_mean = init_mean if init_mean else np.zeros((horizon, act_dim))
        self.init_std = init_std if init_std else np.ones((horizon, act_dim))

        self.curr_mean = self.init_mean.copy()
        self.curr_std = self.init_std.copy()
    
    def plan(self, initial_state, model):
        mean = self.curr_mean.copy()
        std = self.curr_std.copy()
        best_return = -np.inf
        best_sequence = None

        for iter in range(self.cem_iterations):
            action_sequences = np.random.normal(loc=mean, scale=std, size=(self.pop_size, self.horizon, self.act_dim))
            if self.action_bounds:
                low, high = self.action_bounds
                action_sequences = np.clip(action_sequences, low, high)

            returns = np.zeros(self.pop_size)
            for i in range(self.pop_size):
                seq = action_sequences[i]
                total_reward = 0.0
                state = torch.tensor(initial_state, dtype=torch.float32)
                for a in seq:
                    action = torch.tensor(a, dtype=torch.float32)
                    next_state, reward = model.predict(state, action)
                    total_reward += reward.item()
                    state = next_state

                returns[i] = total_reward

            elite_indices = returns.argsort()[-self.elite_num:]
            elite_seqs = action_sequences[elite_indices]

            mean = np.mean(elite_seqs, axis=0)
            std = np.std(elite_seqs, axis=0)

            max_idx = elite_indices[np.argmax(returns[elite_indices])]
            if returns[max_idx] > best_return:
                best_return = returns[max_idx]
                best_sequence = action_sequences[max_idx]

        return best_sequence



In [9]:
obs, _ = env.reset()
done = False

act_dim = env.action_space.shape[0]
horizon = 20
planner = CEMPlanner(act_dim, horizon, pop_size=500, elite_frac=0.1, cem_iterations=4, action_bounds=(env.action_space.low, env.action_space.high))
total_reward = 0.0
t = 0
while not done:
    current_state = obs
    plan_sequence = planner.plan(current_state, dynamics_model)
    action = plan_sequence[0]
    obs, reward, terminated, truncated, _ = env.step(action)
    total_reward += reward
    done = terminated or truncated
    t += 1
    if t % horizon == 0:
        planner.curr_mean = planner.init_mean.copy()
        planner.curr_std = planner.init_std.copy()

print("Episode finished with total reward: ", total_reward)
env.close()



## Comparison MPC vs MBPO vs SAC

In [10]:
from collections import deque


class ReplayBuffer:
    """ Stores transitions for model training and policy optimization """
    def __init__(self, max_size=int(1e6)):
        self.buffer = deque(maxlen=max_size)

    def add(self, state, action, reward, next_state, done):
        state = np.array(state, dtype=np.float32).flatten()
        action = np.array(action, dtype=np.float32).flatten()
        next_state = np.array(next_state, dtype=np.float32).flatten()
        reward = np.float32(reward)
        done = np.float32(done)
        
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        return (torch.tensor(np.stack(states), dtype=torch.float32),
                torch.tensor(np.stack(actions), dtype=torch.float32),
                torch.tensor(rewards, dtype=torch.float32).unsqueeze(1),
                torch.tensor(np.stack(next_states), dtype=torch.float32),
                torch.tensor(dones, dtype=torch.float32).unsqueeze(1))

    def __len__(self):
        return len(self.buffer)

In [11]:
#MPC with CEM
class CEMPlanner:
    def __init__(self, act_dim, horizon=5, pop_size=500, elite_frac=0.1, cem_iters=4, action_bounds=None):
        self.act_dim = act_dim
        self.horizon = horizon
        self.pop_size = pop_size
        self.elite_num = int(pop_size * elite_frac)
        self.cem_iters = cem_iters
        self.action_bounds = action_bounds

    def plan(self, initial_state, model):
        mean = np.zeros((self.horizon, self.act_dim))
        std = np.ones((self.horizon, self.act_dim))

        for _ in range(self.cem_iters):
            actions = np.random.normal(mean, std, (self.pop_size, self.horizon, self.act_dim))
            if self.action_bounds:
                actions = np.clip(actions, self.action_bounds[0], self.action_bounds[1])

            returns = np.zeros(self.pop_size)
            for i in range(self.pop_size):
                returns[i] = self.evaluate_sequence(initial_state, actions[i], model)
            
            elite_indices = returns.argsort()[-self.elite_num:]
            elite_actions = actions[elite_indices]

            mean = elite_actions.mean(axis=0)
            std = elite_actions.std(axis=0)

        return mean[0]
    
    def evaluate_sequence(self, state, actions, model):
        total_reward = 0.0
        for action in actions:
            state, reward = model.predict(state, action)
            total_reward += reward
        return total_reward


In [12]:
#MBPO
class DynamicsModel(nn.Module):
    def __init___(self, state_dim, action_dim, hidden_dim=256):
        super(DynamicsModel, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, state_dim + 1)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        next_state, reward = x[..., :-1], x[..., -1:]
        return next_state, reward

    def predict(self, state, action):
        state_tensor = torch.tensor(state, dtype=torch.float32)
        action_tensor = torch.tensor(action, dtype=torch.float32)
        next_state, reward = self(state_tensor, action_tensor)
        return next_state.detach().numpy().squeeze(0), reward.detach().numpy().squeeze(0)

In [13]:
def generate_model_rollouts(model, policy, real_buffer, model_buffer, rollout_length=5, num_rollouts=500):
    model.eval()
    policy.eval()

    states, _,_,_,_ = real_buffer.sample(num_rollouts)
    with torch.no_grad():
        for state in states:
            for _ in range(rollout_length):
                action, _ = policy.sample_action(torch.tensor(state, dtype=torch.float32).unsqueeze(0))
                next_state, reward = model.predict(state, action)
                model_buffer.add(state, action, reward, next_state, False)
                state = next_state

    model.train()
    policy.train()

In [14]:

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action, hidden_dim=256):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.mean = nn.Linear(hidden_dim, action_dim)
        self.log_std = nn.Linear(hidden_dim, action_dim)
        self.max_action = max_action

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        mean = self.mean(x)
        log_std = torch.clamp(self.log_std(x), -20, 2)
        return mean, log_std
    
    def sample_action(self, state):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        dist = torch.distributions.Normal(mean, std)
        z = dist.rsample()
        action = torch.tanh(z) * self.max_action
        log_prob = dist.log_prob(z) - torch.log(1 - action.pow(2) + 1e-6)
        return action, log_prob.sum(-1, keepdim=True)

In [15]:
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 1)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
class MBPOAgent:
    def __init__(self, state_dim, action_dim, max_action, model_rollouts=1000):
        self.actor = Actor(state_dim, action_dim, max_action)
        self.critic = Critic(state_dim, action_dim)
        self.critic2 = Critic(state_dim, action_dim)
        self.target_critic = Critic(state_dim, action_dim)
        self.target_critic2 = Critic(state_dim, action_dim)

        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)

        self.target_critic.load_state_dict(self.critic.state_dict())
        self.target_critic2.load_state_dict(self.critic2.state_dict())

        self.gamma = 0.99
        self.tau = 0.005
        self.alpha = 0.2
        self.model_rollouts = model_rollouts
        self.model = DynamicsModel(state_dim, action_dim)

    def select_action(self, state):
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        action, _ = self.actor.sample_action(state_tensor)
        return action.detach().numpy().squeeze(0)
    
    def train_model(self, real_buffer, model_buffer):
        batch_size=256
        states, actions, rewards, next_states, dones = model_buffer.sample(batch_size)

        predicted_next_states, predicted_rewards = self.model(states, actions)
        model_loss = F.mse_loss(predicted_next_states, next_states) + F.mse_loss(predicted_rewards, rewards)

        model_optimizer = optim.Adam(self.model.parameters(), lr=3e-4)
        model_optimizer.zero_grad()
        model_loss.backward()
        model_optimizer.step()

        generate_model_rollouts(self.model, self.actor, real_buffer, model_buffer, rollout_length=5, num_rollouts=self.model_rollouts)
    
    def train_sac(self, replay_buffer):
        batch_size = 256
        states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)

        #Compute Q-values:
        with torch.no_grad():
            next_actions, next_log_probs = self.actor.sample_action(next_states)
            q1_next, q2_next = self.target_critic(next_states, next_actions), self.target_critic2(next_states, next_actions)
            min_q_next = torch.min(q1_next, q2_next) - self.alpha * next_log_probs
            target_q = rewards + self.gamma * (1-dones) * min_q_next

        q1, q2 = self.critic(states, actions), self.critic2(states, actions)
        critic_loss = F.mse_loss(q1, target_q) + F.mse_loss(q2, target_q)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        #Update actor
        actions, log_probs = self.actor.sample_action(states)
        q1_new = self.critic(states, actions)
        actor_loss = (self.alpha * log_probs - q1_new).mean()

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        for param, target_param in zip(self.critic.parameters(), self.target_critic.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

        for param, target_param in zip(self.critic2.parameters(), self.target_critic2.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

In [None]:
# SAC

class SACAgent:
    def __init__(self, state_dim, action_dim, max_action):
        self.actor = Actor(state_dim, action_dim, max_action);
        self.critic = Critic(state_dim, action_dim);
        self.critic2 = Critic(state_dim, action_dim);
        self.target_critic = Critic(state_dim, action_dim);
        self.target_critic2 = Critic(state_dim, action_dim);

        self.target_critic.load_state_dict(self.critic.state_dict());
        self.target_critic2.load_state_dict(self.critic2.state_dict());

        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4);
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4);

        self.gamma = 0.99;
        self.tau = 0.005;
        self.alpha = 0.2;

    def select_action(self, state):
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0);
        action, _ = self.actor.sample_action(state_tensor);
        return action.detach().numpy().squeeze(0);

    def train(self, replay_buffer):
        batch_size = 256;
        states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size);

        # Compute Q-values:
        with torch.no_grad():
            next_actions, next_log_probs = self.actor.sample_action(next_states);
            q1_next, q2_next = self.target_critic(next_states, next_actions), self.target_critic2(next_states, next_actions);
            min_q_next = torch.min(q1_next, q2_next) - self.alpha * next_log_probs;
            target_q = rewards + self.gamma * (1 - dones) * min_q_next;

        q1, q2 = self.critic(states, actions), self.critic2(states, actions);
        critic_loss = F.mse_loss(q1, target_q) + F.mse_loss(q2, target_q);

        self.critic_optimizer.zero_grad();
        critic_loss.backward();
        self.critic_optimizer.step();

        # Update actor
        actions, log_probs = self.actor.sample_action(states);
        q1_new = self.critic(states, actions);
        actor_loss = (self.alpha * log_probs - q1_new).mean();

        self.actor_optimizer.zero_grad();
        actor_loss.backward();
        self.actor_optimizer.step();

        for param, target_param in zip(self.critic.parameters(), self.target_critic.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data);

        for param, target_param in zip(self.critic2.parameters(), self.target_critic2.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data);

In [18]:
## Comparison
import matplotlib.pyplot as plt

env = gym.make('HalfCheetah-v5')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])

mpc_controller = CEMPlanner(action_dim, horizon=20, pop_size=500)
mbpo_agent = MBPOAgent(state_dim, action_dim, max_action)
sac_agent = SACAgent(state_dim, action_dim, max_action)

real_buffer = ReplayBuffer(max_size=int(1e6))
model_buffer = ReplayBuffer(max_size=int(1e6))

num_episodes = 500
mpc_rewards, mbpo_rewards, sac_rewards = [], [], []
for episode in range(num_episodes):
    state, _ = env.reset()
    mpc_total, mbpo_total, sac_total = 0.0, 0.0, 0.0

    for step in range(1000):
        # MPC Action
        action_mpc = mpc_controller.plan(state, mbpo_agent.model)
        next_state, reward, terminated, truncated, _ = env.step(action_mpc)
        done = terminated or truncated
        mpc_total += reward
        state = next_state if not done else env.reset()

        #MBPO Action
        action_mbpo = mbpo_agent.select_action(state)
        next_state, reward, terminated, truncated, _ = env.step(action_mbpo)
        done = terminated or truncated
        real_buffer.add(state, action_mbpo, reward, next_state, done)
        state = next_state if not done else env.reset()

        #SAC Action
        action_sac = sac_agent.select_action(state)
        next_state, reward, terminated, truncated, _ = env.step(action_sac)
        done = terminated or truncated
        real_buffer.add(state, action_sac, reward, next_state, done)
        sac_total += reward
        state = next_state if not done else env.reset()

        #Train MBPO
        if len(real_buffer) > 10000:
            mbpo_agent.train_model(real_buffer, model_buffer)
            mbpo_agent.train_sac(real_buffer)

        #Train SAC
        if len(real_buffer) > 10000:
            sac_agent.train(real_buffer)

    mpc_rewards.append(mpc_total)
    mbpo_rewards.append(mbpo_total)
    sac_rewards.append(sac_total)

    print(f"Episode {episode+1}, MPC Reward: {mpc_total}, MBPO Reward: {mbpo_total}, SAC Reward: {sac_total}")

plt.figure(figsize=(12, 8))
plt.plot(mpc_rewards, label='MPC')
plt.plot(mbpo_rewards, label='MBPO')
plt.plot(sac_rewards, label='SAC')
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title("MPC vs MBPO vs SAC on HalfCheetah-v5")
plt.legend()
plt.show()
env.close()


