<a href="https://colab.research.google.com/github/peremartra/Rearchitecting-LLMs/blob/main/CH04/CH04_NB0x_Cosine_Similarity.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Rearchitecting LLMs
## Surgical Optimization for Hyper-Efficient Models


### Chapter 4: Depth Pruning: Building Smaller and Faster Models
by [Pere Martra](https://github.com/peremartra)
_____
Colab Environment: GPU T4

Models:
* Qwen3-0.6B
_____

In this notebook we explore how to evaluate the contribution of different transformer blocks to the LLM’s objective using a dataset.

To do this, we use cosine similarity between the input and the output of the transformer block. The lower the similarity, the greater the modification that block has introduced to the data.

Blocks with higher similarity between input and output will be the candidates to be removed from the model.


# Setting up notebook

In [42]:
!pip install -q \
      "torch==2.8.0+cu126" \
      "transformers==4.55.4" \
      "accelerate==1.10.1" \
      "lm_eval==0.4.9.1" \
      "sentencepiece==0.2.1" \
      "sentence-transformers==5.1.0" \
      "optipfair==0.1.4"

In [43]:
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

In [44]:
# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

Using device: cuda
GPU: Tesla T4


## Load Model

In [45]:
MODEL_NAME = 'Qwen/Qwen3-0.6B'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto"
)
model.eval()

Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 1024)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (up_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (down_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen3RMSNorm((1024,), eps=1e-06)
        (post_attention_layernorm): Qwe

## Load Datasets

In [62]:
RECOVERY_SAMPLES = 100
BATCH_SIZE = 8
MAX_LENGTH = 512

We’re going to use two different datasets to visualize how some layers are more important than others depending on the data being used.

* **Wikitext**: Contains highly complex text. To process this kind of text, the model needs to rely on its deeper layers to understand context, semantic relations, and complex grammatical structures.
* **SMS Spam**: A completely different dataset, made up of short sentences with simple and direct language. It doesn’t require deep semantic understanding.


In [71]:
dataset1 = load_dataset('wikitext', 'wikitext-2-raw-v1', split=f'train[:100]')

dataset2 = load_dataset('sms_spam', split=f'train[:100]')

Cargando WikiText-2 (frases largas, enciclopédico)...
Cargando SMS Spam (frases muy cortas, mensajes)...


In [65]:
def prepare_dataset(dataset, text_field='text'):
  def tokenize_function(examples):
      if text_field in examples:
          texts = examples[text_field]
      elif 'sms' in examples:  # SMS dataset
          texts = examples['sms']
      elif 'text' in examples:
          texts = examples['text']
      else:
          texts = examples[list(examples.keys())[0]]  # First available field

      return tokenizer(
          texts,
          truncation=True,
          padding='max_length',
          max_length=MAX_LENGTH,
          return_tensors='pt'
      )

  tokenized = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)
  tokenized.set_format(type='torch', columns=['input_ids', 'attention_mask'])
  return DataLoader(tokenized, batch_size=BATCH_SIZE, shuffle=False)


In [None]:
# Crear dataloaders
dataloader1 = prepare_dataset(dataset1)  # WikiText (largo)
dataloader2 = prepare_dataset(dataset2)  # SMS (corto)

# Calculate Layer Importande using Cosine Similarity

To decide which layers to remove, we measure their contribution using cosine similarity. We chose this metric because it’s perfect for this task: it measures the change in semantic direction between the input and output vectors of a layer, ignoring their magnitude.

This gives us a normalized score that we convert into an importance score (1 - similarity).

A score close to zero identifies a “passive” layer that barely alters the information, making it an ideal candidate for removal.


## Setup Model Hooks.


To capture the input and output of the layers we use PyTorch hooks, which let us study/spy on the model’s behavior.


In [50]:
def setup_layer_hooks(model):
    """
    Register hooks to capture input/output of each transformer layer
    Returns: hooks list and storage dictionaries
    """
    num_layers = len(model.model.layers)
    layer_inputs = {}
    layer_outputs = {}
    hooks = []

    def create_input_hook(layer_idx):
        def hook(module, input):
            if isinstance(input, tuple) and len(input) > 0:
                layer_inputs[layer_idx] = input[0].detach()
        return hook

    def create_output_hook(layer_idx):
        def hook(module, input, output):
            if isinstance(output, tuple) and len(output) > 0:
                layer_outputs[layer_idx] = output[0].detach()
            else:
                layer_outputs[layer_idx] = output.detach()
        return hook

    # Register hooks for each layer
    for i, layer in enumerate(model.model.layers):
        hooks.append(layer.register_forward_pre_hook(create_input_hook(i)))
        hooks.append(layer.register_forward_hook(create_output_hook(i)))

    return hooks, layer_inputs, layer_outputs, num_layers

## Calculate Cosine Similarity

In [51]:
def calculate_cosine_importance(input_tensor, output_tensor, layer_idx, is_first_batch=False):
    """
    Calculate importance score using cosine similarity between input and output tensors
    Returns: importance score (0.0 to 1.0)
    """
    # Validate tensor dimensions
    if input_tensor.numel() == 0 or output_tensor.numel() == 0:
        return 0.0

    try:
        # Flatten tensors: [batch_size, features]
        input_flat = input_tensor.view(input_tensor.size(0), -1)
        output_flat = output_tensor.view(output_tensor.size(0), -1)

        # Filter out non-finite values
        input_valid_mask = torch.all(torch.isfinite(input_flat), dim=1)
        output_valid_mask = torch.all(torch.isfinite(output_flat), dim=1)
        valid_mask = input_valid_mask & output_valid_mask

        if not valid_mask.any():
            if is_first_batch:
                print(f"Warning: Layer {layer_idx} has all inf/nan samples")
            return 0.0

        # Use only valid samples
        input_valid = input_flat[valid_mask]
        output_valid = output_flat[valid_mask]

        # Calculate cosine similarity
        similarity = F.cosine_similarity(input_valid, output_valid, dim=1)

        # Filter finite similarities and calculate importance
        finite_similarities = similarity[torch.isfinite(similarity)]
        if len(finite_similarities) == 0:
            return 0.0

        importance = 1 - finite_similarities.mean().item()

        # Debug info for first batch only
        if is_first_batch:
            valid_samples = valid_mask.sum().item()
            avg_similarity = finite_similarities.mean().item()

        return importance

    except Exception as e:
        if is_first_batch:
            print(f"Error in layer {layer_idx}: {e}")
        return 0.0


We aggregate the results

In [52]:
def aggregate_importance_scores(layer_scores):
    """
    Aggregate importance scores across all batches
    Returns: dictionary with final averaged scores per layer
    """
    final_scores = {}
    for layer_idx, scores in layer_scores.items():
        if scores:
            # Filter out invalid scores
            valid_scores = [s for s in scores if not (np.isnan(s) or np.isinf(s))]
            final_scores[layer_idx] = np.mean(valid_scores) if valid_scores else 0.0
        else:
            final_scores[layer_idx] = 0.0

    return final_scores


This function takes the importance scores collected from all data batches for each layer. Then, it computes the average of these scores to get a single final consolidated importance score for each layer of the model.

In [53]:
def calculate_layer_importance_cosine(model, dataloader, device):
    """
    Calculate layer importance using cosine similarity between input/output representations

    Args:
        model: Transformer model
        dataloader: DataLoader with tokenized text data
        device: torch device (cuda/cpu)

    Returns:
        dict: Layer importance scores {layer_idx: importance_score}
    """
    # Setup hooks and storage
    hooks, layer_inputs, layer_outputs, num_layers = setup_layer_hooks(model)
    layer_importance_scores = {i: [] for i in range(num_layers)}

    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(dataloader, desc="Processing batches")):
            inputs = {k: v.to(device) for k, v in batch.items()}

            # Forward pass to trigger hooks
            model(**inputs)

            # Calculate importance for each layer
            for layer_idx in range(num_layers):
                if layer_idx not in layer_inputs or layer_idx not in layer_outputs:
                    layer_importance_scores[layer_idx].append(0.0)
                    continue

                input_tensor = layer_inputs[layer_idx]
                output_tensor = layer_outputs[layer_idx]

                importance = calculate_cosine_importance(
                    input_tensor, output_tensor, layer_idx,
                    is_first_batch=(batch_idx == 0)
                )
                layer_importance_scores[layer_idx].append(importance)

            # Clear storage for next batch
            layer_inputs.clear()
            layer_outputs.clear()

    # Cleanup hooks
    for hook in hooks:
        hook.remove()

    # Aggregate final scores
    final_scores = aggregate_importance_scores(layer_importance_scores)

    return final_scores

# Obtaining & Studying results

In [54]:
def print_sorted_importance(scores):
    for i, (layer, score) in enumerate(sorted(scores.items(), key=lambda x: float(x[1]), reverse=True), 1):
        print(f"{i:2d}. Layer {layer:2d}: {float(score):.6f}")

In [73]:
wiki_importance= calculate_layer_importance_cosine(model, dataloader1, device)

Calculating cosine-based importance for 28 layers...


Processing batches:   8%|▊         | 1/13 [00:00<00:06,  1.90it/s]

Layer 0: 8/8 valid samples, avg_similarity=0.1332, importance=0.866821
Layer 1: 8/8 valid samples, avg_similarity=0.6338, importance=0.366211
Layer 2: 8/8 valid samples, avg_similarity=0.3281, importance=0.671875
Layer 3: 8/8 valid samples, avg_similarity=0.9316, importance=0.068359
Layer 4: 8/8 valid samples, avg_similarity=0.9204, importance=0.079590
Layer 5: 8/8 valid samples, avg_similarity=0.9268, importance=0.073242
Layer 6: 8/8 valid samples, avg_similarity=0.9277, importance=0.072266
Layer 7: 8/8 valid samples, avg_similarity=0.9492, importance=0.050781
Layer 8: 8/8 valid samples, avg_similarity=0.9390, importance=0.061035
Layer 9: 8/8 valid samples, avg_similarity=0.9243, importance=0.075684
Layer 10: 8/8 valid samples, avg_similarity=0.9336, importance=0.066406
Layer 11: 8/8 valid samples, avg_similarity=0.9224, importance=0.077637
Layer 12: 8/8 valid samples, avg_similarity=0.9331, importance=0.066895
Layer 13: 8/8 valid samples, avg_similarity=0.9263, importance=0.073730
La

Processing batches: 100%|██████████| 13/13 [00:05<00:00,  2.19it/s]


In [74]:
print_sorted_importance(wiki_importance)

 1. Layer  0: 0.890395
 2. Layer  2: 0.771541
 3. Layer  1: 0.307580
 4. Layer 27: 0.173190
 5. Layer 23: 0.082933
 6. Layer 25: 0.074669
 7. Layer 24: 0.072416
 8. Layer 22: 0.069261
 9. Layer 26: 0.063664
10. Layer 21: 0.062763
11. Layer 17: 0.060885
12. Layer 19: 0.054763
13. Layer 16: 0.051645
14. Layer  4: 0.051382
15. Layer 11: 0.051194
16. Layer  9: 0.049692
17. Layer 13: 0.048528
18. Layer  5: 0.047476
19. Layer  6: 0.047138
20. Layer 14: 0.045335
21. Layer 15: 0.045335
22. Layer  3: 0.044283
23. Layer 12: 0.044246
24. Layer 10: 0.044171
25. Layer 20: 0.042405
26. Layer  8: 0.040152
27. Layer 18: 0.037861
28. Layer  7: 0.033391


In [75]:
sms_importance = calculate_layer_importance_cosine(model, dataloader2, device)

Calculating cosine-based importance for 28 layers...


Processing batches:   8%|▊         | 1/13 [00:00<00:06,  1.96it/s]

Layer 0: 8/8 valid samples, avg_similarity=0.0981, importance=0.901855
Layer 1: 8/8 valid samples, avg_similarity=0.7793, importance=0.220703
Layer 2: 8/8 valid samples, avg_similarity=0.0504, importance=0.949585
Layer 3: 8/8 valid samples, avg_similarity=0.9985, importance=0.001465
Layer 4: 8/8 valid samples, avg_similarity=0.9980, importance=0.001953
Layer 5: 8/8 valid samples, avg_similarity=0.9980, importance=0.001953
Layer 6: 8/8 valid samples, avg_similarity=0.9976, importance=0.002441
Layer 7: 8/8 valid samples, avg_similarity=0.9971, importance=0.002930
Layer 8: 8/8 valid samples, avg_similarity=0.9966, importance=0.003418
Layer 9: 8/8 valid samples, avg_similarity=0.9961, importance=0.003906
Layer 10: 8/8 valid samples, avg_similarity=0.9956, importance=0.004395
Layer 11: 8/8 valid samples, avg_similarity=0.9951, importance=0.004883
Layer 12: 8/8 valid samples, avg_similarity=0.9961, importance=0.003906
Layer 13: 8/8 valid samples, avg_similarity=0.9961, importance=0.003906
La

Processing batches: 100%|██████████| 13/13 [00:05<00:00,  2.21it/s]


In [76]:
print_sorted_importance(sms_importance)

 1. Layer  2: 0.948648
 2. Layer  0: 0.896963
 3. Layer  1: 0.277306
 4. Layer 27: 0.147085
 5. Layer 21: 0.025203
 6. Layer 24: 0.023287
 7. Layer 25: 0.020959
 8. Layer 26: 0.017353
 9. Layer 22: 0.017315
10. Layer 20: 0.016752
11. Layer 19: 0.015925
12. Layer 23: 0.014836
13. Layer 18: 0.008977
14. Layer 17: 0.006686
15. Layer 16: 0.006197
16. Layer 15: 0.005334
17. Layer 14: 0.005258
18. Layer 11: 0.004845
19. Layer 10: 0.004770
20. Layer 12: 0.004094
21. Layer  9: 0.003906
22. Layer 13: 0.003906
23. Layer  8: 0.003418
24. Layer  7: 0.002967
25. Layer  6: 0.002817
26. Layer  5: 0.002329
27. Layer  4: 0.001953
28. Layer  3: 0.001465


In [79]:
def compare_importance(scores1, scores2, name1="Dataset1", name2="Dataset2"):
    print(f"{'Layer':<5} {name1:<10} {name2:<10} {'Diff':<8}")
    print("-" * 35)
    for layer in sorted(scores1.keys()):
        s1, s2 = float(scores1[layer]), float(scores2[layer])
        diff = abs(s1 - s2)
        print(f"{layer:<5} {s1:<10.4f} {s2:<10.4f} {diff:<8.4f}")

In [85]:
compare_importance(wiki_importance, sms_importance, "wiki", "SMS")

Layer wiki       SMS        Diff    
-----------------------------------
0     0.8904     0.8970     0.0066  
1     0.3076     0.2773     0.0303  
2     0.7715     0.9486     0.1771  
3     0.0443     0.0015     0.0428  
4     0.0514     0.0020     0.0494  
5     0.0475     0.0023     0.0451  
6     0.0471     0.0028     0.0443  
7     0.0334     0.0030     0.0304  
8     0.0402     0.0034     0.0367  
9     0.0497     0.0039     0.0458  
10    0.0442     0.0048     0.0394  
11    0.0512     0.0048     0.0463  
12    0.0442     0.0041     0.0402  
13    0.0485     0.0039     0.0446  
14    0.0453     0.0053     0.0401  
15    0.0453     0.0053     0.0400  
16    0.0516     0.0062     0.0454  
17    0.0609     0.0067     0.0542  
18    0.0379     0.0090     0.0289  
19    0.0548     0.0159     0.0388  
20    0.0424     0.0168     0.0257  
21    0.0628     0.0252     0.0376  
22    0.0693     0.0173     0.0519  
23    0.0829     0.0148     0.0681  
24    0.0724     0.0233     0.0491  
25

The results of our comparison reveal a clear pattern. The first layers (0–2) and the last one (27) prove to be important in both datasets, suggesting that they perform fundamental functions, such as the initial processing of the input and the consolidation of the output.

The key difference lies in the behavior of the intermediate layers (roughly 3–26). While in the complex text of Wikitext these layers carry out a measurable job of semantic refinement, in the simple SMS text their contribution is practically null, becoming “passive.” This shows that the importance of a layer varies depending on the complexity of the task, thus validating “depth pruning” as an effective strategy to create more efficient models for specialized tasks.


## Create the pruned model for the SMS dataset by removing the 4 least relevant layers.

Podemos usar la información obtenida para eliminar los bloques Transfomer menos importantes del modelo para ser usado con el Dataset SMS.

In [89]:
from optipfair import prune_model

sms_model, stats = prune_model(
    model=model,
    pruning_type="DEPTH",
    layer_indices=[3, 4, 5, 6],
    show_progress=True,
    return_stats=True
)

Removing layers: 100%|██████████| 20/20 [00:00<00:00, 193732.29it/s]


The new model has only 24 transformer blocks.

In [93]:
print (stats['percentage_reduction'])

13.38227543762604


In [87]:
print(sms_model)

Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 1024)
    (layers): ModuleList(
      (0-23): 24 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (up_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (down_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen3RMSNorm((1024,), eps=1e-06)
        (post_attention_layernorm): Qwe