# PlaNet — Learning Latent Dynamics for Planning from Pixels
**Paper:** Hafner et al. (2019) · [arxiv.org/abs/1811.04551](https://arxiv.org/abs/1811.04551)

PlaNet learns a **world model** that predicts the future in a compressed latent space, then plans inside that space using the Cross-Entropy Method (CEM) — never planning in pixel space.

```
Pixel obs ──► Encoder ──► Posterior q(s|h,e) ──► (h, s)
                                                     │
                            GRU  h_t = f(s_{t-1}, a_{t-1}, h_{t-1})
                                                     │
                              Prior p(s|h) ◄── planning / dreaming
                                                     │
                            Decoder + Reward ──► reconstruct obs & reward
```

**File layout**
```
tutorial/
├── utils.py        ← PixelWrapper, EpisodeStorage, checkpoints,
│                       seed_storage, collect_episode, run_eval, visualize_dream
└── tutorial.ipynb  ← (this file) Config, all model architecture, training loop
```

In [None]:
# ── Cell 0: Install ────────────────────────────────────────────────────────────
!pip install -q gymnasium imageio imageio-ffmpeg matplotlib opencv-python-headless

In [None]:
# ── Cell 1: Mount Google Drive ────────────────────────────────────────────────
# All episodes and checkpoints are saved to Drive so training survives
# Colab runtime restarts.  On a local machine this block is skipped.
try:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=False)
    IN_COLAB = True
except ImportError:
    IN_COLAB = False
    print('Not in Colab — will save locally.')

In [None]:
# ── Cell 2: Imports ───────────────────────────────────────────────────────────
import sys
sys.path.insert(0, '/content/tutorial' if IN_COLAB else '.')

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import Dataset, DataLoader
import gymnasium as gym

# Plumbing (file I/O, env wrapper, data collection, visualizer)
from utils import (PixelWrapper, EpisodeStorage,
                   save_checkpoint, load_checkpoint,
                   seed_storage, collect_episode, run_eval,
                   visualize_dream)

In [None]:
# ── Cell 3: Config ────────────────────────────────────────────────────────────
class Config:
    # Hardware
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Environment — Pendulum-v1 has a single continuous torque action
    action_dim = 1
    obs_size   = 64   # pixel H = W

    # Training schedule
    total_iterations     = 100  # outer loop: train → collect → repeat
    train_steps_per_iter = 50   # gradient steps per iteration
    batch_size           = 16
    seq_len              = 25   # timesteps per training sequence
    lr                   = 1e-4

    # Data collection
    seed_episodes    = 5   # random episodes before CEM is used
    collect_episodes = 1   # CEM episodes collected per iteration

    # Persistence (Google Drive)
    drive_base     = '/content/drive/MyDrive/PlaNet'
    episode_dir    = f'{drive_base}/episodes'
    checkpoint_dir = f'{drive_base}/checkpoints'
    viz_dir        = f'{drive_base}/visualizations'
    checkpoint_every  = 10
    keep_checkpoints  = 5

config = Config()
print(f'Device: {config.device}')

---
## Model Architecture

| Module | Input → Output | Role |
|--------|---------------|------|
| **Encoder** | `(B,64,64,3)` → `(B,1024)` | Compress pixels to embedding |
| **GRU** | `(B,30+1)` → `(B,200)` | Track deterministic history `h_t` |
| **Prior** | `(B,200)` → `(B,30)×2` | Predict `s_t` from `h_t` only |
| **Posterior** | `(B,1224)` → `(B,30)×2` | Refine `s_t` using observation |
| **Decoder** | `(B,230)` → `(B,64,64,3)` | Reconstruct pixels from `(h,s)` |
| **Reward** | `(B,230)` → `(B,1)` | Predict reward from `(h,s)` |

In [None]:
# ── Cell 4: Encoder ──────────────────────────────────────────────────────────
# Four strided convolutions progressively halve spatial resolution:
#   64×64 → 32×32 → 16×16 → 8×8 → 4×4  (each with kernel=4, stride=2, pad=1)
# A final linear layer flattens 256×4×4=4096 features → 1024-D embedding.
# This embedding is passed to the Posterior to infer the stochastic state s_t.
class Encoder(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()
        self.cv1 = nn.Conv2d(in_channels, 32,  4, 2, 1)  # → (B,32,32,32)
        self.cv2 = nn.Conv2d(32,          64,  4, 2, 1)  # → (B,64,16,16)
        self.cv3 = nn.Conv2d(64,          128, 4, 2, 1)  # → (B,128,8,8)
        self.cv4 = nn.Conv2d(128,         256, 4, 2, 1)  # → (B,256,4,4)
        self.fc  = nn.Linear(256 * 4 * 4, 1024)

    def forward(self, x):
        # x: (B, H, W, C)  —  Conv2d expects (B, C, H, W)
        x = x.permute(0, 3, 1, 2)
        x = F.relu(self.cv1(x))
        x = F.relu(self.cv2(x))
        x = F.relu(self.cv3(x))
        x = F.relu(self.cv4(x))
        return self.fc(x.flatten(1))   # → (B, 1024)

In [None]:
# ── Cell 5: GRU · Prior · Posterior · RSSM ───────────────────────────────────
#
# The RSSM maintains two parallel state components:
#
#   Deterministic  h_t  — computed by the GRU from (s_{t-1}, a_{t-1}, h_{t-1})
#                         Carries long-range temporal information.
#
#   Stochastic     s_t  — sampled from either:
#       Posterior q(s|h,e): uses encoded obs during *training* (more accurate)
#       Prior     p(s|h):   no obs required, used for *planning / dreaming*
#
# The KL loss KL[q || p] forces the prior to track the posterior, making the
# prior reliable enough to plan with.

class GRU(nn.Module):
    """h_t = GRUCell([s_{t-1}, a_{t-1}], h_{t-1})"""
    def __init__(self, state_size=30, action_dim=1, hidden_size=200):
        super().__init__()
        # Input is the concatenation of stochastic state and action
        # (B, 30+1=31) → (B, 200)
        self.cell = nn.GRUCell(state_size + action_dim, hidden_size)

    def forward(self, s, a, h):
        return self.cell(torch.cat([s, a], dim=-1), h)  # (B,200)


class Prior(nn.Module):
    """p(s_t | h_t) — prediction from deterministic state alone.
    Returns (mean, std) of a diagonal Gaussian over s_t.
    The +0.1 floor on std prevents collapse (zero variance → infinite KL).
    """
    def __init__(self, hidden=200, out=30):
        super().__init__()
        self.fc  = nn.Linear(hidden, 256)  # (B,200) → (B,256)
        self.mu  = nn.Linear(256, out)     # → (B,30)
        self.std = nn.Linear(256, out)     # → (B,30)  (before softplus)

    def forward(self, h):
        x = F.relu(self.fc(h))
        return self.mu(x), F.softplus(self.std(x)) + 0.1


class Posterior(nn.Module):
    """q(s_t | h_t, e_t) — uses the encoded obs to improve the state estimate.
    Input is the concatenation of h (200) and e (1024) → 1224 features.
    """
    def __init__(self, out=30):
        super().__init__()
        self.fc  = nn.Linear(1024 + 200, 256)  # (B,1224) → (B,256)
        self.mu  = nn.Linear(256, out)
        self.std = nn.Linear(256, out)

    def forward(self, e, h):
        x = F.relu(self.fc(torch.cat([e, h], dim=-1)))
        return self.mu(x), F.softplus(self.std(x)) + 0.1


class RSSM(nn.Module):
    """Recurrent State Space Model — the core of PlaNet."""
    def __init__(self):
        super().__init__()
        self.encoder   = Encoder()
        self.gru       = GRU()
        self.prior     = Prior()
        self.posterior = Posterior()

    def obs_step(self, h, s, obs, a_prev):
        """One training step — uses real observation via posterior.
        Reparameterisation trick: s = mu + std * eps, eps ~ N(0,I)
        allows gradients to flow through the sampling operation.
        """
        h          = self.gru(s, a_prev, h)            # (B,200)
        e          = self.encoder(obs)                  # (B,1024)
        p_m, p_s   = self.prior(h)                     # (B,30) each
        q_m, q_s   = self.posterior(e, h)              # (B,30) each
        s          = q_m + q_s * torch.randn_like(q_m) # (B,30)  reparameterised
        return p_m, p_s, q_m, q_s, h, s

    def imagine_step(self, h, s, a):
        """One dream/planning step — prior only, NO encoder called."""
        h        = self.gru(s, a, h)
        p_m, p_s = self.prior(h)
        s        = p_m + p_s * torch.randn_like(p_m)
        return p_m, p_s, h, s

In [None]:
# ── Cell 6: Decoder · Reward ─────────────────────────────────────────────────
#
# Decoder mirrors the Encoder architecture using transposed convolutions.
# Input: concat(h, s) → (B, 230) → linear → (B, 4096) → reshape → (B,256,4,4)
# Then progressively upscale: 4×4 → 8×8 → 16×16 → 32×32 → 64×64
# Sigmoid output keeps pixel values in [0, 1] matching the normalised input.
#
# Reward head is a small 3-layer MLP over (h, s) → scalar.
# Training the reward head ensures the latent state encodes task-relevant info.

class Decoder(nn.Module):
    def __init__(self, state=30, hidden=200):
        super().__init__()
        self.fc = nn.Linear(state + hidden, 4096)        # (B,230) → (B,4096)
        self.d1 = nn.ConvTranspose2d(256, 128, 4, 2, 1)  # → (B,128,8,8)
        self.d2 = nn.ConvTranspose2d(128, 64,  4, 2, 1)  # → (B,64,16,16)
        self.d3 = nn.ConvTranspose2d(64,  32,  4, 2, 1)  # → (B,32,32,32)
        self.d4 = nn.ConvTranspose2d(32,  3,   4, 2, 1)  # → (B,3,64,64)

    def forward(self, h, s):
        x = F.relu(self.fc(torch.cat([h, s], dim=-1))).reshape(-1, 256, 4, 4)
        x = F.relu(self.d1(x))
        x = F.relu(self.d2(x))
        x = F.relu(self.d3(x))
        # permute back to (B, H, W, C) to match obs format throughout the code
        return torch.sigmoid(self.d4(x)).permute(0, 2, 3, 1)  # (B,64,64,3)


class Reward(nn.Module):
    def __init__(self, state=30, hidden=200, dim=400):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state + hidden, dim), nn.ReLU(),  # (B,230) → (B,400)
            nn.Linear(dim, dim),            nn.ReLU(),  # (B,400) → (B,400)
            nn.Linear(dim, 1),                          # (B,400) → (B,1)
        )

    def forward(self, h, s):
        return self.net(torch.cat([s, h], dim=-1))  # (B,1)

In [None]:
# ── Cell 7: World Model ───────────────────────────────────────────────────────
#
# Wraps RSSM + Decoder + Reward and processes full (T, B) sequences.
# Returns everything needed to compute the loss:
#   recons       : (T, B, 64, 64, 3) — pixel reconstructions
#   pred_rewards : (T, B, 1)         — predicted rewards
#   prior params : (T, B, 30) each   — p_mean, p_std
#   post  params : (T, B, 30) each   — q_mean, q_std
#   overshoot_kl : scalar            — latent overshooting regulariser
#
# Latent Overshooting (Section 3.3 of the paper)
# ─────────────────────────────────────────────────
# A standard 1-step KL [ q(s_t | h_t, e_t) || p(s_t | h_t) ] only teaches
# the prior to be accurate 1 step ahead — but CEM plans 12 steps out.
#
# Overshooting forces the prior to also match the posterior at depths 1…D:
#   KL[ p(s_{t+d} | imagine_d(h_t)) || q(s_{t+d} | h_{t+d}, e_{t+d}) ]
# This regularises the prior across multiple steps, so the model's
# imagination stays reliable long into the future.

class WorldModel(nn.Module):
    def __init__(self, overshoot_d=5):
        super().__init__()
        self.rssm        = RSSM()
        self.decoder     = Decoder()
        self.reward_head = Reward()
        self.overshoot_d = overshoot_d

    def forward(self, obs_seq, action_seq):
        # obs_seq    : (T, B, 64, 64, 3) float32 in [0, 1]
        # action_seq : (T, B, action_dim)
        T, B   = obs_seq.shape[:2]
        device = obs_seq.device

        # Initialise hidden states to zero — no prior episode history
        h = torch.zeros(B, 200, device=device)  # deterministic state
        s = torch.zeros(B,  30, device=device)  # stochastic  state

        recons, rewards         = [], []
        p_means, p_stds         = [], []
        q_means, q_stds         = [], []
        h_all,   s_all          = [], []

        for t in range(T):
            # Use previous action as context (a_{t-1}); zeros at t=0
            a_prev = action_seq[t-1] if t > 0 else torch.zeros_like(action_seq[0])
            p_m, p_s, q_m, q_s, h, s = self.rssm.obs_step(h, s, obs_seq[t], a_prev)

            recons.append(self.decoder(h, s))       # (B,64,64,3)
            rewards.append(self.reward_head(h, s))  # (B,1)
            p_means.append(p_m); p_stds.append(p_s)
            q_means.append(q_m); q_stds.append(q_s)
            h_all.append(h);     s_all.append(s)

        # Latent overshooting KL — computed across all (t, depth d) pairs
        os_kl = []
        for t in range(T - 1):
            hi = h_all[t].detach()  # detach: overshooting does not backprop
            si = s_all[t].detach()  # through the main sequence
            D  = min(self.overshoot_d, T - 1 - t)
            for d in range(1, D + 1):
                im_m, im_s, hi, si = self.rssm.imagine_step(hi, si, action_seq[t+d])
                tgt_m = q_means[t+d].detach()
                tgt_s = q_stds[t+d].detach()
                # KL[ imagined || posterior_target ]  (closed form for Gaussians)
                kl = (torch.log(im_s / tgt_s)
                      + (tgt_s**2 + (tgt_m - im_m)**2) / (2 * im_s**2) - 0.5)
                os_kl.append(kl.sum(-1).mean())

        overshoot_kl = (torch.stack(os_kl).mean()
                        if os_kl else torch.tensor(0., device=device))

        return (torch.stack(recons), torch.stack(rewards),
                torch.stack(p_means), torch.stack(p_stds),
                torch.stack(q_means), torch.stack(q_stds),
                overshoot_kl)

In [None]:
# ── Cell 8: Loss + Train Step ─────────────────────────────────────────────────
#
# Three loss terms (Eq. 1 of the paper):
#
#  1. Reconstruction loss  — MSE between decoded and real pixels.
#     Summed over pixels, averaged over batch+time.
#     This is the *primary* signal that teaches the encoder/decoder to
#     compress and reconstruct the visual world.
#
#  2. Reward loss  — MSE between predicted and real reward scalars.
#     Keeps the latent state task-relevant (not just visually faithful).
#
#  3. KL divergence  — Encourages the prior p(s|h) to track the posterior
#     q(s|h,e).  This is what makes the prior usable at planning time.
#     beta scales the KL to prevent it from dominating early training.
#     beta_os applies the same weighting to the overshooting KL.

def planet_loss(recon, obs, pred_r, true_r, p_m, p_s, q_m, q_s, os_kl,
                beta=0.1, beta_os=0.1):
    # Recon: sum over spatial dims → per-frame scalar, then average
    recon_loss  = F.mse_loss(obs, recon, reduction='none').sum(dim=[-1,-2,-3]).mean()
    reward_loss = F.mse_loss(true_r.unsqueeze(-1), pred_r).mean()
    # Closed-form KL between two diagonal Gaussians: q || p
    kl = (torch.log(p_s / q_s)
          + (q_s**2 + (q_m - p_m)**2) / (2 * p_s**2) - 0.5).sum(-1).mean()
    return recon_loss + reward_loss + beta * kl + beta_os * os_kl


def train_step(model, optimizer, obs, action, reward, device):
    """One gradient update.  Clips grad norm to prevent exploding gradients."""
    recon, pred_r, p_m, p_s, q_m, q_s, os_kl = model(obs, action)
    loss = planet_loss(recon, obs, pred_r, reward, p_m, p_s, q_m, q_s, os_kl)
    optimizer.zero_grad()
    loss.backward()
    clip_grad_norm_(model.parameters(), max_norm=100.0)
    optimizer.step()
    return loss.item()

In [None]:
# ── Cell 9: CEM Planner ───────────────────────────────────────────────────────
#
# Cross-Entropy Method (CEM) planning — no policy gradient, no value function.
#
# Algorithm (per action selection):
#   1. Sample num_candidates action sequences from a Gaussian N(mu, std)
#   2. Roll each sequence forward in LATENT SPACE (imagine_step × n_steps)
#   3. Score each sequence by summing predicted rewards
#   4. Refit mu and std to the top_k scoring sequences
#   5. Repeat n_iter times (iterative refinement)
#   6. Execute only the FIRST action from the best sequence (receding horizon)
#
# CEM operates entirely in latent space, so each evaluation costs only a
# forward pass through the GRU + Prior + Reward — no pixel rendering needed.

class CEMPlanner:
    def __init__(self, model, num_candidates=1000, top_k=100,
                 n_steps=12, n_iter=10, action_dim=1):
        self.model          = model
        self.num_candidates = num_candidates  # random action sequences per step
        self.top_k          = top_k           # elite fraction for refitting
        self.n_steps        = n_steps         # planning horizon
        self.n_iter         = n_iter          # CEM refinement iterations
        self.action_dim     = action_dim

    @torch.no_grad()
    def plan(self, h, s, device):
        # Initialise Gaussian belief over action sequences
        mu  = torch.zeros(self.n_steps, self.action_dim, device=device)
        std = torch.ones_like(mu)

        for _ in range(self.n_iter):
            # Sample: (num_candidates, n_steps, action_dim)
            acts = (mu + std * torch.randn(self.num_candidates, self.n_steps,
                                           self.action_dim, device=device)).clamp(-1, 1)

            # Broadcast h, s to batch size = num_candidates
            H = h.expand(self.num_candidates, -1)  # (K, 200)
            S = s.expand(self.num_candidates, -1)  # (K, 30)
            G = torch.zeros(self.num_candidates, device=device)  # cumulative reward

            for t in range(self.n_steps):
                # Imagine one step forward using the prior (no obs)
                _, _, H, S = self.model.rssm.imagine_step(H, S, acts[:, t])
                G         += self.model.reward_head(H, S).squeeze(-1)

            # Select top_k elite sequences and refit Gaussian
            top = G.topk(self.top_k).indices
            mu  = acts[top].mean(0)
            std = acts[top].std(0).clamp(min=1e-4)

        return mu[0]  # first action of the refined best sequence

In [None]:
# ── Cell 10: Episode Dataset ──────────────────────────────────────────────────
#
# Why cache episodes in RAM instead of reading from Drive each step?
# ──────────────────────────────────────────────────────────────────
# With virtual_len = batch_size × train_steps (e.g. 800 samples), the DataLoader
# would make 800 Drive file-reads per iteration at ~10 ms/read = 8 seconds of
# pure I/O before one gradient step.  Loading all episodes once at iteration
# start costs the same total I/O but amortises it over the full iteration.
#
# virtual_len trick
# ──────────────────
# The DataLoader needs __len__ to know how many batches to produce.  If we
# returned len(episodes) the loader would produce at most 1 batch when only 5
# episodes exist (because 5 < batch_size=16).  By reporting a virtual length
# equal to batch_size × train_steps we always get exactly train_steps batches.
# __getitem__ ignores the DataLoader index and samples randomly instead.

class EpisodeDataset(Dataset):
    def __init__(self, storage, seq_len, virtual_len):
        self.seq_len = seq_len
        self.virtual_len = virtual_len
        self.eps = []
        for path in storage.episode_paths:
            with np.load(path) as ep:
                if len(ep['reward']) >= seq_len:
                    # np.array() ensures a fresh, writeable, contiguous buffer
                    # (NpzFile arrays are read-only; torch.from_numpy needs write access)
                    self.eps.append({k: np.array(ep[k]) for k in ep})
        if not self.eps:
            raise RuntimeError('No episodes >= seq_len. Collect more data first.')
        print(f'Dataset: {len(self.eps)} episodes cached in RAM')

    def __len__(self):  return self.virtual_len

    def __getitem__(self, _):
        # Two levels of randomness: which episode, and which start frame
        ep = self.eps[np.random.randint(len(self.eps))]
        T  = len(ep['reward'])
        i  = np.random.randint(0, T - self.seq_len + 1)
        # Normalise obs from uint8 [0,255] → float32 [0,1]
        obs = ep['obs'][i:i+self.seq_len].astype(np.float32) / 255.0
        return (torch.from_numpy(obs),
                torch.from_numpy(ep['action'][i:i+self.seq_len]),
                torch.from_numpy(ep['reward'][i:i+self.seq_len]),
                torch.from_numpy(ep['terminal'][i:i+self.seq_len].astype(np.float32)))


def collate_fn(batch):
    """Stack items along batch dim and permute to (T, B, …).
    WorldModel.forward loops over the time axis (dim 0), so sequences
    must be (T, B, …) not the default PyTorch (B, T, …).
    """
    obs, act, rew, term = zip(*batch)
    return (torch.stack(obs).permute(1,0,2,3,4),   # (T,B,64,64,3)
            torch.stack(act).permute(1,0,2),         # (T,B,action_dim)
            torch.stack(rew).permute(1,0),            # (T,B)
            torch.stack(term).permute(1,0))           # (T,B)

In [None]:
# ── Cell 11: Instantiate Everything ──────────────────────────────────────────
base_env = gym.make('Pendulum-v1', render_mode='rgb_array')
env      = PixelWrapper(base_env, render_size=64)

model     = WorldModel(overshoot_d=5).to(config.device)
optimizer = optim.AdamW(model.parameters(), lr=config.lr)
planner   = CEMPlanner(model, action_dim=config.action_dim)
storage   = EpisodeStorage(config.episode_dir)

print(f'Parameters: {sum(p.numel() for p in model.parameters()):,}')

In [None]:
# ── Cell 12: Training Loop ────────────────────────────────────────────────────
#
# Each iteration follows the PlaNet cycle:
#   a. Load all episodes from Drive into RAM
#   b. Run train_steps gradient updates on random subsequences
#   c. Collect one new episode with the improved CEM planner
#   d. Optionally checkpoint to Drive
#
# load_checkpoint resumes automatically if a .pt file exists on Drive;
# returns 0 (start from scratch) otherwise — no explicit if-guard needed.

start_iter = load_checkpoint(model, optimizer, config)
device     = config.device

# Seed with random episodes only on the very first run
if len(storage) == 0:
    seed_storage(env, storage, config.seed_episodes)

for iteration in range(start_iter, config.total_iterations):
    print(f'\nIteration {iteration+1}/{config.total_iterations}  | Episodes: {len(storage)}')

    # ── a. Build DataLoader from current Drive episodes ───────────────────
    virtual_len = config.batch_size * config.train_steps_per_iter
    dataset = EpisodeDataset(storage, config.seq_len, virtual_len)
    loader  = DataLoader(
        dataset, batch_size=config.batch_size,
        shuffle=False,    # randomness is in __getitem__
        num_workers=0,    # 0 avoids fork-deadlock in Colab
        collate_fn=collate_fn, pin_memory=False, drop_last=True,
    )

    # ── b. Gradient updates ───────────────────────────────────────────────
    model.train()
    total_loss, steps = 0.0, 0
    for obs_b, act_b, rew_b, _ in loader:
        if steps >= config.train_steps_per_iter:
            break
        total_loss += train_step(model, optimizer,
                                 obs_b.to(device), act_b.to(device), rew_b.to(device), device)
        steps += 1
    print(f'  Avg Loss: {total_loss / max(steps,1):.4f}  ({steps} steps)')

    # ── c. Collect one new CEM episode ────────────────────────────────────
    collect_episode(env, model, planner, storage, config)

    # ── d. Checkpoint ─────────────────────────────────────────────────────
    if (iteration + 1) % config.checkpoint_every == 0:
        save_checkpoint(model, optimizer, iteration, config)

In [None]:
# ── Cell 13: Evaluate ─────────────────────────────────────────────────────────
# run_eval runs 5 episodes with the CEM planner and prints per-episode rewards.
# Reward range for Pendulum-v1:  ~-1600 (random)  →  ~-150 (well-trained)
avg_reward = run_eval(env, model, planner, n_episodes=5, config=config)
save_checkpoint(model, optimizer, config.total_iterations - 1, config)

In [None]:
# ── Cell 14: Dream Visualizer ─────────────────────────────────────────────────
# visualize_dream shows how well the model can imagine the future:
#   • Green (CONTEXT): real frames fed via encoder — warms up latent state
#   • Red   (DREAM):   prior rolls out with NO real pixels — pure imagination
#
# Good model: dream frames match real frames for many steps
# Weak model: dream frames blur or drift quickly after context ends
print('Generating dream visualisations...')
for ep_idx in range(min(3, len(storage))):
    visualize_dream(model, storage, config,
                    episode_idx=ep_idx, context_frames=5, dream_steps=50)
print(f'Saved to: {config.viz_dir}')