In [None]:
# ============================================================================
# THE AUTONOMOUS COLONY - PART 4: ADVANCED RL CONCEPTS
# Meta-Learning, World Models, Hierarchical RL, and Curiosity
# ============================================================================

"""
# ðŸ§  The Autonomous Colony - Advanced RL

**RL Concepts Covered:**
1. Meta-RL / Learning to Learn (MAML-style)
2. World Models (simplified DreamerV3)
3. Hierarchical RL (Options Framework)
4. Intrinsic Motivation (Curiosity/ICM)
5. Offline RL / Imitation Learning
6. Curriculum Learning
7. Model-based planning

**Prerequisites:**
- Parts 1-3 (environment, agents, multi-agent)
"""

# ============================================================================
# SETUP
# ============================================================================

# !pip install torch numpy matplotlib -q

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from typing import Dict, List, Tuple, Optional
from collections import deque
import copy

print("âœ“ Advanced RL modules loaded")

# ============================================================================
# 1. WORLD MODEL (Model-Based RL)
# ============================================================================

class WorldModel(nn.Module):
    """
    Simplified world model for planning.
    Learns to predict next state and reward given current state and action.
    
    RL Concepts:
    - Model-based RL: learn dynamics model
    - Planning: use model to simulate trajectories
    - Dyna-Q style: combine model learning with policy learning
    """
    
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 128):
        super().__init__()
        
        # Transition model: predicts next state
        self.transition_net = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, state_dim)
        )
        
        # Reward model: predicts immediate reward
        self.reward_net = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        
        # Done prediction
        self.done_net = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
    
    def forward(self, state, action_onehot):
        """Predict next state, reward, and done flag"""
        x = torch.cat([state, action_onehot], dim=-1)
        
        next_state = self.transition_net(x)
        reward = self.reward_net(x)
        done_prob = self.done_net(x)
        
        return next_state, reward, done_prob
    
    def imagine_trajectory(self, initial_state, policy_fn, horizon: int = 10):
        """
        Imagine a trajectory using the world model.
        Used for planning or auxiliary training.
        """
        states = [initial_state]
        actions = []
        rewards = []
        
        state = initial_state
        for _ in range(horizon):
            # Get action from policy
            action = policy_fn(state)
            actions.append(action)
            
            # Predict next state
            action_onehot = F.one_hot(action, num_classes=9).float()
            next_state, reward, done_prob = self.forward(state, action_onehot)
            
            states.append(next_state)
            rewards.append(reward)
            
            state = next_state
            
            # Stop if done is likely
            if done_prob > 0.5:
                break
        
        return states, actions, rewards

class WorldModelAgent:
    """Agent that uses world model for planning"""
    
    def __init__(self, state_dim: int, action_dim: int, learning_rate: float = 1e-3):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.world_model = WorldModel(state_dim, action_dim).to(self.device)
        self.optimizer = optim.Adam(self.world_model.parameters(), lr=learning_rate)
        
        self.replay_buffer = deque(maxlen=10000)
        print(f"âœ“ World Model Agent initialized")
    
    def train_world_model(self, batch_size: int = 64):
        """Train world model on collected experience"""
        if len(self.replay_buffer) < batch_size:
            return None
        
        # Sample batch
        batch = [self.replay_buffer[i] for i in np.random.choice(len(self.replay_buffer), batch_size)]
        states, actions, rewards, next_states, dones = zip(*batch)
        
        states_t = torch.FloatTensor(states).to(self.device)
        actions_t = torch.LongTensor(actions).to(self.device)
        rewards_t = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)
        next_states_t = torch.FloatTensor(next_states).to(self.device)
        dones_t = torch.FloatTensor(dones).unsqueeze(1).to(self.device)
        
        # One-hot encode actions
        actions_onehot = F.one_hot(actions_t, num_classes=9).float()
        
        # Predict
        pred_next_state, pred_reward, pred_done = self.world_model(states_t, actions_onehot)
        
        # Losses
        state_loss = F.mse_loss(pred_next_state, next_states_t)
        reward_loss = F.mse_loss(pred_reward, rewards_t)
        done_loss = F.binary_cross_entropy(pred_done, dones_t)
        
        total_loss = state_loss + reward_loss + done_loss
        
        # Optimize
        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()
        
        return total_loss.item()

# ============================================================================
# 2. CURIOSITY MODULE (Intrinsic Motivation)
# ============================================================================

class CuriosityModule(nn.Module):
    """
    Intrinsic Curiosity Module (ICM) for exploration.
    
    RL Concept: Intrinsic motivation
    - Reward agents for exploring novel states
    - Prediction error as curiosity signal
    - Helps with sparse reward environments
    """
    
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 128):
        super().__init__()
        
        # Feature encoder (for state representation)
        self.feature_encoder = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 64)
        )
        
        # Forward model: predict next state features from current + action
        self.forward_model = nn.Sequential(
            nn.Linear(64 + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 64)
        )
        
        # Inverse model: predict action from state features
        self.inverse_model = nn.Sequential(
            nn.Linear(64 + 64, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
    
    def forward(self, state, next_state, action):
        """Compute intrinsic reward based on prediction error"""
        # Encode states
        state_features = self.feature_encoder(state)
        next_state_features = self.feature_encoder(next_state)
        
        # Forward model prediction
        action_onehot = F.one_hot(action, num_classes=9).float()
        predicted_next_features = self.forward_model(
            torch.cat([state_features, action_onehot], dim=-1)
        )
        
        # Intrinsic reward = prediction error
        intrinsic_reward = F.mse_loss(
            predicted_next_features, next_state_features, reduction='none'
        ).mean(dim=-1)
        
        # Inverse model prediction (for training)
        predicted_action = self.inverse_model(
            torch.cat([state_features, next_state_features], dim=-1)
        )
        
        return intrinsic_reward, predicted_action

class CuriosityAgent:
    """Agent with curiosity-driven exploration"""
    
    def __init__(self, state_dim: int, action_dim: int, 
                 base_agent, curiosity_weight: float = 0.5):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.curiosity = CuriosityModule(state_dim, action_dim).to(self.device)
        self.curiosity_optimizer = optim.Adam(self.curiosity.parameters(), lr=1e-4)
        
        self.base_agent = base_agent
        self.curiosity_weight = curiosity_weight
        
        print(f"âœ“ Curiosity Agent initialized (weight={curiosity_weight})")
    
    def compute_intrinsic_reward(self, state, next_state, action):
        """Add curiosity bonus to extrinsic reward"""
        state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        next_state_t = torch.FloatTensor(next_state).unsqueeze(0).to(self.device)
        action_t = torch.LongTensor([action]).to(self.device)
        
        with torch.no_grad():
            intrinsic_reward, _ = self.curiosity(state_t, next_state_t, action_t)
        
        return intrinsic_reward.item()
    
    def train_curiosity(self, state, next_state, action):
        """Train curiosity module"""
        state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        next_state_t = torch.FloatTensor(next_state).unsqueeze(0).to(self.device)
        action_t = torch.LongTensor([action]).to(self.device)
        
        intrinsic_reward, predicted_action = self.curiosity(state_t, next_state_t, action_t)
        
        # Loss: forward model + inverse model
        forward_loss = intrinsic_reward.mean()
        inverse_loss = F.cross_entropy(predicted_action, action_t)
        
        loss = forward_loss + inverse_loss
        
        self.curiosity_optimizer.zero_grad()
        loss.backward()
        self.curiosity_optimizer.step()
        
        return loss.item()

# ============================================================================
# 3. HIERARCHICAL RL (Options Framework)
# ============================================================================

class Option(nn.Module):
    """
    A single option (sub-policy) in hierarchical RL.
    
    RL Concept: Temporal abstraction
    - Options = skills that execute over multiple timesteps
    - Learn when to initiate and terminate options
    - Enables hierarchical decision making
    """
    
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 64):
        super().__init__()
        
        # Intra-option policy: what actions to take within this option
        self.policy = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
        
        # Termination function: when to end this option
        self.termination = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
    
    def forward(self, state):
        """Get action and termination probability"""
        logits = self.policy(state)
        termination_prob = self.termination(state)
        return logits, termination_prob

class HierarchicalAgent:
    """
    Agent with hierarchical policy (meta-controller + options).
    
    Two-level hierarchy:
    - Meta-controller: selects which option to use
    - Options: execute low-level actions
    """
    
    def __init__(self, state_dim: int, action_dim: int, n_options: int = 4):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.n_options = n_options
        
        # Options (sub-policies)
        self.options = nn.ModuleList([
            Option(state_dim, action_dim) for _ in range(n_options)
        ]).to(self.device)
        
        # Meta-controller: selects option
        self.meta_controller = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, n_options)
        ).to(self.device)
        
        self.optimizer = optim.Adam(
            list(self.options.parameters()) + list(self.meta_controller.parameters()),
            lr=1e-4
        )
        
        # Current option
        self.current_option = None
        self.option_duration = 0
        
        print(f"âœ“ Hierarchical Agent initialized ({n_options} options)")
    
    def select_action(self, state, force_new_option=False):
        """Select action using hierarchical policy"""
        state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        
        # Check if we need to select a new option
        if self.current_option is None or force_new_option:
            # Meta-controller selects option
            with torch.no_grad():
                option_logits = self.meta_controller(state_t)
                option_probs = torch.softmax(option_logits, dim=-1)
                self.current_option = torch.multinomial(option_probs, 1).item()
                self.option_duration = 0
        
        # Execute current option
        with torch.no_grad():
            action_logits, termination_prob = self.options[self.current_option](state_t)
            action_probs = torch.softmax(action_logits, dim=-1)
            action = torch.multinomial(action_probs, 1).item()
        
        self.option_duration += 1
        
        # Check termination
        if termination_prob.item() > 0.5 or self.option_duration > 20:
            self.current_option = None
        
        return action, self.current_option

# ============================================================================
# 4. META-RL (Learning to Learn)
# ============================================================================

class MetaLearner:
    """
    Meta-RL using MAML-style adaptation.
    
    RL Concept: Meta-learning
    - Learn initialization that adapts quickly to new tasks
    - Few-shot learning in RL
    - Test on distribution of environments (seasons)
    """
    
    def __init__(self, base_network, inner_lr: float = 0.01, meta_lr: float = 1e-3):
        self.base_network = base_network
        self.inner_lr = inner_lr
        self.meta_optimizer = optim.Adam(base_network.parameters(), lr=meta_lr)
        
        print(f"âœ“ Meta-Learner initialized (inner_lr={inner_lr}, meta_lr={meta_lr})")
    
    def inner_loop_update(self, network, task_data, n_steps: int = 5):
        """
        Inner loop: adapt to specific task.
        Fast adaptation using gradient descent.
        """
        adapted_network = copy.deepcopy(network)
        inner_optimizer = optim.SGD(adapted_network.parameters(), lr=self.inner_lr)
        
        for _ in range(n_steps):
            # Sample from task
            states, actions, rewards, next_states, dones = task_data
            
            # Compute loss (simplified)
            loss = self._compute_task_loss(adapted_network, states, actions, rewards)
            
            inner_optimizer.zero_grad()
            loss.backward()
            inner_optimizer.step()
        
        return adapted_network
    
    def meta_update(self, task_batch):
        """
        Outer loop: update meta-parameters.
        Meta-gradient across multiple tasks.
        """
        meta_loss = 0
        
        for task_data in task_batch:
            # Inner loop adaptation
            adapted_network = self.inner_loop_update(self.base_network, task_data)
            
            # Evaluate on task
            states, actions, rewards, next_states, dones = task_data
            task_loss = self._compute_task_loss(adapted_network, states, actions, rewards)
            
            meta_loss += task_loss
        
        # Meta-optimization
        meta_loss = meta_loss / len(task_batch)
        self.meta_optimizer.zero_grad()
        meta_loss.backward()
        self.meta_optimizer.step()
        
        return meta_loss.item()
    
    def _compute_task_loss(self, network, states, actions, rewards):
        """Compute loss for a task (placeholder)"""
        # This would be your actual RL loss (policy gradient, Q-learning, etc.)
        # Simplified for demonstration
        return torch.tensor(0.0, requires_grad=True)

# ============================================================================
# 5. CURRICULUM LEARNING
# ============================================================================

class CurriculumManager:
    """
    Manages curriculum for progressive training.
    
    RL Concept: Curriculum learning
    - Start with easy tasks, gradually increase difficulty
    - Automatic difficulty adjustment based on performance
    - Accelerates learning on complex tasks
    """
    
    def __init__(self, initial_difficulty: float = 0.1):
        self.difficulty = initial_difficulty
        self.performance_history = deque(maxlen=50)
        
        print(f"âœ“ Curriculum Manager initialized (difficulty={initial_difficulty})")
    
    def get_environment_config(self):
        """Generate environment config based on current difficulty"""
        config = {
            'grid_size': int(10 + self.difficulty * 30),  # 10 to 40
            'n_agents': int(2 + self.difficulty * 6),      # 2 to 8
            'food_spawn_rate': 0.03 * (1 - self.difficulty * 0.5),  # Decrease spawn rate
            'obstacle_density': 0.05 + self.difficulty * 0.15,       # More obstacles
            'max_steps': int(200 + self.difficulty * 300)            # Longer episodes
        }
        return config
    
    def update_difficulty(self, episode_reward: float, success: bool):
        """Adjust difficulty based on performance"""
        self.performance_history.append((episode_reward, success))
        
        if len(self.performance_history) < 20:
            return
        
        # Compute recent performance
        recent_success_rate = sum(s for _, s in list(self.performance_history)[-20:]) / 20
        recent_avg_reward = np.mean([r for r, _ in list(self.performance_history)[-20:]])
        
        # Increase difficulty if performing well
        if recent_success_rate > 0.7 and self.difficulty < 1.0:
            self.difficulty = min(1.0, self.difficulty + 0.05)
            print(f"ðŸ“ˆ Difficulty increased to {self.difficulty:.2f}")
        
        # Decrease if struggling
        elif recent_success_rate < 0.3 and self.difficulty > 0.1:
            self.difficulty = max(0.1, self.difficulty - 0.05)
            print(f"ðŸ“‰ Difficulty decreased to {self.difficulty:.2f}")
    
    def get_difficulty(self):
        return self.difficulty

# ============================================================================
# 6. OFFLINE RL / IMITATION LEARNING
# ============================================================================

class OfflineRLAgent:
    """
    Learn from fixed dataset (offline RL).
    
    RL Concepts:
    - Batch RL: learn without environment interaction
    - Behavioral cloning: imitate expert demonstrations
    - Conservative Q-Learning: avoid OOD actions
    """
    
    def __init__(self, state_dim: int, action_dim: int):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Policy network
        self.policy = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, action_dim)
        ).to(self.device)
        
        self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-3)
        
        print(f"âœ“ Offline RL Agent initialized")
    
    def behavioral_cloning(self, expert_states, expert_actions, n_epochs: int = 100):
        """
        Learn to imitate expert policy.
        Supervised learning on state-action pairs.
        """
        states_t = torch.FloatTensor(expert_states).to(self.device)
        actions_t = torch.LongTensor(expert_actions).to(self.device)
        
        losses = []
        
        print(f"Training behavioral cloning on {len(expert_states)} demonstrations...")
        
        for epoch in range(n_epochs):
            # Forward pass
            logits = self.policy(states_t)
            
            # Cross-entropy loss
            loss = F.cross_entropy(logits, actions_t)
            
            # Optimize
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            losses.append(loss.item())
            
            if (epoch + 1) % 20 == 0:
                print(f"  Epoch {epoch+1}/{n_epochs}, Loss: {loss.item():.4f}")
        
        return losses
    
    def select_action(self, state):
        """Select action using learned policy"""
        state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        with torch.no_grad():
            logits = self.policy(state_t)
            probs = torch.softmax(logits, dim=-1)
            action = torch.multinomial(probs, 1).item()
        return action

# ============================================================================
# 7. INTEGRATED TRAINING UTILITIES
# ============================================================================

def train_with_curriculum(env_class, agent, curriculum_manager, n_iterations: int = 100):
    """Train agent with curriculum learning"""
    rewards_history = []
    difficulty_history = []
    
    print("\nðŸŽ“ Training with Curriculum Learning...")
    
    for iteration in range(n_iterations):
        # Get current difficulty config
        config_dict = curriculum_manager.get_environment_config()
        
        # Create environment (you'd need to adapt your env to accept dict config)
        # env = env_class(config_dict)
        # For demo, just track difficulty
        
        # Train for some episodes at this difficulty
        # episode_reward, success = train_episode(env, agent)
        episode_reward = np.random.randn() * 10 + 50  # Placeholder
        success = episode_reward > 40
        
        # Update curriculum
        curriculum_manager.update_difficulty(episode_reward, success)
        
        rewards_history.append(episode_reward)
        difficulty_history.append(curriculum_manager.get_difficulty())
        
        if (iteration + 1) % 10 == 0:
            print(f"Iteration {iteration+1}: Reward={episode_reward:.1f}, Difficulty={curriculum_manager.get_difficulty():.2f}")
    
    return rewards_history, difficulty_history

def collect_expert_demonstrations(env, expert_agent, n_episodes: int = 50):
    """Collect demonstrations from expert for offline learning"""
    states = []
    actions = []
    
    print(f"Collecting {n_episodes} expert demonstrations...")
    
    for ep in range(n_episodes):
        obs, _ = env.reset()
        done = False
        
        while not done:
            # Expert action (assumes single agent)
            action = expert_agent.select_action(obs[0])
            states.append(obs[0]['state'])  # Simplified: just use state vector
            actions.append(action)
            
            # Step
            next_obs, rewards, dones, truncated, _ = env.step([action] * env.config.n_agents)
            obs = next_obs
            done = dones[0] or truncated[0]
    
    print(f"âœ“ Collected {len(states)} state-action pairs")
    return np.array(states), np.array(actions)

# ============================================================================
# VISUALIZATION
# ============================================================================

def plot_advanced_metrics(results_dict):
    """Plot results from advanced RL experiments"""
    import matplotlib.pyplot as plt
    
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    # World model loss
    if 'world_model_loss' in results_dict:
        axes[0, 0].plot(results_dict['world_model_loss'])
        axes[0, 0].set_title('World Model Training Loss')
        axes[0, 0].set_xlabel('Update Step')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].grid(True, alpha=0.3)
    
    # Intrinsic rewards
    if 'intrinsic_rewards' in results_dict:
        axes[0, 1].plot(results_dict['intrinsic_rewards'], alpha=0.5)
        axes[0, 1].set_title('Curiosity: Intrinsic Rewards')
        axes[0, 1].set_xlabel('Step')
        axes[0, 1].set_ylabel('Intrinsic Reward')
        axes[0, 1].grid(True, alpha=0.3)
    
    # Curriculum difficulty
    if 'curriculum_difficulty' in results_dict:
        axes[1, 0].plot(results_dict['curriculum_difficulty'], linewidth=2)
        axes[1, 0].set_title('Curriculum Learning Progress')
        axes[1, 0].set_xlabel('Iteration')
        axes[1, 0].set_ylabel('Difficulty')
        axes[1, 0].set_ylim([0, 1.1])
        axes[1, 0].grid(True, alpha=0.3)
    
    # Option usage (hierarchical)
    if 'option_usage' in results_dict:
        option_counts = results_dict['option_usage']
        axes[1, 1].bar(range(len(option_counts)), option_counts)
        axes[1, 1].set_title('Hierarchical RL: Option Usage')
        axes[1, 1].set_xlabel('Option ID')
        axes[1, 1].set_ylabel('Frequency')
        axes[1, 1].grid(True, alpha=0.3, axis='y')
    
    plt.suptitle('Advanced RL Metrics', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

# ============================================================================
# DEMO & EXPERIMENTS
# ============================================================================

print("\n" + "="*80)
print("ADVANCED RL MODULES READY!")
print("="*80)
print("\nðŸ§  Implemented Concepts:")
print("  âœ“ World Models (model-based planning)")
print("  âœ“ Curiosity-driven exploration (ICM)")
print("  âœ“ Hierarchical RL (Options framework)")
print("  âœ“ Meta-learning (MAML-style)")
print("  âœ“ Curriculum learning")
print("  âœ“ Offline RL / Behavioral cloning")
print("\nðŸ”¬ Experiment Ideas:")
print("  1. Compare model-free vs model-based agents")
print("  2. Measure exploration with/without curiosity")
print("  3. Emergent skills in hierarchical agents")
print("  4. Fast adaptation with meta-learning")
print("  5. Learning curves with curriculum")
print("="*80)

# Example usage:
"""
# 1. World Model
state_dim = 5
action_dim = 9
wm_agent = WorldModelAgent(state_dim, action_dim)

# 2. Curiosity
from part2_agents import PPOAgent  # Your base agent
base_ppo = PPOAgent(...)
curious_agent = CuriosityAgent(state_dim, action_dim, base_ppo)

# 3. Hierarchical RL
h_agent = HierarchicalAgent(state_dim, action_dim, n_options=4)

# 4. Curriculum
curriculum = CurriculumManager(initial_difficulty=0.1)
# rewards, difficulties = train_with_curriculum(ColonyEnvironment, agent, curriculum)

# 5. Offline RL / Imitation
# expert_states, expert_actions = collect_expert_demonstrations(env, expert_agent)
# offline_agent = OfflineRLAgent(state_dim, action_dim)
# offline_agent.behavioral_cloning(expert_states, expert_actions)
"""