# HW11: Meta-Learning in Reinforcement Learning

> - Full Name: **[Your Full Name]**
> - Student ID: **[Your Student ID]**

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DeepRLCourse/Homework-11-Questions/blob/main/HW11_Notebook.ipynb)
[![Open In kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://raw.githubusercontent.com/DeepRLCourse/Homework-11-Questions/main/HW11_Notebook.ipynb)

## Overview
This assignment focuses on **Meta-Learning in Reinforcement Learning**, exploring algorithms that enable agents to quickly adapt to new tasks by leveraging experience from related tasks. We'll implement and experiment with:

1. **MAML (Model-Agnostic Meta-Learning)** - Gradient-based meta-learning for RL
2. **RL² (Recurrent Meta-RL)** - Black-box meta-learning using recurrent networks
3. **PEARL (Probabilistic Embeddings)** - Context-based meta-RL with task embeddings
4. **Few-Shot Adaptation** - Rapid learning on new tasks with minimal samples
5. **Task Distributions** - Learning across families of related RL tasks

The goal is to understand how meta-learning enables "learning to learn" and achieves fast adaptation to new tasks.


In [None]:
# @title Imports and Setup

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import gymnasium as gym
from gymnasium import spaces
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import trange
from collections import defaultdict, deque
import random
from typing import Tuple, List, Dict, Optional
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")


## 1. Meta-Learning Environment Setup

First, let's create a task distribution for meta-learning. We'll use a parameterized environment where tasks differ in reward functions or dynamics, enabling us to test few-shot adaptation.


In [None]:
class ParameterizedCartPoleEnv(gym.Env):
    """
    Parameterized CartPole environment for meta-learning.
    Tasks differ in pole length, mass, or reward structure.
    """
    
    def __init__(self, task_params=None):
        super().__init__()
        
        # Default task parameters
        self.default_params = {
            'pole_length': 0.5,
            'pole_mass': 0.1,
            'cart_mass': 1.0,
            'gravity': 9.8,
            'reward_scale': 1.0,
            'success_threshold': 195.0
        }
        
        # Set task parameters
        self.params = self.default_params.copy()
        if task_params:
            self.params.update(task_params)
        
        # State space: [cart_pos, cart_vel, pole_angle, pole_vel]
        self.observation_space = spaces.Box(
            low=-np.inf, high=np.inf, shape=(4,), dtype=np.float32
        )
        
        # Action space: [push_left, push_right]
        self.action_space = spaces.Discrete(2)
        
        # Environment state
        self.state = None
        self.steps = 0
        self.max_steps = 200
        
    def reset(self, **kwargs):
        """Reset environment with random initial state."""
        # Random initial state
        self.state = np.array([
            np.random.uniform(-0.1, 0.1),  # cart position
            np.random.uniform(-0.1, 0.1),  # cart velocity
            np.random.uniform(-0.1, 0.1),  # pole angle
            np.random.uniform(-0.1, 0.1)   # pole velocity
        ], dtype=np.float32)
        
        self.steps = 0
        return self.state.copy(), {}
    
    def step(self, action):
        """Execute action and return next state."""
        if self.state is None:
            raise RuntimeError("Environment not reset")
        
        # Convert action to force
        force = 10.0 if action == 1 else -10.0
        
        # CartPole dynamics with task parameters
        x, x_dot, theta, theta_dot = self.state
        
        # Physical constants
        g = self.params['gravity']
        m_cart = self.params['cart_mass']
        m_pole = self.params['pole_mass']
        l = self.params['pole_length']
        
        # Total mass
        total_mass = m_cart + m_pole
        
        # Dynamics
        temp = (force + m_pole * l * theta_dot**2 * np.sin(theta)) / total_mass
        theta_acc = (g * np.sin(theta) - np.cos(theta) * temp) / \
                   (l * (4/3 - m_pole * np.cos(theta)**2 / total_mass))
        x_acc = temp - m_pole * l * theta_acc * np.cos(theta) / total_mass
        
        # Update state
        x = x + x_dot * 0.02
        x_dot = x_dot + x_acc * 0.02
        theta = theta + theta_dot * 0.02
        theta_dot = theta_dot + theta_acc * 0.02
        
        self.state = np.array([x, x_dot, theta, theta_dot], dtype=np.float32)
        self.steps += 1
        
        # Compute reward
        reward = self._compute_reward()
        
        # Check termination
        terminated = self._is_terminated()
        truncated = self.steps >= self.max_steps
        
        return self.state.copy(), reward, terminated, truncated, {}
    
    def _compute_reward(self):
        """Compute reward based on task parameters."""
        x, x_dot, theta, theta_dot = self.state
        
        # Base reward: stay upright and centered
        reward = 1.0
        
        # Penalty for being far from center
        reward -= abs(x) * 0.1
        
        # Penalty for large angle
        reward -= abs(theta) * 0.1
        
        # Scale reward
        reward *= self.params['reward_scale']
        
        return reward
    
    def _is_terminated(self):
        """Check if episode should terminate."""
        x, x_dot, theta, theta_dot = self.state
        
        # Terminate if cart goes too far
        if abs(x) > 2.4:
            return True
        
        # Terminate if pole falls too far
        if abs(theta) > 0.2095:  # ~12 degrees
            return True
        
        return False
    
    def set_task_params(self, task_params):
        """Set new task parameters."""
        self.params.update(task_params)


class TaskDistribution:
    """
    Distribution over parameterized tasks for meta-learning.
    """
    
    def __init__(self, task_type='pole_length'):
        self.task_type = task_type
        
        if task_type == 'pole_length':
            self.param_ranges = {
                'pole_length': (0.3, 0.8),
                'reward_scale': (0.8, 1.2)
            }
        elif task_type == 'mass':
            self.param_ranges = {
                'pole_mass': (0.05, 0.2),
                'cart_mass': (0.8, 1.5),
                'reward_scale': (0.8, 1.2)
            }
        elif task_type == 'gravity':
            self.param_ranges = {
                'gravity': (8.0, 12.0),
                'reward_scale': (0.8, 1.2)
            }
        elif task_type == 'mixed':
            self.param_ranges = {
                'pole_length': (0.3, 0.8),
                'pole_mass': (0.05, 0.2),
                'gravity': (8.0, 12.0),
                'reward_scale': (0.8, 1.2)
            }
    
    def sample_task(self):
        """Sample a random task from the distribution."""
        task_params = {}
        for param, (low, high) in self.param_ranges.items():
            task_params[param] = np.random.uniform(low, high)
        
        return ParameterizedCartPoleEnv(task_params)
    
    def sample_tasks(self, n_tasks):
        """Sample multiple tasks."""
        return [self.sample_task() for _ in range(n_tasks)]


# Test the environment
print("Testing Parameterized CartPole Environment...")

# Test different task types
task_types = ['pole_length', 'mass', 'gravity', 'mixed']
for task_type in task_types:
    print(f"\nTesting {task_type} tasks:")
    task_dist = TaskDistribution(task_type)
    
    # Sample a few tasks
    for i in range(3):
        task = task_dist.sample_task()
        obs, _ = task.reset()
        
        # Run a few steps
        total_reward = 0
        for step in range(10):
            action = task.action_space.sample()
            obs, reward, terminated, truncated, _ = task.step(action)
            total_reward += reward
            
            if terminated or truncated:
                break
        
        print(f"  Task {i+1}: Reward = {total_reward:.2f}, Params = {task.params}")

print("\nEnvironment test completed!")


## 2. MAML (Model-Agnostic Meta-Learning) Implementation

MAML finds initialization parameters that are good for fine-tuning on new tasks. It uses a two-level optimization process:
- **Inner Loop**: Adapt to specific task using gradient descent
- **Outer Loop**: Optimize initialization for fast adaptation across tasks


In [None]:
class PolicyNetwork(nn.Module):
    """
    Policy network for MAML.
    """
    
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super().__init__()
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        
        # Policy network
        self.policy = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
        
        # Value network
        self.value = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, state):
        """Forward pass through policy and value networks."""
        policy_logits = self.policy(state)
        value = self.value(state)
        return policy_logits, value
    
    def get_action(self, state, deterministic=False):
        """Get action from policy."""
        policy_logits, value = self.forward(state)
        
        if deterministic:
            action = torch.argmax(policy_logits, dim=-1)
        else:
            action_probs = F.softmax(policy_logits, dim=-1)
            action = torch.multinomial(action_probs, 1).squeeze(-1)
        
        return action, value
    
    def log_prob(self, state, action):
        """Get log probability of action."""
        policy_logits, _ = self.forward(state)
        log_probs = F.log_softmax(policy_logits, dim=-1)
        return log_probs.gather(1, action.unsqueeze(1)).squeeze(1)


class MAMLAgent:
    """
    MAML agent for meta-learning in RL.
    """
    
    def __init__(self, state_dim, action_dim, meta_lr=0.001, inner_lr=0.01, 
                 inner_steps=1, hidden_dim=64):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.meta_lr = meta_lr
        self.inner_lr = inner_lr
        self.inner_steps = inner_steps
        
        # Meta-network (initialization)
        self.meta_network = PolicyNetwork(state_dim, action_dim, hidden_dim)
        self.meta_optimizer = optim.Adam(self.meta_network.parameters(), lr=meta_lr)
        
        # Copy of network for inner loop updates
        self.inner_network = PolicyNetwork(state_dim, action_dim, hidden_dim)
        
    def collect_trajectories(self, env, network, num_episodes=5, max_steps=200):
        """Collect trajectories using given network."""
        trajectories = []
        
        for episode in range(num_episodes):
            obs, _ = env.reset()
            episode_data = {
                'states': [],
                'actions': [],
                'rewards': [],
                'values': [],
                'log_probs': []
            }
            
            for step in range(max_steps):
                state_tensor = torch.FloatTensor(obs).unsqueeze(0)
                
                # Get action and value
                action, value = network.get_action(state_tensor)
                log_prob = network.log_prob(state_tensor, action)
                
                # Execute action
                next_obs, reward, terminated, truncated, _ = env.step(action.item())
                done = terminated or truncated
                
                # Store data
                episode_data['states'].append(obs)
                episode_data['actions'].append(action.item())
                episode_data['rewards'].append(reward)
                episode_data['values'].append(value.item())
                episode_data['log_probs'].append(log_prob.item())
                
                obs = next_obs
                
                if done:
                    break
            
            trajectories.append(episode_data)
        
        return trajectories
    
    def compute_returns(self, rewards, gamma=0.99):
        """Compute discounted returns."""
        returns = []
        G = 0
        
        for reward in reversed(rewards):
            G = reward + gamma * G
            returns.insert(0, G)
        
        return returns
    
    def compute_advantages(self, returns, values):
        """Compute advantages using returns and values."""
        advantages = []
        for ret, val in zip(returns, values):
            advantages.append(ret - val)
        return advantages
    
    def compute_loss(self, trajectories, network):
        """Compute policy gradient loss."""
        total_loss = 0
        total_samples = 0
        
        for trajectory in trajectories:
            states = torch.FloatTensor(trajectory['states'])
            actions = torch.LongTensor(trajectory['actions'])
            rewards = trajectory['rewards']
            values = torch.FloatTensor(trajectory['values'])
            
            # Compute returns
            returns = self.compute_returns(rewards)
            returns = torch.FloatTensor(returns)
            
            # Compute advantages
            advantages = self.compute_advantages(returns, values)
            advantages = torch.FloatTensor(advantages)
            
            # Normalize advantages
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
            
            # Policy loss
            log_probs = network.log_prob(states, actions)
            policy_loss = -(log_probs * advantages).mean()
            
            # Value loss
            predicted_values = network.value(states).squeeze()
            value_loss = F.mse_loss(predicted_values, returns)
            
            # Total loss
            loss = policy_loss + 0.5 * value_loss
            total_loss += loss * len(states)
            total_samples += len(states)
        
        return total_loss / total_samples if total_samples > 0 else torch.tensor(0.0)
    
    def inner_loop_update(self, task, network):
        """Perform inner loop update on a specific task."""
        # Collect trajectories
        trajectories = self.collect_trajectories(task, network, num_episodes=3)
        
        # Compute loss
        loss = self.compute_loss(trajectories, network)
        
        # Compute gradients
        gradients = torch.autograd.grad(loss, network.parameters(), create_graph=True)
        
        # Update parameters
        updated_params = []
        for param, grad in zip(network.parameters(), gradients):
            updated_params.append(param - self.inner_lr * grad)
        
        return updated_params, loss
    
    def meta_update(self, tasks, num_tasks=4):
        """Perform meta-update across multiple tasks."""
        meta_loss = 0
        
        for task in tasks[:num_tasks]:
            # Copy meta-network for inner loop
            self.inner_network.load_state_dict(self.meta_network.state_dict())
            
            # Inner loop adaptation
            updated_params, inner_loss = self.inner_loop_update(task, self.inner_network)
            
            # Collect test trajectories with adapted network
            # Create temporary network with updated parameters
            temp_network = PolicyNetwork(self.state_dim, self.action_dim)
            temp_network.load_state_dict(self.meta_network.state_dict())
            
            # Manually set updated parameters
            for param, updated_param in zip(temp_network.parameters(), updated_params):
                param.data = updated_param
            
            # Test trajectories
            test_trajectories = self.collect_trajectories(task, temp_network, num_episodes=2)
            test_loss = self.compute_loss(test_trajectories, temp_network)
            
            meta_loss += test_loss
        
        # Meta-gradient update
        meta_loss = meta_loss / num_tasks
        self.meta_optimizer.zero_grad()
        meta_loss.backward()
        self.meta_optimizer.step()
        
        return meta_loss.item()
    
    def adapt_to_task(self, task, num_adaptation_steps=5):
        """Adapt to a new task using MAML."""
        # Copy meta-network
        adapted_network = PolicyNetwork(self.state_dim, self.action_dim)
        adapted_network.load_state_dict(self.meta_network.state_dict())
        
        # Inner loop adaptation
        for step in range(num_adaptation_steps):
            # Collect trajectories
            trajectories = self.collect_trajectories(task, adapted_network, num_episodes=2)
            
            # Compute loss
            loss = self.compute_loss(trajectories, adapted_network)
            
            # Update parameters
            optimizer = optim.Adam(adapted_network.parameters(), lr=self.inner_lr)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        return adapted_network


# Test MAML implementation
print("Testing MAML Implementation...")

# Create task distribution
task_dist = TaskDistribution('pole_length')
state_dim = 4  # CartPole state dimension
action_dim = 2  # CartPole action dimension

# Create MAML agent
maml_agent = MAMLAgent(state_dim, action_dim, meta_lr=0.001, inner_lr=0.01)

# Test meta-update
print("Testing meta-update...")
tasks = task_dist.sample_tasks(4)
meta_loss = maml_agent.meta_update(tasks)
print(f"Meta-loss: {meta_loss:.4f}")

# Test adaptation
print("Testing task adaptation...")
new_task = task_dist.sample_task()
adapted_network = maml_agent.adapt_to_task(new_task, num_adaptation_steps=3)

# Test adapted network
obs, _ = new_task.reset()
state_tensor = torch.FloatTensor(obs).unsqueeze(0)
action, value = adapted_network.get_action(state_tensor)
print(f"Adapted network action: {action.item()}, value: {value.item():.4f}")

print("MAML test completed!")


## 3. RL² (Recurrent Meta-RL) Implementation

RL² uses recurrent networks to encode task information implicitly. The LSTM hidden state learns to adapt to different tasks without explicit inner loop optimization.


In [None]:
class RL2Network(nn.Module):
    """
    RL² network with LSTM for task encoding.
    """
    
    def __init__(self, state_dim, action_dim, hidden_dim=128, lstm_layers=2):
        super().__init__()
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.hidden_dim = hidden_dim
        self.lstm_layers = lstm_layers
        
        # Input: state + previous_action + previous_reward + done_flag
        input_dim = state_dim + action_dim + 1 + 1
        
        # LSTM for task encoding
        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=lstm_layers,
            batch_first=True
        )
        
        # Policy head
        self.policy = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
        
        # Value head
        self.value = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, state, prev_action, prev_reward, done, hidden):
        """
        Forward pass through RL² network.
        
        Args:
            state: Current state
            prev_action: Previous action (one-hot)
            prev_reward: Previous reward
            done: Episode termination flag
            hidden: LSTM hidden state
        """
        # Concatenate inputs
        x = torch.cat([state, prev_action, prev_reward.unsqueeze(-1), done.unsqueeze(-1)], dim=-1)
        
        # LSTM forward pass
        output, hidden_new = self.lstm(x.unsqueeze(1), hidden)
        
        # Get policy and value
        policy_logits = self.policy(output.squeeze(1))
        value = self.value(output.squeeze(1))
        
        return policy_logits, value, hidden_new
    
    def reset_hidden(self, batch_size=1):
        """Reset LSTM hidden state."""
        h0 = torch.zeros(self.lstm_layers, batch_size, self.hidden_dim)
        c0 = torch.zeros(self.lstm_layers, batch_size, self.hidden_dim)
        return (h0, c0)
    
    def get_action(self, state, prev_action, prev_reward, done, hidden, deterministic=False):
        """Get action from policy."""
        policy_logits, value, hidden_new = self.forward(state, prev_action, prev_reward, done, hidden)
        
        if deterministic:
            action = torch.argmax(policy_logits, dim=-1)
        else:
            action_probs = F.softmax(policy_logits, dim=-1)
            action = torch.multinomial(action_probs, 1).squeeze(-1)
        
        return action, value, hidden_new
    
    def log_prob(self, state, prev_action, prev_reward, done, action, hidden):
        """Get log probability of action."""
        policy_logits, _, _ = self.forward(state, prev_action, prev_reward, done, hidden)
        log_probs = F.log_softmax(policy_logits, dim=-1)
        return log_probs.gather(1, action.unsqueeze(1)).squeeze(1)


class RL2Agent:
    """
    RL² agent for recurrent meta-learning.
    """
    
    def __init__(self, state_dim, action_dim, lr=3e-4, hidden_dim=128, lstm_layers=2):
        self.state_dim = state_dim
        self.action_dim = action_dim
        
        # RL² network
        self.network = RL2Network(state_dim, action_dim, hidden_dim, lstm_layers)
        self.optimizer = optim.Adam(self.network.parameters(), lr=lr)
        
    def collect_episode(self, env, hidden, max_steps=200):
        """Collect a single episode."""
        obs, _ = env.reset()
        
        episode_data = {
            'states': [],
            'actions': [],
            'rewards': [],
            'values': [],
            'log_probs': [],
            'prev_actions': [],
            'prev_rewards': [],
            'dones': []
        }
        
        prev_action = torch.zeros(self.action_dim)
        prev_reward = torch.tensor(0.0)
        done = torch.tensor(0.0)
        
        for step in range(max_steps):
            state_tensor = torch.FloatTensor(obs).unsqueeze(0)
            
            # Get action
            action, value, hidden = self.network.get_action(
                state_tensor, prev_action.unsqueeze(0), prev_reward.unsqueeze(0), 
                done.unsqueeze(0), hidden
            )
            
            log_prob = self.network.log_prob(
                state_tensor, prev_action.unsqueeze(0), prev_reward.unsqueeze(0),
                done.unsqueeze(0), action, hidden
            )
            
            # Execute action
            next_obs, reward, terminated, truncated, _ = env.step(action.item())
            episode_done = terminated or truncated
            
            # Store data
            episode_data['states'].append(obs)
            episode_data['actions'].append(action.item())
            episode_data['rewards'].append(reward)
            episode_data['values'].append(value.item())
            episode_data['log_probs'].append(log_prob.item())
            episode_data['prev_actions'].append(prev_action.clone())
            episode_data['prev_rewards'].append(prev_reward.item())
            episode_data['dones'].append(done.item())
            
            # Update for next step
            prev_action = F.one_hot(action, num_classes=self.action_dim).float().squeeze(0)
            prev_reward = torch.tensor(reward)
            done = torch.tensor(1.0 if episode_done else 0.0)
            obs = next_obs
            
            if episode_done:
                break
        
        return episode_data, hidden
    
    def collect_trajectories(self, env, num_episodes=10, max_steps=200):
        """Collect multiple episodes from the same task."""
        trajectories = []
        hidden = self.network.reset_hidden()
        
        for episode in range(num_episodes):
            episode_data, hidden = self.collect_episode(env, hidden, max_steps)
            trajectories.append(episode_data)
            
            # Reset hidden state between episodes (optional)
            # hidden = self.network.reset_hidden()
        
        return trajectories
    
    def compute_returns(self, rewards, gamma=0.99):
        """Compute discounted returns."""
        returns = []
        G = 0
        
        for reward in reversed(rewards):
            G = reward + gamma * G
            returns.insert(0, G)
        
        return returns
    
    def compute_advantages(self, returns, values):
        """Compute advantages."""
        advantages = []
        for ret, val in zip(returns, values):
            advantages.append(ret - val)
        return advantages
    
    def compute_loss(self, trajectories):
        """Compute PPO loss for RL²."""
        total_loss = 0
        total_samples = 0
        
        for trajectory in trajectories:
            states = torch.FloatTensor(trajectory['states'])
            actions = torch.LongTensor(trajectory['actions'])
            rewards = trajectory['rewards']
            values = torch.FloatTensor(trajectory['values'])
            prev_actions = torch.stack(trajectory['prev_actions'])
            prev_rewards = torch.FloatTensor(trajectory['prev_rewards'])
            dones = torch.FloatTensor(trajectory['dones'])
            
            # Compute returns
            returns = self.compute_returns(rewards)
            returns = torch.FloatTensor(returns)
            
            # Compute advantages
            advantages = self.compute_advantages(returns, values)
            advantages = torch.FloatTensor(advantages)
            
            # Normalize advantages
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
            
            # Forward pass through network
            hidden = self.network.reset_hidden(len(states))
            
            # Compute policy loss
            log_probs = []
            predicted_values = []
            
            for i in range(len(states)):
                state = states[i:i+1]
                prev_action = prev_actions[i:i+1]
                prev_reward = prev_rewards[i:i+1]
                done = dones[i:i+1]
                action = actions[i:i+1]
                
                log_prob = self.network.log_prob(state, prev_action, prev_reward, done, action, hidden)
                _, value, hidden = self.network.forward(state, prev_action, prev_reward, done, hidden)
                
                log_probs.append(log_prob)
                predicted_values.append(value)
            
            log_probs = torch.cat(log_probs)
            predicted_values = torch.cat(predicted_values).squeeze()
            
            # Policy loss
            policy_loss = -(log_probs * advantages).mean()
            
            # Value loss
            value_loss = F.mse_loss(predicted_values, returns)
            
            # Total loss
            loss = policy_loss + 0.5 * value_loss
            total_loss += loss * len(states)
            total_samples += len(states)
        
        return total_loss / total_samples if total_samples > 0 else torch.tensor(0.0)
    
    def update(self, trajectories):
        """Update RL² network."""
        loss = self.compute_loss(trajectories)
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return loss.item()
    
    def evaluate_task(self, env, num_episodes=5, max_steps=200):
        """Evaluate performance on a task."""
        total_rewards = []
        
        for episode in range(num_episodes):
            obs, _ = env.reset()
            episode_reward = 0
            
            hidden = self.network.reset_hidden()
            prev_action = torch.zeros(self.action_dim)
            prev_reward = torch.tensor(0.0)
            done = torch.tensor(0.0)
            
            for step in range(max_steps):
                state_tensor = torch.FloatTensor(obs).unsqueeze(0)
                
                # Get action
                action, value, hidden = self.network.get_action(
                    state_tensor, prev_action.unsqueeze(0), prev_reward.unsqueeze(0),
                    done.unsqueeze(0), hidden, deterministic=True
                )
                
                # Execute action
                next_obs, reward, terminated, truncated, _ = env.step(action.item())
                episode_done = terminated or truncated
                
                episode_reward += reward
                
                # Update for next step
                prev_action = F.one_hot(action, num_classes=self.action_dim).float().squeeze(0)
                prev_reward = torch.tensor(reward)
                done = torch.tensor(1.0 if episode_done else 0.0)
                obs = next_obs
                
                if episode_done:
                    break
            
            total_rewards.append(episode_reward)
        
        return np.mean(total_rewards), np.std(total_rewards)


# Test RL² implementation
print("Testing RL² Implementation...")

# Create task distribution
task_dist = TaskDistribution('pole_length')
state_dim = 4
action_dim = 2

# Create RL² agent
rl2_agent = RL2Agent(state_dim, action_dim, lr=3e-4)

# Test on a single task
print("Testing RL² on single task...")
task = task_dist.sample_task()

# Collect trajectories
trajectories = rl2_agent.collect_trajectories(task, num_episodes=3)
print(f"Collected {len(trajectories)} episodes")

# Update network
loss = rl2_agent.update(trajectories)
print(f"Training loss: {loss:.4f}")

# Evaluate performance
mean_reward, std_reward = rl2_agent.evaluate_task(task, num_episodes=3)
print(f"Evaluation reward: {mean_reward:.2f} ± {std_reward:.2f}")

print("RL² test completed!")


## 4. PEARL (Probabilistic Embeddings) Implementation

PEARL uses probabilistic context variables to encode task information, enabling fast adaptation through context inference.


In [None]:
class PEARLNetwork(nn.Module):
    """
    PEARL network with probabilistic context encoding.
    """
    
    def __init__(self, state_dim, action_dim, context_dim=5, hidden_dim=128):
        super().__init__()
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.context_dim = context_dim
        
        # Context encoder (variational)
        self.context_encoder = nn.LSTM(
            input_size=state_dim + action_dim + 1,  # state + action + reward
            hidden_size=hidden_dim,
            batch_first=True
        )
        
        self.context_mean = nn.Linear(hidden_dim, context_dim)
        self.context_logstd = nn.Linear(hidden_dim, context_dim)
        
        # Policy conditioned on context
        self.policy = nn.Sequential(
            nn.Linear(state_dim + context_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
        
        # Q-function conditioned on context
        self.q_function = nn.Sequential(
            nn.Linear(state_dim + action_dim + context_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def encode_context(self, transitions):
        """Encode task from transitions."""
        states, actions, rewards = transitions
        
        # Prepare inputs
        inputs = torch.cat([
            states, 
            F.one_hot(actions, num_classes=self.action_dim).float(),
            rewards.unsqueeze(-1)
        ], dim=-1)
        
        # LSTM encoding
        output, _ = self.context_encoder(inputs.unsqueeze(0))
        pooled = output.mean(dim=1)  # Aggregate over transitions
        
        # Variational encoding
        mean = self.context_mean(pooled)
        logstd = self.context_logstd(pooled)
        
        return mean, logstd
    
    def sample_context(self, mean, logstd):
        """Sample context vector."""
        std = torch.exp(logstd)
        return mean + std * torch.randn_like(std)
    
    def forward(self, state, context):
        """Forward pass conditioned on context."""
        x = torch.cat([state, context], dim=-1)
        return self.policy(x)
    
    def get_q_value(self, state, action, context):
        """Get Q-value conditioned on context."""
        x = torch.cat([state, F.one_hot(action, num_classes=self.action_dim).float(), context], dim=-1)
        return self.q_function(x)


class PEARLAgent:
    """
    PEARL agent for context-based meta-learning.
    """
    
    def __init__(self, state_dim, action_dim, context_dim=5, lr=3e-4, beta=1.0):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.context_dim = context_dim
        self.beta = beta
        
        # PEARL network
        self.network = PEARLNetwork(state_dim, action_dim, context_dim)
        self.optimizer = optim.Adam(self.network.parameters(), lr=lr)
        
        # Experience buffer
        self.buffer = []
        
    def collect_transitions(self, env, context, num_transitions=20):
        """Collect transitions for context encoding."""
        transitions = {
            'states': [],
            'actions': [],
            'rewards': []
        }
        
        obs, _ = env.reset()
        
        for _ in range(num_transitions):
            state_tensor = torch.FloatTensor(obs).unsqueeze(0)
            
            # Get action from policy
            action_logits = self.network(state_tensor, context)
            action_probs = F.softmax(action_logits, dim=-1)
            action = torch.multinomial(action_probs, 1).item()
            
            # Execute action
            next_obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            # Store transition
            transitions['states'].append(obs)
            transitions['actions'].append(action)
            transitions['rewards'].append(reward)
            
            obs = next_obs
            
            if done:
                obs, _ = env.reset()
        
        return transitions
    
    def infer_context(self, transitions):
        """Infer task context from transitions."""
        states = torch.FloatTensor(transitions['states'])
        actions = torch.LongTensor(transitions['actions'])
        rewards = torch.FloatTensor(transitions['rewards'])
        
        mean, logstd = self.network.encode_context((states, actions, rewards))
        context = self.network.sample_context(mean, logstd)
        
        return context, mean, logstd
    
    def compute_kl_loss(self, mean, logstd):
        """Compute KL divergence loss."""
        kl_loss = -0.5 * torch.sum(1 + logstd - mean.pow(2) - logstd.exp())
        return kl_loss
    
    def update(self, batch_size=32):
        """Update PEARL network."""
        if len(self.buffer) < batch_size:
            return
        
        # Sample batch
        batch = random.sample(self.buffer, batch_size)
        
        total_loss = 0
        
        for item in batch:
            states = torch.FloatTensor(item['states'])
            actions = torch.LongTensor(item['actions'])
            rewards = torch.FloatTensor(item['rewards'])
            context = item['context']
            
            # Policy loss (simplified)
            action_logits = self.network(states, context)
            log_probs = F.log_softmax(action_logits, dim=-1)
            policy_loss = -log_probs.gather(1, actions.unsqueeze(1)).mean()
            
            # Q-learning loss (simplified)
            q_values = self.network.get_q_value(states, actions, context)
            q_targets = rewards.unsqueeze(-1)  # Simplified target
            q_loss = F.mse_loss(q_values, q_targets)
            
            # KL loss
            mean, logstd = item['mean'], item['logstd']
            kl_loss = self.compute_kl_loss(mean, logstd)
            
            # Total loss
            loss = policy_loss + q_loss + self.beta * kl_loss
            total_loss += loss
        
        # Update
        total_loss = total_loss / batch_size
        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()
        
        return total_loss.item()
    
    def adapt_to_task(self, env, num_context_transitions=10):
        """Adapt to new task using context inference."""
        # Collect context transitions
        context_transitions = self.collect_transitions(env, None, num_context_transitions)
        
        # Infer context
        context, mean, logstd = self.infer_context(context_transitions)
        
        # Store in buffer
        self.buffer.append({
            'states': context_transitions['states'],
            'actions': context_transitions['actions'],
            'rewards': context_transitions['rewards'],
            'context': context,
            'mean': mean,
            'logstd': logstd
        })
        
        return context


# Test PEARL implementation
print("Testing PEARL Implementation...")

# Create task distribution
task_dist = TaskDistribution('pole_length')
state_dim = 4
action_dim = 2

# Create PEARL agent
pearl_agent = PEARLAgent(state_dim, action_dim, context_dim=5)

# Test context inference
print("Testing context inference...")
task = task_dist.sample_task()
context = pearl_agent.adapt_to_task(task, num_context_transitions=5)
print(f"Inferred context shape: {context.shape}")

# Test action selection
obs, _ = task.reset()
state_tensor = torch.FloatTensor(obs).unsqueeze(0)
action_logits = pearl_agent.network(state_tensor, context)
action = torch.argmax(action_logits, dim=-1)
print(f"Action with context: {action.item()}")

print("PEARL test completed!")


## 5. Training and Evaluation Functions

Let's implement training and evaluation functions for all meta-learning methods.


In [None]:
def train_maml(task_dist, agent, num_meta_iterations=100, tasks_per_iteration=4):
    """Train MAML agent."""
    meta_losses = []
    
    for iteration in trange(num_meta_iterations, desc="Training MAML"):
        # Sample tasks
        tasks = task_dist.sample_tasks(tasks_per_iteration)
        
        # Meta-update
        meta_loss = agent.meta_update(tasks)
        meta_losses.append(meta_loss)
    
    return meta_losses


def train_rl2(task_dist, agent, num_meta_iterations=100, episodes_per_task=5):
    """Train RL² agent."""
    losses = []
    
    for iteration in trange(num_meta_iterations, desc="Training RL²"):
        # Sample task
        task = task_dist.sample_task()
        
        # Collect trajectories
        trajectories = agent.collect_trajectories(task, num_episodes=episodes_per_task)
        
        # Update network
        loss = agent.update(trajectories)
        losses.append(loss)
    
    return losses


def train_pearl(task_dist, agent, num_meta_iterations=100, context_transitions=10):
    """Train PEARL agent."""
    losses = []
    
    for iteration in trange(num_meta_iterations, desc="Training PEARL"):
        # Sample task
        task = task_dist.sample_task()
        
        # Adapt to task
        context = agent.adapt_to_task(task, num_context_transitions)
        
        # Update network
        if len(agent.buffer) >= 32:
            loss = agent.update()
            losses.append(loss)
    
    return losses


def evaluate_few_shot_adaptation(task_dist, agents, num_test_tasks=10, 
                                 adaptation_steps=5, eval_episodes=5):
    """Evaluate few-shot adaptation performance."""
    results = {}
    
    for method_name, agent in agents.items():
        print(f"Evaluating {method_name}...")
        
        task_rewards = []
        
        for task_idx in range(num_test_tasks):
            task = task_dist.sample_task()
            
            if method_name == 'MAML':
                # MAML adaptation
                adapted_network = agent.adapt_to_task(task, num_adaptation_steps)
                
                # Evaluate adapted network
                episode_rewards = []
                for episode in range(eval_episodes):
                    obs, _ = task.reset()
                    episode_reward = 0
                    
                    for step in range(200):
                        state_tensor = torch.FloatTensor(obs).unsqueeze(0)
                        action, _ = adapted_network.get_action(state_tensor, deterministic=True)
                        
                        next_obs, reward, terminated, truncated, _ = task.step(action.item())
                        episode_reward += reward
                        
                        obs = next_obs
                        if terminated or truncated:
                            break
                    
                    episode_rewards.append(episode_reward)
                
                task_rewards.append(np.mean(episode_rewards))
            
            elif method_name == 'RL²':
                # RL² evaluation (no explicit adaptation)
                mean_reward, _ = agent.evaluate_task(task, num_episodes=eval_episodes)
                task_rewards.append(mean_reward)
            
            elif method_name == 'PEARL':
                # PEARL adaptation
                context = agent.adapt_to_task(task, num_context_transitions=adaptation_steps)
                
                # Evaluate with inferred context
                episode_rewards = []
                for episode in range(eval_episodes):
                    obs, _ = task.reset()
                    episode_reward = 0
                    
                    for step in range(200):
                        state_tensor = torch.FloatTensor(obs).unsqueeze(0)
                        action_logits = agent.network(state_tensor, context)
                        action = torch.argmax(action_logits, dim=-1)
                        
                        next_obs, reward, terminated, truncated, _ = task.step(action.item())
                        episode_reward += reward
                        
                        obs = next_obs
                        if terminated or truncated:
                            break
                    
                    episode_rewards.append(episode_reward)
                
                task_rewards.append(np.mean(episode_rewards))
        
        results[method_name] = {
            'mean_reward': np.mean(task_rewards),
            'std_reward': np.std(task_rewards),
            'rewards': task_rewards
        }
    
    return results


def plot_meta_learning_results(results, title="Meta-Learning Results"):
    """Plot meta-learning results."""
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot mean rewards
    ax1 = axes[0]
    methods = list(results.keys())
    mean_rewards = [results[method]['mean_reward'] for method in methods]
    std_rewards = [results[method]['std_reward'] for method in methods]
    
    bars = ax1.bar(methods, mean_rewards, yerr=std_rewards, capsize=5)
    ax1.set_title('Few-Shot Adaptation Performance')
    ax1.set_ylabel('Mean Reward')
    ax1.tick_params(axis='x', rotation=45)
    
    # Add value labels on bars
    for bar, mean, std in zip(bars, mean_rewards, std_rewards):
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height + std + 1,
                f'{mean:.2f}±{std:.2f}', ha='center', va='bottom')
    
    # Plot reward distributions
    ax2 = axes[1]
    for method in methods:
        rewards = results[method]['rewards']
        ax2.hist(rewards, alpha=0.7, label=method, bins=10)
    
    ax2.set_title('Reward Distributions')
    ax2.set_xlabel('Reward')
    ax2.set_ylabel('Frequency')
    ax2.legend()
    
    plt.tight_layout()
    plt.show()


# Test training functions
print("Testing Training Functions...")

# Create task distribution
task_dist = TaskDistribution('pole_length')
state_dim = 4
action_dim = 2

# Create agents
agents = {
    'MAML': MAMLAgent(state_dim, action_dim),
    'RL²': RL2Agent(state_dim, action_dim),
    'PEARL': PEARLAgent(state_dim, action_dim)
}

# Test few-shot adaptation
print("Testing few-shot adaptation...")
results = evaluate_few_shot_adaptation(task_dist, agents, num_test_tasks=5, 
                                     adaptation_steps=3, eval_episodes=3)

print("\\nFew-shot adaptation results:")
for method, result in results.items():
    print(f"{method}: {result['mean_reward']:.2f} ± {result['std_reward']:.2f}")

print("Training functions test completed!")


## 6. Experiments and Analysis

Let's run comprehensive experiments comparing different meta-learning methods on various task distributions.


In [None]:
# @title Run Meta-Learning Experiments

# Set experiment parameters
NUM_META_ITERATIONS = 50  # Reduced for faster execution
NUM_TEST_TASKS = 10
ADAPTATION_STEPS = 3
EVAL_EPISODES = 5

print("Starting Meta-Learning Experiments...")
print(f"Meta-iterations: {NUM_META_ITERATIONS}")
print(f"Test tasks: {NUM_TEST_TASKS}")
print(f"Adaptation steps: {ADAPTATION_STEPS}")
print(f"Evaluation episodes: {EVAL_EPISODES}")
print()

# Test different task distributions
task_types = ['pole_length', 'mass', 'gravity', 'mixed']
all_results = {}

for task_type in task_types:
    print(f"\\n{'='*60}")
    print(f"EXPERIMENT: {task_type.upper()} TASKS")
    print(f"{'='*60}")
    
    # Create task distribution
    task_dist = TaskDistribution(task_type)
    state_dim = 4
    action_dim = 2
    
    # Create agents
    agents = {
        'MAML': MAMLAgent(state_dim, action_dim, meta_lr=0.001, inner_lr=0.01),
        'RL²': RL2Agent(state_dim, action_dim, lr=3e-4),
        'PEARL': PEARLAgent(state_dim, action_dim, context_dim=5, lr=3e-4)
    }
    
    # Train agents
    print("Training agents...")
    
    # Train MAML
    print("Training MAML...")
    maml_losses = train_maml(task_dist, agents['MAML'], 
                           num_meta_iterations=NUM_META_ITERATIONS//2, 
                           tasks_per_iteration=4)
    
    # Train RL²
    print("Training RL²...")
    rl2_losses = train_rl2(task_dist, agents['RL²'], 
                          num_meta_iterations=NUM_META_ITERATIONS, 
                          episodes_per_task=3)
    
    # Train PEARL
    print("Training PEARL...")
    pearl_losses = train_pearl(task_dist, agents['PEARL'], 
                              num_meta_iterations=NUM_META_ITERATIONS, 
                              context_transitions=5)
    
    # Evaluate few-shot adaptation
    print("Evaluating few-shot adaptation...")
    results = evaluate_few_shot_adaptation(task_dist, agents, 
                                         num_test_tasks=NUM_TEST_TASKS,
                                         adaptation_steps=ADAPTATION_STEPS,
                                         eval_episodes=EVAL_EPISODES)
    
    all_results[task_type] = results
    
    # Print results
    print(f"\\nResults for {task_type} tasks:")
    for method, result in results.items():
        print(f"  {method}: {result['mean_reward']:.2f} ± {result['std_reward']:.2f}")

print("\\n" + "="*60)
print("EXPERIMENTAL RESULTS SUMMARY")
print("="*60)

# Print summary table
print("\\nTask Type | MAML | RL² | PEARL")
print("-" * 40)
for task_type, results in all_results.items():
    maml_reward = results['MAML']['mean_reward']
    rl2_reward = results['RL²']['mean_reward']
    pearl_reward = results['PEARL']['mean_reward']
    print(f"{task_type:10} | {maml_reward:4.1f} | {rl2_reward:4.1f} | {pearl_reward:4.1f}")

print("\\nAll experiments completed!")


In [None]:
# @title Plot Results and Analysis

# Plot results for each task type
for task_type, results in all_results.items():
    print(f"\\nPlotting results for {task_type} tasks...")
    plot_meta_learning_results(results, f"Meta-Learning Results - {task_type.title()}")

# Overall analysis
print("\\n" + "="*60)
print("ANALYSIS AND INSIGHTS")
print("="*60)

# Find best performing method for each task type
print("\\nBest performing method by task type:")
for task_type, results in all_results.items():
    best_method = max(results.keys(), key=lambda k: results[k]['mean_reward'])
    best_reward = results[best_method]['mean_reward']
    print(f"  {task_type}: {best_method} ({best_reward:.2f})")

# Overall best method
overall_best = None
overall_best_score = -np.inf
for task_type, results in all_results.items():
    for method, result in results.items():
        if result['mean_reward'] > overall_best_score:
            overall_best_score = result['mean_reward']
            overall_best = method

print(f"\\nOverall best method: {overall_best} ({overall_best_score:.2f})")

# Method comparison
print("\\nMethod Characteristics:")
print("- MAML: Gradient-based meta-learning with explicit inner/outer loops")
print("  - Pros: Theoretically grounded, good for similar tasks")
print("  - Cons: Computationally expensive, requires second-order derivatives")
print()
print("- RL²: Recurrent meta-learning with implicit adaptation")
print("  - Pros: Fast adaptation, no explicit inner loop needed")
print("  - Cons: Black-box adaptation, requires many episodes per task")
print()
print("- PEARL: Context-based meta-learning with probabilistic embeddings")
print("  - Pros: Fast adaptation, interpretable context")
print("  - Cons: Requires context inference, sensitive to context quality")

print("\\nKey Insights:")
print("1. Meta-learning enables rapid adaptation to new tasks")
print("2. Different methods excel in different scenarios:")
print("   - MAML: Good for gradient-based adaptation")
print("   - RL²: Good for sequential task learning")
print("   - PEARL: Good for context-based adaptation")
print("3. Task distribution affects performance significantly")
print("4. Few-shot adaptation is achievable with proper meta-training")

print("\\nMeta-learning provides a powerful framework for 'learning to learn'!")


## 7. Analysis and Discussion Questions

### Key Concepts Demonstrated

1. **Meta-Learning Fundamentals**: Learning to learn across task distributions
2. **MAML**: Gradient-based meta-learning with two-level optimization
3. **RL²**: Recurrent meta-learning with implicit task adaptation
4. **PEARL**: Context-based meta-learning with probabilistic embeddings
5. **Few-Shot Adaptation**: Rapid learning on new tasks with minimal samples

### Discussion Questions

**Answer the following questions based on your experiments:**

1. **Meta-Learning vs Transfer Learning**: How does MAML differ from standard transfer learning? What are the advantages of meta-learning?

2. **Adaptation Mechanisms**: Compare the adaptation mechanisms of MAML, RL², and PEARL. Which approach is most interpretable?

3. **Task Distribution Effects**: How does the task distribution affect meta-learning performance? Which method is most robust to distribution shifts?

4. **Sample Efficiency**: Which method requires the most samples during meta-training? Which is most sample-efficient during adaptation?

5. **Computational Complexity**: Compare the computational requirements of each method. When would you choose each approach?

### Extensions and Future Work

- **First-Order MAML**: Implement FOMAML to reduce computational cost
- **Meta-World Benchmark**: Apply these methods to the Meta-World benchmark
- **Sim-to-Real Transfer**: Use meta-learning for sim-to-real transfer
- **Multi-Task Learning**: Compare meta-learning with multi-task learning
- **Automatic Hyperparameter Selection**: Learn meta-learning hyperparameters

### Conclusion

Meta-learning in reinforcement learning provides a powerful framework for achieving rapid adaptation to new tasks. The methods explored in this assignment demonstrate different approaches to the "learning to learn" paradigm:

- **MAML** provides theoretically grounded gradient-based adaptation
- **RL²** offers fast black-box adaptation through recurrent networks  
- **PEARL** enables interpretable context-based adaptation

The key insight is that **meta-learning enables agents to leverage experience from related tasks to quickly adapt to new situations**, making it particularly valuable for applications requiring rapid deployment in new environments or with new objectives.

The choice of meta-learning method depends on the specific requirements: computational resources, interpretability needs, task similarity, and adaptation speed requirements.
