# MNIST Low Precision Training Example
In this notebook, we present a quick example of how to simulate training a deep neural network in low precision with QPyTorch. The (very small) MNIST data set is used as it is trainable in about 10 minutes on a notebook computer.

## 1. Training MNIST in Floating Point

In [1]:
# import useful modules
import argparse
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from qtorch.quant import Quantizer
from qtorch.optim import OptimLP
from torch.optim import SGD
from qtorch import BlockFloatingPoint, FloatingPoint, FixedPoint
from tqdm import tqdm

Record start time so we can time execution.

In [2]:
import time
start_time = time.time()

We first load the data. In this example, we will experiment with MNIST.

In [3]:
# loading data
ds = torchvision.datasets.MNIST
path = os.path.join("./data", "MNIST")
transform_train = transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))])
transform_test = transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])
train_set = ds(path, train=True, download=True, transform=transform_train)
test_set = ds(path, train=False, download=True, transform=transform_test)
loaders = {
        'train': torch.utils.data.DataLoader(
            train_set,
            batch_size=64,
            shuffle=True,
            num_workers=1,
            pin_memory=True
        ),
        'test': torch.utils.data.DataLoader(
            test_set,
            batch_size=64,
            num_workers=1,
            pin_memory=True
        )
}

We then define the quantization setting we are going to use. We define a low and high precision format for different parts of the computation.

In [4]:
# define two floating point formats
lowp = FixedPoint(wl=8, fl=7)
highp = FloatingPoint(exp=8, man=7)  # this is bfloat16

# define quantization functions
weight_quant = Quantizer(forward_number=lowp, backward_number=None,
                        forward_rounding="nearest", backward_rounding="nearest")
grad_quant = Quantizer(forward_number=lowp, backward_number=None,
                        forward_rounding="nearest", backward_rounding="stochastic")
momentum_quant = Quantizer(forward_number=highp, backward_number=None,
                        forward_rounding="nearest", backward_rounding="stochastic")
acc_quant = Quantizer(forward_number=highp, backward_number=None,
                        forward_rounding="nearest", backward_rounding="nearest")

# define a lambda function so that the Quantizer module can be duplicated easily
act_error_quant = lambda : Quantizer(forward_number=lowp, backward_number=lowp,
                        forward_rounding="nearest", backward_rounding="nearest")

Next, we define a simple LeNet network. 

In [5]:
# let's define the model we are using
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)



We now use the low-precision optimizer wrapper to help define the quantization of weight, gradient, momentum, and gradient accumulator.

In [6]:
use_cuda = False
device = torch.device("cuda" if use_cuda else "cpu")
model = Net().to(device)
optimizer = SGD(model.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4)
mxepochs = 10

We can reuse common training scripts without any extra codes to handle quantization.

In [7]:
def run_epoch(loader, model, criterion, optimizer=None, phase="train"):
    assert phase in ["train", "eval"], "invalid running phase"
    loss_sum = 0.0
    correct = 0.0

    if phase=="train": model.train()
    elif phase=="eval": model.eval()

    ttl = 0
    with torch.autograd.set_grad_enabled(phase=="train"):
        for i, (input, target) in tqdm(enumerate(loader), total=len(loader)):
            input = input.to(device=device)
            target = target.to(device=device)
            output = model(input)
            loss = criterion(output, target)
            loss_sum += loss.cpu().item() * input.size(0)
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
            ttl += input.size()[0]

            if phase=="train":
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

    correct = correct.cpu().item()
    return {
        'loss': loss_sum / float(ttl),
        'accuracy': correct / float(ttl) * 100.0,
    }

Run the training in floating point.

In [8]:
for epoch in range(mxepochs):
    fp_train_res = run_epoch(loaders['train'], model, F.cross_entropy,
                            optimizer=optimizer, phase="train")
    fp_test_res = run_epoch(loaders['test'], model, F.cross_entropy,
                            optimizer=optimizer, phase="eval")
    print('epoch', epoch)
    print(fp_train_res)
    print(fp_test_res)

100%|██████████| 938/938 [00:34<00:00, 26.82it/s]
100%|██████████| 157/157 [00:01<00:00, 82.45it/s]

epoch 0
{'loss': 0.13975346226207913, 'accuracy': 95.67833333333333}
{'loss': 0.048652271528047276, 'accuracy': 98.52}



100%|██████████| 938/938 [00:32<00:00, 28.81it/s]
100%|██████████| 157/157 [00:01<00:00, 84.60it/s]

epoch 1
{'loss': 0.04398178725702067, 'accuracy': 98.6}
{'loss': 0.02990385150157963, 'accuracy': 99.06}



100%|██████████| 938/938 [00:32<00:00, 28.87it/s]
100%|██████████| 157/157 [00:01<00:00, 89.04it/s]

epoch 2
{'loss': 0.03272670375339997, 'accuracy': 98.97166666666666}
{'loss': 0.037594826527242546, 'accuracy': 98.92}



100%|██████████| 938/938 [00:33<00:00, 27.78it/s]
100%|██████████| 157/157 [00:02<00:00, 78.46it/s]

epoch 3
{'loss': 0.02493680008233253, 'accuracy': 99.20333333333333}
{'loss': 0.03149537170268013, 'accuracy': 99.03}



100%|██████████| 938/938 [00:33<00:00, 28.12it/s]
100%|██████████| 157/157 [00:01<00:00, 80.47it/s]

epoch 4
{'loss': 0.020724476177947752, 'accuracy': 99.36500000000001}
{'loss': 0.03945704701770155, 'accuracy': 98.83}



100%|██████████| 938/938 [00:33<00:00, 28.33it/s]
100%|██████████| 157/157 [00:01<00:00, 85.93it/s]

epoch 5
{'loss': 0.015872756416495153, 'accuracy': 99.49166666666667}
{'loss': 0.03982625451160675, 'accuracy': 98.76}



100%|██████████| 938/938 [00:32<00:00, 28.67it/s]
100%|██████████| 157/157 [00:01<00:00, 84.76it/s]

epoch 6
{'loss': 0.012929226745905666, 'accuracy': 99.55833333333334}
{'loss': 0.033759961486868634, 'accuracy': 99.14}



100%|██████████| 938/938 [00:32<00:00, 28.81it/s]
100%|██████████| 157/157 [00:01<00:00, 85.80it/s]

epoch 7
{'loss': 0.012680374797104257, 'accuracy': 99.6}
{'loss': 0.03555196945745847, 'accuracy': 99.02}



100%|██████████| 938/938 [00:35<00:00, 26.65it/s]
100%|██████████| 157/157 [00:02<00:00, 78.24it/s]

epoch 8
{'loss': 0.011234540776508645, 'accuracy': 99.64}
{'loss': 0.035852644269423764, 'accuracy': 99.00999999999999}



100%|██████████| 938/938 [00:35<00:00, 26.37it/s]
100%|██████████| 157/157 [00:01<00:00, 81.83it/s]

epoch 9
{'loss': 0.01165138601524324, 'accuracy': 99.60333333333334}
{'loss': 0.032838685022593565, 'accuracy': 99.15}





## 2. Block Floating Point Inference

Now do it with quantized arithmetic. We first define the data types

In [9]:
# define two floating point formats
lowp = BlockFloatingPoint(wl=8, dim=-1)   
highp = FloatingPoint(exp=8, man=7)      # this is bfloat16

# define quantization functions
weight_quant = Quantizer(forward_number=lowp, backward_number=None,
                        forward_rounding="nearest", backward_rounding="nearest")
grad_quant = Quantizer(forward_number=lowp, backward_number=None,
                        forward_rounding="nearest", backward_rounding="stochastic")
momentum_quant = Quantizer(forward_number=highp, backward_number=None,
                        forward_rounding="nearest", backward_rounding="stochastic")
acc_quant = Quantizer(forward_number=highp, backward_number=None,
                        forward_rounding="nearest", backward_rounding="nearest")

# define a lambda function so that the Quantizer module can be duplicated easily
act_error_quant = lambda : Quantizer(forward_number=lowp, backward_number=lowp,
                        forward_rounding="nearest", backward_rounding="nearest")

Now we define the network. In the definition, we insert quantization module after every convolution layer. Note that the quantization of weight, gradient, momentum, and gradient accumulator are not handled here.

In [10]:
# let's define the model we are using
class lp_Net(nn.Module):
    def __init__(self, quant=None):
        super(lp_Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)
        self.quant = quant()

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.quant(x)
        x = F.max_pool2d(x, 2, 2)
        x = self.quant(x)
        x = F.relu(self.conv2(x))
        x = self.quant(x)
        x = F.max_pool2d(x, 2, 2)
        x = self.quant(x)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.quant(x)
        x = self.fc2(x)
        x = self.quant(x)
        return F.log_softmax(x, dim=1)

In [11]:
use_cuda = False
device = torch.device("cuda" if use_cuda else "cpu")
model = lp_Net(act_error_quant).to(device)
optimizer = SGD(model.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4)
lp_optimizer = OptimLP(optimizer,
                    weight_quant=weight_quant,
                    grad_quant=grad_quant,
                    momentum_quant=momentum_quant,
                    acc_quant=acc_quant
)
for epoch in range(mxepochs):
    train_res = run_epoch(loaders['train'], model, F.cross_entropy,
                                optimizer=lp_optimizer, phase="train")
    test_res = run_epoch(loaders['test'], model, F.cross_entropy,
                                optimizer=lp_optimizer, phase="eval")
    print('epoch', epoch)
    print(train_res)
    print(test_res)

100%|██████████| 938/938 [02:34<00:00,  6.46it/s]
100%|██████████| 157/157 [00:08<00:00, 17.79it/s]

epoch 0
{'loss': 0.134907551719745, 'accuracy': 95.73166666666667}
{'loss': 0.11002362940609454, 'accuracy': 96.78}



100%|██████████| 938/938 [02:25<00:00,  6.46it/s]
100%|██████████| 157/157 [00:06<00:00, 26.11it/s]

epoch 1
{'loss': 0.0475153127261127, 'accuracy': 98.55333333333334}
{'loss': 0.04490658403052948, 'accuracy': 98.55000000000001}



100%|██████████| 938/938 [02:09<00:00,  7.24it/s]
100%|██████████| 157/157 [00:05<00:00, 28.14it/s]

epoch 2
{'loss': 0.03311258720730742, 'accuracy': 98.985}
{'loss': 0.035893300364725295, 'accuracy': 98.92}



100%|██████████| 938/938 [02:05<00:00,  7.47it/s]
100%|██████████| 157/157 [00:06<00:00, 24.37it/s]

epoch 3
{'loss': 0.02580145928617567, 'accuracy': 99.18}
{'loss': 0.0380030062812788, 'accuracy': 98.75}



100%|██████████| 938/938 [02:08<00:00,  7.30it/s]
100%|██████████| 157/157 [00:05<00:00, 27.32it/s]

epoch 4
{'loss': 0.02097408466776833, 'accuracy': 99.34333333333333}
{'loss': 0.03383344763555724, 'accuracy': 99.05000000000001}



100%|██████████| 938/938 [02:04<00:00,  7.55it/s]
100%|██████████| 157/157 [00:05<00:00, 27.30it/s]

epoch 5
{'loss': 0.01713064722112031, 'accuracy': 99.44666666666667}
{'loss': 0.02794512863709533, 'accuracy': 99.14}



100%|██████████| 938/938 [02:03<00:00,  7.57it/s]
100%|██████████| 157/157 [00:05<00:00, 28.28it/s]

epoch 6
{'loss': 0.016992174969456392, 'accuracy': 99.44166666666666}
{'loss': 0.03320048576105037, 'accuracy': 99.03}



100%|██████████| 938/938 [02:07<00:00,  7.33it/s]
100%|██████████| 157/157 [00:05<00:00, 27.31it/s]

epoch 7
{'loss': 0.010031001010301407, 'accuracy': 99.67}
{'loss': 0.03418132924698075, 'accuracy': 99.03999999999999}



100%|██████████| 938/938 [02:01<00:00,  7.70it/s]
100%|██████████| 157/157 [00:05<00:00, 28.68it/s]

epoch 8
{'loss': 0.008979617804124184, 'accuracy': 99.72166666666666}
{'loss': 0.03690760707998743, 'accuracy': 99.03999999999999}



100%|██████████| 938/938 [02:19<00:00,  6.74it/s]
100%|██████████| 157/157 [00:05<00:00, 27.79it/s]

epoch 9
{'loss': 0.006913615543941948, 'accuracy': 99.78833333333334}
{'loss': 0.03011619034334435, 'accuracy': 99.2}





## 3. Accuracy vs wordlength

First include some libraries

In [12]:
import numpy as np
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt

Compute NN accuracy vs wordlength

In [None]:
res = []
(minp, maxp) = (1,8)
for wl in range(minp, maxp+1):
    # define two floating point formats
    lowp = BlockFloatingPoint(wl=wl, dim=-1)   
    highp = FloatingPoint(exp=8, man=7)      # this is bfloat16

    # define quantization functions
    weight_quant = Quantizer(forward_number=lowp, backward_number=None,
                            forward_rounding="nearest", backward_rounding="nearest")
    grad_quant = Quantizer(forward_number=lowp, backward_number=None,
                            forward_rounding="nearest", backward_rounding="stochastic")
    momentum_quant = Quantizer(forward_number=highp, backward_number=None,
                            forward_rounding="nearest", backward_rounding="stochastic")
    acc_quant = Quantizer(forward_number=highp, backward_number=None,
                            forward_rounding="nearest", backward_rounding="nearest")

    # define a lambda function so that the Quantizer module can be duplicated easily
    act_error_quant = lambda : Quantizer(forward_number=lowp, backward_number=lowp,
                            forward_rounding="nearest", backward_rounding="nearest")
    use_cuda = False
    device = torch.device("cuda" if use_cuda else "cpu")
    model = lp_Net(act_error_quant).to(device)
    optimizer = SGD(model.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4)
    lp_optimizer = OptimLP(optimizer,
                        weight_quant=weight_quant,
                        grad_quant=grad_quant,
                        momentum_quant=momentum_quant,
                        acc_quant=acc_quant
    )
    for epoch in range(mxepochs):
        train_res = run_epoch(loaders['train'], model, F.cross_entropy,
                                    optimizer=lp_optimizer, phase="train")
        test_res = run_epoch(loaders['test'], model, F.cross_entropy,
                                    optimizer=lp_optimizer, phase="eval")
        print('wl', wl, 'epoch', epoch)
        print(train_res)
        print(test_res)
        
    # make scatterplot
    res.append((wl, test_res['accuracy']))
print(res)

100%|██████████| 938/938 [02:04<00:00,  7.56it/s]
100%|██████████| 157/157 [00:05<00:00, 28.80it/s]

wl 1 epoch 0
{'loss': 2.459815688451131, 'accuracy': 10.183333333333334}
{'loss': 2.3025834590911867, 'accuracy': 10.09}



100%|██████████| 938/938 [05:07<00:00,  3.05it/s]  
100%|██████████| 157/157 [00:05<00:00, 27.92it/s]

wl 1 epoch 1
{'loss': 2.995434170659383, 'accuracy': 9.975000000000001}
{'loss': 2.556058658218384, 'accuracy': 10.09}



100%|██████████| 938/938 [02:10<00:00,  7.21it/s]
100%|██████████| 157/157 [00:05<00:00, 28.14it/s]

wl 1 epoch 2
{'loss': 2.4692039801279706, 'accuracy': 9.915000000000001}
{'loss': 2.3025834590911867, 'accuracy': 10.09}



  0%|          | 0/938 [00:00<?, ?it/s]

Make a scatterplot of the results, also draw a line to show the bfloat16 result

In [None]:
plt_res=np.array(res)
plt.plot(plt_res[:,0], plt_res[:,1], 'x')
plt.plot((minp,maxp),(fp_test_res['accuracy'], fp_test_res['accuracy']))


In [None]:
print("Total execution time (s):", time.time() - start_time)