[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/prashantkul/weak-to-strong-gen/blob/main/Prashant_Kulkarni_final_icl_w2s.ipynb)

# Weak-to-Strong Generalization Experiments

**Click the badge above to run this notebook in Google Colab!**

---


This notebook demonstrates how to clone the repository from git and run all experiments.

**What this notebook does:**
- Clones the code repository from GitHub
- Installs all dependencies
- Sets up environment and configuration
- Runs all 6 experiments (baseline, disclaimer, CoT × 2 models)
- Generates comprehensive visualization

**Estimated runtime:** 75-90 minutes total

**Requirements:**
- OpenRouter API key (get from https://openrouter.ai/settings/keys)
- Python 3.10+
- ~$150-200 in API credits (if running all experiments)

---
## Part 1: Repository Setup

This section clones the repository and installs dependencies.

In [None]:
# Cell 1: Clone Repository from GitHub
# This will download all the code from git
# Replace prashantkul/weak-to-strong-gen with your actual GitHub repository

import os
from pathlib import Path

# Check if we're in Colab
IN_COLAB = 'COLAB_GPU' in os.environ

if IN_COLAB:
    print("Running in Google Colab")
    # Clone repository to Colab
    !git clone https://github.com/prashantkul/weak-to-strong-gen.git
    %cd astra
    REPO_PATH = Path('/content/astra')
else:
    print("Running in local Jupyter")
    # For local Jupyter, assume we're already in the repo directory
    # If not, uncomment the lines below and update the path
    # !git clone https://github.com/prashantkul/weak-to-strong-gen.git
    # %cd astra
    REPO_PATH = Path.cwd()

print(f"Repository path: {REPO_PATH}")
print(f"Repository exists: {REPO_PATH.exists()}")

In [None]:
# Cell 2: Install Dependencies
# This installs all required Python packages

print("Installing dependencies...")
print("This may take 2-3 minutes")
print("="*70)

# Install from requirements.txt
!pip install -q -r requirements.txt

print("="*70)
print("✓ Dependencies installed successfully")

In [None]:
# Cell 3: Add Repository to Python Path
# This allows us to import the src modules

import sys
from pathlib import Path

# Add repository to path if not already there
repo_path = str(REPO_PATH)
if repo_path not in sys.path:
    sys.path.insert(0, repo_path)
    print(f"✓ Added {repo_path} to Python path")
else:
    print(f"✓ {repo_path} already in Python path")

# Verify we can import the modules
try:
    from src import Config, DatasetManager, ExperimentRunner
    from notebook_experiments import run_baseline_sweep, run_disclaimer_sweep, run_cot_sweep
    print("✓ Successfully imported all modules")
except ImportError as e:
    print(f"✗ Import failed: {e}")
    print("Make sure you're in the repository directory")

---
## Part 2: Environment Configuration

This section sets up API keys and configuration.

In [None]:
# Cell 4: Set API Keys
# For Google Colab: Use Colab Secrets (RECOMMENDED)

import os

# Check if running in Colab
IN_COLAB = 'COLAB_GPU' in os.environ

if IN_COLAB:
    # COLAB METHOD: Use Secrets
    # 1. Click the key icon (🔑) in the left sidebar
    # 2. Add two secrets:
    #    - Name: OPENROUTER_API_KEY, Value: your primary key
    #    - Name: OPENROUTER_API_KEY_BACKUP, Value: your backup key
    # 3. Enable notebook access for both secrets
    # 4. Run this cell
    
    from google.colab import userdata
    os.environ['OPENROUTER_API_KEY'] = userdata.get('OPENROUTER_API_KEY')
    os.environ['OPENROUTER_API_KEY_BACKUP'] = userdata.get('OPENROUTER_API_KEY_BACKUP')
    print("✓ Loaded API keys from Colab Secrets")
    
else:
    # LOCAL JUPYTER: Use .env file
    from dotenv import load_dotenv
    if load_dotenv():
        print("✓ Loaded API keys from .env file")
    else:
        print("⚠️ No .env file found")
        print("\nTo use locally:")
        print("  1. Copy .env.example to .env")
        print("  2. Edit .env and add your API keys")
        print("  3. Restart kernel and re-run this cell")

# Verify keys are set
if os.getenv('OPENROUTER_API_KEY'):
    key_preview = os.getenv('OPENROUTER_API_KEY')[:15] + "..."
    print(f"✓ Primary API key set: {key_preview}")
else:
    print("✗ OPENROUTER_API_KEY not set!")
    print("\n⚠️ Please set up your API keys before continuing")

In [None]:
# Cell 5: Initialize Configuration
# Load configuration and set up the environment

from src import Config

# Load configuration from environment variables
config = Config.from_env()
config.setup_environment()

print("="*70)
print("CONFIGURATION")
print("="*70)
print(f"Weak Model:     {config.weak_model}")
print(f"Strong Model:   {config.strong_model}")
print(f"Temperature:    {config.temperature}")
print(f"Max Parallel:   {config.max_parallel_requests}")
print(f"Cache Dir:      {config.cache_dir}")
print("="*70)
print("✓ Configuration loaded successfully")

---
## Part 3: Dataset Loading

Load and verify the TruthfulQA dataset.

In [None]:
# Cell 6: Load Dataset
# This loads the TruthfulQA dataset and splits it into test and few-shot pool

from src import DatasetManager

print("Loading TruthfulQA dataset...")
print("This may take 30-60 seconds on first load")
print("="*70)

dm = DatasetManager()
test_data, few_shot_pool, split = dm.load_split()

print("="*70)
print("DATASET LOADED")
print("="*70)
print(f"Test Set Size:      {len(test_data)} questions")
print(f"Few-Shot Pool Size: {len(few_shot_pool)} questions")
print(f"Split Used:         {split}")
print("="*70)

# Show example question
print("\nExample Question:")
print(test_data[0].question)
print(f"\nCorrect Answer: {test_data[0].answer}")
print("\n✓ Dataset loaded successfully")

In [None]:
# Cell 7: Quick Test (Optional)
# Run a quick test to verify everything is working
# This will make one API call to test the setup

from test_notebook_functions import test_notebook_setup

print("Running quick test...")
print("This will make one API call to verify the setup")
print("="*70)

test_passed = await test_notebook_setup()

if test_passed:
    print("\n" + "="*70)
    print("✓ ALL TESTS PASSED")
    print("✓ Ready to run experiments!")
    print("="*70)
else:
    print("\n" + "="*70)
    print("✗ TESTS FAILED")
    print("Please check the errors above and fix them before proceeding")
    print("="*70)

---
## Part 4: Baseline Experiments

Run baseline experiments for both model pairs (8B→405B and 8B→70B) across K∈{0,2,5,10}.

In [None]:
# Cell 8: Run 405B Baseline Experiments
# Tests how 405B performs with weak supervision at different K values
# Expected runtime: 15-20 minutes

from notebook_experiments import run_baseline_sweep

print("="*70)
print("EXPERIMENT 1: 405B BASELINE")
print("="*70)
print("Testing 8B→405B with K={0,2,5,10}")
print("Expected runtime: 15-20 minutes")
print("="*70)

results_405b_baseline = await run_baseline_sweep(
    model_pair="8b_to_405b",
    k_values=[0, 2, 5, 10],
    save_results=True
)

print("\n" + "="*70)
print("✓ 405B BASELINE COMPLETE")
print("="*70)
print("\nResults Summary:")
for k, result in results_405b_baseline['pgr_results'].items():
    print(f"  K={k}: PGR = {result.pgr:.3f} ({result.pgr_percentage})")
print("="*70)

In [None]:
# Cell 9: Run 70B Baseline Experiments
# Tests how 70B performs with weak supervision at different K values
# Expected runtime: 10-15 minutes

print("="*70)
print("EXPERIMENT 2: 70B BASELINE")
print("="*70)
print("Testing 8B→70B with K={0,2,5,10}")
print("Expected runtime: 10-15 minutes")
print("="*70)

results_70b_baseline = await run_baseline_sweep(
    model_pair="8b_to_70b",
    k_values=[0, 2, 5, 10],
    save_results=True
)

print("\n" + "="*70)
print("✓ 70B BASELINE COMPLETE")
print("="*70)
print("\nResults Summary:")
for k, result in results_70b_baseline['pgr_results'].items():
    print(f"  K={k}: PGR = {result.pgr:.3f} ({result.pgr_percentage})")
print("="*70)

# Compare the two models
print("\nComparison at K=10:")
pgr_405b = results_405b_baseline['pgr_results'][10].pgr
pgr_70b = results_70b_baseline['pgr_results'][10].pgr
print(f"  405B: {pgr_405b:.3f} ({pgr_405b*100:.1f}%)")
print(f"  70B:  {pgr_70b:.3f} ({pgr_70b*100:.1f}%)")
print(f"  Gap:  {(pgr_405b - pgr_70b)*100:.1f} percentage points")

---
## Part 5: Disclaimer Experiments

Run disclaimer experiments that add a metacognitive warning about weak label quality.

In [None]:
# Cell 10: Run 405B Disclaimer Experiments
# Tests if warning about weak labels helps 405B
# Expected runtime: 10 minutes

from notebook_experiments import run_disclaimer_sweep

print("="*70)
print("EXPERIMENT 3: 405B DISCLAIMER")
print("="*70)
print("Testing metacognitive prompt for 405B")
print("Expected runtime: 10 minutes")
print("="*70)

# Find the baseline experiment path
baseline_405b_path = results_405b_baseline['experiment_path']
print(f"Using baseline from: {baseline_405b_path}\n")

results_405b_disclaimer = await run_disclaimer_sweep(
    model_pair="8b_to_405b",
    k_values=[0, 2, 5, 10],
    baseline_exp_path=baseline_405b_path,
    save_results=True
)

print("\n" + "="*70)
print("✓ 405B DISCLAIMER COMPLETE")
print("="*70)
print("\nResults Summary (Delta from Baseline):")
for k, result in results_405b_disclaimer['pgr_results'].items():
    baseline_pgr = results_405b_baseline['pgr_results'][k].pgr
    delta = result.pgr - baseline_pgr
    sign = "+" if delta >= 0 else ""
    print(f"  K={k}: PGR = {result.pgr:.3f} (Δ = {sign}{delta:.3f})")
print("="*70)

In [None]:
# Cell 11: Run 70B Disclaimer Experiments
# Tests if warning about weak labels helps 70B
# Expected runtime: 8 minutes

print("="*70)
print("EXPERIMENT 4: 70B DISCLAIMER")
print("="*70)
print("Testing metacognitive prompt for 70B")
print("Expected runtime: 8 minutes")
print("="*70)

# Find the baseline experiment path
baseline_70b_path = results_70b_baseline['experiment_path']
print(f"Using baseline from: {baseline_70b_path}\n")

results_70b_disclaimer = await run_disclaimer_sweep(
    model_pair="8b_to_70b",
    k_values=[0, 2, 5, 10],
    baseline_exp_path=baseline_70b_path,
    save_results=True
)

print("\n" + "="*70)
print("✓ 70B DISCLAIMER COMPLETE")
print("="*70)
print("\nResults Summary (Delta from Baseline):")
for k, result in results_70b_disclaimer['pgr_results'].items():
    baseline_pgr = results_70b_baseline['pgr_results'][k].pgr
    delta = result.pgr - baseline_pgr
    sign = "+" if delta >= 0 else ""
    print(f"  K={k}: PGR = {result.pgr:.3f} (Δ = {sign}{delta:.3f})")
print("="*70)

# Show K-dependent reversal
print("\n⚠️ Notice the K-dependent reversal:")
for k in [0, 2, 5, 10]:
    baseline_pgr = results_70b_baseline['pgr_results'][k].pgr
    disc_pgr = results_70b_disclaimer['pgr_results'][k].pgr
    delta = disc_pgr - baseline_pgr
    effect = "helps" if delta > 0 else "hurts" if delta < 0 else "neutral"
    print(f"  K={k}: {effect} ({delta:+.3f})")

---
## Part 6: Chain-of-Thought Label Generation

Generate reasoning demonstrations for CoT experiments.

In [None]:
# Cell 12: Generate Weak CoT Labels (8B with Reasoning)
# Generate reasoning demonstrations from the weak 8B model
# Expected runtime: 5 minutes

from src import ModelEvaluator, get_model_pair
from src.model_evaluator import ModelResponse
import json
from datetime import datetime
from pathlib import Path

print("="*70)
print("GENERATING WEAK COT LABELS (8B)")
print("="*70)
print("Generating reasoning demonstrations from 8B model")
print("Expected runtime: 5 minutes")
print("="*70)

# Get weak model
pair = get_model_pair("8b_to_405b")
weak_model_id = pair.weak_model

# Create evaluator with CoT enabled
evaluator = ModelEvaluator(config, use_cot=True)

# Generate labels for first 20 examples from few-shot pool
questions_to_label = few_shot_pool[:20]
questions = [(q.question_id, q.question) for q in questions_to_label]

print(f"\nGenerating reasoning for {len(questions)} questions...\n")

weak_cot_responses = await evaluator.evaluate_batch(
    questions=questions,
    model_id=weak_model_id,
    few_shot_prompt=None,
    verbose=True
)

# Calculate accuracy
gt_map = {q.question_id: q.answer for q in questions_to_label}
num_correct = sum(1 for r in weak_cot_responses if r.answer == gt_map[r.question_id])
accuracy = num_correct / len(weak_cot_responses)

print(f"\n8B CoT Accuracy: {accuracy:.1%} ({num_correct}/{len(weak_cot_responses)})")

# Save to data directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = Path("data/cot_weak_labels")
output_dir.mkdir(exist_ok=True, parents=True)
output_file = output_dir / f"8b_cot_weak_labels_{timestamp}.json"

data = {
    "metadata": {
        "timestamp": timestamp,
        "weak_model": weak_model_id,
        "temperature": config.temperature,
        "num_labels": len(weak_cot_responses),
        "accuracy": accuracy,
        "use_cot": True
    },
    "weak_labels": [r.model_dump() for r in weak_cot_responses]
}

with open(output_file, 'w') as f:
    json.dump(data, f, indent=2)

weak_cot_path = str(output_file)
print(f"\n✓ Saved to: {weak_cot_path}")
print("="*70)

In [None]:
# Cell 13: Generate Gold CoT Labels (405B with Reasoning)
# Generate reasoning demonstrations from the strong 405B model
# Expected runtime: 10 minutes

print("="*70)
print("GENERATING GOLD COT LABELS (405B)")
print("="*70)
print("Generating reasoning demonstrations from 405B model")
print("Expected runtime: 10 minutes")
print("="*70)

# Get strong model
strong_model_id = pair.strong_model

print(f"\nGenerating reasoning for {len(questions)} questions...\n")

gold_cot_responses = await evaluator.evaluate_batch(
    questions=questions,
    model_id=strong_model_id,
    few_shot_prompt=None,
    verbose=True
)

# Calculate accuracy
num_correct = sum(1 for r in gold_cot_responses if r.answer == gt_map[r.question_id])
accuracy = num_correct / len(gold_cot_responses)

print(f"\n405B CoT Accuracy: {accuracy:.1%} ({num_correct}/{len(gold_cot_responses)})")

# Save to data directory
output_dir = Path("data/cot_gold_labels")
output_dir.mkdir(exist_ok=True, parents=True)
output_file = output_dir / f"405b_cot_gold_labels_{timestamp}.json"

data = {
    "metadata": {
        "timestamp": timestamp,
        "strong_model": strong_model_id,
        "temperature": config.temperature,
        "num_labels": len(gold_cot_responses),
        "accuracy": accuracy,
        "use_cot": True
    },
    "gold_labels": [r.model_dump() for r in gold_cot_responses]
}

with open(output_file, 'w') as f:
    json.dump(data, f, indent=2)

gold_cot_path = str(output_file)
print(f"\n✓ Saved to: {gold_cot_path}")
print("="*70)

---
## Part 7: Chain-of-Thought Experiments

Run CoT experiments using the reasoning demonstrations generated above.

In [None]:
# Cell 14: Run 405B CoT Experiments
# Tests if reasoning demonstrations help 405B
# Expected runtime: 10 minutes

from notebook_experiments import run_cot_sweep

print("="*70)
print("EXPERIMENT 5: 405B CHAIN-OF-THOUGHT")
print("="*70)
print("Testing CoT with reasoning demonstrations for 405B")
print("Expected runtime: 10 minutes")
print("="*70)

results_405b_cot = await run_cot_sweep(
    model_pair="8b_to_405b",
    k_values=[0, 2, 5, 10],
    baseline_exp_path=baseline_405b_path,
    weak_cot_labels_path=weak_cot_path,
    gold_cot_labels_path=gold_cot_path,
    save_results=True
)

print("\n" + "="*70)
print("✓ 405B COT COMPLETE")
print("="*70)
print("\nResults Summary (Delta from Baseline):")
for k, result in results_405b_cot['pgr_results'].items():
    baseline_pgr = results_405b_baseline['pgr_results'][k].pgr
    delta = result.pgr - baseline_pgr
    sign = "+" if delta >= 0 else ""
    print(f"  K={k}: PGR = {result.pgr:.3f} (Δ = {sign}{delta:.3f})")
print("="*70)

# Show crossover effect
print("\n⚠️ Notice the crossover effect:")
for k in [0, 2, 5, 10]:
    baseline_pgr = results_405b_baseline['pgr_results'][k].pgr
    cot_pgr = results_405b_cot['pgr_results'][k].pgr
    delta = cot_pgr - baseline_pgr
    effect = "helps" if delta > 0.02 else "hurts" if delta < -0.02 else "neutral"
    print(f"  K={k}: {effect} ({delta:+.3f})")

In [None]:
# Cell 15: Run 70B CoT Experiments
# Tests if reasoning demonstrations help 70B
# Expected runtime: 8 minutes

print("="*70)
print("EXPERIMENT 6: 70B CHAIN-OF-THOUGHT")
print("="*70)
print("Testing CoT with reasoning demonstrations for 70B")
print("Expected runtime: 8 minutes")
print("="*70)

results_70b_cot = await run_cot_sweep(
    model_pair="8b_to_70b",
    k_values=[0, 2, 5, 10],
    baseline_exp_path=baseline_70b_path,
    weak_cot_labels_path=weak_cot_path,
    gold_cot_labels_path=gold_cot_path,
    save_results=True
)

print("\n" + "="*70)
print("✓ 70B COT COMPLETE")
print("="*70)
print("\nResults Summary (Delta from Baseline):")
for k, result in results_70b_cot['pgr_results'].items():
    baseline_pgr = results_70b_baseline['pgr_results'][k].pgr
    delta = result.pgr - baseline_pgr
    sign = "+" if delta >= 0 else ""
    print(f"  K={k}: PGR = {result.pgr:.3f} (Δ = {sign}{delta:.3f})")
print("="*70)

# Show that CoT consistently hurts 70B
print("\n⚠️ Notice CoT consistently hurts 70B at K>0:")
for k in [0, 2, 5, 10]:
    baseline_pgr = results_70b_baseline['pgr_results'][k].pgr
    cot_pgr = results_70b_cot['pgr_results'][k].pgr
    delta = cot_pgr - baseline_pgr
    effect = "helps" if delta > 0 else "hurts" if delta < 0 else "neutral"
    print(f"  K={k}: {effect} ({delta:+.3f})")

---
## Part 8: Comprehensive Visualization

Create final 6-panel comparison visualization.

In [None]:
# Cell 16: Generate Comprehensive Comparison Visualization
# Creates a 6-panel visualization comparing all experiments

print("="*70)
print("CREATING COMPREHENSIVE VISUALIZATION")
print("="*70)

# Run the visualization script
!python create_final_comparison.py

print("\n✓ Visualization created: results/final_comprehensive_comparison.png")
print("="*70)

In [None]:
# Cell 17: Display Visualization
# Show the comprehensive comparison

from IPython.display import Image, display

print("Displaying comprehensive comparison...\n")
display(Image('results/final_comprehensive_comparison.png', width=1200))

---
## Part 9: Results Summary

Display final results table with all experiments.

In [None]:
# Cell 18: Create Results Summary Table
# Display all results in a comprehensive table

import pandas as pd

print("="*80)
print("COMPLETE RESULTS SUMMARY")
print("="*80)

# Create summary data
summary_data = []
for k in [0, 2, 5, 10]:
    summary_data.append({
        "K": k,
        "405B Baseline": f"{results_405b_baseline['pgr_results'][k].pgr:.3f}",
        "405B Disclaimer": f"{results_405b_disclaimer['pgr_results'][k].pgr:.3f}",
        "405B CoT": f"{results_405b_cot['pgr_results'][k].pgr:.3f}",
        "70B Baseline": f"{results_70b_baseline['pgr_results'][k].pgr:.3f}",
        "70B Disclaimer": f"{results_70b_disclaimer['pgr_results'][k].pgr:.3f}",
        "70B CoT": f"{results_70b_cot['pgr_results'][k].pgr:.3f}",
    })

summary_df = pd.DataFrame(summary_data)
print(summary_df.to_string(index=False))
print("="*80)

# Display deltas from baseline
print("\nDELTA FROM BASELINE (Intervention Effect)")
print("="*80)

delta_data = []
for k in [0, 2, 5, 10]:
    delta_data.append({
        "K": k,
        "405B Disclaimer": f"{results_405b_disclaimer['pgr_results'][k].pgr - results_405b_baseline['pgr_results'][k].pgr:+.3f}",
        "405B CoT": f"{results_405b_cot['pgr_results'][k].pgr - results_405b_baseline['pgr_results'][k].pgr:+.3f}",
        "70B Disclaimer": f"{results_70b_disclaimer['pgr_results'][k].pgr - results_70b_baseline['pgr_results'][k].pgr:+.3f}",
        "70B CoT": f"{results_70b_cot['pgr_results'][k].pgr - results_70b_baseline['pgr_results'][k].pgr:+.3f}",
    })

delta_df = pd.DataFrame(delta_data)
print(delta_df.to_string(index=False))
print("="*80)

---
## Part 10: Key Findings

Summary of novel research findings.

In [None]:
# Cell 19: Display Key Findings

print("="*80)
print("KEY FINDINGS")
print("="*80)

print("""
🔬 FINDING 1: Scaling Threshold for Robustness
   - 405B maintains PGR ≥ 0.984 across all K values
   - 70B degrades to PGR = 0.864 at K=10
   - Threshold exists between 70B and 405B parameters (~6× difference)

🔬 FINDING 2: K-Dependent Reversal Effect (Disclaimer)
   - Disclaimer helps at low K (+0.080 at K=2 for 70B)
   - Disclaimer hurts at high K (-0.017 at K=10 for 70B)
   - First documented reversal in metacognitive prompting!

🔬 FINDING 3: CoT Crossover Effect
   - 405B: CoT hurts at K=0 (-0.164), helps at K>0 (+0.095 at K=2)
   - 70B: CoT consistently hurts at K>0
   - Large models can filter noisy reasoning, medium models cannot

🔬 FINDING 4: Superelicitation
   - 405B with CoT at K=10: PGR = 1.048 (>100%)
   - Weak supervision + reasoning can exceed gold supervision quality

""")
print("="*80)
print("\n✅ All experiments complete!")
print("\nFor detailed findings, see: KEY_FINDINGS.md")
print("For visualization: results/final_comprehensive_comparison.png")
print("="*80)

---
## Updating Code from Git

To pull the latest changes from the repository in the future:

In [None]:
# Cell 20: Pull Latest Changes from Git (Optional)
# Run this cell to update the code to the latest version

print("Pulling latest changes from git...")
!git pull origin main

print("\n⚠️ After pulling updates, restart the kernel to reload modules:")
print("   Jupyter: Kernel > Restart")
print("   Colab: Runtime > Restart runtime")