# PlaNet (Planning LAtent Dynamics for RL)

Latent Dynamics model for pixel-based continuous control, which:
- Learns from high-dimensional observations(images)
- Encodes them into a latent state space
- Predicts latent dynamics using a Recurrent State-Space Model (RSSM)
- Plans sequences of actions using Cross-Entropy Method (CEM) in the latent space

## Background_

### World Model
Learns to predict future observations and rewards given past observations and actions.\
Consists of three components:
- Encoder $e(o_t)$: that maps observations to a latent state
- RSSM that models latent transitions
- Decoder $d_o(s_t, h_t), d_r(s_t, h_t)$: that reconstruct observations/rewards

## RSSM (Recurrent State-Space Model)
Combines deterministic and stochastic transitions:
- $h_t$: deterministic hidden state (RNN, GRU)
- $s_t$: stochastic latent variable
This allows both flexibility and memory/temporal abstraction.

### Variational Inference
Since the true posterior $p(s_t|o_{1:t}, a_{1:t})$ is intractable, we approximate it with a learned encoder $q(s_t|h_t, o_t)$.
### Planning
Once the model is trained, actions are selected using CEM in the latent space, making planning tractable and data efficient.

## Math

### Notation
- $o_t\in\mathbb{R}^{H\times W\times C}$: observation (image)
- $a_t\in\mathbb{R}^{n}$: action
- $r_t\in\mathbb{R}$: reward
- $h_t\in\mathbb{R}^{d}$: hidden state
- $s_t\in\mathbb{R}^{z}$: stochastic latent state

### Generative Model

The full model defines:
$$p(o_{1:T}, r_{1:T}, s_{1:T}, h_{1:T}|a_{1:T}) = \prod_{t=1}^T p(o_t|s_t, h_t) \cdot p(r_t|s_t, h_t) \cdot p(s_t|h_t) \cdot p(h_t|h_{t-1}, s_{t-1}, a_{t-1})$$
Where:
- Observation decoder: $p(o_t|s_t, h_t)$ - reconstructs image
- Reward Model: $p(r_t|s_t, h_t)$ - predicts scalar reward
- Stochastic transition: $p(s_t|h_t)$ - samples latent state from prior
- Deterministic transition: $h_t = f(h_{t-1}, s_{t-1}, a_{t-1})$ - GRU/RNN update

### Inference Model
Since posterior is intractable, we use:
$$q(s_t|h_t, o_t) = \mathcal{N}(\mu_t, \sigma_t^2)$$
Where $\mu_t, \sigma_t$ are ouptuts of the encoder network that take $h_t$ and $o_t$ as input.

### Learning Objective: ELBO
The training loss is the Evidence Lower Bound (ELBO) over a sequence:
$$\mathcal{L}_{ELBO} = \sum_{t=1}^T \mathbb{E}_{q(s_t)}[\log p(o_t|s_t, h_t) + \log p(r_t|s_t, h_t)] - D_{KL}[q(s_t|h_t, o_t)||p(s_t|h_t)]$$
Intuition:
Encourage the latent state to reconstruct the observation and reward and penalize the divergence between the inferred posterior and the prior dynamics.


## Planning in Latent Space using CEM
Once the RSSM is trained, we can simulate the future and plan actions without real interactions.
**Procedure**:\
1. From the current $(h_t, s_t)$, sample a batch of action sequences:
$$a_{t:t+H} \sim \mathcal{N}(\mu, \sigma)$$
2. Roll out the latent model for each sequence to get predicted rewards
3. Select top K sequences with highest rewards
4. Refit $\mu, \sigma$ to top samples
5. Repeat for N iterations
This yields an action plan maximizing predicted rewards.

### CEM Math
Sample N action sequences of length H:
$$A^i = (a^i_t, a^i_{t+1}, \ldots, a^i_{t+H})$$
Exptected reward:
$$R^i = \sum_{j=0}^H \hat{r}^i_{t+j}$$
Select top K based on $R^i$\
Update mean and std:
$$\mu_{new} = \frac{1}{K} \sum_{i\in topK} A^i, \quad \sigma_{new} = std(A^i)$$

## Implementation

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np


In [2]:
class RSSM(nn.Module):
    def __init__(self, action_dim, latent_dim=30, hidden_dim=200, obs_embed_dim=1024):
        super().__init__()
        self.latent_dim = latent_dim
        self.action_dim = action_dim

        self.gru = nn.GRUCell(latent_dim + action_dim, latent_dim)

        self.fc_prior = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.LineaR(hidden_dim, 2*latent_dim)
        )

        self.fc_posterior = nn.Sequential(
            nn.Linear(latent_dim + obs_embed_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2*latent_dim)
        )

    def init_state(self, batch_size):
        h = torch.zeros(batch_size, self.hidden_dim)
        s = torch.zeros(batch_size, self.latent_dim)
        return h, s
    
    def get_dist(self,stats):
        mean, std = torch.chunk(stats, 2, dim=-1)
        std = F.softplus(std)+ 1e-4
        return torch.distributions.Normal(mean, std)
    
    def forward(self, prev_state, action, embed_obs=None):
        h_prev, s_prev = prev_state
        x = torch.cat([s_prev, action], dim=-1)
        h = self.gru(x, h_prev)

        prior_stats = self.fc_prior(h)
        prior_dist = self.get_dist(prior_stats)

        if embed_obs is not None:
            x_post = torch.cat([h, embed_obs], dim=-1)
            post_stats = self.fc_posterior(x_post)
            post_dist = self.get_dist(post_stats)
            s = post_dist.rsample()
        else:
            post_stats = prior_stats
            post_dist = prior_dist
            s

        return (h, s), prior_dist, post_dist, prior_stats, post_stats

In [3]:
class ObsEncoder(nn.Module):
    def __init__(self, obs_shap=(3,64,64), embed_dim = 1024):
        super().__init__()
        C,H,W = obs_shap
        self.encoder = nn.Sequential(
            nn.Conv2d(C,32, 4, stride=2), nn.ReLU(),
            nn.Conv2d(32,64,4, stride=2), nn.ReLU(),
            nn.Conv2d(64,128,4, stride=2), nn.ReLU(),
            nn.Conv2d(128,256,4, stride=2), nn.ReLU(),
            nn.Flatten(),
            nn.Linear(256*3*3, embed_dim),
            nn.ReLU()
        )

    def forward(self, obs):
        return self.encoder(obs)
    
class ObsDecoder(nn.Module):
    def __init__(self, latent_dim=30, hidden_dim=200, obs_shape=(3,64,64)):
        super().__init__()
        C,H,W = obs_shape
        self.fc = nn.Sequential(
            nn.Linear(latent_dim+hidden_dim, 1024), nn.ReLU(),
            nn.Linear(1024, 256*3*3), nn.ReLU()
        )

        self.deconv=nn.Sequential(
            nn.ConvTranspose2d(256,128,5, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(128,64,5, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(64,32,6, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(32, C, 6, stride=2)
        )

    def forward(self, s,h):
        x = torch.cat([s,h], dim=-1)
        x = self.fc(x)
        x = x.view(-1, 256, 3, 3)
        return self.deconv(x)
    
class RewardDecoder(nn.Module):
    def __init__(self, latent_dim=30, hidden_dim=200):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(latent_dim+hidden_dim, 256), nn.ReLU(),
            nn.Linear(256, 1)
        )
    
    def forward(self, s,h):
        x = torch.cat([s,h], dim=-1)
        return self.fc(x)

In [4]:
def compute_elbo_loss(obs_seq, action_seq, reward_seq, encoder, decoder_obs, decoder_r, rssm, beta=1.0):
    B,T,C,H,W = obs_seq.shape
    loss_obs = 0.0
    loss_reward = 0.0
    loss_kl = 0.0

    h, s = rssm.init_state(B)

    for t in range(T):
        o_t = obs_seq[:,t]
        a_t = action_seq[:,t]
        r_t = reward_seq[:,t]

        emb_o = encoder(o_t)
        (h,s), prior, posterior, prior_stats, post_stats = rssm((h,s), a_t, embed_obs=emb_o)

        o_pred = decoder_obs(s,h)
        r_pred = decoder_r(s,h)

        recon_loss = F.mse_loss(o_pred, o_t, reduction='mean')
        reward_loss = F.mse_loss(r_pred, r_t, reduction='mean')
        kl_div = torch.distributions.kl.kl_divergence(posterior, prior).mean()

        loss_obs += recon_loss
        loss_reward += reward_loss
        loss_kl += kl_div

    total_loss = loss_obs + loss_reward + beta * loss_kl
    return total_loss, loss_obs, loss_reward, kl_div

In [5]:
class CEMPlanner:
    def __init__(self, rssm, reward_model, action_dim, plan_horizon=12, optim_iters=5, candidates=1000, top_k=100, device='cpu'):
        self.rssm = rssm
        self.reward_model = reward_model
        self.action_dim = action_dim
        self.plan_horizon = plan_horizon
        self.optim_iters = optim_iters
        self.candidates = candidates
        self.top_k = top_k
        self.device = device

    def plan(self, h,s):
        B = h.size(0)
        mean = torch.zeros(B, self.plan_horizon, self.action_dim).to(self.device)
        std = torch.ones_like(mean)*0.3

        for _ in range(self.optim_iters):
            actions = torch.normal(mean.unsqueeze(1).expand(-1, self.candidates, -1,-1),
                                   std.unsqueeze(1).expand(-1, self.candidates, -1,-1))
            actions = actions.clamp(-1, 1)

            B,C,H,A = actions.shape
            acitons = actions.view(B*C, H, A)
            hs = h.repeat_interleave(C, dim=0)
            ss = s.repeat_interleave(C, dim=0)

            total_reward = torch.zeros(B*C).to(self.device)

            for t in range(H):
                at = actions[:,t]
                (hs, ss), prior, _,_,_ = self.rssm((hs, ss), at, embed_obs=None)
                reward = self.reward_model(hs, ss).squeeze(-1)
                total_reward += reward

            total_reward = total_reward.view(B,C)
            topk = torch.topk(total_reward, self.top_k, dim=-1).indices

            elites = []
            for i in range(B):
                elites.append(actions.view(B,C,H,A)[i, topk[i]])
            elites = torch.stack(elites)

            mean = elites.mean(dim=1)
            std = elites.std(dim=1)+1e-5

        return mean[:,0]