# Pixel-Based DreamerV2: Vision Based RL

Many Real world applications involve visual perception rather than structured low-dimensional state inputs. However, learning from high-dimensional pixel observations is computationally expensive and sample-inefficient using traditional RL approaches.
Pixel-Based DreamerV2 solves this by:
- Using a Convolutinal Encoder to transform high-dimensional images into compact latent representations.
- Performing all planning and policy optimization in the latent space, drastically reducing computational complexity
- Learning a latent transition model to enable efficient imagination-based planning

**Real-World Applications:**
- Robotics: Manipulation, Navigation, and Control
- Autonomous Driving: Perception, Planning, and Control
- Games: Atari, Mujoco, and other environments

## Background

- **Latent Representation Learning**
CNN-based encoder maps raw images $x_t$ into a low-dimensional latent vector $z_t$. The policy is trained using latent-space trajectories, reducing the burden of raw pixel learning.
- **Recurrent State Space Model (RSSM)**
Uses a stichastic latent state $z_t$ and a deterministic recurrent state $h_t$ to predict future states, handles uncertainty and long-term dependencies using GRUs/LSTM.
- **Policy Optimization in Latent Space**
Instead of acting directly on images, the policy and value function operate in the latent state space. This significantly improves sample efficicency and reduces policy overfitting.

## Theoretical Explanation

### Latent-Space encoding
At each timestep t, the image observation $x_t$ is encoded into a latent vector $z_t$:
$$z_t = CNNEncoder(x_t)$$
The encoder is trained to capture the relevant information in the image, such as object positions, colors, and textures.


### RSSM Latent Transition
The deterministic and stochastic state updates:
$$h_t =f(h_{t-1}, z_{t-1}, a_{t-1})$$
$$z_t \sim q(z_t|h_t,x_t)$$
The deterministic state $h_t$ is updated using the previous state $h_{t-1}$, the latent vector $z_{t-1}$, and the action $a_{t-1}$. The stochastic state $z_t$ is sampled from the latent transition model $q(z_t|h_t,x_t)$.

### Imagination-Based Policy Learning
Instead of learning from real trajectories, DreamerV2 learns from imagined rollouts within the latent space. The actor-critic is optimized using imagined latent sequences:
$$L_{actor} = - \mathbb{E}[\sum_t V(z_t)]$$
$$L_{critic} = \mathbb{E}[(V(z_t)-(r_t + \gamma V(z_{t+1})))^2]$$

## Implementation

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [47]:
class ConvEncoder(nn.Module):
    def __init__(self, image_shape, latent_dim):
        """
        Convolutional encoder for pixel-based DreamerV2.
        
        Args:
            image_shape (tuple): Expected image shape (Channels, Height, Width).
            latent_dim (int): Latent space dimension.
        """
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2), nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2), nn.ReLU()
        )

        # Handle input shape (C, H, W) vs (H, W, C)
        if len(image_shape) == 3 and image_shape[0] in [1, 3]:  
            self.input_shape = image_shape  
        elif len(image_shape) == 3:  
            self.input_shape = (image_shape[2], image_shape[0], image_shape[1])
        else:
            raise ValueError(f"Invalid image_shape: {image_shape}")

        # Compute output feature size
        dummy_input = torch.zeros(1, *self.input_shape)
        with torch.no_grad():
            dummy_output = self.encoder(dummy_input)
        self.fc = nn.Linear(dummy_output.view(1, -1).shape[1], latent_dim)

    def forward(self, x):
        x = x / 255.0  # Normalize input images
        x = self.encoder(x)
        x = x.view(x.size(0), -1)  # Flatten feature maps
        return self.fc(x)


In [56]:
class RSSM(nn.Module):
    def __init__(self, latent_dim, action_dim):
        """
        Recurrent State-Space Model (RSSM) for DreamerV2.

        Args:
            latent_dim (int): Latent state size.
            action_dim (int): Number of action dimensions.
        """
        super().__init__()
        self.gru = nn.GRUCell(latent_dim + action_dim, latent_dim)
        self.mu_layer = nn.Linear(latent_dim, latent_dim)
        self.logvar_layer = nn.Linear(latent_dim, latent_dim)

    def forward(self, h, z, a):
        """
        RSSM forward pass: Processes hidden state, latent state, and action.

        Args:
            h (torch.Tensor): Previous deterministic hidden state (batch_size, latent_dim).
            z (torch.Tensor): Previous stochastic latent state (batch_size, latent_dim).
            a (torch.Tensor): Action taken (batch_size, action_dim) or (action_dim,).

        Returns:
            h_next, z_next, mu, logvar
        """
        # Ensure `a` has batch dimension
        if a.dim() == 1:
            a = a.unsqueeze(0)  # Convert from (action_dim,) → (1, action_dim)
        
        # Ensure a matches batch size of z
        if a.shape[0] != z.shape[0]:
            a = a.expand(z.shape[0], -1)  # Expand to match batch size
        
        h_next = self.gru(torch.cat([z, a], dim=-1), h)  # Ensure correct shape
        mu, logvar = self.mu_layer(h_next), self.logvar_layer(h_next)
        z_next = mu + torch.exp(0.5 * logvar) * torch.randn_like(mu)

        return h_next, z_next, mu, logvar

In [57]:
class Actor(nn.Module):
    def __init__(self, latent_dim, action_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128), nn.ReLU(),
            nn.Linear(128, action_dim), nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)

class Critic(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128), nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, z):
        return self.model(z)

In [63]:
class IntrinsicReward(nn.Module):
    def __init__(self, latent_dim, intrinsic_scale=0.1):
        """
        Computes intrinsic rewards based on prediction error in latent space.

        Args:
            latent_dim (int): Size of latent space.
            intrinsic_scale (float): Scale factor for intrinsic reward.
        """
        super().__init__()
        self.intrinsic_scale = intrinsic_scale
        self.fc = nn.Linear(latent_dim, 1)  # Correct input size: latent_dim → 1

    def forward(self, z_pred, z_next):
        """
        Computes intrinsic reward from latent state prediction error.

        Args:
            z_pred (torch.Tensor): Predicted latent state, shape (batch_size, latent_dim).
            z_next (torch.Tensor): True latent state, shape (batch_size, latent_dim).

        Returns:
            torch.Tensor: Intrinsic reward, shape (batch_size, 1).
        """
        prediction_error = ((z_next - z_pred) ** 2).mean(dim=-1, keepdim=True)  # Keep (batch_size, 1)
        intrinsic_reward = self.fc(prediction_error)  # Ensure correct shape
        return self.intrinsic_scale * intrinsic_reward


In [64]:
class DreamerV2PixelAgent:
    def __init__(self, image_shape, action_dim, latent_dim=32, intrinsic_scale=0.1):
        self.encoder = ConvEncoder(image_shape, latent_dim).to(device)
        self.rssm = RSSM(latent_dim, action_dim).to(device)
        self.actor = Actor(latent_dim, action_dim).to(device)
        self.critic = Critic(latent_dim).to(device)
        self.intrinsic_reward = IntrinsicReward(latent_dim, intrinsic_scale).to(device)

        self.optim_actor = optim.Adam(self.actor.parameters(), lr=3e-4)
        self.optim_critic = optim.Adam(self.critic.parameters(), lr=3e-4)

    def train(self, obs_seq, action_seq, reward_seq):
        obs_seq = [torch.tensor(obs, dtype=torch.float32, device=device).permute(2, 0, 1).unsqueeze(0)
               if isinstance(obs, np.ndarray) else obs for obs in obs_seq]

        z = self.encoder(obs_seq[0])
        h = torch.zeros_like(z).to(device)

        zs, hs, rewards = [], [], []
        for t in range(len(action_seq)):
            a = action_seq[t]
            if isinstance(a, torch.Tensor):
                a = a.to(device)
            else:
                a = torch.tensor(a, dtype=torch.float32, device=device).unsqueeze(0)  # Convert and add batch dim
            
            h, z_pred, mu, logvar = self.rssm(h, z, a)
            z_next = self.encoder(obs_seq[t + 1]) if isinstance(obs_seq[t + 1], torch.Tensor) else self.encoder(
                torch.tensor(obs_seq[t + 1], dtype=torch.float32, device=device).permute(2, 0, 1).unsqueeze(0))
            intrinsic_reward = self.intrinsic_reward(z_pred, z_next)

            total_reward = reward_seq[t] + intrinsic_reward
            zs.append(z_next)
            hs.append(h)
            rewards.append(total_reward)

        values = [self.critic(z) for z in zs]
        targets = [rewards[i] + 0.99 * values[i+1].detach() if i+1 < len(values) else rewards[i]
                   for i in range(len(rewards))]

        value_loss = sum((values[i] - targets[i]).pow(2).mean() for i in range(len(values)))
        actor_loss = -sum(self.critic(z).mean() for z in zs)

        self.optim_actor.zero_grad()
        actor_loss.backward()
        self.optim_actor.step()

        self.optim_critic.zero_grad()
        value_loss.backward()
        self.optim_critic.step()

In [65]:
env = gym.make("CarRacing-v3")
obs_dim = env.observation_space.shape
action_dim = env.action_space.shape[0]

agent = DreamerV2PixelAgent(obs_dim, action_dim)

episodes = 200
rewards = []

for ep in range(episodes):
    obs, _ = env.reset()
    ep_reward = 0
    real_obs, real_actions, real_rewards = [], [], []

    for step in range(1000):
        obs_tensor = torch.tensor(obs, dtype=torch.float32).permute(2,0,1).unsqueeze(0).to(device)

        action = agent.actor(agent.encoder(obs_tensor)).cpu().detach().numpy()[0]
        next_obs, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated

        real_obs.append(torch.tensor(next_obs, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(device))
        real_actions.append(torch.tensor(action, dtype=torch.float32).to(device))
        real_rewards.append(torch.tensor([reward], dtype=torch.float32).to(device))

        obs = next_obs
        if done: break

    agent.train(real_obs, real_actions, real_rewards)
    rewards.append(ep_reward)
    print(f"Episode: {ep}, Reward: {ep_reward}")

plt.plot(rewards)
plt.show()

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x1 and 32x1)