<a href="https://colab.research.google.com/github/peremartra/Rearchitecting-LLMs/blob/main/CH05/CH05_NB02_data_sms_wiki.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setting up notebook

In [1]:
!pip install -q \
      "torch" \
      "transformers==4.55.4" \
      "accelerate==1.10.1" \
      "lm_eval==0.4.9.1" \
      "sentencepiece==0.2.1" \
      "datasets" \
      "langdetect"\
      "optipfair==0.2.1"

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.0/42.0 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.6/53.6 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m981.5/981.5 kB[0m [31m66.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.3/11.3 MB[0m [31m69.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m374.9/374.9 kB[0m [31m32.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0

In [2]:
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from lm_eval import evaluator
from torch import nn
from lm_eval.models.huggingface import HFLM
import os
import json
import copy
import gc
import time
from copy import deepcopy
import pandas as pd

In [3]:
# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

Using device: cuda
GPU: Tesla T4


Download helper functions from the repository.

In [4]:
# Download utils.py from GitHub repository
!wget -q https://raw.githubusercontent.com/peremartra/Rearchitecting-LLMs/main/utils.py

# Verify download
import os
if os.path.exists('utils.py'):
    print("✅ utils.py downloaded successfully")
else:
    print("❌ Failed to download utils.py")

from utils import (
  model_evaluation, # Evals with lm_eval
  evaluate_metrics, # Loss & Perpelexity
  generate_text, #test inference model
  clear_gpu_cache
)

✅ utils.py downloaded successfully


In [5]:
def measure_detailed_performance(model, tokenizer, data_source, num_runs=3, max_new_tokens=50, max_samples=None):
    """
    Measures inference performance metrics.

    Args:
        model: Model to evaluate
        tokenizer: Tokenizer
        data_source: DataLoader to sample from
        num_runs: Number of runs per sample for averaging
        max_new_tokens: Tokens to generate
        max_samples: Limit number of samples (None = all)

    Returns:
        dict with timing statistics
    """
    model.eval()
    times = []
    tokens_generated = []

    samples = []
    for batch in data_source:
        for i in range(len(batch['input_ids'])):
            samples.append(batch['input_ids'][i])
            if max_samples and len(samples) >= max_samples:
                break
        if max_samples and len(samples) >= max_samples:
            break

    if max_samples:
        samples = samples[:max_samples]

    print(f"Measuring performance on {len(samples)} samples...")

    with torch.no_grad():
        for sample in tqdm(samples, desc="Performance test"):
            input_ids = sample.unsqueeze(0).to(device)

            for _ in range(num_runs):
                start_time = time.time()
                outputs = model.generate(
                    input_ids,
                    max_new_tokens=max_new_tokens,
                    do_sample=False,
                    pad_token_id=tokenizer.pad_token_id
                )
                end_time = time.time()

                elapsed = end_time - start_time
                times.append(elapsed)
                tokens_generated.append(outputs.shape[1] - input_ids.shape[1])

    avg_time = np.mean(times)
    std_time = np.std(times)
    avg_tokens = np.mean(tokens_generated)
    tokens_per_sec = avg_tokens / avg_time if avg_time > 0 else 0

    return {
        'avg_time_sec': avg_time,
        'std_time_sec': std_time,
        'avg_tokens': avg_tokens,
        'tokens_per_sec': tokens_per_sec,
        'num_samples': len(samples),
        'num_runs': num_runs
    }

## Configuration Parameters

In [6]:
# Model configuration
MODEL_NAME = 'meta-llama/Llama-3.2-1B'

# Dataset configuration
RECOVERY_SAMPLES = 100  # Calibration samples per dataset
MAX_LENGTH = 1024
BATCH_SIZE = 4

# Pruning configuration
PRUNE_PERCENT = 0.2  # 20% of neurons will be pruned

# Generation configuration
MAX_NEW_TOKENS = 50



In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto"
)
model.eval()
model.generation_config.temperature = None
model.generation_config.top_p = None
model.generation_config.top_k = None

print(f"✓ Loaded {MODEL_NAME}")
print(f"  Layers: {len(model.model.layers)}")
print(f"  Hidden size: {model.config.hidden_size}")
print(f"  Intermediate size: {model.config.intermediate_size}")

In [8]:
# Test the original model
prompt = "Paris is the capital of"
generated_base = generate_text(model, tokenizer, prompt, device)
print(f"Base model generation: {generated_base}")

Base model generation: Paris is the capital of France and the largest city in the country. It is located on the River Seine and is one of the most popular tourist destinations in Europe. The city has a rich history and culture, and it is home to many famous landmarks, including the E


# Load and Prepare Calibration Datasets

We'll load two contrasting datasets:

1. **WikiText-2**: Long-form, encyclopedic text with complex language patterns
2. **SMS Spam**: Short conversational messages with informal language

These datasets will serve as calibration sources for our two pruned models.

In [None]:
# Load datasets
datawiki = load_dataset('wikitext', 'wikitext-2-raw-v1', split=f'train[:{RECOVERY_SAMPLES}]')
datasms = load_dataset('sms_spam', split=f'train[:{RECOVERY_SAMPLES}]')

print(f"✓ WikiText samples: {len(datawiki)}")
print(f"✓ SMS samples: {len(datasms)}")

In [10]:
def prepare_dataset(dataset, text_field='text'):
    """
    Tokenizes and prepares a dataset for calibration.

    Handles different dataset formats (WikiText uses 'text', SMS uses 'sms' field).
    """
    def tokenize_function(examples):
        if text_field in examples:
            texts = examples[text_field]
        elif 'sms' in examples:  # SMS dataset specific
            texts = examples['sms']
        elif 'text' in examples:
            texts = examples['text']
        else:
            texts = examples[list(examples.keys())[0]]

        return tokenizer(
            texts,
            truncation=True,
            padding='max_length',
            max_length=MAX_LENGTH,
            return_tensors='pt'
        )

    tokenized = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)
    tokenized.set_format(type='torch', columns=['input_ids', 'attention_mask'])
    return DataLoader(tokenized, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
# Create dataloaders
dataloaderwiki = prepare_dataset(datawiki)
dataloadersms = prepare_dataset(datasms)

print(f"✓ Created dataloaders")
print(f"  Wiki batches: {len(dataloaderwiki)}")
print(f"  SMS batches: {len(dataloadersms)}")

In [12]:
metrics_base_wiki = evaluate_metrics(model, dataloaderwiki)


Evaluating: 100%|██████████| 25/25 [00:14<00:00,  1.71it/s]


In [13]:
metrics_base_sms = evaluate_metrics(model, dataloadersms)


Evaluating: 100%|██████████| 25/25 [00:14<00:00,  1.69it/s]


In [14]:
base_wiki_timing = measure_detailed_performance(model, tokenizer, dataloaderwiki, max_samples=10)

Measuring performance on 10 samples...


Performance test:   0%|          | 0/10 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Performance test: 100%|██████████| 10/10 [00:20<00:00,  2.04s/it]


In [15]:
base_sms_timing = measure_detailed_performance(model, tokenizer, dataloadersms, max_samples=10)

Measuring performance on 10 samples...


Performance test: 100%|██████████| 10/10 [00:32<00:00,  3.29s/it]


In [16]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())


In [17]:
original_params = count_parameters(model)

In [18]:
clear_gpu_cache()


# Data-Driven Pruning Functions

These functions implement the CFSP-inspired methodology from CH05_NB02:
- **Activation capture**: PyTorch hooks on down_proj to record runtime behavior
- **Hybrid importance**: Combines weight magnitudes with activation norms
- **Neuron pair pruning**: Removes least important neurons from gate_proj, up_proj, and down_proj

## Activation Capture with PyTorch Hooks

In [19]:
# Global storage for accumulated activation norms
_accumulated_act_norms = {}

def setup_mlp_hooks_for_importance(model, device):
    """
    Registers hooks on down_proj inputs (X_d) to calculate L2 norms
    for each neuron, following CFSP Equation 8.

    Accumulates norms across multiple calibration batches.

    Returns:
        handles: List of hook handles (for removal after calibration)
    """
    global _accumulated_act_norms
    _accumulated_act_norms.clear()

    # Free memory before starting
    gc.collect()
    torch.cuda.empty_cache()

    handles = []

    # Initialize storage on CPU to save VRAM
    for idx, layer in enumerate(model.model.layers):
        intermediate_size = layer.mlp.down_proj.in_features
        _accumulated_act_norms[idx] = torch.zeros(
            intermediate_size,
            dtype=torch.float32,
            device='cpu'
        )

    def make_hook(layer_idx):
        def hook(module, input, output):
            """
            Captures X_d (input to down_proj) and calculates its L2 norm.

            X_d shape: [batch_size, seq_len, intermediate_size]
            Output: [intermediate_size] with ||X_d^i|| for each neuron i
            """
            X_d = input[0].detach()  # [B, S, I]

            # Calculate L2 norm (Equation 8 from CFSP paper)
            act_norms_L2 = torch.norm(
                X_d.to(torch.float32),
                p=2,
                dim=(0, 1)  # Sum over batch and sequence
            )

            # Accumulate on CPU to save VRAM
            _accumulated_act_norms[layer_idx] += act_norms_L2.cpu()

        return hook

    # Register hooks
    for idx, layer in enumerate(model.model.layers):
        handle = layer.mlp.down_proj.register_forward_hook(
            make_hook(idx)
        )
        handles.append(handle)

    print(f"✓ Registered {len(handles)} hooks on down_proj")

    return handles

def get_activation_norms():
    """
    Returns the accumulated L2 norms in a format ready for pruning.

    Returns:
        Dict[int, torch.Tensor]: {layer_idx: norms_L2 [intermediate_size]}
    """
    return {
        layer_idx: norms.clone()
        for layer_idx, norms in _accumulated_act_norms.items()
    }

## Hybrid Importance Scoring

In [20]:
def compute_neuron_pair_importance(gate_weight, up_weight, down_weight, X_d_norm):
    """
    Hybrid CFSP-inspired importance: Static magnitude + Dynamic activation
    """
    gate_weight = gate_weight.float()
    up_weight = up_weight.float()
    down_weight = down_weight.float()
    X_d_norm = X_d_norm.float().to(gate_weight.device)

    # Static component (L2 norms)
    gate_score = torch.norm(gate_weight, p=2, dim=1)
    up_score = torch.norm(up_weight, p=2, dim=1)
    down_score = torch.norm(down_weight, p=2, dim=0)

    # Normalize to [0, 1] to equalize scales
    gate_norm = gate_score / (gate_score.max() + 1e-8)
    up_norm = up_score / (up_score.max() + 1e-8)
    down_norm = down_score / (down_score.max() + 1e-8)

    # Weighted combination (down_proj gets more weight)
    #structural_score = 0.4 * down_norm + 0.4 * gate_norm + 0.2 * up_norm
    structural_score = down_norm + gate_norm + up_norm

    # Dynamic fusion (multiply by actual activations)
    importance_scores = structural_score * X_d_norm

    return importance_scores

## Neuron Pair Pruning

In [21]:
def prune_neuron_pairs(mlp, prune_percent, X_d_norm, layer_idx):
    """
    Prunes neuron pairs from MLP block using hybrid importance scores.

    Reduces dimensions of gate_proj, up_proj, and down_proj layers.

    Args:
        mlp: LlamaMLP module to prune
        prune_percent: Fraction of neurons to remove (e.g., 0.2 for 20%)
        X_d_norm: Tensor [intermediate_size] with accumulated L2 norms
        layer_idx: Layer index (for logging)

    Returns:
        new_gate_proj, new_up_proj, new_down_proj, k (new intermediate size)
    """

    # Extract weights
    gate_weight = mlp.gate_proj.weight.data
    up_weight = mlp.up_proj.weight.data
    down_weight = mlp.down_proj.weight.data

    original_intermediate_size = gate_weight.size(0)

    # Compute importance scores
    importance_scores = compute_neuron_pair_importance(
        gate_weight=gate_weight,
        up_weight=up_weight,
        down_weight=down_weight,
        X_d_norm=X_d_norm
    )

    # Determine how many neurons to keep
    num_to_prune = min(
        int(prune_percent * original_intermediate_size),
        original_intermediate_size - 1
    )
    k = original_intermediate_size - num_to_prune

    if k <= 0:
        raise ValueError(f"Invalid k={k} for layer {layer_idx}")

    # Select top-k most important neurons
    _, indices_to_keep = torch.topk(
        importance_scores,
        k,
        largest=True,
        sorted=True
    )

    indices_to_keep = indices_to_keep.sort().values

    # Create new pruned layers
    new_gate_proj = nn.Linear(
        mlp.gate_proj.in_features,
        k,
        bias=False
    ).to(device)

    new_up_proj = nn.Linear(
        mlp.up_proj.in_features,
        k,
        bias=False
    ).to(device)

    new_down_proj = nn.Linear(
        k,
        mlp.down_proj.out_features,
        bias=False
    ).to(device)

    # Copy weights for kept neurons
    new_gate_proj.weight.data = gate_weight[indices_to_keep, :]
    new_up_proj.weight.data = up_weight[indices_to_keep, :]
    new_down_proj.weight.data = down_weight[:, indices_to_keep]

    return new_gate_proj, new_up_proj, new_down_proj, k

In [22]:
def update_model(model, prune_percent, activation_norms):
    """
    Applies pruning to all MLP layers in the model.

    Args:
        model: LlamaForCausalLM model to prune
        prune_percent: Fraction of neurons to remove
        activation_norms: Dict mapping layer_idx -> X_d_norm tensor

    Returns:
        model: Pruned model with updated configuration
    """

    new_intermediate_size = None
    pruning_stats = []

    print(f"\n{'='*60}")
    print(f"Starting pruning with {prune_percent*100:.1f}% width pruning")
    print(f"{'='*60}\n")

    for idx, layer in enumerate(model.model.layers):
        mlp = layer.mlp

        if idx not in activation_norms:
            raise KeyError(f"No activation norms for layer {idx}")

        X_d_norm = activation_norms[idx]
        original_size = mlp.gate_proj.out_features

        # Prune neuron pairs
        new_gate_proj, new_up_proj, new_down_proj, new_size = prune_neuron_pairs(
            mlp=mlp,
            prune_percent=prune_percent,
            X_d_norm=X_d_norm,
            layer_idx=idx
        )

        # Replace layers
        mlp.gate_proj = new_gate_proj
        mlp.up_proj = new_up_proj
        mlp.down_proj = new_down_proj

        pruning_stats.append({
            'layer': idx,
            'original_size': original_size,
            'new_size': new_size,
            'pruned': original_size - new_size,
            'kept_percent': (new_size / original_size) * 100
        })

        if new_intermediate_size is None:
            new_intermediate_size = new_size

        if (idx + 1) % 4 == 0:
            print(f"  Pruned layers {idx-3:2d}-{idx:2d}: "
                  f"{original_size} → {new_size} neurons "
                  f"({(new_size/original_size)*100:.1f}% kept)")

    # Update model configuration
    model.config.intermediate_size = new_intermediate_size

    print(f"\n{'='*60}")
    print(f"Pruning completed!")
    print(f"{'='*60}")
    print(f"  Layers pruned: {len(pruning_stats)}")
    print(f"  Original intermediate size: {original_size}")
    print(f"  New intermediate size: {new_intermediate_size}")
    print(f"  Neurons pruned per layer: {original_size - new_intermediate_size}")
    print(f"  Effective width pruning: {((original_size - new_intermediate_size) / original_size) * 100:.2f}%")
    print(f"{'='*60}\n")

    return model

# Calibration and Pruning

Now we'll create two pruned models using different calibration datasets:
1. **Wiki-pruned**: Calibrated on WikiText-2
2. **SMS-pruned**: Calibrated on SMS Spam

## Wiki-Calibrated Model

In [23]:
print("="*60)
print("WIKI CALIBRATION")
print("="*60)

# Step 1: Setup hooks
print("\nSetting up activation hooks...")
handles_wiki = setup_mlp_hooks_for_importance(model, device)

# Step 2: Run calibration forward passes
print("\nRunning calibration forward passes on WikiText...")
model.eval()

with torch.no_grad():
    for batch_idx, batch in enumerate(tqdm(dataloaderwiki, desc="Wiki Calibration")):
        inputs = {
            'input_ids': batch['input_ids'].to(device),
            'attention_mask': batch['attention_mask'].to(device)
        }
        outputs = model(**inputs)

        if (batch_idx + 1) % 10 == 0:
            torch.cuda.empty_cache()

print(f"\n✓ Processed {len(dataloaderwiki)} batches")

# Step 3: Clean up hooks
print("Removing hooks...")
for handle in handles_wiki:
    handle.remove()

# Step 4: Get activation norms
print("Extracting activation statistics...")
activation_norms_wiki = get_activation_norms()
print(f"✓ Collected activation norms for {len(activation_norms_wiki)} layers")

WIKI CALIBRATION

Setting up activation hooks...
✓ Registered 16 hooks on down_proj

Running calibration forward passes on WikiText...


Wiki Calibration: 100%|██████████| 25/25 [00:15<00:00,  1.59it/s]


✓ Processed 25 batches
Removing hooks...
Extracting activation statistics...
✓ Collected activation norms for 16 layers





In [24]:
# Prune the model using Wiki activations
wiki_model = update_model(copy.deepcopy(model), PRUNE_PERCENT, activation_norms_wiki)


Starting pruning with 20.0% width pruning

  Pruned layers  0- 3: 8192 → 6554 neurons (80.0% kept)
  Pruned layers  4- 7: 8192 → 6554 neurons (80.0% kept)
  Pruned layers  8-11: 8192 → 6554 neurons (80.0% kept)
  Pruned layers 12-15: 8192 → 6554 neurons (80.0% kept)

Pruning completed!
  Layers pruned: 16
  Original intermediate size: 8192
  New intermediate size: 6554
  Neurons pruned per layer: 1638
  Effective width pruning: 20.00%



In [25]:
del(model)
clear_gpu_cache()

In [26]:
# Test wiki model
generated_wiki = generate_text(wiki_model, tokenizer, prompt, device)
print(f"Wiki-pruned model: {generated_wiki}")

Wiki-pruned model: Paris is the capital of France and is located in the north-east of the country. The city has a population of 2.1 million people, making it the 3rd largest city in France. It is also the largest in terms of surface area, with 100


In [27]:
clear_gpu_cache()
metrics_wiki_wiki = evaluate_metrics(wiki_model, dataloaderwiki)


Evaluating: 100%|██████████| 25/25 [00:16<00:00,  1.50it/s]


In [28]:
metrics_wiki_sms = evaluate_metrics(wiki_model, dataloadersms)


Evaluating: 100%|██████████| 25/25 [00:16<00:00,  1.47it/s]


In [29]:
wiki_wiki_timing = measure_detailed_performance(wiki_model, tokenizer, dataloaderwiki, max_samples=10)

Measuring performance on 10 samples...


Performance test: 100%|██████████| 10/10 [00:39<00:00,  3.93s/it]


In [30]:
wiki_params = count_parameters(wiki_model)

In [31]:
wiki_sms_timing = measure_detailed_performance(wiki_model, tokenizer, dataloadersms, max_samples=10)

Measuring performance on 10 samples...


Performance test: 100%|██████████| 10/10 [00:38<00:00,  3.85s/it]


In [32]:
del(wiki_model)
clear_gpu_cache

## SMS-Calibrated Model

In [33]:
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto"
)
model.eval()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb):

In [34]:
# Clear GPU cache before SMS calibration
clear_gpu_cache()

print("="*60)
print("SMS CALIBRATION")
print("="*60)

# Step 1: Setup hooks
print("\nSetting up activation hooks...")
handles_sms = setup_mlp_hooks_for_importance(model, device)

# Step 2: Run calibration forward passes
print("\nRunning calibration forward passes on SMS...")
model.eval()

with torch.no_grad():
    for batch_idx, batch in enumerate(tqdm(dataloadersms, desc="SMS Calibration")):
        inputs = {
            'input_ids': batch['input_ids'].to(device),
            'attention_mask': batch['attention_mask'].to(device)
        }
        outputs = model(**inputs)

        if (batch_idx + 1) % 10 == 0:
            torch.cuda.empty_cache()

print(f"\n✓ Processed {len(dataloadersms)} batches")

# Step 3: Clean up hooks
print("Removing hooks...")
for handle in handles_sms:
    handle.remove()

# Step 4: Get activation norms
print("Extracting activation statistics...")
activation_norms_sms = get_activation_norms()
print(f"✓ Collected activation norms for {len(activation_norms_sms)} layers")

SMS CALIBRATION

Setting up activation hooks...
✓ Registered 16 hooks on down_proj

Running calibration forward passes on SMS...


SMS Calibration: 100%|██████████| 25/25 [00:15<00:00,  1.61it/s]


✓ Processed 25 batches
Removing hooks...
Extracting activation statistics...
✓ Collected activation norms for 16 layers





In [35]:
# Prune the model using SMS activations
sms_model = update_model(copy.deepcopy(model), PRUNE_PERCENT, activation_norms_sms)


Starting pruning with 20.0% width pruning

  Pruned layers  0- 3: 8192 → 6554 neurons (80.0% kept)
  Pruned layers  4- 7: 8192 → 6554 neurons (80.0% kept)
  Pruned layers  8-11: 8192 → 6554 neurons (80.0% kept)
  Pruned layers 12-15: 8192 → 6554 neurons (80.0% kept)

Pruning completed!
  Layers pruned: 16
  Original intermediate size: 8192
  New intermediate size: 6554
  Neurons pruned per layer: 1638
  Effective width pruning: 20.00%



In [36]:
del(model)
clear_gpu_cache()

In [37]:
# Test SMS model
generated_sms = generate_text(sms_model, tokenizer, prompt, device)
print(f"SMS-pruned model: {generated_sms}")

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


SMS-pruned model: Paris is the capital of the French department of Paris. It is located on the Seine River, which flows through the city. The city has a population of 1.8 million people. Paris is a major city in the world. Its population is estimated to be 


In [38]:
metrics_sms_wiki = evaluate_metrics(sms_model, dataloaderwiki)
metrics_sms_sms = evaluate_metrics(sms_model, dataloadersms)

Evaluating: 100%|██████████| 25/25 [00:16<00:00,  1.49it/s]
Evaluating: 100%|██████████| 25/25 [00:17<00:00,  1.45it/s]


In [39]:
sms_wiki_timing = measure_detailed_performance(sms_model, tokenizer, dataloaderwiki, max_samples=10)

Measuring performance on 10 samples...


Performance test:   0%|          | 0/10 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Performance test:  10%|█         | 1/10 [00:04<00:37,  4.19s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Performance test:  20

In [40]:
sms_sms_timing = measure_detailed_performance(sms_model, tokenizer, dataloadersms, max_samples=10)

Measuring performance on 10 samples...


Performance test:   0%|          | 0/10 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Performance test:  10%|█         | 1/10 [00:04<00:37,  4.13s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Performance test:  20

## Parameter Comparison

In [41]:
sms_params = count_parameters(sms_model)

print("\n" + "="*60)
print("PARAMETER COUNTS")
print("="*60)
print(f"Original model:     {original_params:,} parameters")
print(f"Wiki-pruned model:  {wiki_params:,} parameters ({((original_params - wiki_params) / original_params * 100):.2f}% reduction)")
print(f"SMS-pruned model:   {sms_params:,} parameters ({((original_params - sms_params) / original_params * 100):.2f}% reduction)")
print("="*60)


PARAMETER COUNTS
Original model:     1,235,814,400 parameters
Wiki-pruned model:  1,074,792,448 parameters (13.03% reduction)
SMS-pruned model:   1,074,792,448 parameters (13.03% reduction)


# Cross-Evaluation Matrix

Now we evaluate all three models (base, wiki-pruned, sms-pruned) on both datasets (Wiki and SMS) to understand:
- How well each pruning strategy preserves quality on its calibration dataset
- How well pruning decisions transfer to the other dataset
- Whether domain-specific calibration provides meaningful benefits

In [42]:
# Clear original model to save memory
clear_gpu_cache()

## Loss and Perplexity Evaluation

In [43]:
print("\n" + "="*60)
print("EVALUATING ON WIKITEXT-2")
print("="*60)

print("\nBase model on Wiki:", metrics_base_wiki)
print("Wiki-pruned on Wiki:", metrics_wiki_wiki)
print("SMS-pruned on Wiki:", metrics_sms_wiki)


EVALUATING ON WIKITEXT-2

Base model on Wiki: {'loss': 3.2460288122350294, 'perplexity': np.float64(25.688124727046315)}
Wiki-pruned on Wiki: {'loss': 3.5913135137696126, 'perplexity': np.float64(36.28170115548473)}
SMS-pruned on Wiki: {'loss': 3.8846293323963836, 'perplexity': np.float64(48.64890654342502)}


In [44]:
print("\n" + "="*60)
print("EVALUATING ON SMS SPAM")
print("="*60)

print("\nBase model on SMS:", metrics_base_sms)
print("Wiki-pruned on SMS:", metrics_wiki_sms)
print("SMS-pruned on SMS:", metrics_sms_sms)


EVALUATING ON SMS SPAM

Base model on SMS: {'loss': 4.8114213257195955, 'perplexity': np.float64(122.90618314977401)}
Wiki-pruned on SMS: {'loss': 5.206973014926048, 'perplexity': np.float64(182.540673179065)}
SMS-pruned on SMS: {'loss': 5.109189337836538, 'perplexity': np.float64(165.53610660777312)}


## Text Generation Comparison

## Performance Measurement (Inference Speed)

In [45]:
print("\n" + "="*60)
print("PERFORMANCE MEASUREMENT ON WIKI")
print("="*60)


print(f"\nBase model:       {base_wiki_timing['avg_time_sec']:.4f}s ({base_wiki_timing['tokens_per_sec']:.2f} tok/s)")
print(f"Wiki-pruned:      {wiki_wiki_timing['avg_time_sec']:.4f}s ({wiki_wiki_timing['tokens_per_sec']:.2f} tok/s)")
print(f"SMS-pruned:       {sms_wiki_timing['avg_time_sec']:.4f}s ({sms_wiki_timing['tokens_per_sec']:.2f} tok/s)")


PERFORMANCE MEASUREMENT ON WIKI

Base model:       0.6791s (29.74 tok/s)
Wiki-pruned:      1.3088s (35.45 tok/s)
SMS-pruned:       1.2069s (34.88 tok/s)


In [46]:
print("\n" + "="*60)
print("PERFORMANCE MEASUREMENT ON SMS")
print("="*60)



print(f"\nBase model:       {base_sms_timing['avg_time_sec']:.4f}s ({base_sms_timing['tokens_per_sec']:.2f} tok/s)")
print(f"Wiki-pruned:      {wiki_sms_timing['avg_time_sec']:.4f}s ({wiki_sms_timing['tokens_per_sec']:.2f} tok/s)")
print(f"SMS-pruned:       {sms_sms_timing['avg_time_sec']:.4f}s ({sms_sms_timing['tokens_per_sec']:.2f} tok/s)")


PERFORMANCE MEASUREMENT ON SMS

Base model:       1.0957s (34.32 tok/s)
Wiki-pruned:      1.2819s (35.88 tok/s)
SMS-pruned:       1.3891s (35.99 tok/s)


# Summary and Results Analysis

In [47]:
# Create comprehensive results table
results = {
    'Model': ['Base', 'Wiki-pruned', 'SMS-pruned'],
    'Parameters': [original_params, wiki_params, sms_params],
    'Param Reduction %': [
        0,
        ((original_params - wiki_params) / original_params * 100),
        ((original_params - sms_params) / original_params * 100)
    ],
    'PPL Wiki': [
        metrics_base_wiki['perplexity'],
        metrics_wiki_wiki['perplexity'],
        metrics_sms_wiki['perplexity']
    ],
    'PPL SMS': [
        metrics_base_sms['perplexity'],
        metrics_wiki_sms['perplexity'],
        metrics_sms_sms['perplexity']
    ],
    'Loss Wiki': [
        metrics_base_wiki['loss'],
        metrics_wiki_wiki['loss'],
        metrics_sms_wiki['loss']
    ],
    'Loss SMS': [
        metrics_base_sms['loss'],
        metrics_wiki_sms['loss'],
        metrics_sms_sms['loss']
    ],
    'Time Wiki (s)': [
        base_wiki_timing['avg_time_sec'],
        wiki_wiki_timing['avg_time_sec'],
        sms_wiki_timing['avg_time_sec']
    ],
    'Time SMS (s)': [
        base_sms_timing['avg_time_sec'],
        wiki_sms_timing['avg_time_sec'],
        sms_sms_timing['avg_time_sec']
    ],
    'Tok/s Wiki': [
        base_wiki_timing['tokens_per_sec'],
        wiki_wiki_timing['tokens_per_sec'],
        sms_wiki_timing['tokens_per_sec']
    ],
    'Tok/s SMS': [
        base_sms_timing['tokens_per_sec'],
        wiki_sms_timing['tokens_per_sec'],
        sms_sms_timing['tokens_per_sec']
    ],
}

df_results = pd.DataFrame(results)

print("\n" + "="*80)
print("COMPREHENSIVE RESULTS")
print("="*80)
print(df_results.to_string(index=False))
print("="*80)


COMPREHENSIVE RESULTS
      Model  Parameters  Param Reduction %  PPL Wiki    PPL SMS  Loss Wiki  Loss SMS  Time Wiki (s)  Time SMS (s)  Tok/s Wiki  Tok/s SMS
       Base  1235814400           0.000000 25.688125 122.906183   3.246029  4.811421       0.679144      1.095701   29.743339  34.315922
Wiki-pruned  1074792448          13.029623 36.281701 182.540673   3.591314  5.206973       1.308820      1.281878   35.451768  35.884851
 SMS-pruned  1074792448          13.029623 48.648907 165.536107   3.884629  5.109189       1.206911      1.389136   34.882451  35.993593


In [48]:
print(generated_base)
print(generated_wiki)
print(generated_sms)

Paris is the capital of France and the largest city in the country. It is located on the River Seine and is one of the most popular tourist destinations in Europe. The city has a rich history and culture, and it is home to many famous landmarks, including the E
Paris is the capital of France and is located in the north-east of the country. The city has a population of 2.1 million people, making it the 3rd largest city in France. It is also the largest in terms of surface area, with 100
Paris is the capital of the French department of Paris. It is located on the Seine River, which flows through the city. The city has a population of 1.8 million people. Paris is a major city in the world. Its population is estimated to be 


40% pruning expansion rate.
```
================================================================================
COMPREHENSIVE RESULTS
================================================================================
      Model  Parameters  Param Reduction %  PPL Wiki    PPL SMS  Loss Wiki  Loss SMS  Time Wiki (s)  Time SMS (s)  Tok/s Wiki  Tok/s SMS
       Base  1235814400           0.000000 25.688125 122.906183   3.246029  4.811421       0.679144      1.095701   29.743339  34.315922
Wiki-pruned  1074792448          13.029623 36.281701 182.540673   3.591314  5.206973       1.308820      1.281878   35.451768  35.884851
 SMS-pruned  1074792448          13.029623 48.648907 165.536107   3.884629  5.109189       1.206911      1.389136   34.882451  35.993593
================================================================================
````
