In [42]:
import torch
import os
import matplotlib.pyplot as plt
import tqdm.notebook as tqdm
import numpy as np
import random
import time

np.random.seed(42)
random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x1089f5d50>

In [43]:
mnist_config = {
    'input_size': 400,
    'batch_size': 64,
    'lr': 0.001,
    'hidden_layers': 2,
    'hidden_size': 1024,
    'epochs': 2,
    'log_interval': 10, # log every 10 batches
    'output_size': 10
}

# mnist_config = {
#     'input_size': 784,
#     'batch_size': 64,
#     'lr': 0.001,
#     'hidden_layers': 2,
#     'hidden_size': 512,
#     'epochs': 2,
#     'log_interval': 10, # log every 10 batches
#     'output_size': 10
# }

sst_config = {
    'batch_size': 64,
    'lr': 0.001,
    'hidden_layers': 3,
    'hidden_size': 256,
    'epochs': 2,
    'log_interval': 10,
    'output_size': 2
}

# SST Dataset

In [44]:
class SSTDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, data_type='train', suffix='tsv', vocab=None, top_k=0, remove_stopwords=False):
        self.datapath = os.path.join(data_dir, data_type + '.' + suffix)
        self.data_type = data_type
        self.sentences = []
        self.labels = []

        with open(self.datapath) as data_f:
            for line in data_f:
                sentence, label = line.strip().split('\t')
                self.sentences.append(sentence)
                self.labels.append(label)
        self.sentences = self.sentences[1:]
        self.labels = self.labels[1:]
        if vocab is None:
            self.vocab = self.build_vocab()
        else:
            self.vocab = vocab
        if top_k > 0:
            self.vocab = self.vocab[:top_k]
        if remove_stopwords:
            self.vocab = self.remove_stopwords()
        self.word2idx = {word: idx for idx, word in enumerate(self.vocab)}
        self.idx2word = {idx: word for word, idx in self.word2idx.items()}
        
        # add OOV token
        if vocab is None:
            self.vocab.append('<OOV>')
            self.word2idx['<OOV>'] = len(self.vocab) - 1
            self.idx2word[len(self.vocab) - 1] = '<OOV>'
            
    def remove_stopwords(self):
        # remove stopwords obtained from nltk
        from nltk.corpus import stopwords
        stop_words = set(stopwords.words('english'))
        return [word for word in self.vocab if word not in stop_words]

    def __len__(self):
        return len(self.sentences)

    def one_hot_encode(self, word):
        vector = torch.zeros(len(self.vocab))
        if word in self.word2idx:
            vector[self.word2idx[word]] += 1
        else:
            vector[self.word2idx['<OOV>']] += 1
        return vector
    
    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        label = self.labels[idx]
        
        encoded_sentence = torch.sum(torch.stack([self.one_hot_encode(word) for word in sentence.split()]), dim=0)
        encoded_label = torch.tensor(int(label)) # 0: negative, 1: positive
        return encoded_sentence, encoded_label

    def build_vocab(self):
        vocab = {}
        for sentence in self.sentences:
            for word in sentence.split():
                if word not in vocab:
                    vocab[word] = 1
                else:
                    vocab[word] += 1
        return sorted(vocab, key=vocab.get, reverse=True)
        

# MNIST Dataset

In [45]:
class MNISTDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, data_type='train', suffix='csv', transforms=None):
        self.datapath = os.path.join(data_dir, f'mnist_{data_type}.{suffix}')
        self.labels = []
        self.data = []
        self.transforms = transforms
        
        with open(self.datapath) as data_f:
            line_cnt = 0
            for line in data_f:
                if line_cnt > 0:
                    data = line.strip().split(',')
                    self.labels.append(data[0])
                    self.data.append(torch.tensor([float(dp) for dp in data[1:]]))
                line_cnt += 1
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        label = self.labels[idx]
        data = self.data[idx]
        # normalize the data
        data = (data - torch.mean(data)) / torch.std(data)
        if self.transforms is not None:
            # reconstruct to image and do transforms
            data = data.reshape(1, 28, 28)
            data = self.transforms(data)
            # flatten
            data = data.reshape(-1)
        return data, torch.tensor(int(label))

# Model

In [46]:
class MLP(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_hidden_layers, output_size):
        super(MLP, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.output_size = output_size
        self.total_flops = 0
        self.layers = torch.nn.ModuleList()
        for layer_idx in range(self.num_hidden_layers):
            if layer_idx == 0:
                self.layers.append(torch.nn.Linear(self.input_size, self.hidden_size))
                self.total_flops += self.input_size * self.hidden_size
            else:
                self.layers.append(torch.nn.Linear(self.hidden_size, self.hidden_size))
                self.total_flops += self.hidden_size * self.hidden_size
            self.layers.append(torch.nn.ReLU())
            self.total_flops += self.hidden_size
        self.layers.append(torch.nn.Linear(self.hidden_size, self.output_size))
        self.total_flops += self.hidden_size * self.output_size
        self.softmax = torch.nn.Softmax(dim=1)
        self.total_flops += self.output_size * np.log2(self.output_size)
    
    def forward(self, x):
        # x: (batch_size, input_size)
        for layer in self.layers:
            x = layer(x)
        x = self.softmax(x) 
        return x
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


# Train

In [117]:
import os

import torch
import time
import tqdm
import numpy as np
import math


class Trainer:
    def __init__(self, 
                 model, 
                 train_dl, 
                 valid_dl, 
                 config, 
                 ckpt_path='ckpt', 
                 device='cuda', 
                 save_ckpts=False,
                 no_train=False):
        self.model = model.to(device)
        self.train_dl = train_dl
        self.valid_dl = valid_dl
        self.config = config
        self.no_train = no_train
        if not no_train:
            self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config['lr'])
            self.criterion = torch.nn.CrossEntropyLoss().to(device)
        self.training_time = 0
        self.training_time_per_epoch = []
        self.total_inference_time = 0
        self.inference_time = []
        self.avg_inference_time_per_epoch = []
        self.ckpt_path = ckpt_path
        self.best_acc = 0
        self.device = device
        self.save_ckpt = save_ckpts

    def train(self):
        start = time.time()
        for epoch in range(self.config['epochs']):
            start_time = time.time()
            self.train_one_epoch()
            end_time = time.time()
            self.training_time_per_epoch.append(end_time - start_time)
            acc = self.validate()
            if acc > self.best_acc:
                self.best_acc = acc
                if self.save_ckpt:
                    self.save(epoch)
        end = time.time()
        self.training_time = end - start

    def train_one_epoch(self):
        batch_bar = tqdm.tqdm(total=len(self.train_dl), dynamic_ncols=True, leave=True, position=0, desc='Train',
                              ncols=5)
        num_correct = 0
        total_loss = 0
        total_inference_epoch = 0

        for idx, (train_data, train_label) in enumerate(self.train_dl):
            unit_num_correct, unit_total_loss, batch_inference_time = self.train_one_batch(train_data, train_label)
            total_inference_epoch += batch_inference_time
            num_correct += unit_num_correct
            total_loss += unit_total_loss
            batch_bar.set_postfix(
                acc=f"{100 * num_correct / (self.config['batch_size'] * (idx + 1)):.4f}",
                loss=f"{total_loss / (self.config['batch_size'] * (idx + 1)):.4f}",
                num_correct=f"{num_correct}",
                lr=f"{self.optimizer.param_groups[0]['lr']:.4f}"
            )
            batch_bar.update()

        self.avg_inference_time_per_epoch.append(total_inference_epoch / len(self.train_dl.dataset))
        batch_bar.close()

    def train_one_batch(self, train_data, train_label):
        train_data = train_data.to(self.device)
        train_label = train_label.to(self.device)
        self.optimizer.zero_grad()
        inference_start = time.time()
        output = self.model(train_data)
        inference_end = time.time()
        self.total_inference_time += inference_end - inference_start
        # l2 regularization
        # loss = self.criterion(output, train_label) - 0.001 * torch.norm(self.model.layers[0].weight, p=2)
        loss = self.criterion(output, train_label)
        num_correct = torch.sum(torch.argmax(output, dim=1) == train_label).item()
        total_loss = float(loss.item())
        loss.backward()
        self.optimizer.step()

        return num_correct, total_loss, inference_end - inference_start

    def validate(self):
        # report accuracy
        correct_count = 0
        valid_loss = 0
        num_samples = 0
        for idx, (valid_data, valid_label) in enumerate(self.valid_dl):
            num_samples += len(valid_data)
            valid_data = valid_data.to(self.device)
            valid_label = valid_label.to(self.device)
            output = self.model(valid_data)
            pred = torch.argmax(output, dim=1)  # (batch_size)
            correct_count += torch.sum(pred == valid_label).item()
            if not self.no_train:
                loss = self.criterion(output, valid_label).item()
                valid_loss += loss
        if not self.no_train:
            valid_loss /= num_samples
        acc = correct_count / len(self.valid_dl.dataset)
        print(f'\nValidation Loss: {valid_loss if not self.no_train else "n/a"}\t Validation Accuracy: {correct_count / len(self.valid_dl.dataset)}')
        return acc
    
    def calculate_inference_latency(self, val_dl):
        # calculate inference latency per batch
        total_inference_time = 0
        num_of_batches = 0
        for idx, (valid_data, valid_label) in enumerate(val_dl):
            num_of_batches += 1
            valid_data = valid_data.to(self.device)
            inference_start = time.time()
            _ = self.model(valid_data)
            inference_end = time.time()
            total_inference_time += inference_end - inference_start
        return total_inference_time / num_of_batches

    def save(self, epoch):
        ckpt_dir = os.path.join(self.ckpt_path, f'{time.strftime("%m-%d-%H-%M", time.localtime())}')
        if not os.path.exists(ckpt_dir):
            os.makedirs(ckpt_dir)
        print(f'Saving model at Epoch {epoch}')
        ckpt_path = os.path.join(ckpt_dir, f'best_acc.pt')
        torch.save(self.model.state_dict(), ckpt_path)

    def get_training_stats(self):
        # we need flops, avg inference per sample and accuracy
        return {
            'num_parameters': self.model.param_count,
            'num_flops': self.model.flops,
            'training_time': self.training_time,
            'training_time_per_epoch': np.mean(self.training_time_per_epoch),
            'best_acc': self.best_acc,
            'average_inference_time_per_sample': self.total_inference_time / (self.config['epochs'] * len(self.train_dl.dataset)),
        }
    def report(self):
        print(f'Number of parameters: {self.model.count_parameters()}')
        print(f'Total number of FLOPs: {self.model.total_flops}')
        print(f'Average training time per epoch: {np.mean(self.training_time_per_epoch)} seconds')
        print(f'Training time: {self.training_time} seconds')
        print(
            f'Average inference time per sample: {self.total_inference_time / (self.config["epochs"] * len(self.train_dl.dataset))} seconds')


# Baseline

In [52]:
# dataloaders
from torchvision import transforms

sst_train_dataset = SSTDataset(data_dir='SST-2', data_type='train', top_k=5000)
sst_dev_dataset = SSTDataset(data_dir='SST-2', data_type='dev', vocab=sst_train_dataset.vocab)
mnist_train_dataset = MNISTDataset(data_dir='MNIST', data_type='train', transforms=transforms.CenterCrop((20, 20)))
mnist_valid_dataset = MNISTDataset(data_dir='MNIST', data_type='test', transforms=transforms.CenterCrop((20, 20)))


sst_dataloader = torch.utils.data.DataLoader(
    dataset=sst_train_dataset,
    batch_size=sst_config['batch_size'],
    shuffle=True
)
sst_dev_dataloader = torch.utils.data.DataLoader(
    dataset=sst_dev_dataset,
    batch_size=sst_config['batch_size'],
    shuffle=False
)
mnist_dataloader = torch.utils.data.DataLoader(dataset=mnist_train_dataset,
                                               batch_size=mnist_config['batch_size'],
                                               shuffle=True)
mnist_valid_dataloader = torch.utils.data.DataLoader(dataset=mnist_valid_dataset,
                                                     batch_size=mnist_config['batch_size'],
                                                     shuffle=False)

In [53]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [54]:
# MNIST model
mnist_model = MLP(input_size=mnist_config['input_size'], hidden_size=mnist_config['hidden_size'], num_hidden_layers=mnist_config['hidden_layers'], output_size=mnist_config['output_size'])

mnist_trainer = Trainer(
    model=mnist_model,
    train_dl=mnist_dataloader,
    valid_dl=mnist_valid_dataloader,
    config=mnist_config,
    device=device
)

mnist_trainer.train()
mnist_trainer.report()

Train: 100%|██████████| 938/938 [00:06<00:00, 138.93it/s, acc=90.5967, loss=0.0243, lr=0.0010, num_correct=54387]



Validation Loss: 0.023860583079792354	 Validation Accuracy: 0.9411941194119412


Train: 100%|██████████| 938/938 [00:06<00:00, 138.61it/s, acc=93.7450, loss=0.0238, lr=0.0010, num_correct=56277]



Validation Loss: 0.023877536336092107	 Validation Accuracy: 0.93999399939994
Number of parameters: 1470474
Total number of FLOPs: 1470497.2192809489
Average training time per epoch: 6.760261416435242 seconds
Training time: 14.094148874282837 seconds
Average inference time per sample: 1.3637519677095937e-05 seconds


In [55]:
# SST model
sst_model = MLP(input_size=len(sst_train_dataset.vocab), 
                hidden_size=sst_config['hidden_size'], 
                num_hidden_layers=sst_config['hidden_layers'], 
                output_size=sst_config['output_size'])

sst_trainer = Trainer(
    model=sst_model,
    train_dl=sst_dataloader,
    valid_dl=sst_dev_dataloader,
    config=sst_config,
    device=device
)
sst_trainer.train()

Train: 100%|██████████| 1053/1053 [00:10<00:00, 99.11it/s, acc=82.3392, loss=0.0075, lr=0.0010, num_correct=55490] 



Validation Loss: 0.008014699495440229	 Validation Accuracy: 0.8107798165137615


Train: 100%|██████████| 1053/1053 [00:10<00:00, 99.00it/s, acc=89.5878, loss=0.0064, lr=0.0010, num_correct=60375]



Validation Loss: 0.007898959681528425	 Validation Accuracy: 0.819954128440367


In [89]:
# get stats for baseline, 
#   - inference latency for batch size 1 and 64 over 5 runs
#   - model size
#   - parameter counts

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    size=os.path.getsize("temp.p")
    # print("model: ",label,' \t','Size (MB):', size/1e6)
    os.remove('temp.p')
    return size

def get_stats(num_runs,
              model,
              trainer,
              val_ds,
              model_type):
    stats = {}
    stats['model_name'] = model_type

    stats['acc'] = trainer.best_acc
    if model_type.startswith('mnist'):
        print(f'model name: {model_type}')
        for batch_size in [64, 1]:
            print(f'Running for batch size {batch_size}')
            avg_inference_latency = []
            for _ in range(num_runs):
                val_dl = torch.utils.data.DataLoader(
                    dataset=val_ds,
                    batch_size=batch_size,
                    shuffle=False
                )
                stats['model_size_mb'] = print_size_of_model(model) / 1e6
                stats['parameter_count'] = model.count_parameters()
                avg_inference_latency.append(trainer.calculate_inference_latency(val_dl))
            avg_inference_latency, std_inference_latency = np.mean(avg_inference_latency), np.std(avg_inference_latency)
            stats[f'avg_inference_latency_{batch_size}'] = avg_inference_latency
            stats[f'std_inference_latency_{batch_size}'] = std_inference_latency
    elif model_type.startswith('sst'):
        print(f'model name: {model_type}')
        for batch_size in [64, 1]:
            print(f'Running for batch size {batch_size}')
            avg_inference_latency = []
            for _ in range(num_runs):
                val_dl = torch.utils.data.DataLoader(
                    dataset=val_ds,
                    batch_size=batch_size,
                    shuffle=False
                )
                stats['model_size'] = print_size_of_model(model) / 1e6
                stats['parameter_count'] = model.count_parameters()
                avg_inference_latency.append(trainer.calculate_inference_latency(val_dl))
            avg_inference_latency, std_inference_latency = np.mean(avg_inference_latency), np.std(avg_inference_latency)
            stats[f'avg_inference_latency_{batch_size}'] = avg_inference_latency
            stats[f'std_inference_latency_{batch_size}'] = std_inference_latency
        
    return stats

In [90]:
mnist_stats = get_stats(5, mnist_model, mnist_trainer, mnist_valid_dataset, 'mnist')
print(mnist_stats)

model name: mnist
Running for batch size 64
Running for batch size 1
{'model_name': 'mnist', 'acc': 0.9411941194119412, 'model_size_mb': 5.883839, 'parameter_count': 1470474, 'avg_inference_latency_64': 0.0005422801728461199, 'std_inference_latency_64': 0.00011336873448196703, 'avg_inference_latency_1': 6.376058653076002e-05, 'std_inference_latency_1': 5.088073975139941e-06}


In [91]:
sst_stats = get_stats(5, sst_model, sst_trainer, sst_dev_dataset, 'sst')
print(sst_stats)

model name: sst
Running for batch size 64
Running for batch size 1
{'model_name': 'sst', 'acc': 0.819954128440367, 'model_size': 5.652849, 'parameter_count': 1412610, 'avg_inference_latency_64': 0.0006016629082815988, 'std_inference_latency_64': 2.821245048805144e-05, 'avg_inference_latency_1': 7.584997273366385e-05, 'std_inference_latency_1': 4.373535918128275e-07}


In [92]:
import pandas as pd

df = pd.DataFrame([mnist_stats, sst_stats])

In [93]:
df.to_csv('baseline.csv', index=False)

# Dynamic Quantization with PyTorch

In [123]:
import torch.quantization
print(torch.backends.quantized.supported_engines)
torch.backends.quantized.engine = 'qnnpack'


['qnnpack', 'none']


In [124]:
# MNIST model
mnist_model = MLP(input_size=mnist_config['input_size'], hidden_size=mnist_config['hidden_size'], num_hidden_layers=mnist_config['hidden_layers'], output_size=mnist_config['output_size'])

quantized_mnist_model_8 = torch.quantization.quantize_dynamic(
    mnist_model, {torch.nn.Linear}, dtype=torch.qint8
)

quantized_mnist_model_16 = torch.quantization.quantize_dynamic(
    mnist_model, {torch.nn.Linear}, dtype=torch.float16
)


RuntimeError: quantized::linear_prepack_fp16 is currently not supported by QNNPACK

In [118]:
mnist_trainer_quantized_8 = Trainer(
    model=quantized_mnist_model_8,
    train_dl=mnist_dataloader,
    valid_dl=mnist_valid_dataloader,
    config=mnist_config,
    device=device,
    no_train=True
)
mnist_trainer_quantized_8.validate()


Validation Loss: n/a	 Validation Accuracy: 0.1338133813381338


0.1338133813381338

In [110]:
eight_bit_stats = get_stats(5, quantized_mnist_model_8, mnist_trainer, mnist_valid_dataset, 'mnist_8bit')
print(eight_bit_stats)

ValueError: optimizer got an empty parameter list