# HW11: Meta-Learning in Reinforcement Learning

> - Full Name: **[Your Full Name]**
> - Student ID: **[Your Student ID]**

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DeepRLCourse/Homework-11-Questions/blob/main/HW11_Notebook.ipynb)
[![Open In kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://raw.githubusercontent.com/DeepRLCourse/Homework-11-Questions/main/HW11_Notebook.ipynb)

## Overview
This assignment focuses on **Meta-Learning in Reinforcement Learning**, exploring algorithms that enable agents to quickly adapt to new tasks by leveraging experience from related tasks. We'll implement and experiment with:

1. **MAML (Model-Agnostic Meta-Learning)** - Gradient-based meta-learning for RL
2. **RL² (Recurrent Meta-RL)** - Black-box meta-learning using recurrent networks
3. **PEARL (Probabilistic Embeddings)** - Context-based meta-RL with task embeddings
4. **Few-Shot Adaptation** - Rapid learning on new tasks with minimal samples
5. **Task Distributions** - Learning across families of related RL tasks

The goal is to understand how meta-learning enables "learning to learn" and achieves fast adaptation to new tasks.


In [None]:
# @title Imports and Setup

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import gymnasium as gym
from gymnasium import spaces
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import trange
from collections import defaultdict, deque
import random
from typing import Tuple, List, Dict, Optional
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")


## 1. Meta-Learning Environment Setup

First, let's create a task distribution for meta-learning. We'll use a parameterized environment where tasks differ in reward functions or dynamics, enabling us to test few-shot adaptation.


In [None]:
class ParameterizedCartPoleEnv(gym.Env):
    """
    Parameterized CartPole environment for meta-learning.
    Tasks differ in pole length, mass, or reward structure.
    """
    
    def __init__(self, task_params=None):
        super().__init__()
        
        # Default task parameters
        self.default_params = {
            'pole_length': 0.5,
            'pole_mass': 0.1,
            'cart_mass': 1.0,
            'gravity': 9.8,
            'reward_scale': 1.0,
            'success_threshold': 195.0
        }
        
        # Set task parameters
        self.params = self.default_params.copy()
        if task_params:
            self.params.update(task_params)
        
        # State space: [cart_pos, cart_vel, pole_angle, pole_vel]
        self.observation_space = spaces.Box(
            low=-np.inf, high=np.inf, shape=(4,), dtype=np.float32
        )
        
        # Action space: [push_left, push_right]
        self.action_space = spaces.Discrete(2)
        
        # Environment state
        self.state = None
        self.steps = 0
        self.max_steps = 200
        
    def reset(self, **kwargs):
        """Reset environment with random initial state."""
        # Random initial state
        self.state = np.array([
            np.random.uniform(-0.1, 0.1),  # cart position
            np.random.uniform(-0.1, 0.1),  # cart velocity
            np.random.uniform(-0.1, 0.1),  # pole angle
            np.random.uniform(-0.1, 0.1)   # pole velocity
        ], dtype=np.float32)
        
        self.steps = 0
        return self.state.copy(), {}
    
    def step(self, action):
        """Execute action and return next state."""
        if self.state is None:
            raise RuntimeError("Environment not reset")
        
        # Convert action to force
        force = 10.0 if action == 1 else -10.0
        
        # CartPole dynamics with task parameters
        x, x_dot, theta, theta_dot = self.state
        
        # Physical constants
        g = self.params['gravity']
        m_cart = self.params['cart_mass']
        m_pole = self.params['pole_mass']
        l = self.params['pole_length']
        
        # Total mass
        total_mass = m_cart + m_pole
        
        # Dynamics
        temp = (force + m_pole * l * theta_dot**2 * np.sin(theta)) / total_mass
        theta_acc = (g * np.sin(theta) - np.cos(theta) * temp) / \
                   (l * (4/3 - m_pole * np.cos(theta)**2 / total_mass))
        x_acc = temp - m_pole * l * theta_acc * np.cos(theta) / total_mass
        
        # Update state
        x = x + x_dot * 0.02
        x_dot = x_dot + x_acc * 0.02
        theta = theta + theta_dot * 0.02
        theta_dot = theta_dot + theta_acc * 0.02
        
        self.state = np.array([x, x_dot, theta, theta_dot], dtype=np.float32)
        self.steps += 1
        
        # Compute reward
        reward = self._compute_reward()
        
        # Check termination
        terminated = self._is_terminated()
        truncated = self.steps >= self.max_steps
        
        return self.state.copy(), reward, terminated, truncated, {}
    
    def _compute_reward(self):
        """Compute reward based on task parameters."""
        x, x_dot, theta, theta_dot = self.state
        
        # Base reward: stay upright and centered
        reward = 1.0
        
        # Penalty for being far from center
        reward -= abs(x) * 0.1
        
        # Penalty for large angle
        reward -= abs(theta) * 0.1
        
        # Scale reward
        reward *= self.params['reward_scale']
        
        return reward
    
    def _is_terminated(self):
        """Check if episode should terminate."""
        x, x_dot, theta, theta_dot = self.state
        
        # Terminate if cart goes too far
        if abs(x) > 2.4:
            return True
        
        # Terminate if pole falls too far
        if abs(theta) > 0.2095:  # ~12 degrees
            return True
        
        return False
    
    def set_task_params(self, task_params):
        """Set new task parameters."""
        self.params.update(task_params)


class TaskDistribution:
    """
    Distribution over parameterized tasks for meta-learning.
    """
    
    def __init__(self, task_type='pole_length'):
        self.task_type = task_type
        
        if task_type == 'pole_length':
            self.param_ranges = {
                'pole_length': (0.3, 0.8),
                'reward_scale': (0.8, 1.2)
            }
        elif task_type == 'mass':
            self.param_ranges = {
                'pole_mass': (0.05, 0.2),
                'cart_mass': (0.8, 1.5),
                'reward_scale': (0.8, 1.2)
            }
        elif task_type == 'gravity':
            self.param_ranges = {
                'gravity': (8.0, 12.0),
                'reward_scale': (0.8, 1.2)
            }
        elif task_type == 'mixed':
            self.param_ranges = {
                'pole_length': (0.3, 0.8),
                'pole_mass': (0.05, 0.2),
                'gravity': (8.0, 12.0),
                'reward_scale': (0.8, 1.2)
            }
    
    def sample_task(self):
        """Sample a random task from the distribution."""
        task_params = {}
        for param, (low, high) in self.param_ranges.items():
            task_params[param] = np.random.uniform(low, high)
        
        return ParameterizedCartPoleEnv(task_params)
    
    def sample_tasks(self, n_tasks):
        """Sample multiple tasks."""
        return [self.sample_task() for _ in range(n_tasks)]


# Test the environment
print("Testing Parameterized CartPole Environment...")

# Test different task types
task_types = ['pole_length', 'mass', 'gravity', 'mixed']
for task_type in task_types:
    print(f"\nTesting {task_type} tasks:")
    task_dist = TaskDistribution(task_type)
    
    # Sample a few tasks
    for i in range(3):
        task = task_dist.sample_task()
        obs, _ = task.reset()
        
        # Run a few steps
        total_reward = 0
        for step in range(10):
            action = task.action_space.sample()
            obs, reward, terminated, truncated, _ = task.step(action)
            total_reward += reward
            
            if terminated or truncated:
                break
        
        print(f"  Task {i+1}: Reward = {total_reward:.2f}, Params = {task.params}")

print("\nEnvironment test completed!")


## 2. MAML (Model-Agnostic Meta-Learning) Implementation

MAML finds initialization parameters that are good for fine-tuning on new tasks. It uses a two-level optimization process:
- **Inner Loop**: Adapt to specific task using gradient descent
- **Outer Loop**: Optimize initialization for fast adaptation across tasks


In [None]:
class PolicyNetwork(nn.Module):
    """
    Policy network for MAML.
    """
    
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super().__init__()
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        
        # Policy network
        self.policy = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
        
        # Value network
        self.value = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, state):
        """Forward pass through policy and value networks."""
        policy_logits = self.policy(state)
        value = self.value(state)
        return policy_logits, value
    
    def get_action(self, state, deterministic=False):
        """Get action from policy."""
        policy_logits, value = self.forward(state)
        
        if deterministic:
            action = torch.argmax(policy_logits, dim=-1)
        else:
            action_probs = F.softmax(policy_logits, dim=-1)
            action = torch.multinomial(action_probs, 1).squeeze(-1)
        
        return action, value
    
    def log_prob(self, state, action):
        """Get log probability of action."""
        policy_logits, _ = self.forward(state)
        log_probs = F.log_softmax(policy_logits, dim=-1)
        return log_probs.gather(1, action.unsqueeze(1)).squeeze(1)


class MAMLAgent:
    """
    MAML agent for meta-learning in RL.
    """
    
    def __init__(self, state_dim, action_dim, meta_lr=0.001, inner_lr=0.01, 
                 inner_steps=1, hidden_dim=64):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.meta_lr = meta_lr
        self.inner_lr = inner_lr
        self.inner_steps = inner_steps
        
        # Meta-network (initialization)
        self.meta_network = PolicyNetwork(state_dim, action_dim, hidden_dim)
        self.meta_optimizer = optim.Adam(self.meta_network.parameters(), lr=meta_lr)
        
        # Copy of network for inner loop updates
        self.inner_network = PolicyNetwork(state_dim, action_dim, hidden_dim)
        
    def collect_trajectories(self, env, network, num_episodes=5, max_steps=200):
        """Collect trajectories using given network."""
        trajectories = []
        
        for episode in range(num_episodes):
            obs, _ = env.reset()
            episode_data = {
                'states': [],
                'actions': [],
                'rewards': [],
                'values': [],
                'log_probs': []
            }
            
            for step in range(max_steps):
                state_tensor = torch.FloatTensor(obs).unsqueeze(0)
                
                # Get action and value
                action, value = network.get_action(state_tensor)
                log_prob = network.log_prob(state_tensor, action)
                
                # Execute action
                next_obs, reward, terminated, truncated, _ = env.step(action.item())
                done = terminated or truncated
                
                # Store data
                episode_data['states'].append(obs)
                episode_data['actions'].append(action.item())
                episode_data['rewards'].append(reward)
                episode_data['values'].append(value.item())
                episode_data['log_probs'].append(log_prob.item())
                
                obs = next_obs
                
                if done:
                    break
            
            trajectories.append(episode_data)
        
        return trajectories
    
    def compute_returns(self, rewards, gamma=0.99):
        """Compute discounted returns."""
        returns = []
        G = 0
        
        for reward in reversed(rewards):
            G = reward + gamma * G
            returns.insert(0, G)
        
        return returns
    
    def compute_advantages(self, returns, values):
        """Compute advantages using returns and values."""
        advantages = []
        for ret, val in zip(returns, values):
            advantages.append(ret - val)
        return advantages
    
    def compute_loss(self, trajectories, network):
        """Compute policy gradient loss."""
        total_loss = 0
        total_samples = 0
        
        for trajectory in trajectories:
            states = torch.FloatTensor(trajectory['states'])
            actions = torch.LongTensor(trajectory['actions'])
            rewards = trajectory['rewards']
            values = torch.FloatTensor(trajectory['values'])
            
            # Compute returns
            returns = self.compute_returns(rewards)
            returns = torch.FloatTensor(returns)
            
            # Compute advantages
            advantages = self.compute_advantages(returns, values)
            advantages = torch.FloatTensor(advantages)
            
            # Normalize advantages
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
            
            # Policy loss
            log_probs = network.log_prob(states, actions)
            policy_loss = -(log_probs * advantages).mean()
            
            # Value loss
            predicted_values = network.value(states).squeeze()
            value_loss = F.mse_loss(predicted_values, returns)
            
            # Total loss
            loss = policy_loss + 0.5 * value_loss
            total_loss += loss * len(states)
            total_samples += len(states)
        
        return total_loss / total_samples if total_samples > 0 else torch.tensor(0.0)
    
    def inner_loop_update(self, task, network):
        """Perform inner loop update on a specific task."""
        # Collect trajectories
        trajectories = self.collect_trajectories(task, network, num_episodes=3)
        
        # Compute loss
        loss = self.compute_loss(trajectories, network)
        
        # Compute gradients
        gradients = torch.autograd.grad(loss, network.parameters(), create_graph=True)
        
        # Update parameters
        updated_params = []
        for param, grad in zip(network.parameters(), gradients):
            updated_params.append(param - self.inner_lr * grad)
        
        return updated_params, loss
    
    def meta_update(self, tasks, num_tasks=4):
        """Perform meta-update across multiple tasks."""
        meta_loss = 0
        
        for task in tasks[:num_tasks]:
            # Copy meta-network for inner loop
            self.inner_network.load_state_dict(self.meta_network.state_dict())
            
            # Inner loop adaptation
            updated_params, inner_loss = self.inner_loop_update(task, self.inner_network)
            
            # Collect test trajectories with adapted network
            # Create temporary network with updated parameters
            temp_network = PolicyNetwork(self.state_dim, self.action_dim)
            temp_network.load_state_dict(self.meta_network.state_dict())
            
            # Manually set updated parameters
            for param, updated_param in zip(temp_network.parameters(), updated_params):
                param.data = updated_param
            
            # Test trajectories
            test_trajectories = self.collect_trajectories(task, temp_network, num_episodes=2)
            test_loss = self.compute_loss(test_trajectories, temp_network)
            
            meta_loss += test_loss
        
        # Meta-gradient update
        meta_loss = meta_loss / num_tasks
        self.meta_optimizer.zero_grad()
        meta_loss.backward()
        self.meta_optimizer.step()
        
        return meta_loss.item()
    
    def adapt_to_task(self, task, num_adaptation_steps=5):
        """Adapt to a new task using MAML."""
        # Copy meta-network
        adapted_network = PolicyNetwork(self.state_dim, self.action_dim)
        adapted_network.load_state_dict(self.meta_network.state_dict())
        
        # Inner loop adaptation
        for step in range(num_adaptation_steps):
            # Collect trajectories
            trajectories = self.collect_trajectories(task, adapted_network, num_episodes=2)
            
            # Compute loss
            loss = self.compute_loss(trajectories, adapted_network)
            
            # Update parameters
            optimizer = optim.Adam(adapted_network.parameters(), lr=self.inner_lr)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        return adapted_network


# Test MAML implementation
print("Testing MAML Implementation...")

# Create task distribution
task_dist = TaskDistribution('pole_length')
state_dim = 4  # CartPole state dimension
action_dim = 2  # CartPole action dimension

# Create MAML agent
maml_agent = MAMLAgent(state_dim, action_dim, meta_lr=0.001, inner_lr=0.01)

# Test meta-update
print("Testing meta-update...")
tasks = task_dist.sample_tasks(4)
meta_loss = maml_agent.meta_update(tasks)
print(f"Meta-loss: {meta_loss:.4f}")

# Test adaptation
print("Testing task adaptation...")
new_task = task_dist.sample_task()
adapted_network = maml_agent.adapt_to_task(new_task, num_adaptation_steps=3)

# Test adapted network
obs, _ = new_task.reset()
state_tensor = torch.FloatTensor(obs).unsqueeze(0)
action, value = adapted_network.get_action(state_tensor)
print(f"Adapted network action: {action.item()}, value: {value.item():.4f}")

print("MAML test completed!")


## 3. RL² (Recurrent Meta-RL) Implementation

RL² uses recurrent networks to encode task information implicitly. The LSTM hidden state learns to adapt to different tasks without explicit inner loop optimization.
