# Multi-Task DreamerV2

Traditional RL models learn one task at a time and struggle to transfer knowledge across different tasks. Multi-Task RL (MTRL) enables a single model to learn and adapt to multiple environments.

## Background

- **Shared and task-specific components**
    Shared models where a common encoder extracts general features across tasks, each task has its own policy and dynamics and knowledge is shared across tasks via latent space representations.
- **Multi-Task RL**
    Shared encoder learns a task-agnostic latent representation and each task has its own RSSM and actor-critic model.

## Theory

- Shared Encoder:
$$z_t=ConvEncoder(x_t)$$
Extract agnostic latent features from the input.
- Task-Specific RSSM:
$$h_t = f(h_{t-1}, z_{t-1},a_{t-1}, task_id)$$
Each task gets its own recurrent dynamics model.
- Task-Specific Actor-Critic:
$$L_{actor}^{task} = - \mathbb{E}[V_{task}(z_t)]$$
$$L_{critic}^{task} = \mathbb{E}[(V_{task}(z_t) -(r_t + \gamma V_{task}(z_{t+1})))^2]$$

## Implementation

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import gymnasium as gym

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [8]:
class MultiTaskConvEncoder(nn.Module):
    def __init__(self, image_shape, latent_dim):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2), nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2), nn.ReLU(),
        )

        dummy_input = torch.zeros(1, *image_shape).to(device)
        with torch.no_grad():
            conv_out_size = self.encoder(dummy_input).view(1, -1).shape[1]

        self.fc = nn.Linear(conv_out_size, latent_dim)

    def forward(self, x):
        x = x / 255.0
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)


In [9]:
class MultiTaskRSSM(nn.Module):
    def __init__(self, latent_dim, action_dim, num_tasks):
        super().__init__()
        self.task_specific_rssms = nn.ModuleList([RSSM(latent_dim, action_dim) for _ in range(num_tasks)])
    def forward(self, h,z,a,task_id):
        return self.task_specific_rssms[task_id](h,z,a)

In [10]:
class MultiTaskActor(nn.Module):
    def __init__(self, latent_dim, action_dim, num_tasks):
        super().__init__()
        self.task_specific_actors = nn.ModuleList([Actor(latent_dim, action_dim) for _ in range(num_tasks)])

    def forward(self, z, task_id):
        return self.task_specific_actors[task_id](z)

class MultiTaskCritic(nn.Module):
    def __init__(self, latent_dim, num_tasks):
        super().__init__()
        self.task_specific_critics = nn.ModuleList([Critic(latent_dim) for _ in range(num_tasks)])

    def forward(self, z, task_id):
        return self.task_specific_critics[task_id](z)


In [11]:
class MultiTaskDreamerV2Agent:
    def __init__(self, image_shape, action_dim, num_tasks, latent_dim=32):
        self.encoder = MultiTaskConvEncoder(image_shape, latent_dim).to(device)
        self.rssm = MultiTaskRSSM(latent_dim, action_dim, num_tasks).to(device)
        self.actor = MultiTaskActor(latent_dim, action_dim, num_tasks).to(device)
        self.critic = MultiTaskCritic(latent_dim, num_tasks).to(device)

        self.optim_actor = optim.Adam(self.actor.parameters(), lr=3e-4)
        self.optim_critic = optim.Adam(self.critic.parameters(), lr=3e-4)

    def train(self, obs_seq, action_seq, reward_seq, task_id):
        z = self.encoder(obs_seq[0])
        h = torch.zeros_like(z).to(device)

        zs, hs, rewards = [], [], []
        for t in range(len(action_seq)):
            h, z_pred, mu, logvar = self.rssm(h, z, action_seq[t], task_id)
            z_next = self.encoder(obs_seq[t + 1])
            zs.append(z_next)
            hs.append(h)
            rewards.append(reward_seq[t])

        values = [self.critic(z, task_id) for z in zs]
        targets = [rewards[i] + 0.99 * values[i+1].detach() if i+1 < len(values) else rewards[i]
                   for i in range(len(rewards))]

        value_loss = sum((values[i] - targets[i]).pow(2).mean() for i in range(len(values)))
        actor_loss = -sum(self.critic(z, task_id).mean() for z in zs)

        self.optim_actor.zero_grad()
        actor_loss.backward()
        self.optim_actor.step()

        self.optim_critic.zero_grad()
        value_loss.backward()
        self.optim_critic.step()



In [None]:
tasks = ["CarRacing-v2", "Walker2d-v4"]  # Example tasks
num_tasks = len(tasks)

envs = [gym.make(task) for task in tasks]
image_shapes = [env.observation_space.shape for env in envs]

image_shapes = [(shape[2], shape[0], shape[1]) for shape in image_shapes]  # Convert to (C,H,W)

agent = MultiTaskDreamerV2Agent(image_shapes[0], action_dim=envs[0].action_space.shape[0], num_tasks=num_tasks)

for ep in range(200):
    for task_id, env in enumerate(envs):
        obs, _ = env.reset()
        ep_reward = 0

        for step in range(1000):
            obs_tensor = torch.tensor(obs, dtype=torch.float32).permute(2,0,1).unsqueeze(0).to(device)
            action = agent.actor(agent.encoder(obs_tensor), task_id).cpu().detach().numpy()[0]
            next_obs, reward, done, _, _ = env.step(action)

            obs = next_obs
            ep_reward += reward
            if done: break

        print(f"Task {tasks[task_id]} | Episode {ep} | Reward: {ep_reward:.2f}")


## Next Steps

### Improve RSSM using Transformer based models

Traditional GRU are limited in handling long-horizon dependencies, transformers provide better sequence modeling allowing:
-Cross-task generalization
-More stable latent representations
-Better memory for long-term planning.

$$h_t = Transformer(h_{t-1}, z_{t-1},a_{t-1}, task_id)$$

### Implementation

In [12]:
class TransformerRSSM(nn.Module):
    def __init__(self, latent_dim, action_dim, num_tasks, num_layers=4, num_heads = 4):
        super().__init__()
        self.taks_specific_embeddings = nn.Embedding(num_tasks, latent_dim)
        self.action_proj = nn.Linear(action_dim, latent_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=latent_dim, nhead=num_heads)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.mu_layer = nn.Linear(latent_dim, latent_dim)
        self.logvar_layer = nn.Linear(latent_dim, latent_dim)

    def forward(self, h,z,a,task_id):
        task_embedding=self.taks_specific_embeddings(torch.tensor([task_id], device=z.device))
        a_proj = self.action_proj(a)

        transformer_input = torch.cat([z,h,a_proj,task_embedding], dim=-1).unsqueeze(0)
        h_next = self.transformer_encoder(transformer_input).squeeze(0)

        mu, logvar = self.mu_layer(h_next), self.logvar_layer(h_next)
        z_next = mu +torch.exp(0.5*logvar)*torch.randn_like(mu)

        return h_next, z_next, mu, logvar