# Adaptive Rollout

Model-based RL builds a predicrive model of the environment and uses it to generate synthetic experience, which significantly improves sample efficiency. One of the most well known model-based RL algorithm is MBPO. MBPO enhances sample efficiency by training a dynamics model of the environment and using it to create rollouts, simulated sequences of state transitions. These rollouts are then used to train a policy, typically through an actor-critic algorithm.
However, MBPO suffers from an important limitation: fixed rollout lengths. This means that ince the learned dynamics model is sufficiently trained, it is used to simulate future transitions for a fixed number of steps. But this approach is not adaptive to the accuracy of the model:
- If rollouts are too short, the model is not utilized effectively, reducing the benefit of MBRL.
- If rollouts are too long, compounding model errors degrade the training data, harming policy learning.

To address this problem, we introduce Adaptive Rollout Lengths, where the number of steps simulated by the model is dynamically adjusted based on model uncertainty.
this approach aims to balance:
- Early-stage training, where short rollouts prevent excessive model bias.
- Late-stage training, where longer rollouts maximize sample efficiency when the model is accurate.

## Background

A rollout refers to a sequence of state transitions generated by a learned environment model. Instead of interacting with the real environment, the RL agent queries the model for predictions of what would happen given a state and action.
In standard MBPO, rollouts are fixed at a predetermined lenght. But this introduces a fundamental problem: the model is not equally accurate at all stages of training. Early on, the model is poorly trained, so long rollouts introduce high errors. Later, as the model improves, using short rollouts wastes valuable synthetic data.

## Theory

The idea behind adaptive rollouts is that the rollout lenght should be proportional to model confidence. If the model is accurate, we can trust it to generate longer sequences. If it is uncertain, shorter rollouts are safer.
This means that the rollout length, $H_t$, should be a function of model uncertainty:
$$H_t = \max(H_{min}, \min(H_{max}, H_{base}\cdot \exp^{-\beta \cdot Var(s_t)}))$$
where:
- $H_{min}$ and $H_{max}$ are the minimum and maximum rollout lengths.
- $H_{base}$ is the base rollout length.
- $\beta$ is a sensitivity parameter that controls how aggressively the rollout length changes
- $Var(s_t)$ is the variance of the model, the model uncertainty.

### Estimating Model Uncertainty
One way to estimate model uncertainty is to train an ensemble of N models, each predicting the next state:
$$\hat{s}_{t+1}^{(i)} = f_{\theta_i}(s_t, a_t), i \in {1,2,...,N}$$
We then compute the variance of these predictions:
$$Var(\hat{s}_{t+1}) = \frac{1}{N-1} \sum_{i=1}^{N} (\hat{s}_{t+1}^{(i)} - \bar{\hat{s}}_{t+1})^2$$
where $\bar{\hat{s}}_{t+1}$ is the mean of the predictions.
The variance serves as a measure of model confidence, if predictions vary widely, it indicates high uncertainty, meaning shorter rollouts should be used.

## Implementation

In [1]:
import numpy as np
import torch
from collections import deque
import random

class ReplayBuffer:
    """ Stores transitions for real and model rollouts """
    def __init__(self, max_size=int(1e6)):
        self.buffer = deque(maxlen=max_size)

    def add(self, state, action, reward, next_state, done):
        state = np.array(state, dtype=np.float32).flatten()
        action = np.array(action, dtype=np.float32).flatten()
        next_state = np.array(next_state, dtype=np.float32).flatten()
        reward = np.float32(reward)
        done = np.float32(done)
        
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        return (torch.tensor(np.stack(states), dtype=torch.float32),
                torch.tensor(np.stack(actions), dtype=torch.float32),
                torch.tensor(rewards, dtype=torch.float32).unsqueeze(1),
                torch.tensor(np.stack(next_states), dtype=torch.float32),
                torch.tensor(dones, dtype=torch.float32).unsqueeze(1))

    def __len__(self):
        return len(self.buffer)


In [2]:
import torch.nn as nn

class EnsembleDynamicsModel:
    """ An ensemble of neural networks for uncertainty-aware model-based rollouts """
    def __init__(self, state_dim, action_dim, num_models=5):
        super(EnsembleDynamicsModel, self).__init__()
        self.models = [self._build_model(state_dim, action_dim) for _ in range(num_models)]
        self.num_models = num_models

    def _build_model(self, state_dim, action_dim):
        """ Create a single dynamics model """
        return nn.Sequential(
            nn.Linear(state_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, state_dim + 1)  # Predict next state and reward
        )

    def predict(self, state, action):
        """ Predict next state and reward with model ensemble """
        inputs = torch.cat([state, action], dim=-1)
        predictions = [model(inputs) for model in self.models]
        
        predictions = torch.stack(predictions)  # Shape: (num_models, batch, output_dim)
        mean_prediction = predictions.mean(dim=0)  # Mean over ensemble models
        uncertainty = predictions.var(dim=0).mean().item()  # Variance as uncertainty
        
        return mean_prediction[..., :-1], mean_prediction[..., -1], uncertainty


In [36]:
def generate_adaptive_model_rollouts(model, policy, real_buffer, model_buffer, adaptive_rollout):
    """ Generate synthetic rollouts with dynamically adjusted rollout length """
    for model_instance in model.models:
        model_instance.eval()
    policy.eval()

    states, _, _, _, _ = real_buffer.sample(500)  # Sample states from real buffer
    states = states.to(torch.float32)

    for state in states:
        uncertainty_sum = 0  # Track uncertainty to adjust rollout length
        for _ in range(adaptive_rollout.rollout_length):
            action, _ = policy.sample_action(state.unsqueeze(0))
            action = torch.tensor(action, dtype=torch.float32).unsqueeze(0)  # Ensure action has the same dimension as state
            next_state, reward, uncertainty = model.predict(state.unsqueeze(0), action)
            model_buffer.add(state.numpy(), action.squeeze(0).numpy(), reward.numpy(), next_state.numpy(), False)
            state = next_state.squeeze(0)  # Remove batch dimension for the next iteration
            uncertainty_sum += uncertainty
        
        avg_uncertainty = uncertainty_sum / adaptive_rollout.rollout_length
        adaptive_rollout.update(avg_uncertainty)  # Adjust rollout length

    model.train()
    policy.train()


In [37]:
class AdaptiveRollout:
    """ Manages the adaptive rollout length based on model uncertainty """
    def __init__(self, rollout_min=1, rollout_max=10, beta=0.1):
        self.rollout_min = rollout_min
        self.rollout_max = rollout_max
        self.beta = beta
        self.rollout_length = rollout_min

    def update(self, uncertainty):
        """ Adjust rollout length based on uncertainty """
        self.rollout_length = int(max(self.rollout_min, min(self.rollout_max, 
                               self.rollout_max * np.exp(-self.beta * uncertainty))))
        return self.rollout_length

In [38]:
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.optim as optim

class Actor(nn.Module):
    """ Stochastic Policy Network for SAC """
    def __init__(self, state_dim, action_dim, max_action, hidden_dim=256):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.mean = nn.Linear(hidden_dim, action_dim)  # Mean of action distribution
        self.log_std = nn.Linear(hidden_dim, action_dim)  # Log standard deviation
        self.max_action = max_action

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        mean = self.mean(x)
        log_std = torch.clamp(self.log_std(x), -20, 2)  # Keep std in reasonable range
        return mean, log_std

    def sample_action(self, state):
        """ Samples action using reparameterization trick (for SAC) """
        mean, log_std = self.forward(state)
        std = log_std.exp()
        dist = torch.distributions.Normal(mean, std)
        z = dist.rsample()  # Reparameterization trick
        action = torch.tanh(z) * self.max_action  # Squash output to action space
        log_prob = dist.log_prob(z) - torch.log(1 - action.pow(2) + 1e-6)
        return action.detach().numpy(), log_prob.sum(dim=-1, keepdim=True)

class Critic(nn.Module):
    """ Twin Q-Value Network for SAC """
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(Critic, self).__init__()
        # First Q-Network
        self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 1)  # Q-value output

        # Second Q-Network
        self.fc4 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc5 = nn.Linear(hidden_dim, hidden_dim)
        self.fc6 = nn.Linear(hidden_dim, 1)  # Second Q-value output

    def forward(self, state, action):
        """ Forward pass for both Q-networks """
        x1 = torch.cat([state, action], dim=-1)
        x1 = torch.relu(self.fc1(x1))
        x1 = torch.relu(self.fc2(x1))
        q1 = self.fc3(x1)

        x2 = torch.cat([state, action], dim=-1)
        x2 = torch.relu(self.fc4(x2))
        x2 = torch.relu(self.fc5(x2))
        q2 = self.fc6(x2)

        return q1, q2

In [39]:
class AdaptiveMBPO:
    def __init__(self, state_dim, action_dim, max_action):
        self.actor = Actor(state_dim, action_dim, max_action)
        self.critic1 = Critic(state_dim, action_dim)
        self.critic2 = Critic(state_dim, action_dim)
        self.target_critic1 = Critic(state_dim, action_dim)
        self.target_critic2 = Critic(state_dim, action_dim)
        self.dynamics_model = EnsembleDynamicsModel(state_dim, action_dim)
        self.adaptive_rollout = AdaptiveRollout(rollout_min=1, rollout_max=10, beta=0.1)
        
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
        self.critic_optimizer = torch.optim.Adam(list(self.critic1.parameters()) + list(self.critic2.parameters()), lr=3e-4)

        self.gamma = 0.99
        self.tau = 0.005
        self.alpha = 0.2

    def train_model(self, real_buffer, model_buffer):
        """ Train dynamics model and generate adaptive-length rollouts """
        batch_size = 256
        states, actions, rewards, next_states, dones = real_buffer.sample(batch_size)

        predicted_next_states, predicted_rewards, _ = self.dynamics_model.predict(states, actions)
        model_loss = nn.MSELoss()(predicted_next_states, next_states) + nn.MSELoss()(predicted_rewards, rewards)

        optimizer = torch.optim.Adam(
            [param for model in self.dynamics_model.models for param in model.parameters()], 
            lr=1e-3
        )
        optimizer.zero_grad()
        model_loss.backward()
        optimizer.step()

        generate_adaptive_model_rollouts(self.dynamics_model, self.actor, real_buffer, model_buffer, self.adaptive_rollout)

    def train_sac(self, replay_buffer):
        """ Train SAC policy using both real and synthetic experiences """
        states, actions, rewards, next_states, dones = replay_buffer.sample(256)

        with torch.no_grad():
            next_actions, next_log_probs = self.actor.sample_action(next_states)
            q1_next, q2_next = self.target_critic1(next_states, next_actions), self.target_critic2(next_states, next_actions)
            min_q_next = torch.min(q1_next, q2_next) - self.alpha * next_log_probs
            target_q = rewards + self.gamma * (1 - dones) * min_q_next

        q1, q2 = self.critic1(states, actions), self.critic2(states, actions)
        critic_loss = nn.MSELoss()(q1, target_q) + nn.MSELoss()(q2, target_q)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()



In [40]:
import gymnasium as gym
import matplotlib.pyplot as plt

# Initialize environment
env = gym.make("HalfCheetah-v5")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])

# Initialize agent, replay buffers
agent = AdaptiveMBPO(state_dim, action_dim, max_action)
real_buffer = ReplayBuffer(max_size=int(1e6))
model_buffer = ReplayBuffer(max_size=int(1e6))

num_episodes = 500
adaptive_mbpo_rewards = []

# Training loop
for episode in range(num_episodes):
    state, _ = env.reset()
    total_reward = 0

    for step in range(1000):  # Maximum steps per episode
        action, _ = agent.actor.sample_action(torch.tensor(state, dtype=torch.float32))
        next_state, reward, done, _, _ = env.step(action)
        real_buffer.add(state, action, reward, next_state, done)
        total_reward += reward
        state = next_state if not done else env.reset()

        # Train adaptive MBPO once enough real data is collected
        if len(real_buffer) > 10000:
            agent.train_model(real_buffer, model_buffer)
            agent.train_sac(real_buffer)

    adaptive_mbpo_rewards.append(total_reward)
    print(f"Episode {episode}: Total Reward = {total_reward}")

# Plot results
plt.plot(adaptive_mbpo_rewards, label="Adaptive Rollout MBPO")
plt.xlabel("Episodes")
plt.ylabel("Total Reward")
plt.title("Adaptive Rollout MBPO on HalfCheetah-v4")
plt.legend()
plt.show()

Episode 0: Total Reward = -312.44342318839784
Episode 1: Total Reward = -344.8832943453535
Episode 2: Total Reward = -274.44077136748234
Episode 3: Total Reward = -301.9231116595517
Episode 4: Total Reward = -275.7856033364171
Episode 5: Total Reward = -471.3448935713538
Episode 6: Total Reward = -238.518042016767
Episode 7: Total Reward = -338.72639772565344
Episode 8: Total Reward = -300.12976323873215
Episode 9: Total Reward = -323.26175059823134


RuntimeError: Tensors must have same number of dimensions: got 2 and 3