In [None]:
# Enhanced Transformer RL Training - Fixed Implementation
# Download Codebase from GitHub
import os
import subprocess
import sys
from pathlib import Path

# Repository information
REPO_URL = "https://github.com/yimianxyz/homepage.git"
BRANCH = "neuro-predator"  # Using the main branch with transformer implementation
REPO_DIR = "homepage"

def download_codebase():
    """Download the codebase from GitHub if not already present"""
    
    if os.path.exists(REPO_DIR):
        print(f"Repository directory '{REPO_DIR}' already exists.")
        
        try:
            os.chdir(REPO_DIR)
            
            # Check current branch
            result = subprocess.run(['git', 'branch', '--show-current'],
                                  capture_output=True, text=True, check=True)
            current_branch = result.stdout.strip()
            
            if current_branch != BRANCH:
                print(f"Switching to branch '{BRANCH}'...")
                subprocess.run(['git', 'checkout', BRANCH], check=True)
            
            # Pull latest changes
            print("Updating repository...")
            subprocess.run(['git', 'pull', 'origin', BRANCH], check=True)
            
            print(f"✅ Repository updated successfully!")
            
        except subprocess.CalledProcessError as e:
            print(f"❌ Error updating repository: {e}")
            print("Repository directory exists but may not be a valid git repository.")
    
    else:
        print(f"Cloning repository from {REPO_URL} (branch: {BRANCH})...")
        
        try:
            # Clone the specific branch
            subprocess.run(['git', 'clone', '-b', BRANCH, REPO_URL, REPO_DIR], check=True)
            print(f"✅ Repository cloned successfully!")
            
            # Change to repository directory
            os.chdir(REPO_DIR)
            
        except subprocess.CalledProcessError as e:
            print(f"❌ Error cloning repository: {e}")
            print("Make sure you have git installed and internet connection.")
            return False
    
    # Verify key files exist
    key_files = [
        'config/constants.py',
        'simulation/processors/input_processor.py',
        'simulation/state_manager/state_manager.py',
        'simulation/random_state_generator/random_state_generator.py',
        'policy/human_prior/closest_pursuit_policy.py',
        'export_to_js.py'
    ]
    
    missing_files = []
    for file_path in key_files:
        if not os.path.exists(file_path):
            missing_files.append(file_path)
    
    if missing_files:
        print(f"⚠️  Warning: Some key files are missing:")
        for file_path in missing_files:
            print(f"  - {file_path}")
        return False
    
    print(f"✅ All key files found!")
    print(f"📁 Working directory: {os.getcwd()}")
    
    return True

# Download the codebase
success = download_codebase()

if success:
    print("\n🎉 Setup complete! Ready for enhanced RL training.")
else:
    print("❌ Setup failed. Please check the errors above and try again.")


In [None]:
# Configuration and Imports - Single Source of Truth
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from collections import deque, defaultdict
import random
import time
from datetime import datetime
import json
from pathlib import Path
import math
from typing import Dict, List, Any, Tuple, Optional

# Ensure we're in the correct directory and add to Python path
project_root = Path.cwd()
if project_root.name != 'homepage':
    print(f"⚠️  Warning: Current directory is '{project_root.name}', expected 'homepage'")

if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

# Import project modules
try:
    from config.constants import CONSTANTS
    from simulation.processors import InputProcessor, ActionProcessor
    from simulation.state_manager import StateManager
    from simulation.random_state_generator import RandomStateGenerator
    from policy.human_prior.closest_pursuit_policy import create_closest_pursuit_policy
    
    print(f"✅ Successfully imported all simulation modules")
    print(f"📁 Project root: {project_root}")
    print(f"🔧 Key constants: MAX_DISTANCE={CONSTANTS.MAX_DISTANCE}, BOID_MAX_SPEED={CONSTANTS.BOID_MAX_SPEED}")
    
except ImportError as e:
    print(f"❌ Failed to import modules: {e}")
    print("Make sure the repository was downloaded correctly in the first cell.")
    raise

# ===== SINGLE SOURCE OF TRUTH CONFIGURATION =====
class RLConfig:
    """Centralized configuration for RL training - single source of truth"""
    
    # Device Configuration
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Model Architecture (loaded from checkpoint)
    MODEL_ARCHITECTURE = {
        'd_model': 128,
        'n_heads': 8,
        'n_layers': 4,
        'ffn_hidden': 512,
        'max_boids': 50
    }
    
    # PPO Hyperparameters
    LEARNING_RATE = 3e-4        # Increased from 5e-5 for better fine-tuning
    CLIP_EPSILON = 0.15         # Slightly conservative for fine-tuning
    ENTROPY_COEF = 0.01         # Reduced for pre-trained deterministic policy
    VALUE_COEF = 0.5           # Standard value loss coefficient
    MAX_GRAD_NORM = 0.5        # Gradient clipping
    PPO_EPOCHS = 4             # Number of PPO optimization epochs
    MINI_BATCH_SIZE = 64       # FIXED: Consistent mini-batch size
    ROLLOUT_STEPS = 2048       # Steps per rollout
    GAMMA = 0.99               # Discount factor
    GAE_LAMBDA = 0.95          # GAE lambda parameter
    
    # Actor Network Parameters
    EXPLORATION_STD_INIT = 0.3  # Initial exploration standard deviation
    EXPLORATION_STD_MIN = 0.05  # Minimum exploration standard deviation
    EXPLORATION_STD_MAX = 0.8   # Maximum exploration standard deviation
    
    # Environment Settings
    MIN_BOIDS = 1
    MAX_BOIDS = 50
    MIN_CANVAS_WIDTH = 400
    MAX_CANVAS_WIDTH = 1600
    MIN_CANVAS_HEIGHT = 400
    MAX_CANVAS_HEIGHT = 1200
    TIMEOUT_MULTIPLIER = 3.0    # Adaptive timeout calculation
    
    # Reward Settings (keeping multi-step attribution as requested)
    CATCH_REWARD = 1.0
    REWARD_DECAY_RATE = 0.05    # For exponential decay over reward window
    REWARD_WINDOW = 50          # Steps before catch to attribute reward
    
    # Training Settings
    NUM_EPISODES = 5000
    LOG_INTERVAL = 50           # Log every N rollouts
    SAVE_INTERVAL = 200         # Save every N rollouts
    EVAL_INTERVAL = 100         # Evaluate every N rollouts
    EVAL_EPISODES = 30          # Episodes per evaluation
    
    # Checkpoints
    SL_CHECKPOINT_PATH = "checkpoints/best_model.pt"
    RL_CHECKPOINT_DIR = "rl_checkpoints"
    
    @classmethod
    def get_dict(cls):
        """Get configuration as dictionary for compatibility"""
        return {attr: getattr(cls, attr) for attr in dir(cls) 
                if not attr.startswith('_') and not callable(getattr(cls, attr))}
    
    @classmethod
    def print_config(cls):
        """Print current configuration"""
        print("🚀 RL Training Configuration:")
        print(f"  Device: {cls.DEVICE}")
        print(f"  Learning Rate: {cls.LEARNING_RATE}")
        print(f"  Mini-batch Size: {cls.MINI_BATCH_SIZE}")
        print(f"  Rollout Steps: {cls.ROLLOUT_STEPS}")
        print(f"  Environment: {cls.MIN_BOIDS}-{cls.MAX_BOIDS} boids")
        print(f"  Canvas: {cls.MIN_CANVAS_WIDTH}x{cls.MIN_CANVAS_HEIGHT} to {cls.MAX_CANVAS_WIDTH}x{cls.MAX_CANVAS_HEIGHT}")
        print(f"  Reward: {cls.CATCH_REWARD} per catch, {cls.REWARD_WINDOW}-step attribution window")
        print(f"  Exploration: std [{cls.EXPLORATION_STD_MIN}, {cls.EXPLORATION_STD_MAX}], init {cls.EXPLORATION_STD_INIT}")

# Initialize configuration
config = RLConfig()
config.print_config()


In [None]:
# Enhanced Model Architectures - Fixed Implementation

class GEGLU(nn.Module):
    """Gated Linear Unit with GELU activation"""
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return x * torch.nn.functional.gelu(gate)

class TransformerLayer(nn.Module):
    """Enhanced transformer layer with proper initialization"""
    def __init__(self, d_model, n_heads, ffn_hidden, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads

        self.norm1 = nn.LayerNorm(d_model)
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)

        self.norm2 = nn.LayerNorm(d_model)

        # GEGLU FFN with separate projections for export compatibility
        self.ffn_gate_proj = nn.Linear(d_model, ffn_hidden)
        self.ffn_up_proj = nn.Linear(d_model, ffn_hidden)
        self.ffn_down_proj = nn.Linear(ffn_hidden, d_model)

    def forward(self, x, padding_mask=None):
        # Self-attention with residual
        normed = self.norm1(x)
        attn_out, _ = self.self_attn(normed, normed, normed, key_padding_mask=padding_mask)
        x = x + attn_out

        # FFN with residual
        normed = self.norm2(x)
        gate = torch.nn.functional.gelu(self.ffn_gate_proj(normed))
        up = self.ffn_up_proj(normed)
        ffn_out = self.ffn_down_proj(gate * up)
        x = x + ffn_out

        return x

class TransformerPredictor(nn.Module):
    """Enhanced transformer predictor with better initialization"""
    def __init__(self, d_model=128, n_heads=8, n_layers=4, ffn_hidden=512, max_boids=50, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.ffn_hidden = ffn_hidden
        self.max_boids = max_boids

        # CLS token embedding
        self.cls_embedding = nn.Parameter(torch.randn(d_model))

        # Type embeddings
        self.type_embeddings = nn.ParameterDict({
            'cls': nn.Parameter(torch.randn(d_model)),
            'ctx': nn.Parameter(torch.randn(d_model)),
            'predator': nn.Parameter(torch.randn(d_model)),
            'boid': nn.Parameter(torch.randn(d_model))
        })

        # Input projections
        self.ctx_projection = nn.Linear(2, d_model)
        self.predator_projection = nn.Linear(4, d_model)
        self.boid_projection = nn.Linear(4, d_model)

        # Transformer layers
        self.transformer_layers = nn.ModuleList([
            TransformerLayer(d_model, n_heads, ffn_hidden, dropout)
            for _ in range(n_layers)
        ])

        # Output projection
        self.output_projection = nn.Linear(d_model, 2)

    def forward(self, structured_inputs, padding_mask=None):
        batch_size = len(structured_inputs) if isinstance(structured_inputs, list) else 1

        # Handle single sample vs batch
        if isinstance(structured_inputs, dict):
            structured_inputs = [structured_inputs]
            batch_size = 1

        # Build token sequences for each sample in batch
        sequences = []
        masks = []

        for sample in structured_inputs:
            tokens = []

            # CLS token
            cls_token = self.cls_embedding + self.type_embeddings['cls']
            tokens.append(cls_token)

            # Context token
            ctx_input = torch.tensor([sample['context']['canvasWidth'], sample['context']['canvasHeight']],
                                   dtype=torch.float32, device=self.cls_embedding.device)
            ctx_token = self.ctx_projection(ctx_input) + self.type_embeddings['ctx']
            tokens.append(ctx_token)

            # Predator token - expand to 4D
            predator_input = torch.tensor([sample['predator']['velX'], sample['predator']['velY'], 0.0, 0.0],
                                        dtype=torch.float32, device=self.cls_embedding.device)
            predator_token = self.predator_projection(predator_input) + self.type_embeddings['predator']
            tokens.append(predator_token)

            # Boid tokens
            sample_mask = [False, False, False]  # CLS, CTX, Predator are not padding

            for boid in sample['boids']:
                boid_input = torch.tensor([boid['relX'], boid['relY'], boid['velX'], boid['velY']],
                                        dtype=torch.float32, device=self.cls_embedding.device)
                boid_token = self.boid_projection(boid_input) + self.type_embeddings['boid']
                tokens.append(boid_token)
                sample_mask.append(False)

            # Pad to max_boids + 3 (CLS + CTX + Predator)
            while len(tokens) < self.max_boids + 3:
                padding_token = torch.zeros(self.d_model, device=self.cls_embedding.device)
                tokens.append(padding_token)
                sample_mask.append(True)  # Mark as padding

            sequences.append(torch.stack(tokens))
            masks.append(sample_mask)

        # Stack sequences
        x = torch.stack(sequences)  # [batch_size, seq_len, d_model]

        # Create padding mask
        if padding_mask is None:
            padding_mask = torch.tensor(masks, dtype=torch.bool, device=x.device)

        # Pass through transformer layers
        for layer in self.transformer_layers:
            x = layer(x, padding_mask)

        # Extract CLS token and project to output
        cls_output = x[:, 0]  # [batch_size, d_model]
        action = self.output_projection(cls_output)  # [batch_size, 2]

        # Apply tanh to ensure [-1, 1] range
        action = torch.tanh(action)

        return action.squeeze(0) if batch_size == 1 else action

class EnhancedActorNetwork(nn.Module):
    """FIXED: Enhanced actor network with proper stochastic policy"""

    def __init__(self, checkpoint, architecture):
        super().__init__()

        # Create transformer with same architecture as SL model
        self.transformer = TransformerPredictor(
            d_model=architecture['d_model'],
            n_heads=architecture['n_heads'],
            n_layers=architecture['n_layers'],
            ffn_hidden=architecture['ffn_hidden'],
            max_boids=architecture['max_boids'],
            dropout=0.1
        )

        # FIXED: Learnable exploration parameters
        self.log_std = nn.Parameter(
            torch.ones(2) * math.log(RLConfig.EXPLORATION_STD_INIT)
        )

        # Load pretrained weights
        if checkpoint is not None:
            self.transformer.load_state_dict(checkpoint['model_state_dict'])
            print("✅ Loaded pretrained SL weights into actor")

        # Store architecture info
        self.architecture = architecture

    def forward(self, structured_inputs):
        """Forward pass through transformer"""
        return self.transformer(structured_inputs)

    def get_action_and_log_prob(self, structured_inputs):
        """FIXED: Get stochastic action and log probability for proper RL exploration"""
        # Get mean action from transformer
        action_mean = self.forward(structured_inputs)

        # FIXED: Proper learnable exploration with clamping
        log_std = torch.clamp(
            self.log_std, 
            min=math.log(RLConfig.EXPLORATION_STD_MIN),
            max=math.log(RLConfig.EXPLORATION_STD_MAX)
        )
        std = torch.exp(log_std)

        # Create distribution and sample
        dist = torch.distributions.Normal(action_mean, std)
        action = dist.sample()
        log_prob = dist.log_prob(action).sum(dim=-1)

        # Apply tanh to ensure bounded output and adjust log_prob
        action_tanh = torch.tanh(action)
        
        # Adjust log probability for tanh transformation
        log_prob = log_prob - torch.log(1 - action_tanh.pow(2) + 1e-7).sum(dim=-1)

        return action_tanh, log_prob

class EnhancedCriticNetwork(nn.Module):
    """FIXED: Enhanced critic network with better feature processing"""

    def __init__(self, input_dim=22, hidden_dims=[256, 256, 128]):
        super().__init__()

        layers = []
        prev_dim = input_dim

        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.LayerNorm(hidden_dim),  # Added layer norm for stability
                nn.Dropout(0.1)
            ])
            prev_dim = hidden_dim

        layers.append(nn.Linear(prev_dim, 1))  # Single value output

        self.network = nn.Sequential(*layers)

        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize network weights"""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

    def forward(self, state_features):
        """Estimate value from state features"""
        return self.network(state_features).squeeze(-1)

def extract_state_features(structured_inputs):
    """FIXED: Enhanced feature extraction with better error handling"""
    batch_size = len(structured_inputs) if isinstance(structured_inputs, list) else 1

    if isinstance(structured_inputs, dict):
        structured_inputs = [structured_inputs]

    features = []

    for sample in structured_inputs:
        # Context features (2D)
        context_feat = [sample['context']['canvasWidth'], sample['context']['canvasHeight']]

        # Predator features (2D)
        predator_feat = [sample['predator']['velX'], sample['predator']['velY']]

        # FIXED: Robust boid features with error handling
        boids = sample['boids']
        if len(boids) > 0:
            # Statistical aggregation of boids with safe operations
            rel_x = [b['relX'] for b in boids]
            rel_y = [b['relY'] for b in boids]
            vel_x = [b['velX'] for b in boids]
            vel_y = [b['velY'] for b in boids]

            # Safe statistics calculation
            boid_feat = [
                len(boids),  # Number of boids
                np.mean(rel_x), max(np.std(rel_x), 1e-8), np.min(rel_x), np.max(rel_x),
                np.mean(rel_y), max(np.std(rel_y), 1e-8), np.min(rel_y), np.max(rel_y),
                np.mean(vel_x), max(np.std(vel_x), 1e-8), np.min(vel_x), np.max(vel_x),
                np.mean(vel_y), max(np.std(vel_y), 1e-8), np.min(vel_y), np.max(vel_y),
                # Distance to closest boid
                np.min([math.sqrt(b['relX']**2 + b['relY']**2) for b in boids])
            ]
        else:
            # No boids remaining - use zeros
            boid_feat = [0] + [0.0] * 16  # 17 features total

        # Combine all features (2 + 2 + 17 = 21 features) + 1 extra = 22
        sample_feat = context_feat + predator_feat + boid_feat + [len(boids) / RLConfig.MAX_BOIDS]  # Normalized boid count
        features.append(sample_feat)

    return torch.tensor(features, dtype=torch.float32, device=RLConfig.DEVICE)

print("✅ Enhanced model architectures defined with fixes:")
print("  - Fixed stochastic actor policy with learnable exploration")
print("  - Enhanced critic network with layer normalization")  
print("  - Robust feature extraction with error handling")
print("  - Proper log probability computation with tanh correction")


In [None]:
# Load Supervised Learning Model and Initialize Networks

def load_sl_checkpoint(checkpoint_path: str = None):
    """Load supervised learning checkpoint with better error handling"""
    
    if checkpoint_path is None:
        checkpoint_path = RLConfig.SL_CHECKPOINT_PATH
    
    if not os.path.exists(checkpoint_path):
        print(f"❌ Checkpoint not found: {checkpoint_path}")
        print("Available checkpoints:")
        if os.path.exists("checkpoints/"):
            checkpoints = list(Path("checkpoints/").glob("*.pt"))
            for cp in checkpoints:
                print(f"  - {cp}")
        else:
            print("  No checkpoints directory found")
        return None, None

    print(f"Loading SL checkpoint from {checkpoint_path}...")
    checkpoint = torch.load(checkpoint_path, map_location=RLConfig.DEVICE)

    # Extract and validate architecture parameters
    if 'architecture' in checkpoint:
        arch = checkpoint['architecture']
        print(f"✅ Found architecture in checkpoint:")
        print(f"  d_model: {arch['d_model']}")
        print(f"  n_heads: {arch['n_heads']}")
        print(f"  n_layers: {arch['n_layers']}")
        print(f"  ffn_hidden: {arch['ffn_hidden']}")
        print(f"  max_boids: {arch['max_boids']}")
        
        # Update config with loaded architecture
        RLConfig.MODEL_ARCHITECTURE.update(arch)
    else:
        print("❌ No architecture found in checkpoint - using default values")
        arch = RLConfig.MODEL_ARCHITECTURE

    # Additional info
    epoch = checkpoint.get('epoch', 'unknown')
    val_loss = checkpoint.get('best_val_loss', 'unknown')
    print(f"  Checkpoint epoch: {epoch}")
    print(f"  Best validation loss: {val_loss}")

    return checkpoint, arch

def initialize_networks(checkpoint, architecture):
    """Initialize actor and critic networks with proper error handling"""
    
    print("🧠 Initializing networks...")
    
    # Create actor and critic networks
    actor = EnhancedActorNetwork(checkpoint, architecture).to(RLConfig.DEVICE)
    critic = EnhancedCriticNetwork(input_dim=22).to(RLConfig.DEVICE)

    # Count parameters
    actor_params = sum(p.numel() for p in actor.parameters())
    critic_params = sum(p.numel() for p in critic.parameters())
    actor_trainable = sum(p.numel() for p in actor.parameters() if p.requires_grad)
    critic_trainable = sum(p.numel() for p in critic.parameters() if p.requires_grad)

    print(f"📊 Network Statistics:")
    print(f"  Actor parameters: {actor_params:,} ({actor_trainable:,} trainable)")
    print(f"  Critic parameters: {critic_params:,} ({critic_trainable:,} trainable)")
    print(f"  Total parameters: {actor_params + critic_params:,}")
    
    # Exploration info
    with torch.no_grad():
        current_std = torch.exp(actor.log_std).mean().item()
    print(f"  Initial exploration std: {current_std:.3f}")
    
    return actor, critic

def test_networks(actor, critic):
    """Test network forward passes with comprehensive validation"""
    
    print("🧪 Testing network forward passes...")
    
    # Create test input with validation
    test_input = {
        'context': {'canvasWidth': 0.8, 'canvasHeight': 0.6},
        'predator': {'velX': 0.1, 'velY': -0.2},
        'boids': [
            {'relX': 0.1, 'relY': 0.3, 'velX': 0.5, 'velY': -0.1},
            {'relX': -0.2, 'relY': 0.1, 'velX': -0.3, 'velY': 0.4}
        ]
    }
    
    # Test with different batch sizes
    test_cases = [
        ("Single sample", test_input),
        ("Empty boids", {
            'context': {'canvasWidth': 0.5, 'canvasHeight': 0.4},
            'predator': {'velX': 0.0, 'velY': 0.0},
            'boids': []
        }),
        ("Many boids", {
            'context': {'canvasWidth': 1.0, 'canvasHeight': 1.0},
            'predator': {'velX': 0.2, 'velY': -0.1},
            'boids': [
                {'relX': 0.1, 'relY': 0.1, 'velX': 0.1, 'velY': 0.1},
                {'relX': -0.1, 'relY': -0.1, 'velX': -0.1, 'velY': -0.1},
                {'relX': 0.2, 'relY': -0.2, 'velX': 0.2, 'velY': -0.2}
            ]
        })
    ]
    
    for test_name, test_case in test_cases:
        try:
            with torch.no_grad():
                # Test actor
                action_mean = actor.forward(test_case)
                action, log_prob = actor.get_action_and_log_prob(test_case)
                
                # Test critic
                features = extract_state_features(test_case)
                value = critic(features)
                
                print(f"  ✅ {test_name}:")
                print(f"    Input boids: {len(test_case['boids'])}")
                print(f"    Action mean: [{action_mean[0].item():.3f}, {action_mean[1].item():.3f}]")
                print(f"    Sampled action: [{action[0].item():.3f}, {action[1].item():.3f}]")
                print(f"    Log prob: {log_prob.item():.3f}")
                print(f"    Features shape: {features.shape}")
                print(f"    Value: {value.item():.3f}")
                
        except Exception as e:
            print(f"  ❌ {test_name} failed: {e}")
            raise
    
    print("✅ All network tests passed!")
    return True

# Load checkpoint and initialize networks
print("🚀 Loading SL checkpoint and initializing networks...")

try:
    checkpoint, architecture = load_sl_checkpoint()
    
    if checkpoint is not None:
        actor, critic = initialize_networks(checkpoint, architecture)
        test_success = test_networks(actor, critic)
        
        if test_success:
            print("\n✅ Networks successfully initialized and tested!")
            print("📱 Ready for RL environment setup and training.")
        else:
            raise RuntimeError("Network tests failed")
    else:
        print("❌ Cannot proceed without SL checkpoint")
        actor = critic = None
        
except Exception as e:
    print(f"❌ Error during network initialization: {e}")
    print("Please ensure the supervised learning checkpoint exists and is valid.")
    actor = critic = None


In [None]:
# FIXED: Enhanced RL Environment Wrapper

class CustomRLPolicy:
    """Custom policy wrapper for RL agent integration with StateManager"""
    
    def __init__(self, action):
        self.action = action if isinstance(action, list) else [action[0].item(), action[1].item()]
    
    def get_action(self, structured_inputs):
        """Return the precomputed action"""
        return self.action

class EnhancedPredatorEnvironment:
    """FIXED: Enhanced RL environment wrapper with proper StateManager integration"""

    def __init__(self):
        # Initialize simulation components using proper infrastructure
        self.state_manager = StateManager()
        self.random_generator = RandomStateGenerator()
        self.input_processor = InputProcessor()
        self.action_processor = ActionProcessor()

        # Environment state tracking
        self.current_step = 0
        self.max_steps = 0
        self.initial_boids_count = 0
        self.episode_catches = []  # Steps when catches occurred
        self.step_rewards = []     # Rewards for each step
        
        # Episode tracking
        self.episode_count = 0
        self.reset_episode_stats()

    def reset_episode_stats(self):
        """Reset episode-level statistics"""
        self.episode_catches = []
        self.step_rewards = []
        self.current_step = 0

    def calculate_adaptive_timeout(self, canvas_width, canvas_height, num_boids):
        """Calculate adaptive timeout based on environment complexity"""
        canvas_area = canvas_width * canvas_height
        base_timeout = (canvas_area * num_boids * RLConfig.TIMEOUT_MULTIPLIER) / 10000
        return max(int(base_timeout), 300)  # Minimum 300 steps for better learning

    def reset(self):
        """Reset environment for new episode with enhanced randomization"""
        self.reset_episode_stats()
        self.episode_count += 1

        # Generate random environment parameters
        num_boids = random.randint(RLConfig.MIN_BOIDS, RLConfig.MAX_BOIDS)
        canvas_width = random.randint(RLConfig.MIN_CANVAS_WIDTH, RLConfig.MAX_CANVAS_WIDTH)
        canvas_height = random.randint(RLConfig.MIN_CANVAS_HEIGHT, RLConfig.MAX_CANVAS_HEIGHT)

        # Generate random initial state with better distribution
        initial_state = self.random_generator.generate_scattered_state(
            num_boids, canvas_width, canvas_height
        )

        # FIXED: Use StateManager properly with dummy policy for initialization
        dummy_policy = create_closest_pursuit_policy()
        self.state_manager.init(initial_state, dummy_policy)

        # Set episode parameters
        self.initial_boids_count = num_boids
        self.max_steps = self.calculate_adaptive_timeout(canvas_width, canvas_height, num_boids)
        self.current_step = 0

        # Get initial observation using proper conversion
        current_state = self.state_manager.get_state()
        observation = self._state_to_structured_inputs(current_state)

        info = {
            'episode': self.episode_count,
            'canvas_width': canvas_width,
            'canvas_height': canvas_height,
            'initial_boids': num_boids,
            'max_steps': self.max_steps,
            'episode_reset': True
        }

        return observation, info

    def _state_to_structured_inputs(self, state):
        """FIXED: Convert state to structured inputs with validation"""
        try:
            # Extract data from state with validation
            boids = state.get('boids_states', [])
            predator_state = state.get('predator_state', {})
            predator_pos = predator_state.get('position', {'x': 0, 'y': 0})
            predator_vel = predator_state.get('velocity', {'x': 0, 'y': 0})
            canvas_width = state.get('canvas_width', 800)
            canvas_height = state.get('canvas_height', 600)

            # Use input processor for proper conversion
            structured_inputs = self.input_processor.process_inputs(
                boids, predator_pos, predator_vel, canvas_width, canvas_height
            )

            return structured_inputs
            
        except Exception as e:
            print(f"⚠️ Error converting state: {e}")
            # Return safe default
            return {
                'context': {'canvasWidth': 0.5, 'canvasHeight': 0.5},
                'predator': {'velX': 0.0, 'velY': 0.0},
                'boids': []
            }

    def step(self, action):
        """FIXED: Take environment step using StateManager properly"""
        # Convert action to proper format
        if torch.is_tensor(action):
            action_list = [action[0].item(), action[1].item()]
        else:
            action_list = list(action)

        # Store current state for comparison
        prev_state = self.state_manager.get_state()
        boids_before = len(prev_state.get('boids_states', []))

        # FIXED: Use StateManager with custom policy instead of direct simulation access
        try:
            custom_policy = CustomRLPolicy(action_list)
            
            # Let StateManager handle the step properly
            # Replace the policy temporarily for this step
            old_policy = self.state_manager.policy
            self.state_manager.policy = custom_policy
            
            # Execute step using StateManager's proper infrastructure
            new_state = self.state_manager.step()
            
            # Restore original policy
            self.state_manager.policy = old_policy
            
        except Exception as e:
            print(f"⚠️ Error in step execution: {e}")
            # Fallback: use previous state
            new_state = prev_state

        # Calculate catches by comparing boid counts
        boids_after = len(new_state.get('boids_states', []))
        catches_this_step = max(0, boids_before - boids_after)

        # Track catches for reward calculation
        if catches_this_step > 0:
            for _ in range(catches_this_step):
                self.episode_catches.append(self.current_step)

        self.current_step += 1

        # Calculate reward using the multi-step attribution system (as requested)
        reward = self._calculate_reward()
        self.step_rewards.append(reward)

        # Check termination conditions
        done = False
        termination_reason = 'ongoing'

        if boids_after == 0:
            done = True
            termination_reason = 'all_caught'
        elif self.current_step >= self.max_steps:
            done = True
            termination_reason = 'timeout'

        # Get next observation
        observation = self._state_to_structured_inputs(new_state)

        # Comprehensive info dictionary
        info = {
            'step': self.current_step,
            'catches_this_step': catches_this_step,
            'total_catches': len(self.episode_catches),
            'boids_remaining': boids_after,
            'boids_before': boids_before,
            'done': done,
            'termination_reason': termination_reason,
            'reward': reward,
            'episode_reward': sum(self.step_rewards),
            'max_steps': self.max_steps,
            'episode': self.episode_count
        }

        return observation, reward, done, info

    def _calculate_reward(self):
        """Multi-step reward attribution as requested (keeping the original design)"""
        total_reward = 0.0

        # Only calculate reward if we have catches
        if len(self.episode_catches) == 0:
            return 0.0

        # For each catch, attribute reward to the last N steps (as per original design)
        for catch_step in self.episode_catches:
            # Define the reward window (steps before the catch)
            reward_start = max(0, catch_step - RLConfig.REWARD_WINDOW + 1)
            reward_end = catch_step + 1

            # Only attribute reward to current step if it's within the window
            if reward_start <= self.current_step - 1 < reward_end:
                # Calculate steps before catch (for decay)
                steps_before_catch = catch_step - (self.current_step - 1)

                # Exponential decay: more recent actions get higher reward
                decay_factor = math.exp(-RLConfig.REWARD_DECAY_RATE * steps_before_catch)
                step_reward = RLConfig.CATCH_REWARD * decay_factor

                total_reward += step_reward

        return total_reward

def test_environment():
    """Test the enhanced environment with comprehensive validation"""
    print("🌍 Testing enhanced RL environment...")
    
    if actor is None:
        print("❌ Cannot test environment - actor not available")
        return False

    # Create test environment
    test_env = EnhancedPredatorEnvironment()

    try:
        # Test reset
        obs, info = test_env.reset()
        print(f"  ✅ Environment reset successful:")
        print(f"    Canvas: {info['canvas_width']}x{info['canvas_height']}")
        print(f"    Boids: {info['initial_boids']}")
        print(f"    Max steps: {info['max_steps']}")
        print(f"    Observation boids: {len(obs['boids'])}")

        # Test multiple steps
        print(f"  🔄 Testing environment steps:")
        for i in range(5):
            with torch.no_grad():
                action, log_prob = actor.get_action_and_log_prob(obs)

            obs, reward, done, step_info = test_env.step(action)

            print(f"    Step {i+1}: action=[{action[0]:.3f}, {action[1]:.3f}], "
                  f"reward={reward:.3f}, boids={step_info['boids_remaining']}, "
                  f"catches={step_info['catches_this_step']}")

            if done:
                print(f"    Episode done: {step_info['termination_reason']}")
                break

        print("  ✅ Environment test completed successfully!")
        return True
        
    except Exception as e:
        print(f"  ❌ Environment test failed: {e}")
        import traceback
        traceback.print_exc()
        return False

# Test the enhanced environment
if actor is not None:
    env_test_success = test_environment()
    if env_test_success:
        print("\n✅ Enhanced environment ready for training!")
        print("🔧 Key fixes implemented:")
        print("  - Proper StateManager integration instead of direct simulation access")
        print("  - Robust error handling and fallback mechanisms")  
        print("  - Enhanced state conversion with validation")
        print("  - Comprehensive step information tracking")
        print("  - Multi-step reward attribution system maintained")
    else:
        print("❌ Environment test failed - check implementation")
else:
    print("❌ Skipping environment test - actor not available")


In [None]:
# FIXED: Enhanced PPO Implementation with Proper GAE

class ExperienceBuffer:
    """FIXED: Memory-efficient experience buffer with proper trajectory handling"""
    
    def __init__(self, buffer_size):
        self.buffer_size = buffer_size
        self.reset()
    
    def reset(self):
        """Reset the buffer"""
        self.observations = []
        self.actions = []
        self.rewards = []
        self.values = []
        self.log_probs = []
        self.dones = []
        self.episode_starts = []  # FIXED: Track episode boundaries
        self.size = 0
    
    def add(self, obs, action, reward, value, log_prob, done, episode_start=False):
        """Add experience to buffer"""
        self.observations.append(obs)
        self.actions.append(action)
        self.rewards.append(reward)
        self.values.append(value)
        self.log_probs.append(log_prob)
        self.dones.append(done)
        self.episode_starts.append(episode_start)
        self.size += 1
    
    def get_batch(self):
        """Get all experiences as tensors"""
        return {
            'observations': self.observations,
            'actions': torch.stack(self.actions),
            'rewards': torch.tensor(self.rewards, dtype=torch.float32, device=RLConfig.DEVICE),
            'values': torch.stack(self.values),
            'log_probs': torch.stack(self.log_probs),
            'dones': torch.tensor(self.dones, dtype=torch.float32, device=RLConfig.DEVICE),
            'episode_starts': torch.tensor(self.episode_starts, dtype=torch.bool, device=RLConfig.DEVICE)
        }
    
    def is_full(self):
        """Check if buffer is full"""
        return self.size >= self.buffer_size

class EnhancedPPOTrainer:
    """FIXED: Enhanced PPO trainer with proper GAE and trajectory handling"""

    def __init__(self, actor, critic):
        self.actor = actor
        self.critic = critic

        # Optimizers with different learning rates for fine-tuning
        self.actor_optimizer = optim.AdamW(
            actor.parameters(), 
            lr=RLConfig.LEARNING_RATE,
            weight_decay=0.01
        )
        self.critic_optimizer = optim.AdamW(
            critic.parameters(), 
            lr=RLConfig.LEARNING_RATE * 2.0,  # Slightly higher for critic
            weight_decay=0.01
        )

        # Learning rate schedulers
        self.actor_scheduler = optim.lr_scheduler.LinearLR(
            self.actor_optimizer, start_factor=1.0, end_factor=0.1, total_iters=1000
        )
        self.critic_scheduler = optim.lr_scheduler.LinearLR(
            self.critic_optimizer, start_factor=1.0, end_factor=0.1, total_iters=1000
        )

        # Experience buffer
        self.buffer = ExperienceBuffer(RLConfig.ROLLOUT_STEPS)

        # Training metrics
        self.training_stats = defaultdict(list)
        self.episode_stats = []

    def collect_rollout(self, env):
        """FIXED: Collect rollout with proper episode boundary tracking"""
        self.buffer.reset()
        
        # Initialize environment
        obs, info = env.reset()
        episode_stats = []
        current_episode = {
            'reward': 0,
            'length': 0,
            'catches': 0,
            'termination': None
        }

        episode_start = True

        while not self.buffer.is_full():
            # Get action and value
            with torch.no_grad():
                action, log_prob = self.actor.get_action_and_log_prob(obs)
                state_features = extract_state_features(obs)
                value = self.critic(state_features)

            # Take step in environment
            next_obs, reward, done, step_info = env.step(action)

            # Add to buffer
            self.buffer.add(obs, action, reward, value, log_prob, done, episode_start)
            episode_start = False

            # Update episode stats
            current_episode['reward'] += reward
            current_episode['length'] += 1
            current_episode['catches'] += step_info['catches_this_step']

            obs = next_obs

            if done:
                # Episode finished
                current_episode['termination'] = step_info['termination_reason']
                episode_stats.append(current_episode.copy())
                
                # Reset for next episode
                obs, info = env.reset()
                episode_start = True
                current_episode = {
                    'reward': 0,
                    'length': 0,
                    'catches': 0,
                    'termination': None
                }

        # FIXED: Calculate returns and advantages with proper GAE
        self._calculate_gae(next_obs)

        return episode_stats

    def _calculate_gae(self, final_obs):
        """FIXED: Proper Generalized Advantage Estimation implementation"""
        batch = self.buffer.get_batch()
        
        # Get final value estimate
        with torch.no_grad():
            final_state_features = extract_state_features(final_obs)
            final_value = self.critic(final_state_features).item()

        # Convert to numpy for easier computation
        rewards = batch['rewards'].cpu().numpy()
        values = batch['values'].squeeze().cpu().numpy()
        dones = batch['dones'].cpu().numpy()
        episode_starts = batch['episode_starts'].cpu().numpy()

        # FIXED: Proper GAE computation with episode boundary handling
        advantages = np.zeros_like(rewards)
        returns = np.zeros_like(rewards)
        
        # Calculate GAE
        gae = 0
        for step in reversed(range(len(rewards))):
            # Handle episode boundaries
            if step == len(rewards) - 1:
                next_value = final_value if not dones[step] else 0
            else:
                next_value = values[step + 1] if not dones[step] else 0
            
            # Reset GAE at episode boundaries
            if episode_starts[step] and step > 0:
                gae = 0
            
            # TD error
            delta = rewards[step] + RLConfig.GAMMA * next_value - values[step]
            
            # GAE computation
            gae = delta + RLConfig.GAMMA * RLConfig.GAE_LAMBDA * gae * (1 - dones[step])
            advantages[step] = gae

        # Calculate returns
        returns = advantages + values

        # Convert back to tensors
        self.buffer.returns = torch.tensor(returns, dtype=torch.float32, device=RLConfig.DEVICE)
        self.buffer.advantages = torch.tensor(advantages, dtype=torch.float32, device=RLConfig.DEVICE)

        # Normalize advantages
        if len(advantages) > 1:
            self.buffer.advantages = (self.buffer.advantages - self.buffer.advantages.mean()) / (
                self.buffer.advantages.std() + 1e-8
            )

    def update_policy(self):
        """FIXED: Update policy with proper mini-batch handling"""
        batch = self.buffer.get_batch()
        batch_size = self.buffer.size

        # Training metrics for this update
        update_stats = {
            'actor_losses': [],
            'critic_losses': [],
            'entropies': [],
            'clip_fractions': [],
            'kl_divs': []
        }

        # FIXED: Multiple PPO epochs with proper batching
        for epoch in range(RLConfig.PPO_EPOCHS):
            # Shuffle indices
            indices = torch.randperm(batch_size)
            
            # Mini-batch training
            for i in range(0, batch_size, RLConfig.MINI_BATCH_SIZE):
                batch_indices = indices[i:i+RLConfig.MINI_BATCH_SIZE]
                
                if len(batch_indices) < RLConfig.MINI_BATCH_SIZE // 2:
                    continue  # Skip small batches
                
                # Get mini-batch data
                mb_obs = [batch['observations'][idx] for idx in batch_indices]
                mb_actions = batch['actions'][batch_indices]
                mb_old_log_probs = batch['log_probs'][batch_indices]
                mb_returns = self.buffer.returns[batch_indices]
                mb_advantages = self.buffer.advantages[batch_indices]
                mb_values = batch['values'][batch_indices]

                # Forward passes
                _, new_log_probs = self.actor.get_action_and_log_prob(mb_obs)
                mb_state_features = extract_state_features(mb_obs)
                new_values = self.critic(mb_state_features)

                # FIXED: Proper PPO loss computation
                # Actor loss
                ratio = torch.exp(new_log_probs - mb_old_log_probs)
                surr1 = ratio * mb_advantages
                surr2 = torch.clamp(
                    ratio, 
                    1.0 - RLConfig.CLIP_EPSILON, 
                    1.0 + RLConfig.CLIP_EPSILON
                ) * mb_advantages
                
                actor_loss = -torch.min(surr1, surr2).mean()

                # Critic loss (clipped value function)
                value_pred_clipped = mb_values.squeeze() + torch.clamp(
                    new_values - mb_values.squeeze(),
                    -RLConfig.CLIP_EPSILON,
                    RLConfig.CLIP_EPSILON
                )
                value_loss1 = F.mse_loss(new_values, mb_returns)
                value_loss2 = F.mse_loss(value_pred_clipped, mb_returns)
                critic_loss = torch.max(value_loss1, value_loss2)

                # Entropy loss (for exploration)
                entropy = -(new_log_probs).mean()  # Approximation for entropy
                entropy_loss = -RLConfig.ENTROPY_COEF * entropy

                # Combined losses
                total_actor_loss = actor_loss + entropy_loss

                # Update actor
                self.actor_optimizer.zero_grad()
                total_actor_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.actor.parameters(), RLConfig.MAX_GRAD_NORM)
                self.actor_optimizer.step()

                # Update critic
                self.critic_optimizer.zero_grad()
                critic_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.critic.parameters(), RLConfig.MAX_GRAD_NORM)
                self.critic_optimizer.step()

                # Metrics
                with torch.no_grad():
                    clip_frac = ((ratio - 1.0).abs() > RLConfig.CLIP_EPSILON).float().mean()
                    kl_div = (mb_old_log_probs - new_log_probs).mean()

                update_stats['actor_losses'].append(actor_loss.item())
                update_stats['critic_losses'].append(critic_loss.item())
                update_stats['entropies'].append(entropy.item())
                update_stats['clip_fractions'].append(clip_frac.item())
                update_stats['kl_divs'].append(kl_div.item())

        # Update learning rate schedulers
        self.actor_scheduler.step()
        self.critic_scheduler.step()

        # Average metrics
        return {
            'actor_loss': np.mean(update_stats['actor_losses']),
            'critic_loss': np.mean(update_stats['critic_losses']),
            'entropy': np.mean(update_stats['entropies']),
            'clip_fraction': np.mean(update_stats['clip_fractions']),
            'kl_div': np.mean(update_stats['kl_divs']),
            'actor_lr': self.actor_scheduler.get_last_lr()[0],
            'critic_lr': self.critic_scheduler.get_last_lr()[0]
        }

    def save_checkpoint(self, filepath, episode, stats):
        """Save comprehensive training checkpoint"""
        os.makedirs(os.path.dirname(filepath), exist_ok=True)
        
        checkpoint = {
            'episode': episode,
            'actor_state_dict': self.actor.state_dict(),
            'critic_state_dict': self.critic.state_dict(),
            'actor_optimizer_state_dict': self.actor_optimizer.state_dict(),
            'critic_optimizer_state_dict': self.critic_optimizer.state_dict(),
            'actor_scheduler_state_dict': self.actor_scheduler.state_dict(),
            'critic_scheduler_state_dict': self.critic_scheduler.state_dict(),
            'training_stats': dict(self.training_stats),
            'episode_stats': self.episode_stats,
            'config': RLConfig.get_dict(),
            'architecture': self.actor.architecture,
            'recent_stats': stats,
            'timestamp': datetime.now().isoformat()
        }

        torch.save(checkpoint, filepath)
        print(f"✅ Saved comprehensive RL checkpoint: {filepath}")

# Initialize the enhanced PPO trainer
if actor is not None and critic is not None:
    ppo_trainer = EnhancedPPOTrainer(actor, critic)
    
    print("🎯 Enhanced PPO Trainer initialized!")
    print("🔧 Key improvements:")
    print(f"  - Proper GAE implementation with λ={RLConfig.GAE_LAMBDA}")
    print(f"  - Episode boundary handling for trajectory segmentation")
    print(f"  - Memory-efficient experience buffer")
    print(f"  - Clipped value function loss")
    print(f"  - Learning rate scheduling")
    print(f"  - Comprehensive metrics tracking")
    print(f"  - Actor LR: {RLConfig.LEARNING_RATE}, Critic LR: {RLConfig.LEARNING_RATE * 2.0}")
    print("\n✅ Ready for enhanced RL training!")
    
else:
    print("❌ Cannot initialize PPO trainer - networks not available")
    ppo_trainer = None


In [None]:
# FIXED: Enhanced Training Loop and Metrics

class ComprehensiveMetrics:
    """Enhanced metrics tracking with detailed analysis"""
    
    def __init__(self):
        self.reset()
    
    def reset(self):
        """Reset all metrics"""
        self.rollout_stats = []
        self.training_stats = []
        self.episode_data = []
        
        # Aggregate statistics
        self.total_episodes = 0
        self.total_steps = 0
        self.total_catches = 0
        self.start_time = time.time()
    
    def add_rollout(self, episode_stats, training_stats, rollout_num):
        """Add rollout statistics"""
        rollout_data = {
            'rollout': rollout_num,
            'episodes': len(episode_stats),
            'episode_stats': episode_stats,
            'training_stats': training_stats,
            'timestamp': time.time()
        }
        
        self.rollout_stats.append(rollout_data)
        self.episode_data.extend(episode_stats)
        self.training_stats.append(training_stats)
        
        # Update aggregates
        for ep in episode_stats:
            self.total_episodes += 1
            self.total_steps += ep['length']
            self.total_catches += ep['catches']
    
    def get_recent_stats(self, window=100):
        """Get statistics for recent episodes"""
        if len(self.episode_data) == 0:
            return {}
        
        recent_episodes = self.episode_data[-window:]
        
        rewards = [ep['reward'] for ep in recent_episodes]
        lengths = [ep['length'] for ep in recent_episodes]
        catches = [ep['catches'] for ep in recent_episodes]
        successes = [1 if ep['catches'] > 0 else 0 for ep in recent_episodes]
        timeouts = [1 if ep['termination'] == 'timeout' else 0 for ep in recent_episodes]
        
        stats = {
            'episodes': len(recent_episodes),
            'avg_reward': np.mean(rewards) if rewards else 0,
            'std_reward': np.std(rewards) if len(rewards) > 1 else 0,
            'avg_length': np.mean(lengths) if lengths else 0,
            'avg_catches': np.mean(catches) if catches else 0,
            'total_catches': sum(catches),
            'success_rate': np.mean(successes) if successes else 0,
            'timeout_rate': np.mean(timeouts) if timeouts else 0,
            'catch_efficiency': sum(catches) / sum(lengths) if sum(lengths) > 0 else 0
        }
        
        return stats
    
    def print_progress(self, rollout_num, training_stats):
        """Print comprehensive progress information"""
        recent_stats = self.get_recent_stats(50)
        elapsed_time = time.time() - self.start_time
        
        print(f"\n📊 Rollout {rollout_num} Summary:")
        print(f"  ⏱️  Time: {elapsed_time:.1f}s ({elapsed_time/60:.1f}m)")
        print(f"  📈 Total: {self.total_episodes} episodes, {self.total_steps:,} steps")
        
        if recent_stats['episodes'] > 0:
            print(f"  🎯 Recent Performance (last {recent_stats['episodes']} episodes):")
            print(f"    Avg Reward: {recent_stats['avg_reward']:.3f} ± {recent_stats['std_reward']:.3f}")
            print(f"    Avg Length: {recent_stats['avg_length']:.1f} steps")
            print(f"    Avg Catches: {recent_stats['avg_catches']:.2f}")
            print(f"    Success Rate: {recent_stats['success_rate']:.1%}")
            print(f"    Timeout Rate: {recent_stats['timeout_rate']:.1%}")
            print(f"    Catch Efficiency: {recent_stats['catch_efficiency']:.4f}")
        
        print(f"  🧠 Training Metrics:")
        print(f"    Actor Loss: {training_stats['actor_loss']:.4f}")
        print(f"    Critic Loss: {training_stats['critic_loss']:.4f}")
        print(f"    Entropy: {training_stats['entropy']:.4f}")
        print(f"    Clip Fraction: {training_stats['clip_fraction']:.3f}")
        print(f"    KL Divergence: {training_stats['kl_div']:.4f}")
        print(f"    Learning Rates: Actor {training_stats['actor_lr']:.2e}, Critic {training_stats['critic_lr']:.2e}")
        
        # Exploration info
        with torch.no_grad():
            current_std = torch.exp(actor.log_std).mean().item()
        print(f"    Exploration Std: {current_std:.3f}")

def enhanced_rl_training(ppo_trainer, num_rollouts=100, test_mode=False):
    """FIXED: Enhanced RL training loop with comprehensive monitoring"""
    
    if ppo_trainer is None:
        print("❌ Cannot start training - PPO trainer not available")
        return None
    
    print(f"🚀 Starting Enhanced RL Training!")
    print(f"  Rollouts: {num_rollouts}")
    print(f"  Steps per rollout: {RLConfig.ROLLOUT_STEPS}")
    print(f"  Total steps: {num_rollouts * RLConfig.ROLLOUT_STEPS:,}")
    print(f"  Test mode: {test_mode}")
    
    # Initialize environment and metrics
    env = EnhancedPredatorEnvironment()
    metrics = ComprehensiveMetrics()
    
    # Create checkpoints directory
    os.makedirs(RLConfig.RL_CHECKPOINT_DIR, exist_ok=True)
    
    # Training loop
    best_performance = 0.0
    
    for rollout in range(1, num_rollouts + 1):
        rollout_start_time = time.time()
        
        try:
            # Collect rollout
            episode_stats = ppo_trainer.collect_rollout(env)
            
            # Update policy
            training_stats = ppo_trainer.update_policy()
            
            # Track metrics
            metrics.add_rollout(episode_stats, training_stats, rollout)
            ppo_trainer.episode_stats.extend(episode_stats)
            
            # Progress reporting
            if rollout % RLConfig.LOG_INTERVAL == 0:
                metrics.print_progress(rollout, training_stats)
                
                # Performance tracking
                recent_stats = metrics.get_recent_stats(100)
                current_performance = recent_stats.get('avg_catches', 0) * recent_stats.get('success_rate', 0)
                
                if current_performance > best_performance:
                    best_performance = current_performance
                    print(f"    🏆 New best performance: {best_performance:.3f}")
            
            # Save checkpoints
            if rollout % RLConfig.SAVE_INTERVAL == 0:
                checkpoint_path = f"{RLConfig.RL_CHECKPOINT_DIR}/checkpoint_rollout_{rollout}.pt"
                recent_stats = metrics.get_recent_stats(100)
                ppo_trainer.save_checkpoint(checkpoint_path, rollout, recent_stats)
                
                # Save as latest
                latest_path = f"{RLConfig.RL_CHECKPOINT_DIR}/latest_checkpoint.pt"
                ppo_trainer.save_checkpoint(latest_path, rollout, recent_stats)
            
            # Early stopping for test mode
            if test_mode and rollout >= 5:
                print(f"🧪 Test mode: stopping early after {rollout} rollouts")
                break
                
        except Exception as e:
            print(f"❌ Error in rollout {rollout}: {e}")
            import traceback
            traceback.print_exc()
            break
    
    # Final statistics
    total_time = time.time() - metrics.start_time
    final_stats = metrics.get_recent_stats(200)
    
    print(f"\n🎉 Training Completed!")
    print(f"  Duration: {total_time:.1f}s ({total_time/60:.1f}m)")
    print(f"  Total Episodes: {metrics.total_episodes}")
    print(f"  Total Steps: {metrics.total_steps:,}")
    print(f"  Total Catches: {metrics.total_catches}")
    print(f"  Best Performance: {best_performance:.3f}")
    
    if final_stats['episodes'] > 0:
        print(f"\n📈 Final Performance (last {final_stats['episodes']} episodes):")
        print(f"  Average Reward: {final_stats['avg_reward']:.3f}")
        print(f"  Average Catches: {final_stats['avg_catches']:.2f}")
        print(f"  Success Rate: {final_stats['success_rate']:.1%}")
        print(f"  Catch Efficiency: {final_stats['catch_efficiency']:.4f}")
    
    return metrics

def quick_training_test():
    """Quick test of the training system"""
    if ppo_trainer is None:
        print("❌ Cannot run training test - PPO trainer not available")
        return None
    
    print("🧪 Running quick training test (5 rollouts)...")
    
    # Temporarily reduce rollout size for quick test
    original_rollout_steps = RLConfig.ROLLOUT_STEPS
    RLConfig.ROLLOUT_STEPS = 256  # Smaller for quick test
    ppo_trainer.buffer = ExperienceBuffer(RLConfig.ROLLOUT_STEPS)  # Update buffer
    
    try:
        test_metrics = enhanced_rl_training(ppo_trainer, num_rollouts=5, test_mode=True)
        
        if test_metrics:
            print("✅ Training test completed successfully!")
            return test_metrics
        else:
            print("❌ Training test failed")
            return None
            
    finally:
        # Restore original settings
        RLConfig.ROLLOUT_STEPS = original_rollout_steps
        ppo_trainer.buffer = ExperienceBuffer(RLConfig.ROLLOUT_STEPS)

# Run quick test
if ppo_trainer is not None:
    print("🧪 Testing enhanced training system...")
    test_results = quick_training_test()
    
    if test_results:
        print("\n✅ All systems ready for full training!")
        print("🚀 To start full training, run:")
        print("   metrics = enhanced_rl_training(ppo_trainer, num_rollouts=500)")
    else:
        print("❌ Training test failed - check implementation")
else:
    print("❌ Skipping training test - PPO trainer not available")


In [None]:
# Evaluation, Visualization, and Export Tools

def evaluate_rl_agent(actor, critic, num_episodes=50, verbose=True):
    """Comprehensive evaluation of the RL-trained agent"""
    
    if actor is None:
        print("❌ Cannot evaluate - actor not available")
        return None, None
    
    print(f"📊 Evaluating RL agent over {num_episodes} episodes...")
    
    env = EnhancedPredatorEnvironment()
    eval_stats = []
    
    for episode in range(num_episodes):
        obs, info = env.reset()
        
        episode_reward = 0
        episode_catches = 0
        episode_length = 0
        done = False
        
        while not done:
            with torch.no_grad():
                action, _ = actor.get_action_and_log_prob(obs)
            
            obs, reward, done, step_info = env.step(action)
            
            episode_reward += reward
            episode_catches += step_info['catches_this_step']
            episode_length += 1
        
        eval_stats.append({
            'episode': episode + 1,
            'reward': episode_reward,
            'length': episode_length,
            'catches': episode_catches,
            'success': episode_catches > 0,
            'efficiency': episode_catches / episode_length if episode_length > 0 else 0,
            'termination': step_info['termination_reason'],
            'canvas_size': info['canvas_width'] * info['canvas_height'],
            'initial_boids': info['initial_boids']
        })
        
        if verbose and (episode + 1) % 10 == 0:
            recent_stats = eval_stats[-10:]
            avg_catches = np.mean([s['catches'] for s in recent_stats])
            success_rate = np.mean([s['success'] for s in recent_stats])
            print(f"  Episodes {episode-8}-{episode+1}: {avg_catches:.1f} catches/ep, {success_rate:.1%} success")
    
    # Calculate summary statistics
    summary = {
        'total_episodes': len(eval_stats),
        'avg_reward': np.mean([s['reward'] for s in eval_stats]),
        'std_reward': np.std([s['reward'] for s in eval_stats]),
        'avg_catches': np.mean([s['catches'] for s in eval_stats]),
        'avg_length': np.mean([s['length'] for s in eval_stats]),
        'success_rate': np.mean([s['success'] for s in eval_stats]),
        'avg_efficiency': np.mean([s['efficiency'] for s in eval_stats]),
        'timeout_rate': np.mean([1 if s['termination'] == 'timeout' else 0 for s in eval_stats]),
        'total_catches': sum([s['catches'] for s in eval_stats])
    }
    
    print(f"\n📈 Evaluation Results:")
    print(f"  Episodes: {summary['total_episodes']}")
    print(f"  Average Reward: {summary['avg_reward']:.3f} ± {summary['std_reward']:.3f}")
    print(f"  Average Catches: {summary['avg_catches']:.2f}")
    print(f"  Success Rate: {summary['success_rate']:.1%}")
    print(f"  Average Efficiency: {summary['avg_efficiency']:.4f} catches/step")
    print(f"  Average Episode Length: {summary['avg_length']:.1f} steps")
    print(f"  Timeout Rate: {summary['timeout_rate']:.1%}")
    print(f"  Total Catches: {summary['total_catches']}")
    
    return eval_stats, summary

def plot_training_metrics(metrics, save_path=None):
    """Enhanced training metrics visualization"""
    
    if metrics is None or len(metrics.episode_data) == 0:
        print("❌ No metrics to plot")
        return
    
    # Create comprehensive plots
    fig, axes = plt.subplots(3, 3, figsize=(18, 15))
    fig.suptitle('Enhanced RL Training Metrics', fontsize=16)
    
    # Episode rewards
    rewards = [ep['reward'] for ep in metrics.episode_data]
    axes[0,0].plot(rewards, alpha=0.7)
    if len(rewards) > 50:
        window = min(50, len(rewards) // 10)
        smoothed = np.convolve(rewards, np.ones(window)/window, mode='valid')
        axes[0,0].plot(range(window//2, len(smoothed)+window//2), smoothed, 'r-', linewidth=2)
    axes[0,0].set_title('Episode Rewards')
    axes[0,0].set_xlabel('Episode')
    axes[0,0].set_ylabel('Reward')
    axes[0,0].grid(True)
    
    # Catches per episode
    catches = [ep['catches'] for ep in metrics.episode_data]
    axes[0,1].plot(catches, alpha=0.7)
    if len(catches) > 50:
        window = min(50, len(catches) // 10)
        smoothed = np.convolve(catches, np.ones(window)/window, mode='valid')
        axes[0,1].plot(range(window//2, len(smoothed)+window//2), smoothed, 'r-', linewidth=2)
    axes[0,1].set_title('Catches per Episode')
    axes[0,1].set_xlabel('Episode')
    axes[0,1].set_ylabel('Catches')
    axes[0,1].grid(True)
    
    # Success rate (rolling average)
    successes = [1 if ep['catches'] > 0 else 0 for ep in metrics.episode_data]
    if len(successes) > 20:
        window = min(50, len(successes) // 5)
        success_rate = np.convolve(successes, np.ones(window)/window, mode='valid')
        axes[0,2].plot(range(window//2, len(success_rate)+window//2), success_rate)
        axes[0,2].set_title(f'Success Rate (rolling avg, window={window})')
        axes[0,2].set_xlabel('Episode')
        axes[0,2].set_ylabel('Success Rate')
        axes[0,2].grid(True)
    
    # Training losses
    if len(metrics.training_stats) > 0:
        actor_losses = [stat['actor_loss'] for stat in metrics.training_stats]
        critic_losses = [stat['critic_loss'] for stat in metrics.training_stats]
        axes[1,0].plot(actor_losses, label='Actor Loss')
        axes[1,0].plot(critic_losses, label='Critic Loss')
        axes[1,0].set_title('Training Losses')
        axes[1,0].set_xlabel('Update')
        axes[1,0].set_ylabel('Loss')
        axes[1,0].legend()
        axes[1,0].grid(True)
        
        # Entropy and exploration
        entropies = [stat['entropy'] for stat in metrics.training_stats]
        axes[1,1].plot(entropies, 'g-', label='Entropy')
        axes[1,1].set_title('Policy Entropy')
        axes[1,1].set_xlabel('Update')
        axes[1,1].set_ylabel('Entropy')
        axes[1,1].grid(True)
        
        # KL divergence and clip fraction
        kl_divs = [stat['kl_div'] for stat in metrics.training_stats]
        clip_fracs = [stat['clip_fraction'] for stat in metrics.training_stats]
        ax1 = axes[1,2]
        ax2 = ax1.twinx()
        ax1.plot(kl_divs, 'b-', label='KL Divergence')
        ax2.plot(clip_fracs, 'r-', label='Clip Fraction')
        ax1.set_xlabel('Update')
        ax1.set_ylabel('KL Divergence', color='b')
        ax2.set_ylabel('Clip Fraction', color='r')
        ax1.set_title('KL Divergence & Clip Fraction')
        ax1.grid(True)
    
    # Episode lengths
    lengths = [ep['length'] for ep in metrics.episode_data]
    axes[2,0].plot(lengths, alpha=0.7)
    if len(lengths) > 50:
        window = min(50, len(lengths) // 10)
        smoothed = np.convolve(lengths, np.ones(window)/window, mode='valid')
        axes[2,0].plot(range(window//2, len(smoothed)+window//2), smoothed, 'r-', linewidth=2)
    axes[2,0].set_title('Episode Lengths')
    axes[2,0].set_xlabel('Episode')
    axes[2,0].set_ylabel('Steps')
    axes[2,0].grid(True)
    
    # Catch efficiency
    efficiency = [ep['catches'] / max(ep['length'], 1) for ep in metrics.episode_data]
    axes[2,1].plot(efficiency, alpha=0.7)
    if len(efficiency) > 50:
        window = min(50, len(efficiency) // 10)
        smoothed = np.convolve(efficiency, np.ones(window)/window, mode='valid')
        axes[2,1].plot(range(window//2, len(smoothed)+window//2), smoothed, 'r-', linewidth=2)
    axes[2,1].set_title('Catch Efficiency')
    axes[2,1].set_xlabel('Episode')
    axes[2,1].set_ylabel('Catches per Step')
    axes[2,1].grid(True)
    
    # Learning rates
    if len(metrics.training_stats) > 0:
        actor_lrs = [stat['actor_lr'] for stat in metrics.training_stats]
        critic_lrs = [stat['critic_lr'] for stat in metrics.training_stats]
        axes[2,2].plot(actor_lrs, label='Actor LR')
        axes[2,2].plot(critic_lrs, label='Critic LR')
        axes[2,2].set_title('Learning Rates')
        axes[2,2].set_xlabel('Update')
        axes[2,2].set_ylabel('Learning Rate')
        axes[2,2].legend()
        axes[2,2].grid(True)
        axes[2,2].set_yscale('log')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"📊 Plots saved to {save_path}")
    
    plt.show()
    print("📊 Training metrics plotted!")

def export_rl_model(actor, output_path="policy/transformer/models/rl_trained_model.js"):
    """Export RL-trained transformer to JavaScript with comprehensive validation"""
    
    if actor is None:
        print("❌ Cannot export - actor not available")
        return False
    
    print(f"🔄 Exporting RL-trained model to JavaScript...")
    
    # Create export checkpoint
    checkpoint_path = f"{RLConfig.RL_CHECKPOINT_DIR}/export_checkpoint.pt"
    os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
    
    # Get current exploration parameters
    with torch.no_grad():
        exploration_std = torch.exp(actor.log_std).cpu().numpy().tolist()
    
    checkpoint = {
        'model_state_dict': actor.transformer.state_dict(),
        'episode': 'rl_export',
        'architecture': actor.architecture,
        'exploration_std': exploration_std,
        'training_type': 'reinforcement_learning_fine_tuned',
        'base_model': 'supervised_learning_transformer',
        'timestamp': datetime.now().isoformat()
    }
    
    torch.save(checkpoint, checkpoint_path)
    print(f"✅ Saved RL export checkpoint: {checkpoint_path}")
    
    # Use export_to_js.py script
    try:
        import subprocess
        result = subprocess.run([
            sys.executable, "export_to_js.py",
            "--checkpoint", checkpoint_path,
            "--output", output_path,
            "--info"
        ], capture_output=True, text=True)
        
        if result.returncode == 0:
            print(f"✅ RL model exported successfully!")
            print(f"📁 Output: {output_path}")
            print("🎉 RL-trained model ready for browser deployment!")
            print("\n" + "="*60)
            print("🚀 EXPORT DETAILS:")
            print("="*60)
            print(result.stdout)
            return True
        else:
            print(f"❌ Export failed:")
            print(result.stderr)
            return False
            
    except Exception as e:
        print(f"❌ Export error: {e}")
        return False

def compare_sl_vs_rl(sl_actor, rl_actor, num_episodes=30):
    """Compare Supervised Learning vs Reinforcement Learning performance"""
    
    if sl_actor is None or rl_actor is None:
        print("❌ Cannot compare - need both SL and RL actors")
        return None, None
    
    print(f"⚔️  Comparing SL vs RL Performance ({num_episodes} episodes each)...")
    
    # Evaluate SL model
    print("\\n📊 Evaluating SL model...")
    sl_stats, sl_summary = evaluate_rl_agent(sl_actor, None, num_episodes, verbose=False)
    
    # Evaluate RL model  
    print("\\n📊 Evaluating RL model...")
    rl_stats, rl_summary = evaluate_rl_agent(rl_actor, None, num_episodes, verbose=False)
    
    # Detailed comparison
    print(f"\\n⚔️  Detailed Performance Comparison:")
    print(f"{'Metric':<25} | {'SL Model':<12} | {'RL Model':<12} | {'Improvement':<12}")
    print("-" * 70)
    
    metrics_comparison = [
        ('Avg Catches', 'avg_catches', '{:.2f}'),
        ('Success Rate', 'success_rate', '{:.1%}'),
        ('Catch Efficiency', 'avg_efficiency', '{:.4f}'),
        ('Avg Episode Length', 'avg_length', '{:.1f}'),
        ('Timeout Rate', 'timeout_rate', '{:.1%}'),
        ('Avg Reward', 'avg_reward', '{:.3f}'),
        ('Total Catches', 'total_catches', '{:.0f}')
    ]
    
    for metric_name, metric_key, format_str in metrics_comparison:
        sl_val = sl_summary[metric_key]
        rl_val = rl_summary[metric_key]
        
        # Calculate improvement (handle division by zero)
        if sl_val != 0:
            if metric_name in ['Timeout Rate']:  # Lower is better
                improvement = ((sl_val - rl_val) / abs(sl_val)) * 100
            else:  # Higher is better
                improvement = ((rl_val - sl_val) / abs(sl_val)) * 100
        else:
            improvement = float('inf') if rl_val > 0 else 0
        
        sl_str = format_str.format(sl_val)
        rl_str = format_str.format(rl_val)
        
        if improvement != float('inf'):
            imp_str = f"{improvement:+7.1f}%"
        else:
            imp_str = "    ∞%"
        
        print(f"{metric_name:<25} | {sl_str:<12} | {rl_str:<12} | {imp_str:<12}")
    
    return sl_stats, rl_stats

# Summary and instructions
print("\\n" + "="*80)
print("🎯 ENHANCED RL TRAINING NOTEBOOK - COMPREHENSIVE IMPLEMENTATION")
print("="*80)
print("✅ All critical issues from original notebook have been FIXED:")
print("")
print("🔧 MAJOR FIXES IMPLEMENTED:")
print("  1. ✅ Fixed StateManager integration (no more direct simulation access)")
print("  2. ✅ Fixed stochastic actor policy with learnable exploration") 
print("  3. ✅ Implemented proper GAE with episode boundary handling")
print("  4. ✅ Fixed memory-efficient experience buffer")
print("  5. ✅ Enhanced feature extraction with error handling")
print("  6. ✅ Comprehensive metrics and progress tracking")
print("  7. ✅ Single source of truth configuration system")
print("  8. ✅ Maintained multi-step reward attribution (as requested)")
print("")
print("🚀 READY FOR FULL TRAINING:")
print("   # For full training, run:")
print("   # metrics = enhanced_rl_training(ppo_trainer, num_rollouts=500)")
print("")
print("📊 EVALUATION AND EXPORT:")
print("   # Evaluate trained model:")
print("   # eval_stats, eval_summary = evaluate_rl_agent(actor, critic, num_episodes=50)")
print("   # ")
print("   # Plot training progress:")
print("   # plot_training_metrics(metrics)")
print("   #")
print("   # Export to JavaScript:")
print("   # export_rl_model(actor)")
print("")
print("⚔️  COMPARISON:")
print("   # Compare SL vs RL performance:")
print("   # sl_stats, rl_stats = compare_sl_vs_rl(sl_actor, rl_actor)")
print("="*80)
