In [None]:
# 1: Import packages (Updated)
import os
import json
import torch
import random
import numpy as np
import torch.nn.functional as F
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from collections import defaultdict
from pathlib import Path
import gc
from typing import Dict, List, Set, Tuple, Optional

# Import model utilities and sentences
import sys
sys.path.append(os.path.dirname(os.path.abspath('__file__')))
from utils.utils_model import get_hooked_pythia_70m, get_hooked_gpt2_small
from utils.utils_data import load_type_dicts, get_token_id_type, get_token_str_type
from corpus.type_sentences import SENTENCES_OF_TYPE

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# 2: Define utility functions
def inject_tokens(sentence: str, tokens: List[str]) -> str:
    """Inject tokens at the beginning of sentence with parentheses"""
    return f"({''.join(tokens)}){sentence}"

def get_next_token_probs(model, text: str, top_k: int = 10) -> List[Tuple[int, str, float]]:
    """Get top-k next token predictions with probabilities"""
    tokens = model.to_tokens(text).to(device)
    with torch.no_grad():
        logits = model(tokens)[0, -1, :]
    probs = torch.softmax(logits, dim=-1)
    top_k_probs, top_k_indices = torch.topk(probs, top_k)
    return [
        (idx.item(), model.to_string(idx.item()), prob.item())
        for idx, prob in zip(top_k_indices, top_k_probs)
    ]

def get_weighted_cosine_similarity(embeddings: torch.Tensor, 
                                 token_ids: List[int], 
                                 token_probs: List[float], 
                                 target_ids: List[int]) -> float:
    """Calculate weighted cosine similarity between predicted and target tokens"""
    if not token_ids or not target_ids:
        return 0.0
    
    # Get embeddings
    output_embeddings = embeddings[token_ids]
    target_embeddings = embeddings[target_ids]
    
    # Normalize
    output_embeddings = F.normalize(output_embeddings, p=2, dim=1)
    target_embeddings = F.normalize(target_embeddings, p=2, dim=1)
    
    # Calculate similarity matrix
    similarity_matrix = torch.matmul(output_embeddings, target_embeddings.T)
    max_similarities = torch.max(similarity_matrix, dim=1)[0]
    
    # Apply weights
    token_probs_tensor = torch.tensor(token_probs, device=max_similarities.device)
    weighted_similarities = max_similarities * token_probs_tensor
    
    return torch.sum(weighted_similarities).item()

def get_weighted_overlap(token_ids: List[int], target_ids: List[int], token_probs: List[float]) -> float:
    """Calculate weighted overlap between predicted and target tokens"""
    overlap = 0.0
    for i, token_id in enumerate(token_ids):
        if token_id in target_ids:
            overlap += token_probs[i]
    return overlap

def get_best_rank(token_ids: List[int], target_ids: List[int]) -> Optional[int]:
    """Get the best rank of target tokens in predicted tokens"""
    for i, token_id in enumerate(token_ids):
        if token_id in target_ids:
            return i
    return None

def clear_memory():
    """Clear GPU memory"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

In [None]:
# 3: Define parameters
# Interference thresholds
PYTHIA_HIGH_INTERFERENCE_THRESHOLD = 0.5
GPT2_HIGH_INTERFERENCE_THRESHOLD = 0.3
BASELINE_INTERFERENCE_THRESHOLD = 0.2

# Experiment parameters
N_SELECTED_TOKENS = 10
N_TRIALS = 100
TOP_K = 30
RANDOM_SEED = 43
N_INJECTION = 30

# File paths
MED_DATA_DIR = Path("./med_data")

print(f"Parameters:")
print(f"  Pythia high interference threshold: {PYTHIA_HIGH_INTERFERENCE_THRESHOLD}")
print(f"  GPT-2 high interference threshold: {GPT2_HIGH_INTERFERENCE_THRESHOLD}")
print(f"  Baseline interference threshold: {BASELINE_INTERFERENCE_THRESHOLD}")
print(f"  Number of trials per experiment: {N_TRIALS}")
print(f"  Top-k predictions: {TOP_K}")

Parameters:
  Pythia high interference threshold: 0.5
  GPT-2 high interference threshold: 0.3
  Baseline interference threshold: 0.2
  Number of trials per experiment: 100
  Top-k predictions: 30


In [None]:
# 4: Load models
print("Loading models...")

# Load Pythia
print("  Loading Pythia-70M...")
pythia = get_hooked_pythia_70m(device)
pythia_embedding = pythia.embed.W_E.detach()
print(f"    Pythia loaded on {next(pythia.parameters()).device}")

# Load GPT-2
print("  Loading GPT-2 Small...")
gpt2 = get_hooked_gpt2_small(device)
gpt2_embedding = gpt2.embed.W_E.detach()
print(f"    GPT-2 loaded on {next(gpt2.parameters()).device}")

clear_memory()
print("Models loaded successfully!")

In [None]:
# 5: Define function to get interference token sets (Updated to remove intersection)
def get_interference_token_sets(
    model_name: str,
    token_type: str,
    high_threshold: float,
    baseline_threshold: float,
    med_data_dir: Path = MED_DATA_DIR
) -> Dict[str, Set[str]]:
    """
    Get interference token sets for a given model and token type
    
    Args:
        model_name: 'pythia' or 'gpt2'
        token_type: Target token type (e.g., 'location', 'person', 'emotion')
        high_threshold: Threshold for high interference tokens
        baseline_threshold: Threshold for baseline interference
        med_data_dir: Directory containing interference data files
    
    Returns:
        Dict with token sets: 'target', 'high_interference', 'medium_interference', 'random'
    """
    print(f"\nGetting interference token sets for {model_name} - {token_type}")
    
    # Load interference data
    interference_file = med_data_dir / f"{model_name}_{token_type}_tokens_r0.8_i0.2_s0.3.json"
    
    if not interference_file.exists():
        raise FileNotFoundError(f"Interference file not found: {interference_file}")
    
    print(f"  Loading: {interference_file}")
    with open(interference_file, 'r', encoding='utf-8') as f:
        interference_data = json.load(f)
    
    # Load token type dictionaries using utils_data
    print(f"  Loading token type dictionaries...")
    token_id_to_type_dict, token_str_to_type_dict = load_type_dicts(model_name)
    print(f"    Loaded {len(token_str_to_type_dict)} token type mappings")
    
    # Get model for token operations
    if model_name == 'pythia':
        model = pythia
    elif model_name == 'gpt2':
        model = gpt2
    else:
        raise ValueError(f"Unknown model: {model_name}")
    
    # Initialize token sets
    target_tokens = set()
    high_interference_tokens = set()
    medium_interference_tokens = set()
    
    # Get target tokens from vocabulary (all tokens of target type)
    print(f"  Getting target tokens from vocabulary...")
    for token_str, token_type_label in token_str_to_type_dict.items():
        if token_type_label == token_type:
            target_tokens.add(token_str)
    
    print(f"    Found {len(target_tokens)} target tokens in vocabulary")
    
    # Extract interference tokens from target features and interference features
    print(f"  Processing interference data...")
    
    for layer_type, layers in interference_data.items():
        if layer_type == 'summary':
            continue
        
        if not isinstance(layers, dict):
            continue
            
        for layer_idx, layer_data in layers.items():
            if not isinstance(layer_data, dict):
                continue
            
            # Process target features - add non-target-type tokens to interference sets
            # target_features = layer_data.get('target_features', [])
            # for feature in target_features:
            #     high_activation_tokens = feature.get('high_activation_tokens', [])
                
            #     # Check each high activation token's type
            #     for token_str in high_activation_tokens:
            #         # Use get_token_str_type from utils_data
            #         token_type_label = get_token_str_type(
            #             model_name, token_str, token_str_to_type_dict
            #         )
                    
            #         # If token is not of target type, it's an interference token
            #         if token_type_label != token_type:
            #             # Since these come from target features, they are high interference
            #             high_interference_tokens.add(token_str)
            
            # Process interference features
            interference_features = layer_data.get('interference_features', [])
            for feature in interference_features:
                # Check maximum interference value for this feature
                interferences = feature.get('interferences', [])
                if not interferences:
                    continue
                
                max_interference_value = max(
                    interference['interference_value'] 
                    for interference in interferences
                )
                
                high_activation_tokens = feature.get('high_activation_tokens', [])
                
                # Categorize based on interference value
                if max_interference_value >= high_threshold:
                    high_interference_tokens.update(high_activation_tokens[:N_SELECTED_TOKENS])
                elif max_interference_value >= baseline_threshold:
                    medium_interference_tokens.update(high_activation_tokens[:N_SELECTED_TOKENS])
    
    # Get all vocabulary tokens for random set
    print(f"  Getting random tokens from vocabulary...")
    vocab_size = model.cfg.d_vocab
    all_tokens = set()
    for token_id in range(vocab_size):
        try:
            token_str = model.to_string(token_id)
            # Only add if it's a valid string token
            if token_str and token_str.strip():
                all_tokens.add(token_str)
        except:
            continue
    
    # Calculate random tokens before removing intersections
    used_tokens = target_tokens | high_interference_tokens | medium_interference_tokens
    random_tokens = all_tokens - used_tokens
    
    print(f"  Initial token set sizes (before intersection removal):")
    print(f"    target: {len(target_tokens)}")
    print(f"    high_interference: {len(high_interference_tokens)}")
    print(f"    medium_interference: {len(medium_interference_tokens)}")
    print(f"    random: {len(random_tokens)}")
    
    # Remove intersection between high and medium interference tokens
    print(f"  Removing intersections between high and medium interference...")
    intersection = high_interference_tokens & medium_interference_tokens
    print(f"    Found {len(intersection)} tokens in intersection")
    
    if intersection:
        # Show sample intersection tokens
        sample_intersection = list(intersection)[:5]
        print(f"    Sample intersection tokens: {sample_intersection}")
        
        # Remove intersection from medium interference (keep in high interference)
        medium_interference_tokens = medium_interference_tokens - intersection
        high_interference_tokens = high_interference_tokens - intersection
        print(f"    Removed intersection from medium interference set")
    
    # Convert to lists and remove empty/invalid strings
    token_sets = {
        'target': [t for t in target_tokens if t and t.strip()],
        'high_interference': [t for t in high_interference_tokens if t and t.strip()],
        'medium_interference': [t for t in medium_interference_tokens if t and t.strip()],
        'random': [t for t in random_tokens if t and t.strip()]
    }
    
    # Print final statistics
    print(f"  Final token set sizes (after intersection removal):")
    for set_name, token_list in token_sets.items():
        print(f"    {set_name}: {len(token_list)}")
    
    # Verify no intersection remains
    high_set = set(token_sets['high_interference'])
    medium_set = set(token_sets['medium_interference'])
    remaining_intersection = high_set & medium_set
    print(f"  Verification: remaining intersection = {len(remaining_intersection)}")
    
    return token_sets

def get_target_token_ids(model, token_strs: List[str]) -> List[int]:
    """Convert token strings to token IDs"""
    token_ids = []
    for token_str in token_strs:
        try:
            token_id = model.to_tokens(token_str, prepend_bos=False)[0]
            if len(token_id) == 1:
                token_ids.append(token_id.item())
            else:
                token_ids.append(token_id[0].item())
        except:
            continue
    return list(set(token_ids))  # Remove duplicates

In [None]:
# 6: Get interference token sets for all experiments (Updated with intersection info)
# Set random seed
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

# Define experiment configurations
token_types = [
    'location', 'person', 'emotion', 'color',
    'animal', 'number', 'science', 'time'
]

EXPERIMENTS = [
    {
        'token_type': token_type,
        'sentences': SENTENCES_OF_TYPE[token_type],
        'pythia_threshold': PYTHIA_HIGH_INTERFERENCE_THRESHOLD,
        'gpt2_threshold': GPT2_HIGH_INTERFERENCE_THRESHOLD
    }
    for token_type in token_types
]

# Get token sets for all experiments
print("="*60)
print("LOADING INTERFERENCE TOKEN SETS")
print("="*60)

experiment_token_sets = {}

# Pre-load token type mappings using utils_data
print("Loading token type mappings...")
try:
    pythia_token_id_to_type, pythia_token_str_to_type = load_type_dicts('pythia')
    print(f"  Loaded Pythia token types: {len(pythia_token_str_to_type)} tokens")
except Exception as e:
    print(f"  Error loading Pythia token types: {e}")
    pythia_token_str_to_type = {}

try:
    gpt2_token_id_to_type, gpt2_token_str_to_type = load_type_dicts('gpt2')
    print(f"  Loaded GPT-2 token types: {len(gpt2_token_str_to_type)} tokens")
except Exception as e:
    print(f"  Error loading GPT-2 token types: {e}")
    gpt2_token_str_to_type = {}

for exp_config in EXPERIMENTS:
    token_type = exp_config['token_type']
    
    print(f"\n--- {token_type.upper()} EXPERIMENT ---")
    
    # Get Pythia token sets
    try:
        pythia_sets = get_interference_token_sets(
            model_name='pythia',
            token_type=token_type,
            high_threshold=exp_config['pythia_threshold'],
            baseline_threshold=BASELINE_INTERFERENCE_THRESHOLD
        )
        
        # Convert target tokens to IDs
        pythia_target_ids = get_target_token_ids(pythia, pythia_sets['target'])
        
        # Verify target tokens are actually of target type using utils_data
        target_type_count = 0
        for token in pythia_sets['target']:
            token_type_result = get_token_str_type('pythia', token, pythia_token_str_to_type)
            if token_type_result == token_type:
                target_type_count += 1
        
        print(f"  Pythia target verification: {target_type_count}/{len(pythia_sets['target'])} are {token_type} type")
        
        # Verify no intersection between high and medium interference
        pythia_high_set = set(pythia_sets['high_interference'])
        pythia_medium_set = set(pythia_sets['medium_interference'])
        pythia_intersection = pythia_high_set & pythia_medium_set
        print(f"  Pythia intersection verification: {len(pythia_intersection)} tokens overlap")
        
        experiment_token_sets[f'pythia_{token_type}'] = {
            'token_sets': pythia_sets,
            'target_ids': pythia_target_ids,
            'sentences': exp_config['sentences']
        }
        
    except Exception as e:
        print(f"  Error loading Pythia {token_type}: {e}")
        continue
    
    # Get GPT-2 token sets  
    try:
        gpt2_sets = get_interference_token_sets(
            model_name='gpt2',
            token_type=token_type,
            high_threshold=exp_config['gpt2_threshold'],
            baseline_threshold=BASELINE_INTERFERENCE_THRESHOLD
        )
        
        # Convert target tokens to IDs
        gpt2_target_ids = get_target_token_ids(gpt2, gpt2_sets['target'])
        
        # Verify target tokens are actually of target type using utils_data
        target_type_count = 0
        for token in gpt2_sets['target']:
            token_type_result = get_token_str_type('gpt2', token, gpt2_token_str_to_type)
            if token_type_result == token_type:
                target_type_count += 1
        
        print(f"  GPT-2 target verification: {target_type_count}/{len(gpt2_sets['target'])} are {token_type} type")
        
        # Verify no intersection between high and medium interference
        gpt2_high_set = set(gpt2_sets['high_interference'])
        gpt2_medium_set = set(gpt2_sets['medium_interference'])
        gpt2_intersection = gpt2_high_set & gpt2_medium_set
        print(f"  GPT-2 intersection verification: {len(gpt2_intersection)} tokens overlap")
        
        experiment_token_sets[f'gpt2_{token_type}'] = {
            'token_sets': gpt2_sets,
            'target_ids': gpt2_target_ids,
            'sentences': exp_config['sentences']
        }
        
    except Exception as e:
        print(f"  Error loading GPT-2 {token_type}: {e}")
        continue

print("\n" + "="*60)
print("TOKEN SETS LOADING COMPLETE")
print("="*60)

# Print detailed summary
print(f"\nDETAILED SUMMARY:")
for exp_name, exp_data in experiment_token_sets.items():
    print(f"\n{exp_name.upper()}:")
    token_sets = exp_data['token_sets']
    target_ids = exp_data['target_ids']
    sentences = exp_data['sentences']
    
    print(f"  Target tokens: {len(token_sets['target'])} -> {len(target_ids)} IDs")
    print(f"  High interference: {len(token_sets['high_interference'])}")
    print(f"  Medium interference: {len(token_sets['medium_interference'])}")
    print(f"  Random tokens: {len(token_sets['random'])}")
    print(f"  Test sentences: {len(sentences)}")
    
    # Show sample tokens from each category
    for category, tokens in token_sets.items():
        if tokens and category == 'high_interference':
            sample_tokens = tokens  # Show all tokens
            print(f"    {category} samples: {sample_tokens}")

clear_memory()

In [None]:
# 7: Define experiment functions (Fixed)
def run_injection_experiment(
    model,
    sentence: str,
    injection_tokens: List[str],
    target_ids: List[int],
    n_trials: int = N_TRIALS,
    top_k: int = TOP_K,
    injection_size: int = 10
) -> Dict:
    """
    Run injection experiments with specified tokens
    
    Args:
        model: HookedTransformer model
        sentence: input sentence
        injection_tokens: tokens to inject 
        target_ids: target token ids to measure against
        n_trials: number of trials
        top_k: number of top predictions to consider
        injection_size: number of tokens to inject per trial
    
    Returns:
        dict with success metrics and detailed trial data
    """
    results = {
        'count_increases': 0,
        'prob_increases': 0,
        'total_trials': n_trials,
        'detailed_metrics': []
    }
    
    # Convert target_ids to set for faster lookup
    target_ids_set = set(target_ids) if target_ids else set()
    
    # Get original predictions
    original_top_k = get_next_token_probs(model, sentence, top_k)
    original_ids = [t[0] for t in original_top_k] 
    original_probs = [t[2] for t in original_top_k]
    
    # Count original target tokens and their total probability
    original_count = sum(1 for tid in original_ids if tid in target_ids_set)
    original_prob = sum(p for tid, p in zip(original_ids, original_probs) if tid in target_ids_set)
    
    # Convert to float to avoid tensor issues
    original_count = float(original_count)
    original_prob = float(original_prob)
    
    # Run injection trials
    for trial in range(n_trials):
        try:
            # Randomly sample tokens for injection
            if len(injection_tokens) >= injection_size:
                selected_tokens = random.sample(injection_tokens, injection_size)
            else:
                # If not enough tokens, sample with replacement
                selected_tokens = random.choices(injection_tokens, k=injection_size)
            
            # Create injected sentence
            injected_sentence = inject_tokens(sentence, selected_tokens)
            
            # Get predictions for injected sentence
            injected_top_k = get_next_token_probs(model, injected_sentence, top_k)
            injected_ids = [t[0] for t in injected_top_k]
            injected_probs = [t[2] for t in injected_top_k]
            
            # Count injected target tokens and their total probability
            injected_count = sum(1 for tid in injected_ids if tid in target_ids_set)
            injected_prob = sum(p for tid, p in zip(injected_ids, injected_probs) if tid in target_ids_set)
            
            # Convert to float to avoid tensor issues
            injected_count = float(injected_count)
            injected_prob = float(injected_prob)
            
            # Calculate additional metrics
            try:
                original_sim = get_weighted_cosine_similarity(
                    model.embed.W_E.detach(),
                    torch.tensor(original_ids, dtype=torch.long),
                    torch.tensor(original_probs, dtype=torch.float32),
                    torch.tensor(target_ids, dtype=torch.long)
                )
                
                injected_sim = get_weighted_cosine_similarity(
                    model.embed.W_E.detach(),
                    torch.tensor(injected_ids, dtype=torch.long),
                    torch.tensor(injected_probs, dtype=torch.float32),
                    torch.tensor(target_ids, dtype=torch.long)
                )
                
                sim_change = float(injected_sim - original_sim)
                
            except Exception as e:
                sim_change = 0.0
            
            try:
                original_overlap = get_weighted_overlap(original_ids, target_ids, original_probs)
                injected_overlap = get_weighted_overlap(injected_ids, target_ids, injected_probs)
                overlap_change = float(injected_overlap - original_overlap)
            except Exception as e:
                overlap_change = 0.0
            
            try:
                original_rank = get_best_rank(original_ids, target_ids)
                injected_rank = get_best_rank(injected_ids, target_ids)
                rank_change = None
                if original_rank is not None and injected_rank is not None:
                    rank_change = float(original_rank - injected_rank)  # Positive = improvement (lower rank)
                elif original_rank is None and injected_rank is not None:
                    rank_change = float(top_k - injected_rank)  # Improvement from no rank to some rank
            except Exception as e:
                rank_change = None
            
            # Record trial data - ensure all values are Python primitives
            trial_data = {
                'trial': trial,
                'selected_tokens': selected_tokens,
                'original_count': original_count,
                'injected_count': injected_count,
                'original_prob': original_prob,
                'injected_prob': injected_prob,
                'count_increase': bool(injected_count > original_count),  # Explicit bool conversion
                'prob_increase': bool(injected_prob > original_prob),     # Explicit bool conversion
                'similarity_change': sim_change,
                'overlap_change': overlap_change,
                'rank_change': rank_change,
                'injected_sentence': injected_sentence,
                'original_predictions': original_top_k,
                'injected_predictions': injected_top_k
            }
            results['detailed_metrics'].append(trial_data)
            
            # Update success counters
            if injected_count > original_count:
                results['count_increases'] += 1
            if injected_prob > original_prob:
                results['prob_increases'] += 1
                
        except Exception as e:
            print(f"    Error in trial {trial}: {e}")
            continue
    
    return results

def run_sentence_experiments(
    model,
    model_name: str,
    token_sets: Dict[str, List[str]],
    target_ids: List[int],
    sentences: List[str],
    max_sentences: int = 50
) -> Dict:
    """
    Run experiments for multiple sentences with different token sets
    
    Args:
        model: HookedTransformer model
        model_name: name of the model for logging
        token_sets: dict of token categories and their tokens
        target_ids: target token IDs
        sentences: list of test sentences
        max_sentences: maximum number of sentences to test
    
    Returns:
        dict with results for each token category
    """
    # Limit number of sentences
    test_sentences = sentences[:max_sentences]
    
    results = {
        'target': [],
        'high_interference': [],
        'medium_interference': [],
        'random': []
    }
    
    print(f"\nRunning experiments for {model_name}")
    print(f"Testing {len(test_sentences)} sentences with {N_TRIALS} trials each")
    print(f"Token set sizes:")
    for category, tokens in token_sets.items():
        if tokens:  # Only show non-empty sets
            print(f"  {category}: {len(tokens)}")
    
    # Run experiments for each sentence
    for sent_idx, sentence in enumerate(tqdm(test_sentences, desc=f"{model_name} sentences")):
        clear_memory()  # Clear memory between sentences
        
        # Test each token category
        for category in ['target', 'high_interference', 'medium_interference', 'random']:
            if category not in token_sets or not token_sets[category]:
                if sent_idx == 0:  # Only warn once
                    print(f"    Warning: No {category} tokens available, skipping")
                continue

            try:
                experiment_result = run_injection_experiment(
                    model=model,
                    sentence=sentence,
                    injection_tokens=token_sets[category],
                    target_ids=target_ids,
                    n_trials=N_TRIALS,
                    top_k=TOP_K,
                    injection_size=N_INJECTION
                )
                
                results[category].append({
                    'sentence_idx': sent_idx,
                    'sentence': sentence,
                    'results': experiment_result
                })
                
            except Exception as e:
                print(f"    Error in {category} experiment for sentence {sent_idx}: {e}")
                continue
    
    return results

In [None]:
# 8: Statistical analysis functions
from scipy.stats import ttest_ind

def analyze_experiment_results(
    results: Dict,
    model_name: str,
    token_type: str
) -> Dict:
    """
    Analyze experiment results with statistical tests
    
    Args:
        results: experiment results from run_sentence_experiments
        model_name: name of the model
        token_type: type of tokens being tested
    
    Returns:
        dict with statistical analysis results
    """
    print(f"\n{'='*60}")
    print(f"{model_name.upper()} {token_type.upper()} EXPERIMENT ANALYSIS")
    print(f"{'='*60}")
    
    categories = ['target', 'high_interference', 'medium_interference', 'random']
    
    # Calculate basic success rates
    analysis = {}
    
    for category in categories:
        if category not in results or not results[category]:
            print(f"\nWarning: No results for {category}")
            continue
        
        # Collect all trial results
        count_increases = []
        prob_increases = []
        total_trials = 0
        
        for sentence_result in results[category]:
            for trial_data in sentence_result['results']['detailed_metrics']:
                count_increases.append(1 if trial_data['count_increase'] else 0)
                prob_increases.append(1 if trial_data['prob_increase'] else 0)
                total_trials += 1
        
        if total_trials == 0:
            continue
        
        # Calculate success rates
        count_success_rate = sum(count_increases) / total_trials
        prob_success_rate = sum(prob_increases) / total_trials
        
        analysis[category] = {
            'total_trials': total_trials,
            'count_increases': sum(count_increases),
            'prob_increases': sum(prob_increases),
            'count_success_rate': count_success_rate,
            'prob_success_rate': prob_success_rate,
            'count_binary_sequence': count_increases,
            'prob_binary_sequence': prob_increases
        }
        
        print(f"\n{category.upper()}:")
        print(f"  Total trials: {total_trials}")
        print(f"  Count increases: {sum(count_increases)} ({count_success_rate:.3f})")
        print(f"  Prob increases: {sum(prob_increases)} ({prob_success_rate:.3f})")
    
    # Statistical comparisons with random baseline
    if 'random' in analysis:
        random_data = analysis['random']
        
        print(f"\n{'='*40}")
        print("STATISTICAL COMPARISONS vs RANDOM")
        print(f"{'='*40}")
        
        for category in ['target', 'high_interference', 'medium_interference']:
            if category not in analysis:
                continue
            
            category_data = analysis[category]
            
            print(f"\n{category.upper()} vs RANDOM:")
            
            # Count increase comparison using t-test
            count_stat, count_pval = ttest_ind(
                category_data['count_binary_sequence'],
                random_data['count_binary_sequence'],
                equal_var=False  # Welch's t-test
            )
            
            # Probability increase comparison using t-test
            prob_stat, prob_pval = ttest_ind(
                category_data['prob_binary_sequence'],
                random_data['prob_binary_sequence'],
                equal_var=False  # Welch's t-test
            )
            
            # Store statistical results
            analysis[f'{category}_vs_random'] = {
                'count_diff': category_data['count_success_rate'] - random_data['count_success_rate'],
                'count_tstat': count_stat,
                'count_pvalue': count_pval,
                'count_significant': count_pval < 0.05,
                'prob_diff': category_data['prob_success_rate'] - random_data['prob_success_rate'],
                'prob_tstat': prob_stat,
                'prob_pvalue': prob_pval,
                'prob_significant': prob_pval < 0.05
            }
            
            # Print results
            print(f"  Count increase rate difference: {category_data['count_success_rate']:.3f} - {random_data['count_success_rate']:.3f} = {analysis[f'{category}_vs_random']['count_diff']:.3f}")
            print(f"    T-statistic: {count_stat:.3f}, p-value: {count_pval:.6f}")
            print(f"    Significant: {'YES' if count_pval < 0.05 else 'NO'}")
            
            print(f"  Prob increase rate difference: {category_data['prob_success_rate']:.3f} - {random_data['prob_success_rate']:.3f} = {analysis[f'{category}_vs_random']['prob_diff']:.3f}")
            print(f"    T-statistic: {prob_stat:.3f}, p-value: {prob_pval:.6f}")
            print(f"    Significant: {'YES' if prob_pval < 0.05 else 'NO'}")
    
    return analysis

def print_experiment_summary(analysis: Dict, model_name: str, token_type: str):
    """Print a concise summary of experiment results"""
    print(f"\n{'='*60}")
    print(f"{model_name.upper()} {token_type.upper()} SUMMARY")
    print(f"{'='*60}")
    
    categories = ['target', 'high_interference', 'medium_interference', 'random']
    
    print(f"{'Category':<20} {'Count Rate':<12} {'Prob Rate':<12} {'vs Random':<15}")
    print(f"{'-'*20} {'-'*12} {'-'*12} {'-'*15}")
    
    for category in categories:
        if category not in analysis:
            continue
        
        count_rate = analysis[category]['count_success_rate']
        prob_rate = analysis[category]['prob_success_rate']
        
        vs_random_str = ""
        if f'{category}_vs_random' in analysis:
            vs_random = analysis[f'{category}_vs_random']
            count_sig = "**" if vs_random['count_significant'] else ""
            prob_sig = "**" if vs_random['prob_significant'] else ""
            vs_random_str = f"{vs_random['count_diff']:+.3f}{count_sig}/{vs_random['prob_diff']:+.3f}{prob_sig}"
        elif category == 'random':
            vs_random_str = "baseline"
        
        print(f"{category:<20} {count_rate:<12.3f} {prob_rate:<12.3f} {vs_random_str:<15}")
    
    print("\n** = statistically significant (p < 0.05)")

In [None]:
# 9: Run single experiment example (template for each token type)
# Example for location tokens - modify token_type and model as needed

def run_single_token_type_experiment(
    model_name: str, 
    token_type: str, 
    max_sentences: int = 50
):
    """
    Run experiment for a single model and token type
    
    Args:
        model_name: 'pythia' or 'gpt2'
        token_type: 'location', 'person', or 'emotion'
        max_sentences: maximum number of sentences to test
    
    Returns:
        tuple of (results, analysis)
    """
    exp_key = f"{model_name}_{token_type}"
    
    if exp_key not in experiment_token_sets:
        print(f"Error: {exp_key} not found in experiment_token_sets")
        return None, None
    
    exp_data = experiment_token_sets[exp_key]
    token_sets = exp_data['token_sets']
    target_ids = exp_data['target_ids']
    sentences = exp_data['sentences']
    
    # Get model
    if model_name == 'pythia':
        model = pythia
    elif model_name == 'gpt2':
        model = gpt2
    else:
        raise ValueError(f"Unknown model: {model_name}")
    
    print(f"Starting {model_name} {token_type} experiment...")
    print(f"Target IDs: {len(target_ids)}")
    print(f"Available sentences: {len(sentences)}")
    
    # Run experiments
    results = run_sentence_experiments(
        model=model,
        model_name=model_name,
        token_sets=token_sets,
        target_ids=target_ids,
        sentences=sentences,
        max_sentences=max_sentences
    )
    
    # Analyze results
    analysis = analyze_experiment_results(
        results=results,
        model_name=model_name,
        token_type=token_type
    )
    
    # Print summary
    print_experiment_summary(analysis, model_name, token_type)
    
    return results, analysis

# Template usage - uncomment and modify as needed:
# results, analysis = run_single_token_type_experiment('pythia', 'location', max_sentences=50)

In [None]:
for token_type in token_types:
    for model_name in ['pythia', 'gpt2']:
        results, analysis = run_single_token_type_experiment(model_name, token_type, max_sentences=100)