# Refusal Directions Analysis

This notebook provides analysis and visualization tools for refusal directions experiments.

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
from pathlib import Path

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)

## 1. Load and Analyze Directions

In [None]:
# Load extracted directions
directions_path = "../results/best_direction.pt"
directions_data = torch.load(directions_path)

best_layer = directions_data["layer"]
direction = directions_data["direction"]
analysis = directions_data.get("analysis", {})

print(f"Best Layer: {best_layer}")
print(f"Direction Shape: {direction.shape}")
print(f"Direction Magnitude: {direction.norm().item():.4f}")

In [None]:
# Visualize direction magnitudes
direction_np = direction.cpu().numpy()

fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Histogram of values
axes[0].hist(direction_np, bins=50, alpha=0.7)
axes[0].set_xlabel('Direction Component Value')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Distribution of Direction Components')
axes[0].axvline(0, color='red', linestyle='--', alpha=0.5)

# Top dimensions
top_k = 20
abs_values = np.abs(direction_np)
top_indices = np.argsort(abs_values)[-top_k:][::-1]
top_values = direction_np[top_indices]

colors = ['red' if v < 0 else 'blue' for v in top_values]
axes[1].bar(range(top_k), top_values, color=colors, alpha=0.7)
axes[1].set_xlabel('Rank')
axes[1].set_ylabel('Component Value')
axes[1].set_title(f'Top {top_k} Direction Components (by magnitude)')
axes[1].axhline(0, color='black', linestyle='-', linewidth=0.5)

plt.tight_layout()
plt.show()

print(f"\nSparsity: {(abs_values < 0.01 * abs_values.max()).sum() / len(abs_values):.2%}")
print(f"Max absolute value: {abs_values.max():.4f}")
print(f"Mean absolute value: {abs_values.mean():.4f}")

## 2. Analyze Ablation Results

In [None]:
# Load ablation results
with open("../results/ablation_results.json", 'r') as f:
    ablation_results = json.load(f)

print("Baseline vs Ablation Performance:")
print("="*50)

baseline = ablation_results["baseline"]
ablation = ablation_results["ablation"]

print(f"Baseline Refusal Rate: {baseline['refusal_rate']:.2%}")
print(f"Ablation Refusal Rate: {ablation['refusal_rate']:.2%}")
print(f"\nReduction: {baseline['refusal_rate'] - ablation['refusal_rate']:.2%}")

if 'safety_rate' in baseline:
    print(f"\nBaseline Safety Rate: {baseline['safety_rate']:.2%}")
    print(f"Ablation Safety Rate: {ablation['safety_rate']:.2%}")

print(f"\nAttack Success Rate: {ablation_results['attack_success_rate']:.2%}")

In [None]:
# Visualize by category
categories = list(ablation_results["by_category"].keys())
refusal_rates = [ablation_results["by_category"][cat]["refusal_rate"] for cat in categories]

plt.figure(figsize=(12, 6))
bars = plt.bar(range(len(categories)), refusal_rates, alpha=0.7)

# Color bars by success (lower refusal = more success)
colors = plt.cm.RdYlGn_r(np.array(refusal_rates))
for bar, color in zip(bars, colors):
    bar.set_color(color)

plt.xlabel('Harm Category')
plt.ylabel('Refusal Rate After Ablation')
plt.title('Refusal Rates by Harm Category (After Directional Ablation)')
plt.xticks(range(len(categories)), [c.replace('_', ' ').title() for c in categories], rotation=45, ha='right')
plt.axhline(0.5, color='red', linestyle='--', alpha=0.5, label='50% threshold')
plt.legend()
plt.tight_layout()
plt.show()

print("\nCategory-specific results:")
for cat in sorted(categories, key=lambda x: ablation_results["by_category"][x]["refusal_rate"]):
    rate = ablation_results["by_category"][cat]["refusal_rate"]
    print(f"  {cat.replace('_', ' ').title():30s}: {rate:.2%}")

## 3. Analyze Addition Results

In [None]:
# Load addition results
try:
    with open("../results/addition_results.json", 'r') as f:
        addition_results = json.load(f)
    
    print("Activation Addition Performance:")
    print("="*50)
    
    baseline_add = addition_results["baseline"]
    addition = addition_results["addition"]
    
    print(f"Baseline Refusal Rate (harmless): {baseline_add['refusal_rate']:.2%}")
    print(f"With Addition Refusal Rate:       {addition['refusal_rate']:.2%}")
    print(f"\nInduced Increase: {addition_results['induced_refusal_increase']:.2%}")
    
    # Visualize
    fig, ax = plt.subplots(figsize=(8, 6))
    conditions = ['Baseline\n(No Intervention)', 'Activation\nAddition']
    rates = [baseline_add['refusal_rate'], addition['refusal_rate']]
    
    bars = ax.bar(conditions, rates, alpha=0.7)
    bars[0].set_color('green')
    bars[1].set_color('red')
    
    ax.set_ylabel('Refusal Rate on Harmless Instructions')
    ax.set_title('Effect of Activation Addition on Harmless Instructions')
    ax.set_ylim([0, 1])
    
    # Add value labels
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.1%}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()
    
except FileNotFoundError:
    print("Addition results not found. Run test_addition.py first.")

## 4. Compare All Interventions

In [None]:
# Load all results
results_files = {
    'Baseline': '../results/ablation_results.json',
    'Ablation': '../results/ablation_results.json',
    'Orthogonalization': '../results/ortho_results.json'
}

comparison_data = {}

try:
    for name, path in results_files.items():
        if Path(path).exists():
            with open(path, 'r') as f:
                data = json.load(f)
                if name == 'Baseline':
                    comparison_data[name] = data['baseline']
                else:
                    comparison_data[name] = data.get('ablation' if 'ablation' in data else 'orthogonalized', data['baseline'])
    
    if len(comparison_data) > 1:
        # Create comparison plot
        fig, axes = plt.subplots(1, 2, figsize=(15, 6))
        
        conditions = list(comparison_data.keys())
        refusal_rates = [comparison_data[c]['refusal_rate'] for c in conditions]
        
        # Refusal rates
        bars1 = axes[0].bar(conditions, refusal_rates, alpha=0.7)
        bars1[0].set_color('blue')
        for bar in bars1[1:]:
            bar.set_color('red')
        
        axes[0].set_ylabel('Refusal Rate')
        axes[0].set_title('Refusal Rates Across Interventions')
        axes[0].set_ylim([0, 1])
        
        for bar in bars1:
            height = bar.get_height()
            axes[0].text(bar.get_x() + bar.get_width()/2., height,
                        f'{height:.1%}', ha='center', va='bottom')
        
        # Safety rates (if available)
        if all('safety_rate' in comparison_data[c] for c in conditions):
            safety_rates = [comparison_data[c]['safety_rate'] for c in conditions]
            
            bars2 = axes[1].bar(conditions, safety_rates, alpha=0.7)
            bars2[0].set_color('green')
            for bar in bars2[1:]:
                bar.set_color('orange')
            
            axes[1].set_ylabel('Safety Rate')
            axes[1].set_title('Safety Rates Across Interventions')
            axes[1].set_ylim([0, 1])
            
            for bar in bars2:
                height = bar.get_height()
                axes[1].text(bar.get_x() + bar.get_width()/2., height,
                            f'{height:.1%}', ha='center', va='bottom')
        
        plt.tight_layout()
        plt.show()
        
        # Print summary table
        print("\nSummary Table:")
        print("="*60)
        print(f"{'Condition':<20} {'Refusal Rate':<15} {'Safety Rate':<15}")
        print("-"*60)
        for cond in conditions:
            refusal = comparison_data[cond]['refusal_rate']
            safety = comparison_data[cond].get('safety_rate', 'N/A')
            if isinstance(safety, float):
                print(f"{cond:<20} {refusal:<15.2%} {safety:<15.2%}")
            else:
                print(f"{cond:<20} {refusal:<15.2%} {safety:<15}")
        print("="*60)
        
except Exception as e:
    print(f"Error loading comparison data: {e}")
    print("Run all experiments first to see full comparison.")

## 5. Export Figures for Paper

In [None]:
# Create figures directory
figures_dir = Path("../results/figures")
figures_dir.mkdir(exist_ok=True, parents=True)

print(f"Figures saved to: {figures_dir}")
print("\nYou can re-run the plotting cells above to regenerate and save specific figures.")