In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn as nn
import os
import gc
from typing import Dict, List, Tuple, Optional
import re

def select_model():
    """Prompt user to choose a model"""
    model_name = input("Enter Hugging Face model name (e.g., 'gpt2', 'EleutherAI/gpt-neo-125M'): ")
    return model_name

def get_total_size(model):
    """Calculate total model parameter size in GB"""
    total_size_bytes = sum(p.element_size() * p.nelement() for p in model.parameters())
    return total_size_bytes / (1024 ** 3)

def find_attention_layers(model) -> List[Tuple[str, nn.Module]]:
    """
    Find all attention layers in the model regardless of architecture
    Returns list of (layer_name, layer_module) tuples
    """
    attention_layers = []
    
    for name, module in model.named_modules():
        # Look for common attention layer patterns
        if any(pattern in name.lower() for pattern in ['attn', 'attention', 'self_attn']):
            # Check if it's a linear layer that could be attention weights
            if isinstance(module, nn.Linear):
                # Common patterns for attention weight matrices
                if any(suffix in name for suffix in ['q_proj', 'k_proj', 'v_proj', 'qkv', 'c_attn', 'query', 'key', 'value']):
                    attention_layers.append((name, module))
    
    return attention_layers

def get_model_architecture_info(model) -> Dict:
    """
    Extract architecture-specific information from the model
    """
    config = model.config
    arch_info = {
        'num_attention_heads': getattr(config, 'num_attention_heads', getattr(config, 'n_head', 12)),
        'hidden_size': getattr(config, 'hidden_size', getattr(config, 'n_embd', 768)),
        'num_layers': getattr(config, 'num_hidden_layers', getattr(config, 'n_layer', 12)),
        'model_type': getattr(config, 'model_type', 'unknown')
    }
    
    arch_info['head_dim'] = arch_info['hidden_size'] // arch_info['num_attention_heads']
    return arch_info

def compute_head_importance_universal(model, arch_info: Dict) -> Dict[int, torch.Tensor]:
    """
    Compute head importance for any model architecture using L2 norm
    """
    head_importance = {}
    attention_layers = find_attention_layers(model)
    
    if not attention_layers:
        raise NotImplementedError("No attention layers found in this model")
    
    # Group attention layers by layer index
    layer_groups = {}
    for name, module in attention_layers:
        # Extract layer number from name (e.g., "layer.0.attention.query" -> 0)
        layer_match = re.search(r'(?:layer|h|block)\.(\d+)', name)
        if layer_match:
            layer_idx = int(layer_match.group(1))
            if layer_idx not in layer_groups:
                layer_groups[layer_idx] = []
            layer_groups[layer_idx].append((name, module))
    
    num_heads = arch_info['num_attention_heads']
    head_dim = arch_info['head_dim']
    
    for layer_idx, layer_modules in layer_groups.items():
        layer_importance = torch.zeros(num_heads)
        
        for name, module in layer_modules:
            weight = module.weight.data
            
            # Handle different attention weight organizations
            if 'qkv' in name or 'c_attn' in name:
                # Combined QKV matrix (like GPT-2)
                if weight.shape[0] == 3 * arch_info['hidden_size']:
                    # Reshape to [3, num_heads, head_dim, hidden_size]
                    qkv_weight = weight.view(3, num_heads, head_dim, -1)
                    # Compute L2 norm for each head across Q, K, V
                    head_norms = qkv_weight.norm(dim=(0, 2, 3))
                    layer_importance += head_norms.cpu()
            else:
                # Separate Q, K, V matrices (like BERT, LLaMA)
                if weight.shape[0] == arch_info['hidden_size']:
                    # Reshape to [num_heads, head_dim, hidden_size]
                    head_weight = weight.view(num_heads, head_dim, -1)
                    # Compute L2 norm for each head
                    head_norms = head_weight.norm(dim=(1, 2))
                    layer_importance += head_norms.cpu()
        
        if layer_importance.sum() > 0:
            head_importance[layer_idx] = layer_importance
    
    return head_importance

def prune_attention_weights_universal(model, head_importance: Dict[int, torch.Tensor], 
                                    prune_percent: float, arch_info: Dict) -> Tuple[int, int]:
    """
    Universal attention pruning by zeroing out least important heads
    """
    heads_to_prune = {}
    original_heads = arch_info['num_attention_heads']
    total_heads_pruned = 0
    
    # Determine which heads to prune
    for layer_idx, importance in head_importance.items():
        num_heads = len(importance)
        n_prune = max(1, int(prune_percent * num_heads))
        n_prune = min(n_prune, num_heads - 1)  # Leave at least 1 head
        
        # Get indices of least important heads
        prune_idxs = torch.topk(importance, k=n_prune, largest=False).indices.tolist()
        heads_to_prune[layer_idx] = prune_idxs
        total_heads_pruned += n_prune
        print(f"Layer {layer_idx}: pruning {n_prune}/{num_heads} heads (indices: {prune_idxs})")
    
    # Apply pruning by zeroing weights
    attention_layers = find_attention_layers(model)
    head_dim = arch_info['head_dim']
    
    for name, module in attention_layers:
        # Extract layer number
        layer_match = re.search(r'(?:layer|h|block)\.(\d+)', name)
        if layer_match:
            layer_idx = int(layer_match.group(1))
            if layer_idx in heads_to_prune:
                prune_idxs = heads_to_prune[layer_idx]
                
                with torch.no_grad():
                    weight = module.weight.data
                    
                    if 'qkv' in name or 'c_attn' in name:
                        # Combined QKV matrix
                        if weight.shape[0] == 3 * arch_info['hidden_size']:
                            # Zero out the weights for pruned heads
                            for head_idx in prune_idxs:
                                start_q = head_idx * head_dim
                                end_q = (head_idx + 1) * head_dim
                                start_k = arch_info['hidden_size'] + head_idx * head_dim
                                end_k = arch_info['hidden_size'] + (head_idx + 1) * head_dim
                                start_v = 2 * arch_info['hidden_size'] + head_idx * head_dim
                                end_v = 2 * arch_info['hidden_size'] + (head_idx + 1) * head_dim
                                
                                weight[start_q:end_q, :] = 0
                                weight[start_k:end_k, :] = 0
                                weight[start_v:end_v, :] = 0
                    else:
                        # Separate Q, K, V matrices
                        if weight.shape[0] == arch_info['hidden_size']:
                            for head_idx in prune_idxs:
                                start_idx = head_idx * head_dim
                                end_idx = (head_idx + 1) * head_dim
                                weight[start_idx:end_idx, :] = 0
                    
                    # Also zero bias if present
                    if hasattr(module, 'bias') and module.bias is not None:
                        bias = module.bias.data
                        if 'qkv' in name or 'c_attn' in name:
                            if bias.shape[0] == 3 * arch_info['hidden_size']:
                                for head_idx in prune_idxs:
                                    start_q = head_idx * head_dim
                                    end_q = (head_idx + 1) * head_dim
                                    start_k = arch_info['hidden_size'] + head_idx * head_dim
                                    end_k = arch_info['hidden_size'] + (head_idx + 1) * head_dim
                                    start_v = 2 * arch_info['hidden_size'] + head_idx * head_dim
                                    end_v = 2 * arch_info['hidden_size'] + (head_idx + 1) * head_dim
                                    
                                    bias[start_q:end_q] = 0
                                    bias[start_k:end_k] = 0
                                    bias[start_v:end_v] = 0
                        else:
                            if bias.shape[0] == arch_info['hidden_size']:
                                for head_idx in prune_idxs:
                                    start_idx = head_idx * head_dim
                                    end_idx = (head_idx + 1) * head_dim
                                    bias[start_idx:end_idx] = 0
    
    return total_heads_pruned, original_heads

def prune_model_universal(model, prune_percent: float = 0.1):
    """
    Universal model pruning function that works with any architecture
    """
    print("Analyzing model architecture...")
    arch_info = get_model_architecture_info(model)
    print(f"Model type: {arch_info['model_type']}")
    print(f"Number of layers: {arch_info['num_layers']}")
    print(f"Number of attention heads: {arch_info['num_attention_heads']}")
    print(f"Head dimension: {arch_info['head_dim']}")
    
    print("\nComputing head importance...")
    head_importance = compute_head_importance_universal(model, arch_info)
    
    print(f"\nPruning {prune_percent*100:.0f}% of heads per layer...")
    total_heads_pruned, original_heads = prune_attention_weights_universal(
        model, head_importance, prune_percent, arch_info
    )
    
    return total_heads_pruned, original_heads, arch_info

# ---- Main execution ----
try:
    model_name = select_model()
    print(f"Loading model {model_name}...")
    
    # Load tokenizer and model with error handling
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
    except Exception as e:
        raise RuntimeError(f"Failed to load model/tokenizer: {str(e)}")
    
    # Device management
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)
    print(f"Model loaded on {device}")
    
    # Get original size
    original_size = get_total_size(model)
    print(f"\nOriginal size: {original_size:.2f} GB")
    
    # Pruning parameters
    PRUNE_PERCENT = 0.20
    
    # Apply universal pruning
    total_heads_pruned, original_heads, arch_info = prune_model_universal(model, PRUNE_PERCENT)
    
    # Calculate final size and statistics
    pruned_size = get_total_size(model)
    print(f"\nPruned size: {pruned_size:.2f} GB")
    print(f"Reduction: {original_size - pruned_size:.2f} GB ({(1 - pruned_size/original_size)*100:.1f}%)")
    print(f"Total heads pruned: {total_heads_pruned}")
    print(f"Pruning percentage: {PRUNE_PERCENT*100:.0f}% per layer")
    print(f"Original heads per layer: {original_heads}")
    
    # Clean up memory
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # Save the pruned model
    safe_model_name = model_name.replace("/", "_").replace("\\", "_")
    output_dir = f"./pruned_model_{safe_model_name}"
    
    try:
        os.makedirs(output_dir, exist_ok=True)
        model.save_pretrained(output_dir, safe_serialization=True)
        tokenizer.save_pretrained(output_dir)
        print(f"\nPruned model saved to: {output_dir}")
    except Exception as e:
        raise RuntimeError(f"Failed to save model: {str(e)}")
    
except KeyboardInterrupt:
    print("\nOperation cancelled by user")
except Exception as e:
    print(f"Error: {str(e)}")
    import traceback
    traceback.print_exc()