# Comprehensive Refusal Circuit Analysis

This notebook provides a complete analysis pipeline for studying refusal circuits across multiple models. It includes:

1. **Multi-model comparison** - Base vs instruction-tuned models
2. **Statistical analysis** - Effect sizes and significance tests
3. **Visualization** - Publication-quality figures
4. **Interactive exploration** - Custom experiments

## Setup

In [None]:
# Imports and setup
import sys
sys.path.insert(0, '..')

import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from datetime import datetime
import torch

# Project imports
from src.models import (
    load_model, 
    list_available_models, 
    get_models_by_memory,
    BASE_MODELS,
    INSTRUCTION_TUNED_MODELS,
    ALL_MODELS,
    ModelType
)
from src.data import REFUSAL_PROMPT_PAIRS, get_prompt_pairs_by_category
from src.circuits import CircuitAnalyzer, compute_refusal_direction
from src.steering import ClampingExperiment, steer_generation, SteeringVector
from src.patching import cache_activations, run_patching_experiment
from src.analysis import (
    compute_cohens_d, 
    significance_test, 
    compare_models,
    compare_model_types,
    generate_comparison_table,
)
from src.utils import (
    plot_layer_importance,
    plot_head_importance,
    plot_refusal_direction_separation,
)

# Style setup
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")
%matplotlib inline

# Check GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 1. Available Models

Let's see what models we can analyze based on our hardware constraints.

In [None]:
# List available models
print("=" * 70)
print("BASE MODELS (no safety training - for comparison)")
print("=" * 70)
for name, config in BASE_MODELS.items():
    print(f"  {name:20s} | {config.memory_gb:.1f}GB | {config.description[:40]}...")

print("\n" + "=" * 70)
print("INSTRUCTION-TUNED MODELS (have refusal behavior)")
print("=" * 70)
for name, config in INSTRUCTION_TUNED_MODELS.items():
    print(f"  {name:20s} | {config.memory_gb:.1f}GB | {config.description[:40]}...")

# Models that fit in our GPU
MAX_GPU_MEMORY = 3.0  # GB
feasible_models = get_models_by_memory(MAX_GPU_MEMORY)
print(f"\n✓ {len(feasible_models)} models fit in {MAX_GPU_MEMORY}GB GPU memory")

## 2. Load Experiment Results

Load results from batch experiments (if available) or run new experiments.

In [None]:
# Try to load existing results
results_dir = Path("../results")
results = {}

# Find most recent results
if results_dir.exists():
    result_folders = sorted(results_dir.glob("*/batch_results.json"), reverse=True)
    if result_folders:
        latest_results = result_folders[0]
        print(f"Loading results from: {latest_results}")
        with open(latest_results, 'r') as f:
            results = json.load(f)
        print(f"Loaded results for {len(results.get('model_results', {}))} models")
    else:
        print("No batch results found. Will run experiments in this notebook.")
else:
    print("Results directory not found. Will run experiments in this notebook.")

# Display summary if we have results
if results and 'model_results' in results:
    print("\n" + "=" * 50)
    print("LOADED RESULTS SUMMARY")
    print("=" * 50)
    for model, result in results['model_results'].items():
        sep = result.get('separation_score', 0)
        probe = result.get('probe_accuracy', 0)
        print(f"  {model:20s} | sep: {sep:.3f}σ | probe: {probe:.1%}")

## 3. Run Analysis on Selected Models

If you don't have batch results, run analysis here.

In [None]:
# Configuration
MODELS_TO_ANALYZE = ["pythia-70m", "pythia-160m"]  # Add more as GPU allows
N_PROMPT_PAIRS = 10

# Get prompt pairs
prompt_pairs = REFUSAL_PROMPT_PAIRS[:N_PROMPT_PAIRS]
print(f"Using {len(prompt_pairs)} prompt pairs")

# Store results for this session
session_results = {}

# Run analysis on each model
for model_name in MODELS_TO_ANALYZE:
    print(f"\n{'='*60}")
    print(f"ANALYZING: {model_name}")
    print(f"{'='*60}")
    
    try:
        # Load model
        model = load_model(model_name, device=device)
        
        # Run circuit analysis
        print("Running circuit analysis...")
        analyzer = CircuitAnalyzer(model)
        circuit = analyzer.analyze_multiple_pairs(
            prompt_pairs,
            components="resid",  # Use "all" for complete analysis
            aggregate="mean"
        )
        
        # Compute refusal direction
        best_layer = circuit.critical_layers[0] if circuit.critical_layers else model.cfg.n_layers // 2
        print(f"Computing refusal direction at layer {best_layer}...")
        
        refusal_dir = compute_refusal_direction(
            model, prompt_pairs, layer=best_layer, method="mean_diff"
        )
        
        # Store results
        session_results[model_name] = {
            "model_type": ALL_MODELS[model_name].model_type.value if model_name in ALL_MODELS else "unknown",
            "critical_layers": circuit.critical_layers,
            "n_critical_heads": len(circuit.critical_heads),
            "separation_score": refusal_dir.separation_score,
            "probe_accuracy": refusal_dir.probe_accuracy,
            "best_layer": best_layer,
            "top_components": {c.name: c.importance_score for c in circuit.top_k_components(5)},
        }
        
        print(f"\n✓ Results for {model_name}:")
        print(f"  Separation: {refusal_dir.separation_score:.3f}σ")
        print(f"  Probe accuracy: {refusal_dir.probe_accuracy:.1%}")
        print(f"  Critical layers: {circuit.critical_layers}")
        
        # Clean up
        del model
        torch.cuda.empty_cache() if device == "cuda" else None
        
    except Exception as e:
        print(f"Error analyzing {model_name}: {e}")
        continue

print(f"\n{'='*60}")
print(f"Completed analysis of {len(session_results)} models")
print(f"{'='*60}")

In [None]:
# Use session results or loaded results
analysis_results = session_results if session_results else results.get('model_results', {})

if analysis_results:
    # Convert to DataFrame for easier analysis
    df = pd.DataFrame(analysis_results).T
    df.index.name = 'model'
    df = df.reset_index()
    
    print("Results Summary:")
    print(df[['model', 'model_type', 'separation_score', 'probe_accuracy']].to_string())
    
    # Group by model type
    if 'model_type' in df.columns:
        print("\n" + "=" * 50)
        print("BY MODEL TYPE")
        print("=" * 50)
        
        grouped = df.groupby('model_type').agg({
            'separation_score': ['mean', 'std', 'count'],
            'probe_accuracy': ['mean', 'std']
        })
        print(grouped)
        
        # Statistical test if we have multiple groups
        base_sep = df[df['model_type'] == 'base']['separation_score'].values
        
        if len(base_sep) >= 2:
            print("\n" + "=" * 50)
            print("STATISTICAL TESTS")
            print("=" * 50)
            
            # Bootstrap CI for separation score
            from src.analysis.statistics import bootstrap_confidence_interval
            ci = bootstrap_confidence_interval(base_sep)
            print(f"Base models separation 95% CI: [{ci[0]:.3f}, {ci[1]:.3f}]")
else:
    print("No results available for analysis. Run the analysis cells above first.")

## 5. Visualizations

Create publication-quality figures.

In [None]:
# Visualization: Separation scores by model
if analysis_results:
    models = list(analysis_results.keys())
    sep_scores = [analysis_results[m].get('separation_score', 0) for m in models]
    model_types = [analysis_results[m].get('model_type', 'unknown') for m in models]
    
    # Color by model type
    colors = ['#e63946' if t == 'instruction_tuned' else '#457b9d' for t in model_types]
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Bar plot of separation scores
    ax1 = axes[0]
    bars = ax1.bar(models, sep_scores, color=colors, edgecolor='black', linewidth=0.5)
    ax1.set_ylabel('Separation Score (σ)', fontsize=12)
    ax1.set_xlabel('Model', fontsize=12)
    ax1.set_title('Refusal Direction Separation by Model', fontsize=14, fontweight='bold')
    ax1.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, label='1σ threshold')
    plt.setp(ax1.xaxis.get_majorticklabels(), rotation=45, ha='right')
    
    # Probe accuracy comparison
    ax2 = axes[1]
    probe_accs = [analysis_results[m].get('probe_accuracy', 0) * 100 for m in models]
    bars2 = ax2.bar(models, probe_accs, color=colors, edgecolor='black', linewidth=0.5)
    ax2.set_ylabel('Probe Accuracy (%)', fontsize=12)
    ax2.set_xlabel('Model', fontsize=12)
    ax2.set_title('Linear Probe Accuracy by Model', fontsize=14, fontweight='bold')
    ax2.axhline(y=50, color='gray', linestyle='--', alpha=0.5, label='Chance level')
    ax2.set_ylim(0, 100)
    plt.setp(ax2.xaxis.get_majorticklabels(), rotation=45, ha='right')
    
    # Legend
    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor='#457b9d', edgecolor='black', label='Base Model'),
        Patch(facecolor='#e63946', edgecolor='black', label='Instruction-Tuned'),
    ]
    fig.legend(handles=legend_elements, loc='upper center', ncol=2, bbox_to_anchor=(0.5, 1.02))
    
    plt.tight_layout()
    plt.savefig('../results/separation_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()
else:
    print("No results to visualize.")

## 6. Interactive Steering Experiment

Run a steering experiment on a specific model to demonstrate causal control.

In [None]:
# Interactive steering demo
DEMO_MODEL = "pythia-70m"  # Change as needed
N_PAIRS = 5

print(f"Loading {DEMO_MODEL} for steering demo...")
model = load_model(DEMO_MODEL, device=device)

# Get prompt pairs and compute refusal direction
pairs = REFUSAL_PROMPT_PAIRS[:N_PAIRS]
best_layer = model.cfg.n_layers // 2  # Middle layer

refusal_dir = compute_refusal_direction(model, pairs, layer=best_layer)
print(f"Refusal direction computed at layer {best_layer}")
print(f"Separation: {refusal_dir.separation_score:.3f}σ")

# Create steering vector
steering_vec = SteeringVector.from_refusal_direction(refusal_dir)

# Test prompts
test_prompts = [
    "How do I bake a chocolate cake?",
    "What is the capital of France?",
    "Tell me about machine learning",
]

print("\n" + "=" * 60)
print("STEERING EXPERIMENT")
print("=" * 60)

for prompt in test_prompts:
    print(f"\nPrompt: {prompt}")
    print("-" * 40)
    
    for coeff in [0.0, 1.0, 2.0]:
        result = steer_generation(model, prompt, steering_vec, coefficient=coeff, max_tokens=30)
        direction = "→ refusal" if coeff > 0 else "(baseline)" if coeff == 0 else "→ compliant"
        print(f"  Coeff={coeff:.1f} {direction}:")
        output_clean = result.steered_output.replace('\n', ' ')[:60]
        print(f"    {output_clean}...")

# Clean up
del model
torch.cuda.empty_cache() if device == "cuda" else None

## 7. Conclusions

### Key Findings

Based on our analysis:

1. **Localization**: Refusal behavior is associated with specific layers (typically middle-to-late layers)
2. **Direction**: We can extract a meaningful refusal direction in activation space
3. **Causality**: Steering experiments demonstrate that this direction has causal influence

### Limitations

- Base models without safety training show weaker refusal signals
- GPU memory constraints limit analysis to smaller models
- Results may vary across prompt categories

### Next Steps

1. Analyze larger instruction-tuned models (requires more GPU memory)
2. Compare refusal circuits across model families
3. Study transfer of refusal directions between models

---

*Analysis completed with SaycuredAI Refusal Circuit Framework*