# Model-Based Policy Optimization

Model-Free RL like SAC or PPO are highly flexible and generalizable. However, they require millions of interactions with the environment to learn effective policies.
Each interaction is typically used only for a few gradient updates before being discarded, making learning process slow and data-inefficient.
In real-world applications like robotics and autonomous driving, model-free can be impractical.

Model-based Policy Optimization improves sample efficiency by introducing a learned model of the environement dynamics.
- Instead of relying solely on costly real-world interactions, MBPO trains a NN to predict state transitions and rewards from state-action inputs.
- This learned model is then used to simulate imaginary rollouts (hypothetical experience that supplements real data)
- By treating the learned model as a simulator, MBPO allows the RL agent to learn from both real and synthetic experiences, dramatically reducing the number of real interactions needed.
A Major problem with using a learned model is model bias, if the model makes inaccurate predictions, the policy trained on these predictions may become unreliable.

MBPO solves this problem by:
1. Learning an approximate dynamics model $p_\theta(s_{t+1}|s_t, a_t)$ that predicts the next state and reward.
2. Using short rollouts (1-5 steps) instead of long-term simulations to limit compounding errors.
3. Maintain an ensemble of models to improve the reliability of predictions and estimate uncertainty.

## Mathematical Formulation


The RL agent aims to maximize the expected cumulative reward:
$$J(\pi) = \mathbb{E}[\sum_{t=0} \gamma^t r_t]$$
MBPO modifies this by incorporating synthetic experience into training. The learned model is trained using Maximum Likelihood Estimation (MLE):
$$L_{model}(\theta) = - \mathbb{E}_{(s,a,s')\sim D_{env}[\log p_\theta(s'|s,a)]}$$
where $D_{env}$ is the real environment dataset of state transitions.

The rollowut process starts from a real state $s_0$ and uses the model to simulate k steps:
$$s_0 \xrightarrow{a_0} s_1 \xrightarrow{a_1} s_2 \xrightarrow{a_2} ... \xrightarrow{a_{k-1}} s_k$$

The hybrid return estimate is:
$$R_k(s_0;\pi) = \mathbb{E}_{a_t\sim\pi}[\sum_{t=0}^{k-1} \gamma^t r(s_t,a_t)+\gamma^kV^\pi(s_k)]$$
where $V^\pi(s_k)$ is the estimated value function.
By limiting rollouts to a short horizon k, MBPO prevents model errors from accumulating too much.

## MBPO vs. Model-Free RL
- SAC and PPO require large number of environment steps to achieve good performance.
- MBPO achieves the same results in 5-10 times fewer environment steps by using synthetic data.
- PPO is especially sample-inefficient becase is on-policy and cannot reuse past data.
- MBPO maintains high final performance while significantly improving sample efficiency.

## MBPO vs. Other Model-Based Methods

Previous model-based RL struggled with:
- Long-horizon model rollouts leading to severe model bias.
- Overreliance on model predictions, causing poor final policy performance.
- Complex computational planning techniques that were inefficient and hard to scale.

MBPO addresses these issues by:
- Limiting rollouts to a few steps to reduce error accumulation
- Using an ensemble of models to improve prediction reliability.
- Integrating synthetic and real data in a balanced way for stable learning.

# Implementation Details

## Dynamics model
Predicts the next state and reward given a state-action pair.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class DynamicModel(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(DynamicModel, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, state_dim+1) #predicts next state and reward
        )

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        output = self.net(x)
        next_state, reward = output[..., :-1], output[..., -1:]
        return next_state, reward

## Experience Buffer RReplay

In [None]:
from collections import deque
import random
import torch
import numpy as np

class ReplayBuffer:
    """ A replay buffer to store real and model-generated experiences. """
    def __init__(self, max_size=int(1e6)):
        self.buffer = deque(maxlen=max_size)

    def add(self, state, action, reward, next_state, done):
        """ Convert inputs to NumPy arrays before storing. """
        state = np.array(state, dtype=np.float32).flatten()
        action = np.array(action, dtype=np.float32).flatten()
        next_state = np.array(next_state, dtype=np.float32).flatten()
        reward = np.float32(reward)  # Scalar
        done = np.float32(done)  # Scalar
        
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        """ Sample a batch of transitions and return PyTorch tensors. """
        batch = random.sample(self.buffer, min(batch_size, len(self.buffer)))
        states, actions, rewards, next_states, dones = zip(*batch)

        return (torch.tensor(np.stack(states), dtype=torch.float32),  
                torch.tensor(np.stack(actions), dtype=torch.float32),
                torch.tensor(rewards, dtype=torch.float32).unsqueeze(1),  
                torch.tensor(np.stack(next_states), dtype=torch.float32),
                torch.tensor(dones, dtype=torch.float32).unsqueeze(1))  

    def __len__(self):
        return len(self.buffer)



## SAC

In [None]:
class Actor(nn.Module):
    """ Stochastic policy network with Tanh squashing. """
    def __init__(self, state_dim, action_dim, max_action, hidden_dim=256):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.mean = nn.Linear(hidden_dim, action_dim)
        self.log_std = nn.Linear(hidden_dim, action_dim)
        self.max_action = max_action

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        mean = self.mean(x)
        log_std = self.log_std(x).clamp(-20, 2)
        return mean, log_std

    def sample_action(self, state):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        normal_dist = torch.distributions.Normal(mean, std)
        z = normal_dist.rsample()
        action = torch.tanh(z) * self.max_action
        log_prob = normal_dist.log_prob(z) - torch.log(1 - action.pow(2) + 1e-6)
        log_prob = log_prob.sum(dim=-1, keepdim=True)
        return action, log_prob


In [4]:
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(Critic, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

        self.q1 = self.net
        self.q2 = self.net
    
    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        q1 = self.q1(x)
        q2 = self.q2(x)
        return q1, q2

## Generate model Rollouts

In [None]:
def generate_model_rollouts(model, policy, env_buffer, model_buffer, rollout_length=5, num_rollouts=400):
    """Generate synthetic transitions using an ensemble-based uncertainty threshold."""
    model.eval()
    policy.eval()
    
    states, _, _, _, _ = env_buffer.sample(num_rollouts)

    with torch.no_grad():
        for state in states:
            for rollout_step in range(rollout_length):
                action, _ = policy.sample_action(state.unsqueeze(0))
                next_state, reward = model.predict(state.unsqueeze(0), action.unsqueeze(0))

                next_state = next_state.squeeze(0).numpy().flatten()
                action = action.squeeze(0).numpy().flatten()

                model_buffer.add(state.numpy(), action, reward.item(), next_state, False)
                state = next_state

    model.train()
    policy.train()

## Train Loop

In [6]:
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt

env = gym.make('HalfCheetah-v5')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])

env_model = DynamicModel(state_dim, action_dim)
actor = Actor(state_dim, action_dim, max_action)
critic = Critic(state_dim, action_dim)
target_critic = Critic(state_dim, action_dim)

target_critic.load_state_dict(critic.state_dict())

env_model_optimizer = optim.Adam(env_model.parameters(), lr=1e-3)
actor_optimizer = optim.Adam(actor.parameters(), lr=1e-3)
critic_optimizer = optim.Adam(critic.parameters(), lr=1e-3)

env_buffer = ReplayBuffer()
model_buffer = ReplayBuffer()

num_env_steps = 50000
env_steps_per_iter = 1000
policy_updates_per_iter = 200
rollout_length = 5
batch_size = 256
gamma = 0.99
tau = 0.005

state,_ = env.reset()
for _ in range(env_steps_per_iter):
    action = env.action_space.sample()
    next_state, reward, terminated, truncated, _ = env.step(action)
    done = terminated or truncated
    env_buffer.add(state, action, reward, next_state, done)
    state = next_state if not done else env.reset()

    if len(env_buffer.buffer) >= batch_size:
        for _ in range(50):
            states, actions, rewards, next_states, dones = env_buffer.sample(batch_size)
            pred_next_states, pred_rewards = env_model(states, actions)
            pred_rewards = pred_rewards.squeeze(-1)
            pred_rewards = pred_rewards.squeeze(-1)
            env_model_loss = F.mse_loss(pred_next_states, next_states) + F.mse_loss(pred_rewards, rewards)
            env_model_optimizer.zero_grad()
            env_model_loss.backward()
            env_model_optimizer.step()

        generate_model_rollouts(env_model, actor, env_buffer, model_buffer, rollout_length=rollout_length, num_rollouts=100)

print("Training Env model completed")

  env_model_loss = F.mse_loss(pred_next_states, next_states) + F.mse_loss(pred_rewards, rewards)


Training Env model completed


In [None]:
# Run SAC
alpha = 0.2

state,_ = env.reset()
for _ in range(policy_updates_per_iter):
    batch_half = batch_size // 2
    batch_half_env = min(batch_half, len(env_buffer))
    batch_half_model = min(batch_half, len(model_buffer))

    states_env, actions_env, rewards_env, next_states_env, dones_env = env_buffer.sample(batch_half_env)
    states_model, actions_model, rewards_model, next_states_model, dones_model = model_buffer.sample(batch_half_model)

    states = torch.cat([states_env, states_model], dim=0)
    actions = torch.cat([actions_env, actions_model], dim=0)
    rewards = torch.cat([rewards_env, rewards_model], dim=0)
    next_states = torch.cat([next_states_env, next_states_model], dim=0)
    dones = torch.cat([dones_env, dones_model], dim=0)

    with torch.no_grad():
        next_actions, next_log_probs = actor.sample_action(next_states)
        target_q = torch.min(*target_critic(next_states, next_actions))-alpha * next_log_probs
        y = rewards + gamma * (1-dones) * target_q

    q1, q2 = critic(states, actions)
    critic_loss = F.mse_loss(q1, y) + F.mse_loss(q2, y)

    critic_optimizer.zero_grad()
    critic_loss.backward()
    critic_optimizer.step()

    if _ % 2 == 0:
        actor_loss = torch.mean(alpha * actor(states)[1] - critic(states, actor(states)[0])[0])
        actor_optimizer.zero_grad()
        actor_loss.backward()
        actor_optimizer.step()

        for param, target_param in zip(critic.parameters(), target_critic.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

    if _ % 10 == 0:
        print(f"Step: {_}, Actor Loss: {actor_loss.item()}, Critic Loss: {critic_loss.item()}")

print("Training SAC completed")
        

ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (128,) + inhomogeneous part.

# Next Steps

## Improve Dynamics Model with Ensembles

Single dynamics model can make inaccurate predictions, leading to model bias. Instead, we train an ensemble of models and randomly choose one for each synthetic step.

In [None]:
class EnsembleDynamicsModel(nn.Module):
    """ Ensemble of neural networks to predict next state and reward. """
    def __init__(self, state_dim, action_dim, num_models=5, hidden_dim=256):
        super(EnsembleDynamicsModel, self).__init__()
        self.models = nn.ModuleList([
            nn.Sequential(
                nn.Linear(state_dim + action_dim, hidden_dim), nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                nn.Linear(hidden_dim, state_dim + 1)  # Predicts next state (state_dim) and reward (1)
            ) for _ in range(num_models)
        ])
        self.num_models = num_models

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        outputs = [model(x) for model in self.models]
        next_states, rewards = zip(*[(out[..., :-1], out[..., -1:]) for out in outputs])
        return next_states, rewards  

    def predict(self, state, action):
        """ Randomly select one model from the ensemble for a prediction """
        model_idx = np.random.randint(self.num_models)
        x = torch.cat([state, action], dim=-1)
        output = self.models[model_idx](x)
        next_state, reward = output[..., :-1], output[..., -1:]
        return next_state, reward


In [None]:
env_model = EnsembleDynamicsModel(state_dim, action_dim, num_models=5)
actor = Actor(state_dim, action_dim, max_action)
critic = Critic(state_dim, action_dim)
target_critic = Critic(state_dim, action_dim)

target_critic.load_state_dict(critic.state_dict())

env_model_optimizer = optim.Adam(env_model.parameters(), lr=1e-3)
actor_optimizer = optim.Adam(actor.parameters(), lr=1e-3)
critic_optimizer = optim.Adam(critic.parameters(), lr=1e-3)

env_buffer = ReplayBuffer()
model_buffer = ReplayBuffer()

num_env_steps = 50000
env_steps_per_iter = 1000
policy_updates_per_iter = 200
rollout_length = 5
batch_size = 256
gamma = 0.99
tau = 0.005

state,_ = env.reset()
for _ in range(env_steps_per_iter):
    action = env.action_space.sample()
    next_state, reward, terminated, truncated, _ = env.step(action)
    done = terminated or truncated
    env_buffer.add(state, action, reward, next_state, done)
    state = next_state if not done else env.reset()

    if len(env_buffer.buffer) >= batch_size:
        for _ in range(50):
            states, actions, rewards, next_states, dones = env_buffer.sample(batch_size)
            pred_next_states, pred_rewards = env_model(states, actions)
            pred_rewards = pred_rewards.squeeze(-1)
            pred_rewards = pred_rewards.squeeze(-1)
            env_model_loss = F.mse_loss(pred_next_states, next_states) + F.mse_loss(pred_rewards, rewards)
            env_model_optimizer.zero_grad()
            env_model_loss.backward()
            env_model_optimizer.step()

        generate_model_rollouts(env_model, actor, env_buffer, model_buffer, rollout_length=rollout_length, num_rollouts=100)

print("Training Env model completed")

In [None]:
# Run SAC
alpha = 0.2

state,_ = env.reset()
for _ in range(policy_updates_per_iter):
    batch_half = batch_size // 2
    batch_half_env = min(batch_half, len(env_buffer))
    batch_half_model = min(batch_half, len(model_buffer))

    states_env, actions_env, rewards_env, next_states_env, dones_env = env_buffer.sample(batch_half_env)
    states_model, actions_model, rewards_model, next_states_model, dones_model = model_buffer.sample(batch_half_model)

    states = torch.cat([states_env, states_model], dim=0)
    actions = torch.cat([actions_env, actions_model], dim=0)
    rewards = torch.cat([rewards_env, rewards_model], dim=0)
    next_states = torch.cat([next_states_env, next_states_model], dim=0)
    dones = torch.cat([dones_env, dones_model], dim=0)

    with torch.no_grad():
        next_actions, next_log_probs = actor.sample_action(next_states)
        target_q = torch.min(*target_critic(next_states, next_actions))-alpha * next_log_probs
        y = rewards + gamma * (1-dones) * target_q

    q1, q2 = critic(states, actions)
    critic_loss = F.mse_loss(q1, y) + F.mse_loss(q2, y)

    critic_optimizer.zero_grad()
    critic_loss.backward()
    critic_optimizer.step()

    if _ % 2 == 0:
        actor_loss = torch.mean(alpha * actor(states)[1] - critic(states, actor(states)[0])[0])
        actor_optimizer.zero_grad()
        actor_loss.backward()
        actor_optimizer.step()

        for param, target_param in zip(critic.parameters(), target_critic.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

    if _ % 10 == 0:
        print(f"Step: {_}, Actor Loss: {actor_loss.item()}, Critic Loss: {critic_loss.item()}")

print("Training SAC completed")

## Adaptive Rollout Length Based on Model Uncertainty

MBPO uses a fixed rollout length, but longer rollouts can accumulate model errors, if the model has high uncertainty, longer rollouts can lead to bad policies.
The solution is to measure the ensemble disagreement and stop rollouts when uncertainty is high.

In [None]:
def generate_model_rollouts(model, policy, env_buffer, model_buffer, max_rollout_length=5, num_rollouts=400):
    model.eval()
    policy.eval()

    states, _,_,_,_= env_buffer.sample(num_rollouts)

    with torch.no_grad():
        for state in states:
            for rollout_step in range(max_rollout_length):
                action, _ = policy.sample_action(state.unsqueeze(0))
                
                next_state, reward = model(state.unsqueeze(0), action)
                next_state = torch.stack(next_state)
                reward = torch.stack(reward)

                model_uncertainty = torch.var(next_state, dim=0).mean()

                if model_uncertainty.item() > 0.05:
                    break

                next_state, reward = model(state.unsqueeze(0), action)

                model_buffer.add(state.numpy(), action.numpy(), reward.item(), next_state.numpy(), False)

                state = next_state.clone()

    model.train()
    policy.train()

In [None]:
env_model = EnsembleDynamicsModel(state_dim, action_dim, num_models=5)
actor = Actor(state_dim, action_dim, max_action)
critic = Critic(state_dim, action_dim)
target_critic = Critic(state_dim, action_dim)

target_critic.load_state_dict(critic.state_dict())

env_model_optimizer = optim.Adam(env_model.parameters(), lr=1e-3)
actor_optimizer = optim.Adam(actor.parameters(), lr=1e-3)
critic_optimizer = optim.Adam(critic.parameters(), lr=1e-3)

env_buffer = ReplayBuffer()
model_buffer = ReplayBuffer()

num_env_steps = 50000
env_steps_per_iter = 1000
policy_updates_per_iter = 200
rollout_length = 5
batch_size = 256
gamma = 0.99
tau = 0.005

state,_ = env.reset()
for _ in range(env_steps_per_iter):
    action = env.action_space.sample()
    next_state, reward, terminated, truncated, _ = env.step(action)
    done = terminated or truncated
    env_buffer.add(state, action, reward, next_state, done)
    state = next_state if not done else env.reset()

    if len(env_buffer.buffer) >= batch_size:
        for _ in range(50):
            states, actions, rewards, next_states, dones = env_buffer.sample(batch_size)
            pred_next_states, pred_rewards = env_model(states, actions)
            pred_rewards = pred_rewards.squeeze(-1)
            pred_rewards = pred_rewards.squeeze(-1)
            env_model_loss = F.mse_loss(pred_next_states, next_states) + F.mse_loss(pred_rewards, rewards)
            env_model_optimizer.zero_grad()
            env_model_loss.backward()
            env_model_optimizer.step()

        generate_model_rollouts(env_model, actor, env_buffer, model_buffer, rollout_length=rollout_length, num_rollouts=100)

print("Training Env model completed")
# Run SAC
alpha = 0.2

state,_ = env.reset()
for _ in range(policy_updates_per_iter):
    batch_half = batch_size // 2
    batch_half_env = min(batch_half, len(env_buffer.buffer))
    batch_half_model = batch_size-batch_half_env

    states_env, actions_env, rewards_env, next_states_env, dones_env = env_buffer.sample(batch_half_env)
    states_model, actions_model, rewards_model, next_states_model, dones_model = model_buffer.sample(batch_size)

    states = torch.cat([states_env, states_model], dim=0)
    actions = torch.cat([actions_env, actions_model], dim=0)
    rewards = torch.cat([rewards_env, rewards_model], dim=0)
    next_states = torch.cat([next_states_env, next_states_model], dim=0)
    dones = torch.cat([dones_env, dones_model], dim=0)

    with torch.no_grad():
        next_actions, next_log_probs = actor.sample_action(next_states)
        target_q = torch.min(*target_critic(next_states, next_actions))-alpha * next_log_probs
        y = rewards + gamma * (1-dones) * target_q

    q1, q2 = critic(states, actions)
    critic_loss = F.mse_loss(q1, y) + F.mse_loss(q2, y)

    critic_optimizer.zero_grad()
    critic_loss.backward()
    critic_optimizer.step()

    if _ % 2 == 0:
        actor_loss = torch.mean(alpha * actor(states)[1] - critic(states, actor(states)[0])[0])
        actor_optimizer.zero_grad()
        actor_loss.backward()
        actor_optimizer.step()

        for param, target_param in zip(critic.parameters(), target_critic.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

    if _ % 10 == 0:
        print(f"Step: {_}, Actor Loss: {actor_loss.item()}, Critic Loss: {critic_loss.item()}")

print("Training SAC completed")