In [1]:
import torch, torchvision
import sparseml
from sparseml.pytorch.models import ModelRegistry
from sparseml.pytorch.optim import ScheduledModifierManager
from sparseml.pytorch.utils import TensorBoardLogger, ModuleExporter, get_prunable_layers, tensor_sparsity

from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss, Linear
from torch.optim import Adam, RMSprop

from torchvision import transforms
from tqdm.notebook import tqdm
import math

In [2]:
print(torch.__version__)
print(sparseml.__version__)

1.12.1+cu116
1.4.4


## **Step 1: Setup Dataset**

Oxford 102 Flower is an image classification dataset consisting of 102 flower categories. The flowers were chosen to be flowers commonly occurring in the United Kingdom. Each class consists of between 40 and 258 images. The images have large scale, pose and light variations. In addition, there are categories that have large variations within the category, and several very similar categories.

We use the standard PyTorch `datasets` and `dataloaders` to manage the dataset.

In [3]:
NUM_LABELS = 102
BATCH_SIZE = 16
CROP_SIZE = 288
RESIZE_SIZE = 288
INTERPOLATION = transforms.InterpolationMode.BICUBIC
IN_MEAN = [0.485, 0.456, 0.406]
IN_STD = [0.229, 0.224, 0.225]

# imagenet transforms
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(CROP_SIZE, interpolation=INTERPOLATION),
    transforms.autoaugment.TrivialAugmentWide(interpolation=INTERPOLATION),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(mean=IN_MEAN, std=IN_STD),
    transforms.RandomErasing(p=.1)
])

val_transforms = transforms.Compose([
    transforms.Resize(RESIZE_SIZE, interpolation=INTERPOLATION),
    transforms.CenterCrop(CROP_SIZE),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(mean=IN_MEAN, std=IN_STD),
])

# datasets
train_dataset = torchvision.datasets.Flowers102(
    root="./data",
    split="train",
    transform=train_transforms,
    download=True
)

val_dataset = torchvision.datasets.Flowers102(
    root="./data",
    split="val",
    transform=val_transforms,
    download=True
)

# dataloaders
train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=16)
val_loader = DataLoader(val_dataset, BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=16)

## Step 2: Setup PyTorch Training Loop

We will use this training loop below. This is standard PyTorch functionality.

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

def run_model_one_epoch(model, data_loader, criterion, device, train=False, optimizer=None):
    if train:
        model.train()
    else:
        model.eval()

    running_loss = 0.0
    total_correct = 0
    total_predictions = 0

    # loop through batches
    for step, (inputs, labels) in tqdm(enumerate(data_loader), total=len(data_loader)):
        inputs = inputs.to(device)
        labels = labels.to(device)

        if train:
            optimizer.zero_grad()

        # compute loss, run backpropogation
        outputs = model(inputs)  # model returns logits
        loss = criterion(outputs, labels)
        if train:
            loss.backward()
            optimizer.step()

        running_loss += loss.item()

        # run evaluation
        predictions = outputs.argmax(dim=1)
        total_correct += torch.sum(predictions == labels).item()
        total_predictions += inputs.size(0)

    # return loss and evaluation metric
    loss = running_loss / (step + 1.0)
    accuracy = total_correct / total_predictions
    return loss, accuracy

cuda


## **Step 3: Train EfficientNetv2s on Flowers102**

First, we will train a dense version of EfficientNet on the Flowers dataset.

In [5]:
# download pre-trained model, setup classification head
model = torchvision.models.efficientnet_b2(weights=torchvision.models.EfficientNet_B2_Weights.DEFAULT)
model.classifier[1] = torch.nn.Linear(in_features=model.classifier[1].in_features, out_features=NUM_LABELS)
model.to(device)

# setup loss function and optimizer
criterion = CrossEntropyLoss(label_smoothing=0.1)
optimizer = Adam(model.parameters(), lr=8e-3) # lr will be override by sparseml

In [6]:
!cat recipe.dense.yaml

# Epoch and Learning-Rate variables
num_epochs: 15.0
init_lr: 0.001
final_lr: 0.0005

training_modifiers:
  - !EpochRangeModifier
    start_epoch: 0.0
    end_epoch: eval(num_epochs)

  - !LearningRateFunctionModifier
    final_lr: eval(final_lr)
    init_lr: eval(init_lr)
    lr_func: cosine
    start_epoch: 0.0
    end_epoch: eval(num_epochs)


In [7]:
dense_recipe_path = "recipe.dense.yaml"

In [8]:
# create ScheduledModifierManager and Optimizer wrapper
manager = ScheduledModifierManager.from_yaml(dense_recipe_path)
optimizer = manager.modify(model, optimizer, steps_per_epoch=len(train_loader))

2023-04-24 21:03:32 sparseml.pytorch.utils.logger INFO     Logging all SparseML modifier-level logs to sparse_logs/24-04-2023_21.03.32.log


In [9]:
epoch = 0
for epoch in range(manager.max_epochs):
    # run training loop
    epoch_name = f"{epoch + 1}/{manager.max_epochs}"
    
    print(f"Running Training Epoch {epoch_name}")
    train_loss, train_acc = run_model_one_epoch(model, train_loader, criterion, device, train=True, optimizer=optimizer)
    print(f"Training Epoch: {epoch_name}\nTraining Loss: {train_loss}\nTop 1 Acc: {train_acc}\n")

    # run validation loop
    print(f"Running Validation Epoch {epoch_name}")
    val_loss, val_acc = run_model_one_epoch(model, val_loader, criterion, device)
    print(f"Validation Epoch: {epoch_name}\nVal Loss: {val_loss}\nTop 1 Acc: {val_acc}\n")

# clean up
manager.finalize(model)

Running Training Epoch 1/15


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 1/15
Training Loss: 3.9966855123639107
Top 1 Acc: 0.17058823529411765

Running Validation Epoch 1/15


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 1/15
Val Loss: 2.601166397333145
Top 1 Acc: 0.5029411764705882

Running Training Epoch 2/15


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 2/15
Training Loss: 2.7057583071291447
Top 1 Acc: 0.4696078431372549

Running Validation Epoch 2/15


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 2/15
Val Loss: 1.979369692504406
Top 1 Acc: 0.6450980392156863

Running Training Epoch 3/15


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 3/15
Training Loss: 2.1555514000356197
Top 1 Acc: 0.6147058823529412

Running Validation Epoch 3/15


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 3/15
Val Loss: 1.7117218002676964
Top 1 Acc: 0.7617647058823529

Running Training Epoch 4/15


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 4/15
Training Loss: 1.8120776172727346
Top 1 Acc: 0.7323529411764705

Running Validation Epoch 4/15


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 4/15
Val Loss: 1.4780828021466732
Top 1 Acc: 0.8343137254901961

Running Training Epoch 5/15


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 5/15
Training Loss: 1.6700921803712845
Top 1 Acc: 0.7519607843137255

Running Validation Epoch 5/15


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 5/15
Val Loss: 1.4409347074106336
Top 1 Acc: 0.8392156862745098

Running Training Epoch 6/15


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 6/15
Training Loss: 1.5235994048416615
Top 1 Acc: 0.8166666666666667

Running Validation Epoch 6/15


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 6/15
Val Loss: 1.3184419581666589
Top 1 Acc: 0.8990196078431373

Running Training Epoch 7/15


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 7/15
Training Loss: 1.4649632824584842
Top 1 Acc: 0.8392156862745098

Running Validation Epoch 7/15


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 7/15
Val Loss: 1.3281712122261524
Top 1 Acc: 0.8931372549019608

Running Training Epoch 8/15


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 8/15
Training Loss: 1.4370731841772795
Top 1 Acc: 0.8539215686274509

Running Validation Epoch 8/15


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 8/15
Val Loss: 1.2335777394473553
Top 1 Acc: 0.9088235294117647

Running Training Epoch 9/15


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 9/15
Training Loss: 1.2929417416453362
Top 1 Acc: 0.9

Running Validation Epoch 9/15


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 9/15
Val Loss: 1.1658179545775056
Top 1 Acc: 0.9441176470588235

Running Training Epoch 10/15


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 10/15
Training Loss: 1.2576863877475262
Top 1 Acc: 0.9029411764705882

Running Validation Epoch 10/15


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 10/15
Val Loss: 1.1877405578270555
Top 1 Acc: 0.9264705882352942

Running Training Epoch 11/15


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 11/15
Training Loss: 1.2658279715105891
Top 1 Acc: 0.8921568627450981

Running Validation Epoch 11/15


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 11/15
Val Loss: 1.1265993416309357
Top 1 Acc: 0.9294117647058824

Running Training Epoch 12/15


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 12/15
Training Loss: 1.2340737339109182
Top 1 Acc: 0.9127450980392157

Running Validation Epoch 12/15


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 12/15
Val Loss: 1.1847464069724083
Top 1 Acc: 0.9264705882352942

Running Training Epoch 13/15


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 13/15
Training Loss: 1.1862440574914217
Top 1 Acc: 0.9225490196078432

Running Validation Epoch 13/15


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 13/15
Val Loss: 1.1055879639461637
Top 1 Acc: 0.9431372549019608

Running Training Epoch 14/15


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 14/15
Training Loss: 1.1622321512550116
Top 1 Acc: 0.9284313725490196

Running Validation Epoch 14/15


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 14/15
Val Loss: 1.1169866248965263
Top 1 Acc: 0.9372549019607843

Running Training Epoch 15/15


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 15/15
Training Loss: 1.1553612435236573
Top 1 Acc: 0.9372549019607843

Running Validation Epoch 15/15


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 15/15
Val Loss: 1.1012473637238145
Top 1 Acc: 0.9411764705882353



Export the model in case we want to reload in the future, so we do not have to rerun.

In [19]:
save_dir = "efficientnet-models-flowers"

In [11]:
exporter = ModuleExporter(model, output_dir=save_dir)
exporter.export_pytorch(name="b2-dense-model.pth")
exporter.export_onnx(torch.randn(1, 3, CROP_SIZE, CROP_SIZE), name="b2-dense-model.onnx", convert_qat=True)



In [11]:
torch.cuda.empty_cache()

## Step 4: Quantize The Model

With a model trained on Flowers, we are now ready to apply the QAT algorithm to quantize the model.

In [12]:
state_dict = torch.load("efficientnet-models-flowers/training/b2-dense-model.pth")

In [13]:
# SparseZoo stub to pre-trained sparse-quantized ResNet-50 for imagenet dataset
model = torchvision.models.efficientnet_b2()
model.classifier[1] = torch.nn.Linear(in_features=model.classifier[1].in_features, out_features=NUM_LABELS)
model.load_state_dict(state_dict["state_dict"])

model.to(device)

# setup loss function and optimizer
criterion = CrossEntropyLoss(label_smoothing=0.1)
optimizer = Adam(model.parameters(), lr=8e-3) # lr will be override by sparseml

In [14]:
quant_recipe_path = "b2.recipe.quant.yaml"

In [15]:
!cat b2.recipe.quant.yaml

# Epoch and Learning-Rate variables
num_epochs: 8.0
quantization_epochs: 6.0
init_lr: 0.00005
final_lr: 0.00001
warmup_lr: 0.000001

training_modifiers:
  - !EpochRangeModifier
    start_epoch: 0.0
    end_epoch: eval(num_epochs)

  - !SetLearningRateModifier
    start_epoch: 0.0
    learning_rate: eval(warmup_lr)   

  - !LearningRateFunctionModifier
    final_lr: eval(final_lr)
    init_lr: eval(init_lr)
    lr_func: cosine
    start_epoch: eval(num_epochs - quantization_epochs)
    end_epoch: eval(num_epochs)

quantization_modifiers:
  - !QuantizationModifier
    start_epoch: eval(num_epochs - quantization_epochs)
    ignore: ['classifier', 'AdaptiveAvgPool2d', 'Sigmoid', 'features.0', 'features.1.0', 'features.2.0']
    disable_quantization_observer_epoch: eval(num_epochs - 0.1)
    freeze_bn_stats_epoch: eval(num_epochs - quantization_epochs)

In [16]:
# create ScheduledModifierManager and Optimizer wrapper
manager = ScheduledModifierManager.from_yaml(quant_recipe_path)
logger = TensorBoardLogger(log_path="./tensorboard_outputs/efficientnet/quant-run")
optimizer = manager.modify(model, optimizer, loggers=[logger], steps_per_epoch=len(train_loader))

In [17]:
# run QAT algorithm
epoch = 0
for epoch in range(manager.max_epochs):
    # run training loop
    epoch_name = f"{epoch + 1}/{manager.max_epochs}"
    
    print(f"Running Training Epoch {epoch_name}")
    train_loss, train_acc = run_model_one_epoch(model, train_loader, criterion, device, train=True, optimizer=optimizer)
    print(f"Training Epoch: {epoch_name}\nTraining Loss: {train_loss}\nTop 1 Acc: {train_acc}\n")

    # run validation loop
    print(f"Running Validation Epoch {epoch_name}")
    val_loss, val_acc = run_model_one_epoch(model, val_loader, criterion, device)
    print(f"Validation Epoch: {epoch_name}\nVal Loss: {val_loss}\nTop 1 Acc: {val_acc}\n")
    
    logger.log_scalar("Metrics/Loss (Train)", train_loss, epoch)
    logger.log_scalar("Metrics/Accuracy (Train)", train_acc, epoch)
    logger.log_scalar("Metrics/Loss (Validation)", val_loss, epoch)
    logger.log_scalar("Metrics/Accuracy (Validation)", val_acc, epoch)

manager.finalize(model)

Running Training Epoch 1/8


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 1/8
Training Loss: 1.1161327175796032
Top 1 Acc: 0.946078431372549

Running Validation Epoch 1/8


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 1/8
Val Loss: 1.0953276669606566
Top 1 Acc: 0.9411764705882353

Running Training Epoch 2/8


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 2/8
Training Loss: 1.1136755589395761
Top 1 Acc: 0.942156862745098

Running Validation Epoch 2/8


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 2/8
Val Loss: 1.087324595078826
Top 1 Acc: 0.9411764705882353

Running Training Epoch 3/8


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 3/8
Training Loss: 1.3689014166593552
Top 1 Acc: 0.8696078431372549

Running Validation Epoch 3/8


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 3/8
Val Loss: 1.1817939430475235
Top 1 Acc: 0.9049019607843137

Running Training Epoch 4/8


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 4/8
Training Loss: 1.1495222207158804
Top 1 Acc: 0.9362745098039216

Running Validation Epoch 4/8


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 4/8
Val Loss: 1.1346615916118026
Top 1 Acc: 0.9166666666666666

Running Training Epoch 5/8


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 5/8
Training Loss: 1.127908794209361
Top 1 Acc: 0.9303921568627451

Running Validation Epoch 5/8


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 5/8
Val Loss: 1.1099653420969844
Top 1 Acc: 0.9225490196078432

Running Training Epoch 6/8


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 6/8
Training Loss: 1.1017089420929551
Top 1 Acc: 0.9333333333333333

Running Validation Epoch 6/8


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 6/8
Val Loss: 1.0872347811236978
Top 1 Acc: 0.9284313725490196

Running Training Epoch 7/8


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 7/8
Training Loss: 1.0860993107780814
Top 1 Acc: 0.9509803921568627

Running Validation Epoch 7/8


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 7/8
Val Loss: 1.0890111606568098
Top 1 Acc: 0.9254901960784314

Running Training Epoch 8/8


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 8/8
Training Loss: 1.0858161579817533
Top 1 Acc: 0.9480392156862745

Running Validation Epoch 8/8


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 8/8
Val Loss: 1.0888487379997969
Top 1 Acc: 0.9254901960784314



In [20]:
exporter = ModuleExporter(model, output_dir=save_dir)
exporter.export_pytorch(name="b2-quant-model.pth")
exporter.export_onnx(torch.randn(1, 3, CROP_SIZE, CROP_SIZE), name="b2-quant-model.onnx", convert_qat=True)

2023-04-24 21:35:51 sparseml.pytorch.sparsification.quantization.quantize_qat_export INFO     Converted 105 quantizable Conv ops with weight and bias to ConvInteger and Add
2023-04-24 21:35:51 sparseml.pytorch.sparsification.quantization.quantize_qat_export INFO     Converted 0 quantizable Gemm ops with weight and bias to MatMulInteger and Add
