# Pruning, Quantization and Finetuning on ResNet-20 model

### Luopeiwen Yi

### 1. Introduction
Deep neural networks (DNNs) often require significant computation and memory resources, making them challenging to deploy on edge or embedded systems. To address this, model compression techniques such as pruning and fixed-point quantization are widely used. In this experiment, we explore different pruning strategies and quantization methods, both individually and in combination, to evaluate their effectiveness on the ResNet-20 model trained on CIFAR-10.

---

### 2. Experiment Design

#### 2.1. Baseline
We start with a pretrained floating-point ResNet-20 model and evaluate its baseline performance:
- **Test Accuracy:** 0.9151
- **Test Loss:** 0.3231

#### 2.2. Pruning Methods
We evaluate three pruning strategies, each aiming for 80% sparsity:
- **One-Shot Layer-Wise Pruning:** Prune 80% of the weights once, then fine-tune.
- **Iterative Layer-Wise Pruning:** Gradually prune over the first 10 epochs, increasing sparsity by 8% each epoch, followed by 10 epochs of fine-tuning.
- **Global Iterative Pruning:** Use a global threshold across all layers for pruning (same percent overall, variable per-layer sparsity), repeated in an iterative fashion.

#### 2.3. Quantization Methods
We apply fixed-point quantization to the residual blocks of the ResNet-20 model using:
- **Asymmetric Quantization**
- **Symmetric Quantization**
- **Both with and without fine-tuning**

Bit-widths (Nbits) tested: 6, 5, 4, 3, 2.

#### 2.4. Combined Pruning + Quantization
We examine the performance of applying fixed-point quantization (Nbits=5 to 2) on a model pruned to 80% sparsity using the best-performing pruning method, and evaluate both before and after finetuning.

---

### 3. Results and Observations

#### 3.1. Pruning + Fine-Tuning Accuracy
| Method                           | Test Accuracy | Test Loss |
|----------------------------------|----------------|------------|
| Floating-point Baseline          | 0.9151         | 0.3231     |
| One-Shot Pruning + Fine-tuning   | 0.8794         | 0.3664     |
| Iterative Pruning + Fine-tuning  | 0.8769         | 0.3750     |
| Global Iterative Pruning + FT    | **0.8841**     | **0.3483** |

**Observation:**
- All pruning methods incur accuracy drops from the FP baseline.
- Global iterative pruning performs best, balancing sparsity and accuracy.
- One-shot pruning is more aggressive and loses more performance.

#### 3.2. Quantization without Finetuning
| Nbits | Asymmetric Acc | Symmetric Acc |
|-------|----------------|----------------|
| 6     | 0.9144         | 0.9134         |
| 5     | 0.9113         | 0.9071         |
| 4     | 0.8973         | 0.8532         |
| 3     | 0.7660         | 0.7151         |
| 2     | 0.0899         | 0.1000         |

**Observation:**
- Both quantization types perform similarly at higher bit-widths.
- Symmetric quantization tends to degrade more quickly under low bit-width.
- Performance drops sharply at 3 and 2 bits, showing the need for finetuning.

#### 3.3. Asymmetric Quantization with Finetuning
| Nbits | Accuracy After Finetune |
|-------|--------------------------|
| 5     | 0.9156                   |
| 4     | 0.9140                   |
| 3     | 0.9058                   |
| 2     | 0.8597                   |

**Observation:**
- Finetuning significantly recovers performance, especially at lower precisions.
- At 5 and 4 bits, accuracy approaches or matches full-precision baseline.

#### 3.4. Quantized Pruned Model
| Nbits | Acc Before FT | Acc After FT |
|-------|---------------|---------------|
| 5     | 0.8778        | 0.9032        |
| 4     | 0.8603        | 0.8994        |
| 3     | 0.7186        | 0.8715        |
| 2     | 0.1000        | 0.3348        |

**Observation:**
- Finetuning is crucial when combining pruning and quantization.
- Accuracy recovers well up to 3-bit precision even after pruning.
- Performance degrades drastically at 2 bits.

---

### 4. Comparative Analysis

| Feature                      | One-Shot Pruning | Iterative Pruning | Global Iterative | Quantization Only | Prune + Quantize |
|-----------------------------|------------------|-------------------|------------------|-------------------|------------------|
| Performance vs. FP Baseline | ↓ -3.9%        | ↓ -4.2%         | **↓ -3.4%**     | ↓ varies by Nbits | ↓ more at low bits |
| Sparsity Flexibility        | Low              | Medium            | High             | N/A               | High             |
| Bit-width Impact            | N/A              | N/A               | N/A              | High              | High             |
| Best Tradeoff               | No               | No                | **Yes**          | Yes (with FT)     | Yes (with FT)    |

---

### 5. Conclusion
This report demonstrates the effectiveness of pruning and quantization in compressing neural networks with minimal loss in performance. Key takeaways include:
- Global iterative pruning outperforms other sparsity strategies under high compression.
- Asymmetric quantization performs slightly better than symmetric at very low bit-widths.
- Finetuning is essential for recovering performance, especially after aggressive pruning or low-bit quantization.
- A combined pruning + quantization pipeline, with proper fine-tuning, can significantly reduce model size while maintaining acceptable accuracy.

These findings support the feasibility of deploying compressed ResNet-20 models in resource-constrained environments without sacrificing much performance.

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import sys
import os

In [3]:
# Change this to the absolute path where dataset.py and utils.py are stored
CODE_PATH = "/content/drive/MyDrive/ECE661 Assignment4"

# Add this path to sys.path so Python can find it
sys.path.append(CODE_PATH)

# Check if Colab can see the files
print("Files in directory:", os.listdir(CODE_PATH))

Files in directory: ['train_util.py', 'resnet20.py', 'ECE661__Pruning_and_Fixed_Point_Quantization_HW4.pdf', 'pretrained_model.pt', '__pycache__', 'net_after_finetune.pt', 'net_after_iterative_prune.pt', 'net_after_global_iterative_prune.pt', 'FP_layers_template.py', 'quantized_net_after_finetune_Nbits_5.pt', 'quantized_net_after_finetune_Nbits_4.pt', 'quantized_net_after_finetune_Nbits_3.pt', 'quantized_net_after_finetune_Nbits_2.pt', 'pruned_quantized_net_Nbits_5.pt', 'pruned_quantized_net_Nbits_4.pt', 'pruned_quantized_net_Nbits_3.pt', 'pruned_quantized_net_Nbits_2.pt', 'FP_layers.py', 'FP_layers_asymmetric.py', 'Pruning, Fixed-point quantization and finetuning on ResNet-20 model.ipynb']


### Model preperation

In [4]:
from resnet20 import ResNetCIFAR
from train_util import train, finetune, test
import torch
import numpy as np
import random

import time

import torchvision.transforms as transforms
import torchvision
import torch.nn as nn
import torch.optim as optim

from FP_layers import *

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
seed = 42
random.seed(seed)  # Python's random module
np.random.seed(seed)  # NumPy's random module
torch.manual_seed(seed)  # PyTorch's random seed for CPU
torch.cuda.manual_seed(seed)  # PyTorch's random seed for the current GPU
torch.cuda.manual_seed_all(seed)  # PyTorch's random seed for all GPUs (if using multi-GPU)

# Ensure deterministic behavior on GPU (optional, may slow down training)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Optional: Set environment variables for further reproducibility
os.environ['PYTHONHASHSEED'] = str(seed)

In [6]:
net = ResNetCIFAR(num_layers=20, Nbits=None)
net = net.to(device)

In [7]:
# Load the best weight paramters
net.load_state_dict(torch.load(os.path.join(CODE_PATH, "pretrained_model.pt")))
test(net)

100%|██████████| 170M/170M [00:13<00:00, 12.7MB/s]


Test Loss=0.3231, Test accuracy=0.9151


Test accuracy of the floating-point pretrained model:  91.51%

## Pruning

### Prune by percentage

Pruning a single layer's weights (like FP_Conv or FP_Linear) by:

- Taking the absolute value of weights

- Finding the q-th percentile (e.g., q=70 → prune bottom 70% smallest-magnitude weights)

- Creating a binary mask to zero out those small weights



In [26]:
def prune_by_percentage(layer, q=70.0):
    """
    Prune the weight parameters of a layer by zeroing out the
    bottom-q percent smallest magnitude weights.
    """
    # Convert weight to numpy array (detach from graph)
    weight = layer.weight.data.cpu().numpy()

    # Calculate threshold at q-th percentile of absolute weight values
    threshold = np.percentile(np.abs(weight), q)

    # Create mask: keep weights >= threshold
    mask = np.abs(weight) >= threshold

    # Convert mask to torch tensor and move to same device as weight
    mask_tensor = torch.tensor(mask, dtype=torch.float32, device=layer.weight.device)

    # Apply mask (in-place pruning)
    layer.weight.data.mul_(mask_tensor)

In [27]:
qs = [0.2, 0.4, 0.6, 0.7, 0.8]  # pruning ratios

for q in qs:
    print(f"\n--- Pruning q = {q} ---")
    net.load_state_dict(torch.load(os.path.join(CODE_PATH, "pretrained_model.pt")))

    for name, layer in net.named_modules():
        if (isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear)) and 'id_mapping' not in name:
            # Convert pruning ratio (0.2) to percentile (20.0)
            prune_by_percentage(layer, q=q * 100)

            np_weight = layer.weight.data.cpu().numpy()
            zeros = np.sum(np_weight == 0)
            total = np_weight.size
            print(f"Sparsity of {name}: {zeros}/{total} = {zeros/total:.4f}")

    # Evaluate pruned model
    test(net)


--- Pruning q = 0.2 ---
Sparsity of head_conv.0.conv: 87/432 = 0.2014
Sparsity of body_op.0.conv1.0.conv: 461/2304 = 0.2001
Sparsity of body_op.0.conv2.0.conv: 461/2304 = 0.2001
Sparsity of body_op.1.conv1.0.conv: 461/2304 = 0.2001
Sparsity of body_op.1.conv2.0.conv: 461/2304 = 0.2001
Sparsity of body_op.2.conv1.0.conv: 461/2304 = 0.2001
Sparsity of body_op.2.conv2.0.conv: 461/2304 = 0.2001
Sparsity of body_op.3.conv1.0.conv: 922/4608 = 0.2001
Sparsity of body_op.3.conv2.0.conv: 1843/9216 = 0.2000
Sparsity of body_op.4.conv1.0.conv: 1843/9216 = 0.2000
Sparsity of body_op.4.conv2.0.conv: 1843/9216 = 0.2000
Sparsity of body_op.5.conv1.0.conv: 1843/9216 = 0.2000
Sparsity of body_op.5.conv2.0.conv: 1843/9216 = 0.2000
Sparsity of body_op.6.conv1.0.conv: 3687/18432 = 0.2000
Sparsity of body_op.6.conv2.0.conv: 7373/36864 = 0.2000
Sparsity of body_op.7.conv1.0.conv: 7373/36864 = 0.2000
Sparsity of body_op.7.conv2.0.conv: 7373/36864 = 0.2000
Sparsity of body_op.8.conv1.0.conv: 7373/36864 = 0.2

### Finetune pruned model

#### One-Shot Pruning + Fine-Tuning

Methodology:

- Prune once: Remove a large percentage of weights (e.g., 80%) in one go.

- Fix the sparsity mask: Ensure pruned weights are permanently zero.

- Fine-tune the model with this fixed sparsity structure for multiple epochs to recover performance.

Intuition:

- This is a brute-force strategy.

- The model suddenly loses a large portion of its parameters.

- It then tries to re-adapt using only the remaining weights.

- Downside: Pruning so many weights all at once may cause irreversible accuracy loss, especially if some important weights are accidentally removed.



In [28]:
def finetune_after_prune(net, trainloader, criterion, optimizer, prune=True):
    """
    Finetune the pruned model for a single epoch.
    Ensures pruned weights remain zero throughout training.
    """
    weight_mask = {}
    for name, layer in net.named_modules():
        if (isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear)) and 'id_mapping' not in name:
            mask = (layer.weight.data != 0).float().to(layer.weight.device)
            weight_mask[name] = mask

    global_steps = 0
    train_loss = 0
    correct = 0
    total = 0
    start = time.time()

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        # Keep pruned weights at 0
        if prune:
            for name, layer in net.named_modules():
                if (isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear)) and 'id_mapping' not in name:
                    layer.weight.data.mul_(weight_mask[name])  # apply mask again

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        global_steps += 1

        if global_steps % 50 == 0:
            end = time.time()
            batch_size = 256
            num_examples_per_second = 50 * batch_size / (end - start)
            print("[Step=%d]\tLoss=%.4f\tacc=%.4f\t%.1f examples/second"
                 % (global_steps, train_loss / (batch_idx + 1), (correct / total), num_examples_per_second))
            start = time.time()

In [29]:
# Get pruned model
net.load_state_dict(torch.load(os.path.join(CODE_PATH, "pretrained_model.pt")))
for name,layer in net.named_modules():
    if (isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear)) and 'id_mapping' not in name:
        prune_by_percentage(layer, q=80.0)

# Training setup, do not change
batch_size=256
lr=0.002
reg=1e-4

print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=16)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.875, weight_decay=reg, nesterov=False)

==> Preparing data..




In [30]:
# Model finetuning
for epoch in range(20):
    print('\nEpoch: %d' % epoch)
    net.train()
    finetune_after_prune(net, trainloader, criterion, optimizer)
    #Start the testing code.
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    num_val_steps = len(testloader)
    val_acc = correct / total
    print("Test Loss=%.4f, Test acc=%.4f" % (test_loss / (num_val_steps), val_acc))

    if val_acc > best_acc:
        best_acc = val_acc
        print("Saving...")
        torch.save(net.state_dict(), os.path.join(CODE_PATH, "net_after_finetune.pt"))


Epoch: 0
[Step=50]	Loss=0.9780	acc=0.6847	6640.2 examples/second
[Step=100]	Loss=0.8098	acc=0.7359	9687.6 examples/second
[Step=150]	Loss=0.7194	acc=0.7633	9641.5 examples/second
Test Loss=0.5586, Test acc=0.8190
Saving...

Epoch: 1
[Step=50]	Loss=0.4706	acc=0.8413	6606.8 examples/second
[Step=100]	Loss=0.4605	acc=0.8437	10447.3 examples/second
[Step=150]	Loss=0.4504	acc=0.8470	10463.5 examples/second
Test Loss=0.4865, Test acc=0.8383
Saving...

Epoch: 2
[Step=50]	Loss=0.4058	acc=0.8622	6693.9 examples/second
[Step=100]	Loss=0.4013	acc=0.8627	9733.1 examples/second
[Step=150]	Loss=0.3965	acc=0.8645	10157.9 examples/second
Test Loss=0.4539, Test acc=0.8495
Saving...

Epoch: 3
[Step=50]	Loss=0.3748	acc=0.8708	6856.6 examples/second
[Step=100]	Loss=0.3741	acc=0.8712	10171.4 examples/second
[Step=150]	Loss=0.3651	acc=0.8751	10560.8 examples/second
Test Loss=0.4343, Test acc=0.8568
Saving...

Epoch: 4
[Step=50]	Loss=0.3523	acc=0.8836	6387.7 examples/second
[Step=100]	Loss=0.3455	acc=0.8835

In [31]:
net.load_state_dict(torch.load(os.path.join(CODE_PATH, "net_after_finetune.pt")))

for name, layer in net.named_modules():
    if (isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear)) and 'id_mapping' not in name:
        np_weight = layer.weight.data.cpu().numpy()
        zeros = np.sum(np_weight == 0)
        total = np_weight.size
        print('Sparsity of '+name+': '+str(zeros / total))

test(net)

Sparsity of head_conv.0.conv: 0.7986111111111112
Sparsity of body_op.0.conv1.0.conv: 0.7999131944444444
Sparsity of body_op.0.conv2.0.conv: 0.7999131944444444
Sparsity of body_op.1.conv1.0.conv: 0.7999131944444444
Sparsity of body_op.1.conv2.0.conv: 0.7999131944444444
Sparsity of body_op.2.conv1.0.conv: 0.7999131944444444
Sparsity of body_op.2.conv2.0.conv: 0.7999131944444444
Sparsity of body_op.3.conv1.0.conv: 0.7999131944444444
Sparsity of body_op.3.conv2.0.conv: 0.7999131944444444
Sparsity of body_op.4.conv1.0.conv: 0.7999131944444444
Sparsity of body_op.4.conv2.0.conv: 0.7999131944444444
Sparsity of body_op.5.conv1.0.conv: 0.7999131944444444
Sparsity of body_op.5.conv2.0.conv: 0.7999131944444444
Sparsity of body_op.6.conv1.0.conv: 0.7999674479166666
Sparsity of body_op.6.conv2.0.conv: 0.7999945746527778
Sparsity of body_op.7.conv1.0.conv: 0.7999945746527778
Sparsity of body_op.7.conv2.0.conv: 0.7999945746527778
Sparsity of body_op.8.conv1.0.conv: 0.7999945746527778
Sparsity of body

Test Accuracy of One-Shot Pruning + Fine-tuning at 80% sparsity level: 87.94%

#### Iterative Pruning + Fine-tuning

Methodology:

- Start from a dense (unpruned) pretrained model.

- For each of the first 10 epochs:

  - Gradually increase pruning percentage, e.g., prune 8%, 16%, ..., up to 80%.

  - Allow the network to adjust (fine-tune) after each pruning step, and even recover pruned weights.
  
  - After reaching 80% pruning at epoch 10, freeze the zero weights and continue fine-tuning with fixed sparsity for the next 10 epochs.

Intuition:

- This is a progressive or gradual approach.

- The model has a chance to redistribute importance among weights during early training.

- Weights that get pruned early can still come back (until sparsity is fixed).

- This encourages the network to naturally evolve toward sparse representations.

- Generally results in better performance and stability, especially under high sparsity levels.

In [32]:
net.load_state_dict(torch.load(os.path.join(CODE_PATH, "pretrained_model.pt")))
best_acc = 0.

for epoch in range(20):
    print('\nEpoch: %d' % epoch)
    net.train()

    if epoch < 10:
      q = 8 * (epoch + 1)
      print(f"Pruning {q:.1f}% of weights before Epoch {epoch}")

      for name, layer in net.named_modules():
          if (isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear)) and 'id_mapping' not in name:
              prune_by_percentage(layer, q=q)

    if epoch < 9:
        finetune_after_prune(net, trainloader, criterion, optimizer, prune=False)  # pruning not enforced yet
    else:
        finetune_after_prune(net, trainloader, criterion, optimizer, prune=True)   # enforce pruned weights remain 0

    # Evaluation
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    num_val_steps = len(testloader)
    val_acc = correct / total
    print("Test Loss = %.4f, Test acc = %.4f" % (test_loss / num_val_steps, val_acc))

    if epoch >= 10 and val_acc > best_acc:
        best_acc = val_acc
        print("Saving...")
        torch.save(net.state_dict(), os.path.join(CODE_PATH, "net_after_iterative_prune.pt"))


Epoch: 0
Pruning 8.0% of weights before Epoch 0
[Step=50]	Loss=0.0465	acc=0.9862	6832.9 examples/second
[Step=100]	Loss=0.0474	acc=0.9853	10718.8 examples/second
[Step=150]	Loss=0.0479	acc=0.9844	10429.4 examples/second
Test Loss = 0.3248, Test acc = 0.9129

Epoch: 1
Pruning 16.0% of weights before Epoch 1
[Step=50]	Loss=0.0468	acc=0.9844	6686.1 examples/second
[Step=100]	Loss=0.0476	acc=0.9846	9990.2 examples/second
[Step=150]	Loss=0.0479	acc=0.9844	9537.6 examples/second
Test Loss = 0.3275, Test acc = 0.9149

Epoch: 2
Pruning 24.0% of weights before Epoch 2
[Step=50]	Loss=0.0522	acc=0.9842	6586.6 examples/second
[Step=100]	Loss=0.0530	acc=0.9838	9894.4 examples/second
[Step=150]	Loss=0.0526	acc=0.9836	10124.1 examples/second
Test Loss = 0.3273, Test acc = 0.9131

Epoch: 3
Pruning 32.0% of weights before Epoch 3
[Step=50]	Loss=0.0571	acc=0.9816	6578.9 examples/second
[Step=100]	Loss=0.0576	acc=0.9808	9579.8 examples/second
[Step=150]	Loss=0.0567	acc=0.9813	10278.6 examples/second
Tes

In [33]:
net.load_state_dict(torch.load(os.path.join(CODE_PATH, "net_after_iterative_prune.pt")))

for name, layer in net.named_modules():
    if (isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear)) and 'id_mapping' not in name:
        np_weight = layer.weight.data.cpu().numpy()
        zeros = np.sum(np_weight == 0)
        total = np_weight.size
        print(f"Sparsity of {name}: {zeros}/{total} = {zeros/total:.4f}")

# Final test
test(net)

Sparsity of head_conv.0.conv: 345/432 = 0.7986
Sparsity of body_op.0.conv1.0.conv: 1843/2304 = 0.7999
Sparsity of body_op.0.conv2.0.conv: 1843/2304 = 0.7999
Sparsity of body_op.1.conv1.0.conv: 1843/2304 = 0.7999
Sparsity of body_op.1.conv2.0.conv: 1843/2304 = 0.7999
Sparsity of body_op.2.conv1.0.conv: 1843/2304 = 0.7999
Sparsity of body_op.2.conv2.0.conv: 1843/2304 = 0.7999
Sparsity of body_op.3.conv1.0.conv: 3686/4608 = 0.7999
Sparsity of body_op.3.conv2.0.conv: 7372/9216 = 0.7999
Sparsity of body_op.4.conv1.0.conv: 7372/9216 = 0.7999
Sparsity of body_op.4.conv2.0.conv: 7372/9216 = 0.7999
Sparsity of body_op.5.conv1.0.conv: 7372/9216 = 0.7999
Sparsity of body_op.5.conv2.0.conv: 7372/9216 = 0.7999
Sparsity of body_op.6.conv1.0.conv: 14745/18432 = 0.8000
Sparsity of body_op.6.conv2.0.conv: 29491/36864 = 0.8000
Sparsity of body_op.7.conv1.0.conv: 29491/36864 = 0.8000
Sparsity of body_op.7.conv2.0.conv: 29491/36864 = 0.8000
Sparsity of body_op.8.conv1.0.conv: 29491/36864 = 0.8000
Sparsity

Test Accuracy of Iterative Pruning + Fine-tuning at 80% sparsity level: 87.69%

#### Global iterative pruning + Fine-tuning

Methodology

- Instead of computing thresholds per layer, we:

  - Collect all weights across all layers.

  - Compute the global q-th percentile threshold.

  - Apply this same threshold to all layers — regardless of their own distribution.

Intuition

- We assume that importance of weights should be measured across the whole model, not just locally in each layer.

- Some layers may end up more sparse than others, depending on their weight distributions.

Tradeoffs

- More flexible: allows sensitive layers to retain more weights.

- Potentially better accuracy at same sparsity budget.

- Slightly more complex to implement and tune.

In [34]:
def global_prune_by_percentage(net, q=70.0):
    """
    Perform global pruning by thresholding weights globally across layers.
    :param q: percentile value, e.g., 70.0 means prune bottom 70% smallest-magnitude weights.
    """
    flattened_weights = []

    # Gather all weights from prunable layers
    for name, layer in net.named_modules():
        if isinstance(layer, (nn.Conv2d, nn.Linear)) and 'id_mapping' not in name:
            flattened_weights.append(np.abs(layer.weight.data.cpu().numpy()).flatten())

    # Compute global threshold
    all_weights = np.concatenate(flattened_weights)
    threshold = np.percentile(all_weights, q)

    # Apply mask globally to each layer
    for name, layer in net.named_modules():
        if isinstance(layer, (nn.Conv2d, nn.Linear)) and 'id_mapping' not in name:
            weight_np = layer.weight.data.cpu().numpy()
            mask = np.abs(weight_np) >= threshold
            mask_tensor = torch.tensor(mask, dtype=torch.float32, device=layer.weight.device)
            layer.weight.data.mul_(mask_tensor)

In [35]:
net.load_state_dict(torch.load(os.path.join(CODE_PATH, "pretrained_model.pt")))
best_acc = 0.

for epoch in range(20):
    print(f"\nEpoch: {epoch}")
    net.train()

    if epoch < 10:
        q = 8 * (epoch + 1)
        print(f"Global Pruning {q:.1f}% of weights before Epoch {epoch}")
        global_prune_by_percentage(net, q=q)

    if epoch<9:
        finetune_after_prune(net, trainloader, criterion, optimizer,prune=False)
    else:
        finetune_after_prune(net, trainloader, criterion, optimizer)

    #Start the testing code.
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    num_val_steps = len(testloader)
    val_acc = correct / total
    print("Test Loss=%.4f, Test acc=%.4f" % (test_loss / (num_val_steps), val_acc))

    if epoch>=10:
        if val_acc > best_acc:
            best_acc = val_acc
            print("Saving...")
            torch.save(net.state_dict(), os.path.join(CODE_PATH, "net_after_global_iterative_prune.pt"))


Epoch: 0
Global Pruning 8.0% of weights before Epoch 0
[Step=50]	Loss=0.0430	acc=0.9872	6701.8 examples/second
[Step=100]	Loss=0.0476	acc=0.9848	10210.5 examples/second
[Step=150]	Loss=0.0481	acc=0.9844	9593.2 examples/second
Test Loss=0.3232, Test acc=0.9151

Epoch: 1
Global Pruning 16.0% of weights before Epoch 1
[Step=50]	Loss=0.0465	acc=0.9854	6708.4 examples/second
[Step=100]	Loss=0.0478	acc=0.9849	10075.0 examples/second
[Step=150]	Loss=0.0492	acc=0.9839	9992.1 examples/second
Test Loss=0.3241, Test acc=0.9155

Epoch: 2
Global Pruning 24.0% of weights before Epoch 2
[Step=50]	Loss=0.0523	acc=0.9825	6739.3 examples/second
[Step=100]	Loss=0.0510	acc=0.9832	10000.3 examples/second
[Step=150]	Loss=0.0501	acc=0.9838	9596.6 examples/second
Test Loss=0.3289, Test acc=0.9139

Epoch: 3
Global Pruning 32.0% of weights before Epoch 3
[Step=50]	Loss=0.0555	acc=0.9816	6694.5 examples/second
[Step=100]	Loss=0.0539	acc=0.9825	10436.5 examples/second
[Step=150]	Loss=0.0539	acc=0.9826	10438.6 ex

In [38]:
net.load_state_dict(torch.load(os.path.join(CODE_PATH, "net_after_global_iterative_prune.pt")))

zeros_sum = 0
total_sum = 0
for name,layer in net.named_modules():
    if (isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear)) and 'id_mapping' not in name:
        np_weight = layer.weight.data.cpu().numpy()
        zeros = np.sum(np_weight == 0)
        total = np_weight.size
        zeros_sum += zeros
        total_sum += total
        print('Sparsity of '+name+': '+str(zeros/total))
print('Total sparsity of: '+str(zeros_sum/total_sum))
test(net)

Sparsity of head_conv.0.conv: 0.3101851851851852
Sparsity of body_op.0.conv1.0.conv: 0.6566840277777778
Sparsity of body_op.0.conv2.0.conv: 0.6393229166666666
Sparsity of body_op.1.conv1.0.conv: 0.6271701388888888
Sparsity of body_op.1.conv2.0.conv: 0.6484375
Sparsity of body_op.2.conv1.0.conv: 0.6315104166666666
Sparsity of body_op.2.conv2.0.conv: 0.6671006944444444
Sparsity of body_op.3.conv1.0.conv: 0.6245659722222222
Sparsity of body_op.3.conv2.0.conv: 0.6885850694444444
Sparsity of body_op.4.conv1.0.conv: 0.7253689236111112
Sparsity of body_op.4.conv2.0.conv: 0.7825520833333334
Sparsity of body_op.5.conv1.0.conv: 0.7243923611111112
Sparsity of body_op.5.conv2.0.conv: 0.8129340277777778
Sparsity of body_op.6.conv1.0.conv: 0.732421875
Sparsity of body_op.6.conv2.0.conv: 0.7647569444444444
Sparsity of body_op.7.conv1.0.conv: 0.7768825954861112
Sparsity of body_op.7.conv2.0.conv: 0.8261176215277778
Sparsity of body_op.8.conv1.0.conv: 0.852783203125
Sparsity of body_op.8.conv2.0.conv: 

Test Accuracy of Global Iterative Pruning + Fine-tuning at 80% sparsity level: 88.41%

### Comparison of Pruning Strategies

| Feature                  | One-Shot Layer-wise Pruning      | Iterative Layer-wise Pruning      | Global Iterative Pruning                     |
|--------------------------|----------------------------------|-----------------------------------|-----------------------------------------------|
| When pruning happens     | Once before fine-tuning          | Gradually before each epoch       | Gradually before each epoch                   |
| Threshold basis          | Per-layer q-th percentile        | Per-layer q-th percentile         | Global q-th percentile across all layers      |
| Sparsity per layer       | Equal (e.g., 70% each)           | Equal (e.g., 8×e% each epoch)     | Unequal (depends on global threshold)         |
| Flexibility              | Low                              | Moderate                          | High                                          |
| Implementation           | Simple                           | Moderate                          | More complex but globally aware               |
| Use-case benefit         | Fast benchmark                   | Controlled gradual sparsity       | Smarter pruning, better overall performance   |


## Quantization

Besides pruning, fixed-point quantization is another important technique applied for deep neural network compression. In this Lab, you will convert the ResNet-20 model we used in previous lab into a quantized model, evaluate is performance and apply finetuning on the model.

#### Implement STE function (FP_layers.py)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"

class STE(torch.autograd.Function):
    @staticmethod
    def forward(ctx, w, bit, symmetric=False):
        '''
        symmetric: True for symmetric quantization, False for asymmetric quantization
        '''
        if bit is None:
            wq = w
        elif bit == 0:
            wq = w * 0
        else:
            # Build a mask to record position of zero weights
            weight_mask = (w != 0).float()

            if symmetric == False:
                # Compute alpha (scale) for dynamic scaling
                w_min = w.min()
                w_max = w.max()
                alpha = w_max - w_min
                beta = w_min

                # Scale w with alpha and beta so that all elements in ws are between 0 and 1
                ws = (w - beta) / (alpha + 1e-8)

                step = 2 ** bit - 1
                # Quantize ws with a linear quantizer to "bit" bits
                R = torch.round(ws * step) / step

                # Scale the quantized weight R back with alpha and beta
                wq = R * alpha + beta

            else:
                # Symmetric quantization (not implemented here)
                wq = w  # Placeholder, for Lab 4

            # Restore zero elements in wq
            wq = wq * weight_mask

        return wq

    @staticmethod
    def backward(ctx, g):
        return g, None, None

class FP_Linear(nn.Module):
    def __init__(self, in_features, out_features, Nbits=None, symmetric=False):
        super(FP_Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.linear = nn.Linear(in_features, out_features)
        self.Nbits = Nbits
        self.symmetric = symmetric

        # Initailization
        m = self.in_features
        n = self.out_features
        self.linear.weight.data.normal_(0, math.sqrt(2. / (m + n)))

    def forward(self, x):
        return F.linear(x, STE.apply(self.linear.weight, self.Nbits, self.symmetric), self.linear.bias)

class FP_Conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False, Nbits=None, symmetric=False):
        super(FP_Conv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
        self.Nbits = Nbits
        self.symmetric = symmetric

        # Initialization
        n = self.kernel_size * self.kernel_size * self.out_channels
        m = self.kernel_size * self.kernel_size * self.in_channels
        self.conv.weight.data.normal_(0, math.sqrt(2. / (n + m)))
        self.sparsity = 1.0

    def forward(self, x):
        return F.conv2d(x, STE.apply(self.conv.weight, self.Nbits, self.symmetric), self.conv.bias, self.conv.stride, self.conv.padding, self.conv.dilation, self.conv.groups)

#### Asymmetric Fixed-point quantization

In [8]:
bit_list = [6, 5, 4, 3, 2]
for Nbits in bit_list:
    print(f"\nTesting ResNet-20 with Nbits = {Nbits} (residual blocks only)")
    # Only residual blocks are quantized; first conv and final FC are still FP
    net = ResNetCIFAR(num_layers=20, Nbits=Nbits)
    net.load_state_dict(torch.load(os.path.join(CODE_PATH, "pretrained_model.pt")))
    net = net.to(device)
    test(net)


Testing ResNet-20 with Nbits = 6 (residual blocks only)
Test Loss=0.3365, Test accuracy=0.9144

Testing ResNet-20 with Nbits = 5 (residual blocks only)
Test Loss=0.3391, Test accuracy=0.9113

Testing ResNet-20 with Nbits = 4 (residual blocks only)
Test Loss=0.3862, Test accuracy=0.8973

Testing ResNet-20 with Nbits = 3 (residual blocks only)
Test Loss=0.9869, Test accuracy=0.7660

Testing ResNet-20 with Nbits = 2 (residual blocks only)
Test Loss=9.6141, Test accuracy=0.0899


In [10]:
finetune_bits = [5, 4, 3, 2]

for Nbits in finetune_bits:
    print(f"\n=== Finetuning Quantized Model (Nbits = {Nbits}) ===")

    # Create quantized model
    net = ResNetCIFAR(num_layers=20, Nbits=Nbits)
    net.load_state_dict(torch.load(os.path.join(CODE_PATH, "pretrained_model.pt")))
    net = net.to(device)

    # Finetune for 20 epochs
    finetune(net, epochs=20, batch_size=256, lr=0.002, reg=1e-4)

    # Load best model after finetune
    net.load_state_dict(torch.load("quantized_net_after_finetune.pt"))

    # save this to local path
    torch.save(net.state_dict(), os.path.join(CODE_PATH, f"quantized_net_after_finetune_Nbits_{Nbits}.pt"))

    print(f"Test accuracy after finetuning (Nbits = {Nbits}):")
    test(net)


=== Finetuning Quantized Model (Nbits = 5) ===
==> Preparing data..

Epoch: 0
[Step=50]	Loss=0.0488	acc=0.9841	6203.2 examples/second
[Step=100]	Loss=0.0511	acc=0.9830	8847.3 examples/second
[Step=150]	Loss=0.0516	acc=0.9829	9115.7 examples/second
Test Loss=0.3284, Test acc=0.9132
Saving...

Epoch: 1
[Step=200]	Loss=0.0453	acc=0.9863	3766.7 examples/second
[Step=250]	Loss=0.0503	acc=0.9837	8454.6 examples/second
[Step=300]	Loss=0.0521	acc=0.9824	9200.8 examples/second
[Step=350]	Loss=0.0517	acc=0.9825	8874.7 examples/second
Test Loss=0.3265, Test acc=0.9134
Saving...

Epoch: 2
[Step=400]	Loss=0.0562	acc=0.9834	3796.1 examples/second
[Step=450]	Loss=0.0488	acc=0.9841	9310.7 examples/second
[Step=500]	Loss=0.0506	acc=0.9831	8655.9 examples/second
[Step=550]	Loss=0.0516	acc=0.9832	8640.5 examples/second
Test Loss=0.3263, Test acc=0.9142
Saving...

Epoch: 3
[Step=600]	Loss=0.0533	acc=0.9837	3719.4 examples/second
[Step=650]	Loss=0.0501	acc=0.9839	9488.1 examples/second
[Step=700]	Loss=0.0

As precision becomes lower, accuracy becomes lower as well. Finetuning helps with improving test accuracy.

#### Quantize pruned model

In [11]:
# Finetuning on pruned & quantized model
finetune_bits = [5, 4, 3, 2]

for Nbits in finetune_bits:
    print(f"\n=== Pruned + Quantized Model (Nbits = {Nbits}) ===")

    # Load pruned model first
    net = ResNetCIFAR(num_layers=20, Nbits=Nbits)
    net.load_state_dict(torch.load(os.path.join(CODE_PATH, "net_after_global_iterative_prune.pt")))
    net = net.to(device)

    print("Accuracy BEFORE finetuning:")
    test(net)

    # Finetune the pruned + quantized model
    finetune(net, epochs=20, batch_size=256, lr=0.002, reg=1e-4)

    # Load the best finetuned version
    net.load_state_dict(torch.load("quantized_net_after_finetune.pt"))

    print("Accuracy AFTER finetuning:")
    test(net)

    # Save the combined model for documentation
    torch.save(net.state_dict(), os.path.join(CODE_PATH, f"pruned_quantized_net_Nbits_{Nbits}.pt"))


=== Pruned + Quantized Model (Nbits = 5) ===
Accuracy BEFORE finetuning:
Test Loss=0.3596, Test accuracy=0.8778
==> Preparing data..

Epoch: 0
[Step=50]	Loss=0.3572	acc=0.8783	6143.1 examples/second
[Step=100]	Loss=0.3099	acc=0.8933	9034.0 examples/second
[Step=150]	Loss=0.2847	acc=0.9007	8874.3 examples/second
Test Loss=0.3866, Test acc=0.8789
Saving...

Epoch: 1
[Step=200]	Loss=0.2337	acc=0.9170	3837.0 examples/second
[Step=250]	Loss=0.2124	acc=0.9244	9051.0 examples/second
[Step=300]	Loss=0.2112	acc=0.9240	8862.6 examples/second
[Step=350]	Loss=0.2094	acc=0.9251	8964.5 examples/second
Test Loss=0.3600, Test acc=0.8860
Saving...

Epoch: 2
[Step=400]	Loss=0.1908	acc=0.9355	3822.6 examples/second
[Step=450]	Loss=0.1923	acc=0.9337	8611.1 examples/second
[Step=500]	Loss=0.1916	acc=0.9329	8659.0 examples/second
[Step=550]	Loss=0.1944	acc=0.9322	9482.1 examples/second
Test Loss=0.3588, Test acc=0.8845

Epoch: 3
[Step=600]	Loss=0.1794	acc=0.9333	3873.2 examples/second
[Step=650]	Loss=0.187

Finetuning improves the performance of quantized + pruned model, but lower bitwidths degrade more when combined with pruning.

#### Symmetric Fixed-point quantization
Symmetric quantization is a commonly used and hardware-friendly quantization approach. In symmetric quantization, the quantization levels are symmetric to zero. Implement symmetric quantization in FP_layers.py

##### Symmetric Quantization in FP_layers.py

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"

class STE(torch.autograd.Function):
    @staticmethod
    def forward(ctx, w, bit, symmetric=False):
        '''
        symmetric: True for symmetric quantization, False for asymmetric quantization
        '''
        if bit is None:
            wq = w
        elif bit == 0:
            wq = w * 0
        else:
            # Build a mask to record position of zero weights
            weight_mask = (w != 0).float()

            if symmetric == False:
                # Asymmetric quantization
                w_min = w.min()
                w_max = w.max()
                alpha = w_max - w_min
                beta = w_min
                ws = (w - beta) / (alpha + 1e-8)
                step = 2 ** bit - 1
                R = torch.round(ws * step) / step
                wq = R * alpha + beta

            else:
                # Symmetric quantization
                w_absmax = torch.max(torch.abs(w))
                alpha = 2 * w_absmax
                ws = (w + w_absmax) / (alpha + 1e-8)  # Scale to [0, 1]
                step = 2 ** bit - 1
                R = torch.round(ws * step) / step
                wq = R * alpha - w_absmax

            # Restore zero elements in wq
            wq = wq * weight_mask

        return wq

    @staticmethod
    def backward(ctx, g):
        return g, None, None

class FP_Linear(nn.Module):
    def __init__(self, in_features, out_features, Nbits=None, symmetric=False):
        super(FP_Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.linear = nn.Linear(in_features, out_features)
        self.Nbits = Nbits
        self.symmetric = symmetric
        m = self.in_features
        n = self.out_features
        self.linear.weight.data.normal_(0, math.sqrt(2. / (m + n)))

    def forward(self, x):
        return F.linear(x, STE.apply(self.linear.weight, self.Nbits, self.symmetric), self.linear.bias)

class FP_Conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False, Nbits=None, symmetric=False):
        super(FP_Conv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
        self.Nbits = Nbits
        self.symmetric = symmetric
        n = self.kernel_size * self.kernel_size * self.out_channels
        m = self.kernel_size * self.kernel_size * self.in_channels
        self.conv.weight.data.normal_(0, math.sqrt(2. / (n + m)))
        self.sparsity = 1.0

    def forward(self, x):
        return F.conv2d(x, STE.apply(self.conv.weight, self.Nbits, self.symmetric), self.conv.bias, self.conv.stride, self.conv.padding, self.conv.dilation, self.conv.groups)

In [7]:
# check the performance of symmetric quantization with 6, 5, 4, 3, 2 bits
bit_list = [6, 5, 4, 3, 2]

for Nbits in bit_list:
    print(f"\nTesting Symmetric Quantized ResNet-20 with Nbits = {Nbits} (residual blocks only)")

    # Enable symmetric quantization in all residual blocks
    net = ResNetCIFAR(num_layers=20, Nbits=Nbits, symmetric=True)
    net.load_state_dict(torch.load(os.path.join(CODE_PATH, "pretrained_model.pt")))
    net = net.to(device)

    test(net)


Testing Symmetric Quantized ResNet-20 with Nbits = 6 (residual blocks only)


100%|██████████| 170M/170M [00:13<00:00, 13.1MB/s]


Test Loss=0.3275, Test accuracy=0.9134

Testing Symmetric Quantized ResNet-20 with Nbits = 5 (residual blocks only)
Test Loss=0.3419, Test accuracy=0.9071

Testing Symmetric Quantized ResNet-20 with Nbits = 4 (residual blocks only)
Test Loss=0.5688, Test accuracy=0.8532

Testing Symmetric Quantized ResNet-20 with Nbits = 3 (residual blocks only)
Test Loss=1.2517, Test accuracy=0.7151

Testing Symmetric Quantized ResNet-20 with Nbits = 2 (residual blocks only)
Test Loss=96.5832, Test accuracy=0.1000


Asymmetric quantization in general results in slightly lower accuracy than symmetric quantization (except the Nbits=2 where Asymmetric quantization results in slightly higher accuracy than symmetric quantization).