# Meta-RL with DreamerV2

Meta-Reinforcement Learning (Meta-RL), also known as "learning to learn", addresses the fundamental challenge in traditional RL: the fifficulty of efficiently adapting to new, unseen tasks after training. While traditional RL typically requires extensive training for each new task, Meta-RL algorithms learn general-purpose policies or learning rules capable of rapid adaptation.
Integrating Meta-RL allows to rapidly adpat to new tasks from very few interactions(few-shot adaptation), efficiently generalize learned policies to previously unseen scenarios.

## Background

Meta-Learning involves two nested levels of learning:
- Inner-loop (Adaptation): Fast learning that occurs within a new task, typically using a small amount of data.
- Outer-loop (Meta-learning): Slower learning process that trains a meta-learner to generalize well and adapt quickly.

## Theory

Meta-RK aims to find parameters $\theta$ (for model or policy) that minimize expected loss after adaptation steps on task T:
$$\min_\theta \mathbb{E}_{T \sim p(T)} [L_T(\theta - \alpha \nabla_\theta L_T(\theta))]$$
Where:
- $T \sim p(T)$: Task distribution
- $L_T$: Loss on task T
- $\alpha$: Step size for adaptation

## Math

### Inner-Loop AdaptationStep (Task-Specific)
Given task-specific data $D_T$, compute adapted parameters $\theta'_T$:
$$\theta'_T = \theta - \alpha \nabla_\theta L_T(D_T, \theta)$$

### Outer-Loop Meta-Update (Meta-Learning)
Update meta-parameters $\theta$ using adapted task parameters $\theta'_T$:
$$\theta \leftarrow \theta - \beta \nabla_\theta \sum_{T \sim p(T)} L_T(D^{val}_T, \theta'_T)$$
Where:
- $D^{val}_T$: Validation data for task T
- $\beta$: Meta-learning rate

## Implementation

### Pseudocode
``` pseudocode
for each meta-training epoch:
    for each task T sampled from task distribution:
        1. Collect a small adaptation set D_train for task T.
        2. compute adapted parameters \theta'_T from D_train (inner-loop adaptation)
        3. Evaluated adapted parameters on D_val (task validation set)
    4. Update meta-parameters \theta using aggregated gradient from all tasks.
```
**Inner-Loop Adaptation**

``` python
def adapt(theta, D_train, alpha):
    for step in range(adapt_steps):
        # Compute loss L_train(theta, D_train)
        theta = theta - alpha * grad(L_train(theta, D_train))
    return theta
```
**Outer-Loop Optimization**

```python
for epoch in range(meta_learning_epochs):
    meta_loss = 0
    for task in tasks:
        theta_task = adapt(theta, D_train[taks], alpha)
        meta_loss += L_val(theta_task, D_val[task])
    theta = theta - beta * grad(meta_loss)
```

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym
import numpy as np
import random
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
class ConvEncoder(nn.Module):
    def __init__(self, image_shape, latent_dim):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2), nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2), nn.ReLU(),
        )

        dummy_input = torch.zeros(1, *image_shape)
        with torch.no_grad():
            flat_dim = self.encoder(dummy_input).view(1, -1).shape[1]

        # This line fixes the error:
        self.fc = nn.Linear(flat_dim, latent_dim)

    def forward(self, x):
        x = x / 255.0
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)


In [3]:
class RSSM(nn.Module):
    def __init__(self, latent_dim, action_dim):
        super().__init__()
        self.gru = nn.GRUCell(latent_dim + action_dim, latent_dim)
        self.mu_layer = nn.Linear(latent_dim, latent_dim)
        self.logvar_layer = nn.Linear(latent_dim, latent_dim)

    def forward(self, h, z, a):
        if a.dim() == 1:
            a = a.unsqueeze(0)
        if h.dim() == 1:
            h = h.unsqueeze(0)

        x = torch.cat([z, a], dim=-1)
        h_next = self.gru(x, h)
        mu, logvar = self.mu_layer(h_next), self.logvar_layer(h_next)
        z_next = mu + torch.exp(0.5 * logvar) * torch.randn_like(mu)

        return h_next, z_next, mu, logvar

In [4]:
class Actor(nn.Module):
    def __init__(self, latent_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, 128), nn.ReLU(),
            nn.Linear(128, action_dim), nn.Tanh()
        )

    def forward(self, z):
        return self.net(z)


In [5]:
class Critic(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, 128), nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, z):
        return self.net(z)


In [6]:
class IntrinsicReward(nn.Module):
    def __init__(self, latent_dim, intrinsic_scale=0.1):
        super().__init__()
        self.intrinsic_scale = intrinsic_scale
        self.fc = nn.Linear(latent_dim, 1)

    def forward(self, z_pred, z_next):
        error = (z_next - z_pred)**2  # shape: [batch_size, latent_dim]
        intrinsic_reward = self.fc(error)  # Now correctly shaped
        return self.intrinsic_scale * intrinsic_reward


In [7]:
class MetaDreamerV2Agent:

    def __init__(self, image_shape, action_dim, latent_dim=32):
        self.image_shape = image_shape
        self.action_dim = action_dim
        self.latent_dim = latent_dim
        self.encoder = ConvEncoder(image_shape, latent_dim).to(device)
        self.rssm = RSSM(latent_dim, action_dim).to(device)
        self.actor = Actor(latent_dim, action_dim).to(device)
        self.critic = Critic(latent_dim).to(device)
        self.intrinsic_reward = IntrinsicReward(latent_dim).to(device)

        self.optimizer = optim.Adam(self.parameters(), lr=1e-4)

    def parameters(self):
        return list(self.encoder.parameters()) + list(self.rssm.parameters()) + list(self.actor.parameters()) + list(self.critic.parameters()) + list(self.intrinsic_reward.parameters())
    
    def clone(self):
        clone = MetaDreamerV2Agent(self.image_shape, self.action_dim, self.latent_dim)
        clone.load_state_dict(self.state_dict())
        return clone
    
    def state_dict(self):
        return {
            'encoder': self.encoder.state_dict(),
            'rssm': self.rssm.state_dict(),
            'actor': self.actor.state_dict(),
            'critic': self.critic.state_dict(),
            'intrinsic_reward': self.intrinsic_reward.state_dict()
        }
    
    def load_state_dict(self, state_dict):
        self.encoder.load_state_dict(state_dict['encoder'])
        self.rssm.load_state_dict(state_dict['rssm'])
        self.actor.load_state_dict(state_dict['actor'])
        self.critic.load_state_dict(state_dict['critic'])
        self.intrinsic_reward.load_state_dict(state_dict['intrinsic_reward'])

    def adapt(self, obs_seq, action_seq, reward_seq, steps=1, lr=1e-3):
        adapted = self.clone()
        optimizer = optim.Adam(adapted.parameters(), lr=lr)

        for _ in range(steps):
            loss = adapted.compute_loss(obs_seq, action_seq, reward_seq)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        return adapted
    
    def compute_loss(self, obs_seq, action_seq, reward_seq):
        z = self.encoder(obs_seq[0])
        h = torch.zeros_like(z).to(device)

        zs, rewards = [], []
        total_loss = 0.0
        for t in range(len(action_seq)):
            h,z_pred, _,_ = self.rssm(h, z, action_seq[t])
            z_next = self.encoder(obs_seq[t+1])

            intrinsic_reward = self.intrinsic_reward(z_pred, z_next)

            combined_reward = reward_seq[t] + intrinsic_reward

            zs.append(z_pred)
            rewards.append(combined_reward)

            z = z_next

        values = [self.critic(z) for z in zs]
        targets = [rewards[i] + 0.99 * values[i+1].detach() if i+1 < len(values) else rewards[i] for i in range(len(values))]
        value_loss = sum((values[i]-targets[i]).pow(2).mean() for i in range(len(values)))

        actor_loss = -sum(self.critic(z).mean() for z in zs)
        total_loss = value_loss + actor_loss
        return total_loss


In [12]:
tasks = ["CarRacing-v3"]


agent = MetaDreamerV2Agent((3,96,96),3)

meta_epoch = 100

for epoch in range(meta_epoch):
    meta_loss = 0
    for task in tasks:
        env = gym.make(task)

        obs, _ = env.reset()
        obs_seq = [torch.tensor(obs, dtype=torch.float32).permute(2,0,1).unsqueeze(0).to(device)]
        action_seq, reward_seq = [], []

        for step in range(10):
            action = env.action_space.sample()
            next_obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            # Append next_obs here directly after each step
            obs_seq.append(torch.tensor(next_obs, dtype=torch.float32).permute(2,0,1).unsqueeze(0).to(device))
            action_seq.append(torch.tensor(action, dtype=torch.float32).unsqueeze(0).to(device))
            reward_seq.append(torch.tensor([reward], dtype=torch.float32).to(device))

            if done: break

        adapted = agent.adapt(obs_seq, action_seq, reward_seq, steps=5, lr=1e-3)

        # Reset environment for evaluation
        eval_obs, _ = env.reset()
        eval_obs_tensor = torch.tensor(eval_obs, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(device)

        # Obtain latent embedding using the adapted agent's encoder
        z = adapted.encoder(eval_obs_tensor)

        # Evaluate using the critic network to ensure differentiable meta-loss
        value_estimate = adapted.critic(z)

        # Compute differentiable meta-loss for gradient updates
        meta_loss += -value_estimate.mean()

    meta_loss /= len(tasks)
    agent.optimizer.zero_grad()
    meta_loss.backward()
    agent.optimizer.step()

    print(f"Epoch: {epoch}, Meta Loss: {meta_loss.item():.3f}")


Epoch: 0, Meta Loss: -0.270
Epoch: 1, Meta Loss: -0.425
Epoch: 2, Meta Loss: -0.761
Epoch: 3, Meta Loss: 0.974
Epoch: 4, Meta Loss: 0.349
Epoch: 5, Meta Loss: -0.163
Epoch: 6, Meta Loss: 0.318
Epoch: 7, Meta Loss: -0.780
Epoch: 8, Meta Loss: 0.735
Epoch: 9, Meta Loss: -0.160
Epoch: 10, Meta Loss: -0.022
Epoch: 11, Meta Loss: 1.300
Epoch: 12, Meta Loss: 0.190
Epoch: 13, Meta Loss: -0.108
Epoch: 14, Meta Loss: 0.009
Epoch: 15, Meta Loss: 0.150
Epoch: 16, Meta Loss: 0.104
Epoch: 17, Meta Loss: 0.233
Epoch: 18, Meta Loss: 0.151
Epoch: 19, Meta Loss: -0.303
Epoch: 20, Meta Loss: 1.548
Epoch: 21, Meta Loss: -0.566
Epoch: 22, Meta Loss: -0.096
Epoch: 23, Meta Loss: -0.097
Epoch: 24, Meta Loss: -0.074
Epoch: 25, Meta Loss: -0.195
Epoch: 26, Meta Loss: -0.021
Epoch: 27, Meta Loss: 0.000
Epoch: 28, Meta Loss: -0.024
Epoch: 29, Meta Loss: 0.356
Epoch: 30, Meta Loss: 0.170
Epoch: 31, Meta Loss: 0.394
Epoch: 32, Meta Loss: -0.133
Epoch: 33, Meta Loss: -0.725
Epoch: 34, Meta Loss: 0.005
Epoch: 35, M