# Cross-Entropy Method (CEM) for Planning

Gradient-free, population-based optimization algorithm used widely in:
- planning in latent dynamics models
- continuous control without differentiable models
- Hig-dimensional policy search
CEM is useful in model-based RL:
- It works directly with learned models
- it optimizes action sequences over a planning horizon
- It's simple, parallelizable and highly effective.

The fundamental goal in many control and decision-making problems is:\
Find an action sequence $a_{1:H}$ that maximizes expected return: $a_1{:H}^* = \arg\max_{a_{1:H}} \mathbb{E}[\sum_{t=1}^H r(s_t, a_t)]$\
This is often non-convex, non-differentiable, or model-based.

Instead of searching for a solution directly, CEM transforms the optimization into a probabilitty distribution optimization problem.\
"Search over distributions rather than solutions"\
CEM maintains a parameterized distribution over candidate solutions, samples from it, evaluates those samples, and refines the distribution to increase the likelihood of generating good solutions.

## Algorithm

Let $p(x;\theta)$ be a distribution over solutions $x\in\mathbb{R}^d$, parameterized by $\theta$.\
The Cross-Entropy Method solves:
$$\theta^* = \arg\max_{\theta} D_{KL}(\pi^*(x)||p(x;\theta))$$
Where:
- $\pi^*(x)$ is the target distribution (e.g., the distribution of good solutions)
- $D_{KL}$ is the Kullback-Leibler divergence, which measures how one probability distribution diverges from a second expected probability distribution.
Since $\pi^*(x)$ is unknown, we approximate it using elite samples (top K) from the current distribution.

## Math

1. Sampling: Draw N samples $x_1, \ldots, x_N \sim p(\cdot;\theta)$
2. Scoring: Evaluate each sample under the objective function: $R_i = f(x_i) = cumulative predicted reward$
3. Elite Selection: Let $\mathcal{E} = {x_i, \ldots, x_{i_K}}$ be the elite set, top K samples with highest reward.
4. Update Distribution Parameters: For Gaussian $\mathcal{N}(\mu, \sigma^2)$, use the elite set to update the mean and variance:
$$\mu' = \frac{1}{K}\sum_{i=1}^K x_i$$
$$\sigma'^2 = \frac{1}{K}\sum_{i=1}^K(x_i - \mu')^2$$
This update minimizes the KL divergence between the new distribution and the one induced by elites.

CEM originates from rare simulation in Monte Carlo methods. In its original form, it was used to optimize the probability of hitting rare events:
$$\max_x \mathbb{P}(f(x) > \gamma)$$
It was later repurposed for general optimization, especially for control.

In model-based RL planning, we interpret x as an action sequence:
$$x=(a_1, a_2, \ldots, a_H)$$
CEM searches over action sequences to maximize predicted future rewards using a learned mdoel $\hat{T}(s,a)$ and reward predictor $\hat{r}(s,a)$.


## Implementation

In [5]:
import torch
import numpy as np

class DummyModel:
    def evaluate_rollout(self, h, s, actions):
        return - (actions ** 2).sum(dim=(1,2))

class CEMPlanner:
    def __init__(self, model, plan_horizon=10, action_dim=2, num_candidates=1000, top_k=100, num_iters=5, action_bounds=(-1.0,1.0)):
        self.model = model
        self.plan_horizon = plan_horizon
        self.action_dim = action_dim
        self.num_candidates = num_candidates
        self.top_k = top_k
        self.num_iters = num_iters
        self.action_bounds = action_bounds

    def plan(self, h, s):
        mean = torch.zeros(self.plan_horizon, self.action_dim)
        std = torch.ones_like(mean)*0.5
        for _ in range(self.num_iters):
            actions = torch.normal(mean.unsqueeze(0).expand(self.num_candidates, -1,-1),
                                std.unsqueeze(0).expand(self.num_candidates, -1,-1))
            actions = actions.clamp(*self.action_bounds)

            returns = self.model.evaluate_rollout(h, s, actions)

            elite_inds = returns.topk(self.top_k).indices
            elites = actions[elite_inds]

            mean = elites.mean(dim=0)
            std = elites.std(dim=0) + 1e-5

        return mean[0]

model = DummyModel()        
planner = CEMPlanner(model, plan_horizon=10, action_dim=2)

h = torch.randn(1, 200)
s = torch.randn(1, 30)

best_action = planner.plan(h, s)
print("Best action:", best_action)

Best action: tensor([-0.0028,  0.0378])
