<a href="https://colab.research.google.com/github/peremartra/llama-glu-expansion-pruning/blob/main/notebooks/00_Neuron_Selection_Method_Comparison.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GLU Pruning Research
## 00 - Neuron Selection Method Comparison: MAW vs VOW vs PON

### Exploring GLU Expansion Ratios in Llama-3.2 Models
by [Pere Martra](https://github.com/peremartra)

[![Paper](https://img.shields.io/badge/OSF-Paper-blue?logo=osf&logoColor=white)](https://doi.org/10.31219/osf.io/qgxea)
[![GitHub](https://img.shields.io/badge/⭐_Star-OptiPFair-orange?logo=github&logoColor=white)](https://github.com/peremartra/optipfair)
[![PyPI](https://img.shields.io/pypi/v/optipfair?logo=python&logoColor=white&label=v)](https://pypi.org/project/optipfair/)

**Repository:** [github.com/peremartra/llama-glu-expansion-pruning](https://github.com/peremartra/llama-glu-expansion-pruning)

---

**Colab Environment:** GPU L4

**Models:**
* Llama-3.2-1B (base)

**Benchmarks (in order):**
* BoolQ (0-shot) - ~15 min
* Lambada (0-shot) - ~15 min
* IFEval (0-shot) - ~20 min
* GSM8K (5-shot CoT) - ~30 min *(if needed)*

---

## 📋 Notebook Objective

This preliminary experiment compares three neuron importance metrics for GLU-based pruning:
- **MAW** (Maximum Absolute Weight) - Default method in OptiPFair
- **VOW** (Variance of Weights)
- **PON** (Product of Norms)

**Goal:** Empirically determine which method best preserves model capabilities under pruning before conducting the main experiments. We will create three versions of Llama-3.2-1B pruned at 10% (one per method) and evaluate them sequentially on progressively demanding benchmarks.

**Evaluation Strategy:** We'll start with the fastest benchmarks (BoolQ, Lambada) to quickly identify clear winners or losers. If methods show similar performance, we'll proceed to more demanding benchmarks (IFEval, GSM8K) for differentiation.

**Why this matters:** The choice of neuron selection method fundamentally affects which neurons are removed during pruning. This "tournament" ensures our main experiments use the most effective approach, with results documented in the paper's methodology section.

**Output:** The winning method will be used exclusively for all subsequent pruning experiments (01 and 02 notebooks).

---

## 📦 Required Libraries

```python
# Install OptiPFair for GLU pruning
!pip install optipfair

# Install LM Evaluation Harness for benchmarking
!pip install lm-eval
```

---

**Note:** This notebook is part of a larger research project on width pruning in GLU architectures. See the full paper and codebase at the links above.

---

# Setup & Install

In [1]:
# Install required libraries.
!pip install -q optipfair
!pip install -q lm-eval
!pip install -q langdetect

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.9/44.9 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.6/53.6 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.5/7.5 MB[0m [31m112.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.5/491.5 kB[0m [31m31.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m293.6/293.6 kB[0m [31m25.0 MB/s[0m eta [36m0:00:

In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from optipfair import prune_model
from datasets import load_dataset
import copy
import json
import gc
from datetime import datetime
from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM

In [3]:
def clear_gpu_cache():
    """Limpia completamente la cache de GPU"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    gc.collect()


In [4]:
MODEL_ID = "meta-llama/Llama-3.2-1B"
PRUNING_PERCENTAGE = 10
PERPLEXITY_SAMPLES=500
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load Base Model

In [None]:
base_model = AutoModelForCausalLM.from_pretrained(
  MODEL_ID,
  dtype=torch.float16,
  device_map="auto"
)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
  tokenizer.pad_token = tokenizer.eos_token

# Evaluations

## Evaluations Code

In [7]:
def model_evaluation(model_obj, tokenizer, tasks, limit=None):
    """
    Runs lm-eval on a model and tokenizer object already in memory.

    Args:
        model_obj: The PyTorch model object to evaluate.
        tokenizer_obj: The tokenizer object.
        tasks (list): A list of task names for lm-eval.
        limit (int): The number of samples per task.
    """
    model_name = getattr(model_obj.config, '_name_or_path', 'unknown')
    limit_str = f"(limit={limit})" if limit else "(full dataset)"
    print(f"\n{'='*70}")
    print(f"Starting lm-eval on model '{model_name}'")
    print(f"Tasks: {tasks} {limit_str}")
    print(f"{'='*70}\n")

    # Wrap the local model object and tokenizer for lm-eval
    model_wrapper = HFLM(
        pretrained=model_obj,
        tokenizer=tokenizer,
        device=str(DEVICE)
    )

    # Run evaluation
    results = evaluator.simple_evaluate(
        model=model_wrapper,
        tasks=tasks,
        num_fewshot=0,
        limit=limit,
        device=str(DEVICE)
    )

    # Format results for clean display
    formatted_results = {}
    for task_name, res in results["results"].items():
        # Extract relevant metrics based on task type
        if 'perplexity,none' in res:
            # For perplexity tasks (wikitext)
            formatted_results[task_name] = {
                'perplexity': f"{res.get('perplexity,none', 0):.2f}",
                'word_perplexity': f"{res.get('word_perplexity,none', 0):.2f}",
                'bits_per_byte': f"{res.get('bits_per_byte,none', 0):.4f}"
            }
        elif 'acc,none' in res:
            # For accuracy tasks (boolq, etc.)
            formatted_results[task_name] = {
                'accuracy': f"{res.get('acc,none', 0):.4f}",
                'acc_norm': f"{res.get('acc_norm,none', 0):.4f}"
            }
        else:
            # Fallback: store all available metrics
            formatted_results[task_name] = {k: f"{v:.4f}" for k, v in res.items() if isinstance(v, (int, float))}

    return formatted_results

In [8]:
TASKS = ["wikitext", "boolq", "lambada_openai"]


In [None]:
BASE_results = model_evaluation(
  model_obj=base_model,
  tokenizer=tokenizer,
  tasks=TASKS,
  limit=None  # Full dataset
)

In [10]:
BASE_results

{'boolq': {'accuracy': '0.6391', 'acc_norm': '0.0000'},
 'lambada_openai': {'perplexity': '5.72',
  'word_perplexity': '0.00',
  'bits_per_byte': '0.0000'},
 'wikitext': {'word_perplexity,none': '11.5688',
  'byte_perplexity,none': '1.5807',
  'bits_per_byte,none': '0.6605'}}

## MAW selecion model


In [11]:
# Create Pruned Model with MAW
pruned_model_maw, stats = prune_model(
    model=copy.deepcopy(base_model),
    pruning_type="MLP_GLU",
    neuron_selection_method="MAW",
    pruning_percentage=PRUNING_PERCENTAGE,
    show_progress=True,
    return_stats=True
)
stats

Pruning layers: 100%|██████████| 16/16 [00:06<00:00,  2.59it/s]


{'original_parameters': 1235814400,
 'pruned_parameters': 1155303424,
 'reduction': 80510976,
 'percentage_reduction': 6.5148112855781575,
 'expansion_rate': 360.009765625}

In [None]:
MAW_results = model_evaluation(
  model_obj=pruned_model_maw,
  tokenizer=tokenizer,
  tasks=TASKS,
  limit=None  # Full dataset
)

In [13]:
MAW_results

{'boolq': {'accuracy': '0.6251', 'acc_norm': '0.0000'},
 'lambada_openai': {'perplexity': '20.54',
  'word_perplexity': '0.00',
  'bits_per_byte': '0.0000'},
 'wikitext': {'word_perplexity,none': '17.4945',
  'byte_perplexity,none': '1.7078',
  'bits_per_byte,none': '0.7721'}}

In [14]:
#Delete MAW model
del pruned_model_maw
clear_gpu_cache()

## PON Selection Model


In [15]:
# Create Pruned Model with PON
pruned_model_pon, stats = prune_model(
    model=copy.deepcopy(base_model),
    pruning_type="MLP_GLU",
    neuron_selection_method="PON",
    pruning_percentage=PRUNING_PERCENTAGE,
    show_progress=True,
    return_stats=True
)
stats

Pruning layers: 100%|██████████| 16/16 [00:06<00:00,  2.56it/s]


{'original_parameters': 1235814400,
 'pruned_parameters': 1155303424,
 'reduction': 80510976,
 'percentage_reduction': 6.5148112855781575,
 'expansion_rate': 360.009765625}

In [None]:
PON_results = model_evaluation(
  model_obj=pruned_model_pon,
  tokenizer=tokenizer,
  tasks=TASKS,
  limit=None  # Full dataset
)

In [17]:
PON_results

{'boolq': {'accuracy': '0.6220', 'acc_norm': '0.0000'},
 'lambada_openai': {'perplexity': '2032.80',
  'word_perplexity': '0.00',
  'bits_per_byte': '0.0000'},
 'wikitext': {'word_perplexity,none': '72.5170',
  'byte_perplexity,none': '2.2280',
  'bits_per_byte,none': '1.1557'}}

In [18]:
del pruned_model_pon
clear_gpu_cache()

## VOW Selection Model


In [19]:
# Create Pruned Model with PON
pruned_model_vow, stats = prune_model(
    model=copy.deepcopy(base_model),
    pruning_type="MLP_GLU",
    neuron_selection_method="VOW",
    pruning_percentage=PRUNING_PERCENTAGE,
    show_progress=True,
    return_stats=True
)
stats

Pruning layers: 100%|██████████| 16/16 [00:05<00:00,  2.69it/s]


{'original_parameters': 1235814400,
 'pruned_parameters': 1155303424,
 'reduction': 80510976,
 'percentage_reduction': 6.5148112855781575,
 'expansion_rate': 360.009765625}

In [None]:
VOW_results = model_evaluation(
  model_obj=pruned_model_vow,
  tokenizer=tokenizer,
  tasks=TASKS,
  limit=None  # Full dataset
)

In [21]:
VOW_results

{'boolq': {'accuracy': '0.6193', 'acc_norm': '0.0000'},
 'lambada_openai': {'perplexity': '532.36',
  'word_perplexity': '0.00',
  'bits_per_byte': '0.0000'},
 'wikitext': {'word_perplexity,none': '50.5592',
  'byte_perplexity,none': '2.0827',
  'bits_per_byte,none': '1.0584'}}

In [22]:
del pruned_model_vow
clear_gpu_cache()

# 📊 Method Comparison Results

| Model | WikiText PPL ↓ | Lambada PPL ↓ | BoolQ Acc ↑ | Status |
|-------|----------------|---------------|-------------|--------|
| **Base** | **11.57** | **5.72** | **0.6391** | Baseline |
| **MAW** | **17.49** (+51%) | **20.54** (+259%) | **0.6251** (-2.2%) | ✅ **SELECTED** |
| VOW | 50.56 (+337%) | 532.36 (+9,207%) | 0.6193 (-3.1%) | ❌ Rejected |
| PON | 72.52 (+527%) | 2032.80 (+35,440%) | 0.6220 (-2.7%) | ❌ Rejected |

**Conclusion**: MAW is the clear winner with acceptable degradation (~50% PPL increase) while VOW/PON show catastrophic failures (10,000%+ increases). MAW will be used for all experiments.

The results are conclusive. While all pruning methods introduce some performance degradation compared to the base model, the MAW (Maximum Absolute Weight) method preserves the model's capabilities far more effectively than the alternatives.

MAW shows only a minor increase in perplexity and a negligible drop in accuracy, demonstrating its ability to remove parameters without causing significant damage to the model's core fluency and comprehension.

VOW and PON both result in a catastrophic loss of performance, especially visible in the wikitext and lambada_openai perplexity scores. A Lambada perplexity score above 2000 for the PON method indicates that the model has become almost unusable for text generation tasks.

This data provides a strong empirical justification for selecting MAW as the sole neuron selection method for all subsequent, large-scale pruning experiments in this research project.

These experiments were powered by **OptiPFair**. If this research helps your work, consider:
- ⭐ Starring [the repo](https://github.com/peremartra/optipfair)
- 📖 Reading the [documentation](https://peremartra.github.io/optipfair/)
- 🐛 Reporting issues or suggesting features