In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from speech_command_dataset import SpeechCommandDataset
import numpy as np
import matplotlib.pyplot as plt
from model import M5
import time

In [None]:
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True 
torch.backends.cudnn.benchmark = False

device = torch.device("cpu")
print(device)

In [None]:
# declare dataloader
calib_params = {"batch_size": ,
                "shuffle": True,
                "drop_last": True,
                "num_workers": 1}

testing_params = {"batch_size": ,
                       "shuffle": False,
                       "drop_last": True,
                       "num_workers": 1}

calib_set = SpeechCommandDataset()
calib_loader = DataLoader(calib_set, **calib_params)

test_set = SpeechCommandDataset(is_training=False)
test_loader = DataLoader(test_set, **testing_params)

In [None]:
def test(model, epoch):
    model.eval()
    correct = 0
    for data, target in test_loader:

        data = data.to(device)
        target = target.to(device)

        #forward
        output = model(data)

        pred = output.argmax(dim=-1)
        correct += pred.squeeze().eq(target).sum().item()
        
    # print testing stats
    test_acc = 100.0 * float(correct) / len(test_set)
    print('Epoch: %3d' % epoch, '|test accuracy: %.2f' % test_acc)
    return test_acc

In [None]:
# load model
model_path = './Checkpoint/best_model.pth.tar'

print("=> loading checkpoint '{}'".format(model_path))
checkpoint = torch.load(model_path, map_location = device)

model = M5(cfg = checkpoint['cfg']).to(device)
model.load_state_dict(checkpoint['state_dict'])

In [None]:
print(model)

In [None]:
print('\nbytes per element:', model.features[0].weight.element_size())

### Static quantization of a model consists of the following steps:

1. Fuse modules
2. Insert Quant/DeQuant Stubs
3. Prepare the fused module (insert observers before and after layers)
4. Calibrate the prepared module (pass it representative data)
5. Convert the calibrated module (replace with quantized version)

### 1.Fuse modules

In [None]:
model.eval()

_ = torch.quantization.fuse_modules(model.features, ['0','1','2'], inplace=True)
_ = torch.quantization.fuse_modules(model.features, ['4','5','6'], inplace=True)
_ = torch.quantization.fuse_modules(model.features, ['8','9','10'], inplace=True)
_ = torch.quantization.fuse_modules(model.features, ['12','13','14'], inplace=True)

print(model)

### 2. Insert Quant/DeQuant Stubs

In [None]:
"""Insert stubs"""
model = nn.Sequential(torch.quantization.QuantStub(), 
                  *model.features,
                   model.avgpool,
                   model.flatten,
                   model.fc,
                   torch.quantization.DeQuantStub())

print(model)

### 3. Prepare the fused module (insert observers before and after layers)

In [None]:
backend = "fbgemm"  # running on a x86 CPU. Use "qnnpack" if running on ARM.

model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.quantization.prepare(model, inplace=True)

print(model)

### 4. Calibrate the prepared module (pass it representative data)

In [None]:
iterator = iter(calib_loader)

NUM_CALIB_BATCH = 10

with torch.inference_mode():
    for _ in range(NUM_CALIB_BATCH):
        inputs, labels = next(iterator)
        inputs = inputs.cpu()
        labels = labels.cpu()
        outputs = model(inputs)

### 5. Convert the calibrated module (replace with quantized version)

In [None]:
"""Convert"""
quantized_model = torch.quantization.convert(model, inplace=False)

print(quantized_model)

In [None]:
print('\nbytes per element:', quantized_model[1].weight().element_size())

In [None]:
test_acc = test(quantized_model, 0)

## run benchmark

In [None]:
# load model
best_path = './Checkpoint/best_model.pth.tar'
fine_path = './Checkpoint/fine_model.pth.tar'
coarse_path = './Checkpoint/coarse_model.pth.tar'

best_checkpoint = torch.load(best_path, map_location = device)
best_model = M5(cfg = best_checkpoint['cfg']).to(device)
best_model.load_state_dict(best_checkpoint['state_dict'])

fine_checkpoint = torch.load(fine_path, map_location = device)
fine_model = M5(cfg = fine_checkpoint['cfg']).to(device)
fine_model.load_state_dict(fine_checkpoint['state_dict'])

coarse_checkpoint = torch.load(coarse_path, map_location = device)
coarse_model = M5(cfg = coarse_checkpoint['cfg']).to(device)
coarse_model.load_state_dict(coarse_checkpoint['state_dict'])

In [None]:
def run_benchmark(model, num_batch):
    model.eval()
    elapsed = 0
    
    for i, (data, target) in enumerate(test_loader):

        data = data.to(device)
        #forward
        start = time.perf_counter()
        output = model(data)
        end = time.perf_counter()
        elapsed = elapsed + (end-start)
        
        if i == num_batch-1:
            break
    print('inference time: %.3f s' % (elapsed))

In [None]:
NUM_BATCH = 100

In [None]:
run_benchmark(quantized_model, NUM_BATCH)

In [None]:
run_benchmark(best_model, NUM_BATCH)

In [None]:
run_benchmark(fine_model, NUM_BATCH)

In [None]:
run_benchmark(coarse_model, NUM_BATCH)