# HW12: Hierarchical Reinforcement Learning

**Course:** Deep Reinforcement Learning  
**Assignment:** Homework 12 - Hierarchical RL  
**Date:** 2024

---

## Overview

Hierarchical Reinforcement Learning (HRL) structures policies across multiple levels of abstraction, enabling agents to solve complex, long-horizon tasks by decomposing them into simpler subtasks. This assignment explores temporal abstraction, options framework, feudal architectures, and goal-conditioned policies.

## Learning Objectives

1. **Temporal Abstraction**: Understand multi-scale decision making
2. **Options Framework**: Master semi-Markov decision processes
3. **Feudal Hierarchies**: Learn manager-worker architectures
4. **Goal-Conditioned RL**: Train policies with diverse goals
5. **Skill Discovery**: Learn reusable primitives automatically
6. **Credit Assignment**: Address challenges across temporal scales

## Table of Contents

1. [Introduction to Hierarchical RL](#introduction)
2. [Options Framework](#options-framework)
3. [Feudal Hierarchies](#feudal-hierarchies)
4. [Goal-Conditioned RL](#goal-conditioned-rl)
5. [Skill Discovery](#skill-discovery)
6. [HAM Framework](#ham-framework)
7. [Evaluation and Comparison](#evaluation)
8. [Conclusion](#conclusion)


In [None]:
# Import necessary libraries
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict, deque
import gym
import random
from typing import Dict, List, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")


## 1. Introduction to Hierarchical RL

### Motivation for Hierarchy

**Challenges in Flat RL:**
- **Long Horizons**: Credit assignment difficult over 1000+ steps
- **Sparse Rewards**: Random exploration ineffective
- **Complex Tasks**: Atomic actions insufficient
- **Transfer**: Hard to reuse learned behaviors

**Benefits of Hierarchy:**
- Temporal abstraction (plan at multiple scales)
- Reusable skills/subpolicies
- Exploration structure
- Transfer learning
- Compositional generalization

### Human Example:
```
Task: Make dinner
├─ Shop for ingredients
│  ├─ Drive to store
│  ├─ Find items
│  └─ Checkout
├─ Prepare food
│  ├─ Chop vegetables
│  ├─ Cook proteins
│  └─ Mix ingredients
└─ Serve meal
```


In [None]:
# Exercise 1.1: Create a Simple GridWorld Environment for HRL
class GridWorld:
    """
    A simple grid world environment for testing hierarchical RL algorithms.
    The agent must navigate from start to goal, potentially using hierarchical actions.
    """
    
    def __init__(self, width=10, height=10, start=(0, 0), goal=(9, 9)):
        self.width = width
        self.height = height
        self.start = start
        self.goal = goal
        self.state = start
        
        # Define atomic actions: up, down, left, right
        self.atomic_actions = [(0, 1), (0, -1), (-1, 0), (1, 0)]
        self.action_names = ['up', 'down', 'left', 'right']
        
        # Define hierarchical actions (options)
        self.options = {}
        
    def reset(self):
        """Reset environment to initial state"""
        self.state = self.start
        return self.state
        
    def step(self, action):
        """Take a step in the environment"""
        if isinstance(action, int):
            # Atomic action
            dx, dy = self.atomic_actions[action]
            new_x = max(0, min(self.width-1, self.state[0] + dx))
            new_y = max(0, min(self.height-1, self.state[1] + dy))
            self.state = (new_x, new_y)
        else:
            # Hierarchical action (option)
            self.state = action
            
        # Calculate reward
        reward = 1.0 if self.state == self.goal else -0.01
        
        # Check if done
        done = self.state == self.goal
        
        return self.state, reward, done, {}
    
    def get_distance_to_goal(self, state):
        """Calculate Manhattan distance to goal"""
        return abs(state[0] - self.goal[0]) + abs(state[1] - self.goal[1])
    
    def render(self):
        """Render the current state"""
        grid = np.zeros((self.height, self.width))
        grid[self.start[1], self.start[0]] = 1  # Start
        grid[self.goal[1], self.goal[0]] = 2     # Goal
        grid[self.state[1], self.state[0]] = 3   # Current position
        
        plt.figure(figsize=(8, 8))
        plt.imshow(grid, cmap='viridis')
        plt.title(f"GridWorld - Current: {self.state}, Goal: {self.goal}")
        plt.show()

# Test the environment
env = GridWorld()
print("GridWorld Environment Created!")
print(f"Start: {env.start}, Goal: {env.goal}")
print(f"Atomic actions: {env.action_names}")

# Test a few steps
state = env.reset()
print(f"Initial state: {state}")

for i in range(5):
    action = np.random.randint(4)  # Random atomic action
    next_state, reward, done, _ = env.step(action)
    print(f"Step {i+1}: Action {env.action_names[action]} -> State {next_state}, Reward {reward:.2f}")
    if done:
        print("Goal reached!")
        break


## 2. Options Framework

### Formal Definition

An **Option** is a temporally extended action defined as:
```
Option ω = (I_ω, π_ω, β_ω)

where:
- I_ω ⊆ S: Initiation set (where option can start)
- π_ω: S × A → [0,1]: Option policy
- β_ω: S → [0,1]: Termination function
```

### Semi-Markov Decision Process (SMDP)

Instead of choosing action at each step, choose option, execute until termination.

**Option-Value Functions:**
- Q(s, ω) = Expected return from executing option ω in state s
- Intra-option learning: Can update Q while executing option
gi

In [None]:
# Exercise 2.1: Implement Handcrafted Options
class Option:
    """
    A handcrafted option for navigation tasks.
    Each option represents a skill like "move towards goal" or "explore area".
    """
    
    def __init__(self, name, initiation_set, policy_func, termination_func):
        self.name = name
        self.initiation_set = initiation_set
        self.policy_func = policy_func
        self.termination_func = termination_func
        
    def can_initiate(self, state):
        """Check if option can be initiated in given state"""
        return state in self.initiation_set
    
    def get_action(self, state):
        """Get action from option policy"""
        return self.policy_func(state)
    
    def should_terminate(self, state):
        """Check if option should terminate"""
        return self.termination_func(state)

class NavigateToGoalOption(Option):
    """Option that navigates towards a specific goal location"""
    
    def __init__(self, goal_location, env):
        self.goal = goal_location
        self.env = env
        
        # Can initiate from any state
        initiation_set = set()
        for x in range(env.width):
            for y in range(env.height):
                initiation_set.add((x, y))
        
        super().__init__(
            name=f"NavigateTo{goal_location}",
            initiation_set=initiation_set,
            policy_func=self._navigate_policy,
            termination_func=self._termination_condition
        )
    
    def _navigate_policy(self, state):
        """Policy: move towards goal using Manhattan distance"""
        current_x, current_y = state
        goal_x, goal_y = self.goal
        
        # Calculate direction to goal
        dx = goal_x - current_x
        dy = goal_y - current_y
        
        # Choose action that moves towards goal
        if abs(dx) > abs(dy):
            return 3 if dx > 0 else 2  # right or left
        else:
            return 0 if dy > 0 else 1  # up or down
    
    def _termination_condition(self, state):
        """Terminate when close to goal"""
        distance = self.env.get_distance_to_goal(state)
        return distance <= 1

class ExploreOption(Option):
    """Option that explores the environment randomly"""
    
    def __init__(self, env, exploration_steps=5):
        self.env = env
        self.exploration_steps = exploration_steps
        self.steps_taken = 0
        
        # Can initiate from any state
        initiation_set = set()
        for x in range(env.width):
            for y in range(env.height):
                initiation_set.add((x, y))
        
        super().__init__(
            name="Explore",
            initiation_set=initiation_set,
            policy_func=self._explore_policy,
            termination_func=self._termination_condition
        )
    
    def _explore_policy(self, state):
        """Policy: random exploration"""
        self.steps_taken += 1
        return np.random.randint(4)  # Random atomic action
    
    def _termination_condition(self, state):
        """Terminate after exploration_steps"""
        if self.steps_taken >= self.exploration_steps:
            self.steps_taken = 0  # Reset for next use
            return True
        return False

# Test handcrafted options
env = GridWorld()
goal_option = NavigateToGoalOption((9, 9), env)
explore_option = ExploreOption(env, exploration_steps=3)

print("Handcrafted Options Created:")
print(f"1. {goal_option.name}")
print(f"2. {explore_option.name}")

# Test goal navigation option
state = env.reset()
print(f"\nTesting {goal_option.name}:")
print(f"Initial state: {state}")

for step in range(10):
    if goal_option.can_initiate(state):
        action = goal_option.get_action(state)
        next_state, reward, done, _ = env.step(action)
        print(f"Step {step+1}: Action {env.action_names[action]} -> State {next_state}")
        
        if goal_option.should_terminate(next_state):
            print("Option terminated!")
            break
        state = next_state
    else:
        print("Cannot initiate option from current state")
        break


In [None]:
# Exercise 2.2: Implement Option-Critic Architecture
class OptionCritic(nn.Module):
    """
    Option-Critic architecture for learning options automatically.
    Learns option policies, termination functions, and option-value functions.
    """
    
    def __init__(self, state_dim, num_options, action_dim, hidden_dim=64):
        super().__init__()
        
        self.state_dim = state_dim
        self.num_options = num_options
        self.action_dim = action_dim
        
        # Shared representation
        self.encoder = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # Option policies
        self.option_policies = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, action_dim)
            ) for _ in range(num_options)
        ])
        
        # Termination functions
        self.terminations = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_options),
            nn.Sigmoid()
        )
        
        # Q-value over options
        self.q_omega = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_options)
        )
        
        # Intra-option Q-values
        self.q_intra = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, action_dim)
            ) for _ in range(num_options)
        ])
    
    def forward(self, state, current_option=None):
        """
        Forward pass through the network
        
        Args:
            state: Current state
            current_option: Current active option (if any)
        
        Returns:
            If current_option is provided:
                - action_logits: Action probabilities for current option
                - beta: Termination probability for current option
            Else:
                - q_omega: Q-values for all options
        """
        features = self.encoder(state)
        
        if current_option is not None:
            # Get action from current option
            action_logits = self.option_policies[current_option](features)
            
            # Termination probability
            beta = self.terminations(features)[:, current_option]
            
            return action_logits, beta
        else:
            # Select option
            q_omega = self.q_omega(features)
            return q_omega
    
    def get_intra_option_q(self, state, option):
        """Get intra-option Q-values"""
        features = self.encoder(state)
        return self.q_intra[option](features)
    
    def select_option(self, state, epsilon=0.1):
        """Select option using epsilon-greedy policy"""
        with torch.no_grad():
            q_values = self.forward(state)
            if np.random.random() < epsilon:
                return np.random.randint(self.num_options)
            else:
                return q_values.argmax().item()
    
    def get_action(self, state, option):
        """Get action from option policy"""
        with torch.no_grad():
            action_logits, _ = self.forward(state, option)
            action_probs = F.softmax(action_logits, dim=-1)
            action = torch.multinomial(action_probs, 1).item()
            return action

class OptionCriticAgent:
    """Agent that uses Option-Critic architecture"""
    
    def __init__(self, state_dim, num_options, action_dim, lr=1e-3):
        self.model = OptionCritic(state_dim, num_options, action_dim)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        
        self.num_options = num_options
        self.current_option = None
        self.option_steps = 0
        
    def reset(self):
        """Reset agent state"""
        self.current_option = None
        self.option_steps = 0
    
    def act(self, state, epsilon=0.1):
        """Select action using current option or select new option"""
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        
        # Check if we need to select a new option
        if self.current_option is None:
            self.current_option = self.model.select_option(state_tensor, epsilon)
            self.option_steps = 0
        
        # Get action from current option
        action = self.model.get_action(state_tensor, self.current_option)
        
        # Check termination
        with torch.no_grad():
            _, beta = self.model.forward(state_tensor, self.current_option)
            if np.random.random() < beta.item() or self.option_steps > 20:
                self.current_option = None  # Terminate option
        
        self.option_steps += 1
        return action
    
    def update(self, batch):
        """Update the model using a batch of experiences"""
        states, actions, rewards, next_states, options, dones = batch
        
        # Convert to tensors
        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.FloatTensor(next_states)
        options = torch.LongTensor(options)
        dones = torch.BoolTensor(dones)
        
        # Compute losses
        q_loss = self._compute_q_loss(states, actions, rewards, next_states, options, dones)
        policy_loss = self._compute_policy_loss(states, actions, options)
        termination_loss = self._compute_termination_loss(states, next_states, options)
        
        # Total loss
        total_loss = q_loss + policy_loss + termination_loss
        
        # Optimize
        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()
        
        return {
            'q_loss': q_loss.item(),
            'policy_loss': policy_loss.item(),
            'termination_loss': termination_loss.item(),
            'total_loss': total_loss.item()
        }
    
    def _compute_q_loss(self, states, actions, rewards, next_states, options, dones):
        """Compute Q-learning loss"""
        # Intra-option Q-values
        q_values = []
        for i, option in enumerate(options):
            q_val = self.model.get_intra_option_q(states[i:i+1], option.item())
            q_values.append(q_val[0, actions[i]])
        q_values = torch.stack(q_values)
        
        # Target Q-values
        with torch.no_grad():
            next_q_values = []
            for i, option in enumerate(options):
                if dones[i]:
                    next_q_val = 0
                else:
                    next_q_val = self.model.get_intra_option_q(next_states[i:i+1], option.item()).max()
                next_q_values.append(next_q_val)
            next_q_values = torch.stack(next_q_values)
            
            targets = rewards + 0.99 * next_q_values
        
        return F.mse_loss(q_values, targets)
    
    def _compute_policy_loss(self, states, actions, options):
        """Compute policy gradient loss"""
        policy_loss = 0
        for i, option in enumerate(options):
            action_logits, _ = self.model.forward(states[i:i+1], option.item())
            log_probs = F.log_softmax(action_logits, dim=-1)
            policy_loss -= log_probs[0, actions[i]]
        
        return policy_loss / len(options)
    
    def _compute_termination_loss(self, states, next_states, options):
        """Compute termination function loss"""
        termination_loss = 0
        for i, option in enumerate(options):
            _, beta = self.model.forward(states[i:i+1], option.item())
            # Encourage termination when option is no longer useful
            termination_loss += beta[0]
        
        return termination_loss / len(options)

# Test Option-Critic
print("Option-Critic Architecture Created!")
print("Testing with GridWorld environment...")

# Create agent
state_dim = 2  # (x, y) coordinates
num_options = 3
action_dim = 4  # up, down, left, right

agent = OptionCriticAgent(state_dim, num_options, action_dim)
env = GridWorld()

# Test agent
state = env.reset()
print(f"Initial state: {state}")

for step in range(10):
    action = agent.act(state)
    next_state, reward, done, _ = env.step(action)
    print(f"Step {step+1}: Action {env.action_names[action]} -> State {next_state}, Reward {reward:.2f}")
    
    if done:
        print("Goal reached!")
        break
    state = next_state

agent.reset()
print("Agent reset for next episode.")


## 3. Feudal Hierarchies

### Key Idea: Manager-Worker Architecture

**FeudalNet (Feudal Networks):**
- **Manager**: Sets goals at high level
- **Worker**: Achieves goals at low level
- **Communication**: Manager provides goals to worker

### Architecture Components:

1. **Perception Module**: Shared state representation
2. **Manager**: LSTM that sets goals every c timesteps
3. **Worker**: LSTM that receives goals and produces actions
4. **Reward Structure**: 
   - Manager reward: Cosine similarity between goals and state transitions
   - Worker reward: Intrinsic + extrinsic rewards


In [None]:
# Exercise 3.1: Implement FeudalNet Architecture
class FeudalNet(nn.Module):
    """
    Feudal Networks implementation with Manager-Worker hierarchy.
    Manager sets goals, Worker achieves them.
    """
    
    def __init__(self, state_dim, action_dim, goal_dim=8, c=10, hidden_dim=256):
        super().__init__()
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.goal_dim = goal_dim
        self.c = c  # Manager horizon
        
        # Perception module (shared)
        self.perception = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # Manager (sets goals)
        self.manager = nn.LSTM(hidden_dim, goal_dim, batch_first=True)
        
        # Worker (achieves goals)
        self.worker = nn.LSTM(hidden_dim + goal_dim, hidden_dim, batch_first=True)
        self.worker_policy = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
        
        # Initialize hidden states
        self.manager_hidden = None
        self.worker_hidden = None
        
    def forward(self, state, t):
        """
        Forward pass through FeudalNet
        
        Args:
            state: Current state
            t: Current timestep
        
        Returns:
            action_logits: Action probabilities
            goal: Current goal from manager
            manager_hidden: Updated manager hidden state
            worker_hidden: Updated worker hidden state
        """
        batch_size = state.shape[0]
        
        # Shared perception
        z = self.perception(state)
        
        # Manager operates every c timesteps
        if t % self.c == 0:
            # Manager sets goal
            if self.manager_hidden is None:
                self.manager_hidden = self._init_hidden(batch_size, self.manager)
            
            z_reshaped = z.unsqueeze(1)  # Add sequence dimension
            goal_output, self.manager_hidden = self.manager(z_reshaped, self.manager_hidden)
            goal = F.normalize(goal_output.squeeze(1), dim=-1)  # Normalize goal
        else:
            # Use previous goal
            goal = getattr(self, 'current_goal', torch.zeros(batch_size, self.goal_dim))
        
        self.current_goal = goal
        
        # Worker receives goal and state
        if self.worker_hidden is None:
            self.worker_hidden = self._init_hidden(batch_size, self.worker)
        
        w_input = torch.cat([z, goal], dim=-1)
        w_input_reshaped = w_input.unsqueeze(1)  # Add sequence dimension
        w_output, self.worker_hidden = self.worker(w_input_reshaped, self.worker_hidden)
        
        # Worker action
        action_logits = self.worker_policy(w_output.squeeze(1))
        
        return action_logits, goal, self.manager_hidden, self.worker_hidden
    
    def _init_hidden(self, batch_size, lstm):
        """Initialize LSTM hidden states"""
        h0 = torch.zeros(1, batch_size, lstm.hidden_size)
        c0 = torch.zeros(1, batch_size, lstm.hidden_size)
        return (h0, c0)
    
    def reset_hidden(self):
        """Reset hidden states for new episode"""
        self.manager_hidden = None
        self.worker_hidden = None
        self.current_goal = None

class FeudalAgent:
    """Agent using FeudalNet architecture"""
    
    def __init__(self, state_dim, action_dim, goal_dim=8, c=10, lr=1e-3):
        self.model = FeudalNet(state_dim, action_dim, goal_dim, c)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.goal_dim = goal_dim
        self.c = c
        
        # Experience buffer
        self.buffer = []
        self.max_buffer_size = 10000
        
    def act(self, state, t, epsilon=0.1):
        """Select action using FeudalNet"""
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        
        with torch.no_grad():
            action_logits, goal, _, _ = self.model(state_tensor, t)
            
            if np.random.random() < epsilon:
                action = np.random.randint(self.action_dim)
            else:
                action_probs = F.softmax(action_logits, dim=-1)
                action = torch.multinomial(action_probs, 1).item()
        
        return action, goal.squeeze(0).numpy()
    
    def store_experience(self, state, action, reward, next_state, goal, t):
        """Store experience in buffer"""
        experience = {
            'state': state,
            'action': action,
            'reward': reward,
            'next_state': next_state,
            'goal': goal,
            't': t
        }
        
        self.buffer.append(experience)
        if len(self.buffer) > self.max_buffer_size:
            self.buffer.pop(0)
    
    def compute_feudal_rewards(self, trajectory):
        """Compute manager and worker rewards for feudal training"""
        manager_rewards = []
        worker_rewards = []
        
        for i in range(len(trajectory) - self.c):
            # Manager reward: cosine similarity between goal and state transition
            state_i = trajectory[i]['state']
            state_i_plus_c = trajectory[i + self.c]['state']
            
            # Compute state transition vector
            transition = np.array(state_i_plus_c) - np.array(state_i)
            
            # Manager reward (cosine similarity)
            goal = trajectory[i]['goal']
            if np.linalg.norm(transition) > 0 and np.linalg.norm(goal) > 0:
                cosine_sim = np.dot(transition, goal) / (np.linalg.norm(transition) * np.linalg.norm(goal))
                manager_reward = cosine_sim
            else:
                manager_reward = 0
            
            manager_rewards.append(manager_reward)
            
            # Worker reward: intrinsic + extrinsic
            intrinsic_reward = manager_reward  # Progress toward goal
            extrinsic_reward = trajectory[i]['reward']
            worker_reward = extrinsic_reward + 0.1 * intrinsic_reward
            
            worker_rewards.append(worker_reward)
        
        return manager_rewards, worker_rewards
    
    def update(self, batch_size=32):
        """Update the model using feudal rewards"""
        if len(self.buffer) < batch_size:
            return {}
        
        # Sample batch
        batch = random.sample(self.buffer, batch_size)
        
        # Compute feudal rewards
        manager_rewards, worker_rewards = self.compute_feudal_rewards(batch)
        
        # Convert to tensors
        states = torch.FloatTensor([exp['state'] for exp in batch])
        actions = torch.LongTensor([exp['action'] for exp in batch])
        rewards = torch.FloatTensor(worker_rewards[:len(batch)])
        
        # Forward pass
        action_logits, goals, _, _ = self.model(states, 0)
        
        # Compute losses
        action_probs = F.softmax(action_logits, dim=-1)
        log_probs = F.log_softmax(action_logits, dim=-1)
        
        # Policy loss (worker)
        policy_loss = -(log_probs.gather(1, actions.unsqueeze(1)) * rewards.unsqueeze(1)).mean()
        
        # Total loss
        total_loss = policy_loss
        
        # Optimize
        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()
        
        return {
            'policy_loss': policy_loss.item(),
            'total_loss': total_loss.item()
        }
    
    def reset(self):
        """Reset agent for new episode"""
        self.model.reset_hidden()

# Test FeudalNet
print("FeudalNet Architecture Created!")
print("Testing with GridWorld environment...")

# Create agent
state_dim = 2  # (x, y) coordinates
action_dim = 4  # up, down, left, right
goal_dim = 8
c = 5  # Manager horizon

agent = FeudalAgent(state_dim, action_dim, goal_dim, c)
env = GridWorld()

# Test agent
state = env.reset()
print(f"Initial state: {state}")

for step in range(15):
    action, goal = agent.act(state, step)
    next_state, reward, done, _ = env.step(action)
    
    print(f"Step {step+1}: Action {env.action_names[action]} -> State {next_state}, Reward {reward:.2f}")
    print(f"  Goal: {goal[:3]}...")  # Show first 3 dimensions
    
    # Store experience
    agent.store_experience(state, action, reward, next_state, goal, step)
    
    if done:
        print("Goal reached!")
        break
    state = next_state

# Update model
losses = agent.update()
print(f"\nTraining losses: {losses}")

agent.reset()
print("Agent reset for next episode.")


## 4. Goal-Conditioned RL

### Key Idea: Train policy to reach any goal state

**Universal Value Function Approximators (UVFA):**
- Policy conditioned on current state and goal: π(a|s,g)
- Q-function conditioned on state, action, and goal: Q(s,a,g)

**Hindsight Experience Replay (HER):**
- Augment failed trajectories with alternative goals
- "What if the goal was different?"
- Dramatically improves sample efficiency in sparse reward settings


In [None]:
# Exercise 4.1: Implement Goal-Conditioned Policy
class GoalConditionedPolicy(nn.Module):
    """
    Goal-conditioned policy that learns to reach any goal state.
    Uses Universal Value Function Approximators (UVFA).
    """
    
    def __init__(self, state_dim, goal_dim, action_dim, hidden_dim=256):
        super().__init__()
        
        self.state_dim = state_dim
        self.goal_dim = goal_dim
        self.action_dim = action_dim
        
        # Policy network: π(a|s,g)
        self.policy = nn.Sequential(
            nn.Linear(state_dim + goal_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
        
        # Q-function: Q(s,a,g)
        self.q_function = nn.Sequential(
            nn.Linear(state_dim + action_dim + goal_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
    def forward(self, state, goal):
        """
        Forward pass through goal-conditioned policy
        
        Args:
            state: Current state
            goal: Target goal
        
        Returns:
            action_logits: Action probabilities
        """
        x = torch.cat([state, goal], dim=-1)
        return self.policy(x)
    
    def get_q_value(self, state, action, goal):
        """Get Q-value for state-action-goal"""
        x = torch.cat([state, action, goal], dim=-1)
        return self.q_function(x)
    
    def act(self, state, goal, epsilon=0.1):
        """Select action using epsilon-greedy policy"""
        with torch.no_grad():
            action_logits = self.forward(state, goal)
            
            if np.random.random() < epsilon:
                action = np.random.randint(self.action_dim)
            else:
                action_probs = F.softmax(action_logits, dim=-1)
                action = torch.multinomial(action_probs, 1).item()
        
        return action

class GoalConditionedAgent:
    """Agent using goal-conditioned policy with HER"""
    
    def __init__(self, state_dim, goal_dim, action_dim, lr=1e-3):
        self.model = GoalConditionedPolicy(state_dim, goal_dim, action_dim)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        
        self.state_dim = state_dim
        self.goal_dim = goal_dim
        self.action_dim = action_dim
        
        # Experience buffer
        self.buffer = []
        self.max_buffer_size = 10000
        
    def act(self, state, goal, epsilon=0.1):
        """Select action"""
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        goal_tensor = torch.FloatTensor(goal).unsqueeze(0)
        
        return self.model.act(state_tensor, goal_tensor, epsilon)
    
    def store_experience(self, state, action, reward, next_state, goal, done):
        """Store experience in buffer"""
        experience = {
            'state': state,
            'action': action,
            'reward': reward,
            'next_state': next_state,
            'goal': goal,
            'done': done
        }
        
        self.buffer.append(experience)
        if len(self.buffer) > self.max_buffer_size:
            self.buffer.pop(0)
    
    def hindsight_experience_replay(self, trajectory, strategy='future'):
        """
        Augment trajectory with hindsight goals
        
        Args:
            trajectory: List of experiences
            strategy: 'future', 'final', or 'random'
        """
        augmented_experiences = []
        
        for i, exp in enumerate(trajectory):
            # Original experience
            augmented_experiences.append(exp)
            
            # Hindsight: "what if goal was different?"
            if strategy == 'future':
                # Sample achieved state as goal
                if i < len(trajectory) - 1:
                    future_idx = np.random.randint(i, len(trajectory))
                    new_goal = trajectory[future_idx]['next_state']
                else:
                    new_goal = exp['next_state']
            elif strategy == 'final':
                new_goal = trajectory[-1]['next_state']
            elif strategy == 'random':
                new_goal = np.random.uniform(-1, 1, self.goal_dim)
            
            # Recompute reward with new goal
            new_reward = self.compute_reward(exp['next_state'], new_goal)
            
            # Create modified experience
            modified_exp = exp.copy()
            modified_exp['goal'] = new_goal
            modified_exp['reward'] = new_reward
            
            augmented_experiences.append(modified_exp)
        
        return augmented_experiences
    
    def compute_reward(self, state, goal):
        """Compute reward for reaching goal"""
        # Simple reward: negative distance to goal
        distance = np.linalg.norm(np.array(state) - np.array(goal))
        return -distance
    
    def update(self, batch_size=32):
        """Update the model using goal-conditioned learning"""
        if len(self.buffer) < batch_size:
            return {}
        
        # Sample batch
        batch = random.sample(self.buffer, batch_size)
        
        # Convert to tensors
        states = torch.FloatTensor([exp['state'] for exp in batch])
        actions = torch.LongTensor([exp['action'] for exp in batch])
        rewards = torch.FloatTensor([exp['reward'] for exp in batch])
        next_states = torch.FloatTensor([exp['next_state'] for exp in batch])
        goals = torch.FloatTensor([exp['goal'] for exp in batch])
        dones = torch.BoolTensor([exp['done'] for exp in batch])
        
        # Compute Q-values
        q_values = self.model.get_q_value(states, F.one_hot(actions, self.action_dim).float(), goals)
        
        # Compute target Q-values
        with torch.no_grad():
            # Get next actions from policy
            next_action_logits = self.model.forward(next_states, goals)
            next_action_probs = F.softmax(next_action_logits, dim=-1)
            next_actions = torch.multinomial(next_action_probs, 1).squeeze(1)
            
            # Compute next Q-values
            next_q_values = self.model.get_q_value(
                next_states, 
                F.one_hot(next_actions, self.action_dim).float(), 
                goals
            )
            
            targets = rewards + 0.99 * next_q_values * (~dones).float()
        
        # Q-learning loss
        q_loss = F.mse_loss(q_values.squeeze(), targets)
        
        # Policy loss (using Q-values as advantage)
        action_logits = self.model.forward(states, goals)
        log_probs = F.log_softmax(action_logits, dim=-1)
        policy_loss = -(log_probs.gather(1, actions.unsqueeze(1)) * q_values).mean()
        
        # Total loss
        total_loss = q_loss + policy_loss
        
        # Optimize
        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()
        
        return {
            'q_loss': q_loss.item(),
            'policy_loss': policy_loss.item(),
            'total_loss': total_loss.item()
        }

# Test Goal-Conditioned RL
print("Goal-Conditioned Policy Created!")
print("Testing with GridWorld environment...")

# Create agent
state_dim = 2  # (x, y) coordinates
goal_dim = 2   # (x, y) goal coordinates
action_dim = 4  # up, down, left, right

agent = GoalConditionedAgent(state_dim, goal_dim, action_dim)
env = GridWorld()

# Test agent with different goals
goals = [(5, 5), (9, 9), (0, 9), (9, 0)]

for goal in goals:
    print(f"\nTesting with goal: {goal}")
    state = env.reset()
    
    for step in range(10):
        action = agent.act(state, goal)
        next_state, reward, done, _ = env.step(action)
        
        # Compute goal-conditioned reward
        goal_reward = agent.compute_reward(next_state, goal)
        
        print(f"Step {step+1}: Action {env.action_names[action]} -> State {next_state}, Reward {goal_reward:.2f}")
        
        # Store experience
        agent.store_experience(state, action, goal_reward, next_state, goal, done)
        
        if done:
            print("Goal reached!")
            break
        state = next_state

# Update model
losses = agent.update()
print(f"\nTraining losses: {losses}")

# Test HER
print("\nTesting Hindsight Experience Replay...")
trajectory = agent.buffer[-10:]  # Last 10 experiences
augmented = agent.hindsight_experience_replay(trajectory, strategy='future')
print(f"Original experiences: {len(trajectory)}")
print(f"Augmented experiences: {len(augmented)}")


## 5. Skill Discovery

### Diversity is All You Need (DIAYN)

**Objective:** Learn diverse skills without rewards

**Key Components:**
1. **Skill-conditioned policy**: π(a|s,z) where z is skill ID
2. **Discriminator**: Predict skill from state transitions
3. **Intrinsic reward**: Reward for making states predictive of skill

**Training Process:**
1. Sample random skill z
2. Execute policy π(a|s,z) 
3. Compute intrinsic reward based on discriminator
4. Update policy and discriminator


In [None]:
# Exercise 5.1: Implement DIAYN (Diversity is All You Need)
class SkillConditionedPolicy(nn.Module):
    """Policy conditioned on skill ID"""
    
    def __init__(self, state_dim, num_skills, action_dim, hidden_dim=256):
        super().__init__()
        
        self.state_dim = state_dim
        self.num_skills = num_skills
        self.action_dim = action_dim
        
        # Skill embedding
        self.skill_embedding = nn.Embedding(num_skills, hidden_dim)
        
        # Policy network
        self.policy = nn.Sequential(
            nn.Linear(state_dim + hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
        
    def forward(self, state, skill):
        """Forward pass through skill-conditioned policy"""
        skill_emb = self.skill_embedding(skill)
        x = torch.cat([state, skill_emb], dim=-1)
        return self.policy(x)
    
    def act(self, state, skill, epsilon=0.1):
        """Select action using epsilon-greedy policy"""
        with torch.no_grad():
            action_logits = self.forward(state, skill)
            
            if np.random.random() < epsilon:
                action = np.random.randint(self.action_dim)
            else:
                action_probs = F.softmax(action_logits, dim=-1)
                action = torch.multinomial(action_probs, 1).item()
        
        return action

class SkillDiscriminator(nn.Module):
    """Discriminator that predicts skill from state"""
    
    def __init__(self, state_dim, num_skills, hidden_dim=256):
        super().__init__()
        
        self.discriminator = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_skills)
        )
    
    def forward(self, state):
        """Predict skill from state"""
        return self.discriminator(state)
    
    def get_skill_prob(self, state, skill):
        """Get probability of skill given state"""
        logits = self.forward(state)
        probs = F.softmax(logits, dim=-1)
        return probs[:, skill]

class DIAYN:
    """DIAYN algorithm for unsupervised skill discovery"""
    
    def __init__(self, state_dim, num_skills, action_dim, lr=1e-3):
        self.state_dim = state_dim
        self.num_skills = num_skills
        self.action_dim = action_dim
        
        # Models
        self.policy = SkillConditionedPolicy(state_dim, num_skills, action_dim)
        self.discriminator = SkillDiscriminator(state_dim, num_skills)
        
        # Optimizers
        self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=lr)
        self.discriminator_optimizer = optim.Adam(self.discriminator.parameters(), lr=lr)
        
        # Experience buffer
        self.buffer = []
        self.max_buffer_size = 10000
        
    def act(self, state, skill, epsilon=0.1):
        """Select action using skill-conditioned policy"""
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        skill_tensor = torch.LongTensor([skill])
        
        return self.policy.act(state_tensor, skill_tensor, epsilon)
    
    def store_experience(self, state, action, next_state, skill):
        """Store experience in buffer"""
        experience = {
            'state': state,
            'action': action,
            'next_state': next_state,
            'skill': skill
        }
        
        self.buffer.append(experience)
        if len(self.buffer) > self.max_buffer_size:
            self.buffer.pop(0)
    
    def compute_intrinsic_reward(self, state, skill):
        """Compute intrinsic reward for skill discovery"""
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        skill_tensor = torch.LongTensor([skill])
        
        with torch.no_grad():
            # Get skill probability from discriminator
            skill_prob = self.discriminator.get_skill_prob(state_tensor, skill_tensor)
            
            # Intrinsic reward: log p(skill|state) - log(1/num_skills)
            # This encourages states that are predictive of the skill
            intrinsic_reward = torch.log(skill_prob + 1e-8) - np.log(1.0 / self.num_skills)
            
            return intrinsic_reward.item()
    
    def update_policy(self, batch_size=32):
        """Update skill-conditioned policy"""
        if len(self.buffer) < batch_size:
            return {}
        
        # Sample batch
        batch = random.sample(self.buffer, batch_size)
        
        # Convert to tensors
        states = torch.FloatTensor([exp['state'] for exp in batch])
        actions = torch.LongTensor([exp['action'] for exp in batch])
        skills = torch.LongTensor([exp['skill'] for exp in batch])
        
        # Compute intrinsic rewards
        intrinsic_rewards = []
        for i, exp in enumerate(batch):
            reward = self.compute_intrinsic_reward(exp['next_state'], exp['skill'])
            intrinsic_rewards.append(reward)
        
        intrinsic_rewards = torch.FloatTensor(intrinsic_rewards)
        
        # Policy loss
        action_logits = self.policy.forward(states, skills)
        log_probs = F.log_softmax(action_logits, dim=-1)
        policy_loss = -(log_probs.gather(1, actions.unsqueeze(1)) * intrinsic_rewards.unsqueeze(1)).mean()
        
        # Optimize policy
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()
        
        return {'policy_loss': policy_loss.item()}
    
    def update_discriminator(self, batch_size=32):
        """Update skill discriminator"""
        if len(self.buffer) < batch_size:
            return {}
        
        # Sample batch
        batch = random.sample(self.buffer, batch_size)
        
        # Convert to tensors
        states = torch.FloatTensor([exp['next_state'] for exp in batch])  # Use next_state
        skills = torch.LongTensor([exp['skill'] for exp in batch])
        
        # Discriminator loss
        logits = self.discriminator.forward(states)
        discriminator_loss = F.cross_entropy(logits, skills)
        
        # Optimize discriminator
        self.discriminator_optimizer.zero_grad()
        discriminator_loss.backward()
        self.discriminator_optimizer.step()
        
        return {'discriminator_loss': discriminator_loss.item()}
    
    def train_episode(self, env, max_steps=50):
        """Train for one episode"""
        # Sample random skill
        skill = np.random.randint(self.num_skills)
        
        state = env.reset()
        episode_reward = 0
        
        for step in range(max_steps):
            # Select action
            action = self.act(state, skill)
            
            # Execute action
            next_state, reward, done, _ = env.step(action)
            
            # Store experience
            self.store_experience(state, action, next_state, skill)
            
            # Compute intrinsic reward
            intrinsic_reward = self.compute_intrinsic_reward(next_state, skill)
            episode_reward += intrinsic_reward
            
            if done:
                break
            
            state = next_state
        
        return episode_reward, skill

# Test DIAYN
print("DIAYN Algorithm Created!")
print("Testing skill discovery...")

# Create DIAYN agent
state_dim = 2  # (x, y) coordinates
num_skills = 4  # Number of skills to discover
action_dim = 4  # up, down, left, right

diayn = DIAYN(state_dim, num_skills, action_dim)
env = GridWorld()

# Train for several episodes
print("Training DIAYN for skill discovery...")

for episode in range(10):
    episode_reward, skill = diayn.train_episode(env)
    
    # Update models
    policy_losses = diayn.update_policy()
    discriminator_losses = diayn.update_discriminator()
    
    print(f"Episode {episode+1}: Skill {skill}, Reward {episode_reward:.2f}")
    print(f"  Policy Loss: {policy_losses.get('policy_loss', 0):.4f}")
    print(f"  Discriminator Loss: {discriminator_losses.get('discriminator_loss', 0):.4f}")

# Test discovered skills
print("\nTesting discovered skills...")
for skill in range(num_skills):
    print(f"\nSkill {skill}:")
    state = env.reset()
    
    for step in range(10):
        action = diayn.act(state, skill, epsilon=0.0)  # No exploration
        next_state, reward, done, _ = env.step(action)
        
        intrinsic_reward = diayn.compute_intrinsic_reward(next_state, skill)
        print(f"  Step {step+1}: Action {env.action_names[action]} -> State {next_state}, Intrinsic Reward {intrinsic_reward:.2f}")
        
        if done:
            print("  Goal reached!")
            break
        state = next_state
