In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_

In [None]:
import gym
from gym import spaces
import cv2
import numpy as np

In [10]:

class Encoder(nn.Module):
  def __init__(self, in_channels=3):
    super().__init__()
    # 64 x 64 x 3
    # 0 = (I - K + 2p)/s
    self.cv1 = nn.Conv2d(in_channels, 32, 4, 2, 1) # 32 x 32 x 32
    self.cv2 = nn.Conv2d(32, 64, 4, 2, 1) # 64 x 16 x 16
    self.cv3 = nn.Conv2d(64, 128, 4, 2, 1) # 128 x 8 x 8
    self.cv4 = nn.Conv2d(128, 256, 4, 2, 1) # 256 x 4 x 4
    self.fn = nn.Linear(256*4*4, 1024)
  def forward(self, x):
    x = x.permute(0, 3, 1, 2) # batch x 3 x 64 x 64
    x = F.relu(self.cv1(x))
    x = F.relu(self.cv2(x))
    x = F.relu(self.cv3(x))
    x = F.relu(self.cv4(x))
    x = x.reshape(x.size(0), -1)
    x = self.fn(x)
    return x

In [11]:
img = torch.rand([1, 64, 64, 3])
enc = Encoder(3)
print(enc(img).shape)

torch.Size([1, 1024])


### GRU at Time Step $t$

**Inputs:**
*   $x_t$: Current input
*   $h_{t-1}$: Previous hidden state

**Equations:**

*   **Reset gate ($r_t$):** Controls how much past state is used to form new content.
    $$r_t = \sigma(W_r x_t + U_r h_{t-1})$$

*   **Update gate ($z_t$):** Controls how much of the old state is kept.
    $$z_t = \sigma(W_z x_t + U_z h_{t-1})$$

*   **Candidate hidden state ($h_t$):** An RNN-like state, gated by the reset gate.
    $$h_t = \tanh(W_h x_t + U_h (r_t \odot h_{t-1}))$$

*   **Final hidden state ($h_t$):** Interpolation between the old and new candidate states.
    $$h_t = z_t \odot h_{t-1} + (1 - z_t) \odot \tilde{h}_t$$

In [None]:

# """
# The reset gate determines how much of the old state is used to construct new information.
# The update gate determines how much of the old state is kept versus replaced.
# """
# class GRU(nn.Module):
#   def __init__(self, input_size, hidden_size=200):
#     super().__init__()
#     hidden_size = hidden_size
#     self.wr = nn.Linear(input_size, hidden_size)
#     self.ur = nn.Linear(hidden_size, hidden_size)

#     self.wz = nn.Linear(input_size, hidden_size)
#     self.uz = nn.Linear(hidden_size, hidden_size)


#     self.wh = nn.Linear(input_size, hidden_size)
#     self.uh = nn.Linear(hidden_size, hidden_size)
#     self.tanh = nn.Tanh()
#     self.sigmoid = nn.Sigmoid()

#   def forward(self, x, h_old):
#     r = self.sigmoid(self.wr(x) + self.ur(h_old))
#     z = self.sigmoid(self.wz(x) + self.uz(h_old))
#     h_beta = self.tanh(self.wh(x) + self.uh(r*h_old))
#     h = z* h_old + (1 - z)*h_beta
#     return h



In [None]:
class GRU(nn.Module):
  def __init__(self, state_size=30, action_dim=4, hidden_size=200):
    super().__init__()
    input_size = state_size + action_dim
    self.gru = nn.GRUCell(input_size, hidden_size)
  def forward(self, s_t, a_t, h_old):
    x = torch.cat([s_t, a_t], dim=-1)
    return self.gru(x, h_old)

In [22]:
class Posterior(nn.Module):
  def __init__(self, input_size=1224, output = 30):
    super().__init__()
    # 1024 + 200 -> e_t + h_t
    input_size = 1024 + 200
    self.fc = nn.Linear(input_size, 256)
    self.fc_mu = nn.Linear(256, output)
    self.fc_std = nn.Linear(256, output)

  def forward(self, e_t, h_t):
    eps = 0.1 # for stability and ensure model is never 100% confident
    x = torch.cat([e_t, h_t], dim=-1)
    x = F.relu(self.fc(x))
    mean = self.fc_mu(x)
    std = F.softplus(self.fc_std(x)) + eps
    return mean, std

In [25]:
class Prior(nn.Module):
  def __init__(self, hidden_size=200, output_size=30):
    super().__init__()
    self.fc = nn.Linear(hidden_size, 256)
    self.fc_mu = nn.Linear(256, output_size)
    self.fc_std = nn.Linear(256, output_size)
  def forward(self, h):
    eps = 0.1
    x = F.relu(self.fc(h))
    mu = self.fc_mu(x)
    std = F.softplus(self.fc_std(x)) + eps

    return mu, std


In [16]:
class RSSM(nn.Module):
  def __init__(self):
    super().__init__()
    self.encoder = Encoder()
    self.gru = GRU()
    self.posterior = Posterior()
    self.prior = Prior()


  def obs_step(self, h_old, s_old, obs, a_t):
    h = self.gru(s_old, a_t, h_old)
    e = self.encoder(obs)
    m_pr, std_pr = self.prior(h)
    m_po, std_po = self.posterior(e, h)
    s = m_po + std_po*torch.randn_like(m_po)

    return m_pr, std_pr, m_po, std_po, h, s


  def imagine_step(self, h_old, s_old, a_t):
    h = self.gru(s_old, a_t, h_old)
    m_pr, std_pr = self.prior(h)
    s = m_pr + std_pr * torch.randn_like(m_pr)

    return m_pr, std_pr, h, s



In [None]:
class Decoder(nn.Module):
  def __init__(self, state_size=30, hidden_size=200):
    """
    output shape of conv2dTranspose -> 0 = (I - 1)*s-2P+k
    """
    super().__init__()
    input_size = state_size +  hidden_size
    self.fc1 = nn.Linear(input_size, 4096) # 256*4*4
    self.dec1 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1) # 256 x 4 x 4 -> 128 x 8 x 8
    self.dec2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, ) # 128 x 8 x 8 -> 64 x 16 x 16
    self.dec3 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1, ) # 64 x 16 x 16 -> 32 x 32 x 32
    self.dec4 = nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1, ) # 32 x 32 x 32-> 3 x 64 x 64

  def forward(self, h, s):
    x = torch.cat([h, s], dim=-1)
    x = F.relu(self.fc1(x))
    x = x.reshape(-1, 256, 4, 4)
    x = F.relu(self.dec1(x))
    x = F.relu(self.dec2(x))
    x = F.relu(self.dec3(x))
    x = torch.sigmoid(self.dec4(x))
    obs = x.permute(0, 2, 3, 1)

    return obs


In [18]:
h = torch.rand(200)
s = torch.rand(30)
dec = Decoder()
img = dec(h, s)
print(img.shape)

torch.Size([1, 64, 64, 3])


In [19]:
class Reward(nn.Module):
  def __init__(self, s_size=30, hidden_size=200, hidden_dim=400):
    super().__init__()
    input_size=s_size + hidden_size
    self.fc1 = nn.Linear(input_size, hidden_dim)
    self.fc2 = nn.Linear(hidden_dim, hidden_dim)
    self.fc3 = nn.Linear(hidden_dim, 1)

  def forward(self, h, s):
    x = torch.cat([s, h], dim=-1)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))

    return self.fc3(x)


In [None]:
class WorldModel(nn.Module):
  def __init__(self, overshoot_d = 5):
    super().__init__()
    # self.encoder = Encoder() # obs -> e
    self.rssm = RSSM() # h_old, s_old, obs, a_t -> m_pr, std_pr, m_po, std_po, h, s
    self.decoder = Decoder() # h, s -> obs
    self.reward = Reward() # h, s -> r
    self.overshoot_d = overshoot_d


  def forward(self, obs_seq, action_seq):
    # obs_seq -> (step, B, 3, 64, 64)
    # action_seq -> (step, B, action_dim)
    T, B = obs_seq.shape[:2]
    device = obs_seq.device

    h = torch.zeros(B, 200).to(device)
    s = torch.zeros(B, 30).to(device)

    recon_img , pred_reward = [], []
    prior_mean, prior_std = [], []
    post_mean, post_std = [], []

    h_all, s_all = [], []

    for t in range(T):
      # RSSM
      # Fix 4: use previous action a_{t-1} — the transition h_t = f(h_{t-1}, s_{t-1}, a_{t-1})
      prev_action = action_seq[t-1] if t > 0 else torch.zeros_like(action_seq[0])
      p_m, p_s, q_m, q_s, h, s = self.rssm.obs_step(h, s, obs_seq[t], prev_action)

      # Decode
      recon_img_t = self.decoder(h, s)

      # Reward
      pred_reward_t = self.reward(h, s)

      # COllecting distribution params
      prior_mean.append(p_m)
      prior_std.append(p_s)
      post_mean.append(q_m)
      post_std.append(q_s)

      h_all.append(h)
      s_all.append(s)

      # Decode and Reward containers
      recon_img.append(recon_img_t)
      pred_reward.append(pred_reward_t)
    
    overshoot_kl_terms = []

    for t in range(T - 1):
      h_im = h_all[t].detach()
      s_im = s_all[t].detach()

      D = min(self.overshoot_d, T - 1 - t)

      for d in range(1, D + 1):
        im_m, im_s, h_im, s_im = self.rssm.imagine_step(
          h_im, s_im, action_seq[t + d]
        )
        target_m = post_mean[t + d].detach()
        target_s = post_std[t + d].detach()

        kl = (
          torch.log(im_s / target_s) +
          (target_s**2 + (target_m - im_m)**2) / (2 * im_s**2)
          - 0.5
        )
        overshoot_kl_terms.append(kl.sum(dim=-1).mean())
    if overshoot_kl_terms:
      overshoot_kl = torch.stack(overshoot_kl_terms).mean()
    else:
      overshoot_kl = torch.tensor(0.0, device=device)

    return (torch.stack(recon_img),
            torch.stack(pred_reward),
            torch.stack(prior_mean),
            torch.stack(prior_std),
            torch.stack(post_mean),
            torch.stack(post_std),
            overshoot_kl
    )



In [None]:
obs_seq = torch.rand(1, 1, 64, 64, 3)
action_seq = torch.rand(1, 1, 4)
print(obs_seq.shape, action_seq.shape)
model = WorldModel()
print(model(obs_seq, action_seq))

$$D_{KL}(Q \parallel P) = \log \left( \frac{\sigma_p}{\sigma_q} \right) + \frac{\sigma_q^2 + (\mu_q - \mu_p)^2}{2\sigma_p^2} - \frac{1}{2}$$

In [None]:
def calculate_loss(recon_img, img, reward, pred_reward, p_m, p_s, q_m, q_s,overshoot_kl, beta=0.1, beta_overshoot=0.1):
  recon_loss = F.mse_loss(img, recon_img, reduction='none').sum(dim=[-1, -2, -3]).mean()
  pred_loss = F.mse_loss(reward.unsqueeze(-1), pred_reward, reduction='none').mean()
  kl_loss = torch.log(p_s/q_s) + ((q_s**2 + (q_m - p_m)**2)/(2*p_s**2)) - 0.5
  kl_loss = kl_loss.sum(dim=-1).mean()
  return recon_loss + pred_loss + beta * kl_loss  + beta_overshoot * overshoot_kl

In [None]:

class ReplayBuffer:
  def __init__(self, capacity, obs_shape, action_dim):
    self.capacity = capacity
    self.idx = 0
    self.is_full = False

    self.obs_buffer = np.empty((capacity, *obs_shape), dtype=np.uint8)
    self.action_buffer = np.empty((capacity, action_dim), dtype=np.float32)
    self.reward_buffer = np.empty((capacity,), dtype=np.float32)
    self.terminal_buffer = np.empty((capacity,), dtype=bool)

  def add(self, obs, action, reward, terminal):
    self.obs_buffer[self.idx] = obs
    self.action_buffer[self.idx] = action
    self.reward_buffer[self.idx] = reward
    self.terminal_buffer[self.idx] = terminal


    self.idx = (self.idx + 1) % self.capacity

    if self.idx == 0:
      self.is_full = True


  def _is_valid(self, start, seq_len, current_capacity):
      """
        Check if a sequence starting at 'start' is valid:
        1. It doesn't fall off the end of the physical array.
        2. It doesn't cross the circular 'write head' (self.idx).
        3. It doesn't contain a terminal state in the middle (only allowed at the end).
      """
      end = start + seq_len

      # Bound Check
      if end > current_capacity:
        return False
      # 2. Seam check: if buffer is full, sequence cannot cross self.idx
      if self.is_full:
        if start < self.idx < end:
          return False
      # 3. Episode boundary check: no terminals allowed in the middle of the clip
      # We check from start to end-1 because the last frame can be terminal.
      if self.terminal_buffer[start:end-1].any():
        return False


      return True



  def sample(self, batch_size, seq_len, device):
    current_capacity = self.capacity if self.is_full else self.idx

    if current_capacity < seq_len:
      return None

    indices = []
    attempts = 0
    max_attempts = batch_size * 100
    while len(indices) < batch_size and attempts < max_attempts:
      start = np.random.randint(0, current_capacity)
      if self._is_valid(start, seq_len, current_capacity):
        indices.append(start)
      attempts += 1

    if len(indices) < batch_size:
      return None # Not enough valid sequences found

    obs_batch, action_batch, reward_batch, terminal_batch = [], [], [], []

    for start in indices:
      end = start + seq_len
      obs_batch.append(self.obs_buffer[start:end])
      action_batch.append(self.action_buffer[start:end])
      reward_batch.append(self.reward_buffer[start:end])
      terminal_batch.append(self.terminal_buffer[start:end])

    obs_batch = np.stack(obs_batch)
    action_batch = np.stack(action_batch)
    reward_batch = np.stack(reward_batch)
    terminal_batch = np.stack(terminal_batch)

    return self._post_process(obs_batch, action_batch, reward_batch, terminal_batch, device)

  def _post_process(self, obs, action, reward, terminal, device):
    obs = torch.as_tensor(obs, device=device).float()
    action = torch.as_tensor(action, device=device).float()
    reward = torch.as_tensor(reward, device=device).float()
    terminal = torch.as_tensor(terminal, device=device).float()

    obs = obs/255.0

    obs = obs.permute(1, 0, 2, 3, 4)
    action = action.transpose(1, 0)
    reward = reward.transpose(1, 0)
    terminal = terminal.transpose(1, 0)

    return obs, action, reward, terminal




In [None]:
model = WorldModel()
optimizer = optim.AdamW(model.parameters(), lr=0.01)

def train_step(model, optimizer, obs, action, reward, device):
  recon_image, pred_reward, p_m, p_s, q_m, q_s, overshoot_kl = model(obs, action)

  optimizer.zero_grad()
  loss = calculate_loss(recon_image, obs, reward, pred_reward, p_m, p_s, q_m, q_s,overshoot_kl)
  loss.backward()
  clip_grad_norm_(
      model.parameters(),
      max_norm=100.0
  )
  optimizer.step()

  return loss.item()




In [None]:
class CEMPlanner:
    def __init__(self, model, num_candidates=1000, top_k=100,  n_steps=12, iteration=10, action_dim=4):
        self.model = model
        self.num_candidates = num_candidates
        self.top_k = top_k
        self.n_steps = n_steps
        self.iteration = iteration
        self.action_dim = action_dim
        self.epsillon = 0.1
    
    @torch.no_grad()
    def plan(self, h, s):
        if (h.ndim == 1):
            h = h.unsqueeze(0)
            s = s.unsqueeze(0)
        device = h.device

        mean = torch.zeros(self.n_steps, self.action_dim, device=device)
        std = torch.ones(self.n_steps, self.action_dim, device=device)


        for i in range(self.iteration):
            noise = torch.randn(self.num_candidates, self.n_steps, self.action_dim, device=device)
            actions = mean.unsqueeze(0) + std.unsqueeze(0)*noise

            actions = actions.clamp(-2.0, 2.0)  # Fix 3: Pendulum-v1 torque range is [-2, 2]

            # (1, 200) -> (num_candidates, 200)
            h_im = h.expand(self.num_candidates, -1)
            s_im = s.expand(self.num_candidates, -1)

            total_rewards = torch.zeros(self.num_candidates, device=device)

            for t in range(self.n_steps):
                # actions[:, t] → (1000, 4) — all candidates' action at step t
                _, _, h_im, s_im = self.model.rssm.imagine_step(h_im, s_im, actions[:, t])
                reward = self.model.reward(h_im, s_im) # o -> (1000, 1)
                total_rewards += reward.squeeze(-1)

            top_indices = total_rewards.topk(self.top_k).indices

            elite_actions = actions[top_indices]

            mean = elite_actions.mean(dim=0) # (100, 12, 4) -> (12, 4)
            std = elite_actions.std(dim=0) + self.epsillon

        return mean[0] #(4,)





In [None]:
def collect_experience(env, model, planner, replay_buffer, num_episodes, device):
     """
    Collects experience by interacting with the environment using the 
    CEM planner and stores it in the replay buffer.
    """
    model.eval()

    for _ in range(num_episodes):
        obs, _ = env.reset()
        done = False

        h = torch.zeros(1, 200, device=device)
        s = torch.zeros(1, 30, device=device)

        # Fix 5: encode the first real observation before planning
        with torch.no_grad():
            init_obs = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0).unsqueeze(0) / 255.0
            dummy_a  = torch.zeros(1, planner.action_dim, device=device)
            _, _, _, _, h, s = model.rssm.obs_step(h, s, init_obs[0], dummy_a)

        while not done:
            action = planner.plan(h, s)
            action_np = action.cpu().numpy()

            next_obs, reward, terminated, truncated, _ = env.step(action_np)
            done = terminated or truncated

            replay_buffer.add(obs, action_np, reward, done)

            with torch.no_grad():
                obs_tensor = torch.tensor(next_obs, dtype=torch.float32, device=device)
                # Add Batch and Time dims
                obs_tensor = obs_tensor.unsqueeze(0).unsqueeze(0) / 255.0
                action_tensor = action.unsqueeze(0).unsqueeze(0)

                # Use the RSSM observation step to update h and s based on the real observation
                _, _, _, _, h, s = model.rssm.obs_step(h,s, obs_tensor[0], action_tensor[0])
            
            obs = next_obs


In [None]:
def train_planner(env, model, optimizer, planner, replay_buffer, config):
    device = config.device

    # 1. Seed the Replay Buffer with random experience
    # Initially, the model knows nothing, so planning is useless. 
    # We collect uniform random actions to get a starting dataset.


    print("Collecting initial random seed experience...")

    for _ in range(config.seed_episodes):
        obs, _ = env.reset()
        done = False
        while not done:
            action = env.action_space.sample()
            next_obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            replay_buffer.add(obs, action, reward, done)
            obs = next_obs

    # 2. Main Training Iterations
    for iteration in range(config.total_iterations):
        print(f"Iteration {iteration + 1}/{config.total_iterations}")
        model.train()
        total_loss = 0

        for _ in range(config.train_steps_per_iteration):
            batch = replay_buffer.sample(config.batch_size, config.seq_len, device)

            if batch is None:
                continue

            obs_batch, action_batch, reward_batch, terminal_batch = batch

            loss = train_step(model, optimizer, obs_batch, action_batch, reward_batch, device)

            total_loss += loss
        print(f"Average Loss: {total_loss / config.train_steps_per_iteration:.4f}")

        collect_experience(env, model, planner, replay_buffer, config.collect_episodes, device)

            




In [None]:
class Config:
    def __init__(self):
        # Hardware
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Training loop parameters
        self.seed_episodes = 5              # Initial episodes with random actions to seed the buffer
        self.total_iterations = 100         # Total number of iterations (train + collect cycles)
        self.train_steps_per_iteration = 1000 # Number of gradient updates per iteration
        self.collect_episodes = 1           # Number of new episodes to collect after training
        
        # Replay Buffer parameters
        self.batch_size = 50                # Number of sequences in a batch
        self.seq_len = 50                   # Length of each sequence chunk sampled for BPTT
        self.capacity = 100000              # Total transitions to store in the replay buffer
        
        # Environment parameters
        self.obs_shape = (64, 64, 3)        # Shape of the environment observations
        self.action_dim = 1                 # Fix 2: Pendulum-v1 has a 1-D continuous action
        
# Usage:
config = Config()
# train_planner(env, model, optimizer, planner, replay_buffer, config)


In [None]:
class PixelWrapper(gym.Wrapper):
    def __init__(self, env, render_size=64):
        super().__init__(env)
        self.render_size = render_size

        # Override the observation space to be an image
        self.observation_space = spaces.Box(
            low=0, high=255,
            shape=(render_size, render_size, 3),
            dtype=np.uint8
        )
    def _get_pixels(self):
        # Render the environment as an RGB array
        img = self.env.render()
            
        # Resize to exactly 64x64
        if img.shape[:2] != (self.render_size, self.render_size):
            img = cv2.resize(img, (self.render_size, self.render_size), interpolation=cv2.INTER_AREA)
            
        return img
    def reset(self, **kwargs):
        _ = self.env.reset(**kwargs)
        return self._get_pixels(), {}
    def step(self, action):
        _, reward, terminated, truncated, info = self.env.step(action)
        return self._get_pixels(), reward, terminated, truncated, info

In [None]:
def evaluate_planer(env, model, planner, num_episode, device):
    """
    Evaluates the World Model's performance in the environment.
    """
    model.eval()
    total_rewards = []

    for ep in range(num_episode):
        obs, _ = env.reset()
        done = False
        episode_reward = 0

        h = torch.zeros(1, 200, device=device)
        s = torch.zeros(1, 30, device=device)

        # Fix 5: encode the first real observation before planning
        with torch.no_grad():
            init_obs = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0).unsqueeze(0) / 255.0
            dummy_a  = torch.zeros(1, planner.action_dim, device=device)
            _, _, _, _, h, s = model.rssm.obs_step(h, s, init_obs[0], dummy_a)

        while not done:
            with torch.no_grad():
                action = planner.plan(h, s)
                action_np = action.cpu().numpy()
                next_obs, reward, terminated, truncated, _ = env.step(action_np)

                done = terminated or truncated
                episode_reward += reward
                obs_tensor = torch.tensor(next_obs, dtype=torch.float32, device=device)

                obs_tensor = obs_tensor.unsqueeze(0).unsqueeze(0) / 255.0
                action_tensor = action.unsqueeze(0).unsqueeze(0)
                _, _, _, _, h, s = model.rssm.obs_step(h, s, obs_tensor[0], action_tensor[0])
        total_rewards.append(episode_reward)
        print(f"Eval Episode {ep+1}: Reward = {episode_reward:.2f}")
    avg_reward = np.mean(total_rewards)
    print(f"Average Evaluation Reward: {avg_reward:.2f}")
    return avg_reward



In [None]:
base_env = gym.make('Pendulum-v1', render_mode='rgb_array')

env = PixelWrapper(base_env, render_size=64)

config = Config()
device = config.device

model = WorldModel().to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

planner = CEMPlanner(model, action_dim=config.action_dim)  # Fix 1: corrected typo CEMPlanner; pass action_dim from config
replay_buffer = ReplayBuffer(config.capacity, config.obs_shape, config.action_dim)
print("Starting PlaNet Training...")
train_planner(env, model, optimizer, planner, replay_buffer, config)
evaluate_planer(  # Fix 1: corrected function name to match definitionenv, model, planner, 5, device)