In [None]:
# Download Codebase from GitHub
import os
import subprocess
import sys
from pathlib import Path

# Repository information
REPO_URL = "https://github.com/yimianxyz/homepage.git"
BRANCH = "neuro-predator"
REPO_DIR = "homepage"

def download_codebase():
    """Download the codebase from GitHub if not already present"""
    
    if os.path.exists(REPO_DIR):
        print(f"Repository directory '{REPO_DIR}' already exists.")
        
        try:
            os.chdir(REPO_DIR)
            
            # Check current branch
            result = subprocess.run(['git', 'branch', '--show-current'],
                                  capture_output=True, text=True, check=True)
            current_branch = result.stdout.strip()
            
            if current_branch != BRANCH:
                print(f"Switching to branch '{BRANCH}'...")
                subprocess.run(['git', 'checkout', BRANCH], check=True)
            
            # Pull latest changes
            print("Updating repository...")
            subprocess.run(['git', 'pull', 'origin', BRANCH], check=True)
            
            print(f"✅ Repository updated successfully!")
            
        except subprocess.CalledProcessError as e:
            print(f"❌ Error updating repository: {e}")
            print("Repository directory exists but may not be a valid git repository.")
            
    else:
        print(f"Cloning repository from {REPO_URL} (branch: {BRANCH})...")
        
        try:
            # Clone the specific branch
            subprocess.run(['git', 'clone', '-b', BRANCH, REPO_URL, REPO_DIR], check=True)
            
            print(f"✅ Repository cloned successfully!")
            
            # Change to repository directory
            os.chdir(REPO_DIR)
            
        except subprocess.CalledProcessError as e:
            print(f"❌ Error cloning repository: {e}")
            print("Make sure you have git installed and internet connection.")
            return False
    
    # Verify key files exist
    key_files = [
        'config/constants.py',
        'simulation/processors/input_processor.py',
        'simulation/runtime/simulation_runtime.py',
        'simulation/state_manager/state_manager.py',
        'policy/human_prior/closest_pursuit_policy.py'
    ]
    
    missing_files = []
    for file_path in key_files:
        if not os.path.exists(file_path):
            missing_files.append(file_path)
    
    if missing_files:
        print(f"⚠️  Warning: Some key files are missing:")
        for file_path in missing_files:
            print(f"  - {file_path}")
        return False
    
    print(f"✅ All key files found!")
    print(f"📁 Working directory: {os.getcwd()}")
    
    return True

# Download the codebase
success = download_codebase()

if success:
    print("\n🎉 Setup complete! Ready for RL training.")
else:
    print("❌ Setup failed. Please check the errors above and try again.")


In [None]:
# Imports and Configuration
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from collections import deque, defaultdict
import random
import time
from datetime import datetime
import json
from pathlib import Path
import math
from typing import Dict, List, Any, Tuple, Optional

# Ensure we're in the correct directory and add to Python path
project_root = Path.cwd()
if project_root.name != 'homepage':
    print(f"⚠️  Warning: Current directory is '{project_root.name}', expected 'homepage'")
    print("Make sure the first cell downloaded the repository correctly.")

if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

# Import project modules
try:
    from config.constants import CONSTANTS
    from simulation.processors import InputProcessor, ActionProcessor
    from simulation.state_manager import StateManager
    from simulation.random_state_generator import RandomStateGenerator
    from policy.human_prior.closest_pursuit_policy import create_closest_pursuit_policy
    
    print(f"✅ Successfully imported all simulation modules")
    print(f"📁 Project root: {project_root}")
    print(f"🔧 Key constants: MAX_DISTANCE={CONSTANTS.MAX_DISTANCE}, BOID_MAX_SPEED={CONSTANTS.BOID_MAX_SPEED}")
    
except ImportError as e:
    print(f"❌ Failed to import modules: {e}")
    print("Make sure the repository was downloaded correctly in the first cell.")
    raise

# RL Training Configuration
RL_CONFIG = {
    # PPO Hyperparameters
    'learning_rate': 3e-4,
    'clip_epsilon': 0.2,
    'entropy_coef': 0.01,
    'value_coef': 0.5,
    'max_grad_norm': 0.5,
    'ppo_epochs': 4,
    'mini_batch_size': 64,
    'rollout_steps': 2048,
    
    # Environment Settings
    'min_boids': 1,
    'max_boids': 50,
    'min_canvas_width': 400,
    'max_canvas_width': 1600,
    'min_canvas_height': 400,
    'max_canvas_height': 1200,
    'timeout_multiplier': 3.0,  # Adaptive timeout: canvas_area * initial_boids * multiplier / 10000
    
    # Reward Settings
    'catch_reward': 1.0,
    'reward_decay_rate': 0.1,  # For exponential decay over last 50 steps
    'reward_window': 50,  # Number of steps before catch to attribute reward
    
    # Training Settings
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'num_episodes': 5000,
    'log_interval': 100,
    'save_interval': 500,
    'eval_interval': 200,
    'eval_episodes': 50,
}

print(f"\n🚀 RL Training Configuration:")
print(f"  Device: {RL_CONFIG['device']}")
print(f"  Episodes: {RL_CONFIG['num_episodes']}")
print(f"  Rollout steps: {RL_CONFIG['rollout_steps']}")
print(f"  Environment: {RL_CONFIG['min_boids']}-{RL_CONFIG['max_boids']} boids, {RL_CONFIG['min_canvas_width']}x{RL_CONFIG['min_canvas_height']} to {RL_CONFIG['max_canvas_width']}x{RL_CONFIG['max_canvas_height']}")
print(f"  Reward: {RL_CONFIG['catch_reward']} per catch, {RL_CONFIG['reward_window']}-step attribution window")


In [None]:
# Load Supervised Learning Model and Create Architecture
from transformer_training import GEGLU, TransformerLayer, TransformerPredictor

# Load SL checkpoint and extract architecture
def load_sl_checkpoint(checkpoint_path: str = "checkpoints/best_model.pt"):
    """Load the supervised learning checkpoint and extract architecture parameters"""
    
    if not os.path.exists(checkpoint_path):
        print(f"❌ Checkpoint not found: {checkpoint_path}")
        print("Available checkpoints:")
        if os.path.exists("checkpoints/"):
            checkpoints = list(Path("checkpoints/").glob("*.pt"))
            for cp in checkpoints:
                print(f"  - {cp}")
        else:
            print("  No checkpoints directory found")
        return None, None
    
    print(f"Loading SL checkpoint from {checkpoint_path}...")
    checkpoint = torch.load(checkpoint_path, map_location=RL_CONFIG['device'])
    
    # Extract architecture parameters
    if 'architecture' in checkpoint:
        arch = checkpoint['architecture']
        print(f"✅ Found architecture in checkpoint:")
        print(f"  d_model: {arch['d_model']}")
        print(f"  n_heads: {arch['n_heads']}")
        print(f"  n_layers: {arch['n_layers']}")
        print(f"  ffn_hidden: {arch['ffn_hidden']}")
        print(f"  max_boids: {arch['max_boids']}")
    else:
        print("❌ No architecture found in checkpoint - using default values")
        arch = {
            'd_model': 128,
            'n_heads': 8,
            'n_layers': 4,
            'ffn_hidden': 512,
            'max_boids': 50
        }
    
    # Additional info
    epoch = checkpoint.get('epoch', 'unknown')
    val_loss = checkpoint.get('best_val_loss', 'unknown')
    print(f"  Checkpoint epoch: {epoch}")
    print(f"  Best validation loss: {val_loss}")
    
    return checkpoint, arch

# Create Actor Network (uses transformer from SL)
class ActorNetwork(nn.Module):
    """Actor network using the pretrained transformer"""
    
    def __init__(self, checkpoint, architecture):
        super().__init__()
        
        # Create transformer with same architecture as SL model
        self.transformer = TransformerPredictor(
            d_model=architecture['d_model'],
            n_heads=architecture['n_heads'], 
            n_layers=architecture['n_layers'],
            ffn_hidden=architecture['ffn_hidden'],
            max_boids=architecture['max_boids'],
            dropout=0.1
        )
        
        # Load pretrained weights
        if checkpoint is not None:
            self.transformer.load_state_dict(checkpoint['model_state_dict'])
            print("✅ Loaded pretrained SL weights into actor")
        
        # Store architecture info
        self.architecture = architecture
        
    def forward(self, structured_inputs):
        """Forward pass through transformer"""
        # Transformer already outputs tanh(-1, 1) range
        return self.transformer(structured_inputs)
    
    def get_action_and_log_prob(self, structured_inputs):
        """Get action and log probability for PPO"""
        # Get mean action from transformer
        action_mean = self.forward(structured_inputs)
        
        # For continuous control, we can either:
        # 1. Use deterministic policy (no noise)
        # 2. Add small amount of noise for exploration
        
        # Option 1: Deterministic (like supervised learning)
        action = action_mean
        # For deterministic policy, log_prob is not meaningful, but PPO needs it
        # We'll use a very small variance Gaussian for technical compatibility
        std = torch.ones_like(action) * 0.01  # Very small exploration noise
        dist = torch.distributions.Normal(action_mean, std)
        log_prob = dist.log_prob(action).sum(dim=-1)
        
        return action, log_prob

# Create Critic Network (separate value function)
class CriticNetwork(nn.Module):
    """Simple critic network for value estimation"""
    
    def __init__(self, input_dim=128, hidden_dims=[256, 256]):
        super().__init__()
        
        layers = []
        prev_dim = input_dim
        
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.1)
            ])
            prev_dim = hidden_dim
            
        layers.append(nn.Linear(prev_dim, 1))  # Single value output
        
        self.network = nn.Sequential(*layers)
        
    def forward(self, state_features):
        """Estimate value from state features"""
        return self.network(state_features).squeeze(-1)

# Feature extractor for critic (simple aggregation of structured inputs)
def extract_state_features(structured_inputs):
    """Extract fixed-size features from structured inputs for critic"""
    batch_size = len(structured_inputs) if isinstance(structured_inputs, list) else 1
    
    if isinstance(structured_inputs, dict):
        structured_inputs = [structured_inputs]
    
    features = []
    
    for sample in structured_inputs:
        # Context features (2D)
        context_feat = [sample['context']['canvasWidth'], sample['context']['canvasHeight']]
        
        # Predator features (2D) 
        predator_feat = [sample['predator']['velX'], sample['predator']['velY']]
        
        # Boid features (aggregate statistics)
        boids = sample['boids']
        if len(boids) > 0:
            # Statistical aggregation of boids
            rel_x = [b['relX'] for b in boids]
            rel_y = [b['relY'] for b in boids]
            vel_x = [b['velX'] for b in boids]
            vel_y = [b['velY'] for b in boids]
            
            boid_feat = [
                len(boids),  # Number of boids
                np.mean(rel_x), np.std(rel_x), np.min(rel_x), np.max(rel_x),  # Position stats
                np.mean(rel_y), np.std(rel_y), np.min(rel_y), np.max(rel_y),
                np.mean(vel_x), np.std(vel_x), np.min(vel_x), np.max(vel_x),  # Velocity stats
                np.mean(vel_y), np.std(vel_y), np.min(vel_y), np.max(vel_y),
                # Distance to closest boid
                np.min([math.sqrt(b['relX']**2 + b['relY']**2) for b in boids])
            ]
        else:
            # No boids remaining
            boid_feat = [0] + [0.0] * 16  # 17 features total
        
        # Combine all features (2 + 2 + 17 = 21 features)
        sample_feat = context_feat + predator_feat + boid_feat
        features.append(sample_feat)
    
    return torch.tensor(features, dtype=torch.float32, device=RL_CONFIG['device'])

# Load checkpoint and create networks
checkpoint, architecture = load_sl_checkpoint()

if checkpoint is not None:
    # Create actor and critic networks
    actor = ActorNetwork(checkpoint, architecture).to(RL_CONFIG['device'])
    critic = CriticNetwork(input_dim=21).to(RL_CONFIG['device'])  # 21 features from extract_state_features
    
    # Count parameters
    actor_params = sum(p.numel() for p in actor.parameters())
    critic_params = sum(p.numel() for p in critic.parameters())
    
    print(f"\n🧠 Networks created:")
    print(f"  Actor parameters: {actor_params:,} (pretrained transformer)")
    print(f"  Critic parameters: {critic_params:,} (new MLP)")
    print(f"  Total parameters: {actor_params + critic_params:,}")
    
    # Test forward pass
    test_input = {
        'context': {'canvasWidth': 0.8, 'canvasHeight': 0.6},
        'predator': {'velX': 0.1, 'velY': -0.2},
        'boids': [
            {'relX': 0.1, 'relY': 0.3, 'velX': 0.5, 'velY': -0.1},
            {'relX': -0.2, 'relY': 0.1, 'velX': -0.3, 'velY': 0.4}
        ]
    }
    
    with torch.no_grad():
        test_action, test_log_prob = actor.get_action_and_log_prob(test_input)
        test_features = extract_state_features(test_input)
        test_value = critic(test_features)
        
    print(f"\n🧪 Test forward pass:")
    print(f"  Input: {len(test_input['boids'])} boids")
    print(f"  Actor output: [{test_action[0].item():.4f}, {test_action[1].item():.4f}]")
    print(f"  Log prob: {test_log_prob.item():.4f}")
    print(f"  State features shape: {test_features.shape}")
    print(f"  Critic value: {test_value.item():.4f}")
    
    print(f"\n✅ Networks ready for RL training!")
else:
    print("❌ Cannot proceed without SL checkpoint")
    actor = critic = None


In [None]:
# RL Environment Wrapper
class PredatorEnvironment:
    """RL Environment wrapper for the predator-boids simulation"""
    
    def __init__(self, config=RL_CONFIG):
        self.config = config
        
        # Initialize simulation components
        self.state_manager = StateManager()
        self.random_generator = RandomStateGenerator()
        
        # Environment state
        self.current_step = 0
        self.max_steps = 0
        self.initial_boids_count = 0
        self.episode_catches = []
        self.step_history = []  # For reward attribution
        
        # Metrics tracking
        self.reset_episode_stats()
        
    def reset_episode_stats(self):
        """Reset episode-level statistics"""
        self.episode_catches = []
        self.step_history = []
        self.current_step = 0
        
    def calculate_adaptive_timeout(self, canvas_width, canvas_height, num_boids):
        """Calculate adaptive timeout based on environment size and complexity"""
        canvas_area = canvas_width * canvas_height
        base_timeout = (canvas_area * num_boids * self.config['timeout_multiplier']) / 10000
        return max(int(base_timeout), 200)  # Minimum 200 steps
    
    def reset(self):
        """Reset environment for new episode"""
        self.reset_episode_stats()
        
        # Generate random environment parameters
        num_boids = random.randint(self.config['min_boids'], self.config['max_boids'])
        canvas_width = random.randint(self.config['min_canvas_width'], self.config['max_canvas_width'])
        canvas_height = random.randint(self.config['min_canvas_height'], self.config['max_canvas_height'])
        
        # Generate random initial state
        initial_state = self.random_generator.generate_scattered_state(
            num_boids, canvas_width, canvas_height
        )
        
        # Initialize state manager with dummy policy (we'll override actions)
        dummy_policy = create_closest_pursuit_policy()  # Not used, just for initialization
        self.state_manager.init(initial_state, dummy_policy)
        
        # Set episode parameters
        self.initial_boids_count = num_boids
        self.max_steps = self.calculate_adaptive_timeout(canvas_width, canvas_height, num_boids)
        self.current_step = 0
        
        # Get initial observation
        current_state = self.state_manager.get_state()
        observation = self._state_to_structured_inputs(current_state)
        
        return observation, {
            'canvas_width': canvas_width,
            'canvas_height': canvas_height,
            'initial_boids': num_boids,
            'max_steps': self.max_steps
        }
    
    def _state_to_structured_inputs(self, state):
        """Convert state to structured inputs format"""
        # Use input processor to convert state
        input_processor = InputProcessor()
        
        # Extract data from state
        boids = state['boids_states']
        predator_pos = state['predator_state']['position']
        predator_vel = state['predator_state']['velocity']
        canvas_width = state['canvas_width']
        canvas_height = state['canvas_height']
        
        # Convert to structured inputs
        structured_inputs = input_processor.process_inputs(
            boids, predator_pos, predator_vel, canvas_width, canvas_height
        )
        
        return structured_inputs
    
    def step(self, action):
        """Take one step in the environment"""
        # Convert action to the format expected by state manager
        action_processor = ActionProcessor()
        game_actions = action_processor.process_action([action[0].item(), action[1].item()])
        
        # Store current state for reward attribution
        current_state = self.state_manager.get_state()
        boids_before = len(current_state['boids_states'])
        
        # Manual step (bypass state manager's policy)
        predator_action = {
            'force_x': game_actions[0],
            'force_y': game_actions[1]
        }
        
        # Import simulation step function directly
        from simulation.runtime.simulation_runtime import simulation_step
        
        # Run simulation step
        step_result = simulation_step(
            current_state['boids_states'],
            current_state['predator_state'],
            predator_action,
            current_state['canvas_width'],
            current_state['canvas_height']
        )
        
        # Update state (remove caught boids)
        caught_boids = step_result['caught_boids']
        new_boids_states = step_result['boids_states']
        
        # Remove caught boids in reverse order to maintain indices
        for i in reversed(caught_boids):
            new_boids_states.pop(i)
        
        # Update state manager's internal state
        new_state = {
            'boids_states': new_boids_states,
            'predator_state': step_result['predator_state'],
            'canvas_width': current_state['canvas_width'],
            'canvas_height': current_state['canvas_height']
        }
        
        # Manually update state manager
        self.state_manager.current_state = new_state
        
        # Record step information
        step_info = {
            'step': self.current_step,
            'action': action.clone(),
            'boids_before': boids_before,
            'boids_after': len(new_boids_states),
            'catches': len(caught_boids)
        }
        self.step_history.append(step_info)
        
        # Track catches
        if len(caught_boids) > 0:
            for _ in caught_boids:
                self.episode_catches.append(self.current_step)
        
        self.current_step += 1
        
        # Check termination conditions
        done = False
        termination_reason = None
        
        if len(new_boids_states) == 0:
            done = True
            termination_reason = "all_caught"
        elif self.current_step >= self.max_steps:
            done = True
            termination_reason = "timeout"
        
        # Calculate reward
        reward = self._calculate_reward()
        
        # Get next observation
        observation = self._state_to_structured_inputs(new_state)
        
        # Info dictionary
        info = {
            'step': self.current_step,
            'catches_this_step': len(caught_boids),
            'total_catches': len(self.episode_catches),
            'boids_remaining': len(new_boids_states),
            'done': done,
            'termination_reason': termination_reason,
            'reward': reward
        }
        
        return observation, reward, done, info
    
    def _calculate_reward(self):
        """Calculate reward based on recent catches with temporal credit assignment"""
        total_reward = 0.0
        
        # Only calculate reward if we have catches
        if len(self.episode_catches) == 0:
            return 0.0
        
        # For each catch, attribute reward to the last N steps
        for catch_step in self.episode_catches:
            # Define the reward window (steps before the catch)
            reward_start = max(0, catch_step - self.config['reward_window'] + 1)
            reward_end = catch_step + 1
            
            # Only attribute reward to current step if it's within the window
            if reward_start <= self.current_step - 1 < reward_end:
                # Calculate steps before catch (for decay)
                steps_before_catch = catch_step - (self.current_step - 1)
                
                # Exponential decay: more recent actions get higher reward
                decay_factor = math.exp(-self.config['reward_decay_rate'] * steps_before_catch)
                step_reward = self.config['catch_reward'] * decay_factor
                
                total_reward += step_reward
        
        return total_reward

# Test the environment
if actor is not None:
    print("🌍 Testing RL Environment...")
    
    # Create test environment
    test_env = PredatorEnvironment()
    
    # Test reset
    obs, info = test_env.reset()
    print(f"  Environment reset:")
    print(f"    Canvas: {info['canvas_width']}x{info['canvas_height']}")
    print(f"    Boids: {info['initial_boids']}")
    print(f"    Max steps: {info['max_steps']}")
    print(f"    Observation boids: {len(obs['boids'])}")
    
    # Test a few steps
    print(f"  Testing steps:")
    for i in range(3):
        with torch.no_grad():
            action, _ = actor.get_action_and_log_prob(obs)
            
        obs, reward, done, info = test_env.step(action)
        
        print(f"    Step {i+1}: action=[{action[0]:.3f}, {action[1]:.3f}], "
              f"reward={reward:.3f}, boids={info['boids_remaining']}, "
              f"catches={info['catches_this_step']}")
        
        if done:
            print(f"    Episode done: {info['termination_reason']}")
            break
    
    print("✅ Environment test completed!")
else:
    print("❌ Skipping environment test - actor not available")


In [None]:
# PPO Implementation
class PPOTrainer:
    """PPO trainer for the predator agent"""
    
    def __init__(self, actor, critic, config=RL_CONFIG):
        self.actor = actor
        self.critic = critic
        self.config = config
        
        # Optimizers
        self.actor_optimizer = optim.Adam(actor.parameters(), lr=config['learning_rate'])
        self.critic_optimizer = optim.Adam(critic.parameters(), lr=config['learning_rate'])
        
        # Experience buffer
        self.reset_buffer()
        
        # Training metrics
        self.training_stats = defaultdict(list)
        
    def reset_buffer(self):
        """Reset experience buffer"""
        self.buffer = {
            'observations': [],
            'actions': [],
            'rewards': [],
            'values': [],
            'log_probs': [],
            'dones': [],
            'returns': [],
            'advantages': []
        }
        
    def collect_rollout(self, env, num_steps):
        """Collect a rollout of experiences"""
        self.reset_buffer()
        
        # Reset environment for new rollout
        obs, info = env.reset()
        
        rollout_stats = {
            'episode_rewards': [],
            'episode_lengths': [],
            'catches_per_episode': [],
            'episodes_completed': 0
        }
        
        current_episode_reward = 0
        current_episode_length = 0
        current_episode_catches = 0
        
        for step in range(num_steps):
            # Get action and value
            with torch.no_grad():
                action, log_prob = self.actor.get_action_and_log_prob(obs)
                state_features = extract_state_features(obs)
                value = self.critic(state_features)
            
            # Take step in environment
            next_obs, reward, done, step_info = env.step(action)
            
            # Store experience
            self.buffer['observations'].append(obs)
            self.buffer['actions'].append(action)
            self.buffer['rewards'].append(reward)
            self.buffer['values'].append(value)
            self.buffer['log_probs'].append(log_prob)
            self.buffer['dones'].append(done)
            
            # Update episode stats
            current_episode_reward += reward
            current_episode_length += 1
            current_episode_catches += step_info['catches_this_step']
            
            obs = next_obs
            
            if done:
                # Episode finished
                rollout_stats['episode_rewards'].append(current_episode_reward)
                rollout_stats['episode_lengths'].append(current_episode_length)
                rollout_stats['catches_per_episode'].append(current_episode_catches)
                rollout_stats['episodes_completed'] += 1
                
                # Reset for next episode
                obs, info = env.reset()
                current_episode_reward = 0
                current_episode_length = 0
                current_episode_catches = 0
        
        # Calculate returns and advantages
        self._calculate_returns_and_advantages(obs)
        
        return rollout_stats
    
    def _calculate_returns_and_advantages(self, final_obs):
        """Calculate returns and advantages using GAE"""
        # Get final value estimate
        with torch.no_grad():
            final_state_features = extract_state_features(final_obs)
            final_value = self.critic(final_state_features).item()
        
        # Convert to tensors
        rewards = torch.tensor(self.buffer['rewards'], dtype=torch.float32, device=self.config['device'])
        values = torch.stack(self.buffer['values'])
        dones = torch.tensor(self.buffer['dones'], dtype=torch.float32, device=self.config['device'])
        
        # Calculate returns (discounted cumulative rewards)
        returns = []
        advantages = []
        
        # Simple return calculation (no GAE for now)
        gamma = 0.99  # Discount factor
        gae_lambda = 0.95  # GAE lambda
        
        # Calculate returns
        returns_tensor = torch.zeros_like(rewards)
        running_return = final_value
        
        for t in reversed(range(len(rewards))):
            if dones[t]:
                running_return = 0
            running_return = rewards[t] + gamma * running_return
            returns_tensor[t] = running_return
        
        # Calculate advantages (simple baseline)
        advantages_tensor = returns_tensor - values.squeeze()
        
        # Normalize advantages
        if len(advantages_tensor) > 1:
            advantages_tensor = (advantages_tensor - advantages_tensor.mean()) / (advantages_tensor.std() + 1e-8)
        
        self.buffer['returns'] = returns_tensor
        self.buffer['advantages'] = advantages_tensor
    
    def update_policy(self):
        """Update actor and critic using PPO"""
        # Convert buffer to tensors
        observations = self.buffer['observations']
        actions = torch.stack(self.buffer['actions'])
        old_log_probs = torch.stack(self.buffer['log_probs'])
        returns = self.buffer['returns']
        advantages = self.buffer['advantages']
        
        # Training loop
        for epoch in range(self.config['ppo_epochs']):
            # Shuffle data
            indices = torch.randperm(len(observations))
            
            # Mini-batch training
            for i in range(0, len(observations), self.config['mini_batch_size']):
                batch_indices = indices[i:i+self.config['mini_batch_size']]
                
                # Get batch data
                batch_obs = [observations[idx] for idx in batch_indices]
                batch_actions = actions[batch_indices]
                batch_old_log_probs = old_log_probs[batch_indices]
                batch_returns = returns[batch_indices]
                batch_advantages = advantages[batch_indices]
                
                # Forward pass
                _, new_log_probs = self.actor.get_action_and_log_prob(batch_obs)
                batch_state_features = extract_state_features(batch_obs)
                new_values = self.critic(batch_state_features)
                
                # PPO loss calculations
                ratio = torch.exp(new_log_probs - batch_old_log_probs)
                
                # Actor loss (PPO clipping)
                surr1 = ratio * batch_advantages
                surr2 = torch.clamp(ratio, 1.0 - self.config['clip_epsilon'], 1.0 + self.config['clip_epsilon']) * batch_advantages
                actor_loss = -torch.min(surr1, surr2).mean()
                
                # Critic loss
                critic_loss = F.mse_loss(new_values, batch_returns)
                
                # Entropy bonus (for exploration)
                entropy = -new_log_probs.mean()  # Simple entropy approximation
                entropy_loss = -self.config['entropy_coef'] * entropy
                
                # Total losses
                total_actor_loss = actor_loss + entropy_loss
                total_critic_loss = critic_loss
                
                # Update actor
                self.actor_optimizer.zero_grad()
                total_actor_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.config['max_grad_norm'])
                self.actor_optimizer.step()
                
                # Update critic
                self.critic_optimizer.zero_grad()
                total_critic_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.config['max_grad_norm'])
                self.critic_optimizer.step()
        
        # Store training statistics
        self.training_stats['actor_loss'].append(actor_loss.item())
        self.training_stats['critic_loss'].append(critic_loss.item())
        self.training_stats['entropy'].append(entropy.item())
        
        return {
            'actor_loss': actor_loss.item(),
            'critic_loss': critic_loss.item(),
            'entropy': entropy.item()
        }
    
    def save_checkpoint(self, filepath, episode, stats):
        """Save training checkpoint"""
        checkpoint = {
            'episode': episode,
            'actor_state_dict': self.actor.state_dict(),
            'critic_state_dict': self.critic.state_dict(),
            'actor_optimizer_state_dict': self.actor_optimizer.state_dict(),
            'critic_optimizer_state_dict': self.critic_optimizer.state_dict(),
            'training_stats': dict(self.training_stats),
            'config': self.config,
            'architecture': self.actor.architecture,
            'episode_stats': stats,
            'timestamp': datetime.now().isoformat()
        }
        
        torch.save(checkpoint, filepath)
        print(f"✅ Saved RL checkpoint: {filepath}")

# Initialize PPO trainer
if actor is not None and critic is not None:
    ppo_trainer = PPOTrainer(actor, critic, RL_CONFIG)
    print("🎯 PPO Trainer initialized!")
    print(f"  Actor learning rate: {RL_CONFIG['learning_rate']}")
    print(f"  Critic learning rate: {RL_CONFIG['learning_rate']}")
    print(f"  Clip epsilon: {RL_CONFIG['clip_epsilon']}")
    print(f"  Mini-batch size: {RL_CONFIG['mini_batch_size']}")
    print("✅ Ready for RL training!")
else:
    print("❌ Cannot initialize PPO trainer - networks not available")
    ppo_trainer = None


In [None]:
# Training Loop with Metrics
class TrainingMetrics:
    """Track and manage training metrics"""
    
    def __init__(self):
        self.reset()
        
    def reset(self):
        """Reset all metrics"""
        self.metrics = {
            'episode_rewards': [],
            'episode_lengths': [],
            'catches_per_episode': [],
            'success_rate': [],  # Episodes with at least 1 catch
            'catch_efficiency': [],  # Catches per step
            'steps_to_first_catch': [],
            'total_episodes': 0,
            'total_steps': 0,
            'total_catches': 0,
            
            # Training losses
            'actor_losses': [],
            'critic_losses': [],
            'entropy': [],
            
            # Environment stats
            'canvas_sizes': [],
            'initial_boids': [],
            'timeout_rate': []
        }
        
    def update_episode(self, reward, length, catches, first_catch_step, canvas_size, initial_boids, termination_reason):
        """Update metrics with episode data"""
        self.metrics['episode_rewards'].append(reward)
        self.metrics['episode_lengths'].append(length)
        self.metrics['catches_per_episode'].append(catches)
        
        # Success rate (episodes with at least 1 catch)
        success = 1 if catches > 0 else 0
        self.metrics['success_rate'].append(success)
        
        # Catch efficiency (catches per step)
        efficiency = catches / length if length > 0 else 0
        self.metrics['catch_efficiency'].append(efficiency)
        
        # Steps to first catch
        if first_catch_step is not None:
            self.metrics['steps_to_first_catch'].append(first_catch_step)
        
        # Environment stats
        self.metrics['canvas_sizes'].append(canvas_size)
        self.metrics['initial_boids'].append(initial_boids)
        self.metrics['timeout_rate'].append(1 if termination_reason == 'timeout' else 0)
        
        # Global counters
        self.metrics['total_episodes'] += 1
        self.metrics['total_steps'] += length
        self.metrics['total_catches'] += catches
    
    def update_training(self, actor_loss, critic_loss, entropy):
        """Update training loss metrics"""
        self.metrics['actor_losses'].append(actor_loss)
        self.metrics['critic_losses'].append(critic_loss)
        self.metrics['entropy'].append(entropy)
    
    def get_recent_stats(self, window=100):
        """Get statistics for recent episodes"""
        if len(self.metrics['episode_rewards']) == 0:
            return {}
            
        recent_rewards = self.metrics['episode_rewards'][-window:]
        recent_lengths = self.metrics['episode_lengths'][-window:]
        recent_catches = self.metrics['catches_per_episode'][-window:]
        recent_success = self.metrics['success_rate'][-window:]
        recent_efficiency = self.metrics['catch_efficiency'][-window:]
        recent_first_catch = self.metrics['steps_to_first_catch'][-window:] if self.metrics['steps_to_first_catch'] else []
        recent_timeout = self.metrics['timeout_rate'][-window:]
        
        stats = {
            'episodes': len(recent_rewards),
            'avg_reward': np.mean(recent_rewards),
            'avg_length': np.mean(recent_lengths),
            'avg_catches': np.mean(recent_catches),
            'success_rate': np.mean(recent_success),
            'avg_efficiency': np.mean(recent_efficiency),
            'avg_first_catch': np.mean(recent_first_catch) if recent_first_catch else None,
            'timeout_rate': np.mean(recent_timeout),
            'total_catches': sum(recent_catches)
        }
        
        return stats

def train_rl_agent(ppo_trainer, num_updates=1000):
    """Main RL training loop"""
    
    if ppo_trainer is None:
        print("❌ Cannot start training - PPO trainer not available")
        return
        
    print(f"🚀 Starting RL training for {num_updates} updates...")
    print(f"  Rollout steps per update: {RL_CONFIG['rollout_steps']}")
    print(f"  Total environment steps: {num_updates * RL_CONFIG['rollout_steps']:,}")
    
    # Initialize environment and metrics
    env = PredatorEnvironment(RL_CONFIG)
    metrics = TrainingMetrics()
    
    # Create checkpoints directory
    os.makedirs("rl_checkpoints", exist_ok=True)
    
    start_time = time.time()
    
    for update in range(num_updates):
        update_start = time.time()
        
        # Collect rollout
        rollout_stats = ppo_trainer.collect_rollout(env, RL_CONFIG['rollout_steps'])
        
        # Update policy
        training_stats = ppo_trainer.update_policy()
        
        # Update metrics
        for i, (reward, length, catches) in enumerate(zip(
            rollout_stats['episode_rewards'],
            rollout_stats['episode_lengths'], 
            rollout_stats['catches_per_episode']
        )):\n            # Estimate first catch step (simplified)\n            first_catch = length // 2 if catches > 0 else None\n            \n            # Environment stats (simplified)\n            canvas_size = 800 * 600  # Placeholder\n            initial_boids = 25  # Placeholder\n            termination_reason = 'timeout' if catches == 0 else 'caught'\n            \n            metrics.update_episode(\n                reward, length, catches, first_catch, \n                canvas_size, initial_boids, termination_reason\n            )\n        \n        metrics.update_training(\n            training_stats['actor_loss'],\n            training_stats['critic_loss'], \n            training_stats['entropy']\n        )\n        \n        update_time = time.time() - update_start\n        \n        # Logging\n        if (update + 1) % RL_CONFIG['log_interval'] == 0:\n            recent_stats = metrics.get_recent_stats(RL_CONFIG['log_interval'])\n            elapsed_time = time.time() - start_time\n            \n            print(f\"\\n📊 Update {update+1}/{num_updates} (Episode {metrics.metrics['total_episodes']})\")\n            print(f\"   Time: {elapsed_time:.1f}s, Update time: {update_time:.2f}s\")\n            print(f\"   Episodes: {rollout_stats['episodes_completed']}, Total steps: {metrics.metrics['total_steps']:,}\")\n            \n            if recent_stats['episodes'] > 0:\n                print(f\"   Recent {recent_stats['episodes']} episodes:\")\n                print(f\"     Avg reward: {recent_stats['avg_reward']:.3f}\")\n                print(f\"     Avg length: {recent_stats['avg_length']:.1f} steps\")\n                print(f\"     Avg catches: {recent_stats['avg_catches']:.1f}\")\n                print(f\"     Success rate: {recent_stats['success_rate']:.1%}\")\n                print(f\"     Catch efficiency: {recent_stats['avg_efficiency']:.4f} catches/step\")\n                print(f\"     Timeout rate: {recent_stats['timeout_rate']:.1%}\")\n            \n            print(f\"   Training:\")\n            print(f\"     Actor loss: {training_stats['actor_loss']:.4f}\")\n            print(f\"     Critic loss: {training_stats['critic_loss']:.4f}\")\n            print(f\"     Entropy: {training_stats['entropy']:.4f}\")\n        \n        # Save checkpoint\n        if (update + 1) % RL_CONFIG['save_interval'] == 0:\n            checkpoint_path = f\"rl_checkpoints/rl_checkpoint_{update+1}.pt\"\n            ppo_trainer.save_checkpoint(checkpoint_path, update + 1, recent_stats)\n            \n            # Also save as latest\n            latest_path = \"rl_checkpoints/latest.pt\"\n            ppo_trainer.save_checkpoint(latest_path, update + 1, recent_stats)\n    \n    total_time = time.time() - start_time\n    \n    print(f\"\\n🎉 RL Training completed!\")\n    print(f\"  Total time: {total_time:.1f}s ({total_time/60:.1f}m)\")\n    print(f\"  Total episodes: {metrics.metrics['total_episodes']}\")\n    print(f\"  Total steps: {metrics.metrics['total_steps']:,}\")\n    print(f\"  Total catches: {metrics.metrics['total_catches']}\")\n    \n    final_stats = metrics.get_recent_stats(200)\n    if final_stats['episodes'] > 0:\n        print(f\"\\n📈 Final performance (last 200 episodes):\")\n        print(f\"  Average reward: {final_stats['avg_reward']:.3f}\")\n        print(f\"  Average catches: {final_stats['avg_catches']:.1f}\")\n        print(f\"  Success rate: {final_stats['success_rate']:.1%}\")\n        print(f\"  Catch efficiency: {final_stats['avg_efficiency']:.4f}\")\n    \n    return metrics\n\n# Quick training test (small scale)\nif ppo_trainer is not None:\n    print(\"🧪 Quick training test (5 updates)...\")\n    \n    # Very small test\n    test_metrics = train_rl_agent(ppo_trainer, num_updates=5)\n    \n    print(\"✅ Training test completed! Ready for full training.\")\n    \n    # Uncomment below for full training\n    # print(\"\\n🚀 Starting full RL training...\")\n    # full_metrics = train_rl_agent(ppo_trainer, num_updates=500)\nelse:\n    print(\"❌ Skipping training test - PPO trainer not available\")


In [None]:
# Evaluation, Visualization and Model Export
def evaluate_rl_agent(actor, critic, num_episodes=50, verbose=True):
    """Evaluate the RL-trained agent"""
    
    if actor is None:
        print("❌ Cannot evaluate - actor not available")
        return None
        
    print(f"📊 Evaluating RL agent for {num_episodes} episodes...")
    
    env = PredatorEnvironment(RL_CONFIG)
    eval_stats = []
    
    for episode in range(num_episodes):
        obs, info = env.reset()
        
        episode_reward = 0
        episode_catches = 0
        episode_length = 0
        done = False
        
        while not done:
            with torch.no_grad():
                action, _ = actor.get_action_and_log_prob(obs)
                
            obs, reward, done, step_info = env.step(action)
            
            episode_reward += reward
            episode_catches += step_info['catches_this_step']
            episode_length += 1
        
        eval_stats.append({
            'reward': episode_reward,
            'length': episode_length,
            'catches': episode_catches,
            'success': episode_catches > 0,
            'efficiency': episode_catches / episode_length if episode_length > 0 else 0,
            'termination': step_info['termination_reason'],
            'canvas_size': info['canvas_width'] * info['canvas_height'],
            'initial_boids': info['initial_boids']
        })
        
        if verbose and (episode + 1) % 10 == 0:
            recent_stats = eval_stats[-10:]
            avg_catches = np.mean([s['catches'] for s in recent_stats])
            success_rate = np.mean([s['success'] for s in recent_stats])
            print(f"  Episodes {episode-8}-{episode+1}: {avg_catches:.1f} catches/ep, {success_rate:.1%} success")
    
    # Calculate summary statistics
    summary = {
        'total_episodes': len(eval_stats),
        'avg_reward': np.mean([s['reward'] for s in eval_stats]),
        'avg_catches': np.mean([s['catches'] for s in eval_stats]),
        'avg_length': np.mean([s['length'] for s in eval_stats]),
        'success_rate': np.mean([s['success'] for s in eval_stats]),
        'avg_efficiency': np.mean([s['efficiency'] for s in eval_stats]),
        'timeout_rate': np.mean([1 if s['termination'] == 'timeout' else 0 for s in eval_stats]),
        'total_catches': sum([s['catches'] for s in eval_stats])
    }
    
    print(f"\n📈 Evaluation Results:")
    print(f"  Episodes: {summary['total_episodes']}")
    print(f"  Average reward: {summary['avg_reward']:.3f}")
    print(f"  Average catches: {summary['avg_catches']:.2f}")
    print(f"  Success rate: {summary['success_rate']:.1%}")
    print(f"  Average efficiency: {summary['avg_efficiency']:.4f} catches/step")
    print(f"  Average episode length: {summary['avg_length']:.1f} steps")
    print(f"  Timeout rate: {summary['timeout_rate']:.1%}")
    print(f"  Total catches: {summary['total_catches']}")
    
    return eval_stats, summary

def plot_training_metrics(metrics):
    """Plot training progress"""
    
    if metrics is None or len(metrics.metrics['episode_rewards']) == 0:
        print("❌ No metrics to plot")
        return
        
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Episode rewards
    axes[0,0].plot(metrics.metrics['episode_rewards'])
    axes[0,0].set_title('Episode Rewards')
    axes[0,0].set_xlabel('Episode')
    axes[0,0].set_ylabel('Reward')
    axes[0,0].grid(True)
    
    # Catches per episode
    axes[0,1].plot(metrics.metrics['catches_per_episode'])
    axes[0,1].set_title('Catches per Episode')
    axes[0,1].set_xlabel('Episode')
    axes[0,1].set_ylabel('Catches')
    axes[0,1].grid(True)
    
    # Success rate (rolling average)
    if len(metrics.metrics['success_rate']) > 10:
        window = min(50, len(metrics.metrics['success_rate']) // 10)
        success_smooth = np.convolve(metrics.metrics['success_rate'], 
                                   np.ones(window)/window, mode='valid')
        axes[0,2].plot(success_smooth)
        axes[0,2].set_title(f'Success Rate (rolling avg, window={window})')
        axes[0,2].set_xlabel('Episode')
        axes[0,2].set_ylabel('Success Rate')
        axes[0,2].grid(True)
    
    # Training losses
    if len(metrics.metrics['actor_losses']) > 0:
        axes[1,0].plot(metrics.metrics['actor_losses'], label='Actor Loss')
        axes[1,0].plot(metrics.metrics['critic_losses'], label='Critic Loss')
        axes[1,0].set_title('Training Losses')
        axes[1,0].set_xlabel('Update')
        axes[1,0].set_ylabel('Loss')
        axes[1,0].legend()
        axes[1,0].grid(True)
    
    # Catch efficiency
    axes[1,1].plot(metrics.metrics['catch_efficiency'])
    axes[1,1].set_title('Catch Efficiency')
    axes[1,1].set_xlabel('Episode')
    axes[1,1].set_ylabel('Catches per Step')
    axes[1,1].grid(True)
    
    # Episode lengths
    axes[1,2].plot(metrics.metrics['episode_lengths'])
    axes[1,2].set_title('Episode Lengths')
    axes[1,2].set_xlabel('Episode')
    axes[1,2].set_ylabel('Steps')
    axes[1,2].grid(True)
    
    plt.tight_layout()
    plt.show()
    
    print("📊 Training metrics plotted!")

def export_rl_model(actor, output_path="policy/transformer/models/rl_model.js"):
    """Export RL-trained transformer to JavaScript"""
    
    if actor is None:
        print("❌ Cannot export - actor not available")
        return False
        
    print(f"🔄 Exporting RL-trained model to JavaScript format...")
    
    # Save PyTorch checkpoint first
    checkpoint_path = "rl_checkpoints/export_checkpoint.pt"
    checkpoint = {
        'model_state_dict': actor.transformer.state_dict(),
        'episode': 'rl_export',
        'architecture': actor.architecture,
        'timestamp': datetime.now().isoformat(),
        'training_type': 'reinforcement_learning'
    }
    
    torch.save(checkpoint, checkpoint_path)
    print(f"✅ Saved RL checkpoint for export: {checkpoint_path}")
    
    # Use export_to_js.py script
    import subprocess
    result = subprocess.run([
        sys.executable, "export_to_js.py",
        "--checkpoint", checkpoint_path,
        "--output", output_path
    ], capture_output=True, text=True)
    
    if result.returncode == 0:
        print(f"✅ RL model exported to: {output_path}")
        print("🎉 RL-trained model ready for browser deployment!")
        print(result.stdout)
        return True
    else:
        print(f"❌ Export failed:")
        print(result.stderr)
        return False

def compare_sl_vs_rl(sl_actor, rl_actor, num_episodes=20):
    """Compare SL vs RL performance side by side"""
    
    if sl_actor is None or rl_actor is None:
        print("❌ Cannot compare - need both SL and RL actors")
        return
        
    print(f"⚔️  Comparing SL vs RL performance ({num_episodes} episodes each)...")
    
    # Create fresh actors (reload SL weights)
    sl_eval_actor = ActorNetwork(checkpoint, architecture).to(RL_CONFIG['device'])
    
    # Evaluate both
    print("\\n📊 Evaluating SL model...")
    sl_stats, sl_summary = evaluate_rl_agent(sl_eval_actor, None, num_episodes, verbose=False)
    
    print("\\n📊 Evaluating RL model...")
    rl_stats, rl_summary = evaluate_rl_agent(rl_actor, None, num_episodes, verbose=False)
    
    # Comparison
    print(f"\\n⚔️  SL vs RL Comparison:")
    print(f"  Metric                | SL Model | RL Model | Improvement")
    print(f"  ---------------------|----------|----------|------------")
    print(f"  Avg Catches          | {sl_summary['avg_catches']:8.2f} | {rl_summary['avg_catches']:8.2f} | {((rl_summary['avg_catches']/max(sl_summary['avg_catches'],0.001))-1)*100:+7.1f}%")
    print(f"  Success Rate         | {sl_summary['success_rate']:7.1%} | {rl_summary['success_rate']:7.1%} | {((rl_summary['success_rate']/max(sl_summary['success_rate'],0.001))-1)*100:+7.1f}%")
    print(f"  Catch Efficiency     | {sl_summary['avg_efficiency']:8.4f} | {rl_summary['avg_efficiency']:8.4f} | {((rl_summary['avg_efficiency']/max(sl_summary['avg_efficiency'],0.0001))-1)*100:+7.1f}%")
    print(f"  Timeout Rate         | {sl_summary['timeout_rate']:7.1%} | {rl_summary['timeout_rate']:7.1%} | {((rl_summary['timeout_rate']/max(sl_summary['timeout_rate'],0.001))-1)*100:+7.1f}%")
    print(f"  Avg Episode Length   | {sl_summary['avg_length']:8.1f} | {rl_summary['avg_length']:8.1f} | {((rl_summary['avg_length']/max(sl_summary['avg_length'],0.1))-1)*100:+7.1f}%")
    
    return sl_stats, rl_stats

# Run evaluation and visualization
if ppo_trainer is not None and hasattr(ppo_trainer, 'training_stats'):
    print("📊 Running evaluation and generating plots...")
    
    # Evaluate current model
    eval_stats, eval_summary = evaluate_rl_agent(actor, critic, num_episodes=20)
    
    # Plot metrics if we have training data
    if hasattr(ppo_trainer, 'training_stats') and test_metrics is not None:
        plot_training_metrics(test_metrics)
    
    print("✅ Evaluation completed!")
    
    # Uncomment below to export RL model
    # print("\\n🔄 Exporting RL model...")
    # export_success = export_rl_model(actor)
    
    # Uncomment below for SL vs RL comparison  
    # print("\\n⚔️  Running SL vs RL comparison...")
    # sl_stats, rl_stats = compare_sl_vs_rl(actor, actor, num_episodes=10)
    
else:
    print("❌ Skipping evaluation - training not completed or data not available")

print("\\n" + "="*60)
print("🎯 RL TRAINING PIPELINE READY!")
print("="*60) 
print("✅ All components initialized and tested")
print("✅ Environment, PPO trainer, and metrics ready")
print("✅ Quick test completed successfully")
print("")
print("🚀 To start full training, uncomment and run:")
print("   # full_metrics = train_rl_agent(ppo_trainer, num_updates=500)")
print("")
print("📊 Key metrics to monitor during training:")
print("  - Success rate (episodes with ≥1 catch)")
print("  - Average catches per episode")
print("  - Catch efficiency (catches per step)")
print("  - Episode length and timeout rate")
print("")
print("🎮 After training, export model with:")
print("   # export_rl_model(actor)")
print("="*60)
