# 3. Sharpness-aware minimization (SAM) wirh Sparse Networks - 25 points

## 3.1 Get a sparse networks through pruning
**Pruning** is a technique used to reduce the size and complexity of a neural network model by removing (setting to zero) less important parameters. The goal is to create a more efficient model that retains its predictive accuracy while being smaller, which can improve both inference speed and memory usage.

Let's train a simple model on the MNIST dataset to learn about pruning at first. We just use 10% of the dataset for both training and testing.

### 3.1.1 Train a dense network with SGD
Let us first train a dense model with SGD. We reuse the model for the discriminator of the GAN in Homework 2 and name it 'Classifier'.

In [1]:
from lib.part3.utils import *
max_epochs = 10
device = "cpu" # Change this if you can and want to use a GPU device
model = Classifier().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.05)

We define the optimizing process of SGD.

In [2]:
def optimize_sgd(model, optimizer, img, label):
    optimizer.zero_grad()
    output = model(img)
    loss = cross_entropy(output, label)
    loss.backward()
    optimizer.step()

The following cell runs the training loop, this might take a few minutes.

In [3]:
train_model(model, optimizer, optimize_sgd, max_epochs)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:01<00:00, 8.66MB/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 277kB/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 2.76MB/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 1.87MB/s]


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

Epoch 0 with 0.932 accuracy on the validation set.
Epoch 1 with 0.948 accuracy on the validation set.
Epoch 2 with 0.956 accuracy on the validation set.
Epoch 3 with 0.958 accuracy on the validation set.
Epoch 4 with 0.964 accuracy on the validation set.
Epoch 5 with 0.963 accuracy on the validation set.
Epoch 6 with 0.965 accuracy on the validation set.
Epoch 7 with 0.966 accuracy on the validation set.
Epoch 8 with 0.968 accuracy on the validation set.
Epoch 9 with 0.967 accuracy on the validation set.


#### Evaluate model
Evaluate the model on the test set.

In [4]:
acc = evaluate(model)
print(f"Accuracy of {round(acc, 4)} on the test set.")

Accuracy of 0.9768 on the test set.


### 3.1.2 Sparse network with magnitude-based pruning

Magnitude-based pruning specifically focuses on **removing weights that have the smallest absolute values**, under the assumption that weights with smaller magnitudes contribute less to the model's output.

**(1)** (6 points) Realize magnitude-based pruning below, which removes a part of weights that have the smallest absolute values.

In [None]:
def magnitude_prune(model, prune_fraction):
    for name, param in model.named_parameters():
        if "weight" in name and param.requires_grad:
            # FILL: Get weight's absolute values
            abs_weight = param.data.abs()
            # FILL: Compute the threshold
            threshold = torch.quantile(abs_weight.view(-1), prune_fraction)
            # FILL: Prune weights below the threshold
            mask = abs_weight >= threshold
            param.data.mul_(mask)   # zero out pruned weights
    return model

In [6]:
import copy
# Copy a model for pruning
sparse_model = copy.deepcopy(model)
# Get a sparse model by pruning 50% parameters
sparse_model = magnitude_prune(sparse_model, prune_fraction=0.5)

Copy a sparse model for SAM implementation later in 3.2.

In [7]:
sparse_model_sam = copy.deepcopy(sparse_model)

Evaluation after pruning

In [8]:
acc = evaluate(sparse_model)
print(f"Accuracy of {round(acc, 4)} on the test set.")

Accuracy of 0.9422 on the test set.


### 3.1.3 Finetune the sparse model

We finetune the sparse model after pruning with SGD to recover its performance.

In [9]:
finetune_epoch = 3
train_model(sparse_model, optimizer, optimize_sgd, finetune_epoch)

Epoch 0 with 0.93 accuracy on the validation set.
Epoch 1 with 0.93 accuracy on the validation set.
Epoch 2 with 0.93 accuracy on the validation set.


Evaluate the sparse model after finetuning.

In [10]:
acc = evaluate(sparse_model)
print(f"Accuracy of {round(acc, 4)} on the test set.")

Accuracy of 0.9468 on the test set.


**(2)** (2 point) What are the pros and cons of sparse networks?

Sparse networks reduce the number of nonzero parameters, which can lower memory usage and (when the sparsity pattern is actually exploited by the hardware/software stack) reduce computation and speed up inference. Sparsity can also act as a form of regularization, sometimes improving generalization by removing weak weights.  

In practice, unstructured sparsity often yields little or no GPU speedup because dense kernels are highly optimized, while sparse kernels can suffer from overhead and irregular memory access. Accuracy may degrade after pruning, especially at high sparsity levels, so fine-tuning is usually required to recover performance. It also adds engineering complexity, since masking, sparse storage formats, and sparsity-aware tooling complicate both training and inference pipelines. Finally, the gains are hardware-dependent, since meaningful acceleration typically requires structured sparsity or specialized libraries and accelerators.


## 3.2 Train the sparse model with SAM

Sharpness-aware minimization (SAM) is a new optimization technique, which is satisfied with not just a low loss, instead it seeks a neighborhood with uniformly low loss. SAM is motivated by the link between the geometry of the loss landscape and generalization. It makes sense that a low loss within a uniformly low loss neighborhood will generalize better than a low loss within a region of higher variance.

To be specific, we consider a model with the weight vector of $\mathbf{w}$ and the training loss $L_S$. SAM aims to minimize the maximum loss within a small region which is usually a $\ell_2$ ball with $\rho$ radius. Note that $\rho$ is a small value close to $0$. Therefore, SAM can be formulated as a minimax optimization problem:
$$\min_{\mathbf{w}} \max_{\mathbf{\epsilon}: \|\mathbf{\epsilon}\|_2\leq \rho} L_S (\mathbf{w} + \mathbf{\epsilon})$$

**(3)** (3 points) Please solve the inner maximum problem by first-order Taylor expansion.

Using a first-order Taylor expansion around $\mathbf{w}$,
$$
L_S(\mathbf{w}+\boldsymbol{\epsilon})
\approx
L_S(\mathbf{w}) + \nabla_{\mathbf{w}}L_S(\mathbf{w})^\top \boldsymbol{\epsilon}.
$$
so the inner problem reduces to maximizing the linear form $\nabla_{\mathbf{w}}L_S(\mathbf{w})^\top \boldsymbol{\epsilon}$ over the $\ell_2$ ball $\{\boldsymbol{\epsilon}:\|\boldsymbol{\epsilon}\|_2\le \rho\}$.

Hence the inner maximization becomes
$$
\max_{\|\boldsymbol{\epsilon}\|_2\le \rho} L_S(\mathbf{w}+\boldsymbol{\epsilon})
\approx
L_S(\mathbf{w}) + \max_{\|\boldsymbol{\epsilon}\|_2\le \rho}\nabla_{\mathbf{w}}L_S(\mathbf{w})^\top \boldsymbol{\epsilon}.
$$
By Cauchy–Schwarz,
$$
\nabla_{\mathbf{w}}L_S(\mathbf{w})^\top \boldsymbol{\epsilon}
\le \|\nabla_{\mathbf{w}}L_S(\mathbf{w})\|_2\,\|\boldsymbol{\epsilon}\|_2
\le \rho \|\nabla_{\mathbf{w}}L_S(\mathbf{w})\|_2,
$$
and the maximum is achieved when $\boldsymbol{\epsilon}$ aligns with the gradient:
$$
\boldsymbol{\epsilon}^\star
=
\rho \frac{\nabla_{\mathbf{w}}L_S(\mathbf{w})}{\|\nabla_{\mathbf{w}}L_S(\mathbf{w})\|_2}.
$$
Therefore,
$$
\max_{\|\boldsymbol{\epsilon}\|_2\le \rho} L_S(\mathbf{w} + \boldsymbol{\epsilon})
\approx
L_S(\mathbf{w}) + \rho \|\nabla_{\mathbf{w}}L_S(\mathbf{w})\|_2.
$$


**(4)** (8 points) Now we will train the same model using the SAM optimizer.
Please implement SAM by the two steps below. The first step is for the maximizer which calculates $\epsilon$ obtained in question (1). The second step is the normal step for the minimizer: $\mathbf{w}_{t+1} = \mathbf{w}_{t} - \eta_t \nabla L_S (\mathbf{w}_t + \mathbf{\epsilon}_t)$ where $\eta_t$ is step size. Note that we set $\rho=0.05$.

Hint: be careful about weight updates.

In [None]:
class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, lr=0.01, rho=0.05):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, lr)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)

    def _grad_norm(self):
        # Note that p.grad gets the gradient; p.data gets the weight.
        norm = torch.norm(
                    torch.stack([
                        p.grad.norm(p=2)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        norm += 1e-12 # Avoid zero norm
        return norm

    @torch.no_grad()
    def first_step(self):
        # Add the perturbation on the weight.
        grad_norm = self._grad_norm()

        for group in self.param_groups:
            rho = group["rho"]
            scale = rho / grad_norm

            for p in group["params"]:
                if p.grad is None:
                    continue

                # save current weights
                self.state[p]["old_p"] = p.data.clone()

                # ascent step: w <- w + epsilon
                e_w = p.grad * scale
                p.add_(e_w)

        self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                p.data.copy_(self.state[p]["old_p"])

        self.base_optimizer.step()

        self.zero_grad()

Define an optimizer of `SAM` for the model. We recommend using `SGD` as base optimizer with a learning rate of $0.05$ (which is same with SGD).

In [12]:
base_optimizer = torch.optim.SGD
sam_optimizer = SAM(sparse_model_sam.parameters(), base_optimizer, lr=0.05)

**(5)** (4 points) Please define the optimizing process of SAM.

In [None]:
def optimize_sam(model, optimizer, img, label):

    enable_running_stats(model)
    # First forward-backward pass
    output = model(img)
    loss = cross_entropy(output, label)
    loss.backward()              # fills p.grad at w
    optimizer.first_step()       # w <- w + eps, and zero_grad()


    disable_running_stats(model)
    # Second forward-backward pass
    output_perturbed = model(img)
    loss_perturbed = cross_entropy(output_perturbed, label)
    loss_perturbed.backward()    # fills p.grad at w+eps
    optimizer.second_step()      # restore w, then base_optimizer.step(), then zero_grad()

    return loss_perturbed.item()


In [14]:
train_model(sparse_model_sam, sam_optimizer, optimize_sam, finetune_epoch)

Epoch 0 with 0.963 accuracy on the validation set.
Epoch 1 with 0.965 accuracy on the validation set.
Epoch 2 with 0.966 accuracy on the validation set.


#### Evaluate model
Evaluate the sparse model finetuned with SAM on the test set.

In [15]:
acc = evaluate(sparse_model_sam)
print(f"Accuracy of {round(acc, 4)} on the test set.")

Accuracy of 0.9778 on the test set.


**(6)** (2 points) Give a conclusion comparing SAM with SGD. Is there any drawback of SAM?

In our MNIST pruning experiment, the dense model trained with SGD reached a test accuracy of $0.9768$. After $50%$ magnitude-based pruning, accuracy dropped to $0.9422$. Finetuning the sparse model with SGD for $3$ epochs only slightly improved performance to $0.9468$ (recovering $0.0046$), whereas finetuning with SAM for $3$ epochs achieved $0.9778$ (recovering $0.0356$), which is $0.0310$ higher than the SGD-finetuned sparse model and even $0.0010$ higher than the original dense baseline.

After pruning, the network has less redundancy and becomes more sensitive: small changes to the remaining weights can cause larger changes in the output, so SGD can easily settle in a sharp minimum that fits the training data but is not robust, which helps explain the limited recovery from $0.9422$ to $0.9468$. In contrast, SAM updates parameters using gradients evaluated at a worst-case perturbed point $w+\epsilon$ within an $\ell_2$ ball of radius $\rho$, encouraging solutions where the loss stays low not only at $w$ but throughout a neighborhood around it. This biases training toward flatter, more robust regions of the loss landscape, which typically generalize better; in our experiment this robustness appears crucial after pruning, since SAM recovers accuracy from $0.9422$ to $0.9778$ and slightly surpasses the dense baseline ($0.9768$).

The main drawbacks of SAM are practical. First, it is substantially more expensive than SGD because each update requires two forward-backward passes (one to compute the ascent direction $\epsilon$ and one to compute the gradient at $w+\epsilon$), which roughly doubles training time and increases memory usage. Second, SAM introduces extra hyperparameter sensitivity: the neighborhood size $\rho$ interacts with the learning rate, weight decay, and momentum, so achieving good performance often requires retuning and can slow convergence if $\rho$ is too large (over-regularization) or too small (little benefit). Third, the implementation is more complex, beacuse parameters must be perturbed and then exactly restored before applying the base optimizer step, gradients must be carefully zeroed between the two passes, and BatchNorm running statistics typically need special handling so they are not updated on the perturbed forward pass.