In [5]:
class LLMLayerImportance:
    def __init__(self, model_name="gpt2", device=None):
        """
        Initialize the LLMLayerImportance class to analyze layer importance in LLMs.
        
        Parameters:
        -----------
        model_name : str
            The name of the pretrained model to load (e.g., "gpt2", "llama-7b", "mistral-7b")
        device : torch.device or None
            Device to load the model on. If None, uses CUDA if available, otherwise CPU.
        """
        import torch
        from transformers import AutoModelForCausalLM, AutoTokenizer
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
        self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        
        # Get model architecture information
        self.num_layers = self.model.config.num_hidden_layers
        self.num_attention_heads = self.model.config.num_attention_heads
        
        # Store reference to whether model uses multi-head attention
        if hasattr(self.model.config, "is_decoder") and self.model.config.is_decoder:
            self.has_cross_attention = True
        else:
            self.has_cross_attention = False
            
        # Initialize layer mask attributes
        self.layer_masks = None
        
    def get_dataloader(self, path, name=None, split="validation", batch_size=8, shuffle=False):
        """
        Create a dataloader for evaluating layer importance.
        
        Parameters:
        -----------
        path : str
            Dataset path or name from HuggingFace datasets
        name : str or None
            Specific dataset configuration name
        split : str
            Dataset split to use (e.g., "train", "validation", "test")
        batch_size : int
            Batch size for dataloader
        shuffle : bool
            Whether to shuffle the dataset
            
        Returns:
        --------
        torch.utils.data.DataLoader
            DataLoader containing preprocessed dataset samples
        """
        from datasets import load_dataset
        from torch.utils.data import DataLoader
        
        dataset = load_dataset(path, name, split=split)
        dataset = self._preprocess_dataset(path, dataset)
        dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    
    def _preprocess_dataset(self, path, dataset):
        """
        Preprocess different datasets based on their type.
        
        Parameters:
        -----------
        path : str
            Dataset path or name
        dataset : datasets.Dataset
            Dataset object to preprocess
            
        Returns:
        --------
        datasets.Dataset
            Preprocessed dataset
        """
        if path == "wikitext":
            return self._preprocess_wikitext(dataset)
        elif path == "glue":
            return self._preprocess_glue(dataset)
        elif path == "openai/webtext":
            return self._preprocess_webtext(dataset)
        elif path == "imdb":
            return self._preprocess_imdb(dataset)
        else:
            raise ValueError(f"Preprocessing for dataset {path} is not implemented.")
            
    def _preprocess_wikitext(self, dataset):
        """Preprocess WikiText dataset for language modeling"""
        def preprocess(batch):
            inputs = self.tokenizer(batch["text"], padding="max_length", 
                                    truncation=True, max_length=512)
            inputs["labels"] = inputs["input_ids"].copy()
            return inputs
            
        return dataset.map(preprocess, batched=True, remove_columns=["text"])
    
    def _preprocess_glue(self, dataset):
        """Preprocess GLUE dataset for classification"""
        def preprocess(batch):
            if "sentence" in batch:
                texts = batch["sentence"]
            elif "sentence1" in batch and "sentence2" in batch:
                texts = [s1 + " [SEP] " + s2 for s1, s2 in zip(batch["sentence1"], batch["sentence2"])]
            else:
                raise ValueError("Unsupported GLUE task format")
                
            inputs = self.tokenizer(texts, padding="max_length", truncation=True, max_length=128)
            inputs["labels"] = batch["label"]
            return inputs
            
        cols_to_remove = [col for col in dataset.column_names if col not in ["label"]]
        return dataset.map(preprocess, batched=True, remove_columns=cols_to_remove)
    
    def _preprocess_webtext(self, dataset):
        """Preprocess WebText dataset for language modeling"""
        def preprocess(batch):
            inputs = self.tokenizer(batch["text"], padding="max_length", 
                                    truncation=True, max_length=512)
            inputs["labels"] = inputs["input_ids"].copy()
            return inputs
            
        return dataset.map(preprocess, batched=True, remove_columns=["text"])
    
    def _preprocess_imdb(self, dataset):
        """Preprocess IMDB dataset for sentiment classification"""
        def preprocess(batch):
            inputs = self.tokenizer(batch["text"], padding="max_length", 
                                   truncation=True, max_length=512)
            inputs["labels"] = batch["label"]
            return inputs
            
        return dataset.map(preprocess, batched=True, remove_columns=["text"])
    
    def compute_layer_importance(self, dataloader, layer_mask=None, num_batches=50):
        """
        Compute importance scores for layers in the LLM by measuring their influence on model predictions.
        
        Layer importance quantifies how much each layer contributes to the model's performance/loss.
        Higher importance values indicate layers that have greater impact on the model's predictions.
        
        Parameters:
        -----------
        dataloader : torch.utils.data.DataLoader
            DataLoader containing evaluation data
        layer_mask : torch.Tensor or None
            Optional initial mask for model layers of shape (num_layers,)
            Contains 0s and 1s, where 1 indicates the layer is active
        num_batches : int
            Number of batches to use for importance computation
            
        Returns:
        --------
        torch.Tensor
            Normalized importance scores for each layer of shape (num_layers,)
        """
        import torch
        import torch.nn as nn
        
        self.model.eval()
        
        # Initialize layer masks if not provided
        if layer_mask is None:
            layer_mask = torch.ones(self.num_layers, device=self.device, requires_grad=True)
        else:
            # Ensure mask has requires_grad=True for gradient computation
            layer_mask = layer_mask.clone().detach().requires_grad_(True)
            
        # Store the mask for potential later use
        self.layer_masks = layer_mask
        
        # Initialize importance scores
        layer_importance = torch.zeros(self.num_layers, device=self.device)
        
        # Register hooks for each transformer layer to apply masks
        hooks = []
        
        def get_layer_hook(layer_idx):
            def hook(module, input, output):
                # Apply layer mask through scaling
                mask_value = layer_mask[layer_idx]
                # Scale the output by the mask value
                return output * mask_value
            return hook
        
        # Attach hooks to transformer layers
        for i in range(self.num_layers):
            if hasattr(self.model, "transformer"):
                # GPT-2 style models
                layer = self.model.transformer.h[i]
            elif hasattr(self.model, "model"):
                # Some models have nested structure
                if hasattr(self.model.model, "layers"):
                    layer = self.model.model.layers[i]
                else:
                    layer = self.model.model.decoder.layers[i]
            elif hasattr(self.model, "decoder"):
                # Decoder-only models
                layer = self.model.decoder.layers[i]
            elif hasattr(self.model, "layers"):
                # Direct layers attribute
                layer = self.model.layers[i]
            else:
                raise ValueError(f"Unsupported model architecture: {type(self.model)}")
                
            hook = layer.register_forward_hook(get_layer_hook(i))
            hooks.append(hook)
        
        # Process batches
        batch_count = 0
        for batch in dataloader:
            if batch_count >= num_batches:
                break
                
            input_ids = batch["input_ids"].to(self.device)
            attention_mask = batch["attention_mask"].to(self.device) if "attention_mask" in batch else None
            labels = batch["labels"].to(self.device)
            
            # Reset gradients
            self.model.zero_grad()
            
            # Forward pass with layer masking applied through hooks
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            loss = outputs.loss
            loss.backward()
            
            # Accumulate gradients (absolute values) for importance scores
            layer_importance += torch.abs(layer_mask.grad.detach())
            
            # Reset the gradient for the next iteration
            layer_mask.grad.zero_()
            
            batch_count += 1
        
        # Remove hooks
        for hook in hooks:
            hook.remove()
        
        # Normalize importance scores to [0, 1]
        max_importance = layer_importance.max()
        if max_importance > 0:
            layer_importance = layer_importance / max_importance
        
        return layer_importance
    
    def visualize_layer_importance(self, layer_importance, save_path="layer_importance.png"):
        """
        Visualize layer importance scores as a bar chart.
        
        Parameters:
        -----------
        layer_importance : torch.Tensor
            Tensor of layer importance scores from compute_layer_importance
        save_path : str
            Path to save the visualization
        """
        import matplotlib.pyplot as plt
        import seaborn as sns
        import numpy as np
        
        plt.figure(figsize=(12, 6))
        
        # Convert to numpy for plotting
        importance_np = layer_importance.cpu().numpy()
        
        # Create bar plot with layer indices
        sns.barplot(x=np.arange(self.num_layers), y=importance_np)
        
        plt.title(f"Layer Importance Scores for {self.model.config.model_type}")
        plt.xlabel("Layer Index")
        plt.ylabel("Importance Score")
        plt.xticks(np.arange(0, self.num_layers, max(1, self.num_layers // 10)))
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        
        # Add importance values as text above bars
        for i, val in enumerate(importance_np):
            if i % max(1, self.num_layers // 10) == 0:  # Only show some values to avoid clutter
                plt.text(i, val + 0.02, f"{val:.2f}", ha='center', va='bottom', rotation=45)
        
        plt.tight_layout()
        plt.savefig(save_path)
        plt.close()
        
    def prune_layers(self, layer_importance, pruning_ratio=0.3):
        """
        Create a pruning mask based on layer importance scores.
        
        Parameters:
        -----------
        layer_importance : torch.Tensor
            Tensor of layer importance scores
        pruning_ratio : float
            Ratio of layers to prune (0.0 to 1.0)
            
        Returns:
        --------
        torch.Tensor
            Binary mask indicating which layers to keep (1) and which to prune (0)
        """
        import torch
        
        # Determine number of layers to prune
        num_to_prune = int(self.num_layers * pruning_ratio)
        
        # Create full mask (all 1s initially)
        pruning_mask = torch.ones_like(layer_importance)
        
        if num_to_prune > 0:
            # Get indices of least important layers
            _, indices = torch.topk(layer_importance, self.num_layers - num_to_prune, largest=True)
            
            # Create mask (0 for pruned layers, 1 for kept layers)
            pruning_mask = torch.zeros_like(layer_importance)
            pruning_mask[indices] = 1.0
            
        return pruning_mask
        
    def evaluate_with_mask(self, dataloader, layer_mask, num_eval_batches=50):
        """
        Evaluate model performance with a given layer mask.
        
        Parameters:
        -----------
        dataloader : torch.utils.data.DataLoader
            DataLoader containing evaluation data
        layer_mask : torch.Tensor
            Binary mask indicating which layers to use
        num_eval_batches : int
            Number of batches to use for evaluation
            
        Returns:
        --------
        float
            Average loss on the evaluation data
        """
        import torch
        
        self.model.eval()
        total_loss = 0.0
        batch_count = 0
        
        # Register hooks for each transformer layer to apply masks
        hooks = []
        
        def get_layer_hook(layer_idx):
            def hook(module, input, output):
                # Apply layer mask - if mask is 0, effectively disables the layer
                mask_value = layer_mask[layer_idx]
                if mask_value == 0:
                    # For pruned layer, pass through the input directly
                    return input[0]  # Most transformers return a tuple where first element is the layer input
                else:
                    return output
            return hook
        
        # Attach hooks to transformer layers
        for i in range(self.num_layers):
            if hasattr(self.model, "transformer"):
                layer = self.model.transformer.h[i]
            elif hasattr(self.model, "model"):
                if hasattr(self.model.model, "layers"):
                    layer = self.model.model.layers[i]
                else:
                    layer = self.model.model.decoder.layers[i]
            elif hasattr(self.model, "decoder"):
                layer = self.model.decoder.layers[i]
            elif hasattr(self.model, "layers"):
                layer = self.model.layers[i]
            else:
                raise ValueError(f"Unsupported model architecture: {type(self.model)}")
                
            hook = layer.register_forward_hook(get_layer_hook(i))
            hooks.append(hook)
        
        with torch.no_grad():
            for batch in dataloader:
                if batch_count >= num_eval_batches:
                    break
                    
                input_ids = batch["input_ids"].to(self.device)
                attention_mask = batch["attention_mask"].to(self.device) if "attention_mask" in batch else None
                labels = batch["labels"].to(self.device)
                
                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                
                total_loss += outputs.loss.item()
                batch_count += 1
        
        # Remove hooks
        for hook in hooks:
            hook.remove()
        
        return total_loss / batch_count if batch_count > 0 else float('inf')
    
    def find_optimal_pruning(self, dataloader, importance_scores, pruning_ratios=None, tolerance=0.1):
        """
        Find the optimal pruning ratio by testing different ratios.
        
        Parameters:
        -----------
        dataloader : torch.utils.data.DataLoader
            DataLoader containing evaluation data
        importance_scores : torch.Tensor
            Layer importance scores
        pruning_ratios : list or None
            List of pruning ratios to try. If None, uses default values
        tolerance : float
            Maximum acceptable performance degradation ratio
            
        Returns:
        --------
        tuple
            (optimal_pruning_ratio, optimal_mask, performance_results)
        """
        import torch
        
        if pruning_ratios is None:
            pruning_ratios = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
        
        # First evaluate the model with no pruning to get baseline performance
        baseline_mask = torch.ones(self.num_layers, device=self.device)
        baseline_loss = self.evaluate_with_mask(dataloader, baseline_mask)
        
        results = []
        optimal_ratio = 0.0
        optimal_mask = baseline_mask
        
        for ratio in pruning_ratios:
            # Generate mask based on importance scores
            mask = self.prune_layers(importance_scores, ratio)
            
            # Evaluate with this mask
            loss = self.evaluate_with_mask(dataloader, mask)
            
            # Calculate relative performance degradation
            relative_degradation = (loss - baseline_loss) / baseline_loss
            
            results.append({
                'pruning_ratio': ratio,
                'loss': loss,
                'relative_degradation': relative_degradation,
                'mask': mask.clone()
            })
            
            # Update optimal if within tolerance and most aggressive pruning so far
            if relative_degradation <= tolerance and ratio > optimal_ratio:
                optimal_ratio = ratio
                optimal_mask = mask.clone()
        
        return optimal_ratio, optimal_mask, results
    
    def visualize_pruning_results(self, results, save_path="pruning_results.png"):
        """
        Visualize the impact of different pruning ratios on model performance.
        
        Parameters:
        -----------
        results : list
            List of dictionaries with pruning results from find_optimal_pruning
        save_path : str
            Path to save the visualization
        """
        import matplotlib.pyplot as plt
        
        ratios = [r['pruning_ratio'] for r in results]
        losses = [r['loss'] for r in results]
        degradations = [r['relative_degradation'] for r in results]
        
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), sharex=True)
        
        # Plot loss
        ax1.plot(ratios, losses, 'o-', color='blue')
        ax1.set_ylabel('Loss')
        ax1.set_title('Effect of Layer Pruning on Model Performance')
        ax1.grid(True)
        
        # Plot relative degradation
        ax2.plot(ratios, degradations, 'o-', color='red')
        ax2.axhline(y=0.1, color='green', linestyle='--', label='10% Degradation Threshold')
        ax2.set_xlabel('Pruning Ratio')
        ax2.set_ylabel('Relative Performance Degradation')
        ax2.grid(True)
        ax2.legend()
        
        plt.tight_layout()
        plt.savefig(save_path)
        plt.close()
    
    def apply_permanent_pruning(self, layer_mask):
        """
        Apply permanent pruning to the model by modifying its architecture.
        
        WARNING: This permanently modifies the model architecture and cannot be undone.
        
        Parameters:
        -----------
        layer_mask : torch.Tensor
            Binary mask indicating which layers to keep (1) and which to prune (0)
            
        Returns:
        --------
        model
            The pruned model
        """
        import torch
        import copy
        
        # Create a copy of the model to avoid modifying the original
        pruned_model = copy.deepcopy(self.model)
        
        # Get indices of layers to keep
        keep_indices = torch.nonzero(layer_mask).squeeze().tolist()
        
        # Handle the case where only one layer is kept
        if isinstance(keep_indices, int):
            keep_indices = [keep_indices]
        
        # Modify the layer structure based on model type
        if hasattr(pruned_model, "transformer"):
            # GPT-2 style
            pruned_model.transformer.h = torch.nn.ModuleList(
                [pruned_model.transformer.h[i] for i in keep_indices]
            )
            pruned_model.config.num_hidden_layers = len(keep_indices)
            
        elif hasattr(pruned_model, "model"):
            if hasattr(pruned_model.model, "layers"):
                pruned_model.model.layers = torch.nn.ModuleList(
                    [pruned_model.model.layers[i] for i in keep_indices]
                )
            else:
                pruned_model.model.decoder.layers = torch.nn.ModuleList(
                    [pruned_model.model.decoder.layers[i] for i in keep_indices]
                )
            pruned_model.config.num_hidden_layers = len(keep_indices)
            
        elif hasattr(pruned_model, "decoder"):
            pruned_model.decoder.layers = torch.nn.ModuleList(
                [pruned_model.decoder.layers[i] for i in keep_indices]
            )
            pruned_model.config.num_hidden_layers = len(keep_indices)
            
        elif hasattr(pruned_model, "layers"):
            pruned_model.layers = torch.nn.ModuleList(
                [pruned_model.layers[i] for i in keep_indices]
            )
            pruned_model.config.num_hidden_layers = len(keep_indices)
            
        # Update model attributes
        self.model = pruned_model
        self.num_layers = len(keep_indices)
        
        return pruned_model

In [None]:
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm

# Import the LLMLayerImportance class we defined earlier
# For this example, assume the class is in a file called llm_layer_importance.py
# from llm_layer_importance import LLMLayerImportance


In [7]:

def main():
    print("Starting GPT-2 layer importance analysis...")
    
    # Initialize the LLMLayerImportance with GPT-2 model
    # Using the small GPT-2 model for faster computation
    print("Loading GPT-2 model...")
    lli = LLMLayerImportance(model_name="gpt2", device=None)  # None will use CUDA if available
    
    # Print model details
    print(f"Model: {lli.model.config.model_type}")
    print(f"Number of layers: {lli.num_layers}")
    print(f"Number of attention heads per layer: {lli.num_attention_heads}")
    
    # Load and preprocess a small dataset for evaluation
    # Using WikiText for language modeling evaluation
    print("Loading WikiText dataset...")
    dataloader = lli.get_dataloader(
        path="wikitext",
        name="wikitext-2-raw-v1",
        split="test",
        batch_size=4,
        shuffle=False
    )
    
    # Compute layer importance
    print("Computing layer importance scores...")
    layer_importance = lli.compute_layer_importance(
        dataloader=dataloader,
        num_batches=10  # Using a small number of batches for demonstration
    )
    
    # Print the layer importance scores
    print("\nLayer Importance Scores:")
    for i, score in enumerate(layer_importance.cpu().numpy()):
        print(f"Layer {i}: {score:.4f}")
    
    # Visualize the importance scores
    print("\nVisualizing layer importance...")
    lli.visualize_layer_importance(layer_importance, save_path="gpt2_layer_importance.png")
    print("Visualization saved to gpt2_layer_importance.png")
    
    # Find optimal pruning ratio
    print("\nFinding optimal pruning configuration...")
    pruning_ratios = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5]
    optimal_ratio, optimal_mask, results = lli.find_optimal_pruning(
        dataloader=dataloader,
        importance_scores=layer_importance,
        pruning_ratios=pruning_ratios,
        tolerance=0.1  # Allow up to 10% performance degradation
    )
    
    # Visualize pruning results
    lli.visualize_pruning_results(results, save_path="gpt2_pruning_results.png")
    print("Pruning results visualization saved to gpt2_pruning_results.png")
    
    # Print optimal pruning information
    print(f"\nOptimal pruning ratio: {optimal_ratio:.2f}")
    print(f"Number of layers to keep: {int(lli.num_layers * (1 - optimal_ratio))}")
    print(f"Layers to keep: {torch.nonzero(optimal_mask).squeeze().tolist()}")
    
    # Evaluate the pruned model
    print("\nEvaluating model performance before and after pruning...")
    
    # Baseline performance (no pruning)
    baseline_mask = torch.ones(lli.num_layers, device=lli.device)
    baseline_loss = lli.evaluate_with_mask(dataloader, baseline_mask, num_eval_batches=10)
    print(f"Original model loss: {baseline_loss:.4f}")
    
    # Optimal pruning performance
    pruned_loss = lli.evaluate_with_mask(dataloader, optimal_mask, num_eval_batches=10)
    print(f"Pruned model loss: {pruned_loss:.4f}")
    print(f"Performance impact: {((pruned_loss - baseline_loss) / baseline_loss) * 100:.2f}%")
    
    # Apply permanent pruning (optional)
    if optimal_ratio > 0:
        print("\nDemonstrating permanent pruning...")
        original_num_layers = lli.num_layers
        pruned_model = lli.apply_permanent_pruning(optimal_mask)
        print(f"Model permanently pruned from {original_num_layers} to {lli.num_layers} layers")
        
        # Final evaluation of the permanently pruned model
        final_dataloader = lli.get_dataloader(
            path="wikitext",
            name="wikitext-2-raw-v1",
            split="test",
            batch_size=4,
            shuffle=False
        )
        
        final_loss = lli.evaluate_with_mask(
            final_dataloader, 
            torch.ones(lli.num_layers, device=lli.device),  # All layers active in pruned model
            num_eval_batches=10
        )
        print(f"Final pruned model loss: {final_loss:.4f}")
    
    print("\nAnalysis complete!")



In [8]:
if __name__ == "__main__":
    # Set random seed for reproducibility
    torch.manual_seed(42)
    
    try:
        main()
    except Exception as e:
        print(f"Error occurred: {e}")

Starting GPT-2 layer importance analysis...
Loading GPT-2 model...
Model: gpt2
Number of layers: 12
Number of attention heads per layer: 12
Loading WikiText dataset...
Computing layer importance scores...
Error occurred: index out of range in self
