In [None]:
import os
import sys
import json
import torch
import numpy as np
from tqdm.notebook import tqdm
from collections import defaultdict
current_dir = os.getcwd()
if current_dir not in sys.path:
    sys.path.append(current_dir)
from utils.utils_exp import run_gradient_intervention_on_small_models
from utils.utils_model import get_hooked_pythia_70m, get_hooked_gpt2_small

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {DEVICE}")

torch.manual_seed(42)
np.random.seed(42)

In [2]:
results_dir = "results/gradient_pythia_gpt2"
os.makedirs(results_dir, exist_ok=True)

# Serialize the result data
def save_experiment_results(results, model_name, suffix=""):
    """Save experiment results to JSON file"""
    filename = f"{model_name}_gradient_intervention_results{suffix}.json"
    filepath = os.path.join(results_dir, filename)
    
    # Convert tensors to serializable format
    def convert_to_serializable(obj):
        if isinstance(obj, torch.Tensor):
            return obj.cpu().numpy().tolist()
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, (np.float32, np.float64)):
            return float(obj)
        elif isinstance(obj, (np.int32, np.int64)):
            return int(obj)
        elif isinstance(obj, dict):
            return {key: convert_to_serializable(value) for key, value in obj.items()}
        elif isinstance(obj, list):
            return [convert_to_serializable(item) for item in obj]
        else:
            return obj
    
    serializable_results = convert_to_serializable(results)
    
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(serializable_results, f, indent=2, ensure_ascii=False)
    
    return filepath

In [None]:
pythia_model = get_hooked_pythia_70m(DEVICE)
PYTHIA_CONFIG = {
    'model_name': 'pythia',
    'layer_types': ['att', 'mlp'],
    'layer_range': range(0, 6),
    'model': pythia_model,
    'tokenizer': None,
    'similarity_threshold': 0.2,
    'n_target_features': 4,
    'n_interference_features': 3,
    'n_top_tokens': 3,
    'n_test_sentences': 3,
    'seed': 50,
    'device': DEVICE,
    'use_hooked_model': True
}

# Scale values for intervention
scale_values = [0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 17, 20]
SCALE_RANGE = [-x for x in scale_values[::-1]] + scale_values

print(f"Pythia Gradient Intervention Configuration:")
print(f"- Model: {PYTHIA_CONFIG['model_name']}")
print(f"- Layer types: {PYTHIA_CONFIG['layer_types']}")
print(f"- Layer range: {list(PYTHIA_CONFIG['layer_range'])}")
print(f"- Similarity threshold: {PYTHIA_CONFIG['similarity_threshold']}")
print(f"- Target features: {PYTHIA_CONFIG['n_target_features']}")
print(f"- Interference features: {PYTHIA_CONFIG['n_interference_features']}")
print(f"- Top tokens: {PYTHIA_CONFIG['n_top_tokens']}")
print(f"- Test sentences: {PYTHIA_CONFIG['n_test_sentences']}")
print(f"- Scale range: {len(SCALE_RANGE)} values from {min(SCALE_RANGE)} to {max(SCALE_RANGE)}")
print(f"- Using HookedTransformer: {PYTHIA_CONFIG['use_hooked_model']}")

In [None]:
# Run the Pythia gradient intervention experiment
print("Starting Pythia gradient intervention experiment...")
print("="*60)

import time
start_time = time.time()

try:
    pythia_results = run_gradient_intervention_on_small_models(
        scale_range=SCALE_RANGE,
        **PYTHIA_CONFIG
    )
    
    elapsed_time = time.time() - start_time
    print(f"\nPythia gradient intervention experiment completed in {elapsed_time:.2f} seconds")
    
except Exception as e:
    print(f"Error running Pythia gradient intervention experiment: {e}")
    import traceback
    traceback.print_exc()
    pythia_results = {}

In [None]:
if pythia_results:
    import time
    timestamp = str(int(time.time()))
    thres = PYTHIA_CONFIG['similarity_threshold']
    pythia_filepath = save_experiment_results(pythia_results, 'pythia', f'_{thres}_{timestamp}')

    # Print detailed statistics
    total_features = 0
    total_tests = 0
    layer_stats = {}

    for layer_type in pythia_results:
        layer_stats[layer_type] = {}
        for layer_idx in pythia_results[layer_type]:
            features_in_layer = len(pythia_results[layer_type][layer_idx])
            layer_stats[layer_type][layer_idx] = features_in_layer
            total_features += features_in_layer
            
            # Count total tests (tokens x sentences)
            for feature_id in pythia_results[layer_type][layer_idx]:
                feature_data = pythia_results[layer_type][layer_idx][feature_id]
                total_tests += len(feature_data.get('tests', {}))

    print(f"\nPythia Gradient Intervention Results Summary:")
    print(f"- Total features tested: {total_features}")
    print(f"- Total token tests: {total_tests}")
    print(f"- Layer types: {list(pythia_results.keys())}")

    for layer_type, layers in layer_stats.items():
        print(f"\n{layer_type.upper()} layers:")
        for layer_idx, feature_count in layers.items():
            print(f"  Layer {layer_idx}: {feature_count} features")
            
    print(f"\nFull results saved to: {pythia_filepath}")
else:
    print("No Pythia results to save.")

In [None]:
# GPT-2 gradient intervention experiment
import time

# Load GPT-2 HookedTransformer model
gpt2_model = get_hooked_gpt2_small(DEVICE)

GPT2_CONFIG = {
    'model_name': 'gpt2',
    'layer_types': ['att', 'mlp'],
    'layer_range': range(0, 12),
    'model': gpt2_model,
    'tokenizer': None,
    'n_target_features': 4,
    'n_interference_features': 3,
    'n_top_tokens': 5,
    'n_test_sentences': 3,
    'seed': 53,
    'device': DEVICE,
    'use_hooked_model': True
}

semantic_thresholds = [0.4, 0.3, 0.2, 0.15]

scale_values = [0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 17, 20]
SCALE_RANGE = [-x for x in scale_values[::-1]] + scale_values

print(f"Starting GPT-2 gradient intervention experiments with semantic thresholds: {semantic_thresholds}")
print(f"Configuration:")
print(f"- Layer types: {GPT2_CONFIG['layer_types']}")
print(f"- Layer range: {list(GPT2_CONFIG['layer_range'])}")
print(f"- Target features: {GPT2_CONFIG['n_target_features']}")
print(f"- Interference features: {GPT2_CONFIG['n_interference_features']}")
print(f"- Top tokens: {GPT2_CONFIG['n_top_tokens']}")
print(f"- Test sentences: {GPT2_CONFIG['n_test_sentences']}")
print(f"- Scale range: {len(SCALE_RANGE)} values from {min(SCALE_RANGE)} to {max(SCALE_RANGE)}")
print(f"- Using HookedTransformer: {GPT2_CONFIG['use_hooked_model']}")

# run experiments for each threshold
gpt2_results = {}
total_start_time = time.time()

for i, threshold in enumerate(semantic_thresholds):
    print(f"\n{'='*80}")
    print(f"Running GPT-2 gradient intervention experiment {i+1}/{len(semantic_thresholds)} with similarity threshold: {threshold}")
    print(f"{'='*80}")
    
    config = GPT2_CONFIG.copy()
    config['similarity_threshold'] = threshold
    
    start_time = time.time()
    
    try:
        results = run_gradient_intervention_on_small_models(
            scale_range=SCALE_RANGE,
            **config
        )
        
        gpt2_results[threshold] = results
        
        total_features = 0
        total_tests = 0
        layer_stats = {}
        
        for layer_type in results:
            layer_stats[layer_type] = {}
            for layer_idx in results[layer_type]:
                features_in_layer = len(results[layer_type][layer_idx])
                layer_stats[layer_type][layer_idx] = features_in_layer
                total_features += features_in_layer
                
                for feature_id in results[layer_type][layer_idx]:
                    feature_data = results[layer_type][layer_idx][feature_id]
                    total_tests += len(feature_data.get('tests', {}))
        
        elapsed_time = time.time() - start_time
        total_elapsed = time.time() - total_start_time
        
        print(f"\nCompleted threshold {threshold} in {elapsed_time:.2f} seconds")
        print(f"Total elapsed time: {total_elapsed:.2f} seconds")
        print(f"Statistics for threshold {threshold}:")
        print(f"- Total features tested: {total_features}")
        print(f"- Total token tests: {total_tests}")
        print(f"- Layer types: {list(results.keys())}")
        
        for layer_type, layers in layer_stats.items():
            layer_feature_count = sum(layers.values())
            print(f"  {layer_type.upper()}: {layer_feature_count} features across {len(layers)} layers")
        
        timestamp = str(int(time.time()))
        filepath = save_experiment_results(
            gpt2_results[threshold], 
            'gpt2', 
            f'_{threshold}_{timestamp}'
        )
        
        print(f"- Results saved to: {filepath}")
        
        remaining_thresholds = len(semantic_thresholds) - (i + 1)
        if remaining_thresholds > 0:
            avg_time_per_threshold = total_elapsed / (i + 1)
            estimated_remaining = avg_time_per_threshold * remaining_thresholds
            print(f"- Estimated remaining time: {estimated_remaining:.1f} seconds ({estimated_remaining/60:.1f} minutes)")
        
    except Exception as e:
        print(f"Error running experiment with threshold {threshold}: {e}")
        import traceback
        traceback.print_exc()
        gpt2_results[threshold] = {'error': str(e)}
        continue

total_elapsed_time = time.time() - total_start_time

print(f"\n{'='*80}")
print("All GPT-2 gradient intervention experiments completed!")
print(f"Total execution time: {total_elapsed_time:.2f} seconds ({total_elapsed_time/60:.1f} minutes)")
print(f"Successfully completed thresholds: {[t for t in gpt2_results.keys() if 'error' not in gpt2_results[t]]}")
print(f"Failed thresholds: {[t for t in gpt2_results.keys() if 'error' in gpt2_results[t]]}")

print(f"\nOverall GPT-2 Gradient Intervention Results Summary:")
for threshold in semantic_thresholds:
    if threshold in gpt2_results and 'error' not in gpt2_results[threshold]:
        results = gpt2_results[threshold]
        total_features = sum(len(results[lt][li]) for lt in results for li in results[lt])
        total_tests = sum(len(results[lt][li][fid].get('tests', {})) 
                         for lt in results for li in results[lt] for fid in results[lt][li])
        
        print(f"Threshold {threshold}:")
        print(f"  - Features: {total_features}")
        print(f"  - Tests: {total_tests}")
        print(f"  - Layer types: {list(results.keys())}")
    else:
        print(f"Threshold {threshold}: FAILED")

print(f"\nAll result files are saved in the '{results_dir}' directory.")

In [None]:
import os
import json
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict
import glob
import matplotlib.font_manager as fm
import matplotlib as mpl

def plot_gradient_intervention_results(data_path, output_dir="analysis_results", model_name="Model"):
    """
    A temporary plot function for visualization.
    """
    
    os.makedirs(output_dir, exist_ok=True)
    
    if '*' in data_path:
        files = glob.glob(data_path)
        if not files:
            raise FileNotFoundError(f"No files found matching pattern: {data_path}")
        data_path = sorted(files)[-1]
        print(f"Using latest file: {data_path}")
    
    with open(data_path, 'r') as f:
        results = json.load(f)
    
    print(f"Loaded results from: {data_path}")
    
    metrics_data = {
        'weighted_cosine_similarity': defaultdict(list),
        'spearman_correlation': defaultdict(list),
        'kendall_correlation': defaultdict(list),
        'weighted_overlap': defaultdict(list)
    }
    
    metric_names = {
        'weighted_cosine_similarity': 'Weighted Cosine Similarity',
        'spearman_correlation': 'Spearman Correlation',
        'kendall_correlation': 'Kendall Correlation',
        'weighted_overlap': 'Weighted Overlap'
    }
    
    print("Processing gradient intervention results...")
    
    total_processed = 0
    total_improved = 0
    improvement_counts = {metric: defaultdict(int) for metric in metrics_data.keys()}
    
    for layer_type in results:
        for layer_index in results[layer_type]:
            for feature_id in results[layer_type][layer_index]:
                feature_data = results[layer_type][layer_index][feature_id]
                
                for token_idx in feature_data.get('tests', {}):
                    token_data = feature_data['tests'][token_idx]
                    
                    for sent_idx in token_data.get('sentences', {}):
                        sent_data = token_data['sentences'][sent_idx]
                        
                        if 'baseline' not in sent_data:
                            continue
                        
                        baseline = sent_data['baseline']
                        
                        for level in sent_data:
                            if level == 'baseline':
                                continue
                            
                            level_data = sent_data[level]
                            
                            for metric in metrics_data.keys():
                                baseline_val = baseline.get(metric, 0)
                                best_val = level_data.get(f'best_{metric}', 0)
                                
                                is_improved = False
                                percentage_improvement = 0
                                
                                if metric in ['weighted_cosine_similarity', 'weighted_overlap']:
                                    is_improved = best_val > baseline_val
                                    if is_improved and abs(baseline_val) > 1e-8:
                                        percentage_improvement = (best_val - baseline_val) / abs(baseline_val) * 100
                                        
                                elif metric in ['spearman_correlation', 'kendall_correlation']:
                                    is_improved = best_val > baseline_val
                                    if is_improved and abs(baseline_val) > 1e-8:
                                        percentage_improvement = (best_val - baseline_val) / abs(baseline_val) * 100
                                
                                if is_improved and percentage_improvement > 0:
                                    metrics_data[metric][level].append(percentage_improvement)
                                    improvement_counts[metric][level] += 1
                                    total_improved += 1
                        
                        total_processed += 1
    
    print(f"Processed {total_processed} test cases")
    print(f"Total improvements recorded: {total_improved}")

    all_levels = set()
    for metric in metrics_data:
        all_levels.update(metrics_data[metric].keys())
    
    level_order = ['self']
    
    range_levels = [l for l in all_levels if '-' in l and l != 'self']
    range_levels.sort(key=lambda x: float(x.split('-')[0]) if x.split('-')[0].replace('.','').isdigit() else 0, reverse=True)
    level_order.extend(range_levels)
    
    if 'rand' in all_levels:
        level_order.append('rand')
    
    other_levels = [l for l in all_levels if l not in level_order]
    level_order.extend(sorted(other_levels))
    
    print(f"Gradient interference levels found: {level_order}")
    
    level_labels = {}
    for level in level_order:
        if level == 'self':
            level_labels[level] = 'Original'
        elif level == 'rand':
            level_labels[level] = 'Random'
        elif '-' in level and level.count('-') == 1:
            try:
                parts = level.split('-')
                min_val = float(parts[0])
                if min_val >= 0.4:
                    level_labels[level] = "High"
                elif min_val >= 0.3:
                    level_labels[level] = "Mid High"
                elif min_val >= 0.2:
                    level_labels[level] = "Mid"
                elif min_val >= 0.1:
                    level_labels[level] = "Mid Low"
                else:
                    level_labels[level] = "Low"
            except:
                level_labels[level] = level.title()
        else:
            level_labels[level] = level.title()
    
    fig, axs = plt.subplots(1, 4, figsize=(17, 4.5), dpi=300)
    axs = axs.flatten()
    
    colors = ['#B71C1C', '#C62828', '#D32F2F', '#E53935', '#F44336', '#EF5350', '#E57373']
    
    metrics = ['weighted_cosine_similarity', 'spearman_correlation', 'kendall_correlation', 'weighted_overlap']
    titles = ['Weighted Cosine Similarity', 'Spearman Correlation', 'Kendall Correlation', 'Weighted Overlap']
    y_labels = [
        'Mean % Improvement',
        'Mean % Improvement', 
        'Mean % Improvement',
        'Mean % Improvement'
    ]
    
    for i, (metric, title, y_label) in enumerate(zip(metrics, titles, y_labels)):
        ax = axs[i]

        means = []
        sems = []
        labels = []
        counts = []
        
        for level in level_order:
            if level in metrics_data[metric] and len(metrics_data[metric][level]) > 0:
                values = metrics_data[metric][level]
                means.append(np.mean(values))
                sems.append(np.std(values) / np.sqrt(len(values)))
                labels.append(level_labels[level])
                counts.append(len(values))
            else:
                means.append(0)
                sems.append(0)
                labels.append(level_labels.get(level, level))
                counts.append(0)
        
        x_positions = np.arange(len(means)) + 1
        
        bars = ax.bar(x_positions, means, width=0.6, color=colors[:len(means)], 
                      yerr=sems, capsize=5, 
                      error_kw={'ecolor': 'black', 'linewidth': 1, 'capthick': 1})
        
        for idx, bar in enumerate(bars):
            height = bar.get_height()
            if height > 0:
                ax.text(bar.get_x() + bar.get_width()/2., height + sems[idx] + max(means)*0.01,
                       f'{height:.1f}%', ha='center', va='bottom', fontsize=8)
        
        ax.set_title(title, fontsize=12, pad=10)
        ax.set_xlabel('Gradient Interference Level', fontsize=10)
        ax.set_ylabel(y_label, fontsize=10)
        ax.set_xticks(x_positions)
        ax.set_xticklabels(labels, rotation=45)
        ax.grid(axis='y', linestyle='-', alpha=0.2)
        
        ax.set_ylim(bottom=0)
        
        ax.tick_params(axis='both', which='major', labelsize=8)
    
    plt.suptitle(f'{model_name} Gradient Intervention', fontsize=14, y=1.05)
    plt.tight_layout()
    
    model_lower = model_name.lower().replace(' ', '_')
    png_path = os.path.join(output_dir, f"{model_lower}_gradient_intervention_results.png")
    pdf_path = os.path.join(output_dir, f"{model_lower}_gradient_intervention_results.pdf")
    
    plt.savefig(png_path, bbox_inches='tight', dpi=300)
    plt.savefig(pdf_path, bbox_inches='tight')
    plt.show()
    
    print(f"\nResults saved:")
    print(f"- Plot (PNG): {png_path}")
    print(f"- Plot (PDF): {pdf_path}")
    
    return metrics_data

In [None]:
# plot results
if True:
    print("Plotting GPT2-Small gradient intervention results...")
    plot_gradient_intervention_results(
        data_path="results/gradient_pythia_gpt2/gpt2_gradient_intervention_results_0.15_(YOUR_TIMESTAMP).json",
        output_dir="analysis_results",
        model_name="GPT2-Small"
    )
else:
    print("No Pythia results available for plotting.")