In [None]:
# 1: Import packages for large model experiments (Updated)
import os
import json
import torch
import random
import numpy as np
import torch.nn.functional as F
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from collections import defaultdict
from pathlib import Path
import gc
from typing import Dict, List, Set, Tuple, Optional

# Import model utilities and sentences
import sys
sys.path.append(os.path.dirname(os.path.abspath('__file__')))
from utils.utils_model import get_hooked_pythia_70m, get_hooked_gpt2_small, get_llama_3_8B, get_gemma_2_9B
from utils.utils_data import load_type_dicts, get_token_id_type, get_token_str_type
from corpus.type_sentences import SENTENCES_OF_TYPE
from scipy.stats import ttest_ind

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Parameters
PYTHIA_HIGH_INTERFERENCE_THRESHOLD = 0.5
GPT2_HIGH_INTERFERENCE_THRESHOLD = 0.5
BASELINE_INTERFERENCE_THRESHOLD = 0.2
N_TRIALS = 100
N_TEST_SENTENCES = 100
TOP_K = 30
INJECTION_SIZE = 30
N_SELECTED_ACT_TOKENS = 10  # Number of high_activation_tokens to select per interference feature
TOKEN_SET_TYPE = 'union'
RANDOM_SEED = 42
MED_DATA_DIR = Path("./med_data")

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

print(f"Parameters set with thresholds: Pythia={PYTHIA_HIGH_INTERFERENCE_THRESHOLD}, GPT-2={GPT2_HIGH_INTERFERENCE_THRESHOLD}")
print(f"N_SELECTED_ACT_TOKENS: {N_SELECTED_ACT_TOKENS}")
print(f"TOKEN_SET_TYPE: {TOKEN_SET_TYPE}")
print(f"Available sentence types: {list(SENTENCES_OF_TYPE.keys())}")
for token_type, sentences in SENTENCES_OF_TYPE.items():
    print(f"  {token_type}: {len(sentences)} sentences")

In [2]:
# Cell 2: Define functions to generate both union and overlap token sets
def generate_union_token_sets(token_type: str) -> Dict[str, List[str]]:
    """
    Generate union token sets from Pythia and GPT-2 interference data
    First remove high/medium overlap within each model, then take unions
    
    Args:
        token_type: token type name
    
    Returns:
        Dict with union token sets for large model testing
    """
    print(f"\n{'='*60}")
    print(f"GENERATING UNION TOKEN SETS FOR {token_type.upper()}")
    print(f"{'='*60}")
    
    # Load token type mappings
    pythia_token_id_to_type, pythia_token_str_to_type = load_type_dicts('pythia')
    gpt2_token_id_to_type, gpt2_token_str_to_type = load_type_dicts('gpt2')
    
    # Load models first for vocabulary extraction
    pythia_model = get_hooked_pythia_70m(device)
    gpt2_model = get_hooked_gpt2_small(device)
    
    # Initialize token sets for both models
    pythia_target = set()
    pythia_high_interference = set()
    pythia_medium_interference = set()
    gpt2_target = set()
    gpt2_high_interference = set()
    gpt2_medium_interference = set()
    
    # Process Pythia data
    print(f"Processing Pythia {token_type} data...")
    pythia_file = MED_DATA_DIR / f"pythia_{token_type}_tokens_r0.8_i0.2_s0.3.json"
    if pythia_file.exists():
        with open(pythia_file, 'r', encoding='utf-8') as f:
            pythia_data = json.load(f)
        
        # Extract target tokens from vocabulary
        for token_str, token_type_label in pythia_token_str_to_type.items():
            if token_type_label == token_type:
                pythia_target.add(token_str)
        
        # Extract interference tokens
        for layer_type, layers in pythia_data.items():
            if layer_type == 'summary' or not isinstance(layers, dict):
                continue
            
            for layer_idx, layer_data in layers.items():
                if not isinstance(layer_data, dict):
                    continue
                
                # Process target features
                # target_features = layer_data.get('target_features', [])
                # for feature in target_features:
                #     high_activation_tokens = feature.get('high_activation_tokens', [])
                #     for token_str in high_activation_tokens:
                #         token_type_label = get_token_str_type('pythia', token_str, pythia_token_str_to_type)
                #         if token_type_label != token_type:
                #             pythia_high_interference.add(token_str)
                
                # Process interference features
                interference_features = layer_data.get('interference_features', [])
                for feature in interference_features:
                    interferences = feature.get('interferences', [])
                    if not interferences:
                        continue
                    
                    max_interference_value = max(
                        interference['interference_value'] for interference in interferences
                    )
                    
                    high_activation_tokens = feature.get('high_activation_tokens', [])
                    
                    if max_interference_value >= PYTHIA_HIGH_INTERFERENCE_THRESHOLD:
                        pythia_high_interference.update(high_activation_tokens[:N_SELECTED_ACT_TOKENS])
                    elif max_interference_value >= BASELINE_INTERFERENCE_THRESHOLD:
                        pythia_medium_interference.update(high_activation_tokens[:N_SELECTED_ACT_TOKENS])
        
        print(f"  Pythia before dedup - Target: {len(pythia_target)}, High: {len(pythia_high_interference)}, Medium: {len(pythia_medium_interference)}")
    
    # Process GPT-2 data
    print(f"Processing GPT-2 {token_type} data...")
    gpt2_file = MED_DATA_DIR / f"gpt2_{token_type}_tokens_r0.8_i0.2_s0.3.json"
    if gpt2_file.exists():
        with open(gpt2_file, 'r', encoding='utf-8') as f:
            gpt2_data = json.load(f)
        
        # Extract target tokens from vocabulary
        for token_str, token_type_label in gpt2_token_str_to_type.items():
            if token_type_label == token_type:
                gpt2_target.add(token_str)
        
        # Extract interference tokens
        for layer_type, layers in gpt2_data.items():
            if layer_type == 'summary' or not isinstance(layers, dict):
                continue
            
            for layer_idx, layer_data in layers.items():
                if not isinstance(layer_data, dict):
                    continue
                
                # Process target features
                # target_features = layer_data.get('target_features', [])
                # for feature in target_features:
                #     high_activation_tokens = feature.get('high_activation_tokens', [])
                #     for token_str in high_activation_tokens:
                #         token_type_label = get_token_str_type('gpt2', token_str, gpt2_token_str_to_type)
                #         if token_type_label != token_type:
                #             gpt2_high_interference.add(token_str)
                
                # Process interference features
                interference_features = layer_data.get('interference_features', [])
                for feature in interference_features:
                    interferences = feature.get('interferences', [])
                    if not interferences:
                        continue
                    
                    max_interference_value = max(
                        interference['interference_value'] for interference in interferences
                    )
                    
                    high_activation_tokens = feature.get('high_activation_tokens', [])
                    
                    if max_interference_value >= GPT2_HIGH_INTERFERENCE_THRESHOLD:
                        gpt2_high_interference.update(high_activation_tokens[:N_SELECTED_ACT_TOKENS])
                    elif max_interference_value >= BASELINE_INTERFERENCE_THRESHOLD:
                        gpt2_medium_interference.update(high_activation_tokens[:N_SELECTED_ACT_TOKENS])
        
        print(f"  GPT-2 before dedup - Target: {len(gpt2_target)}, High: {len(gpt2_high_interference)}, Medium: {len(gpt2_medium_interference)}")
    
    # STEP 1: Remove high/medium overlap WITHIN EACH MODEL
    print(f"\nSTEP 1: Removing high/medium overlaps within each model...")
    
    # Remove Pythia high/medium overlap (keep in high)
    pythia_high_medium_overlap = pythia_high_interference & pythia_medium_interference
    print(f"  Pythia high/medium overlap: {len(pythia_high_medium_overlap)} tokens")
    
    pythia_high_interference_clean = pythia_high_interference - pythia_high_medium_overlap
    pythia_medium_interference_clean = pythia_medium_interference - pythia_high_medium_overlap
    
    # Remove GPT-2 high/medium overlap (keep in high)
    gpt2_high_medium_overlap = gpt2_high_interference & gpt2_medium_interference
    print(f"  GPT-2 high/medium overlap: {len(gpt2_high_medium_overlap)} tokens")
    
    gpt2_high_interference_clean = gpt2_high_interference - gpt2_high_medium_overlap
    gpt2_medium_interference_clean = gpt2_medium_interference - gpt2_high_medium_overlap
    
    print(f"  After within-model dedup:")
    print(f"    Pythia - High: {len(pythia_high_interference_clean)}, Medium: {len(pythia_medium_interference_clean)}")
    print(f"    GPT-2 - High: {len(gpt2_high_interference_clean)}, Medium: {len(gpt2_medium_interference_clean)}")
    
    # Calculate random sets for each model (using original sets before dedup)
    print(f"\nCalculating individual random sets...")
    
    # Get all vocabulary tokens
    all_pythia_tokens = set()
    for token_id in range(pythia_model.cfg.d_vocab):
        try:
            token_str = pythia_model.to_string(token_id)
            if token_str and token_str.strip():
                all_pythia_tokens.add(token_str)
        except:
            continue
    
    all_gpt2_tokens = set()
    for token_id in range(gpt2_model.cfg.d_vocab):
        try:
            token_str = gpt2_model.to_string(token_id)
            if token_str and token_str.strip():
                all_gpt2_tokens.add(token_str)
        except:
            continue
    
    # Calculate random sets for each model (using original interference sets)
    pythia_used_tokens = pythia_target | pythia_high_interference | pythia_medium_interference
    pythia_random = all_pythia_tokens - pythia_used_tokens
    
    gpt2_used_tokens = gpt2_target | gpt2_high_interference | gpt2_medium_interference
    gpt2_random = all_gpt2_tokens - gpt2_used_tokens
    
    print(f"  Pythia random: {len(pythia_random)} tokens")
    print(f"  GPT-2 random: {len(gpt2_random)} tokens")
    
    # STEP 2: Calculate unions of cleaned sets
    print(f"\nSTEP 2: Calculating unions of cleaned interference sets...")
    
    # Get unions of cleaned sets
    union_high_interference = pythia_high_interference_clean | gpt2_high_interference_clean
    union_medium_interference = pythia_medium_interference_clean | gpt2_medium_interference_clean
    target_union = pythia_target | gpt2_target
    random_union = pythia_random | gpt2_random
    
    # Convert to lists and filter
    token_sets = {
        'target': [t for t in target_union if t and t.strip()],
        'high_interference': [t for t in union_high_interference if t and t.strip()],
        'medium_interference': [t for t in union_medium_interference if t and t.strip()],
        'random': [t for t in random_union if t and t.strip()]
    }
    
    # Print final statistics
    print(f"\nFinal union token sets:")
    for category, tokens in token_sets.items():
        print(f"  {category}: {len(tokens)} tokens")
    
    # Save to file
    output_file = MED_DATA_DIR / f"union_{token_type}_token_sets_{N_SELECTED_ACT_TOKENS}tk_{PYTHIA_HIGH_INTERFERENCE_THRESHOLD}ph_{GPT2_HIGH_INTERFERENCE_THRESHOLD}gh.json"
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(token_sets, f, ensure_ascii=False, indent=2)
    print(f"\nSaved union token sets to: {output_file}")
    
    # Clear models from memory
    del pythia_model, gpt2_model
    gc.collect()
    torch.cuda.empty_cache()
    
    return token_sets

def generate_overlap_token_sets(token_type: str) -> Dict[str, List[str]]:
    """
    Generate overlap token sets from Pythia and GPT-2 interference data
    Use intersections for interference sets, unions for target/random
    
    Args:
        token_type: token type name
    
    Returns:
        Dict with overlap token sets for large model testing
    """
    print(f"\n{'='*60}")
    print(f"GENERATING OVERLAP TOKEN SETS FOR {token_type.upper()}")
    print(f"{'='*60}")
    
    # Load token type mappings
    pythia_token_id_to_type, pythia_token_str_to_type = load_type_dicts('pythia')
    gpt2_token_id_to_type, gpt2_token_str_to_type = load_type_dicts('gpt2')
    
    # Load models first for vocabulary extraction
    pythia_model = get_hooked_pythia_70m(device)
    gpt2_model = get_hooked_gpt2_small(device)
    
    # Initialize token sets for both models
    pythia_target = set()
    pythia_high_interference = set()
    pythia_medium_interference = set()
    gpt2_target = set()
    gpt2_high_interference = set()
    gpt2_medium_interference = set()
    
    # Process Pythia data
    print(f"Processing Pythia {token_type} data...")
    pythia_file = MED_DATA_DIR / f"pythia_{token_type}_tokens_r0.8_i0.2_s0.3.json"
    if pythia_file.exists():
        with open(pythia_file, 'r', encoding='utf-8') as f:
            pythia_data = json.load(f)
        
        # Extract target tokens from vocabulary
        for token_str, token_type_label in pythia_token_str_to_type.items():
            if token_type_label == token_type:
                pythia_target.add(token_str)
        
        # Extract interference tokens (NO within-model deduplication)
        for layer_type, layers in pythia_data.items():
            if layer_type == 'summary' or not isinstance(layers, dict):
                continue
            
            for layer_idx, layer_data in layers.items():
                if not isinstance(layer_data, dict):
                    continue
                
                # Process target features
                target_features = layer_data.get('target_features', [])
                for feature in target_features:
                    high_activation_tokens = feature.get('high_activation_tokens', [])
                    for token_str in high_activation_tokens[:N_SELECTED_ACT_TOKENS]:
                        token_type_label = get_token_str_type('pythia', token_str, pythia_token_str_to_type)
                        if token_type_label != token_type:
                            pythia_high_interference.add(token_str)
                
                # Process interference features
                interference_features = layer_data.get('interference_features', [])
                for feature in interference_features:
                    interferences = feature.get('interferences', [])
                    if not interferences:
                        continue
                    
                    max_interference_value = max(
                        interference['interference_value'] for interference in interferences
                    )
                    
                    high_activation_tokens = feature.get('high_activation_tokens', [])
                    
                    if max_interference_value >= PYTHIA_HIGH_INTERFERENCE_THRESHOLD:
                        pythia_high_interference.update(high_activation_tokens[:N_SELECTED_ACT_TOKENS])
                    elif max_interference_value >= BASELINE_INTERFERENCE_THRESHOLD:
                        pythia_medium_interference.update(high_activation_tokens[:N_SELECTED_ACT_TOKENS])
        
        print(f"  Pythia - Target: {len(pythia_target)}, High: {len(pythia_high_interference)}, Medium: {len(pythia_medium_interference)}")
    
    # Process GPT-2 data
    print(f"Processing GPT-2 {token_type} data...")
    gpt2_file = MED_DATA_DIR / f"gpt2_{token_type}_tokens_r0.8_i0.2_s0.3.json"
    if gpt2_file.exists():
        with open(gpt2_file, 'r', encoding='utf-8') as f:
            gpt2_data = json.load(f)
        
        # Extract target tokens from vocabulary
        for token_str, token_type_label in gpt2_token_str_to_type.items():
            if token_type_label == token_type:
                gpt2_target.add(token_str)
        
        # Extract interference tokens (NO within-model deduplication)
        for layer_type, layers in gpt2_data.items():
            if layer_type == 'summary' or not isinstance(layers, dict):
                continue
            
            for layer_idx, layer_data in layers.items():
                if not isinstance(layer_data, dict):
                    continue
                
                # Process target features
                target_features = layer_data.get('target_features', [])
                for feature in target_features:
                    high_activation_tokens = feature.get('high_activation_tokens', [])
                    for token_str in high_activation_tokens[:N_SELECTED_ACT_TOKENS]:
                        token_type_label = get_token_str_type('gpt2', token_str, gpt2_token_str_to_type)
                        if token_type_label != token_type:
                            gpt2_high_interference.add(token_str)
                
                # Process interference features
                interference_features = layer_data.get('interference_features', [])
                for feature in interference_features:
                    interferences = feature.get('interferences', [])
                    if not interferences:
                        continue
                    
                    max_interference_value = max(
                        interference['interference_value'] for interference in interferences
                    )
                    
                    high_activation_tokens = feature.get('high_activation_tokens', [])
                    
                    if max_interference_value >= GPT2_HIGH_INTERFERENCE_THRESHOLD:
                        gpt2_high_interference.update(high_activation_tokens[:N_SELECTED_ACT_TOKENS])
                    elif max_interference_value >= BASELINE_INTERFERENCE_THRESHOLD:
                        gpt2_medium_interference.update(high_activation_tokens[:N_SELECTED_ACT_TOKENS])
        
        print(f"  GPT-2 - Target: {len(gpt2_target)}, High: {len(gpt2_high_interference)}, Medium: {len(gpt2_medium_interference)}")
    
    # Calculate random sets for each model
    print(f"\nCalculating individual random sets...")
    
    # Get all vocabulary tokens
    all_pythia_tokens = set()
    for token_id in range(pythia_model.cfg.d_vocab):
        try:
            token_str = pythia_model.to_string(token_id)
            if token_str and token_str.strip():
                all_pythia_tokens.add(token_str)
        except:
            continue
    
    all_gpt2_tokens = set()
    for token_id in range(gpt2_model.cfg.d_vocab):
        try:
            token_str = gpt2_model.to_string(token_id)
            if token_str and token_str.strip():
                all_gpt2_tokens.add(token_str)
        except:
            continue
    
    # Calculate random sets: all_tokens - all_used_tokens
    pythia_used_tokens = pythia_target | pythia_high_interference | pythia_medium_interference
    pythia_random = all_pythia_tokens - pythia_used_tokens
    
    gpt2_used_tokens = gpt2_target | gpt2_high_interference | gpt2_medium_interference
    gpt2_random = all_gpt2_tokens - gpt2_used_tokens
    
    print(f"  Pythia random: {len(pythia_random)} tokens")
    print(f"  GPT-2 random: {len(gpt2_random)} tokens")
    
    # Calculate overlaps and unions
    print(f"\nCalculating overlaps and unions...")
    
    # Target and random: unions
    target_union = pythia_target | gpt2_target
    random_overlap = pythia_random & gpt2_random  # intersection for random
    
    # High and medium interference: intersections
    overlap_high_interference = pythia_high_interference & gpt2_high_interference
    overlap_medium_interference = pythia_medium_interference & gpt2_medium_interference
    
    print(f"  Target union: {len(target_union)} tokens")
    print(f"  High interference overlap: {len(overlap_high_interference)} tokens")
    print(f"  Medium interference overlap: {len(overlap_medium_interference)} tokens")
    print(f"  Random overlap: {len(random_overlap)} tokens")
    
    # Remove overlap between high and medium (keep in high)
    print(f"\nRemoving high/medium overlap...")
    high_medium_overlap = overlap_high_interference & overlap_medium_interference
    print(f"  High-Medium overlap to remove: {len(high_medium_overlap)} tokens")
    
    final_overlap_high = overlap_high_interference - high_medium_overlap
    final_overlap_medium = overlap_medium_interference - high_medium_overlap
    
    # Convert to lists and filter
    token_sets = {
        'target': [t for t in target_union if t and t.strip()],
        'high_interference': [t for t in final_overlap_high if t and t.strip()],
        'medium_interference': [t for t in final_overlap_medium if t and t.strip()],
        'random': [t for t in random_overlap if t and t.strip()]
    }
    
    # Print final statistics
    print(f"\nFinal overlap token sets:")
    for category, tokens in token_sets.items():
        print(f"  {category}: {len(tokens)} tokens")
    
    # Save to file
    output_file = MED_DATA_DIR / f"overlap_{token_type}_token_sets_{N_SELECTED_ACT_TOKENS}tk_{PYTHIA_HIGH_INTERFERENCE_THRESHOLD}ph_{GPT2_HIGH_INTERFERENCE_THRESHOLD}gh.json"
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(token_sets, f, ensure_ascii=False, indent=2)
    print(f"\nSaved overlap token sets to: {output_file}")
    
    # Clear models from memory
    del pythia_model, gpt2_model
    gc.collect()
    torch.cuda.empty_cache()
    
    return token_sets

def load_token_sets(token_type: str, set_type: str = 'union') -> Optional[Dict[str, List[str]]]:
    """Load pre-generated token sets"""
    filename = f"{set_type}_{token_type}_token_sets_{N_SELECTED_ACT_TOKENS}tk_{PYTHIA_HIGH_INTERFERENCE_THRESHOLD}ph_{GPT2_HIGH_INTERFERENCE_THRESHOLD}gh.json"
    token_sets_file = MED_DATA_DIR / filename
    if token_sets_file.exists():
        with open(token_sets_file, 'r', encoding='utf-8') as f:
            return json.load(f)
    return None

In [None]:
# Cell 3: Generate token sets for all types (Updated)
print(f"Generating {TOKEN_SET_TYPE} token sets for all token types...")

token_sets = {}
token_types = [
    'location', 'person', 'emotion', 'color', 'animal', 'number', 'science', 'time',
]

for token_type in token_types:
    try:
        # Try to load existing sets first
        existing_sets = load_token_sets(token_type, TOKEN_SET_TYPE)
        if existing_sets:
            print(f"\nLoaded existing {TOKEN_SET_TYPE} sets for {token_type}")
            token_sets[token_type] = existing_sets
        else:
            # Generate new sets
            if TOKEN_SET_TYPE == 'union':
                generated_sets = generate_union_token_sets(token_type)
            elif TOKEN_SET_TYPE == 'overlap':
                generated_sets = generate_overlap_token_sets(token_type)
            else:
                raise ValueError(f"Unknown TOKEN_SET_TYPE: {TOKEN_SET_TYPE}")
            
            token_sets[token_type] = generated_sets
    except Exception as e:
        print(f"Error processing {token_type}: {e}")
        continue

print(f"\n{'='*60}")
print(f"{TOKEN_SET_TYPE.upper()} TOKEN SETS GENERATION COMPLETE")
print(f"{'='*60}")

# Print summary
for token_type, token_set in token_sets.items():
    print(f"\n{token_type.upper()}:")
    for category, tokens in token_set.items():
        print(f"  {category}: {len(tokens)} tokens")

In [None]:
# Cell 4: Define utility functions for large model experiments (Updated with large model token type checking)
from utils.utils_data import get_large_model_token_type, load_large_model_token_type_dict
LARGE_MODEL_TOKEN_TYPE_PATH = "./dataset/large_model_token_type.json"
large_model_token_type_dict = load_large_model_token_type_dict(LARGE_MODEL_TOKEN_TYPE_PATH)

def inject_tokens_large_model(sentence: str, tokens: List[str]) -> str:
    """Inject tokens at the beginning of sentence with parentheses"""
    return f"({''.join(tokens)}){sentence}"

def get_next_token_probs_large_model(model, tokenizer, text: str, top_k: int = 10) -> List[Tuple[int, str, float]]:
    """Get top-k next token predictions with probabilities for large models"""
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits[0, -1, :]
    
    probs = torch.softmax(logits, dim=-1)
    top_k_probs, top_k_indices = torch.topk(probs, top_k)
    
    results = []
    for idx, prob in zip(top_k_indices, top_k_probs):
        try:
            token_str = tokenizer.decode([idx.item()], skip_special_tokens=True)
            results.append((idx.item(), token_str, prob.item()))
        except:
            continue
    return results

def check_token_type_from_small_models(
    token_str: str, 
    target_type: str,
    pythia_token_str_to_type: Dict[str, str],
    gpt2_token_str_to_type: Dict[str, str]
) -> bool:
    """
    Check if a token string is of target type based on Pythia/GPT-2 mappings
    """
    # Check in Pythia mapping
    pythia_type = get_token_str_type('pythia', token_str, pythia_token_str_to_type)
    if pythia_type == target_type:
        return True
    # Check in GPT-2 mapping
    gpt2_type = get_token_str_type('gpt2', token_str, gpt2_token_str_to_type)
    if gpt2_type == target_type:
        return True
    return False

def check_token_type_from_large_model(
    token_str: str,
    target_type: str,
    large_model_token_type_dict: dict
) -> bool:
    """
    Check if a token string is of target type based on large model token type dict
    """
    token_type = get_large_model_token_type(token_str, large_model_token_type_dict)
    return token_type == target_type

def run_injection_experiment_large_model_with_type_check(
    model,
    tokenizer,
    sentence: str,
    injection_tokens: List[str],
    target_type: str,
    pythia_token_str_to_type: Dict[str, str],
    gpt2_token_str_to_type: Dict[str, str],
    n_trials: int = N_TRIALS,
    top_k: int = TOP_K,
    injection_size: int = INJECTION_SIZE,
    use_large_model_type: bool = False,
    large_model_token_type_dict: dict = None
) -> Dict:
    """
    Run injection experiments for large models with token type checking
    Args:
        use_large_model_type: whether to use large model token type dict for checking
        large_model_token_type_dict: large model token type dict
    """
    results = {
        'count_increases': 0,
        'prob_increases': 0,
        'total_trials': n_trials,
        'detailed_metrics': []
    }
    # Get original predictions
    original_top_k = get_next_token_probs_large_model(model, tokenizer, sentence, top_k)
    original_ids = [t[0] for t in original_top_k]
    original_strs = [t[1] for t in original_top_k]
    original_probs = [t[2] for t in original_top_k]
    # Count original target tokens using type checking
    original_count = 0
    original_prob = 0.0
    for token_str, prob in zip(original_strs, original_probs):
        if use_large_model_type:
            is_target = check_token_type_from_large_model(token_str, target_type, large_model_token_type_dict)
        else:
            is_target = check_token_type_from_small_models(
                token_str, target_type, pythia_token_str_to_type, gpt2_token_str_to_type
            )
        if is_target:
            original_count += 1
            original_prob += prob
    original_count = float(original_count)
    original_prob = float(original_prob)
    # Run injection trials
    for trial in range(n_trials):
        try:
            # Sample injection tokens
            if len(injection_tokens) >= injection_size:
                selected_tokens = random.sample(injection_tokens, injection_size)
            else:
                selected_tokens = random.choices(injection_tokens, k=injection_size)
            # Create injected sentence
            injected_sentence = inject_tokens_large_model(sentence, selected_tokens)
            # Get predictions
            injected_top_k = get_next_token_probs_large_model(model, tokenizer, injected_sentence, top_k)
            injected_ids = [t[0] for t in injected_top_k]
            injected_strs = [t[1] for t in injected_top_k]
            injected_probs = [t[2] for t in injected_top_k]
            # Count injected target tokens using type checking
            injected_count = 0
            injected_prob = 0.0
            for token_str, prob in zip(injected_strs, injected_probs):
                if use_large_model_type:
                    is_target = check_token_type_from_large_model(token_str, target_type, large_model_token_type_dict)
                else:
                    is_target = check_token_type_from_small_models(
                        token_str, target_type, pythia_token_str_to_type, gpt2_token_str_to_type
                    )
                if is_target:
                    injected_count += 1
                    injected_prob += prob
            injected_count = float(injected_count)
            injected_prob = float(injected_prob)
            # Record trial data
            trial_data = {
                'trial': trial,
                'selected_tokens': selected_tokens,
                'original_count': original_count,
                'injected_count': injected_count,
                'original_prob': original_prob,
                'injected_prob': injected_prob,
                'count_increase': bool(injected_count > original_count),
                'prob_increase': bool(injected_prob > original_prob),
                'original_predictions': [(id, s, p) for id, s, p in zip(original_ids, original_strs, original_probs)],
                'injected_predictions': [(id, s, p) for id, s, p in zip(injected_ids, injected_strs, injected_probs)]
            }
            results['detailed_metrics'].append(trial_data)
            if injected_count > original_count:
                results['count_increases'] += 1
            if injected_prob > original_prob:
                results['prob_increases'] += 1
        except Exception as e:
            print(f"    Error in trial {trial}: {e}")
            continue
    return results

def clear_memory():
    """Clear GPU memory"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

In [None]:
# 5: Load large models
print("Loading large models...")

try:
    print("  Loading Llama 3.1 8B...")
    llama_model, llama_tokenizer = get_llama_3_8B()
    print(f"    Llama loaded on {llama_model.device}")
except Exception as e:
    print(f"    Error loading Llama: {e}")
    llama_model, llama_tokenizer = None, None

print("Large models loaded!")

In [None]:
# 6: Load Gemma model
print("Loading large models...")

try:
    print("  Loading Gemma 2 9B...")
    gemma_model, gemma_tokenizer = get_gemma_2_9B()
    print(f"    Gemma loaded on {gemma_model.device}")
except Exception as e:
    print(f"    Error loading Gemma: {e}")
    gemma_model, gemma_tokenizer = None, None

clear_memory()
print("Large models loaded!")

In [None]:
# 7: Run experiments on large models (Updated with t-test and large model type checking)
def run_large_model_experiment_with_type_check(
    model,
    tokenizer,
    model_name: str,
    token_type: str,
    max_sentences: int = 50
):
    """
    Run experiments on large models with token type checking from small models or large model dict
    """
    print(f"\n{'='*60}")
    print(f"RUNNING {model_name.upper()} {token_type.upper()} EXPERIMENT")
    print(f"Using {TOKEN_SET_TYPE} token sets")
    print(f"{'='*60}")

    # Get token sets and sentences
    if token_type not in token_sets:
        print(f"Error: No token sets for {token_type}")
        return None, None

    current_token_sets = token_sets[token_type]
    sentences = SENTENCES_OF_TYPE[token_type]
    test_sentences = random.sample(sentences, min(max_sentences, len(sentences)))

    # Load token type mappings for type checking
    print("Loading token type mappings for type checking...")
    try:
        pythia_token_id_to_type, pythia_token_str_to_type = load_type_dicts('pythia')
        print(f"  Loaded Pythia token types: {len(pythia_token_str_to_type)} tokens")
    except Exception as e:
        print(f"  Error loading Pythia token types: {e}")
        pythia_token_str_to_type = {}

    try:
        gpt2_token_id_to_type, gpt2_token_str_to_type = load_type_dicts('gpt2')
        print(f"  Loaded GPT-2 token types: {len(gpt2_token_str_to_type)} tokens")
    except Exception as e:
        print(f"  Error loading GPT-2 token types: {e}")
        gpt2_token_str_to_type = {}

    print(f"\nToken set sizes:")
    for category, tokens in current_token_sets.items():
        print(f"  {category}: {len(tokens)}")
    print(f"Test sentences: {len(test_sentences)}")

    use_large_model_type = model_name.lower() in ['llama', 'gemma']

    # Run experiments
    results = {}
    categories = ['target', 'high_interference', 'medium_interference', 'random']

    for category in categories:
        if not current_token_sets[category]:
            print(f"\nSkipping {category} - no tokens available")
            continue

        print(f"\nTesting {category} tokens...")
        category_results = []

        for sent_idx, sentence in enumerate(tqdm(test_sentences, desc=f"{category}")):
            if sent_idx % 20 == 19:
                clear_memory()
            try:
                experiment_result = run_injection_experiment_large_model_with_type_check(
                    model=model,
                    tokenizer=tokenizer,
                    sentence=sentence,
                    injection_tokens=current_token_sets[category],
                    target_type=token_type,
                    pythia_token_str_to_type=pythia_token_str_to_type,
                    gpt2_token_str_to_type=gpt2_token_str_to_type,
                    n_trials=N_TRIALS,
                    top_k=TOP_K,
                    injection_size=INJECTION_SIZE,
                    use_large_model_type=use_large_model_type,
                    large_model_token_type_dict=large_model_token_type_dict
                )
                category_results.append({
                    'sentence_idx': sent_idx,
                    'sentence': sentence,
                    'results': experiment_result
                })
            except Exception as e:
                print(f"    Error in sentence {sent_idx}: {e}")
                continue

        results[category] = category_results

    # Analyze results
    analysis = analyze_large_model_results(results, model_name, token_type)
    print_large_model_summary(analysis, model_name, token_type)

    return results, analysis

def analyze_large_model_results(results: Dict, model_name: str, token_type: str) -> Dict:
    """Analyze large model experiment results"""
    print(f"\n{'='*50}")
    print(f"{model_name.upper()} {token_type.upper()} ANALYSIS")
    print(f"{'='*50}")
    
    analysis = {}
    categories = ['target', 'high_interference', 'medium_interference', 'random']
    
    for category in categories:
        if category not in results or not results[category]:
            continue
        
        # Collect trial results
        count_increases = []
        prob_increases = []
        total_trials = 0
        
        for sentence_result in results[category]:
            for trial_data in sentence_result['results']['detailed_metrics']:
                count_increases.append(1 if trial_data['count_increase'] else 0)
                prob_increases.append(1 if trial_data['prob_increase'] else 0)
                total_trials += 1
        
        if total_trials == 0:
            continue
        
        count_success_rate = sum(count_increases) / total_trials
        prob_success_rate = sum(prob_increases) / total_trials
        
        analysis[category] = {
            'total_trials': total_trials,
            'count_increases': sum(count_increases),
            'prob_increases': sum(prob_increases),
            'count_success_rate': count_success_rate,
            'prob_success_rate': prob_success_rate,
            'count_increases_raw': count_increases,
            'prob_increases_raw': prob_increases
        }
        
        print(f"\n{category.upper()}:")
        print(f"  Total trials: {total_trials}")
        print(f"  Count increases: {sum(count_increases)} ({count_success_rate:.3f})")
        print(f"  Prob increases: {sum(prob_increases)} ({prob_success_rate:.3f})")
    
    # Statistical comparisons using t-test
    if 'random' in analysis:
        random_data = analysis['random']
        print(f"\nSTATISTICAL COMPARISONS (t-test):")
        
        for category in ['target', 'high_interference', 'medium_interference']:
            if category not in analysis:
                continue
            
            category_data = analysis[category]
            
            # Count increase comparison using t-test
            try:
                count_stat, count_pval = ttest_ind(
                    category_data['count_increases_raw'],
                    random_data['count_increases_raw'],
                    equal_var=False  # Welch's t-test (unequal variances)
                )
            except Exception as e:
                print(f"    Error in count t-test for {category}: {e}")
                count_stat, count_pval = 0.0, 1.0
            
            # Probability increase comparison using t-test
            try:
                prob_stat, prob_pval = ttest_ind(
                    category_data['prob_increases_raw'],
                    random_data['prob_increases_raw'],
                    equal_var=False  # Welch's t-test (unequal variances)
                )
            except Exception as e:
                print(f"    Error in prob t-test for {category}: {e}")
                prob_stat, prob_pval = 0.0, 1.0
            
            analysis[f'{category}_vs_random'] = {
                'count_diff': category_data['count_success_rate'] - random_data['count_success_rate'],
                'count_pvalue': count_pval,
                'count_significant': count_pval < 0.05,
                'count_tstat': count_stat,
                'prob_diff': category_data['prob_success_rate'] - random_data['prob_success_rate'],
                'prob_pvalue': prob_pval,
                'prob_significant': prob_pval < 0.05,
                'prob_tstat': prob_stat
            }
            
            print(f"\n{category} vs random:")
            print(f"  Count diff: {analysis[f'{category}_vs_random']['count_diff']:+.3f} (t={count_stat:.3f}, p={count_pval:.4f})")
            print(f"  Prob diff: {analysis[f'{category}_vs_random']['prob_diff']:+.3f} (t={prob_stat:.3f}, p={prob_pval:.4f})")
    
    return analysis

def print_large_model_summary(analysis: Dict, model_name: str, token_type: str):
    """Print summary of large model results"""
    print(f"\n{'='*60}")
    print(f"{model_name.upper()} {token_type.upper()} SUMMARY")
    print(f"{'='*60}")
    
    categories = ['target', 'high_interference', 'medium_interference', 'random']
    
    print(f"{'Category':<20} {'Count Rate':<12} {'Prob Rate':<12} {'vs Random':<15}")
    print(f"{'-'*20} {'-'*12} {'-'*12} {'-'*15}")
    
    for category in categories:
        if category not in analysis:
            continue
        
        count_rate = analysis[category]['count_success_rate']
        prob_rate = analysis[category]['prob_success_rate']
        
        vs_random_str = ""
        if f'{category}_vs_random' in analysis:
            vs_random = analysis[f'{category}_vs_random']
            count_sig = "**" if vs_random['count_significant'] else ""
            prob_sig = "**" if vs_random['prob_significant'] else ""
            vs_random_str = f"{vs_random['count_diff']:+.3f}{count_sig}/{vs_random['prob_diff']:+.3f}{prob_sig}"
        elif category == 'random':
            vs_random_str = "baseline"
        
        print(f"{category:<20} {count_rate:<12.3f} {prob_rate:<12.3f} {vs_random_str:<15}")
    
    print("\n** = statistically significant (p < 0.05)")

## Llama Exp using Large Model Token Type Dict

In [None]:
# 8: Run experiments on LlaMa
import datetime
import json
import numpy as np
from pathlib import Path

def convert_for_json(obj):
    """Convert numpy types to Python native types for JSON serialization"""
    if isinstance(obj, np.bool_):
        return bool(obj)
    elif isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {key: convert_for_json(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [convert_for_json(item) for item in obj]
    elif isinstance(obj, tuple):
        return tuple(convert_for_json(item) for item in obj)
    else:
        return obj

def safe_json_dump(data, file_path):
    """Safely dump data to JSON file with type conversion"""
    converted_data = convert_for_json(data)
    with open(file_path, 'w', encoding='utf-8') as f:
        json.dump(converted_data, f, ensure_ascii=False, indent=2)

# Create results directory
RESULTS_DIR = Path("./results/llama_gemma_prompt_injection")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

# Create timestamp for this experiment run
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
experiment_id = f"{TOKEN_SET_TYPE}_{N_SELECTED_ACT_TOKENS}tk_{PYTHIA_HIGH_INTERFERENCE_THRESHOLD}ph_{GPT2_HIGH_INTERFERENCE_THRESHOLD}gh_{timestamp}"

print(f"Experiment ID: {experiment_id}")
print(f"Results will be saved to: {RESULTS_DIR}")

experiment_results = {}

models = []
if llama_model and llama_tokenizer:
    models.append(('llama', llama_model, llama_tokenizer))
# if gemma_model and gemma_tokenizer:
#     models.append(('gemma', gemma_model, gemma_tokenizer))

test_token_types = [
    'location', 'person', 'emotion', 'color', 'animal', 'number', 'science', 'time'
]

# Save experiment metadata
experiment_metadata = {
    'experiment_id': experiment_id,
    'timestamp': timestamp,
    'parameters': {
        'TOKEN_SET_TYPE': TOKEN_SET_TYPE,
        'N_SELECTED_ACT_TOKENS': N_SELECTED_ACT_TOKENS,
        'PYTHIA_HIGH_INTERFERENCE_THRESHOLD': float(PYTHIA_HIGH_INTERFERENCE_THRESHOLD),
        'GPT2_HIGH_INTERFERENCE_THRESHOLD': float(GPT2_HIGH_INTERFERENCE_THRESHOLD),
        'BASELINE_INTERFERENCE_THRESHOLD': float(BASELINE_INTERFERENCE_THRESHOLD),
        'N_TRIALS': int(N_TRIALS),
        'TOP_K': int(TOP_K),
        'INJECTION_SIZE': int(INJECTION_SIZE),
        'RANDOM_SEED': int(RANDOM_SEED)
    },
    'models': [name for name, _, _ in models],
    'token_types': test_token_types,
    'max_sentences_per_type': 50
}

# Save metadata
metadata_file = RESULTS_DIR / f"{experiment_id}_metadata.json"
safe_json_dump(experiment_metadata, metadata_file)
print(f"Saved metadata to: {metadata_file}")

for model_name, model, tokenizer in models:
    experiment_results[model_name] = {}
    
    for token_type in test_token_types:
        print(f"\n{'#'*80}")
        print(f"STARTING {model_name.upper()} {token_type.upper()} EXPERIMENT")
        print(f"{'#'*80}")
        
        try:
            results, analysis = run_large_model_experiment_with_type_check(
                model=model,
                tokenizer=tokenizer,
                model_name=model_name,
                token_type=token_type,
                max_sentences=N_TEST_SENTENCES
            )
            
            experiment_results[model_name][token_type] = {
                'results': results,
                'analysis': analysis
            }
            
            # Save individual result file for each model-token_type combination
            individual_result = {
                'experiment_id': experiment_id,
                'model_name': model_name,
                'token_type': token_type,
                'timestamp': timestamp,
                'parameters': experiment_metadata['parameters'],
                'token_sets': token_sets[token_type],
                'results': results,
                'analysis': analysis
            }
            
            individual_file = RESULTS_DIR / f"{experiment_id}_{model_name}_{token_type}.json"
            safe_json_dump(individual_result, individual_file)
            print(f"Saved {model_name} {token_type} results to: {individual_file}")
            
        except Exception as e:
            print(f"Error in {model_name} {token_type} experiment: {e}")
            
            # Save error info
            error_result = {
                'experiment_id': experiment_id,
                'model_name': model_name,
                'token_type': token_type,
                'timestamp': timestamp,
                'error': str(e),
                'parameters': experiment_metadata['parameters']
            }
            
            error_file = RESULTS_DIR / f"{experiment_id}_{model_name}_{token_type}_ERROR.json"
            safe_json_dump(error_result, error_file)
            print(f"Saved error info to: {error_file}")
            continue
        
        # Clear memory between experiments
        clear_memory()

# Save complete experiment results
complete_results = {
    'experiment_id': experiment_id,
    'timestamp': timestamp,
    'metadata': experiment_metadata,
    'results': experiment_results
}

complete_file = RESULTS_DIR / f"{experiment_id}_complete.json"
safe_json_dump(complete_results, complete_file)

print(f"\n{'='*80}")
print("ALL LARGE MODEL EXPERIMENTS COMPLETE")
print(f"Saved complete results to: {complete_file}")
print(f"{'='*80}")

## Gemma Exp using Large Model Token Type Dict

In [None]:
# 8: Run experiments on LlaMa
import datetime
import json
import numpy as np
from pathlib import Path

def convert_for_json(obj):
    """Convert numpy types to Python native types for JSON serialization"""
    if isinstance(obj, np.bool_):
        return bool(obj)
    elif isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {key: convert_for_json(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [convert_for_json(item) for item in obj]
    elif isinstance(obj, tuple):
        return tuple(convert_for_json(item) for item in obj)
    else:
        return obj

def safe_json_dump(data, file_path):
    """Safely dump data to JSON file with type conversion"""
    converted_data = convert_for_json(data)
    with open(file_path, 'w', encoding='utf-8') as f:
        json.dump(converted_data, f, ensure_ascii=False, indent=2)

# Create results directory
RESULTS_DIR = Path("./results/llama_gemma_prompt_injection")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

# Create timestamp for this experiment run
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
experiment_id = f"{TOKEN_SET_TYPE}_{N_SELECTED_ACT_TOKENS}tk_{PYTHIA_HIGH_INTERFERENCE_THRESHOLD}ph_{GPT2_HIGH_INTERFERENCE_THRESHOLD}gh_{timestamp}"

print(f"Experiment ID: {experiment_id}")
print(f"Results will be saved to: {RESULTS_DIR}")

experiment_results = {}

models = []
# if llama_model and llama_tokenizer:
#     models.append(('llama', llama_model, llama_tokenizer))
if gemma_model and gemma_tokenizer:
    models.append(('gemma', gemma_model, gemma_tokenizer))

test_token_types = [
    'location', 'person', 'emotion', 'color', 'animal', 'number', 'science', 'time'
]

# Save experiment metadata
experiment_metadata = {
    'experiment_id': experiment_id,
    'timestamp': timestamp,
    'parameters': {
        'TOKEN_SET_TYPE': TOKEN_SET_TYPE,
        'N_SELECTED_ACT_TOKENS': N_SELECTED_ACT_TOKENS,
        'PYTHIA_HIGH_INTERFERENCE_THRESHOLD': float(PYTHIA_HIGH_INTERFERENCE_THRESHOLD),
        'GPT2_HIGH_INTERFERENCE_THRESHOLD': float(GPT2_HIGH_INTERFERENCE_THRESHOLD),
        'BASELINE_INTERFERENCE_THRESHOLD': float(BASELINE_INTERFERENCE_THRESHOLD),
        'N_TRIALS': int(N_TRIALS),
        'TOP_K': int(TOP_K),
        'INJECTION_SIZE': int(INJECTION_SIZE),
        'RANDOM_SEED': int(RANDOM_SEED)
    },
    'models': [name for name, _, _ in models],
    'token_types': test_token_types,
    'max_sentences_per_type': 50
}

# Save metadata
metadata_file = RESULTS_DIR / f"{experiment_id}_metadata.json"
safe_json_dump(experiment_metadata, metadata_file)
print(f"Saved metadata to: {metadata_file}")

for model_name, model, tokenizer in models:
    experiment_results[model_name] = {}
    
    for token_type in test_token_types:
        print(f"\n{'#'*80}")
        print(f"STARTING {model_name.upper()} {token_type.upper()} EXPERIMENT")
        print(f"{'#'*80}")
        
        try:
            results, analysis = run_large_model_experiment_with_type_check(
                model=model,
                tokenizer=tokenizer,
                model_name=model_name,
                token_type=token_type,
                max_sentences=N_TEST_SENTENCES
            )
            
            experiment_results[model_name][token_type] = {
                'results': results,
                'analysis': analysis
            }
            
            # Save individual result file for each model-token_type combination
            individual_result = {
                'experiment_id': experiment_id,
                'model_name': model_name,
                'token_type': token_type,
                'timestamp': timestamp,
                'parameters': experiment_metadata['parameters'],
                'token_sets': token_sets[token_type],
                'results': results,
                'analysis': analysis
            }
            
            individual_file = RESULTS_DIR / f"{experiment_id}_{model_name}_{token_type}.json"
            safe_json_dump(individual_result, individual_file)
            print(f"Saved {model_name} {token_type} results to: {individual_file}")
            
        except Exception as e:
            print(f"Error in {model_name} {token_type} experiment: {e}")
            
            # Save error info
            error_result = {
                'experiment_id': experiment_id,
                'model_name': model_name,
                'token_type': token_type,
                'timestamp': timestamp,
                'error': str(e),
                'parameters': experiment_metadata['parameters']
            }
            
            error_file = RESULTS_DIR / f"{experiment_id}_{model_name}_{token_type}_ERROR.json"
            safe_json_dump(error_result, error_file)
            print(f"Saved error info to: {error_file}")
            continue
        
        # Clear memory between experiments
        clear_memory()

# Save complete experiment results
complete_results = {
    'experiment_id': experiment_id,
    'timestamp': timestamp,
    'metadata': experiment_metadata,
    'results': experiment_results
}

complete_file = RESULTS_DIR / f"{experiment_id}_complete.json"
safe_json_dump(complete_results, complete_file)

print(f"\n{'='*80}")
print("ALL LARGE MODEL EXPERIMENTS COMPLETE")
print(f"Saved complete results to: {complete_file}")
print(f"{'='*80}")