
---

# Model Pruning for Pretrained LLMs

## 1. Introduction: What is Model Pruning?

**Model pruning** is a compression technique that reduces the size and computational complexity of a neural network by removing parameters (typically weights or neurons) that contribute little to the model's overall performance. The key idea is that many parameters in deep neural networks are redundant or underutilized, and can be pruned to create smaller, faster, and more efficient models.

Types of pruning:

* **Unstructured Pruning**: Removes individual weights regardless of their position in the model. Leads to sparse weight matrices.
* **Structured Pruning**: Removes entire units such as neurons, attention heads, or even layers. More hardware-friendly than unstructured pruning.
* **Global vs. Layer-wise Pruning**: Global pruning ranks all weights across the model, while layer-wise pruning does this within each layer.

---

## 2. Model Pruning in the Context of Pretrained LLMs

### Why Prune LLMs?

Pretrained large language models like GPT, BERT, LLaMA, and T5 have hundreds of millions to billions of parameters. Pruning helps in:

* **Reducing inference latency**
* **Lowering memory footprint**
* **Enabling edge deployment**
* **Reducing energy usage**
* **Serving more models on limited GPU resources**

However, pruning LLMs is non-trivial due to:

* Their **layered architecture** (transformer blocks),
* **Interdependencies** between layers and attention heads,
* Sensitivity to performance degradation.

---

## 3. Best Practices in LLM Pruning

### Pruning Strategy

* **Magnitude-Based Pruning**: Remove weights with the smallest absolute values. Simple but effective.
* **Gradient-Based Pruning**: Uses gradients to determine importance, such as SNIP or GraSP.
* **Attention Head Pruning**: Identify and remove redundant attention heads (e.g., via attention entropy or importance scores).
* **LayerDrop or Layer Pruning**: Remove entire transformer layers carefully, based on layer importance.

### Iterative Pruning and Fine-tuning

* Avoid one-shot pruning. **Iterative pruning + fine-tuning** is more effective:

  1. Prune a small percentage of weights.
  2. Fine-tune to recover accuracy.
  3. Repeat until target sparsity is reached.

### Prune Ratio

* Common practice: start with **20–30% sparsity**, then scale up.
* 80–90% sparsity is possible, but often with performance trade-offs unless sophisticated techniques are used (e.g., Lottery Ticket Hypothesis or movement pruning).

### Metrics to Track

* Perplexity (for language generation)
* F1, EM, or BLEU scores (for classification or QA)
* Inference time (latency, throughput)
* FLOPs reduction
* Memory usage
* Sparsity %

---

## 4. How to Know if the Model is Properly Pruned

A pruned model is considered “proper” when:

* **It maintains accuracy within acceptable loss bounds** (e.g., <1% drop in F1 or BLEU).
* **It exhibits desired sparsity levels** (e.g., 80% weights pruned).
* **It improves inference efficiency** (faster or lower memory use).
* **It behaves consistently across datasets and tasks.**

### Tools:

* **Hugging Face Transformers + Optimum + ONNX Runtime**: Evaluate sparsity and performance.
* **TorchPruner, PyTorch Lightning, or SparseML**: Visualize and validate pruning effectiveness.
* **NNMeter, DeepSparse, TensorRT, or TVM**: Measure real-world speedups.

---

## 5. Dataset Considerations

### Do You Need a Dataset to Prune?

Yes: **most pruning strategies require a dataset** for:

* **Scoring weight/feature importance**
* **Validating loss/performance drops**
* **Fine-tuning after pruning**

### Choosing the Right Dataset

* Use **task-aligned data**: If pruning a QA model, use SQuAD or NaturalQuestions.
* For general LLM pruning: Use a **diverse corpus** (e.g., WikiText-103, Pile, OpenWebText).
* For language-specific tasks: Choose data in the target language or domain.
* You can use a **subset** (even 1-5%) of the training set for iterative pruning/fine-tuning.

---

## 6. Resource Requirements

Pruning is less demanding than pretraining, but still non-trivial:

* **Memory**: Need to load the full pretrained model and optimizer states.
* **Compute**:

  * Unstructured pruning: \~1–2 GPU hours per iteration.
  * Structured pruning + fine-tuning: \~10–100 GPU hours depending on model size and strategy.
* **Frameworks**:

  * Hugging Face Transformers + PyTorch/TensorFlow
  * SparseML, DeepSpeed, and OpenVINO for deployment
  * Intel Neural Compressor or Nvidia TensorRT for inference speed-up

---

## 7. Concerns and Pitfalls

### Common Pitfalls

* **Loss of performance**: Especially if pruning is too aggressive without fine-tuning.
* **Unstructured pruning ≠ speedup**: Sparse weights don’t speed up inference unless the backend supports sparse computation (e.g., DeepSparse).
* **Layer collapse**: Removing critical layers can destabilize the model.
* **Poor generalization**: Over-pruned models may overfit to narrow domains.
* **Deployment compatibility**: Some runtimes do not support sparse models natively.

### Things to Avoid

* One-shot high-rate pruning
* Pruning without validation metrics
* Ignoring batch norm / layer norm interactions
* Using task-irrelevant data during fine-tuning

---

## 8. Success Stories and Applications

* **OpenAI GPT Pruning (2020)**: Achieved 90% sparsity with <1% loss in accuracy (via movement pruning).
* **DistilBERT + Pruning**: Combined distillation and pruning to reduce BERT size by 60% with minimal loss.
* **SparseGPT (MIT, 2023)**: A one-shot pruning method for LLMs like OPT and LLaMA, achieving >50% sparsity with minimal accuracy drop.
* **DeepSparse + Neural Magic**: Used for production-level sparse inference acceleration, especially in NLP tasks.
* **Transformer Head Pruning in BERT**: Showed that 30–40% of attention heads can be pruned without major accuracy loss.

---

## 9. Summary

| Aspect           | Recommendation                                                                  |
| ---------------- | ------------------------------------------------------------------------------- |
| **Goal**         | Reduce size/inference cost with minimal performance loss                        |
| **Strategy**     | Iterative pruning + fine-tuning preferred                                       |
| **Pruning type** | Structured (e.g., heads/layers) for deployability; unstructured for compression |
| **Dataset**      | Required for scoring/fine-tuning; should match task                             |
| **Toolkits**     | HuggingFace + SparseML, DeepSparse, Neural Compressor, TorchPruner              |
| **Verification** | Track sparsity, accuracy, latency, memory                                       |
| **Pitfalls**     | Overpruning, unstructured-only pruning, lack of fine-tuning                     |

---

## 10. Further Reading & Resources

* **Papers**

  * ["The Lottery Ticket Hypothesis"](https://arxiv.org/abs/1803.03635)
  * ["SparseGPT"](https://arxiv.org/abs/2301.00774)
  * ["Movement Pruning"](https://arxiv.org/abs/2005.07683)

* **Libraries**

  * [SparseML](https://github.com/neuralmagic/sparseml)
  * [Hugging Face Optimum](https://huggingface.co/docs/optimum/index)
  * [Intel Neural Compressor](https://github.com/intel/neural-compressor)
  * [NVIDIA TensorRT](https://developer.nvidia.com/tensorrt)

---

# Lab: Pruning an LLM Using TensorFlow
- #### SST-2 Dataset

- #### Constraints:
> - ✅ Runs on CPU
> - ✅ Uses TF-compatible Hugging Face models

In [None]:
import tensorflow as tf
from transformers import TFAutoModel, AutoTokenizer
from datasets import load_dataset
import numpy as np
import time

### 1. Load and Preprocess SST-2 Dataset

In [None]:
dataset = load_dataset("glue", "sst2")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

MAX_LEN = 128
BATCH_SIZE = 32

def tokenize_function(example):
    return tokenizer(
        example["sentence"], padding="max_length", truncation=True, max_length=MAX_LEN
    )

tokenized = dataset.map(tokenize_function, batched=True)
tokenized.set_format(type='tensorflow', columns=['input_ids', 'attention_mask', 'label'])

#### Convert to tf.data.Dataset

In [None]:
def convert_to_tf_dataset(tokenized_dataset):
    return tf.data.Dataset.from_tensor_slices((
        {
            "input_ids": tokenized_dataset["input_ids"],
            "attention_mask": tokenized_dataset["attention_mask"]
        },
        tokenized_dataset["label"]
    )).batch(BATCH_SIZE)

train_ds = convert_to_tf_dataset(tokenized['train'].shuffle(1000).select(range(5000)))
val_ds = convert_to_tf_dataset(tokenized['validation'])

### Load Pretrained Model

In [None]:
base_model = TFAutoModel.from_pretrained("distilbert-base-uncased")

# Freeze base model to speed up training
base_model.trainable = False

### Build Classifier Model

In [None]:
inputs = {
    "input_ids": tf.keras.Input(shape=(MAX_LEN,), dtype=tf.int32, name="input_ids"),
    "attention_mask": tf.keras.Input(shape=(MAX_LEN,), dtype=tf.int32, name="attention_mask")
}

outputs = base_model(inputs)[0][:, 0, :]  # [CLS] token
outputs = tf.keras.layers.Dense(64, activation='relu', name="dense_1")(outputs)
outputs = tf.keras.layers.Dense(1, activation='sigmoid', name="classifier")(outputs)

model = tf.keras.Model(inputs=inputs, outputs=outputs)

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=3e-5),
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=["accuracy"]
)

model.summary()

### Train Baseline Model

In [None]:
print("Training baseline model...")
model.fit(train_ds, validation_data=val_ds, epochs=2)
### note: this can take up to 10 min per epoch

#### Evaluate baseline

In [None]:
baseline_loss, baseline_acc = model.evaluate(val_ds)
print(f"\nBaseline Accuracy: {baseline_acc:.4f}")

#### Measure inference time

In [None]:
sample_batch = next(iter(val_ds))
start_time = time.time()
_ = model.predict(sample_batch[0])
print(f"Baseline inference time: {time.time() - start_time:.4f} sec")

### Weight Pruning

In [None]:
def prune_weights_by_magnitude(model, sparsity=0.5):
    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.Dense):
            weights, biases = layer.get_weights()
            threshold = np.percentile(np.abs(weights), sparsity * 100)
            pruned_weights = np.where(np.abs(weights) < threshold, 0, weights)
            layer.set_weights([pruned_weights, biases])
    return model

# Apply pruning

print("\nPruning weights manually...")
pruned_model = prune_weights_by_magnitude(model, sparsity=0.5)

#### Evaluate pruned model

In [None]:
pruned_loss, pruned_acc = pruned_model.evaluate(val_ds)
print(f"\nPruned Accuracy: {pruned_acc:.4f}")

#### Measure inference time after pruning

In [None]:
start_time = time.time()
_ = pruned_model.predict(sample_batch[0])
print(f"Pruned inference time: {time.time() - start_time:.4f} sec")

#### Save Pruned Model

In [None]:
pruned_model.save("pruned_model")
print("\nPruned model saved as 'pruned_model'")