In [1]:
# !pip install -q -U bitsandbytes
# !pip install -q -U git+https://github.com/huggingface/transformers.git
# !pip install -q -U git+https://github.com/huggingface/peft.git
# !pip install -q -U git+https://github.com/huggingface/accelerate.git
# !pip install datasets

In [2]:
import time
import torch
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

import torch
import time
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from datasets import load_dataset

import torch
import numpy as np
import time
import os
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from datasets import load_dataset




  from .autonotebook import tqdm as notebook_tqdm


### Studying Quantization Effect

In [3]:
import os
import time
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

class QuantizedModelPipeline:
    def __init__(self, model_name, quantization_config=None, max_tokens=64):
        self.model_name = model_name
        self.quantization_config = quantization_config
        self.max_tokens = max_tokens
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = None
        self.tokenizer = None
        self.model_size_mb = None  # Track model size

    def load_quantized_model(self):
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)

        if self.quantization_config is None:
            # Full precision baseline
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                device_map="auto",
                torch_dtype=torch.float32
            )
            print(f"Loaded full-precision baseline model: {self.model_name}")
        else:
            # Quantized model
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                quantization_config=self.quantization_config,
                device_map="auto",
                torch_dtype=torch.bfloat16
            )
            print(f"Loaded quantized model: {self.model_name}")

        self.model.eval()
        self.model_size_mb = self._calculate_model_size()
        print(f"Model size after loading: {self.model_size_mb:.2f} MB")

    def _calculate_model_size(self):
        param_size = 0
        for param in self.model.parameters():
            param_size += param.nelement() * param.element_size()
        return param_size / (1024 ** 2)  # Convert to MB

    def calculate_perplexity(self, text):
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.device)
        with torch.no_grad():
            start_time = time.time()
            outputs = self.model(**inputs, labels=inputs["input_ids"])
            end_time = time.time()
        loss = outputs.loss.item()
        return np.exp(loss), end_time - start_time

    def generate_and_evaluate_bleu(self, input_text, reference_text):
        inputs = self.tokenizer.encode(input_text, return_tensors="pt", truncation=True, max_length=512).to(self.device)

        start_time = time.time()
        outputs = self.model.generate(inputs, max_new_tokens=self.max_tokens)
        end_time = time.time()

        generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

        smoothie = SmoothingFunction().method4
        reference = [reference_text.split()]
        candidate = generated_text.split()
        bleu_score = sentence_bleu(reference, candidate, smoothing_function=smoothie)

        return generated_text, bleu_score, end_time - start_time

    def evaluate_sample(self, input_text, reference_text):
        ppl, ppl_time = self.calculate_perplexity(input_text)
        generated_text, bleu, gen_time = self.generate_and_evaluate_bleu(input_text, reference_text)

        return {
            "generated": generated_text,
            "perplexity": ppl,
            "perplexity_time": ppl_time,
            "bleu_score": bleu,
            "generation_time": gen_time,
            "model_size_mb": self.model_size_mb
        }

    def get_transformer_layers(self):
        if hasattr(self.model, "model"):
            if hasattr(self.model.model, "layers"):
                return self.model.model.layers
            elif hasattr(self.model.model, "decoder") and hasattr(self.model.model.decoder, "layers"):
                return self.model.model.decoder.layers
        elif hasattr(self.model, "transformer") and hasattr(self.model.transformer, "h"):
            return self.model.transformer.h
        raise ValueError("Unsupported model architecture: can't find transformer layers")

    def compute_layerwise_importance(self, text):
        print("\n🔍 Computing layerwise importance (Ablation)...")
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.device)
        labels = inputs["input_ids"]

        with torch.no_grad():
            full_output = self.model(**inputs, labels=labels)
            base_loss = full_output.loss.item()
            base_perplexity = np.exp(base_loss)

        importance_scores = {}

        layers = self.get_transformer_layers()

        for i, layer in enumerate(layers):
            def forward_hook(module, input, output):
                if isinstance(output, tuple):
                    return tuple(torch.zeros_like(t) if isinstance(t, torch.Tensor) else t for t in output)
                return torch.zeros_like(output) if isinstance(output, torch.Tensor) else output

            handle = layer.register_forward_hook(forward_hook)

            with torch.no_grad():
                ablated_output = self.model(**inputs, labels=labels)
                ablated_loss = ablated_output.loss.item()
                ablated_perplexity = np.exp(ablated_loss)

            delta_loss = ablated_loss - base_loss
            delta_perplexity = ablated_perplexity - base_perplexity

            importance_scores[i] = {
                "delta_loss": delta_loss,
                "delta_perplexity": delta_perplexity,
                "ablated_loss": ablated_loss,
                "ablated_perplexity": ablated_perplexity
            }

            handle.remove()
            print(f"Layer {i}: ΔLoss = {delta_loss:.4f}, ΔPPL = {delta_perplexity:.4f}")

        return importance_scores


    def print_model_layers(self):
        print(f"\n📚 Model Layers in: {self.model_name}")
        print("=" * 50)
        for name, module in self.model.named_modules():
            print(name)


### Change model name to create experiments

In [4]:
import json
import os
import torch
from transformers import BitsAndBytesConfig

# Your model (e.g., "gpt2", "facebook/opt-125m", "meta-llama/Llama-2-7b-chat-hf")
MODEL_NAME =   "facebook/opt-350m" # "facebook/opt-1.3b" #"microsoft/phi-2" # "gpt2"
RESULTS_FILE = f"quantization_experiment_opt350_results.json"

# Example input (can be replaced by any SQuAD sample)
INPUT_TEXT = "What is the capital of France?"
REFERENCE_TEXT = "The capital of France is Paris."

# Different quantization settings  Uncomment the other configs when using gpu
experiments = {
    "baseline": None,  # Full precision
    # "8bit": BitsAndBytesConfig(load_in_8bit=True),
    # "4bit": BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4")
}

# Storage
all_results = {}

for exp_name, quant_config in experiments.items():
    print("\n" + "="*60)
    print(f"🚀 Running experiment: {exp_name}")
    print("="*60)

    # Initialize pipeline
    pipeline = QuantizedModelPipeline(model_name=MODEL_NAME, quantization_config=quant_config)
    pipeline.load_quantized_model()

    # Evaluate basic metrics
    metrics = pipeline.evaluate_sample(INPUT_TEXT, REFERENCE_TEXT)

    # Compute layerwise importance (ΔLoss + ΔPerplexity Ablation)
    layerwise_importance = pipeline.compute_layerwise_importance(INPUT_TEXT)

    # Organize and save all collected metrics
    all_results[exp_name] = {
        "evaluation_metrics": {
            "perplexity": metrics["perplexity"],
            "perplexity_time": metrics["perplexity_time"],
            "bleu_score": metrics["bleu_score"],
            "generation_time": metrics["generation_time"],
            "model_size_mb": metrics["model_size_mb"],
        },
        "layerwise_importance": layerwise_importance,  # Now includes both ΔLoss and ΔPerplexity
    }

    # Cleanup memory
    del pipeline
    torch.cuda.empty_cache()

# Save final results
with open(RESULTS_FILE, "w") as f:
    json.dump(all_results, f, indent=2)

print(f"\n✅ Finished all experiments. Results saved to {RESULTS_FILE}")



🚀 Running experiment: baseline
Loaded full-precision baseline model: facebook/opt-350m
Model size after loading: 1263.41 MB

🔍 Computing layerwise importance (Ablation)...
Layer 0: ΔLoss = 5.9505, ΔPPL = 9596.5166
Layer 1: ΔLoss = 5.4893, ΔPPL = 6041.4043
Layer 2: ΔLoss = 4.9524, ΔPPL = 3521.3421
Layer 3: ΔLoss = 6.9565, ΔPPL = 26286.5573
Layer 4: ΔLoss = 6.6137, ΔPPL = 18649.9737
Layer 5: ΔLoss = 7.2763, ΔPPL = 36202.3390
Layer 6: ΔLoss = 6.3859, ΔPPL = 14846.0663
Layer 7: ΔLoss = 7.9664, ΔPPL = 72206.1709
Layer 8: ΔLoss = 5.9905, ΔPPL = 9989.3998
Layer 9: ΔLoss = 7.6145, ΔPPL = 50783.2130
Layer 10: ΔLoss = 8.1441, ΔPPL = 86254.4057
Layer 11: ΔLoss = 6.6694, ΔPPL = 19721.2947
Layer 12: ΔLoss = 4.9775, ΔPPL = 3611.4027
Layer 13: ΔLoss = 7.5007, ΔPPL = 45316.4481
Layer 14: ΔLoss = 6.1870, ΔPPL = 12163.6349
Layer 15: ΔLoss = 5.7090, ΔPPL = 7532.4272
Layer 16: ΔLoss = 5.5097, ΔPPL = 6166.4491
Layer 17: ΔLoss = 6.8885, ΔPPL = 24556.6605
Layer 18: ΔLoss = 6.7126, ΔPPL = 20592.8840
Layer 19

In [5]:
import itertools
import torch
import time
import json
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM

class LayerwiseAblationStudy:
    def __init__(self, model_name, quantization_config=None, max_combination_size=2, output_file="ablation_results.json"):
        self.model_name = model_name
        self.quantization_config = quantization_config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = None
        self.tokenizer = None
        self.max_combination_size = max_combination_size
        self.output_file = output_file

    def load_model(self):
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            quantization_config=self.quantization_config,
            device_map="auto",
            torch_dtype=torch.bfloat16 if self.quantization_config else torch.float32
        )
        self.model.eval()
        print(f"✅ Loaded model {self.model_name} with quantization: {self.quantization_config}")

    def compute_loss_and_perplexity(self, input_text):
        """Returns loss and perplexity"""
        inputs = self.tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(self.device)
        with torch.no_grad():
            outputs = self.model(**inputs, labels=inputs["input_ids"])
            loss = outputs.loss.item()
        perplexity = np.exp(loss)
        return loss, perplexity
    
    
    def get_transformer_layers(self):
        if hasattr(self.model, "model"):
            if hasattr(self.model.model, "layers"):
                return self.model.model.layers
            elif hasattr(self.model.model, "decoder") and hasattr(self.model.model.decoder, "layers"):
                return self.model.model.decoder.layers
        elif hasattr(self.model, "transformer") and hasattr(self.model.transformer, "h"):
            return self.model.transformer.h
        raise ValueError("Unsupported model architecture: can't find transformer layers")

    

    def ablation_study(self, input_text):
        print(f"🔍 Starting ablation study (max_combination_size={self.max_combination_size})...")

        base_loss, base_perplexity = self.compute_loss_and_perplexity(input_text)
        print(f"🎯 Base loss: {base_loss:.4f} | Base perplexity: {base_perplexity:.4f}")

        layers = self.get_transformer_layers()
        num_layers = len(layers)
        print(f"📚 Total layers: {num_layers}")

        ablation_results = {}

        # Initialize output file
        self.initialize_output_file(base_loss, base_perplexity)

        # Go through combinations
        for r in range(1, self.max_combination_size + 1):
            for layer_indices in itertools.combinations(range(num_layers), r):
                print(f"⚡ Masking layers: {layer_indices}")

                handles = []
                for idx in layer_indices:
                    handle = layers[idx].register_forward_hook(self.mask_hook)
                    handles.append(handle)

                ablated_loss, ablated_perplexity = self.compute_loss_and_perplexity(input_text)

                delta_loss = ablated_loss - base_loss
                delta_perplexity = ablated_perplexity - base_perplexity

                combo_key = str(layer_indices)
                result = {
                    "ablated_loss": ablated_loss,
                    "ablated_perplexity": ablated_perplexity,
                    "delta_loss": delta_loss,
                    "delta_perplexity": delta_perplexity
                }
                ablation_results[combo_key] = result

                # Append result immediately to file
                self.append_result_to_file(combo_key, result)

                # Remove hooks
                for handle in handles:
                    handle.remove()

        print("✅ Ablation study completed.")
        return ablation_results

    def initialize_output_file(self, base_loss, base_perplexity):
        """Initialize output JSON file with base loss and perplexity."""
        try:
            with open(self.output_file, "w") as f:
                json.dump({
                    "base_loss": base_loss,
                    "base_perplexity": base_perplexity,
                    "results": {}
                }, f, indent=2)
        except Exception as e:
            print(f"❌ Failed to initialize output file: {e}")

    def append_result_to_file(self, combo_key, result):
        """Append one result to output file safely."""
        try:
            with open(self.output_file, "r") as f:
                data = json.load(f)
        except (FileNotFoundError, json.JSONDecodeError):
            data = {"base_loss": None, "base_perplexity": None, "results": {}}

        data["results"][combo_key] = result

        with open(self.output_file, "w") as f:
            json.dump(data, f, indent=2)

    @staticmethod
    def mask_hook(module, input, output):
        if isinstance(output, tuple):
            return tuple(torch.zeros_like(t) if isinstance(t, torch.Tensor) else t for t in output)
        return torch.zeros_like(output) if isinstance(output, torch.Tensor) else output


In [6]:
# from transformers import BitsAndBytesConfig

# model_name = "microsoft/phi-2" #"gpt2"
# input_text = "The theory of relativity was proposed by"

# # Example quantization config (optional)
# quant_config = None #BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)

# # Create the study object
# ablation_study = LayerwiseAblationStudy(
#     model_name=model_name,
#     quantization_config=quant_config,
#     max_combination_size=1,   # (careful, combinations grow fast!)
#     output_file="my_ablation_results-phi-1.json"
# )

# ablation_study.load_model()
# results = ablation_study.ablation_study(input_text)


In [7]:
import torch
from transformers import BitsAndBytesConfig
from pathlib import Path
import json

# Assuming LayerwiseAblationStudy is already defined and imported
# from your module

model_name = "facebook/opt-350m" # "microsoft/phi-2"
input_text = "The theory of relativity was proposed by"

# Define quantization configurations to test
quantization_setups = {
    "fp32": None,
    # "8bit": BitsAndBytesConfig(load_in_8bit=True),
    # "4bit": BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16),
    # "nf4": BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16)
}

# Run ablation study for each configuration
for config_name, quant_config in quantization_setups.items():
    print(f"\n🧪 Running ablation study with quantization: {config_name}")

    output_file = f"ablation_opt350_{config_name}.json"

    ablation_study = LayerwiseAblationStudy(
        model_name=model_name,
        quantization_config=quant_config,
        max_combination_size=2,   ## change to number between 3-5
        output_file=output_file
    )

    ablation_study.load_model()
    results = ablation_study.ablation_study(input_text)

    print(f"✅ Saved results to {output_file}")



🧪 Running ablation study with quantization: fp32
✅ Loaded model facebook/opt-350m with quantization: None
🔍 Starting ablation study (max_combination_size=2)...
🎯 Base loss: 3.8323 | Base perplexity: 46.1703
📚 Total layers: 24
⚡ Masking layers: (0,)
⚡ Masking layers: (1,)
⚡ Masking layers: (2,)
⚡ Masking layers: (3,)
⚡ Masking layers: (4,)
⚡ Masking layers: (5,)
⚡ Masking layers: (6,)
⚡ Masking layers: (7,)
⚡ Masking layers: (8,)
⚡ Masking layers: (9,)
⚡ Masking layers: (10,)
⚡ Masking layers: (11,)
⚡ Masking layers: (12,)
⚡ Masking layers: (13,)
⚡ Masking layers: (14,)
⚡ Masking layers: (15,)
⚡ Masking layers: (16,)
⚡ Masking layers: (17,)
⚡ Masking layers: (18,)
⚡ Masking layers: (19,)
⚡ Masking layers: (20,)
⚡ Masking layers: (21,)
⚡ Masking layers: (22,)
⚡ Masking layers: (23,)
⚡ Masking layers: (0, 1)
⚡ Masking layers: (0, 2)
⚡ Masking layers: (0, 3)
⚡ Masking layers: (0, 4)
⚡ Masking layers: (0, 5)
⚡ Masking layers: (0, 6)
⚡ Masking layers: (0, 7)
⚡ Masking layers: (0, 8)
⚡ Maski

### Compute Shapely Values

In [8]:
import json
import itertools
import numpy as np

def compute_shapley_values(result_file):
    # Load ablation results
    with open(result_file, "r") as f:
        data = json.load(f)

    results = data["results"]

    # Extract layers and their associated delta_loss and delta_perplexity
    layer_combinations_loss = {}
    layer_combinations_ppl = {}

    for key, value in results.items():
        layers = eval(key)  # Convert string representation to tuple (layer indices)
        delta_loss = value["delta_loss"]
        delta_perplexity = value.get("delta_perplexity", 0)  # In case old files don't have this
        layer_combinations_loss[layers] = delta_loss
        layer_combinations_ppl[layers] = delta_perplexity

    # Get all layers involved
    all_layers = set()
    for layers in layer_combinations_loss.keys():
        all_layers.update(layers)
    all_layers = sorted(list(all_layers))

    # Initialize dictionaries to store Shapley values
    shapley_values_loss = {layer: 0 for layer in all_layers}
    shapley_values_ppl = {layer: 0 for layer in all_layers}

    # Total number of combinations to consider
    num_combinations = 0

    # Loop over all non-empty subsets of layers
    for r in range(1, len(all_layers) + 1):
        for comb in itertools.combinations(all_layers, r):
            num_combinations += 1

            comb_loss = layer_combinations_loss.get(comb, 0)
            comb_ppl = layer_combinations_ppl.get(comb, 0)

            for layer in comb:
                reduced_comb = tuple(l for l in comb if l != layer)
                reduced_loss = layer_combinations_loss.get(reduced_comb, 0)
                reduced_ppl = layer_combinations_ppl.get(reduced_comb, 0)

                marginal_contribution_loss = comb_loss - reduced_loss
                marginal_contribution_ppl = comb_ppl - reduced_ppl

                shapley_values_loss[layer] += marginal_contribution_loss
                shapley_values_ppl[layer] += marginal_contribution_ppl

    # Normalize
    for layer in all_layers:
        shapley_values_loss[layer] /= num_combinations
        shapley_values_ppl[layer] /= num_combinations

    # Return both
    return {
        "shapley_values_loss": shapley_values_loss,
        "shapley_values_perplexity": shapley_values_ppl
    }


In [9]:
shapley_values = compute_shapley_values(f"ablation_opt350_{config_name}.json")
print(shapley_values)


{'shapley_values_loss': {0: -0.00011433942142352709, 1: -0.00011440076354789966, 2: -0.00011459831217536247, 3: -0.00011384989644399576, 4: -0.00011466258548764672, 5: -0.00011556469400934273, 6: -0.00011446226153693603, 7: -0.0001146758118151356, 8: -0.00011509506888875915, 9: -0.00011384022157984214, 10: -0.00011433392261823241, 11: -0.00011357689021986069, 12: -0.00011346386003057663, 13: -0.00011193019243861007, 14: -0.00011341662473830358, 15: -0.00011188384072049246, 16: -0.00011479701608853906, 17: -0.00011395638304588421, 18: -0.00011030072094278516, 19: -0.00011464380612623835, 20: -0.00011415482536563161, 21: -0.00011264590910501493, 22: -0.00012001999845381816, 23: -0.0001159405327128104}, 'shapley_values_perplexity': {0: -2.1757349277803013, 1: -2.182954429580831, 2: -2.195041815759652, 3: -2.0951577070340455, 4: -2.2194428016635666, 5: -2.246380054443724, 6: -2.2128858102147966, 7: -2.224112593741975, 8: -2.239762408006018, 9: -2.1689484220865043, 10: -2.2076437597276035, 

#### Utils for Pruning

In [10]:
import torch.nn as nn
# class TransformerBlockIdentity(nn.Module):
#     def __init__(self):
#         super().__init__()

#     def forward(self, hidden_states, *args, **kwargs):
#         # Return in the same structure that original transformer blocks return
#         # In most HuggingFace models: output is (hidden_states, present/past_key_values)
#         return (hidden_states, None)


In [15]:
# class TransformerBlockIdentity(nn.Module):
#     def __init__(self, return_tuple=True, use_cache=False):
#         """
#         return_tuple: Whether to return (hidden_states, past_key_values) or just hidden_states.
#         use_cache: Whether to simulate past_key_values for decoder models like OPT.
#         """
#         super().__init__()
#         self.return_tuple = return_tuple
#         self.use_cache = use_cache

#     def forward(self, hidden_states, *args, **kwargs):
#         if self.use_cache:
#             # Dummy past_key_values: return same structure as decoder block
#             batch_size, seq_len, hidden_dim = hidden_states.shape
#             dummy_cache = (
#                 torch.zeros(batch_size, 1, seq_len, hidden_dim // 1, device=hidden_states.device, dtype=hidden_states.dtype),
#                 torch.zeros(batch_size, 1, seq_len, hidden_dim // 1, device=hidden_states.device, dtype=hidden_states.dtype)
#             )
#             return (hidden_states, dummy_cache)

#         if self.return_tuple:
#             return (hidden_states, None)  # matches models that return tuple
#         else:
#             return hidden_states

In [21]:
import torch.nn as nn
import torch

class TransformerBlockIdentity(nn.Module):
    def __init__(self, return_tuple=True, use_cache=True):
        super().__init__()
        self.return_tuple = return_tuple
        self.use_cache = use_cache

    def forward(self, hidden_states, *args, **kwargs):
        # Handle optional keyword inputs like `use_cache`, `past_key_values`
        if self.return_tuple:
            dummy_cache = self._make_dummy_cache(hidden_states) if self.use_cache else None
            return (hidden_states, dummy_cache)
        return hidden_states

    def _make_dummy_cache(self, hidden_states):
        # Mimic a cache object with right structure: 2 tensors per layer (key & value)
        batch_size, seq_len, hidden_dim = hidden_states.shape
        head_dim = hidden_dim // 16  # typical assumption
        shape = (batch_size, 16, seq_len, head_dim)  # [batch, num_heads, seq_len, head_dim]
        dummy_tensor = torch.zeros(shape, dtype=hidden_states.dtype, device=hidden_states.device)
        return (dummy_tensor, dummy_tensor)


## Comparing against randomised pruning

## Final Eval

In [22]:
import random

class ModelPruner:
    def __init__(self, model_name, shapley_loss_values, shapley_ppl_values,
                 prune_mode="loss", prune_threshold=-1.0, max_layers_to_prune=None, quantization_config=None):
        """
        prune_mode: "loss" or "ppl" - based on which Shapley to prune
        prune_threshold: prune layers with shapley below this threshold
        max_layers_to_prune: prune top-k worst layers
        """
        self.model_name = model_name
        self.shapley_loss_values = shapley_loss_values
        self.shapley_ppl_values = shapley_ppl_values
        self.prune_mode = prune_mode
        self.prune_threshold = prune_threshold
        self.max_layers_to_prune = max_layers_to_prune
        self.quantization_config = quantization_config

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = None
        self.tokenizer = None

    def load_model(self):
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            device_map="auto",
            torch_dtype=torch.bfloat16 if self.quantization_config else torch.float32,
            quantization_config=self.quantization_config,
        )
        self.model.eval()


    def get_transformer_layers(self):
        if hasattr(self.model, "model"):
            if hasattr(self.model.model, "layers"):
                return self.model.model.layers
            elif hasattr(self.model.model, "decoder") and hasattr(self.model.model.decoder, "layers"):
                return self.model.model.decoder.layers
        elif hasattr(self.model, "transformer") and hasattr(self.model.transformer, "h"):
            return self.model.transformer.h
        raise ValueError("Unsupported model architecture: can't find transformer layers")

    

    def select_layers_to_prune(self):
        shapley_values = self.shapley_loss_values if self.prune_mode == "loss" else self.shapley_ppl_values

        # Sort by Shapley (ascending: more negative = worse)
        sorted_layers = sorted(shapley_values.items(), key=lambda x: x[1])

        if self.max_layers_to_prune:
            layers_to_prune = [layer for layer, _ in sorted_layers[:self.max_layers_to_prune]]
        else:
            layers_to_prune = [layer for layer, score in sorted_layers if score < self.prune_threshold]

        return layers_to_prune


    def select_random_layers_to_prune(self, total_layers):
        """Randomly select layers to prune."""
        all_layer_indices = list(range(total_layers))
        layers_to_prune = random.sample(all_layer_indices, self.max_layers_to_prune)
        return layers_to_prune


    def prune_layers(self, layers_to_prune):
        print(f"🔪 Pruning layers: {layers_to_prune}")

        layer_list = self.get_transformer_layers()

        # Determine model type for correct identity behavior
        model_type = type(self.model).__name__.lower()
        use_cache = "opt" in model_type or "phi" in model_type  # decoder-style models use kv cache
        return_tuple = True  # most HF models return a tuple (hidden_states, cache)

        for idx in layers_to_prune:
            print(f"Zeroing Layer {idx}")
            layer_list[idx] = TransformerBlockIdentity(return_tuple=return_tuple, use_cache=use_cache)

                
    def evaluate_model(self, input_text):
        inputs = self.tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(self.device)

        start_time = time.time()
        with torch.no_grad():
            outputs = self.model(**inputs, labels=inputs["input_ids"])
            loss = outputs.loss.item()
        end_time = time.time()

        inference_time = end_time - start_time
        perplexity = torch.exp(outputs.loss).item()
        return loss, perplexity, inference_time
    

    def generate_text(self, input_text, max_new_tokens=50):
        inputs = self.tokenizer(input_text, return_tensors="pt").to(self.device)
        with torch.no_grad():
            generated_ids = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                top_k=50,
                top_p=0.95,
                temperature=1.0,
                pad_token_id=self.tokenizer.eos_token_id
            )
        return self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    
    def run_full_evaluation(self, input_text):
        # Load and evaluate original model
        self.load_model()
        original_loss, original_perplexity, original_time = self.evaluate_model(input_text)
        original_output = self.generate_text(input_text)
        print(f"\n🏋️ Before Pruning:\nLoss = {original_loss:.4f}, Perplexity = {original_perplexity:.2f}, Inference Time = {original_time:.4f} sec")
        print(f"Generated Text:\n{original_output}")

        # -----------------------
        # Shapley-based pruning
        # -----------------------
        shapley_layers_to_prune = self.select_layers_to_prune()
        self.prune_layers(shapley_layers_to_prune)

        shapley_loss, shapley_perplexity, shapley_time = self.evaluate_model(input_text)
        shapley_output = self.generate_text(input_text)
        print(f"\n📊 After Shapley Pruning:\nLoss = {shapley_loss:.4f}, Perplexity = {shapley_perplexity:.2f}, Inference Time = {shapley_time:.4f} sec")
        print(f"Generated Text:\n{shapley_output}")

        # -----------------------
        # Random pruning
        # -----------------------
        self.load_model()  # Reload fresh model

        # try:
        #     total_layers = len(self.model.model.layers)
        # except AttributeError:
        #     total_layers = len(self.model.transformer.h)
        layer_list = self.get_transformer_layers()
        total_layers = len(layer_list)

        random_layers_to_prune = self.select_random_layers_to_prune(total_layers)
        self.prune_layers(random_layers_to_prune)

        random_loss, random_perplexity, random_time = self.evaluate_model(input_text)
        random_output = self.generate_text(input_text)
        print(f"\n🎲 After Random Pruning:\nLoss = {random_loss:.4f}, Perplexity = {random_perplexity:.2f}, Inference Time = {random_time:.4f} sec")
        print(f"Generated Text:\n{random_output}")

        return {
            "before_pruning": {
                "loss": original_loss,
                "perplexity": original_perplexity,
                "inference_time_sec": original_time,
                "generated_text": original_output
            },
            "shapley_pruning": {
                "loss": shapley_loss,
                "perplexity": shapley_perplexity,
                "inference_time_sec": shapley_time,
                "generated_text": shapley_output,
                "layers_pruned": shapley_layers_to_prune
            },
            "random_pruning": {
                "loss": random_loss,
                "perplexity": random_perplexity,
                "inference_time_sec": random_time,
                "generated_text": random_output,
                "layers_pruned": random_layers_to_prune
            }
        }





In [23]:
model_name = "facebook/opt-350m" # "microsoft/phi-2"
input_text = "The theory of relativity was proposed by"

pruner = ModelPruner(
    model_name=model_name,
    shapley_loss_values=shapley_values['shapley_values_loss'],
    shapley_ppl_values=shapley_values['shapley_values_perplexity'],
    prune_mode="loss",
    prune_threshold=-1.0,
    max_layers_to_prune=3,
    quantization_config=None
)

results = pruner.run_full_evaluation(input_text)

print("\n📈 Results Comparison:")

# Before pruning
print(f"\n🏋️ Before Pruning:")
print(f"Loss: {results['before_pruning']['loss']:.4f}, Perplexity: {results['before_pruning']['perplexity']:.2f}, Inference Time: {results['before_pruning']['inference_time_sec']:.4f} sec")
print(f"Generated Text:\n{results['before_pruning']['generated_text']}")

# After Shapley pruning
print(f"\n📊 After Shapley Pruning (layers {results['shapley_pruning']['layers_pruned']}):")
print(f"Loss: {results['shapley_pruning']['loss']:.4f}, Perplexity: {results['shapley_pruning']['perplexity']:.2f}, Inference Time: {results['shapley_pruning']['inference_time_sec']:.4f} sec")
print(f"Generated Text:\n{results['shapley_pruning']['generated_text']}")

# After Random pruning
print(f"\n🎲 After Random Pruning (layers {results['random_pruning']['layers_pruned']}):")
print(f"Loss: {results['random_pruning']['loss']:.4f}, Perplexity: {results['random_pruning']['perplexity']:.2f}, Inference Time: {results['random_pruning']['inference_time_sec']:.4f} sec")
print(f"Generated Text:\n{results['random_pruning']['generated_text']}")



🏋️ Before Pruning:
Loss = 3.8323, Perplexity = 46.17, Inference Time = 0.0743 sec
Generated Text:
The theory of relativity was proposed by Einstein and later developed by Newton. The theory was first established by Einstein in 1905.

The physics of relativity are related to the idea of a series of moving objects that are the result of gravity. The concept of relativity is the idea of the
🔪 Pruning layers: [22, 23, 5]
Zeroing Layer 22
Zeroing Layer 23
Zeroing Layer 5


AttributeError: 'tuple' object has no attribute 'to_legacy_cache'