In [None]:
import os
import json
import torch
import numpy as np
from tqdm.notebook import tqdm
from collections import defaultdict
from utils.utils_exp import run_sae_intervention_on_small_models
from utils.utils_model import get_pythia_70m, get_gpt2_small

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {DEVICE}")

torch.manual_seed(42)
np.random.seed(42)

In [None]:
# Create results directory
results_dir = "results/pythia_gpt2"
os.makedirs(results_dir, exist_ok=True)
print(f"Results will be saved to: {results_dir}")

# Helper function to save results
def save_experiment_results(results, model_name, suffix=""):
    """Save experiment results to JSON file"""
    filename = f"{model_name}_sae_intervention_results{suffix}.json"
    filepath = os.path.join(results_dir, filename)
    
    # Convert tensors to serializable format
    def convert_to_serializable(obj):
        if isinstance(obj, torch.Tensor):
            return obj.cpu().numpy().tolist()
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, (np.float32, np.float64)):
            return float(obj)
        elif isinstance(obj, (np.int32, np.int64)):
            return int(obj)
        elif isinstance(obj, dict):
            return {key: convert_to_serializable(value) for key, value in obj.items()}
        elif isinstance(obj, list):
            return [convert_to_serializable(item) for item in obj]
        else:
            return obj
    
    serializable_results = convert_to_serializable(results)
    
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(serializable_results, f, indent=2, ensure_ascii=False)
    
    print(f"Results saved to: {filepath}")
    return filepath

In [3]:
# Experiment configuration for Pythia
pythia_model, pythia_tokenizer = get_pythia_70m(DEVICE)
PYTHIA_CONFIG = {
    'model_name': 'pythia',
    'layer_types': ['att', 'mlp', 'res'],
    'layer_range': range(0, 6),
    'model': pythia_model,
    'tokenizer': pythia_tokenizer,
    'similarity_threshold': 0.2,
    'n_target_features': 10,
    'n_interference_features': 3,
    'n_top_tokens': 5,
    'n_test_sentences': 3,
    'seed': 50,
    'device': DEVICE,
    'use_auto_model': True
}

# Scale values for intervention
scale_values = [0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 17, 20]
SCALE_RANGE = [-x for x in scale_values[::-1]] + scale_values

In [None]:
# Run the experiment
pythia_results = run_sae_intervention_on_small_models(
    scale_range=SCALE_RANGE,
    **PYTHIA_CONFIG
)

In [None]:
# Save Pythia results
import time
timestamp = str(int(time.time()))
thres = PYTHIA_CONFIG['similarity_threshold']
pythia_filepath = save_experiment_results(pythia_results, 'pythia', f'_{thres}_{timestamp}')

# Print detailed statistics
total_features = 0
total_tests = 0
layer_stats = {}

for layer_type in pythia_results:
    layer_stats[layer_type] = {}
    for layer_idx in pythia_results[layer_type]:
        features_in_layer = len(pythia_results[layer_type][layer_idx])
        layer_stats[layer_type][layer_idx] = features_in_layer
        total_features += features_in_layer
        
        # Count total tests (tokens x sentences)
        for feature_id in pythia_results[layer_type][layer_idx]:
            feature_data = pythia_results[layer_type][layer_idx][feature_id]
            total_tests += len(feature_data.get('tests', {}))

print(f"\nPythia Results Summary:")
print(f"- Total features tested: {total_features}")
print(f"- Total token tests: {total_tests}")
print(f"- Layer types: {list(pythia_results.keys())}")

for layer_type, layers in layer_stats.items():
    print(f"\n{layer_type.upper()} layers:")
    for layer_idx, feature_count in layers.items():
        print(f"  Layer {layer_idx}: {feature_count} features")
        
print(f"\nFull results saved to: {pythia_filepath}")

In [None]:
import time
gpt2_model, gpt2_tokenizer = get_gpt2_small(DEVICE)

GPT2_CONFIG = {
    'model_name': 'gpt2',
    'layer_types': ['att', 'res_mid', 'mlp', 'res_post'],
    'layer_range': range(0, 12),  # Layers 0-11
    'model': gpt2_model,
    'tokenizer': gpt2_tokenizer,
    'n_target_features': 10,
    'n_interference_features': 3,
    'n_top_tokens': 5,
    'n_test_sentences': 3,
    'seed': 53,
    'device': DEVICE,
    'use_auto_model': True
}

semantic_thresholds = [0.4, 0.3, 0.2, 0.15]

scale_values = [0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 17, 20]
SCALE_RANGE = [-x for x in scale_values[::-1]] + scale_values

print(f"Starting GPT-2 SAE experiments with semantic thresholds: {semantic_thresholds}")
print(f"Configuration:")
print(f"- Layer types: {GPT2_CONFIG['layer_types']}")
print(f"- Layer range: {list(GPT2_CONFIG['layer_range'])}")
print(f"- Target features: {GPT2_CONFIG['n_target_features']}")
print(f"- Interference features: {GPT2_CONFIG['n_interference_features']}")
print(f"- Top tokens: {GPT2_CONFIG['n_top_tokens']}")
print(f"- Test sentences: {GPT2_CONFIG['n_test_sentences']}")
print(f"- Scale range: {len(SCALE_RANGE)} values from {min(SCALE_RANGE)} to {max(SCALE_RANGE)}")

gpt2_results = {}
total_start_time = time.time()

for i, threshold in enumerate(semantic_thresholds):
    print(f"\n{'='*80}")
    print(f"Running GPT-2 SAE experiment {i+1}/{len(semantic_thresholds)} with similarity threshold: {threshold}")
    print(f"{'='*80}")
    
    config = GPT2_CONFIG.copy()
    config['similarity_threshold'] = threshold
    
    start_time = time.time()
    
    try:
        results = run_sae_intervention_on_small_models(
            scale_range=SCALE_RANGE,
            **config
        )
        
        gpt2_results[threshold] = results
        
        total_features = 0
        total_tests = 0
        layer_stats = {}
        
        for layer_type in results:
            layer_stats[layer_type] = {}
            for layer_idx in results[layer_type]:
                features_in_layer = len(results[layer_type][layer_idx])
                layer_stats[layer_type][layer_idx] = features_in_layer
                total_features += features_in_layer
                
                for feature_id in results[layer_type][layer_idx]:
                    feature_data = results[layer_type][layer_idx][feature_id]
                    total_tests += len(feature_data.get('tests', {}))
        
        elapsed_time = time.time() - start_time
        total_elapsed = time.time() - total_start_time
        
        print(f"\nCompleted threshold {threshold} in {elapsed_time:.2f} seconds")
        print(f"Total elapsed time: {total_elapsed:.2f} seconds")
        print(f"Statistics for threshold {threshold}:")
        print(f"- Total features tested: {total_features}")
        print(f"- Total token tests: {total_tests}")
        print(f"- Layer types: {list(results.keys())}")
        
        for layer_type, layers in layer_stats.items():
            layer_feature_count = sum(layers.values())
            print(f"  {layer_type.upper()}: {layer_feature_count} features across {len(layers)} layers")
        
        timestamp = str(int(time.time()))
        filepath = save_experiment_results(
            gpt2_results[threshold], 
            'gpt2', 
            f'_{threshold}_{timestamp}'
        )
        
        print(f"- Results saved to: {filepath}")
        
        remaining_thresholds = len(semantic_thresholds) - (i + 1)
        if remaining_thresholds > 0:
            avg_time_per_threshold = total_elapsed / (i + 1)
            estimated_remaining = avg_time_per_threshold * remaining_thresholds
            print(f"- Estimated remaining time: {estimated_remaining:.1f} seconds ({estimated_remaining/60:.1f} minutes)")
        
    except Exception as e:
        print(f"Error running experiment with threshold {threshold}: {e}")
        import traceback
        traceback.print_exc()
        gpt2_results[threshold] = {'error': str(e)}
        continue

total_elapsed_time = time.time() - total_start_time

print(f"\n{'='*80}")
print("All GPT-2 SAE experiments completed!")
print(f"Total execution time: {total_elapsed_time:.2f} seconds ({total_elapsed_time/60:.1f} minutes)")
print(f"Successfully completed thresholds: {[t for t in gpt2_results.keys() if 'error' not in gpt2_results[t]]}")
print(f"Failed thresholds: {[t for t in gpt2_results.keys() if 'error' in gpt2_results[t]]}")

print(f"\nOverall GPT-2 SAE Results Summary:")
for threshold in semantic_thresholds:
    if threshold in gpt2_results and 'error' not in gpt2_results[threshold]:
        results = gpt2_results[threshold]
        total_features = sum(len(results[lt][li]) for lt in results for li in results[lt])
        total_tests = sum(len(results[lt][li][fid].get('tests', {})) 
                         for lt in results for li in results[lt] for fid in results[lt][li])
        
        print(f"Threshold {threshold}:")
        print(f"  - Features: {total_features}")
        print(f"  - Tests: {total_tests}")
        print(f"  - Layer types: {list(results.keys())}")
    else:
        print(f"Threshold {threshold}: FAILED")

print(f"\nAll result files are saved in the '{results_dir}' directory.")

In [None]:
import time

# Load GPT-2 model
gpt2_model, gpt2_tokenizer = get_gpt2_small(DEVICE)

GPT2_CONFIG = {
    'model_name': 'gpt2',
    'layer_types': ['att', 'res_mid', 'mlp', 'res_post'],
    'layer_range': range(0, 12),  # Layers 0-11
    'model': gpt2_model,  # Use AutoModel
    'tokenizer': gpt2_tokenizer,  # Add tokenizer
    'n_target_features': 3,
    'n_interference_features': 3,
    'n_top_tokens': 3,
    'n_test_sentences': 3,
    'seed': 42,
    'device': DEVICE,
    'use_auto_model': True  # Add this flag
}

semantic_thresholds = [0.4, 0.3, 0.2, 0.15]

scale_values = [0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 17, 20]
SCALE_RANGE = [-x for x in scale_values[::-1]] + scale_values

print(f"Starting GPT-2 experiments with semantic thresholds: {semantic_thresholds}")
print(f"Layer types: {GPT2_CONFIG['layer_types']}")
print(f"Layer range: {list(GPT2_CONFIG['layer_range'])}")
print(f"Scale range: {len(SCALE_RANGE)} values from {min(SCALE_RANGE)} to {max(SCALE_RANGE)}")

gpt2_results = {}

for threshold in semantic_thresholds:
    print(f"\n{'='*60}")
    print(f"Running GPT-2 experiment with similarity threshold: {threshold}")
    print(f"{'='*60}")
    
    config = GPT2_CONFIG.copy()
    config['similarity_threshold'] = threshold
    
    start_time = time.time()
    
    try:
        results = run_sae_intervention_on_small_models(
            scale_range=SCALE_RANGE,
            **config
        )
        
        gpt2_results[threshold] = results
        
        total_features = 0
        total_tests = 0
        layer_stats = {}
        
        for layer_type in results:
            layer_stats[layer_type] = {}
            for layer_idx in results[layer_type]:
                features_in_layer = len(results[layer_type][layer_idx])
                layer_stats[layer_type][layer_idx] = features_in_layer
                total_features += features_in_layer
                
                for feature_id in results[layer_type][layer_idx]:
                    feature_data = results[layer_type][layer_idx][feature_id]
                    total_tests += len(feature_data.get('tests', {}))
        
        elapsed_time = time.time() - start_time
        
        print(f"\nCompleted threshold {threshold} in {elapsed_time:.2f} seconds")
        print(f"- Total features tested: {total_features}")
        print(f"- Total token tests: {total_tests}")
        
        timestamp = str(int(time.time()))
        filename = f"gpt2_sae_intervention_results_{threshold}_{timestamp}.json"
        filepath = save_experiment_results(gpt2_results[threshold], 'gpt2', f'_{threshold}_{timestamp}')
        
        print(f"- Results saved to: {filepath}")
        
    except Exception as e:
        print(f"Error running experiment with threshold {threshold}: {e}")
        import traceback
        traceback.print_exc()
        continue

print(f"\n{'='*60}")
print("All GPT-2 experiments completed!")
print(f"Successfully completed thresholds: {list(gpt2_results.keys())}")

print(f"\nOverall GPT-2 Results Summary:")
for threshold in gpt2_results:
    results = gpt2_results[threshold]
    total_features = sum(len(results[lt][li]) for lt in results for li in results[lt])
    total_tests = sum(len(results[lt][li][fid].get('tests', {})) 
                     for lt in results for li in results[lt] for fid in results[lt][li])
    
    print(f"Threshold {threshold}:")
    print(f"  - Features: {total_features}")
    print(f"  - Tests: {total_tests}")
    print(f"  - Layer types: {list(results.keys())}")

## plot absolute improvement

In [None]:
import os
import json
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict
import glob

def plot_sae_intervention_results(data_path, output_dir="analysis_results", model_name="Model"):
    """
    A temporary plot function
    """
    
    os.makedirs(output_dir, exist_ok=True)
    
    if '*' in data_path:
        files = glob.glob(data_path)
        if not files:
            raise FileNotFoundError(f"No files found matching pattern: {data_path}")
        data_path = sorted(files)[-1]
        print(f"Using latest file: {data_path}")
    
    with open(data_path, 'r') as f:
        results = json.load(f)
    
    print(f"Loaded results from: {data_path}")
    
    metrics_data = {
        'weighted_cosine_similarity': defaultdict(list),
        'spearman_correlation': defaultdict(list),
        'kendall_correlation': defaultdict(list),
        'weighted_overlap': defaultdict(list)
    }
    
    metric_names = {
        'weighted_cosine_similarity': 'Weighted Cosine Similarity',
        'spearman_correlation': 'Spearman Correlation',
        'kendall_correlation': 'Kendall Correlation',
        'weighted_overlap': 'Weighted Overlap'
    }
    
    print("Processing results...")
    
    total_processed = 0
    for layer_type in results:
        for layer_index in results[layer_type]:
            for feature_id in results[layer_type][layer_index]:
                feature_data = results[layer_type][layer_index][feature_id]
                
                for token_idx in feature_data.get('tests', {}):
                    token_data = feature_data['tests'][token_idx]
                    
                    for sent_idx in token_data.get('sentences', {}):
                        sent_data = token_data['sentences'][sent_idx]
                        
                        if 'baseline' not in sent_data:
                            continue
                        
                        baseline = sent_data['baseline']
                        
                        for level in sent_data:
                            if level == 'baseline':
                                continue
                            
                            level_data = sent_data[level]
                            
                            for metric in metrics_data.keys():
                                baseline_val = baseline.get(metric, 0)
                                best_val = level_data.get(f'best_{metric}', 0)
                                
                                improvement = best_val - baseline_val
                                
                                if improvement > 0:
                                    metrics_data[metric][level].append(improvement)
                        
                        total_processed += 1
    
    print(f"Processed {total_processed} test cases")
    
    all_levels = set()
    for metric in metrics_data:
        all_levels.update(metrics_data[metric].keys())
    
    level_order = ['self']
    
    range_levels = [l for l in all_levels if '-' in l and l != 'self']
    range_levels.sort(key=lambda x: float(x.split('-')[0]) if x.split('-')[0].replace('.','').isdigit() else 0, reverse=True)
    level_order.extend(range_levels)
    
    if 'rand' in all_levels:
        level_order.append('rand')
    
    other_levels = [l for l in all_levels if l not in level_order]
    level_order.extend(sorted(other_levels))
    
    print(f"Interference levels found: {level_order}")
    
    level_labels = {}
    for level in level_order:
        if level == 'self':
            level_labels[level] = 'Self'
        elif level == 'rand':
            level_labels[level] = 'Random'
        elif '-' in level and level.count('-') == 1:
            try:
                parts = level.split('-')
                min_val = float(parts[0])
                if min_val >= 0.4:
                    level_labels[level] = f"High\n({level})"
                elif min_val >= 0.3:
                    level_labels[level] = f"H-Mid\n({level})"
                elif min_val >= 0.2:
                    level_labels[level] = f"Mid\n({level})"
                elif min_val >= 0.1:
                    level_labels[level] = f"L-Mid\n({level})"
                else:
                    level_labels[level] = f"Low\n({level})"
            except:
                level_labels[level] = level.title()
        else:
            level_labels[level] = level.title()
    
    fig, axs = plt.subplots(2, 2, figsize=(15, 12), dpi=300)
    axs = axs.flatten()
    
    colors = plt.cm.Blues(np.linspace(0.3, 0.9, len(level_order)))
    
    metrics = ['weighted_cosine_similarity', 'spearman_correlation', 'kendall_correlation', 'weighted_overlap']
    
    for i, metric in enumerate(metrics):
        ax = axs[i]
        
        means = []
        sems = []
        labels = []
        
        for level in level_order:
            if level in metrics_data[metric] and len(metrics_data[metric][level]) > 0:
                values = metrics_data[metric][level]
                means.append(np.mean(values))
                sems.append(np.std(values) / np.sqrt(len(values)))
                labels.append(level_labels[level])
            else:
                means.append(0)
                sems.append(0)
                labels.append(level_labels.get(level, level))
        
        x_positions = np.arange(len(means))
        
        bars = ax.bar(x_positions, means, width=0.6, color=colors, 
                      yerr=sems, capsize=3, 
                      error_kw={'ecolor': 'black', 'linewidth': 1, 'capthick': 1})
        
        for idx, (bar, mean_val, sem_val) in enumerate(zip(bars, means, sems)):
            if mean_val > 0:
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height + sem_val + max(means)*0.01,
                       f'{mean_val:.3f}', ha='center', va='bottom', fontsize=9, fontweight='bold')
        
        ax.set_title(metric_names[metric], fontsize=14, pad=15, fontweight='bold')
        ax.set_xlabel('Interference Level', fontsize=12)
        ax.set_ylabel('Mean Improvement', fontsize=12)
        ax.set_xticks(x_positions)
        ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=10)
        ax.grid(axis='y', linestyle='-', alpha=0.3)
        
        ax.set_ylim(bottom=0)
        
        ax.tick_params(axis='both', which='major', labelsize=10)
        
        total_points = sum(len(metrics_data[metric][level]) for level in level_order 
                          if level in metrics_data[metric])
        ax.text(0.02, 0.98, f'n={total_points}', transform=ax.transAxes, 
                fontsize=9, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.suptitle(f'{model_name} SAE Intervention Results', fontsize=16, y=0.95, fontweight='bold')
    plt.tight_layout()
    
    model_lower = model_name.lower().replace(' ', '_')
    png_path = os.path.join(output_dir, f"{model_lower}_sae_intervention_results.png")
    pdf_path = os.path.join(output_dir, f"{model_lower}_sae_intervention_results.pdf")
    
    plt.savefig(png_path, bbox_inches='tight', dpi=300)
    plt.savefig(pdf_path, bbox_inches='tight')
    plt.show()
    
    statistics = {}
    for metric in metrics_data:
        statistics[metric] = {}
        metric_data = metrics_data[metric]
        
        for level in metric_data:
            if len(metric_data[level]) > 0:
                values = metric_data[level]
                statistics[metric][level] = {
                    'mean': float(np.mean(values)),
                    'std': float(np.std(values)),
                    'sem': float(np.std(values) / np.sqrt(len(values))),
                    'count': len(values),
                    'min': float(np.min(values)),
                    'max': float(np.max(values)),
                    'median': float(np.median(values))
                }
            else:
                statistics[metric][level] = {
                    'mean': 0, 'std': 0, 'sem': 0, 'count': 0,
                    'min': 0, 'max': 0, 'median': 0
                }
    
    stats_path = os.path.join(output_dir, f"{model_lower}_statistics.json")
    with open(stats_path, 'w') as f:
        json.dump(statistics, f, indent=2)
    
    print(f"\nResults saved:")
    print(f"- Plot (PNG): {png_path}")
    print(f"- Plot (PDF): {pdf_path}")
    print(f"- Statistics: {stats_path}")
    
    return statistics


## plot relative improvement

In [None]:
import os
import json
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict
import glob
import matplotlib.font_manager as fm
import matplotlib as mpl

def plot_sae_intervention_results(data_path, output_dir="analysis_results", model_name="Model"):
    """
    A temporary plot function for relative value improvements
    """
    
    try:
        font_path = "./AvenirLTStd-Roman.otf"
        if os.path.exists(font_path):
            avenir_font = fm.FontProperties(fname=font_path)
            fm.fontManager.addfont(font_path)
            mpl.rcParams['font.family'] = avenir_font.get_name()
    except:
        pass
    
    os.makedirs(output_dir, exist_ok=True)
    
    if '*' in data_path:
        files = glob.glob(data_path)
        if not files:
            raise FileNotFoundError(f"No files found matching pattern: {data_path}")
        data_path = sorted(files)[-1]
        print(f"Using latest file: {data_path}")
    
    with open(data_path, 'r') as f:
        results = json.load(f)
    
    print(f"Loaded results from: {data_path}")
    
    metrics_data = {
        'weighted_cosine_similarity': defaultdict(list),
        'spearman_correlation': defaultdict(list),
        'kendall_correlation': defaultdict(list),
        'weighted_overlap': defaultdict(list)
    }
    
    metric_names = {
        'weighted_cosine_similarity': 'Weighted Cosine Similarity',
        'spearman_correlation': 'Spearman Correlation',
        'kendall_correlation': 'Kendall Correlation',
        'weighted_overlap': 'Weighted Overlap'
    }
    
    print("Processing results...")
    
    total_processed = 0
    total_improved = 0
    improvement_counts = {metric: defaultdict(int) for metric in metrics_data.keys()}
    
    for layer_type in results:
        for layer_index in results[layer_type]:
            for feature_id in results[layer_type][layer_index]:
                feature_data = results[layer_type][layer_index][feature_id]
                
                for token_idx in feature_data.get('tests', {}):
                    token_data = feature_data['tests'][token_idx]
                    
                    for sent_idx in token_data.get('sentences', {}):
                        sent_data = token_data['sentences'][sent_idx]
                    
                        if 'baseline' not in sent_data:
                            continue
                        
                        baseline = sent_data['baseline']
                        
                        for level in sent_data:
                            if level == 'baseline':
                                continue
                            
                            level_data = sent_data[level]
                            
                            for metric in metrics_data.keys():
                                baseline_val = baseline.get(metric, 0)
                                best_val = level_data.get(f'best_{metric}', 0)
                                
                                is_improved = False
                                percentage_improvement = 0
                                
                                if metric in ['weighted_cosine_similarity', 'weighted_overlap']:
                                    is_improved = best_val > baseline_val
                                    if is_improved and abs(baseline_val) > 1e-8:
                                        percentage_improvement = (best_val - baseline_val) / abs(baseline_val) * 100
                                        
                                elif metric in ['spearman_correlation', 'kendall_correlation']:
                                    is_improved = best_val > baseline_val
                                    if is_improved and abs(baseline_val) > 1e-8:
                                        percentage_improvement = (best_val - baseline_val) / abs(baseline_val) * 100
                                
                                if is_improved and percentage_improvement > 0:
                                    metrics_data[metric][level].append(percentage_improvement)
                                    improvement_counts[metric][level] += 1
                                    total_improved += 1
                        
                        total_processed += 1
    
    print(f"Processed {total_processed} test cases")
    print(f"Total improvements recorded: {total_improved}")

    print("\nImprovement counts per metric and level:")
    for metric in metrics_data.keys():
        print(f"\n{metric_names[metric]}:")
        for level in sorted(improvement_counts[metric].keys()):
            count = improvement_counts[metric][level]
            print(f"  {level:15s}: {count:3d} improvements")
    
    all_levels = set()
    for metric in metrics_data:
        all_levels.update(metrics_data[metric].keys())
    
    level_order = ['self']
    
    range_levels = [l for l in all_levels if '-' in l and l != 'self']
    range_levels.sort(key=lambda x: float(x.split('-')[0]) if x.split('-')[0].replace('.','').isdigit() else 0, reverse=True)
    level_order.extend(range_levels)
    
    if 'rand' in all_levels:
        level_order.append('rand')
    
    other_levels = [l for l in all_levels if l not in level_order]
    level_order.extend(sorted(other_levels))
    
    print(f"\nInterference levels found: {level_order}")
    
    level_labels = {}
    for level in level_order:
        if level == 'self':
            level_labels[level] = 'Original'
        elif level == 'rand':
            level_labels[level] = 'Random'
        elif '-' in level and level.count('-') == 1:
            try:
                parts = level.split('-')
                min_val = float(parts[0])
                if min_val >= 0.4:
                    level_labels[level] = "High"
                elif min_val >= 0.3:
                    level_labels[level] = "Mid High"
                elif min_val >= 0.2:
                    level_labels[level] = "Mid"
                elif min_val >= 0.1:
                    level_labels[level] = "Mid Low"
                else:
                    level_labels[level] = "Low"
            except:
                level_labels[level] = level.title()
        else:
            level_labels[level] = level.title()
    
    fig, axs = plt.subplots(1, 4, figsize=(17, 4.5), dpi=300)
    axs = axs.flatten()

    colors = ['#0D47A1', '#1565C0', '#1976D2', '#1E88E5', '#2196F3', '#42A5F5', '#64B5F6']
    
    metrics = ['weighted_cosine_similarity', 'spearman_correlation', 'kendall_correlation', 'weighted_overlap']
    titles = ['Weighted Cosine Similarity', 'Spearman Correlation', 'Kendall Correlation', 'Weighted Overlap']
    y_labels = [
        'Mean % Improvement',
        'Mean % Improvement', 
        'Mean % Improvement',
        'Mean % Improvement'
    ]
    
    for i, (metric, title, y_label) in enumerate(zip(metrics, titles, y_labels)):
        ax = axs[i]
        
        means = []
        sems = []
        labels = []
        counts = []
        
        for level in level_order:
            if level in metrics_data[metric] and len(metrics_data[metric][level]) > 0:
                values = metrics_data[metric][level]
                means.append(np.mean(values))
                sems.append(np.std(values) / np.sqrt(len(values)))
                labels.append(level_labels[level])
                counts.append(len(values))
            else:
                means.append(0)
                sems.append(0)
                labels.append(level_labels.get(level, level))
                counts.append(0)
        
        x_positions = np.arange(len(means)) + 1
        
        bars = ax.bar(x_positions, means, width=0.6, color=colors[:len(means)], 
                      yerr=sems, capsize=5, 
                      error_kw={'ecolor': 'black', 'linewidth': 1, 'capthick': 1})
        
        for idx, bar in enumerate(bars):
            height = bar.get_height()
            if height > 0:
                ax.text(bar.get_x() + bar.get_width()/2., height + sems[idx] + max(means)*0.01,
                       f'{height:.1f}%', ha='center', va='bottom', fontsize=8)
        
        ax.set_title(title, fontsize=12, pad=10)
        ax.set_xlabel('Interference Level', fontsize=10)
        ax.set_ylabel(y_label, fontsize=10)
        ax.set_xticks(x_positions)
        ax.set_xticklabels(labels, rotation=45)
        ax.grid(axis='y', linestyle='-', alpha=0.2)
        
        ax.set_ylim(bottom=0)
        
        ax.tick_params(axis='both', which='major', labelsize=8)
    
    plt.suptitle(model_name, fontsize=14, y=1.05)
    plt.tight_layout()
    
    model_lower = model_name.lower().replace(' ', '_')
    png_path = os.path.join(output_dir, f"{model_lower}_sae_intervention_percentage.png")
    pdf_path = os.path.join(output_dir, f"{model_lower}_sae_intervention_percentage.pdf")
    
    plt.savefig(png_path, bbox_inches='tight', dpi=300)
    plt.savefig(pdf_path, bbox_inches='tight')
    plt.show()
    
    statistics = {}
    for metric in metrics_data:
        statistics[metric] = {}
        metric_data = metrics_data[metric]
        
        for level in metric_data:
            if len(metric_data[level]) > 0:
                values = metric_data[level]
                statistics[metric][level] = {
                    'mean_percentage': float(np.mean(values)),
                    'std_percentage': float(np.std(values)),
                    'sem_percentage': float(np.std(values) / np.sqrt(len(values))),
                    'count': len(values),
                    'min_percentage': float(np.min(values)),
                    'max_percentage': float(np.max(values)),
                    'median_percentage': float(np.median(values))
                }
            else:
                statistics[metric][level] = {
                    'mean_percentage': 0, 'std_percentage': 0, 'sem_percentage': 0, 'count': 0,
                    'min_percentage': 0, 'max_percentage': 0, 'median_percentage': 0
                }
    
    statistics['summary'] = {
        'total_test_cases': total_processed,
        'total_improvements': total_improved,
        'improvement_rate': float(total_improved / total_processed) if total_processed > 0 else 0,
        'improvements_per_metric': {metric: sum(len(metrics_data[metric][level]) for level in metrics_data[metric]) 
                                   for metric in metrics_data}
    }
    
    stats_path = os.path.join(output_dir, f"{model_lower}_percentage_statistics.json")
    with open(stats_path, 'w') as f:
        json.dump(statistics, f, indent=2)
    
    print(f"\nResults saved:")
    print(f"- Plot (PNG): {png_path}")
    print(f"- Plot (PDF): {pdf_path}")
    print(f"- Statistics: {stats_path}")
    
    return statistics

## Do a simple visualization

In [None]:
data_file = "results/pythia_gpt2/pythia_sae_intervention_results_0.2_1756957339.json"
try:
    stats = plot_sae_intervention_results(
        data_path=data_file,
        output_dir="analysis_results",
        model_name="Pythia"
    )
    print("Plotting completed successfully!")
except FileNotFoundError:
    print(f"File not found: {data_file}")
    print("Please check the file path or use the wildcard pattern below.")