In [None]:
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath('__file__'))))

from utils.utils_data import *
from tqdm.notebook import tqdm
import torch
import json
import numpy as np
import gc

def cosine_similarity_matrix(X, Y):
    """
    Calculate cosine similarity matrix between two tensor matrices
    X: [n, d], Y: [m, d]
    Returns: [n, m]
    """
    X_norm = torch.nn.functional.normalize(X, p=2, dim=1)
    Y_norm = torch.nn.functional.normalize(Y, p=2, dim=1)
    return torch.mm(X_norm, Y_norm.t())

def create_range_masks(decoder_sim_matrix, similarity_ranges):
    """
    Create mask matrices for all similarity ranges
    """
    range_masks = {}
    for range_min, range_max in similarity_ranges:
        if range_max == 1.0:
            range_mask = (decoder_sim_matrix >= range_min) & (decoder_sim_matrix <= range_max)
        else:
            range_mask = (decoder_sim_matrix >= range_min) & (decoder_sim_matrix < range_max)
        
        # Exclude diagonal elements
        range_mask.fill_diagonal_(False)
        range_masks[f"{range_min:.1f}-{range_max:.1f}"] = range_mask
    
    return range_masks

def extract_interference_features_on_cpu(combined_masks_cpu, qualified_feature_positions_cpu, 
                                        range_keys, valid_indices, feature_explanations):
    """
    Extract interference features on CPU for fast processing
    """
    qualified_features = {}
    
    if len(qualified_feature_positions_cpu) == 0:
        return qualified_features
    
    print(f"    Moving data to CPU for feature ID extraction...")
    
    # Process each qualifying feature on CPU
    for pos in tqdm(qualified_feature_positions_cpu, desc="Extracting interference features"):
        feature_idx = valid_indices[pos]
        feature_interference = {}
        
        # Find interference features for each range
        for range_idx, range_key in enumerate(range_keys):
            # Get interference mask for this feature in this range
            interference_mask = combined_masks_cpu[range_idx, pos]
            interference_positions = np.nonzero(interference_mask)[0]
            
            # Convert to feature IDs
            interference_feature_ids = [valid_indices[p] for p in interference_positions]
            feature_interference[range_key] = interference_feature_ids
        
        qualified_features[feature_idx] = {
            'explanation': feature_explanations[feature_idx],
            'interference_features': feature_interference
        }
    
    return qualified_features

def fully_vectorized_interference_analysis(decoder_sim_matrix, explanation_sim_matrix, 
                                         semantic_thresholds, similarity_ranges, 
                                         valid_indices, feature_explanations, device):
    """
    Fully vectorized interference feature analysis (GPU filtering + CPU extraction)
    """
    print(f"    Starting GPU vectorized filtering...")
    n_features = decoder_sim_matrix.shape[0]
    
    # 1. Create all range masks (bool type)
    range_masks = create_range_masks(decoder_sim_matrix, similarity_ranges)
    range_keys = list(range_masks.keys())
    
    # 2. Stack all range masks into 3D tensor [n_ranges, n_features, n_features]
    range_masks_tensor = torch.stack([range_masks[key] for key in range_keys])
    
    # Release range_masks dictionary
    del range_masks
    torch.cuda.empty_cache()
    
    print(f"    Range mask tensor shape: {range_masks_tensor.shape}")
    
    # 3. Process each threshold in batch
    results_by_threshold = {}
    
    for threshold in semantic_thresholds:
        print(f"    Processing threshold {threshold} on GPU...")
        
        # Create semantic mask matrix [n_features, n_features] - keep bool type
        semantic_mask = (explanation_sim_matrix < threshold)
        semantic_mask.fill_diagonal_(False)  # Exclude self
        
        # 4. Calculate combined masks in batch [n_ranges, n_features, n_features]
        # Use bool AND operation: [n_ranges, n_features, n_features] & [1, n_features, n_features]
        combined_masks = range_masks_tensor & semantic_mask.unsqueeze(0)
        
        # 5. Check if each feature has interference features in each range
        # For each range, check if each row (each feature) has any interference features
        has_interference_per_range = combined_masks.any(dim=2)  # [n_ranges, n_features]
        
        # 6. Find features that have interference features in all ranges
        # Check if each feature has at least one interference feature in all ranges
        has_interference_all_ranges = has_interference_per_range.all(dim=0)  # [n_features]
        
        # 7. Get indices of qualifying features
        qualified_feature_positions = torch.nonzero(has_interference_all_ranges, as_tuple=True)[0]
        
        print(f"    GPU filtering complete, found {len(qualified_feature_positions)} qualifying features")
        
        # 8. Move data to CPU for feature ID extraction
        if len(qualified_feature_positions) > 0:
            # Move GPU data to CPU
            combined_masks_cpu = combined_masks.cpu().numpy()
            qualified_feature_positions_cpu = qualified_feature_positions.cpu().numpy()
            
            # Extract specific interference feature IDs on CPU
            qualified_features = extract_interference_features_on_cpu(
                combined_masks_cpu, qualified_feature_positions_cpu, 
                range_keys, valid_indices, feature_explanations
            )
        else:
            qualified_features = {}
        
        results_by_threshold[threshold] = qualified_features
        print(f"    Threshold {threshold}: final confirmed {len(qualified_features)} features")
        
        # Clean up temporary data for current threshold
        del semantic_mask, combined_masks, has_interference_per_range, has_interference_all_ranges
        del qualified_feature_positions
        torch.cuda.empty_cache()
    
    # Clean up range-related data
    del range_masks_tensor
    torch.cuda.empty_cache()
    
    return results_by_threshold

def save_single_file(model_name, layer_type, layer_idx, threshold, features, output_dir="./interference_results"):
    """
    Save single file: model_layertype_layeridx_threshold.json
    """
    if len(features) == 0:
        return None
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Convert feature data format
    cleaned_features = {}
    for feature_id, feature_data in features.items():
        cleaned_features[str(feature_id)] = {
            'explanation': feature_data['explanation'],
            'interference_features': feature_data['interference_features']
        }
    
    # Generate filename
    filename = f"{model_name}_{layer_type}_{layer_idx}_{threshold:.2f}.json"
    filepath = os.path.join(output_dir, filename)
    
    # Save file
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(cleaned_features, f, indent=2, ensure_ascii=False)
    
    print(f"    Saved file: {filename} ({len(features)} features)")
    return filename

def analyze_model_interference(model_name, layer_types, num_layers, 
                              similarity_ranges=None, semantic_thresholds=None,
                              output_dir="./interference_results"):
    """
    Analyze interference features for a given model
    
    Args:
        model_name: 'pythia' or 'gpt2'
        layer_types: list of layer types to analyze
        num_layers: number of layers in the model
        similarity_ranges: list of (min, max) tuples for similarity ranges
        semantic_thresholds: list of semantic similarity thresholds
        output_dir: directory to save results
    """
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Default parameters if not provided
    if similarity_ranges is None:
        similarity_ranges = [
            (0.0, 0.1),
            (0.1, 0.2), 
            (0.2, 0.3),
            (0.3, 0.4),
            (0.4, 1.0)
        ]
    
    if semantic_thresholds is None:
        semantic_thresholds = [0.4, 0.3, 0.2, 0.15, 0.1]
    
    print(f"Starting interference feature analysis for {model_name.upper()} model")
    print("="*80)
    print(f"Model: {model_name}")
    print(f"Layer types: {layer_types}")
    print(f"Number of layers: {num_layers} (0-{num_layers-1})")
    print(f"Similarity ranges: {similarity_ranges}")
    print(f"Semantic thresholds: {semantic_thresholds}")
    print(f"Output directory: {output_dir}")
    
    # Record all saved files
    saved_files = []
    
    # Process each layer
    for layer_idx in range(num_layers):
        print(f"\nProcessing layer {layer_idx}...")
        
        # Process each layer type
        for layer_type in layer_types:
            print(f"  Processing {layer_type} type...")
            
            try:
                # Clear GPU memory
                torch.cuda.empty_cache()
                gc.collect()
                
                # 1. Get all necessary data
                print(f"    Loading data...")
                active_features = get_sae_features_by_layer(model_name, layer_type, layer_idx, active_only=True)
                if not active_features:
                    print(f"    No active features found")
                    continue
                
                decoder_weights = get_sae_decoder_weights_from_local(model_name, layer_type, layer_idx)
                explanation_embeddings_dict = get_feature_explanation_embeddings_by_layer(model_name, layer_type, layer_idx)
                feature_explanations = get_feature_explanations_by_layer(model_name, layer_type, layer_idx)
                
                # 2. Filter features with complete information
                valid_indices = [
                    idx for idx in active_features.keys()
                    if idx in explanation_embeddings_dict and idx in feature_explanations
                ]
                
                if not valid_indices:
                    print(f"    No features with complete information found")
                    continue
                
                print(f"    Valid features count: {len(valid_indices)}")
                
                # 3. Build matrices (on GPU, using float32)
                print(f"    Building weight matrices (float32)...")
                valid_decoder_weights = decoder_weights[valid_indices].to(device).float()
                explanation_embeddings_list = [explanation_embeddings_dict[idx] for idx in valid_indices]
                explanation_embeddings = torch.stack(explanation_embeddings_list).to(device).float()
                
                # Release original data immediately
                del decoder_weights
                torch.cuda.empty_cache()
                
                # 4. Calculate similarity matrices (on GPU, float32)
                print(f"    Computing similarity matrices (float32)...")
                decoder_sim_matrix = cosine_similarity_matrix(valid_decoder_weights, valid_decoder_weights)
                explanation_sim_matrix = cosine_similarity_matrix(explanation_embeddings, explanation_embeddings)
                
                # Release weight matrices immediately
                del valid_decoder_weights, explanation_embeddings
                torch.cuda.empty_cache()
                
                print(f"    Matrix size: {decoder_sim_matrix.shape}, data type: {decoder_sim_matrix.dtype}")
                
                # 5. GPU filtering + CPU extraction
                threshold_results = fully_vectorized_interference_analysis(
                    decoder_sim_matrix, explanation_sim_matrix, semantic_thresholds, 
                    similarity_ranges, valid_indices, feature_explanations, device
                )
                
                # 6. Save results for each threshold and clean up memory
                for threshold in semantic_thresholds:
                    features = threshold_results[threshold]
                    if len(features) > 0:
                        filename = save_single_file(model_name, layer_type, layer_idx, threshold, features, output_dir)
                        if filename:
                            saved_files.append(filename)
                    
                    # Clean up results for this threshold
                    del threshold_results[threshold]
                
                # Clean up all data
                del threshold_results
                del decoder_sim_matrix, explanation_sim_matrix
                torch.cuda.empty_cache()
                gc.collect()
                
            except Exception as e:
                print(f"    Error processing layer {layer_idx} {layer_type}: {e}")
                import traceback
                traceback.print_exc()
                
                # Clean up GPU memory on error
                torch.cuda.empty_cache()
                gc.collect()
                continue
        
        # Memory cleanup after each layer
        print(f"\n  Layer {layer_idx} processing complete, cleaning memory...")
        torch.cuda.empty_cache()
        gc.collect()
    
    # Display final results
    print("\n" + "="*80)
    print("Analysis complete! File statistics:")
    print("="*80)
    
    # Collect file statistics
    file_stats = {}
    total_features = 0
    
    for filename in saved_files:
        try:
            filepath = os.path.join(output_dir, filename)
            with open(filepath, 'r', encoding='utf-8') as f:
                data = json.load(f)
            
            feature_count = len(data)
            total_features += feature_count
            
            # Parse filename to get information
            parts = filename.replace('.json', '').split('_')
            threshold = parts[-1]
            
            if threshold not in file_stats:
                file_stats[threshold] = []
            
            file_stats[threshold].append((filename, feature_count))
            
        except Exception as e:
            print(f"Error reading file {filename}: {e}")
    
    # Display statistics by threshold
    for threshold in sorted(file_stats.keys()):
        print(f"\nThreshold {threshold}:")
        threshold_total = 0
        for filename, count in file_stats[threshold]:
            print(f"  {filename}: {count} features")
            threshold_total += count
        print(f"  Subtotal: {threshold_total} features")
    
    print(f"\nTotal summary:")
    print(f"  Number of files: {len(saved_files)}")
    print(f"  Total features: {total_features}")
    
    return saved_files, file_stats

In [None]:
analyze_model_interference(
    model_name='pythia',
    layer_types=['att', 'mlp', 'res'],
    num_layers=6,
    semantic_thresholds=[0.4, 0.3, 0.2, 0.15],
    output_dir="./interference_results/pythia"
)

analyze_model_interference(
    model_name='gpt2',
    layer_types=['att', 'res_mid', 'mlp', 'res_post'],
    num_layers=12,
    semantic_thresholds=[0.4, 0.3, 0.2, 0.15],
    output_dir="./interference_results/gpt2"
)