In [None]:
# ============================================================================
# THE AUTONOMOUS COLONY - PART 2: RL AGENTS
# From Tabular Q-Learning to Deep RL (DQN, PPO, SAC)
# ============================================================================

"""
# ðŸ¤– The Autonomous Colony - RL Agents

This notebook implements multiple RL algorithms for the colony environment.

**RL Concepts Covered:**
1. Tabular Q-Learning (value-based, discrete)
2. DQN with experience replay (deep Q-learning)
3. PPO (policy gradient, actor-critic)
4. Exploration strategies (Îµ-greedy, entropy bonus)
5. Reward shaping and sparse rewards
6. Training curves and evaluation

**Prerequisites:**
- Run Part 1 first to set up the environment
"""

# ============================================================================
# SETUP
# ============================================================================

# !pip install gymnasium stable-baselines3 torch tensorboard -q

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque, defaultdict
import random
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple
import copy

# Import environment from Part 1 (assumes it's in the same notebook or imported)
# from part1_environment import ColonyEnvironment, EnvironmentConfig, ActionType

print("âœ“ Dependencies loaded")

# ============================================================================
# 1. TABULAR Q-LEARNING AGENT
# ============================================================================

class TabularQLearningAgent:
    """
    Classic Q-Learning with table lookup.
    
    RL Concepts:
    - Value-based learning
    - Temporal Difference (TD) learning
    - Îµ-greedy exploration
    - Bellman equation: Q(s,a) = Q(s,a) + Î±[r + Î³Â·max_a'Q(s',a') - Q(s,a)]
    """
    
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        learning_rate: float = 0.1,
        gamma: float = 0.99,
        epsilon: float = 1.0,
        epsilon_decay: float = 0.995,
        epsilon_min: float = 0.01
    ):
        self.action_dim = action_dim
        self.lr = learning_rate
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        
        # Q-table: state -> action values
        self.q_table = defaultdict(lambda: np.zeros(action_dim))
        self.state_visits = defaultdict(int)
        
    def _discretize_state(self, observation: Dict) -> tuple:
        """Convert continuous observation to discrete state"""
        state_vec = observation['state']
        # Discretize to 10 bins
        discrete = tuple((state_vec * 10).astype(int).clip(0, 9))
        
        # Add grid info (simplified)
        grid = observation['grid']
        # Count nearby resources
        center = grid.shape[0] // 2
        local_view = grid[center-1:center+2, center-1:center+2, :]
        food_count = int(local_view[:, :, 1].sum())
        water_count = int(local_view[:, :, 2].sum())
        
        return discrete + (food_count, water_count)
    
    def select_action(self, observation: Dict, training: bool = True) -> int:
        """Îµ-greedy action selection"""
        state = self._discretize_state(observation)
        
        # Exploration
        if training and random.random() < self.epsilon:
            return random.randint(0, self.action_dim - 1)
        
        # Exploitation
        q_values = self.q_table[state]
        return int(np.argmax(q_values))
    
    def update(self, state: Dict, action: int, reward: float, 
               next_state: Dict, done: bool):
        """Q-Learning update"""
        s = self._discretize_state(state)
        s_next = self._discretize_state(next_state)
        
        # Current Q-value
        q_current = self.q_table[s][action]
        
        # Target Q-value
        if done:
            q_target = reward
        else:
            q_target = reward + self.gamma * np.max(self.q_table[s_next])
        
        # TD error
        td_error = q_target - q_current
        
        # Update
        self.q_table[s][action] += self.lr * td_error
        self.state_visits[s] += 1
        
        return td_error
    
    def decay_epsilon(self):
        """Decay exploration rate"""
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

# ============================================================================
# 2. DEEP Q-NETWORK (DQN)
# ============================================================================

class QNetwork(nn.Module):
    """
    Neural network for Q-value approximation.
    Handles both grid observations and internal state.
    """
    
    def __init__(self, grid_shape: tuple, state_dim: int, action_dim: int):
        super().__init__()
        
        # CNN for grid observation
        self.conv1 = nn.Conv2d(grid_shape[2], 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        
        # Calculate conv output size
        conv_out_size = grid_shape[0] * grid_shape[1] * 64
        
        # FC for internal state
        self.state_fc = nn.Linear(state_dim, 64)
        
        # Combined layers
        self.fc1 = nn.Linear(conv_out_size + 64, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, action_dim)
        
    def forward(self, grid_obs, state_obs):
        # Process grid
        x = F.relu(self.conv1(grid_obs))
        x = F.relu(self.conv2(x))
        x = x.flatten(1)
        
        # Process state
        s = F.relu(self.state_fc(state_obs))
        
        # Combine
        combined = torch.cat([x, s], dim=1)
        x = F.relu(self.fc1(combined))
        x = F.relu(self.fc2(x))
        q_values = self.fc3(x)
        
        return q_values

class ReplayBuffer:
    """Experience replay buffer for DQN"""
    
    def __init__(self, capacity: int = 10000):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size: int):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return states, actions, rewards, next_states, dones
    
    def __len__(self):
        return len(self.buffer)

class DQNAgent:
    """
    Deep Q-Network with experience replay and target network.
    
    RL Concepts:
    - Function approximation (neural networks)
    - Experience replay (decorrelation)
    - Target network (stability)
    - Double DQN (optional)
    """
    
    def __init__(
        self,
        grid_shape: tuple,
        state_dim: int,
        action_dim: int,
        learning_rate: float = 1e-4,
        gamma: float = 0.99,
        epsilon: float = 1.0,
        epsilon_decay: float = 0.995,
        epsilon_min: float = 0.01,
        buffer_size: int = 10000,
        batch_size: int = 64,
        target_update_freq: int = 100
    ):
        self.action_dim = action_dim
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        self.batch_size = batch_size
        self.target_update_freq = target_update_freq
        
        # Networks
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.q_network = QNetwork(grid_shape, state_dim, action_dim).to(self.device)
        self.target_network = QNetwork(grid_shape, state_dim, action_dim).to(self.device)
        self.target_network.load_state_dict(self.q_network.state_dict())
        
        # Optimizer
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=learning_rate)
        
        # Replay buffer
        self.replay_buffer = ReplayBuffer(buffer_size)
        
        # Training stats
        self.update_count = 0
        self.loss_history = []
        
        print(f"âœ“ DQN Agent initialized on {self.device}")
    
    def _obs_to_tensor(self, obs: Dict):
        """Convert observation dict to tensors"""
        grid = torch.FloatTensor(obs['grid']).permute(2, 0, 1).unsqueeze(0).to(self.device)
        state = torch.FloatTensor(obs['state']).unsqueeze(0).to(self.device)
        return grid, state
    
    def select_action(self, observation: Dict, training: bool = True) -> int:
        """Îµ-greedy action selection"""
        if training and random.random() < self.epsilon:
            return random.randint(0, self.action_dim - 1)
        
        grid, state = self._obs_to_tensor(observation)
        with torch.no_grad():
            q_values = self.q_network(grid, state)
        return int(q_values.argmax().item())
    
    def update(self, state: Dict, action: int, reward: float,
               next_state: Dict, done: bool):
        """Store transition and train if buffer is ready"""
        self.replay_buffer.push(state, action, reward, next_state, done)
        
        if len(self.replay_buffer) < self.batch_size:
            return None
        
        # Sample batch
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
        
        # Convert to tensors
        grids = torch.stack([torch.FloatTensor(s['grid']).permute(2, 0, 1) for s in states]).to(self.device)
        state_vecs = torch.stack([torch.FloatTensor(s['state']) for s in states]).to(self.device)
        actions_t = torch.LongTensor(actions).to(self.device)
        rewards_t = torch.FloatTensor(rewards).to(self.device)
        next_grids = torch.stack([torch.FloatTensor(s['grid']).permute(2, 0, 1) for s in next_states]).to(self.device)
        next_state_vecs = torch.stack([torch.FloatTensor(s['state']) for s in next_states]).to(self.device)
        dones_t = torch.FloatTensor(dones).to(self.device)
        
        # Current Q-values
        current_q = self.q_network(grids, state_vecs).gather(1, actions_t.unsqueeze(1)).squeeze(1)
        
        # Target Q-values
        with torch.no_grad():
            next_q = self.target_network(next_grids, next_state_vecs).max(1)[0]
            target_q = rewards_t + self.gamma * next_q * (1 - dones_t)
        
        # Loss
        loss = F.mse_loss(current_q, target_q)
        
        # Optimize
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 1.0)
        self.optimizer.step()
        
        # Update target network
        self.update_count += 1
        if self.update_count % self.target_update_freq == 0:
            self.target_network.load_state_dict(self.q_network.state_dict())
        
        self.loss_history.append(loss.item())
        return loss.item()
    
    def decay_epsilon(self):
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

# ============================================================================
# 3. PPO AGENT (Policy Gradient)
# ============================================================================

class ActorCritic(nn.Module):
    """
    Actor-Critic network for PPO.
    Actor outputs policy, Critic outputs value function.
    """
    
    def __init__(self, grid_shape: tuple, state_dim: int, action_dim: int):
        super().__init__()
        
        # Shared feature extractor
        self.conv1 = nn.Conv2d(grid_shape[2], 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        conv_out_size = grid_shape[0] * grid_shape[1] * 64
        
        self.state_fc = nn.Linear(state_dim, 64)
        self.shared_fc = nn.Linear(conv_out_size + 64, 256)
        
        # Actor head (policy)
        self.actor_fc = nn.Linear(256, 128)
        self.actor_out = nn.Linear(128, action_dim)
        
        # Critic head (value)
        self.critic_fc = nn.Linear(256, 128)
        self.critic_out = nn.Linear(128, 1)
    
    def forward(self, grid_obs, state_obs):
        # Shared features
        x = F.relu(self.conv1(grid_obs))
        x = F.relu(self.conv2(x))
        x = x.flatten(1)
        s = F.relu(self.state_fc(state_obs))
        features = F.relu(self.shared_fc(torch.cat([x, s], dim=1)))
        
        # Actor (policy logits)
        actor_x = F.relu(self.actor_fc(features))
        logits = self.actor_out(actor_x)
        
        # Critic (value)
        critic_x = F.relu(self.critic_fc(features))
        value = self.critic_out(critic_x)
        
        return logits, value
    
    def get_action_and_value(self, grid_obs, state_obs, action=None):
        logits, value = self(grid_obs, state_obs)
        probs = torch.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        
        if action is None:
            action = dist.sample()
        
        log_prob = dist.log_prob(action)
        entropy = dist.entropy()
        
        return action, log_prob, entropy, value

class PPOAgent:
    """
    Proximal Policy Optimization - modern policy gradient method.
    
    RL Concepts:
    - Policy gradient theorem
    - Actor-Critic architecture
    - Clipped surrogate objective
    - Generalized Advantage Estimation (GAE)
    - Entropy regularization for exploration
    """
    
    def __init__(
        self,
        grid_shape: tuple,
        state_dim: int,
        action_dim: int,
        learning_rate: float = 3e-4,
        gamma: float = 0.99,
        gae_lambda: float = 0.95,
        clip_epsilon: float = 0.2,
        entropy_coef: float = 0.01,
        value_coef: float = 0.5,
        n_epochs: int = 4,
        batch_size: int = 64
    ):
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.clip_epsilon = clip_epsilon
        self.entropy_coef = entropy_coef
        self.value_coef = value_coef
        self.n_epochs = n_epochs
        self.batch_size = batch_size
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.network = ActorCritic(grid_shape, state_dim, action_dim).to(self.device)
        self.optimizer = optim.Adam(self.network.parameters(), lr=learning_rate)
        
        # Rollout storage
        self.rollout_buffer = []
        self.loss_history = []
        
        print(f"âœ“ PPO Agent initialized on {self.device}")
    
    def select_action(self, observation: Dict, training: bool = True):
        """Sample action from policy"""
        grid = torch.FloatTensor(observation['grid']).permute(2, 0, 1).unsqueeze(0).to(self.device)
        state = torch.FloatTensor(observation['state']).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            action, log_prob, entropy, value = self.network.get_action_and_value(grid, state)
        
        return int(action.item()), log_prob.item(), value.item()
    
    def store_transition(self, state, action, reward, log_prob, value, done):
        """Store transition for batch update"""
        self.rollout_buffer.append((state, action, reward, log_prob, value, done))
    
    def compute_gae(self, rewards, values, dones):
        """Compute Generalized Advantage Estimation"""
        advantages = []
        gae = 0
        
        for t in reversed(range(len(rewards))):
            if t == len(rewards) - 1:
                next_value = 0
            else:
                next_value = values[t + 1]
            
            delta = rewards[t] + self.gamma * next_value * (1 - dones[t]) - values[t]
            gae = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * gae
            advantages.insert(0, gae)
        
        returns = [adv + val for adv, val in zip(advantages, values)]
        return advantages, returns
    
    def update(self):
        """PPO update using collected rollouts"""
        if len(self.rollout_buffer) < self.batch_size:
            return None
        
        # Prepare batch
        states, actions, rewards, old_log_probs, values, dones = zip(*self.rollout_buffer)
        
        # Compute advantages
        advantages, returns = self.compute_gae(list(rewards), list(values), list(dones))
        
        # Convert to tensors
        grids = torch.stack([torch.FloatTensor(s['grid']).permute(2, 0, 1) for s in states]).to(self.device)
        state_vecs = torch.stack([torch.FloatTensor(s['state']) for s in states]).to(self.device)
        actions_t = torch.LongTensor(actions).to(self.device)
        old_log_probs_t = torch.FloatTensor(old_log_probs).to(self.device)
        advantages_t = torch.FloatTensor(advantages).to(self.device)
        returns_t = torch.FloatTensor(returns).to(self.device)
        
        # Normalize advantages
        advantages_t = (advantages_t - advantages_t.mean()) / (advantages_t.std() + 1e-8)
        
        # PPO epochs
        total_loss = 0
        for _ in range(self.n_epochs):
            # Get current policy
            _, new_log_probs, entropy, values_pred = self.network.get_action_and_value(
                grids, state_vecs, actions_t
            )
            
            # Ratio for clipped objective
            ratio = torch.exp(new_log_probs - old_log_probs_t)
            
            # Policy loss with clipping
            surr1 = ratio * advantages_t
            surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages_t
            policy_loss = -torch.min(surr1, surr2).mean()
            
            # Value loss
            value_loss = F.mse_loss(values_pred.squeeze(), returns_t)
            
            # Entropy bonus
            entropy_loss = -entropy.mean()
            
            # Total loss
            loss = policy_loss + self.value_coef * value_loss + self.entropy_coef * entropy_loss
            
            # Optimize
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.network.parameters(), 0.5)
            self.optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / self.n_epochs
        self.loss_history.append(avg_loss)
        
        # Clear buffer
        self.rollout_buffer = []
        
        return avg_loss

# ============================================================================
# TRAINING UTILITIES
# ============================================================================

def train_agent(env, agent, n_episodes: int = 100, agent_type: str = "dqn"):
    """
    Universal training loop for different agent types.
    """
    episode_rewards = []
    episode_lengths = []
    losses = []
    
    print(f"\nðŸš€ Training {agent_type.upper()} agent for {n_episodes} episodes...")
    
    for episode in range(n_episodes):
        observations, _ = env.reset()
        episode_reward = 0
        episode_length = 0
        episode_loss = []
        
        # For single agent (first agent only)
        obs = observations[0]
        done = False
        
        while not done and episode_length < env.config.max_steps:
            # Select action
            if agent_type == "ppo":
                action, log_prob, value = agent.select_action(obs)
            else:
                action = agent.select_action(obs)
            
            # Step environment (all agents take same action for simplicity)
            actions = [action] * env.config.n_agents
            next_observations, rewards, dones, truncated, info = env.step(actions)
            
            next_obs = next_observations[0]
            reward = rewards[0]
            done = dones[0] or truncated[0]
            
            # Update agent
            if agent_type == "ppo":
                agent.store_transition(obs, action, reward, log_prob, value, done)
                if len(agent.rollout_buffer) >= agent.batch_size:
                    loss = agent.update()
                    if loss:
                        episode_loss.append(loss)
            else:
                loss = agent.update(obs, action, reward, next_obs, done)
                if loss:
                    episode_loss.append(loss)
            
            obs = next_obs
            episode_reward += reward
            episode_length += 1
        
        # Decay exploration
        if hasattr(agent, 'decay_epsilon'):
            agent.decay_epsilon()
        
        episode_rewards.append(episode_reward)
        episode_lengths.append(episode_length)
        if episode_loss:
            losses.append(np.mean(episode_loss))
        
        # Logging
        if (episode + 1) % 10 == 0:
            avg_reward = np.mean(episode_rewards[-10:])
            avg_length = np.mean(episode_lengths[-10:])
            eps = agent.epsilon if hasattr(agent, 'epsilon') else 0
            print(f"Episode {episode+1}/{n_episodes} | Avg Reward: {avg_reward:.2f} | Avg Length: {avg_length:.1f} | Îµ: {eps:.3f}")
    
    return {
        'rewards': episode_rewards,
        'lengths': episode_lengths,
        'losses': losses
    }

def plot_training_results(results: Dict, title: str):
    """Plot training metrics"""
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # Rewards
    rewards = results['rewards']
    axes[0].plot(rewards, alpha=0.3, label='Raw')
    if len(rewards) > 10:
        smoothed = np.convolve(rewards, np.ones(10)/10, mode='valid')
        axes[0].plot(range(9, len(rewards)), smoothed, linewidth=2, label='Smoothed (10)')
    axes[0].set_xlabel('Episode')
    axes[0].set_ylabel('Total Reward')
    axes[0].set_title('Episode Rewards')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Lengths
    lengths = results['lengths']
    axes[1].plot(lengths, alpha=0.3, label='Raw')
    if len(lengths) > 10:
        smoothed = np.convolve(lengths, np.ones(10)/10, mode='valid')
        axes[1].plot(range(9, len(lengths)), smoothed, linewidth=2, label='Smoothed (10)')
    axes[1].set_xlabel('Episode')
    axes[1].set_ylabel('Episode Length')
    axes[1].set_title('Episode Lengths')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    # Losses
    if results['losses']:
        axes[2].plot(results['losses'], alpha=0.6)
        axes[2].set_xlabel('Update Step')
        axes[2].set_ylabel('Loss')
        axes[2].set_title('Training Loss')
        axes[2].grid(True, alpha=0.3)
    
    plt.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

# ============================================================================
# DEMO: Train Different Agents
# ============================================================================

print("\n" + "="*80)
print("READY TO TRAIN RL AGENTS!")
print("="*80)
print("\nAvailable agents:")
print("1. TabularQLearningAgent - Classic Q-learning")
print("2. DQNAgent - Deep Q-Network")
print("3. PPOAgent - Proximal Policy Optimization")
print("\nUncomment the training blocks below to run experiments!")
print("="*80)