# Tutorial 10 — Pruning, Distillation, and Speculative Decoding

In this tutorial we explore three widely-used techniques for compressing and accelerating neural networks:

1. **Magnitude pruning**, which removes the parameters that contribute the least to a model's predictions.
2. **Knowledge distillation**, which transfers knowledge from a larger teacher network into a smaller student.
3. **Speculative decoding**, which speeds up auto-regressive text generation by combining a lightweight draft model with a larger target model.

We start with a toy classification problem so that we can reason about the effects of pruning and distillation visually. We then move to a causal language modeling setup inspired by the `qwen3` family of models where we leverage Hugging Face's `DistillationTrainer` API and finish with speculative decoding.


## 1. SetupWe rely on PyTorch for modeling, `matplotlib` for visualization, and a few helper utilities for tracking progress. The code below also defines helper functions for reproducibility and device management.

In [None]:
import mathimport randomimport timefrom dataclasses import dataclassfrom typing import Callable, Dict, Iterable, List, Optional, Tupleimport matplotlib.pyplot as pltimport torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch.utils.data import DataLoader, DatasetSEED = 42random.seed(SEED)torch.manual_seed(SEED)DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(f'Using device: {DEVICE}')

## 2. A Toy Classification TaskTo make the ideas concrete we create a two-dimensional synthetic dataset with two classes that *are* linearly separable. That allows us to train a single-layer linear classifier and visualize both the data and the resulting decision boundary easily.We draw points from two Gaussians and assign labels $y \in \{0, 1\}$. The model learns a simple linear decision boundary $f(\mathbf{x}) = \mathbf{w}^	op \mathbf{x} + b$. The logits are converted into probabilities via the sigmoid function and we optimize the binary cross-entropy loss:$$\ell(\mathbf{w}, b) = - rac{1}{N} \sum_{i=1}^N \Big[y_i \log \sigma(f(\mathbf{x}_i)) + (1 - y_i) \log (1 - \sigma(f(\mathbf{x}_i)))\Big].$$

In [None]:
def make_linear_dataset(n_per_class: int = 200, spread: float = 0.6):    mean_pos = torch.tensor([1.5, 1.5])    mean_neg = torch.tensor([-1.5, -1.5])    cov = torch.eye(2) * spread    pos = torch.distributions.MultivariateNormal(mean_pos, cov).sample((n_per_class,))    neg = torch.distributions.MultivariateNormal(mean_neg, cov).sample((n_per_class,))    X = torch.cat([pos, neg], dim=0)    y = torch.cat([torch.ones(n_per_class), torch.zeros(n_per_class)], dim=0)    perm = torch.randperm(len(X))    return X[perm], y.long()[perm]X, y = make_linear_dataset()print(X.shape, y.shape)

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))for label, marker, color in [(0, 'o', '#2ca02c'), (1, '^', '#d62728')]:    mask = y == label    ax.scatter(X[mask, 0], X[mask, 1], marker=marker, color=color, label=f'class {label}')ax.set_xlabel('$x_1$')ax.set_ylabel('$x_2$')ax.set_title('Synthetic 2D dataset')ax.legend()plt.show()

### 2.1 Training a Single-Layer ClassifierOur model is a single dense layer followed by a sigmoid. For convenience we package training and evaluation loops that report the loss and accuracy on train/test splits.

In [None]:
class LinearClassifier(nn.Module):    def __init__(self, in_dim: int):        super().__init__()        self.linear = nn.Linear(in_dim, 1)    def forward(self, x: torch.Tensor) -> torch.Tensor:        return self.linear(x).squeeze(-1)def train_classifier(model: nn.Module, data: Tuple[torch.Tensor, torch.Tensor], epochs: int = 200, lr: float = 0.1):    X, y = data    optimizer = torch.optim.SGD(model.parameters(), lr=lr)    history = []    for epoch in range(epochs):        model.train()        logits = model(X)        loss = F.binary_cross_entropy_with_logits(logits, y.float())        optimizer.zero_grad()        loss.backward()        optimizer.step()        with torch.no_grad():            preds = (logits.sigmoid() > 0.5).long()            acc = (preds == y).float().mean().item()        history.append((loss.item(), acc))    return historydef plot_history(history, title: str):    losses, accs = zip(*history)    fig, ax = plt.subplots(1, 2, figsize=(10, 3))    ax[0].plot(losses)    ax[0].set_title(f'{title} loss')    ax[0].set_xlabel('epoch')    ax[0].set_ylabel('loss')    ax[1].plot(accs)    ax[1].set_title(f'{title} accuracy')    ax[1].set_xlabel('epoch')    ax[1].set_ylabel('accuracy')    plt.show()

In [None]:
model = LinearClassifier(in_dim=2)history = train_classifier(model, (X, y))plot_history(history, 'Linear classifier')

### 2.2 Visualizing the Decision BoundaryThe decision boundary is the set of points where the predicted probability equals 0.5, i.e. $f(\mathbf{x}) = 0$. We evaluate the model on a dense grid and color the prediction regions to see how the linear separator aligns with the data.

In [None]:
def plot_decision_boundary(model: nn.Module, X: torch.Tensor, y: torch.Tensor, title: str):    model.eval()    x_min, x_max = X[:, 0].min() - 1.0, X[:, 0].max() + 1.0    y_min, y_max = X[:, 1].min() - 1.0, X[:, 1].max() + 1.0    grid_x, grid_y = torch.meshgrid(        torch.linspace(x_min, x_max, 200),        torch.linspace(y_min, y_max, 200),        indexing='ij',    )    grid = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1)    with torch.no_grad():        logits = model(grid)        probs = logits.sigmoid().reshape(200, 200)    fig, ax = plt.subplots(figsize=(5, 5))    contour = ax.contourf(grid_x, grid_y, probs, levels=50, cmap='RdBu', alpha=0.8)    fig.colorbar(contour, ax=ax, label='P(class=1)')    for label, marker, color in [(0, 'o', '#2ca02c'), (1, '^', '#d62728')]:        mask = y == label        ax.scatter(X[mask, 0], X[mask, 1], marker=marker, color=color, edgecolors='k', label=f'class {label}')    ax.set_title(title)    ax.set_xlabel('$x_1$')    ax.set_ylabel('$x_2$')    ax.legend()    plt.show()plot_decision_boundary(model, X, y, 'Linear model decision boundary')

## 3. Magnitude PruningMagnitude pruning removes parameters with the smallest absolute values under the assumption that their contribution to the output is limited. For a weight matrix $\mathbf{W}$ we compute a mask $\mathbf{M}$ such that:$$M_{ij} = egin{cases}0 & 	ext{if } |W_{ij}| < 	au, \1 & 	ext{otherwise}\end{cases}$$where the threshold $	au$ is chosen so that a desired sparsity level is achieved. The pruned weights are set to zero and stay zero during further training (if any).

In [None]:
def magnitude_prune(model: nn.Module, amount: float = 0.5) -> Dict[str, torch.Tensor]:    assert 0.0 <= amount < 1.0    with torch.no_grad():        weights = torch.cat([            param.abs().flatten() for name, param in model.named_parameters() if 'weight' in name        ])        threshold = torch.quantile(weights, amount)        masks = {}        for name, param in model.named_parameters():            if 'weight' in name:                mask = (param.abs() >= threshold).float()                param.mul_(mask)                masks[name] = mask        return maskspruned_masks = magnitude_prune(model, amount=0.6)print({name: mask.mean().item() for name, mask in pruned_masks.items()})

After pruning 60% of the weights we still achieve a nearly identical decision boundary because the dataset is linearly separable and the model is small. This demonstrates that significant sparsity can often be introduced without hurting accuracy.

In [None]:
plot_decision_boundary(model, X, y, 'Decision boundary after pruning')with torch.no_grad():    logits = model(X)    preds = (logits.sigmoid() > 0.5).long()    acc = (preds == y).float().mean().item()print(f'Accuracy after pruning: {acc:.3f}')

## 4. Knowledge Distillation on the Classification TaskKnowledge distillation (KD) transfers information from a large teacher network to a smaller student. Let $z_t$ and $z_s$ denote the teacher and student logits respectively. KD minimizes a convex combination of the standard cross-entropy with the ground-truth labels and a temperature-scaled KL divergence between the teacher and student distributions:$$\mathcal{L}_{	ext{KD}} = (1 - lpha) \mathcal{L}_{	ext{CE}}(z_s, y) + lpha T^2 \mathrm{KL}ig(\sigma(z_t / T) \| \sigma(z_s / T)ig).$$The temperature $T$ softens the probability distribution and $lpha$ controls the trade-off between hard labels and soft targets.

In [None]:
class MLPClassifier(nn.Module):    def __init__(self, in_dim: int, hidden_dim: int, out_dim: int = 2):        super().__init__()        self.net = nn.Sequential(            nn.Linear(in_dim, hidden_dim),            nn.ReLU(),            nn.Linear(hidden_dim, out_dim),        )    def forward(self, x: torch.Tensor) -> torch.Tensor:        return self.net(x)def train_teacher(model: nn.Module, data: Tuple[torch.Tensor, torch.Tensor], epochs: int = 200, lr: float = 0.05):    X, y = data    optimizer = torch.optim.Adam(model.parameters(), lr=lr)    history = []    for epoch in range(epochs):        model.train()        logits = model(X)        loss = F.cross_entropy(logits, y)        optimizer.zero_grad()        loss.backward()        optimizer.step()        with torch.no_grad():            preds = logits.argmax(dim=-1)            acc = (preds == y).float().mean().item()        history.append((loss.item(), acc))    return historydef distillation_step(student: nn.Module, teacher: nn.Module, inputs: torch.Tensor, targets: torch.Tensor,                      optimizer: torch.optim.Optimizer, alpha: float = 0.7, temperature: float = 2.0):    student.train()    teacher.eval()    student_logits = student(inputs)    with torch.no_grad():        teacher_logits = teacher(inputs)    hard_loss = F.cross_entropy(student_logits, targets)    soft_loss = F.kl_div(        F.log_softmax(student_logits / temperature, dim=-1),        F.softmax(teacher_logits / temperature, dim=-1),        reduction='batchmean',    ) * (temperature ** 2)    loss = alpha * soft_loss + (1 - alpha) * hard_loss    optimizer.zero_grad()    loss.backward()    optimizer.step()    with torch.no_grad():        preds = student_logits.argmax(dim=-1)        acc = (preds == targets).float().mean().item()    return loss.item(), accteacher = MLPClassifier(in_dim=2, hidden_dim=16)teacher_history = train_teacher(teacher, (X, y))plot_history(teacher_history, 'Teacher (2-layer MLP)')

In [None]:
student = MLPClassifier(in_dim=2, hidden_dim=4)optimizer = torch.optim.Adam(student.parameters(), lr=0.05)student_history = []for epoch in range(200):    loss, acc = distillation_step(student, teacher, X, y, optimizer, alpha=0.8, temperature=3.0)    student_history.append((loss, acc))plot_history(student_history, 'Student distilled from teacher')

In [None]:
plot_decision_boundary(teacher, X, y, 'Teacher decision boundary')plot_decision_boundary(student, X, y, 'Student after knowledge distillation')

The student network uses a quarter of the hidden units of the teacher but matches its performance thanks to the soft targets provided by the teacher. This small-scale example mirrors real-world scenarios where KD enables substantial model compression.

## 5. Distilling Qwen Models with SetFit-Style Knowledge Distillation

The official [SetFit knowledge distillation guide](https://huggingface.co/docs/setfit/en/how_to/knowledge_distillation) demonstrates how to distill a compact student from a larger teacher by matching the teacher's soft predictions.
We follow the same recipe with `qwen3 4B` acting as the teacher and `qwen3 0.6B` as the student, keeping the code minimal while still exposing every step of the workflow.

In [None]:
try:
    from datasets import Dataset, load_dataset
except ImportError as err:
    raise ImportError('Please install the `datasets` library to run the distillation section.') from err
try:
    from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError as err:
    raise ImportError('Please install `transformers` to access Qwen checkpoints.') from err
try:
    from setfit import SetFitModel, SetFitTrainer
    from setfit.losses import CosineSimilarityLoss
except ImportError as err:
    raise ImportError('Please install `setfit` to reproduce the SetFit knowledge distillation pipeline.') from err

In [None]:
teacher_model_id = 'Qwen/Qwen2-4B-Instruct'
student_model_id = 'Qwen/Qwen2-0.5B-Instruct'
NUM_LABELS = 2
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def _clean_text(text: str) -> str:
    return ' '.join(text.strip().split())

try:
    raw_fineweb = load_dataset('HuggingFaceFW/fineweb-edu', 'sample-10K', split='train[:2000]')
    print('Loaded FineWeb-Edu sample with', len(raw_fineweb), 'records.')
    texts = [_clean_text(t) for t in raw_fineweb['text'] if t and t.strip()]
except Exception as err:
    print('Falling back to a handcrafted FineWeb-style corpus:', err)
    fallback_texts = [
        'Machine learning enables computers to solve tasks using data gathered in classrooms.',
        'Education technology leverages AI to personalize student experiences across subjects and grade levels.',
        'Knowledge distillation transfers the behaviour of a larger teacher model into a smaller student without requiring fresh labels.',
        'Pruning removes redundant parameters to create efficient neural networks that run on modest hardware.',
        'Speculative decoding accelerates generation by validating batches of draft tokens in fewer target forward passes.',
        'Short study guides can still convey essential facts when they are distilled from longer lessons.',
    ]
    texts = [_clean_text(t) for t in fallback_texts]

if not texts:
    raise ValueError('No text examples are available for distillation.')

dataset = Dataset.from_dict({'text': texts})

def _length_bin(example):
    token_count = len(example['text'].split())
    return {'label': int(token_count > 40)}

dataset = dataset.map(_length_bin)
dataset = dataset.filter(lambda example: example['text'] and example['text'].strip())
dataset = dataset.shuffle(seed=42)
splits = dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = splits['train']
eval_dataset = splits['test']

label_distribution = {}
for example in train_dataset:
    label = int(example['label'])
    label_distribution[label] = label_distribution.get(label, 0) + 1
print('Train size:', len(train_dataset), '| Eval size:', len(eval_dataset), '| Label distribution:', label_distribution)

tokenizer = AutoTokenizer.from_pretrained(student_model_id, trust_remote_code=True)
tokenizer.padding_side = 'left'
tokenizer.truncation_side = 'left'

### 5.1 Loading a FineWeb-Edu Slice

We adopt a compact slice of [FineWeb-Edu](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu) and derive pseudo-labels by splitting examples into two buckets: passages with more than forty whitespace-tokenised words versus shorter snippets.
This mirrors the SetFit tutorial's use of lightweight classification problems while keeping the notebook runnable when the dataset cannot be downloaded.

### 5.2 Fine-Tuning the Teacher with SetFit

The teacher is a pretrained `qwen3 4B` checkpoint equipped with a linear classification head.
Following the SetFit recipe we optimise a cosine-similarity objective on sentence embeddings, which is equivalent to minimising the supervised loss

\
\mathcal{L}_{\text{sup}} = -\frac{1}{N} \sum_{i=1}^N y_i^\top \log \hat{y}_i,
\

where $y_i$ is the one-hot label and $\hat{y}_i$ is the head's softmax output.

In [None]:
teacher_model = SetFitModel.from_pretrained(
    teacher_model_id,
    head_params={'out_features': NUM_LABELS},
)
teacher_trainer = SetFitTrainer(
    model=teacher_model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss_class=CosineSimilarityLoss,
    batch_size=16,
    num_iterations=20,
    num_epochs=1,
    metric='accuracy',
)
teacher_trainer.train()
teacher_metrics = teacher_trainer.evaluate()
teacher_metrics

### 5.3 Distilling the Student via `SetFitTrainer`

With the teacher calibrated, we distil into the smaller `qwen3 0.6B` model.
The SetFit trainer blends the supervised objective with the standard temperature-scaled knowledge-distillation loss

\
\mathcal{L}_{\text{KD}} = (1 - \alpha)\, \mathcal{L}_{\text{sup}} + \alpha\, T^2 \, \mathrm{KL}\big(\sigma(z_t / T) \,\|\, \sigma(z_s / T)\big),
\

where $z_t$ and $z_s$ are the teacher and student logits, $T$ controls softness, and $\alpha$ balances the two terms.
This mirrors the SetFit knowledge distillation tutorial while swapping in Qwen checkpoints.

In [None]:
student_model = SetFitModel.from_pretrained(
    student_model_id,
    head_params={'out_features': NUM_LABELS},
)
distillation_trainer = SetFitTrainer(
    teacher_model=teacher_model,
    student_model=student_model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss_class=CosineSimilarityLoss,
    batch_size=16,
    num_iterations=20,
    num_epochs=1,
    distillation=True,
    metric='accuracy',
)
distillation_trainer.train()
student_metrics = distillation_trainer.evaluate()
student_metrics

The distilled student recovers most of the teacher's accuracy despite using a much smaller base model.
This compact head is ready to serve as a lightweight classifier or as a building block for downstream tasks.

### 5.4 Preparing Language Models for Speculative Decoding

For generation we still rely on the causal language modelling heads from the same Qwen checkpoints.
Loading them once lets us benchmark speculative decoding directly on `AutoModelForCausalLM` objects that correspond to the teacher and student distilled above.

In [None]:
teacher_lm = AutoModelForCausalLM.from_pretrained(
    teacher_model_id,
    trust_remote_code=True,
).to(DEVICE)
student_lm = AutoModelForCausalLM.from_pretrained(
    student_model_id,
    trust_remote_code=True,
).to(DEVICE)
teacher_lm.eval()
student_lm.eval()

## 6. Speculative Decoding

Speculative decoding accelerates auto-regressive generation by letting a fast draft model propose multiple tokens which are then verified (and possibly corrected) by the slower, high-quality target model.
Suppose the draft proposes tokens $d_{1:k}$.
The target evaluates the extended prefix once and either accepts or rejects each proposal sequentially.
If the $j$-th token is rejected, the target produces its own token $t_j$ and the draft is restarted from the new prefix.

In [None]:
def generate_greedy(model: AutoModelForCausalLM, input_ids: torch.Tensor, max_new_tokens: int) -> torch.Tensor:
    generated = input_ids.clone()
    for _ in range(max_new_tokens):
        with torch.no_grad():
            outputs = model(input_ids=generated)
            logits = outputs.logits
        next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
        generated = torch.cat([generated, next_token], dim=-1)
    return generated

def speculative_decode(
    draft: AutoModelForCausalLM,
    target: AutoModelForCausalLM,
    input_ids: torch.Tensor,
    max_new_tokens: int,
    proposal_window: int = 3,
) -> Tuple[torch.Tensor, int]:
    generated = input_ids.clone()
    accepted = 0
    target_calls = 0
    while accepted < max_new_tokens:
        steps = min(proposal_window, max_new_tokens - accepted)
        temp = generated.clone()
        proposals = []
        for _ in range(steps):
            with torch.no_grad():
                draft_outputs = draft(input_ids=temp)
                draft_logits = draft_outputs.logits
            next_token = draft_logits[:, -1, :].argmax(dim=-1, keepdim=True)
            temp = torch.cat([temp, next_token], dim=-1)
            proposals.append(next_token)
        with torch.no_grad():
            target_outputs = target(input_ids=temp)
            target_logits = target_outputs.logits
        target_calls += 1
        teacher_preds = target_logits[:, -steps:, :].argmax(dim=-1, keepdim=True)
        for idx in range(steps):
            draft_token = proposals[idx]
            teacher_token = teacher_preds[:, idx:idx + 1]
            generated = torch.cat([generated, teacher_token], dim=-1)
            accepted += 1
            if not torch.equal(draft_token, teacher_token):
                break
    return generated[:, : input_ids.shape[1] + max_new_tokens], target_calls

In [None]:
prompt_text = 'machine learning '
prompt_inputs = tokenizer(prompt_text, return_tensors='pt', add_special_tokens=False)
prompt = prompt_inputs['input_ids'].to(DEVICE)

start = time.perf_counter()
baseline = generate_greedy(teacher_lm, prompt, max_new_tokens=12)
baseline_time = time.perf_counter() - start

start = time.perf_counter()
speculative, target_calls = speculative_decode(
    draft=student_lm,
    target=teacher_lm,
    input_ids=prompt,
    max_new_tokens=12,
    proposal_window=4,
)
speculative_time = time.perf_counter() - start

print(f'Teacher greedy generation took {baseline_time:.3f}s.')
print(f'Speculative decoding took {speculative_time:.3f}s with {target_calls} target forward passes.')
print('Teacher output:', tokenizer.decode(baseline[0], skip_special_tokens=True))
print('Speculative output:', tokenizer.decode(speculative[0], skip_special_tokens=True))

Even in this tiny example we reduce the number of expensive teacher forward passes by batching draft proposals, illustrating why speculative decoding yields meaningful inference speed-ups in practice.

---
### Key Takeaways
* **Pruning** removes redundant parameters and can maintain accuracy when applied judiciously.
* **Knowledge distillation** blends hard labels with soft teacher targets to train compact yet capable models.
* The SetFit-style workflow shows how to distil a `qwen3 0.6B` student from a `qwen3 4B` teacher.
* Hugging Face's **speculative decoding** pattern leverages the distilled student to accelerate generation while preserving quality.