In [1]:
import torch, torchvision
import sparseml
from sparseml.pytorch.models import ModelRegistry
from sparseml.pytorch.optim import ScheduledModifierManager
from sparseml.pytorch.utils import TensorBoardLogger, ModuleExporter, get_prunable_layers, tensor_sparsity

from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import Adam, RMSprop

from torchvision import transforms
from tqdm.auto import tqdm
import math

In [2]:
print(torch.__version__)
print(sparseml.__version__)

1.12.1+cu116
1.4.4


## **Step 1: Setup Dataset**

Oxford 102 Flower is an image classification dataset consisting of 102 flower categories. The flowers were chosen to be flowers commonly occurring in the United Kingdom. Each class consists of between 40 and 258 images. The images have large scale, pose and light variations. In addition, there are categories that have large variations within the category, and several very similar categories.

We use the standard PyTorch `datasets` and `dataloaders` to manage the dataset.

In [3]:
NUM_LABELS = 102
BATCH_SIZE = 32
CROP_SIZE = 240
RESIZE_SIZE = 256
INTERPOLATION = transforms.InterpolationMode.BICUBIC
IN_MEAN = [0.485, 0.456, 0.406]
IN_STD = [0.229, 0.224, 0.225]

# imagenet transforms
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(CROP_SIZE, interpolation=INTERPOLATION),
    transforms.autoaugment.TrivialAugmentWide(interpolation=INTERPOLATION),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(mean=IN_MEAN, std=IN_STD),
    transforms.RandomErasing(p=.1)
])

val_transforms = transforms.Compose([
    transforms.Resize(RESIZE_SIZE, interpolation=INTERPOLATION),
    transforms.CenterCrop(CROP_SIZE),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(mean=IN_MEAN, std=IN_STD),
])

# datasets
train_dataset = torchvision.datasets.Flowers102(
    root="./data",
    split="train",
    transform=train_transforms,
    download=True
)

val_dataset = torchvision.datasets.Flowers102(
    root="./data",
    split="val",
    transform=val_transforms,
    download=True
)

# dataloaders
train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=16)
val_loader = DataLoader(val_dataset, BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=16)

## Step 2: Setup PyTorch Training Loop

We will use this training loop below. This is standard PyTorch functionality.

In [4]:
device = "cuda:1" if torch.cuda.is_available() else "cpu"
print(device)

def run_model_one_epoch(model, data_loader, criterion, device, train=False, optimizer=None):
    if train:
        model.train()
    else:
        model.eval()

    running_loss = 0.0
    total_correct = 0
    total_predictions = 0

    # loop through batches
    for step, (inputs, labels) in tqdm(enumerate(data_loader), total=len(data_loader)):
        inputs = inputs.to(device)
        labels = labels.to(device)

        if train:
            optimizer.zero_grad()

        # compute loss, run backpropogation
        outputs, _ = model(inputs)  # model returns logits
        loss = criterion(outputs, labels)
        if train:
            loss.backward()
            optimizer.step()

        running_loss += loss.item()

        # run evaluation
        predictions = outputs.argmax(dim=1)
        total_correct += torch.sum(predictions == labels).item()
        total_predictions += inputs.size(0)

    # return loss and evaluation metric
    loss = running_loss / (step + 1.0)
    accuracy = total_correct / total_predictions
    return loss, accuracy

cuda:1


## **Step 3: Train EfficientNetv2s on Flowers102**

First, we will train a dense version of EfficientNet on the Flowers dataset.

In [5]:
# download pre-trained model, setup classification head
ZOO_STUB = "zoo:cv/classification/efficientnet-b1/pytorch/sparseml/imagenet/base-none"

model = ModelRegistry.create(
    key="efficientnet-b1",
    pretrained_path=ZOO_STUB,
    num_classes=NUM_LABELS,
)

model.to(device)

# setup loss function and optimizer
criterion = CrossEntropyLoss(label_smoothing=0.1)
optimizer = RMSprop(model.parameters(), lr=8e-3) # lr will be override by sparseml

In [6]:
!cat ../recipe.dense.yaml

# Epoch and Learning-Rate variables
num_epochs: 20.0
init_lr: 0.001
final_lr: 0.0001

training_modifiers:
  - !EpochRangeModifier
    start_epoch: 0.0
    end_epoch: eval(num_epochs)

  - !LearningRateFunctionModifier
    final_lr: eval(final_lr)
    init_lr: eval(init_lr)
    lr_func: cosine
    start_epoch: 0.0
    end_epoch: eval(num_epochs)


In [7]:
dense_recipe_path = "../recipe.dense.yaml"

In [8]:
# create ScheduledModifierManager and Optimizer wrapper
manager = ScheduledModifierManager.from_yaml(dense_recipe_path)
optimizer = manager.modify(model, optimizer, steps_per_epoch=len(train_loader))

2023-04-06 18:24:16 sparseml.pytorch.utils.logger INFO     Logging all SparseML modifier-level logs to sparse_logs/06-04-2023_18.24.16.log
INFO:sparseml.pytorch.utils.logger:Logging all SparseML modifier-level logs to sparse_logs/06-04-2023_18.24.16.log


In [9]:
epoch = 0
for epoch in range(manager.max_epochs):
    # run training loop
    epoch_name = f"{epoch + 1}/{manager.max_epochs}"
    
    print(f"Running Training Epoch {epoch_name}")
    train_loss, train_acc = run_model_one_epoch(model, train_loader, criterion, device, train=True, optimizer=optimizer)
    print(f"Training Epoch: {epoch_name}\nTraining Loss: {train_loss}\nTop 1 Acc: {train_acc}\n")

    # run validation loop
    print(f"Running Validation Epoch {epoch_name}")
    val_loss, val_acc = run_model_one_epoch(model, val_loader, criterion, device)
    print(f"Validation Epoch: {epoch_name}\nVal Loss: {val_loss}\nTop 1 Acc: {val_acc}\n")

# clean up
manager.finalize(model)

Running Training Epoch 1/20


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 1/20
Training Loss: 4.4801114201545715
Top 1 Acc: 0.04607843137254902

Running Validation Epoch 1/20


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 1/20
Val Loss: 6.156083934009075
Top 1 Acc: 0.08333333333333333

Running Training Epoch 2/20


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 2/20
Training Loss: 3.4624276012182236
Top 1 Acc: 0.23431372549019608

Running Validation Epoch 2/20


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 2/20
Val Loss: 2.7912555411458015
Top 1 Acc: 0.3519607843137255

Running Training Epoch 3/20


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 3/20
Training Loss: 2.54279213398695
Top 1 Acc: 0.4715686274509804

Running Validation Epoch 3/20


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 3/20
Val Loss: 1.981403086334467
Top 1 Acc: 0.6558823529411765

Running Training Epoch 4/20


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 4/20
Training Loss: 2.0715203434228897
Top 1 Acc: 0.6372549019607843

Running Validation Epoch 4/20


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 4/20
Val Loss: 1.7747747823596
Top 1 Acc: 0.7205882352941176

Running Training Epoch 5/20


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 5/20
Training Loss: 1.7515205591917038
Top 1 Acc: 0.7421568627450981

Running Validation Epoch 5/20


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 5/20
Val Loss: 1.6067880429327488
Top 1 Acc: 0.7754901960784314

Running Training Epoch 6/20


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 6/20
Training Loss: 1.565021213144064
Top 1 Acc: 0.8225490196078431

Running Validation Epoch 6/20


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 6/20
Val Loss: 1.447959542274475
Top 1 Acc: 0.8245098039215686

Running Training Epoch 7/20


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 7/20
Training Loss: 1.4621017798781395
Top 1 Acc: 0.8431372549019608

Running Validation Epoch 7/20


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 7/20
Val Loss: 1.3863703589886427
Top 1 Acc: 0.8480392156862745

Running Training Epoch 8/20


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 8/20
Training Loss: 1.3518884666264057
Top 1 Acc: 0.8607843137254902

Running Validation Epoch 8/20


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 8/20
Val Loss: 1.3022270128130913
Top 1 Acc: 0.8735294117647059

Running Training Epoch 9/20


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 9/20
Training Loss: 1.2652370445430279
Top 1 Acc: 0.8911764705882353

Running Validation Epoch 9/20


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 9/20
Val Loss: 1.2661871016025543
Top 1 Acc: 0.8950980392156863

Running Training Epoch 10/20


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 10/20
Training Loss: 1.2551480904221535
Top 1 Acc: 0.884313725490196

Running Validation Epoch 10/20


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 10/20
Val Loss: 1.2433946020901203
Top 1 Acc: 0.9

Running Training Epoch 11/20


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 11/20
Training Loss: 1.1810289807617664
Top 1 Acc: 0.9147058823529411

Running Validation Epoch 11/20


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 11/20
Val Loss: 1.2338924258947372
Top 1 Acc: 0.8921568627450981

Running Training Epoch 12/20


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 12/20
Training Loss: 1.1727205365896225
Top 1 Acc: 0.9215686274509803

Running Validation Epoch 12/20


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 12/20
Val Loss: 1.1792064141482115
Top 1 Acc: 0.907843137254902

Running Training Epoch 13/20


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 13/20
Training Loss: 1.1652562860399485
Top 1 Acc: 0.9127450980392157

Running Validation Epoch 13/20


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 13/20
Val Loss: 1.1808396149426699
Top 1 Acc: 0.9019607843137255

Running Training Epoch 14/20


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 14/20
Training Loss: 1.14595115929842
Top 1 Acc: 0.9333333333333333

Running Validation Epoch 14/20


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 14/20
Val Loss: 1.1687483321875334
Top 1 Acc: 0.9058823529411765

Running Training Epoch 15/20


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 15/20
Training Loss: 1.12690051458776
Top 1 Acc: 0.9313725490196079

Running Validation Epoch 15/20


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 15/20
Val Loss: 1.1451063640415668
Top 1 Acc: 0.9245098039215687

Running Training Epoch 16/20


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 16/20
Training Loss: 1.122094664722681
Top 1 Acc: 0.9294117647058824

Running Validation Epoch 16/20


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 16/20
Val Loss: 1.1377974078059196
Top 1 Acc: 0.9235294117647059

Running Training Epoch 17/20


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 17/20
Training Loss: 1.0358950830996037
Top 1 Acc: 0.9588235294117647

Running Validation Epoch 17/20


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 17/20
Val Loss: 1.1366141978651285
Top 1 Acc: 0.9294117647058824

Running Training Epoch 18/20


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 18/20
Training Loss: 1.0829532723873854
Top 1 Acc: 0.9450980392156862

Running Validation Epoch 18/20


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 18/20
Val Loss: 1.1273325141519308
Top 1 Acc: 0.9235294117647059

Running Training Epoch 19/20


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 19/20
Training Loss: 1.0940583441406488
Top 1 Acc: 0.9392156862745098

Running Validation Epoch 19/20


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 19/20
Val Loss: 1.1134563777595758
Top 1 Acc: 0.9303921568627451

Running Training Epoch 20/20


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 20/20
Training Loss: 1.0575630385428667
Top 1 Acc: 0.9558823529411765

Running Validation Epoch 20/20


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 20/20
Val Loss: 1.1045657861977816
Top 1 Acc: 0.9333333333333333



Export the model in case we want to reload in the future, so we do not have to rerun.

In [10]:
save_dir = "efficientnet-models-flowers"
model_name = "b1-dense-model.pth"

In [11]:
exporter = ModuleExporter(model, output_dir=save_dir)
exporter.export_pytorch(name="b1-dense-model.pth")
exporter.export_onnx(torch.randn(1, 3, CROP_SIZE, CROP_SIZE), name="b1-dense-model.onnx", convert_qat=True)



In [12]:
torch.cuda.empty_cache()

## Step 4: Quantize The Model

With a model trained on Flowers, we are now ready to apply the QAT algorithm to quantize the model.

In [15]:
# SparseZoo stub to pre-trained sparse-quantized ResNet-50 for imagenet dataset
model = ModelRegistry.create(
    key="efficientnet-b1",
    pretrained_path=save_dir + "/training/" + model_name,
    num_classes=NUM_LABELS,
)

model.to(device)

# setup loss function and optimizer
criterion = CrossEntropyLoss(label_smoothing=0.1)
optimizer = RMSprop(model.parameters(), lr=8e-3) # lr will be override by sparseml

In [16]:
quant_recipe_path = "../recipe.quant.yaml"

In [17]:
!cat ../recipe.quant.yaml

# Epoch and Learning-Rate variables
num_epochs: 6.0
quantization_epochs: 5.0
init_lr: 0.0001
warmup_lr: 0.00001

training_modifiers:
  - !EpochRangeModifier
    start_epoch: 0.0
    end_epoch: eval(num_epochs)

  - !SetLearningRateModifier
    start_epoch: 0.0
    learning_rate: eval(warmup_lr)   

  - !LearningRateFunctionModifier
    final_lr: 0.0
    init_lr: eval(init_lr)
    lr_func: cosine
    start_epoch: eval(num_epochs-quantization_epochs)
    end_epoch: eval(num_epochs)

quantization_modifiers:
  - !QuantizationModifier
    start_epoch: eval(num_epochs - quantization_epochs)
    ignore: ['classifier', 'AdaptiveAvgPool2d', 'Sigmoid', 'sections.1.0.spatial']
    disable_quantization_observer_epoch: eval(num_epochs - 0.1)
    freeze_bn_stats_epoch: eval(num_epochs - quantization_epochs)


In [18]:
# create ScheduledModifierManager and Optimizer wrapper
manager = ScheduledModifierManager.from_yaml(quant_recipe_path)
logger = TensorBoardLogger(log_path="./tensorboard_outputs/efficientnet/quant-run")
optimizer = manager.modify(model, optimizer, loggers=[logger], steps_per_epoch=len(train_loader))

In [19]:
# run QAT algorithm
epoch = 0
for epoch in range(manager.max_epochs):
    # run training loop
    epoch_name = f"{epoch + 1}/{manager.max_epochs}"
    
    print(f"Running Training Epoch {epoch_name}")
    train_loss, train_acc = run_model_one_epoch(model, train_loader, criterion, device, train=True, optimizer=optimizer)
    print(f"Training Epoch: {epoch_name}\nTraining Loss: {train_loss}\nTop 1 Acc: {train_acc}\n")

    # run validation loop
    print(f"Running Validation Epoch {epoch_name}")
    val_loss, val_acc = run_model_one_epoch(model, val_loader, criterion, device)
    print(f"Validation Epoch: {epoch_name}\nVal Loss: {val_loss}\nTop 1 Acc: {val_acc}\n")
    
    logger.log_scalar("Metrics/Loss (Train)", train_loss, epoch)
    logger.log_scalar("Metrics/Accuracy (Train)", train_acc, epoch)
    logger.log_scalar("Metrics/Loss (Validation)", val_loss, epoch)
    logger.log_scalar("Metrics/Accuracy (Validation)", val_acc, epoch)

manager.finalize(model)

Running Training Epoch 1/6


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 1/6
Training Loss: 1.0479964073747396
Top 1 Acc: 0.9490196078431372

Running Validation Epoch 1/6


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 1/6
Val Loss: 1.0989590268582106
Top 1 Acc: 0.9382352941176471

Running Training Epoch 2/6


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 2/6
Training Loss: 1.7036013696342707
Top 1 Acc: 0.75

Running Validation Epoch 2/6


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 2/6
Val Loss: 1.292163461446762
Top 1 Acc: 0.8705882352941177

Running Training Epoch 3/6


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 3/6
Training Loss: 1.1844023391604424
Top 1 Acc: 0.9127450980392157

Running Validation Epoch 3/6


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 3/6
Val Loss: 1.2466791197657585
Top 1 Acc: 0.884313725490196

Running Training Epoch 4/6


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 4/6
Training Loss: 1.129274770617485
Top 1 Acc: 0.9245098039215687

Running Validation Epoch 4/6


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 4/6
Val Loss: 1.1841304209083319
Top 1 Acc: 0.9098039215686274

Running Training Epoch 5/6


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 5/6
Training Loss: 1.1279195342212915
Top 1 Acc: 0.9313725490196079

Running Validation Epoch 5/6


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 5/6
Val Loss: 1.1864626351743937
Top 1 Acc: 0.9107843137254902

Running Training Epoch 6/6


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 6/6
Training Loss: 1.110894814133644
Top 1 Acc: 0.9362745098039216

Running Validation Epoch 6/6


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 6/6
Val Loss: 1.1804747302085161
Top 1 Acc: 0.9127450980392157



In [21]:
save_dir = "efficientnet-models-flowers"
exporter = ModuleExporter(model, output_dir=save_dir)
exporter.export_pytorch(name="b1-quant-model.pth")
exporter.export_onnx(torch.randn(1, 3, CROP_SIZE, CROP_SIZE), name="b1-quant-model.onnx", convert_qat=True)

2023-04-06 18:34:27 sparseml.pytorch.sparsification.quantization.quantize_qat_export INFO     Converted 113 quantizable Conv ops with weight and bias to ConvInteger and Add
INFO:sparseml.pytorch.sparsification.quantization.quantize_qat_export:Converted 113 quantizable Conv ops with weight and bias to ConvInteger and Add
2023-04-06 18:34:27 sparseml.pytorch.sparsification.quantization.quantize_qat_export INFO     Converted 0 quantizable Gemm ops with weight and bias to MatMulInteger and Add
INFO:sparseml.pytorch.sparsification.quantization.quantize_qat_export:Converted 0 quantizable Gemm ops with weight and bias to MatMulInteger and Add
