In [5]:
SHOGGOTH_IDS = [
    "w4tt1jsu",
    "zfl5u3fd",
    "1sde8jph",
    "8fgq0wro",
    "w3euhlb8",
    ]

NO_PENALTY_IDS = [
    "9f3aqg2s",
    "x4ycjhu3",
    "tp2lfh4y",
    "1acq8chp",
    "put7hoex",
    "w6gsb6ok",
    "ip9tix4l",
    "5n394zv2",
]

PENALTY_IDS = [
    "zmhndno2",
    "sz2dldwu",
    "knaehdyp",
    "tty1nb3s",
    "z5f8wxgt",
    "eckpvxsz",
    "a11zhjue",
    "9wyk1cnx",
]

In [7]:
import wandb
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Initialize wandb API
api = wandb.Api()

# Project configuration
project_name = "qwen3-reinforce-reasongym"
entity = None  # Set this to your wandb entity if needed

# Construct project path
project_path = f"{entity}/{project_name}" if entity else project_name

print(f"Loading runs from project: {project_path}")

# Get all runs from the project
runs = api.runs(project_path)

print(f"Found {len(runs)} total runs in the project")

# Filter for acre-related runs (based on the config patterns)
acre_runs = []
for run in runs:
    if run.name and ("acre" in run.name.lower() or "acre_shoggoth" in run.name.lower()):
        acre_runs.append(run)

print(f"Found {len(acre_runs)} acre-related runs:")
for run in acre_runs:
    print(f"  - {run.name} (state: {run.state}, created: {run.created_at})")


Loading runs from project: qwen3-reinforce-reasongym
Found 183 total runs in the project
Found 150 acre-related runs:
  - acre-1 (state: killed, created: 2025-07-09T01:08:28Z)
  - acre-2 (state: killed, created: 2025-07-09T01:09:49Z)
  - acre-3 (state: killed, created: 2025-07-09T04:34:46Z)
  - acre-4 (state: killed, created: 2025-07-11T21:00:27Z)
  - acre-5 (state: failed, created: 2025-07-11T21:12:23Z)
  - acre-6 (state: failed, created: 2025-07-11T21:25:44Z)
  - acre-7 (state: killed, created: 2025-07-11T21:40:53Z)
  - acre-8 (state: killed, created: 2025-07-11T22:45:12Z)
  - acre-9 (state: killed, created: 2025-07-11T22:52:48Z)
  - acre-10 (state: killed, created: 2025-07-11T23:03:11Z)
  - acre-11 (state: killed, created: 2025-07-12T00:04:56Z)
  - acre-12 (state: killed, created: 2025-07-12T00:29:43Z)
  - acre-13 (state: killed, created: 2025-07-12T00:40:40Z)
  - acre face shoggoth (state: killed, created: 2025-07-12T02:14:47Z)
  - acre face shog short (state: finished, created: 20

In [None]:
# Load detailed data for each acre run
run_data = []

for run in acre_runs:
    print(f"\nLoading data for run: {run.name}")
    
    # Get run configuration
    config = run.config
    
    # Get run summary (final metrics)
    summary = run.summary
    
    # Get run history (all logged metrics over time)
    history = run.history()
    
    # Store run information
    run_info = {
        'run': run,
        'name': run.name,
        'id': run.id,
        'state': run.state,
        'config': config,
        'summary': summary,
        'history': history
    }
    
    run_data.append(run_info)
    
    print(f"  - Config keys: {list(config.keys())}")
    print(f"  - Summary keys: {list(summary.keys())}")
    print(f"  - History shape: {history.shape}")
    print(f"  - History columns: {list(history.columns)}")

print(f"\nLoaded data for {len(run_data)} acre runs")


In [None]:
# Example: Plot training curves for all acre runs
plt.figure(figsize=(15, 10))

# Plot reward progression for each run
plt.subplot(2, 2, 1)
for run_info in run_data:
    history = run_info['history']
    if 'reward_mean' in history.columns:
        plt.plot(history['reward_mean'], label=f"{run_info['name']}", alpha=0.7)
plt.xlabel('Episode')
plt.ylabel('Reward Mean')
plt.title('Training Reward Progression')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot loss progression
plt.subplot(2, 2, 2)
for run_info in run_data:
    history = run_info['history']
    if 'loss' in history.columns:
        plt.plot(history['loss'], label=f"{run_info['name']}", alpha=0.7)
plt.xlabel('Episode')
plt.ylabel('Loss')
plt.title('Training Loss Progression')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot penalty progression
plt.subplot(2, 2, 3)
for run_info in run_data:
    history = run_info['history']
    if 'penalty' in history.columns:
        plt.plot(history['penalty'], label=f"{run_info['name']}", alpha=0.7)
plt.xlabel('Episode')
plt.ylabel('Penalty')
plt.title('Penalty Progression')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot penalized reward vs base reward
plt.subplot(2, 2, 4)
for run_info in run_data:
    history = run_info['history']
    if 'penalized_reward' in history.columns and 'reward_mean' in history.columns:
        plt.plot(history['reward_mean'], label=f"{run_info['name']} (base)", alpha=0.7)
        plt.plot(history['penalized_reward'], label=f"{run_info['name']} (penalized)", alpha=0.7, linestyle='--')
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.title('Base vs Penalized Rewards')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Show summary statistics
print("\n=== Run Summary Statistics ===")
for run_info in run_data:
    print(f"\nRun: {run_info['name']}")
    summary = run_info['summary']
    config = run_info['config']
    
    # Show key metrics if available
    metrics_to_show = ['reward_mean', 'penalized_reward', 'penalty', 'loss']
    for metric in metrics_to_show:
        if metric in summary:
            print(f"  Final {metric}: {summary[metric]:.4f}")
    
    # Show key config parameters
    config_to_show = ['learning_rate', 'num_episodes', 'batch_size', 'penalty_per_word']
    for param in config_to_show:
        if param in config:
            print(f"  {param}: {config[param]}")
    
    # Show word penalties if configured
    if 'word_penalty' in config and config['word_penalty'].get('enabled'):
        words = config['word_penalty'].get('words', [])
        penalty_per_word = config['word_penalty'].get('penalty_per_word', 0)
        print(f"  Penalized words: {words} (penalty: {penalty_per_word})")


In [None]:
# Optional: Access run artifacts (like rollout data)
print("=== Available Artifacts ===")
for run_info in run_data:
    run = run_info['run']
    artifacts = run.logged_artifacts()
    
    if artifacts:
        print(f"\nRun {run_info['name']} has {len(artifacts)} artifacts:")
        for artifact in artifacts:
            print(f"  - {artifact.name} (type: {artifact.type}, size: {artifact.size})")
    else:
        print(f"\nRun {run_info['name']} has no artifacts")

# Example: Download and examine a rollout artifact (if available)
def examine_rollout_artifact(run_info):
    """Helper function to download and examine rollout data"""
    run = run_info['run']
    artifacts = run.logged_artifacts()
    
    rollout_artifacts = [a for a in artifacts if a.type == "dataset" and "rollouts" in a.name]
    
    if rollout_artifacts:
        print(f"\nDownloading rollout data for {run_info['name']}...")
        artifact = rollout_artifacts[0]  # Use the first rollout artifact
        artifact_dir = artifact.download()
        
        # Look for JSONL files in the artifact
        import os
        import json
        
        jsonl_files = [f for f in os.listdir(artifact_dir) if f.endswith('.jsonl')]
        if jsonl_files:
            jsonl_path = os.path.join(artifact_dir, jsonl_files[0])
            print(f"Found rollout file: {jsonl_path}")
            
            # Load first few rollouts as example
            rollouts = []
            with open(jsonl_path, 'r') as f:
                for i, line in enumerate(f):
                    if i >= 3:  # Only load first 3 for example
                        break
                    rollouts.append(json.loads(line))
            
            print(f"Loaded {len(rollouts)} example rollouts")
            if rollouts:
                print(f"Example rollout keys: {list(rollouts[0].keys())}")
                if 'rollout' in rollouts[0]:
                    print(f"Rollout data keys: {list(rollouts[0]['rollout'].keys())}")
            
            return rollouts
    return None

# Uncomment the following lines to download and examine rollout data:
# if run_data:
#     rollouts = examine_rollout_artifact(run_data[0])  # Examine first run's rollouts


In [None]:
SHOGGOTH_IDS = [
    "w4tt1jsu",
    "zfl5u3fd",
    "1sde8jph",
    "8fgq0wro",
    "w3euhlb8",
    ]

NO_PENALTY_IDS = [
    "9f3aqg2s",
    "x4ycjhu3",
    "tp2lfh4y",
    "1acq8chp",
    "put7hoex",
    "w6gsb6ok",
    "ip9tix4l",
    "5n394zv2",
]

PENALTY_IDS = [
    "zmhndno2",
    "sz2dldwu",
    "knaehdyp",
    "tty1nb3s",
    "z5f8wxgt",
    "eckpvxsz",
    "a11zhjue",
    "9wyk1cnx",
]

In [None]:
# Load detailed data for each acre run
run_data = []

for run in acre_runs:
    print(f"\nLoading data for run: {run.name}")
    
    # Get run configuration
    config = run.config
    
    # Get run summary (final metrics)
    summary = run.summary
    
    # Get run history (all logged metrics over time)
    history = run.history()
    
    # Store run information
    run_info = {
        'run': run,
        'name': run.name,
        'id': run.id,
        'state': run.state,
        'config': config,
        'summary': summary,
        'history': history
    }
    
    run_data.append(run_info)
    
    print(f"  - Config keys: {list(config.keys())}")
    print(f"  - Summary keys: {list(summary.keys())}")
    print(f"  - History shape: {history.shape}")
    print(f"  - History columns: {list(history.columns)}")

print(f"\nLoaded data for {len(run_data)} acre runs")


In [None]:
# Example: Plot training curves for all acre runs
plt.figure(figsize=(15, 10))

# Plot reward progression for each run
plt.subplot(2, 2, 1)
for run_info in run_data:
    history = run_info['history']
    if 'reward_mean' in history.columns:
        plt.plot(history['reward_mean'], label=f"{run_info['name']}", alpha=0.7)
plt.xlabel('Episode')
plt.ylabel('Reward Mean')
plt.title('Training Reward Progression')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot loss progression
plt.subplot(2, 2, 2)
for run_info in run_data:
    history = run_info['history']
    if 'loss' in history.columns:
        plt.plot(history['loss'], label=f"{run_info['name']}", alpha=0.7)
plt.xlabel('Episode')
plt.ylabel('Loss')
plt.title('Training Loss Progression')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot penalty progression
plt.subplot(2, 2, 3)
for run_info in run_data:
    history = run_info['history']
    if 'penalty' in history.columns:
        plt.plot(history['penalty'], label=f"{run_info['name']}", alpha=0.7)
plt.xlabel('Episode')
plt.ylabel('Penalty')
plt.title('Penalty Progression')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot penalized reward vs base reward
plt.subplot(2, 2, 4)
for run_info in run_data:
    history = run_info['history']
    if 'penalized_reward' in history.columns and 'reward_mean' in history.columns:
        plt.plot(history['reward_mean'], label=f"{run_info['name']} (base)", alpha=0.7)
        plt.plot(history['penalized_reward'], label=f"{run_info['name']} (penalized)", alpha=0.7, linestyle='--')
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.title('Base vs Penalized Rewards')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Show summary statistics
print("\n=== Run Summary Statistics ===")
for run_info in run_data:
    print(f"\nRun: {run_info['name']}")
    summary = run_info['summary']
    config = run_info['config']
    
    # Show key metrics if available
    metrics_to_show = ['reward_mean', 'penalized_reward', 'penalty', 'loss']
    for metric in metrics_to_show:
        if metric in summary:
            print(f"  Final {metric}: {summary[metric]:.4f}")
    
    # Show key config parameters
    config_to_show = ['learning_rate', 'num_episodes', 'batch_size', 'penalty_per_word']
    for param in config_to_show:
        if param in config:
            print(f"  {param}: {config[param]}")
    
    # Show word penalties if configured
    if 'word_penalty' in config and config['word_penalty'].get('enabled'):
        words = config['word_penalty'].get('words', [])
        penalty_per_word = config['word_penalty'].get('penalty_per_word', 0)
        print(f"  Penalized words: {words} (penalty: {penalty_per_word})")


In [None]:
# Optional: Access run artifacts (like rollout data)
print("=== Available Artifacts ===")
for run_info in run_data:
    run = run_info['run']
    artifacts = run.logged_artifacts()
    
    if artifacts:
        print(f"\nRun {run_info['name']} has {len(artifacts)} artifacts:")
        for artifact in artifacts:
            print(f"  - {artifact.name} (type: {artifact.type}, size: {artifact.size})")
    else:
        print(f"\nRun {run_info['name']} has no artifacts")

# Example: Download and examine a rollout artifact (if available)
def examine_rollout_artifact(run_info):
    """Helper function to download and examine rollout data"""
    run = run_info['run']
    artifacts = run.logged_artifacts()
    
    rollout_artifacts = [a for a in artifacts if a.type == "dataset" and "rollouts" in a.name]
    
    if rollout_artifacts:
        print(f"\nDownloading rollout data for {run_info['name']}...")
        artifact = rollout_artifacts[0]  # Use the first rollout artifact
        artifact_dir = artifact.download()
        
        # Look for JSONL files in the artifact
        import os
        import json
        
        jsonl_files = [f for f in os.listdir(artifact_dir) if f.endswith('.jsonl')]
        if jsonl_files:
            jsonl_path = os.path.join(artifact_dir, jsonl_files[0])
            print(f"Found rollout file: {jsonl_path}")
            
            # Load first few rollouts as example
            rollouts = []
            with open(jsonl_path, 'r') as f:
                for i, line in enumerate(f):
                    if i >= 3:  # Only load first 3 for example
                        break
                    rollouts.append(json.loads(line))
            
            print(f"Loaded {len(rollouts)} example rollouts")
            if rollouts:
                print(f"Example rollout keys: {list(rollouts[0].keys())}")
                if 'rollout' in rollouts[0]:
                    print(f"Rollout data keys: {list(rollouts[0]['rollout'].keys())}")
            
            return rollouts
    return None

# Uncomment the following lines to download and examine rollout data:
# if run_data:
#     rollouts = examine_rollout_artifact(run_data[0])  # Examine first run's rollouts
