# Analysis of DQN vs. PPO on Atari Pong

This notebook loads the training data generated by `scripts/train.py` and produces learning curves to compare the performance of the different agents, consistent with the figures in the paper.

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Set plot style
sns.set_theme(style="whitegrid")

## 1. Load Experiment Data

We'll scan the `data/` directory to find all the training runs, load their `log.csv` files, and combine them into a single DataFrame. We extract the agent name and seed from the directory name.

In [None]:
DATA_DIR = '../data'
SEEDS = [42, 84, 126]
AGENTS = ['dqn_vanilla', 'dqn_enhanced', 'ppo']

all_data = []

for agent in AGENTS:
    for seed in SEEDS:
        # Find the directory that matches the agent and seed
        run_dirs = [d for d in os.listdir(DATA_DIR) if d.startswith(f"{agent}__{seed}")]
        if not run_dirs:
            print(f"Warning: No data found for {agent} with seed {seed}")
            continue
        
        run_dir = run_dirs[0] # Assume only one run per agent/seed
        log_path = os.path.join(DATA_DIR, run_dir, 'log.csv')
        
        try:
            df = pd.read_csv(log_path)
            df['agent'] = agent
            df['seed'] = seed
            all_data.append(df)
        except FileNotFoundError:
            print(f"Warning: log.csv not found in {run_dir}")

if all_data:
    results_df = pd.concat(all_data, ignore_index=True)
    print("Data loaded successfully!")
    print(f"Total rows: {len(results_df)}")
    results_df.head()
else:
    print("No data was loaded. Please run training first.")

## 2. Plot Learning Curves

Now we'll use Seaborn's `lineplot` to visualize the mean evaluation return over training frames. The shaded area represents the 95% confidence interval across the different seeds, giving us a measure of training stability.

In [None]:
def plot_learning_curves(df, title, x='step', y='eval/mean_return', hue='agent'):
    """Helper function to plot learning curves with confidence intervals."""
    plt.figure(figsize=(12, 8))
    
    # Seaborn automatically calculates mean and 95% CI
    sns.lineplot(data=df, x=x, y=y, hue=hue, errorbar=('ci', 95))

    plt.title(title, fontsize=16)
    plt.xlabel('Environment Frames', fontsize=12)
    plt.ylabel('Mean Episode Return', fontsize=12)
    plt.legend(title='Agent')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

# Filter out rows without evaluation data
eval_df = results_df.dropna(subset=['eval/mean_return'])
eval_df['step_in_millions'] = eval_df['step'] / 1_000_000

# Plot 1: Ablation Study (Vanilla vs. Enhanced DQN)
ablation_df = eval_df[eval_df['agent'].isin(['dqn_vanilla', 'dqn_enhanced'])]
plot_learning_curves(ablation_df, 'Ablation Study: Vanilla DQN vs. Enhanced DQN', x='step_in_millions')

# Plot 2: Main Comparison (Enhanced DQN vs. PPO)
main_comparison_df = eval_df[eval_df['agent'].isin(['dqn_enhanced', 'ppo'])]
plot_learning_curves(main_comparison_df, 'Main Comparison: Enhanced DQN vs. PPO', x='step_in_millions')

## 3. Quantitative Analysis

Here, we can compute the specific metrics mentioned in the paper, such as 'Frames to +15 Score' and 'Final Score'.

In [None]:
TARGET_SCORE = 15
FINAL_EVALS = 5

summary = []

for agent in AGENTS:
    agent_df = eval_df[eval_df['agent'] == agent]
    
    # Frames to Target Score
    frames_to_target = agent_df[agent_df['eval/mean_return'] >= TARGET_SCORE]['step'].min()
    
    # Final Score (average over last N evaluations per seed, then average seeds)
    final_scores_per_seed = []
    for seed in SEEDS:
        seed_df = agent_df[agent_df['seed'] == seed]
        if not seed_df.empty:
            last_n_scores = seed_df.nlargest(FINAL_EVALS, 'step')['eval/mean_return']
            final_scores_per_seed.append(last_n_scores.mean())
    
    final_score_mean = np.mean(final_scores_per_seed) if final_scores_per_seed else np.nan
    final_score_std = np.std(final_scores_per_seed) if final_scores_per_seed else np.nan

    summary.append({
        'Agent': agent,
        'Frames to +15 Score': frames_to_target,
        'Final Score (Mean)': final_score_mean,
        'Final Score (Std)': final_score_std
    })

summary_df = pd.DataFrame(summary)
summary_df