<a href="https://colab.research.google.com/github/peremartra/Rearchitecting-LLMs/blob/main/CH02/CH02_NB01_Depth_pruning_evaluation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Rearchitecting LLMs
## Surgical Optimization for Hyper-Efficient Models


### Chapter 2: Rearchitecting an LLM: A hands-on introduction
by [Pere Martra](https://github.com/peremartra)
_____
Colab Environment: GPU T4

- Models recommended: google/gemma-3-270m

- Tested with: meta-llama/Llama-3.2-1B
_____

Welcome to your first hands-on model tailoring project. In this notebook, we will follow the first two steps of a typical optimization workflow:

* Establish a Baseline: We will measure the performance of a standard, pre-trained model on a few key metrics.

* Perform the Surgery: We will surgically remove entire layers from the model's architecture (a technique known as depth pruning).

* Evaluate the Impact: We will re-run our baseline evaluation to precisely quantify the effect of our intervention.

This notebook sets the stage for the next step, covered in the book and the following notebook: recovering the model's lost knowledge through fine-tuning.

# 2.1 Setting up the project and establising the pipeline

## Libraries

In [1]:
!pip install -q \
      "torch==2.8.0+cu126" \
      "transformers==4.55.4" \
      "accelerate==1.10.1" \
      "lm_eval==0.4.9.1" \
      "sentencepiece==0.2.1" \
      "sentence-transformers==5.1.0" \
      "optipfair==0.1.4"

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.6/53.6 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.5/7.5 MB[0m [31m73.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.0/40.0 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.5/491.5 kB[0m [31m31.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m293.6/293.6 kB[0m [31m22.6 MB/s[0m eta [36m0:00:0

In [2]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from optipfair.pruning.depth import prune_model_depth
from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM
import time
import json
import numpy as np


# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

Using device: cuda
GPU: Tesla T4


## Load model

In [None]:
# Model configuration
MODEL_NAME = "google/gemma-3-270m"
#MODEL_NAME = "meta-llama/Llama-3.2-1B"
MAX_NEW_TOKENS = 50
LAYERS_TO_REMOVE = 2
TEST_PROMPT = "Paris is the capital of"

print(f"Loading model: {MODEL_NAME}")

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

## Model structure

In [4]:
original_params= sum(p.numel() for p in model.parameters())
print(f"Total parameters: {original_params}")
print(f"Original layers: {len(model.model.layers)}")

print("=" * 20)
print(model)

Total parameters: 268098176
Original layers: 18
Gemma3ForCausalLM(
  (model): Gemma3TextModel(
    (embed_tokens): Gemma3TextScaledWordEmbedding(262144, 640, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x Gemma3DecoderLayer(
        (self_attn): Gemma3Attention(
          (q_proj): Linear(in_features=640, out_features=1024, bias=False)
          (k_proj): Linear(in_features=640, out_features=256, bias=False)
          (v_proj): Linear(in_features=640, out_features=256, bias=False)
          (o_proj): Linear(in_features=1024, out_features=640, bias=False)
          (q_norm): Gemma3RMSNorm((256,), eps=1e-06)
          (k_norm): Gemma3RMSNorm((256,), eps=1e-06)
        )
        (mlp): Gemma3MLP(
          (gate_proj): Linear(in_features=640, out_features=2048, bias=False)
          (up_proj): Linear(in_features=640, out_features=2048, bias=False)
          (down_proj): Linear(in_features=2048, out_features=640, bias=False)
          (act_fn): PytorchGELUTanh()
        )
    

## Basic generation test

### Support functions

In [5]:
# Clean conflicting configuration after loading the model
from transformers import GenerationConfig
clean_config = GenerationConfig(
    max_length=model.generation_config.max_length,
    pad_token_id=model.generation_config.pad_token_id,
    eos_token_id=model.generation_config.eos_token_id,
    do_sample=False,
    num_beams=1,
    early_stopping=False
)
model.generation_config = clean_config

print("✓ Model generation config cleaned")
print(f"New config: {model.generation_config}")

✓ Model generation config cleaned
New config: GenerationConfig {}



In [6]:
def count_parameters(model):
    """Count total parameters in model"""
    return sum(p.numel() for p in model.parameters())

def generate_text(model, tokenizer, prompt: str, max_new_tokens: int = MAX_NEW_TOKENS) -> str:
    """Generate text with the model"""
    inputs = tokenizer(prompt, return_tensors='pt').to(device)

    with torch.no_grad():
        outputs = model.generate(
            inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_new_tokens=max_new_tokens,
            num_return_sequences=1,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=False,
            num_beams=3,
            #early_stopping=True,
            no_repeat_ngram_size=2
        )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)

### Generating response

In [7]:
# A simple prompt to test basic knowledge and coherence
prompt = TEST_PROMPT
generated_text = generate_text(model, tokenizer, prompt)

print(f"Prompt: '{prompt}'")
print(f"Generated Text: '{generated_text}'")

Prompt: 'Paris is the capital of'
Generated Text: 'Paris is the capital of France. It is located in the middle of the country and has a population of about 10 million people. Paris is a city with a rich history and culture. The city is known for its beautiful architecture, art, and history. There are'


## Baseline Evaluation

In [8]:
def model_evaluation(model_obj, tokenizer, tasks, limit=100):
    """
    Runs lm-eval on a PyTorch model object already in memory.

    Args:
        model_obj: The PyTorch model object to evaluate.
        tokenizer: The tokenizer object.
        tasks (list): A list of task names.
        limit (int): The number of samples per task.
    """
    print(f"Starting lm-eval on model '{model_obj.config._name_or_path}' for tasks: {tasks}")

    # Wrap the local model object and tokenizer for lm-eval
    model_wrapper = HFLM(
        pretrained=model_obj,
        tokenizer=tokenizer,
        device=str(device)
    )

    results = evaluator.simple_evaluate(
        model=model_wrapper,
        tasks=tasks,
        num_fewshot=0,
        limit=limit,
        device=str(device),
    )

    # Format results for clean display
    formatted_results = {}
    for task_name, res in results["results"].items():
        # Look for accuracy ('acc') first, then perplexity ('ppl')
        if 'acc,none' in res:
            metric_val = res.get('acc,none', 0)
        elif 'ppl,none' in res:
             metric_val = res.get('ppl,none', 0)
        else:
            metric_val = 0 # Fallback

        formatted_results[task_name] = f"{metric_val:.4f}"

    print(json.dumps(formatted_results, indent=2))
    return formatted_results


In [None]:
# Define the benchmark suite for our diagnostic
benchmark_tasks = ['arc_easy', 'winogrande', 'boolq', 'lambada_openai']

# Run the evaluation
baseline_results = model_evaluation(model, tokenizer, benchmark_tasks, limit=100)


In [10]:
baseline_results

{'arc_easy': '0.5500',
 'boolq': '0.6600',
 'lambada_openai': '0.4200',
 'winogrande': '0.6000'}

## Inference time response

In [11]:
#### Inference Speed Measurement

import gc
def clear_gpu_cache():
    """Limpia completamente la cache de GPU"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    gc.collect()

def measure_inference_time(model, tokenizer, prompts, num_runs=5):
    """Measure average inference time across multiple runs"""
    times = []

    # Warm up GPU
    print("Warming up base model...")
    clear_gpu_cache()
    for _ in range(20):
        _ = generate_text(model, tokenizer, "warmup", max_new_tokens=50)
    torch.cuda.synchronize()

    for run in range(num_runs):
        start_time = time.time()

        for prompt in prompts:
            inputs = tokenizer(prompt, return_tensors='pt').to(device)
            with torch.no_grad():
                _ = model.generate(
                    inputs['input_ids'],
                    attention_mask=inputs['attention_mask'],
                    max_new_tokens=50,
                    do_sample=False,
                    pad_token_id=tokenizer.pad_token_id
                )

        end_time = time.time()
        times.append(end_time - start_time)

    return {
        'mean_time': np.mean(times),
        'std_time': np.std(times),
        'all_times': times
    }

# Test prompts for speed measurement
speed_test_prompts = [
    "The capital of France is",
    "Machine learning is",
    "Climate change refers to"
]

In [12]:
# Measure original model
print("Measuring original model...")
original_timing = measure_inference_time(model, tokenizer, speed_test_prompts)

Measuring original model...
Warming up base model...


In [13]:
print(f"\n📊 Inference Speed Results:")
print(f"   Original model: {original_timing['mean_time']:.3f}s ± {original_timing['std_time']:.3f}s")


📊 Inference Speed Results:
   Original model: 4.748s ± 0.291s


# 2.2 Applying depth pruning to the model
Let's remove the last layers of the model, based on the results from the paper:
* Kim, B.-K., Kim, G., Kim, T.-H., Castells, T., Choi, S., Shin, J., & Song, H.-K. (2024). Shortened LLaMA: Depth Pruning for Large Language Models with Comparison of Retraining Methods. http://arxiv.org/abs/2402.02834

In [14]:
# Create a copy of the original model for manual pruning
import copy
pruned_model = copy.deepcopy(model)
print(f"Original model structure:")
print(f"  - Total layers: {len(pruned_model.model.layers)}")
print(f"  - Layer indices: 0 to {len(pruned_model.model.layers)-1}")
print(f"  - Parameters: {count_parameters(pruned_model):,}")

Original model structure:
  - Total layers: 18
  - Layer indices: 0 to 17
  - Parameters: 268,098,176


In [15]:
# Manual layer removal - remove last N layers
print(f"\nManually removing last {LAYERS_TO_REMOVE} layers...")
original_layers_count = len(pruned_model.model.layers)
new_layers_count = original_layers_count - LAYERS_TO_REMOVE
# Create new layer list excluding the last LAYERS_TO_REMOVE layers
new_layers = pruned_model.model.layers[:new_layers_count]
pruned_model.model.layers = nn.ModuleList(new_layers)

# Update model configuration to reflect the change
pruned_model.config.num_hidden_layers = len(pruned_model.model.layers)



Manually removing last 2 layers...


In [16]:
# Verify the manual pruning
manual_params = count_parameters(pruned_model)
manual_layers = len(pruned_model.model.layers)

print(f"Manual pruning results:")
print(f"  - New layer count: {manual_layers}")
print(f"  - New parameter count: {manual_params:,}")
print(f"  - Parameters removed: {original_params - manual_params:,}")
print(f"  - Reduction: {((original_params - manual_params) / original_params * 100):.2f}%")

Manual pruning results:
  - New layer count: 16
  - New parameter count: 256,950,912
  - Parameters removed: 11,147,264
  - Reduction: 4.16%


In [17]:
pruned_model

Gemma3ForCausalLM(
  (model): Gemma3TextModel(
    (embed_tokens): Gemma3TextScaledWordEmbedding(262144, 640, padding_idx=0)
    (layers): ModuleList(
      (0-15): 16 x Gemma3DecoderLayer(
        (self_attn): Gemma3Attention(
          (q_proj): Linear(in_features=640, out_features=1024, bias=False)
          (k_proj): Linear(in_features=640, out_features=256, bias=False)
          (v_proj): Linear(in_features=640, out_features=256, bias=False)
          (o_proj): Linear(in_features=1024, out_features=640, bias=False)
          (q_norm): Gemma3RMSNorm((256,), eps=1e-06)
          (k_norm): Gemma3RMSNorm((256,), eps=1e-06)
        )
        (mlp): Gemma3MLP(
          (gate_proj): Linear(in_features=640, out_features=2048, bias=False)
          (up_proj): Linear(in_features=640, out_features=2048, bias=False)
          (down_proj): Linear(in_features=2048, out_features=640, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma3RMSNorm((640,), eps

## Basic pruned generation test.

In [17]:
print("Unloading the original model from VRAM to ensure isolated measurement...")
del model
clear_gpu_cache()

Unloading the original model from VRAM to ensure isolated measurement...


In [18]:
# A simple prompt to test basic knowledge and coherence
generated_text_pruned = generate_text(pruned_model, tokenizer, TEST_PROMPT)
print(f"Prompt: '{prompt}'")
print(f"Generated Text Pruned: '{generated_text_pruned}'")
print(f"Generated Text Base  : '{generated_text}'")

Prompt: 'Paris is the capital of'
Generated Text Pruned: 'Paris is the capital of France and one of the largest cities in Europe. It occupies approximately 2.5 million hectares of land surrounded by mountains and forests. Parisians love to travel abroad because they enjoy sightseeing tours abroad. Tourists visiting Paris visit museums, monuments, theaters,'
Generated Text Base  : 'Paris is the capital of France. It is located in the middle of the country and has a population of about 10 million people. Paris is a city with a rich history and culture. The city is known for its beautiful architecture, art, and history. There are'


## Pruned Evaluation

In [19]:
## Evaluation with lm-eval
# Run evaluation on the pruned model object we have in memory
print("--- Evaluating Pruned Model ---")

# The 'pruned_model' variable holds the model we modified in Section 2.1
pruned_results = model_evaluation(pruned_model, tokenizer, benchmark_tasks, limit=100)



--- Evaluating Pruned Model ---
Starting lm-eval on model 'google/gemma-3-270m' for tasks: ['arc_easy', 'winogrande', 'boolq', 'lambada_openai']


100%|██████████| 100/100 [00:00<00:00, 558.76it/s]
100%|██████████| 100/100 [00:00<00:00, 1917.48it/s]
100%|██████████| 100/100 [00:00<00:00, 60436.66it/s]
100%|██████████| 100/100 [00:00<00:00, 1123.41it/s]
Running loglikelihood requests: 100%|██████████| 899/899 [00:26<00:00, 34.36it/s]


bootstrapping for stddev: perplexity


100%|██████████| 100/100 [00:01<00:00, 95.10it/s]


{
  "arc_easy": "0.4600",
  "boolq": "0.4500",
  "lambada_openai": "0.3400",
  "winogrande": "0.4800"
}


In [20]:
from IPython.display import display, Markdown

# Calculate parameter reduction
pruned_params = count_parameters(pruned_model)
param_reduction_pct = (original_params - pruned_params) / original_params

# Helper function to calculate percentage change
def calculate_change(old, new):
    old, new = float(old), float(new)
    if old == 0:
        return "N/A"
    return f"{(new - old) / old:+.2%}"

# Create the comparison table
model_id = MODEL_NAME.split('/')[-1]  # Get just the model name without org

markdown_table = f"""
## Performance Impact Analysis

| Metric                  | Original Model (`{model_id}`) | Pruned Model (-{LAYERS_TO_REMOVE} Layers) | Change          |
| :---------------------- | :----------------------------- | :----------------------- | :-------------- |
| **Parameters**          | {original_params:,}              | {pruned_params:,}          | **-{param_reduction_pct:.2%}** |
| **arc_easy** (acc)      | {baseline_results['arc_easy']}   | {pruned_results['arc_easy']} | {calculate_change(baseline_results['arc_easy'], pruned_results['arc_easy'])} |
| **winogrande** (acc)    | {baseline_results['winogrande']} | {pruned_results['winogrande']} | {calculate_change(baseline_results['winogrande'], pruned_results['winogrande'])} |
| **boolq** (acc)         | {baseline_results['boolq']}      | {pruned_results['boolq']}      | {calculate_change(baseline_results['boolq'], pruned_results['boolq'])} |
| **lambada_openai** (acc)| {baseline_results['lambada_openai']} | {pruned_results['lambada_openai']} | {calculate_change(baseline_results['lambada_openai'], pruned_results['lambada_openai'])} |

### Key Insights:
- **Efficiency Gain**: {param_reduction_pct:.1%} parameter reduction → faster inference & lower memory
- **Performance Impact**: See individual benchmark changes above
- **Next Step**: Knowledge recovery through fine-tuning (Section 2.3)
"""

display(Markdown(markdown_table))


## Performance Impact Analysis

| Metric                  | Original Model (`gemma-3-270m`) | Pruned Model (-2 Layers) | Change          |
| :---------------------- | :----------------------------- | :----------------------- | :-------------- |
| **Parameters**          | 268,098,176              | 256,950,912          | **-4.16%** |
| **arc_easy** (acc)      | 0.5500   | 0.4600 | -16.36% |
| **winogrande** (acc)    | 0.6000 | 0.4800 | -20.00% |
| **boolq** (acc)         | 0.6600      | 0.4500      | -31.82% |
| **lambada_openai** (acc)| 0.4200 | 0.3400 | -19.05% |

### Key Insights:
- **Efficiency Gain**: 4.2% parameter reduction → faster inference & lower memory
- **Performance Impact**: See individual benchmark changes above
- **Next Step**: Knowledge recovery through fine-tuning (Section 2.3)


In [23]:
print(f"Measuring inference speed across {len(speed_test_prompts)} prompts, 5 runs each...")

# Measure pruned model
pruned_timing = measure_inference_time(pruned_model, tokenizer, speed_test_prompts)

# Calculate speedup
speedup = original_timing['mean_time'] / pruned_timing['mean_time']
time_reduction_pct = (original_timing['mean_time'] - pruned_timing['mean_time']) / original_timing['mean_time']

print(f"\n📊 Inference Speed Results:")
print(f"   Original model: {original_timing['mean_time']:.3f}s ± {original_timing['std_time']:.3f}s")
print(f"   Pruned model:   {pruned_timing['mean_time']:.3f}s ± {pruned_timing['std_time']:.3f}s")
print(f"   Speedup:        {speedup:.2f}x ({time_reduction_pct:+.1%} faster)")

Measuring inference speed across 3 prompts, 5 runs each...
Warming up base model...

📊 Inference Speed Results:
   Original model: 4.651s ± 0.213s
   Pruned model:   4.181s ± 0.163s
   Speedup:        1.11x (+10.1% faster)


In [24]:
# Update the comparison table to include measured speed
markdown_table = f"""
## Performance Impact Analysis

| Metric                  | Original Model (`{model_id}`) | Pruned Model (-{LAYERS_TO_REMOVE} Layers) | Change          |
| :---------------------- | :----------------------------- | :----------------------- | :-------------- |
| **Parameters**          | {original_params:,}              | {pruned_params:,}          | **-{param_reduction_pct:.2%}** |
| **Inference Time**      | {original_timing['mean_time']:.3f}s | {pruned_timing['mean_time']:.3f}s | **{time_reduction_pct:+.1%}** |
| **arc_easy** (acc)      | {baseline_results['arc_easy']}   | {pruned_results['arc_easy']} | {calculate_change(baseline_results['arc_easy'], pruned_results['arc_easy'])} |
| **winogrande** (acc)    | {baseline_results['winogrande']} | {pruned_results['winogrande']} | {calculate_change(baseline_results['winogrande'], pruned_results['winogrande'])} |
| **boolq** (acc)         | {baseline_results['boolq']}      | {pruned_results['boolq']}      | {calculate_change(baseline_results['boolq'], pruned_results['boolq'])} |
| **lambada_openai** (acc)| {baseline_results['lambada_openai']} | {pruned_results['lambada_openai']} | {calculate_change(baseline_results['lambada_openai'], pruned_results['lambada_openai'])} |

### Key Insights:
- **Efficiency Gain**: {param_reduction_pct:.1%} parameter reduction + {speedup:.2f}x speedup
- **Performance Impact**: See individual benchmark changes above
- **Trade-off**: {time_reduction_pct:.1%} faster inference vs accuracy changes
"""
display(Markdown(markdown_table))


## Performance Impact Analysis

| Metric                  | Original Model (`gemma-3-270m`) | Pruned Model (-2 Layers) | Change          |
| :---------------------- | :----------------------------- | :----------------------- | :-------------- |
| **Parameters**          | 268,098,176              | 256,950,912          | **-4.16%** |
| **Inference Time**      | 4.651s | 4.181s | **+10.1%** |
| **arc_easy** (acc)      | 0.5500   | 0.4600 | -16.36% |
| **winogrande** (acc)    | 0.6000 | 0.4800 | -20.00% |
| **boolq** (acc)         | 0.6600      | 0.4500      | -31.82% |
| **lambada_openai** (acc)| 0.4200 | 0.3400 | -19.05% |

### Key Insights:
- **Efficiency Gain**: 4.2% parameter reduction + 1.11x speedup
- **Performance Impact**: See individual benchmark changes above
- **Trade-off**: 10.1% faster inference vs accuracy changes



## Llama-3.2-1B Performance Impact Analysis

| Metric                  | Original Model (`Llama-3.2-1B`) | Pruned Model (-2 Layers) | Change          |
| :---------------------- | :----------------------------- | :----------------------- | :-------------- |
| **Parameters**          | 1,235,814,400              | 1,117,171,392          | **-9.84%** |
| **Inference Time**      | 6.635s | 4,752s | **+20.4%** |
| **arc_easy** (acc)      | 0.6600   | 0.4800 | -27.27% |
| **winogrande** (acc)    | 0.6000 | 0.5400 | -10% |
| **boolq** (acc)         | 0.6700      | 0.7000      | +4.48% |
| **lambada_openai** (acc)| 0.5700 | 0.1700 | -70.18% |

### Key Insights:
- **Efficiency Gain**: 9.8 parameter reduction + 1.40x speedup
- **Performance Impact**: See individual benchmark changes above
- **Trade-off**: 28.4 faster inference vs accuracy changes
