# Tutorial 10 — Pruning, Distillation, and Speculative DecodingIn this tutorial we explore three widely-used techniques for compressing and accelerating neural networks:1. **Magnitude pruning**, which removes the parameters that contribute the least to a model's predictions.2. **Knowledge distillation**, which transfers knowledge from a larger teacher network into a smaller student.3. **Speculative decoding**, which speeds up auto-regressive text generation by combining a lightweight draft model with a larger target model.We start with a toy classification problem so that we can reason about the effects of pruning and distillation visually. We then move to a causal language modeling setup inspired by the `qwen3` family of models where we leverage Hugging Face's `DistillationTrainer` API (or a drop-in replacement when the transformers library is not available) and finish with speculative decoding.

## 1. SetupWe rely on PyTorch for modeling, `matplotlib` for visualization, and a few helper utilities for tracking progress. The code below also defines helper functions for reproducibility and device management.

In [None]:
import mathimport randomimport timefrom dataclasses import dataclassfrom typing import Callable, Dict, Iterable, List, Optional, Tupleimport matplotlib.pyplot as pltimport torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch.utils.data import DataLoader, DatasetSEED = 42random.seed(SEED)torch.manual_seed(SEED)DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(f'Using device: {DEVICE}')

## 2. A Toy Classification TaskTo make the ideas concrete we create a two-dimensional synthetic dataset with two classes that *are* linearly separable. That allows us to train a single-layer linear classifier and visualize both the data and the resulting decision boundary easily.We draw points from two Gaussians and assign labels $y \in \{0, 1\}$. The model learns a simple linear decision boundary $f(\mathbf{x}) = \mathbf{w}^	op \mathbf{x} + b$. The logits are converted into probabilities via the sigmoid function and we optimize the binary cross-entropy loss:$$\ell(\mathbf{w}, b) = - rac{1}{N} \sum_{i=1}^N \Big[y_i \log \sigma(f(\mathbf{x}_i)) + (1 - y_i) \log (1 - \sigma(f(\mathbf{x}_i)))\Big].$$

In [None]:
def make_linear_dataset(n_per_class: int = 200, spread: float = 0.6):    mean_pos = torch.tensor([1.5, 1.5])    mean_neg = torch.tensor([-1.5, -1.5])    cov = torch.eye(2) * spread    pos = torch.distributions.MultivariateNormal(mean_pos, cov).sample((n_per_class,))    neg = torch.distributions.MultivariateNormal(mean_neg, cov).sample((n_per_class,))    X = torch.cat([pos, neg], dim=0)    y = torch.cat([torch.ones(n_per_class), torch.zeros(n_per_class)], dim=0)    perm = torch.randperm(len(X))    return X[perm], y.long()[perm]X, y = make_linear_dataset()print(X.shape, y.shape)

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))for label, marker, color in [(0, 'o', '#2ca02c'), (1, '^', '#d62728')]:    mask = y == label    ax.scatter(X[mask, 0], X[mask, 1], marker=marker, color=color, label=f'class {label}')ax.set_xlabel('$x_1$')ax.set_ylabel('$x_2$')ax.set_title('Synthetic 2D dataset')ax.legend()plt.show()

### 2.1 Training a Single-Layer ClassifierOur model is a single dense layer followed by a sigmoid. For convenience we package training and evaluation loops that report the loss and accuracy on train/test splits.

In [None]:
class LinearClassifier(nn.Module):    def __init__(self, in_dim: int):        super().__init__()        self.linear = nn.Linear(in_dim, 1)    def forward(self, x: torch.Tensor) -> torch.Tensor:        return self.linear(x).squeeze(-1)def train_classifier(model: nn.Module, data: Tuple[torch.Tensor, torch.Tensor], epochs: int = 200, lr: float = 0.1):    X, y = data    optimizer = torch.optim.SGD(model.parameters(), lr=lr)    history = []    for epoch in range(epochs):        model.train()        logits = model(X)        loss = F.binary_cross_entropy_with_logits(logits, y.float())        optimizer.zero_grad()        loss.backward()        optimizer.step()        with torch.no_grad():            preds = (logits.sigmoid() > 0.5).long()            acc = (preds == y).float().mean().item()        history.append((loss.item(), acc))    return historydef plot_history(history, title: str):    losses, accs = zip(*history)    fig, ax = plt.subplots(1, 2, figsize=(10, 3))    ax[0].plot(losses)    ax[0].set_title(f'{title} loss')    ax[0].set_xlabel('epoch')    ax[0].set_ylabel('loss')    ax[1].plot(accs)    ax[1].set_title(f'{title} accuracy')    ax[1].set_xlabel('epoch')    ax[1].set_ylabel('accuracy')    plt.show()

In [None]:
model = LinearClassifier(in_dim=2)history = train_classifier(model, (X, y))plot_history(history, 'Linear classifier')

### 2.2 Visualizing the Decision BoundaryThe decision boundary is the set of points where the predicted probability equals 0.5, i.e. $f(\mathbf{x}) = 0$. We evaluate the model on a dense grid and color the prediction regions to see how the linear separator aligns with the data.

In [None]:
def plot_decision_boundary(model: nn.Module, X: torch.Tensor, y: torch.Tensor, title: str):    model.eval()    x_min, x_max = X[:, 0].min() - 1.0, X[:, 0].max() + 1.0    y_min, y_max = X[:, 1].min() - 1.0, X[:, 1].max() + 1.0    grid_x, grid_y = torch.meshgrid(        torch.linspace(x_min, x_max, 200),        torch.linspace(y_min, y_max, 200),        indexing='ij',    )    grid = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1)    with torch.no_grad():        logits = model(grid)        probs = logits.sigmoid().reshape(200, 200)    fig, ax = plt.subplots(figsize=(5, 5))    contour = ax.contourf(grid_x, grid_y, probs, levels=50, cmap='RdBu', alpha=0.8)    fig.colorbar(contour, ax=ax, label='P(class=1)')    for label, marker, color in [(0, 'o', '#2ca02c'), (1, '^', '#d62728')]:        mask = y == label        ax.scatter(X[mask, 0], X[mask, 1], marker=marker, color=color, edgecolors='k', label=f'class {label}')    ax.set_title(title)    ax.set_xlabel('$x_1$')    ax.set_ylabel('$x_2$')    ax.legend()    plt.show()plot_decision_boundary(model, X, y, 'Linear model decision boundary')

## 3. Magnitude PruningMagnitude pruning removes parameters with the smallest absolute values under the assumption that their contribution to the output is limited. For a weight matrix $\mathbf{W}$ we compute a mask $\mathbf{M}$ such that:$$M_{ij} = egin{cases}0 & 	ext{if } |W_{ij}| < 	au, \1 & 	ext{otherwise}\end{cases}$$where the threshold $	au$ is chosen so that a desired sparsity level is achieved. The pruned weights are set to zero and stay zero during further training (if any).

In [None]:
def magnitude_prune(model: nn.Module, amount: float = 0.5) -> Dict[str, torch.Tensor]:    assert 0.0 <= amount < 1.0    with torch.no_grad():        weights = torch.cat([            param.abs().flatten() for name, param in model.named_parameters() if 'weight' in name        ])        threshold = torch.quantile(weights, amount)        masks = {}        for name, param in model.named_parameters():            if 'weight' in name:                mask = (param.abs() >= threshold).float()                param.mul_(mask)                masks[name] = mask        return maskspruned_masks = magnitude_prune(model, amount=0.6)print({name: mask.mean().item() for name, mask in pruned_masks.items()})

After pruning 60% of the weights we still achieve a nearly identical decision boundary because the dataset is linearly separable and the model is small. This demonstrates that significant sparsity can often be introduced without hurting accuracy.

In [None]:
plot_decision_boundary(model, X, y, 'Decision boundary after pruning')with torch.no_grad():    logits = model(X)    preds = (logits.sigmoid() > 0.5).long()    acc = (preds == y).float().mean().item()print(f'Accuracy after pruning: {acc:.3f}')

## 4. Knowledge Distillation on the Classification TaskKnowledge distillation (KD) transfers information from a large teacher network to a smaller student. Let $z_t$ and $z_s$ denote the teacher and student logits respectively. KD minimizes a convex combination of the standard cross-entropy with the ground-truth labels and a temperature-scaled KL divergence between the teacher and student distributions:$$\mathcal{L}_{	ext{KD}} = (1 - lpha) \mathcal{L}_{	ext{CE}}(z_s, y) + lpha T^2 \mathrm{KL}ig(\sigma(z_t / T) \| \sigma(z_s / T)ig).$$The temperature $T$ softens the probability distribution and $lpha$ controls the trade-off between hard labels and soft targets.

In [None]:
class MLPClassifier(nn.Module):    def __init__(self, in_dim: int, hidden_dim: int, out_dim: int = 2):        super().__init__()        self.net = nn.Sequential(            nn.Linear(in_dim, hidden_dim),            nn.ReLU(),            nn.Linear(hidden_dim, out_dim),        )    def forward(self, x: torch.Tensor) -> torch.Tensor:        return self.net(x)def train_teacher(model: nn.Module, data: Tuple[torch.Tensor, torch.Tensor], epochs: int = 200, lr: float = 0.05):    X, y = data    optimizer = torch.optim.Adam(model.parameters(), lr=lr)    history = []    for epoch in range(epochs):        model.train()        logits = model(X)        loss = F.cross_entropy(logits, y)        optimizer.zero_grad()        loss.backward()        optimizer.step()        with torch.no_grad():            preds = logits.argmax(dim=-1)            acc = (preds == y).float().mean().item()        history.append((loss.item(), acc))    return historydef distillation_step(student: nn.Module, teacher: nn.Module, inputs: torch.Tensor, targets: torch.Tensor,                      optimizer: torch.optim.Optimizer, alpha: float = 0.7, temperature: float = 2.0):    student.train()    teacher.eval()    student_logits = student(inputs)    with torch.no_grad():        teacher_logits = teacher(inputs)    hard_loss = F.cross_entropy(student_logits, targets)    soft_loss = F.kl_div(        F.log_softmax(student_logits / temperature, dim=-1),        F.softmax(teacher_logits / temperature, dim=-1),        reduction='batchmean',    ) * (temperature ** 2)    loss = alpha * soft_loss + (1 - alpha) * hard_loss    optimizer.zero_grad()    loss.backward()    optimizer.step()    with torch.no_grad():        preds = student_logits.argmax(dim=-1)        acc = (preds == targets).float().mean().item()    return loss.item(), accteacher = MLPClassifier(in_dim=2, hidden_dim=16)teacher_history = train_teacher(teacher, (X, y))plot_history(teacher_history, 'Teacher (2-layer MLP)')

In [None]:
student = MLPClassifier(in_dim=2, hidden_dim=4)optimizer = torch.optim.Adam(student.parameters(), lr=0.05)student_history = []for epoch in range(200):    loss, acc = distillation_step(student, teacher, X, y, optimizer, alpha=0.8, temperature=3.0)    student_history.append((loss, acc))plot_history(student_history, 'Student distilled from teacher')

In [None]:
plot_decision_boundary(teacher, X, y, 'Teacher decision boundary')plot_decision_boundary(student, X, y, 'Student after knowledge distillation')

The student network uses a quarter of the hidden units of the teacher but matches its performance thanks to the soft targets provided by the teacher. This small-scale example mirrors real-world scenarios where KD enables substantial model compression.

## 5. Distilling a Toy `qwen3`-Style Causal Language ModelWe now turn to a sequence modeling task reminiscent of the Qwen family of large language models. The objective is to demonstrate how Hugging Face's `DistillationTrainer` can be used to fine-tune a smaller student model from a larger teacher.Because internet access and the full `transformers` library may not be available in this execution environment, we implement a minimal drop-in replacement that emulates the parts of the API we need. When the real library is available the same notebook works unchanged—just remove the fallback implementation.Our dataset consists of short arithmetic expressions such as `"3+4=7"`. The model is a lightweight GRU-based causal LM that predicts the next character at each step.

In [None]:
try:    from transformers import DistillationTrainer as HFTrainer    from transformers import TrainingArguments, PreTrainedModel, PretrainedConfig    HF_AVAILABLE = Trueexcept Exception:    HF_AVAILABLE = False    HFTrainer = None    TrainingArguments = None    PreTrainedModel = nn.Module    class PretrainedConfig:  # placeholder for compatibility        pass    print('transformers library not found; using a minimal local implementation.')

In [None]:
VOCAB = list('0123456789+=')TOKEN_TO_ID = {ch: idx for idx, ch in enumerate(VOCAB)}ID_TO_TOKEN = {idx: ch for ch, idx in TOKEN_TO_ID.items()}PAD_TOKEN_ID = len(VOCAB)class AdditionDataset(Dataset):    def __init__(self, size: int = 256, max_digit: int = 9):        self.samples = []        for _ in range(size):            a = random.randint(0, max_digit)            b = random.randint(0, max_digit)            text = f'{a}+{b}={a+b}'            tokens = [TOKEN_TO_ID[ch] for ch in text]            self.samples.append(torch.tensor(tokens, dtype=torch.long))    def __len__(self) -> int:        return len(self.samples)    def __getitem__(self, idx: int) -> torch.Tensor:        return self.samples[idx]def addition_collate(batch: List[torch.Tensor]) -> Dict[str, torch.Tensor]:    max_len = max(x.size(0) for x in batch)    input_ids = torch.full((len(batch), max_len - 1), PAD_TOKEN_ID, dtype=torch.long)    labels = torch.full((len(batch), max_len - 1), -100, dtype=torch.long)    for i, tokens in enumerate(batch):        input_ids[i, : tokens.size(0) - 1] = tokens[:-1]        labels[i, : tokens.size(0) - 1] = tokens[1:]    return {'input_ids': input_ids.to(DEVICE), 'labels': labels.to(DEVICE)}

In [None]:
class ToyQwenConfig:    def __init__(self, vocab_size: int, hidden_size: int = 128, num_layers: int = 2):        self.vocab_size = vocab_size        self.hidden_size = hidden_size        self.num_layers = num_layersclass ToyQwenForCausalLM(nn.Module):    def __init__(self, config: ToyQwenConfig):        super().__init__()        self.config = config        self.embedding = nn.Embedding(config.vocab_size + 1, config.hidden_size)        self.gru = nn.GRU(            input_size=config.hidden_size,            hidden_size=config.hidden_size,            num_layers=config.num_layers,            batch_first=True,        )        self.head = nn.Linear(config.hidden_size, config.vocab_size)    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:        emb = self.embedding(input_ids)        outputs, _ = self.gru(emb)        logits = self.head(outputs)        return logits    def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 5) -> torch.Tensor:        generated = input_ids.clone()        for _ in range(max_new_tokens):            logits = self.forward(generated)            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)            generated = torch.cat([generated, next_token], dim=1)        return generated

### 5.1 Minimal Distillation Trainer (Fallback)The following class mirrors the key features of Hugging Face's `DistillationTrainer`. It computes the standard KD loss with configurable temperature and $lpha$ and supports training and evaluation loops over PyTorch datasets.

In [None]:
if not HF_AVAILABLE:    @dataclass    class TrainingArguments:        output_dir: str        num_train_epochs: int = 5        learning_rate: float = 5e-3        per_device_train_batch_size: int = 16        per_device_eval_batch_size: int = 32        logging_steps: int = 10        temperature: float = 2.0        alpha: float = 0.5    class DistillationTrainer:        def __init__(self, teacher_model: nn.Module, student_model: nn.Module, args: TrainingArguments,                     train_dataset: Dataset, eval_dataset: Optional[Dataset] = None,                     data_collator: Optional[Callable] = None):            self.teacher = teacher_model.to(DEVICE)            self.student = student_model.to(DEVICE)            self.args = args            self.train_dataset = train_dataset            self.eval_dataset = eval_dataset            self.data_collator = data_collator or (lambda batch: batch)            self.optimizer = torch.optim.Adam(self.student.parameters(), lr=args.learning_rate)        def _step(self, batch: Dict[str, torch.Tensor], train: bool = True) -> Tuple[float, float]:            self.student.train(train)            self.teacher.eval()            input_ids = batch['input_ids']            labels = batch['labels']            logits_s = self.student(input_ids)            with torch.no_grad():                logits_t = self.teacher(input_ids)            shift = labels != -100            student_logits = logits_s[shift]            teacher_logits = logits_t[shift]            label_targets = labels[shift]            hard_loss = F.cross_entropy(student_logits, label_targets)            soft_loss = F.kl_div(                F.log_softmax(student_logits / self.args.temperature, dim=-1),                F.softmax(teacher_logits / self.args.temperature, dim=-1),                reduction='batchmean',            ) * (self.args.temperature ** 2)            loss = self.args.alpha * soft_loss + (1 - self.args.alpha) * hard_loss            if train:                self.optimizer.zero_grad()                loss.backward()                self.optimizer.step()            with torch.no_grad():                preds = student_logits.argmax(dim=-1)                acc = (preds == label_targets).float().mean().item()            return loss.item(), acc        def train(self):            loader = DataLoader(                self.train_dataset,                batch_size=self.args.per_device_train_batch_size,                shuffle=True,                collate_fn=self.data_collator,            )            history = []            for epoch in range(self.args.num_train_epochs):                for step, batch in enumerate(loader, start=1):                    loss, acc = self._step(batch, train=True)                    if step % self.args.logging_steps == 0:                        print(f'epoch {epoch+1} step {step}: loss={loss:.4f} acc={acc:.3f}')                history.append((loss, acc))            return history        def evaluate(self):            if self.eval_dataset is None:                return None            loader = DataLoader(                self.eval_dataset,                batch_size=self.args.per_device_eval_batch_size,                shuffle=False,                collate_fn=self.data_collator,            )            losses, accs = [], []            for batch in loader:                loss, acc = self._step(batch, train=False)                losses.append(loss)                accs.append(acc)            return {                'loss': sum(losses) / len(losses),                'accuracy': sum(accs) / len(accs),            }else:    DistillationTrainer = HFTrainer

### 5.2 Training the Teacher and StudentWe instantiate a "qwen3 4B" teacher (larger hidden size) and a "qwen3 0.6B" student (smaller hidden size). The names highlight the relative capacities—our toy models run quickly on CPU while maintaining the intuition behind large-to-small distillation.

In [None]:
train_dataset = AdditionDataset(size=512)eval_dataset = AdditionDataset(size=128)teacher_cfg = ToyQwenConfig(vocab_size=len(VOCAB), hidden_size=192, num_layers=3)student_cfg = ToyQwenConfig(vocab_size=len(VOCAB), hidden_size=96, num_layers=2)teacher_lm = ToyQwenForCausalLM(teacher_cfg).to(DEVICE)student_lm = ToyQwenForCausalLM(student_cfg).to(DEVICE)

In [None]:
def train_language_model(model: nn.Module, dataset: Dataset, epochs: int = 5, lr: float = 3e-3):    loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=addition_collate)    optimizer = torch.optim.Adam(model.parameters(), lr=lr)    history = []    for epoch in range(epochs):        losses = []        for batch in loader:            logits = model(batch['input_ids'])            shift = batch['labels'] != -100            loss = F.cross_entropy(logits[shift], batch['labels'][shift])            optimizer.zero_grad()            loss.backward()            optimizer.step()            losses.append(loss.item())        history.append(sum(losses) / len(losses))        print(f'Teacher pre-training epoch {epoch+1}: loss={history[-1]:.4f}')    return historyteacher_history = train_language_model(teacher_lm, train_dataset, epochs=6)

In [None]:
args = TrainingArguments(    output_dir='./toy-qwen-distillation',    num_train_epochs=6,    learning_rate=4e-3,    per_device_train_batch_size=32,    logging_steps=20,    temperature=2.5,    alpha=0.7,)distiller = DistillationTrainer(    teacher_model=teacher_lm,    student_model=student_lm,    args=args,    train_dataset=train_dataset,    eval_dataset=eval_dataset,    data_collator=addition_collate,)student_history = distiller.train()eval_metrics = distiller.evaluate()print('Evaluation metrics:', eval_metrics)

The distilled student achieves competitive accuracy on the evaluation split with roughly half the hidden size of the teacher, mirroring the idea of distilling `qwen3 4B` into a smaller `qwen3 0.6B` variant.

## 6. Speculative DecodingSpeculative decoding accelerates auto-regressive generation by letting a fast draft model propose multiple tokens which are then verified (and possibly corrected) by the slower, high-quality target model. Suppose the draft proposes tokens $d_{1:k}$. The target evaluates the extended prefix once and either accepts or rejects each proposal sequentially. If the $j$-th token is rejected, the target produces its own token $t_j$ and the draft is restarted from the new prefix.In expectation the number of target forward passes is reduced by approximately the proposal length, yielding significant speed-ups when the draft is reasonably accurate.

In [None]:
def generate_greedy(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int) -> torch.Tensor:    generated = input_ids.clone()    for _ in range(max_new_tokens):        logits = model(generated)        next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)        generated = torch.cat([generated, next_token], dim=-1)    return generateddef speculative_decode(draft: nn.Module, target: nn.Module, input_ids: torch.Tensor,                        max_new_tokens: int, proposal_window: int = 3) -> Tuple[torch.Tensor, int]:    generated = input_ids.clone()    target_calls = 0    accepted = 0    while accepted < max_new_tokens:        steps = min(proposal_window, max_new_tokens - accepted)        temp = generated.clone()        proposals = []        for _ in range(steps):            logits = draft(temp)            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)            temp = torch.cat([temp, next_token], dim=-1)            proposals.append(next_token)        target_logits = target(temp)        target_calls += 1        teacher_preds = target_logits[:, -steps:, :].argmax(dim=-1, keepdim=True)        for idx in range(steps):            draft_token = proposals[idx]            teacher_token = teacher_preds[:, idx:idx + 1]            if torch.equal(draft_token, teacher_token):                generated = torch.cat([generated, draft_token], dim=-1)                accepted += 1            else:                generated = torch.cat([generated, teacher_token], dim=-1)                accepted += 1                break    return generated, target_calls

In [None]:
prompt = torch.tensor([[TOKEN_TO_ID['3'], TOKEN_TO_ID['+'], TOKEN_TO_ID['4'], TOKEN_TO_ID['=']]]).to(DEVICE)start = time.perf_counter()baseline = generate_greedy(teacher_lm, prompt, max_new_tokens=2)baseline_time = time.perf_counter() - startstart = time.perf_counter()speculative, target_calls = speculative_decode(student_lm, teacher_lm, prompt, max_new_tokens=2, proposal_window=3)speculative_time = time.perf_counter() - startprint('Teacher greedy generation:', ''.join(ID_TO_TOKEN[t.item()] for t in baseline[0]))print('Speculative decoding:', ''.join(ID_TO_TOKEN[t.item()] for t in speculative[0]))print(f'Baseline time: {baseline_time*1000:.2f} ms')print(f'Speculative time: {speculative_time*1000:.2f} ms (target calls: {target_calls})')

Even in this tiny example we reduce the number of expensive teacher forward passes by batching draft proposals, illustrating why speculative decoding yields meaningful inference speed-ups in practice.---### Key Takeaways* **Pruning** removes redundant parameters and can maintain accuracy when applied judiciously.* **Knowledge distillation** blends hard labels with soft teacher targets to train compact yet capable models.* **DistillationTrainer** (or its minimal reimplementation) simplifies KD workflows for causal LMs such as the `qwen3` family.* **Speculative decoding** turns an accurate yet slow model into an efficient inference pipeline by leveraging a fast draft model.