
# Imports


In [2]:
import numpy as np
import torch
import torch.nn as nn


# Define the neural network, load data, train it

In [3]:
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# Define a simple neural network
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Set the device to use for computation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set up the network and optimizer
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# Load the training data
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                #       transforms.Normalize((0.1307,), (0.3081,))
                   ])
    ),
    batch_size=64, shuffle=True)

# Train the model
model.train()
for epoch in range(10):  # 10 epochs
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = nn.CrossEntropyLoss()(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 91939963.61it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 66962793.71it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 28918513.39it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 5658012.70it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [4]:
# Load the test data
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                   #    transforms.Normalize((0.1307,), (0.3081,))
                   ])
    ),
    batch_size=1000, shuffle=True)


In [5]:
from copy import deepcopy
q_model_dict = deepcopy(model.state_dict())


# Construct and train an Observer Model.

This network records the maximum and minimum of each layer in the training set when loaded with the trained weights from the previous step

In [6]:
class ObserveNet(nn.Module):
    def __init__(self):
        super(ObserveNet, self).__init__()
        self.fc1 = nn.Linear(784, 500)
        self.fc2 = nn.Linear(500, 10)
        self.so_1_max = torch.tensor(-float('inf'))
        self.so_2_max = torch.tensor(-float('inf'))
        self.so_1_min = torch.tensor(float('inf'))
        self.so_2_min = torch.tensor(float('inf'))

    def forward(self, x):

        x = x.view(-1, 784)
        x = torch.relu(self.fc1(x))

        temp_max = torch.max(x)
        temp_min = torch.min(x)
        self.so_1_max = temp_max if temp_max > self.so_1_max else self.so_1_max
        self.so_1_min = temp_min if temp_min < self.so_1_min else self.so_1_min

        x = self.fc2(x)
        temp_max = torch.max(x)
        temp_min = torch.min(x)
        self.so_2_max = temp_max if temp_max > self.so_2_max else self.so_2_max
        self.so_2_min = temp_min if temp_min < self.so_2_min else self.so_2_min

        return x


o_net = ObserveNet()
o_net.load_state_dict(q_model_dict)

def test(model, device, test_loader):
    model.eval()  # set the model to evaluation mode
    test_loss = 0
    correct = 0
    with torch.no_grad():  # disable gradient computation
        for data, target in test_loader:
            data = torch.round(data)
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += nn.CrossEntropyLoss()(output, target).item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

test(o_net, device, train_loader)



Test set: Average loss: 0.0008, Accuracy: 59186/60000 (99%)



# Calculate scale constants of output for layer1 and layer2

In [7]:
so_1 = max(o_net.so_1_max, torch.abs(o_net.so_1_min)).float().item() / 127
so_2 = max(o_net.so_2_max, torch.abs(o_net.so_2_min)).float().item() / 127
print(so_1)
print(so_2)

0.05031989315363366
0.2053285433551458


# Quantize the layer weights

In [8]:
from copy import deepcopy
q_model_dict = deepcopy(model.state_dict())

# Returns the maximum value of a tensor devided by `m` which is the maximum n-bit int value
# in quantization range.
def max_scale(x: torch.tensor, m: int):
  return torch.max(torch.abs(x)).item()/m

# Scale of fc1 and fc2 determined by maximum value of int8 (127) and maximum weight value.
def quantize_fc(x: torch.tensor, m: int):
  return max_scale(x,m), torch.round(x / max_scale(x, m)).to(dtype=torch.int32)

# Scale of bias determined by scale of the output of fc layer
# Which is the scale of input multiplied by scale of the fc layer.
def quantize_bias(x: torch.tensor, s: float):
  # saturate
  return torch.clip(torch.round(x/s), min=-127, max=127)

s_fc1, q_model_dict['fc1.weight'] = quantize_fc(q_model_dict['fc1.weight'], 127)
q_model_dict['fc1.bias'] = quantize_bias(q_model_dict['fc1.bias'], s_fc1 * 1 / 127)
s_fc2, q_model_dict['fc2.weight'] = quantize_fc(q_model_dict['fc2.weight'], 127)
q_model_dict['fc2.bias'] = quantize_bias(q_model_dict['fc2.bias'], s_fc2*so_1)

class QuantNet(nn.Module):
    def __init__(self, s_fc1, s_fc2, so_1, so_2):
        super(QuantNet, self).__init__()
        self.fc1 = nn.Linear(784, 500)
        self.fc2 = nn.Linear(500, 10)
        self.s_fc1 = s_fc1
        self.s_fc2 = s_fc2
        self.so_1 = so_1
        self.so_2 = so_2
        self.s_x = 1 / 127

    def forward(self, x):
        x = x.view(-1, 784)

        # Scale input
        x = torch.round(x / self.s_x)

        x = torch.relu(self.fc1(x))

        # Requantize and saturation cast
        x = np.clip(torch.round(x * ((self.s_fc1 * self.s_x) / self.so_1)), -127, 127)

        x = self.fc2(x)

        x = np.clip(torch.round(x * ((self.s_fc2 * self.so_1) / self.so_2)), -127, 127)

        return x * self.so_2

q_net = QuantNet(s_fc1=s_fc1, s_fc2=s_fc2, so_1=so_1, so_2=so_2)
q_net.load_state_dict(q_model_dict)

<All keys matched successfully>

# Print the scaling values so they can be used inside of Urbit

In [9]:
print(so_1)
print(so_2)
print(s_fc1)
print(s_fc2)

0.05031989315363366
0.2053285433551458
0.002196800051711676
0.005814554184440553


# Test the QuantNet on the test set

In [10]:
torch.manual_seed(1)
def test(model, device, test_loader):
    model.eval()  # set the model to evaluation mode
    test_loss = 0
    correct = 0
    with torch.no_grad():  # disable gradient computation
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += nn.CrossEntropyLoss()(output, target).item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

# Run the test function
test(q_net, device, test_loader)



Test set: Average loss: 0.0001, Accuracy: 9773/10000 (98%)



# Write the QuantNet weights to disk as int32

In [11]:
for name, param in q_net.named_parameters():
    print(param.detach().int().numpy())
    def to_byte_array(array, name):
        # Ensure the array is float32
        array = array.astype(np.int32)

        # Flatten the array in column-major order
        flattened = array.flatten(order='C')

        # Convert to byte array
        byte_array = flattened.tobytes()

        # Write byte array to a file
        with open(name, 'wb') as f:
          f.write(byte_array)


        return byte_array

    # Test the function
    to_byte_array(param.detach().int().numpy(), f'{name}.mnist')

[[  8  -5  -3 ...   0   9  12]
 [ 14   3 -14 ...   0 -11 -12]
 [ 10 -12 -10 ...   7   4 -11]
 ...
 [ -9   5  15 ...  -8   5   3]
 [ 10  -8 -12 ...   2   6 -12]
 [  1 -10   7 ...  -5 -16   7]]
[ 105  111  127 -127  127  127  127 -127 -127  127 -127 -127  127 -127
  127  127  127 -127  127  127  127  127    7  127 -127  127  127  127
 -127 -127  -79 -127 -127 -127  127  127 -127  127 -127  127  127 -127
  127  127 -127  127  127  127 -127  127  127 -127 -127  127 -127  127
    6  127  127  127 -127  127  127  127 -127 -127  127  127  127 -127
  127 -127 -127 -127 -127  127 -127 -127  127  127 -127  127  127  127
  127  127  127  127 -127  127 -127 -127  127  127 -127  127  127  127
  127  127 -127  127 -127  127  127  127 -127 -127  127  127 -127 -127
 -127 -127  127 -127 -127  127  127 -127  127 -127  127  127  127 -127
 -127  127 -127  127 -127  127 -127 -127  127 -127  127  127  127  127
 -127  127 -127  127 -127 -127  127  127  127  127 -127 -127  -38 -127
   81 -127  127  127 -127 -