# Sparsifying DenseNet121 from Scratch (Flower102)

In this example, we will demonstrate how to sparsify an image classification model from scratch using SparseML's PyTorch integration. We train and prune [DenseNet121](https://pytorch.org/vision/main/models/generated/torchvision.models.densenet121.html) on the downstream [Oxford Flower 102 dataset](https://pytorch.org/vision/main/generated/torchvision.datasets.Flowers102.html#:~:text=Oxford%20102%20Flower%20is%20an,scale%2C%20pose%20and%20light%20variations) using the Global Magnitude Pruning algorithm. 

## Agenda

There are a few steps:

 1. Setup the dataset
 2. Setup the PyTorch training loop
 3. Train a dense version of DenseNet121
 4. Run the GMP pruning algorithm and QAT quantization algorithm on the dense model
 
## Installation

Install SparseML with `pip`:

```
pip install sparseml[torchvision]
```

In [1]:
import torch
import sparseml
import torchvision
from sparseml.pytorch.optim import ScheduledModifierManager
from sparseml.pytorch.utils import TensorBoardLogger, ModuleExporter, get_prunable_layers, tensor_sparsity

from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import Adam

from torchvision import transforms
from tqdm.auto import tqdm
import math

In [2]:
print(torch.__version__)

1.12.1+cu116


## **Step 1: Setup Dataset**

Oxford 102 Flower is an image classification dataset consisting of 102 flower categories. The flowers were chosen to be flowers commonly occurring in the United Kingdom. Each class consists of between 40 and 258 images. The images have large scale, pose and light variations. In addition, there are categories that have large variations within the category, and several very similar categories.

We use the standard PyTorch `datasets` and `dataloaders` to manage the dataset.

In [None]:
NUM_LABELS = 102
BATCH_SIZE = 16

# imagenet transforms
imagenet_transform = transforms.Compose([
   transforms.Resize(size=256, interpolation=transforms.InterpolationMode.BILINEAR, max_size=None, antialias=None),
   transforms.CenterCrop(size=(224, 224)),
   transforms.ToTensor(),
   transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# datasets
train_dataset = torchvision.datasets.Flowers102(
    root="./data",
    split="train",
    transform=imagenet_transform,
    download=True
)

val_dataset = torchvision.datasets.Flowers102(
    root="./data",
    split="val",
    transform=imagenet_transform,
    download=True
)

# dataloaders
train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=16)
val_loader = DataLoader(val_dataset, BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=16)

Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/102flowers.tgz to data/flowers-102/102flowers.tgz


  0%|          | 0/344862509 [00:00<?, ?it/s]

## Step 2: Setup PyTorch Training Loop

We will use this training loop below. This is standard PyTorch functionality.

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

def run_model_one_epoch(model, data_loader, criterion, device, train=False, optimizer=None):
    if train:
        model.train()
    else:
        model.eval()

    running_loss = 0.0
    total_correct = 0
    total_predictions = 0

    # loop through batches
    for step, (inputs, labels) in tqdm(enumerate(data_loader), total=len(data_loader)):
        inputs = inputs.to(device)
        labels = labels.to(device)

        if train:
            optimizer.zero_grad()

        # compute loss, run backpropogation
        outputs = model(inputs)  # model returns logits
        loss = criterion(outputs, labels)
        if train:
            loss.backward()
            optimizer.step()

        running_loss += loss.item()

        # run evaluation
        predictions = outputs.argmax(dim=1)
        total_correct += torch.sum(predictions == labels).item()
        total_predictions += inputs.size(0)

    # return loss and evaluation metric
    loss = running_loss / (step + 1.0)
    accuracy = total_correct / total_predictions
    return loss, accuracy

## **Step 3: Train DenseNet121 on Flowers102**

First, we will train a dense version of DenseNet121 on the Flowers dataset.

In [None]:
# download pre-trained model, setup classification head
model = torchvision.models.densenet121(weights=torchvision.models.DenseNet121_Weights.DEFAULT)
model.classifier = torch.nn.Linear(model.classifier.in_features, NUM_LABELS)
model.to(device)

# setup loss function and optimizer
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=8e-3) # lr will be override by sparseml

Next, we will use SparseML's recipes to set the hyperparameters of training loop. In this case, we will use the following recipe:

```yaml
# Epoch and Learning-Rate variables
num_epochs: 15.0
init_lr: 0.001

training_modifiers:
  - !EpochRangeModifier
    start_epoch: 0.0
    end_epoch: eval(num_epochs)

  - !LearningRateFunctionModifier
    final_lr: 0.0
    init_lr: eval(init_lr)
    lr_func: cosine
    start_epoch: 0.0
    end_epoch: eval(num_epochs)
```

As you can see, the recipe includes an `!EpochRangeModifier` and a `!LearningRateFunctionModifier`. These modifiers simply set the number of epochs to train for and the learning rate schedule. As a result, the final model will be dense.

In [None]:
dense_recipe_path = "./recipe.dense.yaml"

In [None]:
!cat ./recipe.dense.yaml

Next, we use SparseML's `ScheduledModifierManager` to parse and apply the recipe. The `manager.modify` function modifies and wraps the `model` and `optimizer` with the instructions from the recipe. You can use the `model` and `optimizer` just like standard PyTorch objects.

In [None]:
# create ScheduledModifierManager and Optimizer wrapper
manager = ScheduledModifierManager.from_yaml(dense_recipe_path)
optimizer = manager.modify(model, optimizer, steps_per_epoch=len(train_loader))

Kick off the transfer learning loop. Our run reached ~91% validation accuracy after 15 epochs.

In [None]:
epoch = 0
for epoch in range(manager.max_epochs):
    # run training loop
    epoch_name = f"{epoch + 1}/{manager.max_epochs}"
    
    print(f"Running Training Epoch {epoch_name}")
    train_loss, train_acc = run_model_one_epoch(model, train_loader, criterion, device, train=True, optimizer=optimizer)
    print(f"Training Epoch: {epoch_name}\nTraining Loss: {train_loss}\nTop 1 Acc: {train_acc}\n")

    # run validation loop
    print(f"Running Validation Epoch {epoch_name}")
    val_loss, val_acc = run_model_one_epoch(model, val_loader, criterion, device)
    print(f"Validation Epoch: {epoch_name}\nVal Loss: {val_loss}\nTop 1 Acc: {val_acc}\n")

# clean up
manager.finalize(model)

Export the model in case we want to reload in the future, so we do not have to rerun.

In [10]:
save_dir = "densenet-models"
exporter = ModuleExporter(model, output_dir=save_dir)
exporter.export_pytorch(name="dense-model.pth")

In [11]:
torch.cuda.empty_cache()

## Step 4: Prune The Mo

With a model trained on Flowers, we are now ready to apply the GMP algorithm to prune the model. The GMP algorithm is an interative pruning algorithm. At the end of each epoch, we identify the lowest magnitude weights (those closest to 0) and remove them from the network starting from an initial level of sparsity until a final level of sparsity. The remaining nonzero weights are then fine-tuned onto training dataset.

After we prune the model, we will apply QAT to convert the weights from FP32 to INT8.

In [12]:
# first, load the trained model from Part 3
checkpoint = torch.load("./densenet-models/training/dense-model.pth")
model = torchvision.models.densenet121()
model.classifier = torch.nn.Linear(model.classifier.in_features, NUM_LABELS)
model.load_state_dict(checkpoint['state_dict'])
model.to(device)

# setup loss function and optimizer, LR will be overriden by sparseml
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=8e-3)

Next, we need to create a SparseML recipe which includes the GMP algorithm. The `!GlobalMagnitudePruningModifier` modifier instructs SparseML to apply the GMP algorithm at a global level (pruning the lowest magnitude weights across all layers).

Firstly, we need to decide identify which parameters of the model to apply the GMP algorithm to. We can use the `get_prunable_layers` function to inspect:

In [13]:
# print parameters
for (name, layer) in get_prunable_layers(model):
    print(f"{name}")

features.conv0
features.denseblock1.denselayer1.conv1
features.denseblock1.denselayer1.conv2
features.denseblock1.denselayer2.conv1
features.denseblock1.denselayer2.conv2
features.denseblock1.denselayer3.conv1
features.denseblock1.denselayer3.conv2
features.denseblock1.denselayer4.conv1
features.denseblock1.denselayer4.conv2
features.denseblock1.denselayer5.conv1
features.denseblock1.denselayer5.conv2
features.denseblock1.denselayer6.conv1
features.denseblock1.denselayer6.conv2
features.transition1.conv
features.denseblock2.denselayer1.conv1
features.denseblock2.denselayer1.conv2
features.denseblock2.denselayer2.conv1
features.denseblock2.denselayer2.conv2
features.denseblock2.denselayer3.conv1
features.denseblock2.denselayer3.conv2
features.denseblock2.denselayer4.conv1
features.denseblock2.denselayer4.conv2
features.denseblock2.denselayer5.conv1
features.denseblock2.denselayer5.conv2
features.denseblock2.denselayer6.conv1
features.denseblock2.denselayer6.conv2
features.denseblock2.de

We will apply GMP to all layers with `__ALL_PRUNABLE__`. Here is what the recipe looks like:

```yaml
# Epoch hyperparams
stabilization_epochs: 1.0
pruning_epochs: 9.0
finetuning_epochs: 5.0
quantization_epochs: 3.0

# Learning rate hyperparams
init_lr: 0.0001
final_lr: 0.00005

# Pruning hyperparams
init_sparsity: 0.05
final_sparsity: 0.9

# Stabalization Stage
training_modifiers:
  - !EpochRangeModifier
    start_epoch: 0.0
    end_epoch: eval(stabilization_epochs + pruning_epochs + finetuning_epochs + quantization_epochs)

  - !SetLearningRateModifier
    start_epoch: 0.0
    learning_rate: eval(init_lr)

# Pruning Stage
pruning_modifiers:
  - !LearningRateFunctionModifier
    init_lr: eval(init_lr)
    final_lr: eval(final_lr)
    lr_func: cosine
    start_epoch: eval(stabilization_epochs)
    end_epoch: eval(stabilization_epochs + pruning_epochs)

  - !GlobalMagnitudePruningModifier
    init_sparsity: eval(init_sparsity)
    final_sparsity: eval(final_sparsity)
    start_epoch: eval(stabilization_epochs)
    end_epoch: eval(stabilization_epochs + pruning_epochs)
    update_frequency: 0.5
    params: __ALL_PRUNABLE__
    leave_enabled: True

# Finetuning Stage
finetuning_modifiers:
  - !LearningRateFunctionModifier
    init_lr: eval(init_lr)
    final_lr: eval(final_lr)
    lr_func: cosine
    start_epoch: eval(stabilization_epochs + pruning_epochs)
    end_epoch: eval(stabilization_epochs + pruning_epochs + finetuning_epochs)

# Quantization hyperparams
quantization_modifiers:
  - !QuantizationModifier
    start_epoch: eval(stabilization_epochs + pruning_epochs + finetuning_epochs)
```

This recipe specifies that we will run the GMP algorithm for the first 10 epochs. We start at an `init_sparsity` level of 5% and gradually increase sparsity to a `final_sparsity` level of 90% following a `cubic` curve across each of the layers in the network.

Over the next 5 epochs, we fine-tune the 90% pruned model further. Since we set `leave_enabled=True` the sparsity level will be maintained as the fine-tuning occurs.

Over the final 3 epochs, we apply QAT to quantize the model.

In [14]:
pruning_recipe_path = "./recipe.prune_quant.yaml"

In [15]:
!cat ./recipe.prune_quant.yaml

# Epoch hyperparams
stabilization_epochs: 1.0
pruning_epochs: 9.0
finetuning_epochs: 5.0
quantization_epochs: 3.0

# Learning rate hyperparams
init_lr: 0.0001
final_lr: 0.00005

# Pruning hyperparams
init_sparsity: 0.05
final_sparsity: 0.9

# Stabalization Stage
training_modifiers:
  - !EpochRangeModifier
    start_epoch: 0.0
    end_epoch: eval(stabilization_epochs + pruning_epochs + finetuning_epochs + quantization_epochs)

  - !SetLearningRateModifier
    start_epoch: 0.0
    learning_rate: eval(init_lr)

# Pruning Stage
pruning_modifiers:
  - !LearningRateFunctionModifier
    init_lr: eval(init_lr)
    final_lr: eval(final_lr)
    lr_func: cosine
    start_epoch: eval(stabilization_epochs)
    end_epoch: eval(stabilization_epochs + pruning_epochs)

  - !GlobalMagnitudePruningModifier
    init_sparsity: eval(init_sparsity)
    final_sparsity: eval(final_sparsity)
    start_epoch: eval(stabilization_epochs)
    end_epoch: eval(stabilization_epochs

In [16]:
# create ScheduledModifierManager and Optimizer wrapper
manager = ScheduledModifierManager.from_yaml(pruning_recipe_path)
logger = TensorBoardLogger(log_path="./tensorboard_outputs/densenet/pruning-run")
optimizer = manager.modify(model, optimizer, loggers=[logger], steps_per_epoch=len(train_loader))

Next, kick off the GMP training loop. 

As you can see, we use the wrapped `optimizer` and `model` in the same way as above. SparseML parsed the recipe and updated the `optimizer` with the logic of GMP algorithm from the recipe. This allows you to use the `optimizer` and `model` as usual, with all of the pruning-related logic handled by SparseML.

Our 90% pruned model reaches ~90% validation accuracy (vs ~90% for the dense model).

In [17]:
# run GMP algorithm
epoch = 0
for epoch in range(manager.max_epochs):
    # run training loop
    epoch_name = f"{epoch + 1}/{manager.max_epochs}"
    
    print(f"Running Training Epoch {epoch_name}")
    train_loss, train_acc = run_model_one_epoch(model, train_loader, criterion, device, train=True, optimizer=optimizer)
    print(f"Training Epoch: {epoch_name}\nTraining Loss: {train_loss}\nTop 1 Acc: {train_acc}\n")

    # run validation loop
    print(f"Running Validation Epoch {epoch_name}")
    val_loss, val_acc = run_model_one_epoch(model, val_loader, criterion, device)
    print(f"Validation Epoch: {epoch_name}\nVal Loss: {val_loss}\nTop 1 Acc: {val_acc}\n")
    
    logger.log_scalar("Metrics/Loss (Train)", train_loss, epoch)
    logger.log_scalar("Metrics/Accuracy (Train)", train_acc, epoch)
    logger.log_scalar("Metrics/Loss (Validation)", val_loss, epoch)
    logger.log_scalar("Metrics/Accuracy (Validation)", val_acc, epoch)

manager.finalize(model)

Running Training Epoch 1/18


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 1/18
Training Loss: 0.01275764120509848
Top 1 Acc: 1.0

Running Validation Epoch 1/18


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 1/18
Val Loss: 0.3881779580115108
Top 1 Acc: 0.9009803921568628

Running Training Epoch 2/18


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 2/18
Training Loss: 0.006604243415495148
Top 1 Acc: 1.0

Running Validation Epoch 2/18


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 2/18
Val Loss: 0.36990982507995795
Top 1 Acc: 0.9019607843137255

Running Training Epoch 3/18


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 3/18
Training Loss: 0.007901180022599874
Top 1 Acc: 0.9990196078431373

Running Validation Epoch 3/18


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 3/18
Val Loss: 0.36820574732155364
Top 1 Acc: 0.9029411764705882

Running Training Epoch 4/18


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 4/18
Training Loss: 0.00951768000231823
Top 1 Acc: 1.0

Running Validation Epoch 4/18


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 4/18
Val Loss: 0.35841650055590435
Top 1 Acc: 0.9137254901960784

Running Training Epoch 5/18


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 5/18
Training Loss: 0.017474219625000842
Top 1 Acc: 1.0

Running Validation Epoch 5/18


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 5/18
Val Loss: 0.39351300171983894
Top 1 Acc: 0.907843137254902

Running Training Epoch 6/18


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 6/18
Training Loss: 0.043898234391235746
Top 1 Acc: 1.0

Running Validation Epoch 6/18


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 6/18
Val Loss: 0.41797342558857054
Top 1 Acc: 0.9058823529411765

Running Training Epoch 7/18


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 7/18
Training Loss: 0.09149695426458493
Top 1 Acc: 0.9990196078431373

Running Validation Epoch 7/18


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 7/18
Val Loss: 0.4712565544323297
Top 1 Acc: 0.8931372549019608

Running Training Epoch 8/18


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 8/18
Training Loss: 0.1224245154298842
Top 1 Acc: 1.0

Running Validation Epoch 8/18


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 8/18
Val Loss: 0.5245802142017055
Top 1 Acc: 0.8921568627450981

Running Training Epoch 9/18


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 9/18
Training Loss: 0.12702065438497812
Top 1 Acc: 0.9980392156862745

Running Validation Epoch 9/18


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 9/18
Val Loss: 0.5165597766754217
Top 1 Acc: 0.8980392156862745

Running Training Epoch 10/18


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 10/18
Training Loss: 0.08719327434664592
Top 1 Acc: 0.9990196078431373

Running Validation Epoch 10/18


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 10/18
Val Loss: 0.48045873222872615
Top 1 Acc: 0.907843137254902

Running Training Epoch 11/18


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 11/18
Training Loss: 0.06748799287015572
Top 1 Acc: 1.0

Running Validation Epoch 11/18


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 11/18
Val Loss: 0.4438641173474025
Top 1 Acc: 0.9137254901960784

Running Training Epoch 12/18


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 12/18
Training Loss: 0.050291253632167354
Top 1 Acc: 1.0

Running Validation Epoch 12/18


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 12/18
Val Loss: 0.43283609559875913
Top 1 Acc: 0.9107843137254902

Running Training Epoch 13/18


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 13/18
Training Loss: 0.0417996046890039
Top 1 Acc: 1.0

Running Validation Epoch 13/18


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 13/18
Val Loss: 0.42522090824786574
Top 1 Acc: 0.9088235294117647

Running Training Epoch 14/18


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 14/18
Training Loss: 0.03894200167269446
Top 1 Acc: 1.0

Running Validation Epoch 14/18


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 14/18
Val Loss: 0.42405990276893135
Top 1 Acc: 0.9098039215686274

Running Training Epoch 15/18


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 15/18
Training Loss: 0.03521087218541652
Top 1 Acc: 1.0

Running Validation Epoch 15/18


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 15/18
Val Loss: 0.42383157154836226
Top 1 Acc: 0.9117647058823529

Running Training Epoch 16/18


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 16/18
Training Loss: 0.03761145524913445
Top 1 Acc: 1.0

Running Validation Epoch 16/18


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 16/18
Val Loss: 0.41609785216860473
Top 1 Acc: 0.9127450980392157

Running Training Epoch 17/18


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 17/18
Training Loss: 0.0321939253481105
Top 1 Acc: 1.0

Running Validation Epoch 17/18


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 17/18
Val Loss: 0.42837292490003165
Top 1 Acc: 0.9058823529411765

Running Training Epoch 18/18


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 18/18
Training Loss: 0.03206482196401339
Top 1 Acc: 1.0

Running Validation Epoch 18/18


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 18/18
Val Loss: 0.4149022806814173
Top 1 Acc: 0.9107843137254902



The resulting model is is 90% sparse and quantized, while achieving validation accuracy of ~91% (vs the unoptimized dense model at ~91%) without much hyperparameter search. Key hyperparameter experiments you may want to run include:
- Learning rate
- Learning rate schedule
- Sparsity level
- Number of pruning epochs

In [18]:
print("Sparsity By Layer:")
for (name, layer) in get_prunable_layers(model):
    print(f"{name}.weight: {tensor_sparsity(layer.weight).item():.4f}")

Sparsity By Layer:
features.conv0.module.weight: 0.4833
features.denseblock1.denselayer1.conv1.module.weight: 0.7510
features.denseblock1.denselayer1.conv2.module.weight: 0.8344
features.denseblock1.denselayer2.conv1.module.weight: 0.7566
features.denseblock1.denselayer2.conv2.module.weight: 0.8695
features.denseblock1.denselayer3.conv1.module.weight: 0.7657
features.denseblock1.denselayer3.conv2.module.weight: 0.8344
features.denseblock1.denselayer4.conv1.module.weight: 0.8263
features.denseblock1.denselayer4.conv2.module.weight: 0.8338
features.denseblock1.denselayer5.conv1.module.weight: 0.8754
features.denseblock1.denselayer5.conv2.module.weight: 0.8872
features.denseblock1.denselayer6.conv1.module.weight: 0.8493
features.denseblock1.denselayer6.conv2.module.weight: 0.8465
features.transition1.conv.module.weight: 0.7229
features.denseblock2.denselayer1.conv1.module.weight: 0.9299
features.denseblock2.denselayer1.conv2.module.weight: 0.8891
features.denseblock2.denselayer2.conv1.mod

In [19]:
save_dir = "densenet-models"
exporter = ModuleExporter(model, output_dir=save_dir)
exporter.export_pytorch(name="densenet-pruned-int8.pth")