In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import yaml
import pandas as pd
from pathlib import Path
import random
from typing import List, Dict, Tuple
import seaborn as sns
from transformers import AutoTokenizer, AutoModelForCausalLM
import re

# Set style for better plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

class ModelAnalyzer:
    """Analyzer for trained prover-verifier models using real checkpoint data"""
    
    def __init__(self, checkpoint_dir: str, config_path: str = None):
        self.checkpoint_dir = Path(checkpoint_dir)
        self.config = self.load_config(config_path) if config_path else {}
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Cache for loaded data
        self.checkpoints_cache = {}
        
        print(f"Initialized analyzer for {checkpoint_dir}")
        available_checkpoints = list(self.checkpoint_dir.glob('*.pt'))
        print(f"Found {len(available_checkpoints)} checkpoints")
        for cp in sorted(available_checkpoints):
            print(f"  - {cp.name}")
    
    def load_config(self, config_path: str):
        """Load configuration file"""
        try:
            with open(config_path) as f:
                return yaml.safe_load(f)
        except Exception as e:
            print(f"Could not load config: {e}")
            return {}
    
    def load_checkpoint(self, round_idx: int):
        """Load checkpoint for specific round"""
        if round_idx in self.checkpoints_cache:
            return self.checkpoints_cache[round_idx]
            
        # Try different checkpoint naming patterns
        patterns = [
            f"kirchner_round_{round_idx:03d}.pt",
            f"pure_stackelberg_round_{round_idx:03d}.pt",
            f"round_{round_idx:03d}.pt"
        ]
        
        checkpoint_file = None
        for pattern in patterns:
            potential_file = self.checkpoint_dir / pattern
            if potential_file.exists():
                checkpoint_file = potential_file
                break
        
        if checkpoint_file is None:
            print(f"No checkpoint found for round {round_idx}")
            return None
        
        try:
            checkpoint = torch.load(checkpoint_file, map_location='cpu')
            self.checkpoints_cache[round_idx] = checkpoint
            print(f"Loaded checkpoint from round {round_idx}: {checkpoint_file.name}")
            return checkpoint
        except Exception as e:
            print(f"Error loading checkpoint for round {round_idx}: {e}")
            return None
    
    def extract_replay_buffer_stats(self, replay_buffer: List[Tuple]) -> Dict:
        """Extract statistics from replay buffer"""
        if not replay_buffer:
            return {}
        
        helpful_items = [item for item in replay_buffer if len(item) > 5 and item[5] == 'helpful']
        sneaky_items = [item for item in replay_buffer if len(item) > 5 and item[5] == 'sneaky']
        
        stats = {
            'total_items': len(replay_buffer),
            'helpful_count': len(helpful_items),
            'sneaky_count': len(sneaky_items),
        }
        
        # Extract rewards if available
        if helpful_items and len(helpful_items[0]) > 3:
            helpful_rewards = [item[3] for item in helpful_items if isinstance(item[3], (int, float))]
            if helpful_rewards:
                stats['helpful_reward_mean'] = np.mean(helpful_rewards)
                stats['helpful_reward_std'] = np.std(helpful_rewards)
        
        if sneaky_items and len(sneaky_items[0]) > 3:
            sneaky_rewards = [item[3] for item in sneaky_items if isinstance(item[3], (int, float))]
            if sneaky_rewards:
                stats['sneaky_reward_mean'] = np.mean(sneaky_rewards)
                stats['sneaky_reward_std'] = np.std(sneaky_rewards)
        
        # Extract solution lengths
        if helpful_items and len(helpful_items[0]) > 2:
            helpful_solutions = [item[2] for item in helpful_items if isinstance(item[2], str)]
            if helpful_solutions:
                stats['helpful_solution_length_mean'] = np.mean([len(sol.split()) for sol in helpful_solutions])
        
        if sneaky_items and len(sneaky_items[0]) > 2:
            sneaky_solutions = [item[2] for item in sneaky_items if isinstance(item[2], str)]
            if sneaky_solutions:
                stats['sneaky_solution_length_mean'] = np.mean([len(sol.split()) for sol in sneaky_solutions])
        
        return stats
    
    def analyze_training_progression(self, max_rounds: int = None) -> pd.DataFrame:
        """Analyze actual training progression from checkpoints"""
        if max_rounds is None:
            # Auto-detect max rounds
            checkpoint_files = list(self.checkpoint_dir.glob('*.pt'))
            round_numbers = []
            for f in checkpoint_files:
                match = re.search(r'round_(\d+)', f.name)
                if match:
                    round_numbers.append(int(match.group(1)))
            max_rounds = max(round_numbers) + 1 if round_numbers else 10
        
        progression_data = []
        
        for round_idx in range(max_rounds):
            checkpoint = self.load_checkpoint(round_idx)
            if checkpoint is None:
                continue
            
            round_data = {'round': round_idx}
            
            # Extract replay buffer statistics
            if 'replay_buffer' in checkpoint:
                replay_stats = self.extract_replay_buffer_stats(checkpoint['replay_buffer'])
                round_data.update(replay_stats)
            
            # Extract any other metrics from checkpoint
            for key in ['round', 'pure_stackelberg', 'config']:
                if key in checkpoint:
                    if key == 'config' and isinstance(checkpoint[key], dict):
                        # Extract relevant config info
                        if 'training' in checkpoint[key]:
                            round_data.update({f"config_{k}": v for k, v in checkpoint[key]['training'].items() 
                                             if isinstance(v, (int, float, str))})
                    else:
                        round_data[key] = checkpoint[key]
            
            progression_data.append(round_data)
        
        return pd.DataFrame(progression_data)
    
    def plot_kirchner_style_training(self, max_rounds: int = None):
        """Plot training dynamics in Kirchner paper style"""
        df = self.analyze_training_progression(max_rounds)
        
        if df.empty:
            print("No training data found in checkpoints")
            return None
        
        print(f"Loaded data for {len(df)} rounds")
        print("Available columns:", df.columns.tolist())
        
        # Create the plot
        fig, axes = plt.subplots(2, 2, figsize=(12, 8))
        
        # Plot 1: Episode counts over rounds
        if 'helpful_count' in df.columns and 'sneaky_count' in df.columns:
            axes[0,0].plot(df['round'], df['helpful_count'], 'o-', label='Helpful', color='green', linewidth=2)
            axes[0,0].plot(df['round'], df['sneaky_count'], 'o-', label='Sneaky', color='red', linewidth=2)
            axes[0,0].set_xlabel('Round')
            axes[0,0].set_ylabel('Episode Count')
            axes[0,0].set_title('Episodes Generated per Round')
            axes[0,0].legend()
            axes[0,0].grid(True, alpha=0.3)
        
        # Plot 2: Reward evolution
        if 'helpful_reward_mean' in df.columns and 'sneaky_reward_mean' in df.columns:
            axes[0,1].plot(df['round'], df['helpful_reward_mean'], 'o-', label='Helpful', color='green', linewidth=2)
            axes[0,1].plot(df['round'], df['sneaky_reward_mean'], 'o-', label='Sneaky', color='red', linewidth=2)
            
            # Add error bars if std available
            if 'helpful_reward_std' in df.columns:
                axes[0,1].fill_between(df['round'], 
                                     df['helpful_reward_mean'] - df['helpful_reward_std'],
                                     df['helpful_reward_mean'] + df['helpful_reward_std'],
                                     alpha=0.2, color='green')
            if 'sneaky_reward_std' in df.columns:
                axes[0,1].fill_between(df['round'], 
                                     df['sneaky_reward_mean'] - df['sneaky_reward_std'],
                                     df['sneaky_reward_mean'] + df['sneaky_reward_std'],
                                     alpha=0.2, color='red')
            
            axes[0,1].set_xlabel('Round')
            axes[0,1].set_ylabel('Average Reward')
            axes[0,1].set_title('Reward Evolution Over Training')
            axes[0,1].legend()
            axes[0,1].grid(True, alpha=0.3)
        
        # Plot 3: Total data accumulation
        if 'total_items' in df.columns:
            cumulative_data = df['total_items'].fillna(0)
            axes[1,0].plot(df['round'], cumulative_data, 'o-', color='blue', linewidth=2)
            axes[1,0].set_xlabel('Round')
            axes[1,0].set_ylabel('Total Items in Replay Buffer')
            axes[1,0].set_title('Data Accumulation')
            axes[1,0].grid(True, alpha=0.3)
        
        # Plot 4: Solution length evolution
        if 'helpful_solution_length_mean' in df.columns and 'sneaky_solution_length_mean' in df.columns:
            axes[1,1].plot(df['round'], df['helpful_solution_length_mean'], 'o-', 
                          label='Helpful', color='green', linewidth=2)
            axes[1,1].plot(df['round'], df['sneaky_solution_length_mean'], 'o-', 
                          label='Sneaky', color='red', linewidth=2)
            axes[1,1].set_xlabel('Round')
            axes[1,1].set_ylabel('Average Solution Length (words)')
            axes[1,1].set_title('Solution Complexity Evolution')
            axes[1,1].legend()
            axes[1,1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        return df
    
    def sample_real_solutions(self, round_idx: int, num_samples: int = 3) -> Dict:
        """Sample actual solutions from checkpoint data"""
        checkpoint = self.load_checkpoint(round_idx)
        if not checkpoint or 'replay_buffer' not in checkpoint:
            print(f"No replay buffer data for round {round_idx}")
            return {}
        
        replay_buffer = checkpoint['replay_buffer']
        helpful_items = [item for item in replay_buffer if len(item) > 5 and item[5] == 'helpful']
        sneaky_items = [item for item in replay_buffer if len(item) > 5 and item[5] == 'sneaky']
        
        samples = {
            'helpful': random.sample(helpful_items, min(num_samples, len(helpful_items))),
            'sneaky': random.sample(sneaky_items, min(num_samples, len(sneaky_items)))
        }
        
        return samples
    
    def display_solution_samples(self, round_idx: int, num_samples: int = 2):
        """Display actual solution samples from training"""
        samples = self.sample_real_solutions(round_idx, num_samples)
        
        if not samples:
            print(f"No samples available for round {round_idx}")
            return
        
        print(f"=== SOLUTION SAMPLES FROM ROUND {round_idx} ===\n")
        
        for role in ['helpful', 'sneaky']:
            if role in samples and samples[role]:
                print(f"{role.upper()} SOLUTIONS:")
                print("=" * 50)
                
                for i, item in enumerate(samples[role]):
                    if len(item) > 2 and isinstance(item[2], str):
                        solution = item[2]
                        reward = item[3] if len(item) > 3 else "N/A"
                        
                        print(f"\nSample {i+1} (Reward: {reward}):")
                        print("-" * 30)
                        print(solution[:500] + ("..." if len(solution) > 500 else ""))
                        print()
                
                print("=" * 50)
                print()
    
    def compare_rounds_detailed(self, round1: int, round2: int):
        """Detailed comparison between two rounds"""
        df = self.analyze_training_progression()
        
        if round1 not in df['round'].values or round2 not in df['round'].values:
            print(f"Data not available for rounds {round1} or {round2}")
            return
        
        data1 = df[df['round'] == round1].iloc[0]
        data2 = df[df['round'] == round2].iloc[0]
        
        print(f"\n=== DETAILED COMPARISON: ROUND {round1} vs ROUND {round2} ===")
        
        # Compare numeric columns
        numeric_cols = df.select_dtypes(include=[np.number]).columns
        for col in numeric_cols:
            if col in data1 and col in data2 and col != 'round':
                val1 = data1[col]
                val2 = data2[col]
                change = val2 - val1
                pct_change = (change / val1 * 100) if val1 != 0 else 0
                
                print(f"{col}:")
                print(f"  Round {round1}: {val1:.3f}")
                print(f"  Round {round2}: {val2:.3f}")
                print(f"  Change: {change:+.3f} ({pct_change:+.1f}%)")
                print()
        
        return data1, data2
    
    def analyze_reward_distribution(self, max_rounds: int = None):
        """Analyze reward distributions across training"""
        df = self.analyze_training_progression(max_rounds)
        
        if df.empty:
            return
        
        # Collect all rewards from all rounds
        all_helpful_rewards = []
        all_sneaky_rewards = []
        round_labels_helpful = []
        round_labels_sneaky = []
        
        for round_idx in df['round']:
            checkpoint = self.load_checkpoint(round_idx)
            if checkpoint and 'replay_buffer' in checkpoint:
                helpful_items = [item for item in checkpoint['replay_buffer'] 
                               if len(item) > 5 and item[5] == 'helpful']
                sneaky_items = [item for item in checkpoint['replay_buffer'] 
                              if len(item) > 5 and item[5] == 'sneaky']
                
                helpful_rewards = [item[3] for item in helpful_items 
                                 if len(item) > 3 and isinstance(item[3], (int, float))]
                sneaky_rewards = [item[3] for item in sneaky_items 
                                if len(item) > 3 and isinstance(item[3], (int, float))]
                
                all_helpful_rewards.extend(helpful_rewards)
                all_sneaky_rewards.extend(sneaky_rewards)
                round_labels_helpful.extend([round_idx] * len(helpful_rewards))
                round_labels_sneaky.extend([round_idx] * len(sneaky_rewards))
        
        # Plot distributions
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        
        # Overall distribution
        axes[0].hist(all_helpful_rewards, alpha=0.6, label='Helpful', bins=30, color='green')
        axes[0].hist(all_sneaky_rewards, alpha=0.6, label='Sneaky', bins=30, color='red')
        axes[0].set_xlabel('Reward')
        axes[0].set_ylabel('Frequency')
        axes[0].set_title('Overall Reward Distribution')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
        
        # Reward evolution over rounds
        if len(set(round_labels_helpful)) > 1:
            reward_by_round = {}
            for round_idx in sorted(set(round_labels_helpful)):
                helpful_round_rewards = [r for r, rl in zip(all_helpful_rewards, round_labels_helpful) if rl == round_idx]
                sneaky_round_rewards = [r for r, rl in zip(all_sneaky_rewards, round_labels_sneaky) if rl == round_idx]
                
                if helpful_round_rewards:
                    reward_by_round[f"helpful_{round_idx}"] = helpful_round_rewards
                if sneaky_round_rewards:
                    reward_by_round[f"sneaky_{round_idx}"] = sneaky_round_rewards
            
            # Box plot by round
            rounds = sorted(set(round_labels_helpful))
            helpful_means = [np.mean([r for r, rl in zip(all_helpful_rewards, round_labels_helpful) if rl == round_idx]) 
                           for round_idx in rounds]
            sneaky_means = [np.mean([r for r, rl in zip(all_sneaky_rewards, round_labels_sneaky) if rl == round_idx]) 
                          for round_idx in rounds]
            
            axes[1].plot(rounds, helpful_means, 'o-', label='Helpful', color='green', linewidth=2)
            axes[1].plot(rounds, sneaky_means, 'o-', label='Sneaky', color='red', linewidth=2)
            axes[1].set_xlabel('Round')
            axes[1].set_ylabel('Mean Reward')
            axes[1].set_title('Reward Evolution by Round')
            axes[1].legend()
            axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        return {
            'helpful_rewards': all_helpful_rewards,
            'sneaky_rewards': all_sneaky_rewards,
            'helpful_stats': {
                'mean': np.mean(all_helpful_rewards),
                'std': np.std(all_helpful_rewards),
                'count': len(all_helpful_rewards)
            },
            'sneaky_stats': {
                'mean': np.mean(all_sneaky_rewards),
                'std': np.std(all_sneaky_rewards),
                'count': len(all_sneaky_rewards)
            }
        }

# Quick analysis functions
def quick_analysis(checkpoint_dir: str, config_path: str = None):
    """Run quick analysis of training results"""
    analyzer = ModelAnalyzer(checkpoint_dir, config_path)
    
    print("1. Training Progression Analysis")
    df = analyzer.plot_kirchner_style_training()
    
    print("\n2. Reward Distribution Analysis")
    reward_stats = analyzer.analyze_reward_distribution()
    if reward_stats:
        print(f"Helpful rewards: μ={reward_stats['helpful_stats']['mean']:.3f}, σ={reward_stats['helpful_stats']['std']:.3f}")
        print(f"Sneaky rewards: μ={reward_stats['sneaky_stats']['mean']:.3f}, σ={reward_stats['sneaky_stats']['std']:.3f}")
    
    print("\n3. Sample Solutions from Latest Round")
    latest_round = df['round'].max() if not df.empty else 0
    analyzer.display_solution_samples(latest_round)
    
    return analyzer, df

def compare_early_vs_late(analyzer, early_round: int = 0, late_round: int = None):
    """Compare early vs late training rounds"""
    if late_round is None:
        df = analyzer.analyze_training_progression()
        late_round = df['round'].max() if not df.empty else 5
    
    print(f"Comparing Round {early_round} vs Round {late_round}")
    analyzer.compare_rounds_detailed(early_round, late_round)
    
    print(f"\nSolution samples from Round {early_round}:")
    analyzer.display_solution_samples(early_round, 1)
    
    print(f"\nSolution samples from Round {late_round}:")
    analyzer.display_solution_samples(late_round, 1)

if __name__ == "__main__":
    analyzer, df = quick_analysis(
        checkpoint_dir="runs/pure_stackelberg_experiment/checkpoints",
        config_path="pmv/configs/config_pure_stackelberg.yaml"
    )
    
    print("\nother commands:")
    print("- analyzer.display_solution_samples(round_idx)")
    print("- analyzer.compare_rounds_detailed(round1, round2)")
    print("- compare_early_vs_late(analyzer)")

Could not load config: [Errno 2] No such file or directory: 'pmv/configs/config_pure_stackelberg.yaml'
Initialized analyzer for runs/pure_stackelberg_experiment/checkpoints
Found 0 checkpoints
1. Training Progression Analysis
No checkpoint found for round 0
No checkpoint found for round 1
No checkpoint found for round 2
No checkpoint found for round 3
No checkpoint found for round 4
No checkpoint found for round 5
No checkpoint found for round 6
No checkpoint found for round 7
No checkpoint found for round 8
No checkpoint found for round 9
No training data found in checkpoints

2. Reward Distribution Analysis
No checkpoint found for round 0
No checkpoint found for round 1
No checkpoint found for round 2
No checkpoint found for round 3
No checkpoint found for round 4
No checkpoint found for round 5
No checkpoint found for round 6
No checkpoint found for round 7
No checkpoint found for round 8
No checkpoint found for round 9

3. Sample Solutions from Latest Round


AttributeError: 'NoneType' object has no attribute 'empty'

/Users/saranya/Desktop/PMV/pmv
