In [1]:
import torch
import os
import matplotlib.pyplot as plt
import tqdm.notebook as tqdm
import numpy as np
import random
import time

np.random.seed(42)
random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7f5bb2f3b410>

In [2]:
mnist_config = {
    'input_size': 400,
    'batch_size': 64,
    'lr': 0.001,
    'hidden_layers': 2,
    'hidden_size': 1024,
    'epochs': 2,
    'log_interval': 10, # log every 10 batches
    'output_size': 10
}

# mnist_config = {
#     'input_size': 784,
#     'batch_size': 64,
#     'lr': 0.001,
#     'hidden_layers': 2,
#     'hidden_size': 512,
#     'epochs': 2,
#     'log_interval': 10, # log every 10 batches
#     'output_size': 10
# }

sst_config = {
    'batch_size': 64,
    'lr': 0.001,
    'hidden_layers': 3,
    'hidden_size': 256,
    'epochs': 2,
    'log_interval': 10,
    'output_size': 2
}

# SST Dataset

In [3]:
class SSTDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, data_type='train', suffix='tsv', vocab=None, top_k=0, remove_stopwords=False):
        self.datapath = os.path.join(data_dir, data_type + '.' + suffix)
        self.data_type = data_type
        self.sentences = []
        self.labels = []

        with open(self.datapath) as data_f:
            for line in data_f:
                sentence, label = line.strip().split('\t')
                self.sentences.append(sentence)
                self.labels.append(label)
        self.sentences = self.sentences[1:]
        self.labels = self.labels[1:]
        if vocab is None:
            self.vocab = self.build_vocab()
        else:
            self.vocab = vocab
        if top_k > 0:
            self.vocab = self.vocab[:top_k]
        if remove_stopwords:
            self.vocab = self.remove_stopwords()
        self.word2idx = {word: idx for idx, word in enumerate(self.vocab)}
        self.idx2word = {idx: word for word, idx in self.word2idx.items()}
        
        # add OOV token
        if vocab is None:
            self.vocab.append('<OOV>')
            self.word2idx['<OOV>'] = len(self.vocab) - 1
            self.idx2word[len(self.vocab) - 1] = '<OOV>'
            
    def remove_stopwords(self):
        # remove stopwords obtained from nltk
        from nltk.corpus import stopwords
        stop_words = set(stopwords.words('english'))
        return [word for word in self.vocab if word not in stop_words]

    def __len__(self):
        return len(self.sentences)

    def one_hot_encode(self, word):
        vector = torch.zeros(len(self.vocab))
        if word in self.word2idx:
            vector[self.word2idx[word]] += 1
        else:
            vector[self.word2idx['<OOV>']] += 1
        return vector
    
    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        label = self.labels[idx]
        
        encoded_sentence = torch.sum(torch.stack([self.one_hot_encode(word) for word in sentence.split()]), dim=0)
        encoded_label = torch.tensor(int(label)) # 0: negative, 1: positive
        return encoded_sentence, encoded_label

    def build_vocab(self):
        vocab = {}
        for sentence in self.sentences:
            for word in sentence.split():
                if word not in vocab:
                    vocab[word] = 1
                else:
                    vocab[word] += 1
        return sorted(vocab, key=vocab.get, reverse=True)
        

# MNIST Dataset

In [4]:
class MNISTDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, data_type='train', suffix='csv', transforms=None):
        self.datapath = os.path.join(data_dir, f'mnist_{data_type}.{suffix}')
        self.labels = []
        self.data = []
        self.transforms = transforms
        
        with open(self.datapath) as data_f:
            line_cnt = 0
            for line in data_f:
                if line_cnt > 0:
                    data = line.strip().split(',')
                    self.labels.append(data[0])
                    self.data.append(torch.tensor([float(dp) for dp in data[1:]]))
                line_cnt += 1
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        label = self.labels[idx]
        data = self.data[idx]
        # normalize the data
        data = (data - torch.mean(data)) / torch.std(data)
        if self.transforms is not None:
            # reconstruct to image and do transforms
            data = data.reshape(1, 28, 28)
            data = self.transforms(data)
            # flatten
            data = data.reshape(-1)
        return data, torch.tensor(int(label))

# Model

In [5]:
class MLP(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_hidden_layers, output_size):
        super(MLP, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.output_size = output_size
        self.total_flops = 0
        self.layers = torch.nn.ModuleList()
        for layer_idx in range(self.num_hidden_layers):
            if layer_idx == 0:
                self.layers.append(torch.nn.Linear(self.input_size, self.hidden_size))
                self.total_flops += self.input_size * self.hidden_size
            else:
                self.layers.append(torch.nn.Linear(self.hidden_size, self.hidden_size))
                self.total_flops += self.hidden_size * self.hidden_size
            self.layers.append(torch.nn.ReLU())
            self.total_flops += self.hidden_size
        self.layers.append(torch.nn.Linear(self.hidden_size, self.output_size))
        self.total_flops += self.hidden_size * self.output_size
        self.softmax = torch.nn.Softmax(dim=1)
        self.total_flops += self.output_size * np.log2(self.output_size)
    
    def forward(self, x):
        # x: (batch_size, input_size)
        for layer in self.layers:
            x = layer(x)
        x = self.softmax(x) 
        return x
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


# Train

In [6]:
import os

import torch
import time
import tqdm
import numpy as np
import math


class Trainer:
    def __init__(self, 
                 model, 
                 train_dl, 
                 valid_dl, 
                 config, 
                 ckpt_path='ckpt', 
                 device='cuda', 
                 save_ckpts=False,
                 no_train=False):
        self.model = model.to(device)
        self.train_dl = train_dl
        self.valid_dl = valid_dl
        self.config = config
        self.no_train = no_train
        if not no_train:
            self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config['lr'])
            self.criterion = torch.nn.CrossEntropyLoss().to(device)
        self.training_time = 0
        self.training_time_per_epoch = []
        self.total_inference_time = 0
        self.inference_time = []
        self.avg_inference_time_per_epoch = []
        self.ckpt_path = ckpt_path
        self.best_acc = 0
        self.device = device
        self.save_ckpt = save_ckpts

    def train(self):
        start = time.time()
        for epoch in range(self.config['epochs']):
            start_time = time.time()
            self.train_one_epoch()
            end_time = time.time()
            self.training_time_per_epoch.append(end_time - start_time)
            acc = self.validate()
            if acc > self.best_acc:
                self.best_acc = acc
                if self.save_ckpt:
                    self.save(epoch)
        end = time.time()
        self.training_time = end - start

    def train_one_epoch(self):
        batch_bar = tqdm.tqdm(total=len(self.train_dl), dynamic_ncols=True, leave=True, position=0, desc='Train',
                              ncols=5)
        num_correct = 0
        total_loss = 0
        total_inference_epoch = 0

        for idx, (train_data, train_label) in enumerate(self.train_dl):
            unit_num_correct, unit_total_loss, batch_inference_time = self.train_one_batch(train_data, train_label)
            total_inference_epoch += batch_inference_time
            num_correct += unit_num_correct
            total_loss += unit_total_loss
            batch_bar.set_postfix(
                acc=f"{100 * num_correct / (self.config['batch_size'] * (idx + 1)):.4f}",
                loss=f"{total_loss / (self.config['batch_size'] * (idx + 1)):.4f}",
                num_correct=f"{num_correct}",
                lr=f"{self.optimizer.param_groups[0]['lr']:.4f}"
            )
            batch_bar.update()

        self.avg_inference_time_per_epoch.append(total_inference_epoch / len(self.train_dl.dataset))
        batch_bar.close()

    def train_one_batch(self, train_data, train_label):
        train_data = train_data.to(self.device)
        train_label = train_label.to(self.device)
        self.optimizer.zero_grad()
        inference_start = time.time()
        output = self.model(train_data)
        inference_end = time.time()
        self.total_inference_time += inference_end - inference_start
        # l2 regularization
        # loss = self.criterion(output, train_label) - 0.001 * torch.norm(self.model.layers[0].weight, p=2)
        loss = self.criterion(output, train_label)
        num_correct = torch.sum(torch.argmax(output, dim=1) == train_label).item()
        total_loss = float(loss.item())
        loss.backward()
        self.optimizer.step()

        return num_correct, total_loss, inference_end - inference_start

    def validate(self):
        # report accuracy
        correct_count = 0
        valid_loss = 0
        num_samples = 0
        for idx, (valid_data, valid_label) in enumerate(self.valid_dl):
            num_samples += len(valid_data)
            valid_data = valid_data.to(self.device)
            valid_label = valid_label.to(self.device)
            output = self.model(valid_data)
            pred = torch.argmax(output, dim=1)  # (batch_size)
            correct_count += torch.sum(pred == valid_label).item()
            if not self.no_train:
                loss = self.criterion(output, valid_label).item()
                valid_loss += loss
        if not self.no_train:
            valid_loss /= num_samples
        acc = correct_count / len(self.valid_dl.dataset)
        print(f'\nValidation Loss: {valid_loss if not self.no_train else "n/a"}\t Validation Accuracy: {acc:.4f}')
        self.valid_acc = acc
        return acc
    
    def calculate_inference_latency(self, val_dl):
        # calculate inference latency per batch
        total_inference_time = 0
        num_of_batches = 0
        for idx, (valid_data, valid_label) in enumerate(val_dl):
            num_of_batches += 1
            valid_data = valid_data.to(self.device)
            inference_start = time.time()
            _ = self.model(valid_data)
            inference_end = time.time()
            total_inference_time += inference_end - inference_start
        return total_inference_time / num_of_batches

    def save(self, epoch):
        ckpt_dir = os.path.join(self.ckpt_path, f'{time.strftime("%m-%d-%H-%M", time.localtime())}')
        if not os.path.exists(ckpt_dir):
            os.makedirs(ckpt_dir)
        print(f'Saving model at Epoch {epoch}')
        ckpt_path = os.path.join(ckpt_dir, f'best_acc.pt')
        torch.save(self.model.state_dict(), ckpt_path)

    def get_training_stats(self):
        # we need flops, avg inference per sample and accuracy
        return {
            'num_parameters': self.model.param_count,
            'num_flops': self.model.flops,
            'training_time': self.training_time,
            'training_time_per_epoch': np.mean(self.training_time_per_epoch),
            'best_acc': self.best_acc,
            'average_inference_time_per_sample': self.total_inference_time / (self.config['epochs'] * len(self.train_dl.dataset)),
        }
    def report(self):
        print(f'Number of parameters: {self.model.count_parameters()}')
        print(f'Total number of FLOPs: {self.model.total_flops}')
        print(f'Average training time per epoch: {np.mean(self.training_time_per_epoch)} seconds')
        print(f'Training time: {self.training_time} seconds')
        print(
            f'Average inference time per sample: {self.total_inference_time / (self.config["epochs"] * len(self.train_dl.dataset))} seconds')


# Baseline

In [7]:
# dataloaders
from torchvision import transforms

sst_train_dataset = SSTDataset(data_dir='SST-2', data_type='train', top_k=5000)
sst_dev_dataset = SSTDataset(data_dir='SST-2', data_type='dev', vocab=sst_train_dataset.vocab)
mnist_train_dataset = MNISTDataset(data_dir='MNIST', data_type='train', transforms=transforms.CenterCrop((20, 20)))
mnist_valid_dataset = MNISTDataset(data_dir='MNIST', data_type='test', transforms=transforms.CenterCrop((20, 20)))


sst_dataloader = torch.utils.data.DataLoader(
    dataset=sst_train_dataset,
    batch_size=sst_config['batch_size'],
    shuffle=True
)
sst_dev_dataloader = torch.utils.data.DataLoader(
    dataset=sst_dev_dataset,
    batch_size=sst_config['batch_size'],
    shuffle=False
)
mnist_dataloader = torch.utils.data.DataLoader(dataset=mnist_train_dataset,
                                               batch_size=mnist_config['batch_size'],
                                               shuffle=True)
mnist_valid_dataloader = torch.utils.data.DataLoader(dataset=mnist_valid_dataset,
                                                     batch_size=mnist_config['batch_size'],
                                                     shuffle=False)

In [8]:
sst_train_dataset[0][0].shape

torch.Size([5001])

In [9]:
len(mnist_train_dataset)

59999

In [10]:
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'

In [11]:
# MNIST model
mnist_model = MLP(input_size=mnist_config['input_size'], hidden_size=mnist_config['hidden_size'], num_hidden_layers=mnist_config['hidden_layers'], output_size=mnist_config['output_size'])

mnist_trainer = Trainer(
    model=mnist_model,
    train_dl=mnist_dataloader,
    valid_dl=mnist_valid_dataloader,
    config=mnist_config,
    device=device
)

mnist_trainer.train()
mnist_trainer.report()

Train: 100%|██████████| 938/938 [00:14<00:00, 62.94it/s, acc=80.2322, loss=0.0259, lr=0.0010, num_correct=48165]



Validation Loss: 0.025196940413188047	 Validation Accuracy: 0.8553


Train: 100%|██████████| 938/938 [00:14<00:00, 65.64it/s, acc=87.5999, loss=0.0248, lr=0.0010, num_correct=52588]



Validation Loss: 0.023892757439329594	 Validation Accuracy: 0.9386
Number of parameters: 1470474
Total number of FLOPs: 1470497.2192809489
Average training time per epoch: 14.599167823791504 seconds
Training time: 30.736724138259888 seconds
Average inference time per sample: 2.0052999688692118e-05 seconds


In [12]:
# SST model
sst_model = MLP(input_size=len(sst_train_dataset.vocab), 
                hidden_size=sst_config['hidden_size'], 
                num_hidden_layers=sst_config['hidden_layers'], 
                output_size=sst_config['output_size'])

sst_trainer = Trainer(
    model=sst_model,
    train_dl=sst_dataloader,
    valid_dl=sst_dev_dataloader,
    config=sst_config,
    device=device
)
sst_trainer.train()

Train: 100%|██████████| 1053/1053 [00:21<00:00, 49.84it/s, acc=82.3718, loss=0.0074, lr=0.0010, num_correct=55512]



Validation Loss: 0.007868469444983596	 Validation Accuracy: 0.8257


Train: 100%|██████████| 1053/1053 [00:21<00:00, 48.83it/s, acc=89.6160, loss=0.0064, lr=0.0010, num_correct=60394]



Validation Loss: 0.007981193721841235	 Validation Accuracy: 0.8119


In [13]:
# get stats for baseline, 
#   - inference latency for batch size 1 and 64 over 5 runs
#   - model size
#   - parameter counts

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    size=os.path.getsize("temp.p")
    # print("model: ",label,' \t','Size (MB):', size/1e6)
    os.remove('temp.p')
    return size

def count_trainable_parameters(quantized_mode):
    total_trainable_params = 0
    for module in quantized_mode.modules():
        if hasattr(module, '_packed_params'):
            if isinstance(module, torch.nn.quantized.dynamic.Linear):
                weight, bias = module._packed_params.unpack()
                
                if weight.requires_grad:
                    total_trainable_params += weight.numel()
                if bias is not None and bias.requires_grad:
                    total_trainable_params += bias.numel()
    return total_trainable_params

def get_stats(num_runs,
              model,
              trainer,
              val_ds,
              model_type):
    stats = {}
    stats['model_name'] = model_type

    stats['acc'] = trainer.valid_acc
    if model_type.startswith('mnist'):
        print(f'model name: {model_type}')
        for batch_size in [64, 1]:
            print(f'Running for batch size {batch_size}')
            avg_inference_latency = []
            for _ in range(num_runs):
                val_dl = torch.utils.data.DataLoader(
                    dataset=val_ds,
                    batch_size=batch_size,
                    shuffle=False
                )
                stats['model_size_mb'] = print_size_of_model(model) / 1e6
                stats['parameter_count'] = sum(p.numel() for p in model.parameters() if p.requires_grad)
                avg_inference_latency.append(trainer.calculate_inference_latency(val_dl))
            avg_inference_latency, std_inference_latency = np.mean(avg_inference_latency), np.std(avg_inference_latency)
            stats[f'avg_inference_latency_{batch_size}'] = avg_inference_latency
            stats[f'std_inference_latency_{batch_size}'] = std_inference_latency
    elif model_type.startswith('sst'):
        print(f'model name: {model_type}')
        for batch_size in [64, 1]:
            print(f'Running for batch size {batch_size}')
            avg_inference_latency = []
            for _ in range(num_runs):
                val_dl = torch.utils.data.DataLoader(
                    dataset=val_ds,
                    batch_size=batch_size,
                    shuffle=False
                )
                stats['model_size_mb'] = print_size_of_model(model) / 1e6
                stats['parameter_count'] = sum(p.numel() for p in model.parameters() if p.requires_grad)
                avg_inference_latency.append(trainer.calculate_inference_latency(val_dl))
            avg_inference_latency, std_inference_latency = np.mean(avg_inference_latency), np.std(avg_inference_latency)
            stats[f'avg_inference_latency_{batch_size}'] = avg_inference_latency
            stats[f'std_inference_latency_{batch_size}'] = std_inference_latency
        
    return stats

In [14]:
mnist_stats = get_stats(5, mnist_model, mnist_trainer, mnist_valid_dataset, 'mnist')
print(mnist_stats)

model name: mnist
Running for batch size 64
Running for batch size 1
{'model_name': 'mnist', 'acc': 0.9385938593859386, 'model_size_mb': 5.883839, 'parameter_count': 1470474, 'avg_inference_latency_64': 0.0013329408730670905, 'std_inference_latency_64': 0.00015313183800845417, 'avg_inference_latency_1': 0.0001267622632853495, 'std_inference_latency_1': 7.931996021353464e-06}


In [15]:
sst_stats = get_stats(5, sst_model, sst_trainer, sst_dev_dataset, 'sst')
print(sst_stats)

model name: sst
Running for batch size 64
Running for batch size 1
{'model_name': 'sst', 'acc': 0.8119266055045872, 'model_size_mb': 5.652849, 'parameter_count': 1412610, 'avg_inference_latency_64': 0.0010027953556605747, 'std_inference_latency_64': 3.19563286115864e-05, 'avg_inference_latency_1': 0.00015513984435195223, 'std_inference_latency_1': 3.554871816429752e-06}


In [16]:
import pandas as pd

df = pd.DataFrame([mnist_stats, sst_stats])

In [17]:
df.to_csv('baseline.csv', index=False)

# Dynamic Quantization with PyTorch

In [18]:
import torch.quantization
print(torch.backends.quantized.supported_engines)
torch.backends.quantized.engine = 'x86'


['qnnpack', 'none', 'onednn', 'x86', 'fbgemm']


In [23]:
# MNIST model
quantized_mnist_model_8 = torch.quantization.quantize_dynamic(
    mnist_model, {torch.nn.Linear}, dtype=torch.qint8
)

quantized_mnist_model_16 = torch.quantization.quantize_dynamic(
    mnist_model, {torch.nn.Linear}, dtype=torch.float16
)

mnist_trainer_quantized_8 = Trainer(
    model=quantized_mnist_model_8,
    train_dl=mnist_dataloader,
    valid_dl=mnist_valid_dataloader,
    config=mnist_config,
    device=device,
    no_train=True
)
mnist_trainer_quantized_8.validate()

mnist_trainer_quantized_16 = Trainer(
    model=quantized_mnist_model_16,
    train_dl=mnist_dataloader,
    valid_dl=mnist_valid_dataloader,
    config=mnist_config,
    device=device,
    no_train=True
)
mnist_trainer_quantized_16.validate()

mnist_eight_bit_stats = get_stats(5, quantized_mnist_model_8, mnist_trainer_quantized_8, mnist_valid_dataset, 'mnist_8bit')
print(mnist_eight_bit_stats)

mnist_sixteen_bit_stats = get_stats(5, quantized_mnist_model_16, mnist_trainer_quantized_16, mnist_valid_dataset, 'mnist_16bit')
print(mnist_sixteen_bit_stats)


Validation Loss: n/a	 Validation Accuracy: 0.9386

Validation Loss: n/a	 Validation Accuracy: 0.9386
model name: mnist_8bit
Running for batch size 64
Running for batch size 1
{'model_name': 'mnist_8bit', 'acc': 0.9385938593859386, 'model_size_mb': 1.480535, 'parameter_count': 0, 'avg_inference_latency_64': 0.0005661390389606452, 'std_inference_latency_64': 6.417420069453933e-05, 'avg_inference_latency_1': 0.00015660790589251797, 'std_inference_latency_1': 3.47960730864735e-06}
model name: mnist_16bit
Running for batch size 64
Running for batch size 1
{'model_name': 'mnist_16bit', 'acc': 0.9385938593859386, 'model_size_mb': 5.885655, 'parameter_count': 0, 'avg_inference_latency_64': 0.000959180722570723, 'std_inference_latency_64': 0.00012071004421095663, 'avg_inference_latency_1': 0.0001696667488556717, 'std_inference_latency_1': 1.0326271385712067e-05}


In [24]:
# do the same for sst

quantized_sst_model_8 = torch.quantization.quantize_dynamic(
    sst_model, {torch.nn.Linear}, dtype=torch.qint8
)

quantized_sst_model_16 = torch.quantization.quantize_dynamic(
    sst_model, {torch.nn.Linear}, dtype=torch.float16
)

sst_trainer_quantized_8 = Trainer(
    model=quantized_sst_model_8,
    train_dl=sst_dataloader,
    valid_dl=sst_dev_dataloader,
    config=sst_config,
    device=device,
    no_train=True
)
sst_trainer_quantized_8.validate()

sst_trainer_quantized_16 = Trainer(
    model=quantized_sst_model_16,
    train_dl=sst_dataloader,
    valid_dl=sst_dev_dataloader,
    config=sst_config,
    device=device,
    no_train=True
)
sst_trainer_quantized_16.validate()

sst_eight_bit_stats = get_stats(5, quantized_sst_model_8, sst_trainer_quantized_8, sst_dev_dataset, 'sst_8bit')
print(sst_eight_bit_stats)

sst_sixteen_bit_stats = get_stats(5, quantized_sst_model_16, sst_trainer_quantized_16, sst_dev_dataset, 'sst_16bit')
print(sst_sixteen_bit_stats)


Validation Loss: n/a	 Validation Accuracy: 0.8096

Validation Loss: n/a	 Validation Accuracy: 0.8119
model name: sst_8bit
Running for batch size 64
Running for batch size 1
{'model_name': 'sst_8bit', 'acc': 0.8096330275229358, 'model_size_mb': 1.419903, 'parameter_count': 0, 'avg_inference_latency_64': 0.0008473566600254603, 'std_inference_latency_64': 8.789052910645608e-05, 'avg_inference_latency_1': 0.00020241206939067315, 'std_inference_latency_1': 1.8937977372946463e-05}
model name: sst_16bit
Running for batch size 64
Running for batch size 1
{'model_name': 'sst_16bit', 'acc': 0.8119266055045872, 'model_size_mb': 5.655295, 'parameter_count': 0, 'avg_inference_latency_64': 0.0010310615812029159, 'std_inference_latency_64': 3.63134646375152e-05, 'avg_inference_latency_1': 0.0002176146988474995, 'std_inference_latency_1': 5.180847951934833e-06}


In [25]:
# export to csv

dynamic_quantization_df = pd.DataFrame([mnist_eight_bit_stats, mnist_sixteen_bit_stats, sst_eight_bit_stats, sst_sixteen_bit_stats])
# add baseline stats
dynamic_quantization_df = pd.concat([df, dynamic_quantization_df], axis=0)
dynamic_quantization_df.to_csv('dynamic_quantization.csv', index=False)

# Static Quantization

In [26]:
# static quantization with prepare_fx

# do the same for sst
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization import get_default_qconfig_mapping

qconfig_mapping = get_default_qconfig_mapping("fbgemm")
example_inputs = torch.randn(5001)
quantized_sst_model_8 = convert_fx(prepare_fx(sst_model, qconfig_mapping, example_inputs=example_inputs))

sst_trainer_quantized_8 = Trainer(
    model=quantized_sst_model_8,
    train_dl=sst_dataloader,
    valid_dl=sst_dev_dataloader,
    config=sst_config,
    device=device,
    no_train=True
)
sst_trainer_quantized_8.validate()


sst_eight_bit_stats = get_stats(5, quantized_sst_model_8, sst_trainer_quantized_8, sst_dev_dataset, 'sst_8bit_static')
print(sst_eight_bit_stats)




Validation Loss: n/a	 Validation Accuracy: 0.4954
model name: sst_8bit_static
Running for batch size 64
Running for batch size 1
{'model_name': 'sst_8bit_static', 'acc': 0.4954128440366973, 'model_size_mb': 1.434179, 'parameter_count': 0, 'avg_inference_latency_64': 0.0008141211100987025, 'std_inference_latency_64': 0.00011938487696745565, 'avg_inference_latency_1': 0.00019600331236463076, 'std_inference_latency_1': 2.0684235634766287e-05}


In [27]:
# do the same to mnist

quantized_mnist_model_8 = convert_fx(prepare_fx(mnist_model, qconfig_mapping, example_inputs=example_inputs))

mnist_trainer_quantized_8 = Trainer(
    model=quantized_mnist_model_8,
    train_dl=mnist_dataloader,
    valid_dl=mnist_valid_dataloader,
    config=mnist_config,
    device=device,
    no_train=True
)
mnist_trainer_quantized_8.validate()

mnist_eight_bit_stats = get_stats(5, quantized_mnist_model_8, mnist_trainer_quantized_8, mnist_valid_dataset, 'mnist_8bit_static')
print(mnist_eight_bit_stats)




Validation Loss: n/a	 Validation Accuracy: 0.9091
model name: mnist_8bit_static
Running for batch size 64
Running for batch size 1
{'model_name': 'mnist_8bit_static', 'acc': 0.9090909090909091, 'model_size_mb': 1.515047, 'parameter_count': 0, 'avg_inference_latency_64': 0.0005392086733678344, 'std_inference_latency_64': 0.0001448632877445155, 'avg_inference_latency_1': 0.0001490873412521306, 'std_inference_latency_1': 4.263607324251732e-06}


In [28]:
# export to csv
static_quantization_df = pd.DataFrame([mnist_eight_bit_stats, sst_eight_bit_stats])
static_quantization_df = pd.concat([dynamic_quantization_df, static_quantization_df], axis=0)

In [29]:
static_quantization_df.to_csv('static_quantization.csv', index=False)