In [None]:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# import matplotlib.pyplot as plt

In [None]:
class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, 1)
        self.relu1 = nn.ReLU()

        self.maxpool2 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.relu2 = nn.ReLU()

        self.conv3 = nn.Conv2d(6, 16, 5, 1)
        self.relu3 = nn.ReLU()

        self.maxpool4 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.relu4 = nn.ReLU()

        self.conv5 = nn.Conv2d(16, 120, 5, 1)
        self.relu5 = nn.ReLU()

        self.fc6 = nn.Linear(120, 10)
        self.relu6 = nn.ReLU()

    def forward(self, x):
        x = self.relu1(self.conv1(x))
        x = self.relu2(self.maxpool2(x))
        x = self.relu3(self.conv3(x))
        x = self.relu4(self.maxpool4(x))
        x = self.relu5(self.conv5(x)).squeeze()
        x = self.relu6(self.fc6(x))
        return x

In [None]:
with open('sd_card_files/params.bin', 'rb') as f:
    params = np.fromfile(f, dtype=np.float32) 
    conv1_weights = torch.from_numpy(params[:150].reshape(6, 1, 5, 5))
    conv1_bias = torch.from_numpy(params[150:156].reshape(6))
    conv3_weights = torch.from_numpy(params[156:156+2400].reshape(16, 6, 5, 5))
    conv3_bias = torch.from_numpy(params[2556:2572].reshape(16))
    conv5_weights = torch.from_numpy(params[2572:50572].reshape(120, 16, 5, 5))
    conv5_bias = torch.from_numpy(params[50572:50692].reshape(120))
    fc6_weights = torch.from_numpy(params[50692:51892].reshape(10, 120))
    fc6_bias = torch.from_numpy(params[51892:].reshape(10))

with open('sd_card_files/images.bin', 'rb') as f:
    images_raw = np.fromfile(f, dtype=np.uint8)
    images_raw = images_raw[16:].reshape(-1, 1, 28, 28)
    images = np.ones((images_raw.shape[0], 1, 32, 32)) * -1
    images[:, :, 2:30, 2:30] = images_raw / 255.0 * 2.0 - 1.0
    images = torch.from_numpy(images).float()

with open('sd_card_files/labels.bin', 'rb') as f:
    labels = np.fromfile(f, dtype=np.uint8)
    labels = torch.from_numpy(labels[8:])

In [None]:
lenet = LeNet()
lenet.conv1.weight.data = conv1_weights
lenet.conv1.bias.data = conv1_bias
lenet.conv3.weight.data = conv3_weights
lenet.conv3.bias.data = conv3_bias
lenet.conv5.weight.data = conv5_weights
lenet.conv5.bias.data = conv5_bias
lenet.fc6.weight.data = fc6_weights
lenet.fc6.bias.data = fc6_bias

In [None]:
plt.imshow(images[225, :, :].squeeze())
torch.argmax(lenet(images[225:226, :, :, :]), axis=-1)

In [None]:
pred = torch.argmax(lenet(images.reshape(-1, 1, 32, 32)), dim=-1).to(dtype=torch.uint8)
torch.sum(pred == labels) / pred.numel()

In [None]:
lenet.conv1.weight.requires_grad = False
lenet.conv1.bias.requires_grad = False
lenet.conv3.weight.requires_grad = False
lenet.conv3.bias.requires_grad = False
lenet.fc6.weight.requires_grad = False
lenet.fc6.bias.requires_grad = False

In [None]:
class Quantizer:
    def __init__(self, nbits=8) -> None:
        self.nbits = nbits

    def channel_quantize(self, weights):
        maximum = np.max(weights, axis=-1)
        minimum = np.min(weights, axis=-1)
        scale = (maximum - minimum) / (2 ** self.nbits - 1)
        scale = np.repeat(scale, weights.shape[-1], axis=-1).reshape(*scale.shape, weights.shape[-1])
        bias = minimum
        bias = np.repeat(bias, weights.shape[-1], axis=-1).reshape(*bias.shape, weights.shape[-1])
        t = ((weights - bias) / scale).round()

        t = t * scale + bias

        return t, scale, bias
    
    def tensor_quantize(self, weights):
        maximum = np.max(weights)
        minimum = np.min(weights)
        scale = (maximum - minimum) / (2 ** self.nbits - 1)
        bias = minimum
        t = ((weights - bias) / scale).round()

        t = t * scale + bias

        return t, scale, bias

In [None]:
quantizer = Quantizer(8)

In [None]:
t, scale, bias = quantizer.channel_quantize(lenet.conv5.weight.data.numpy().reshape(120, 16, 25))
t = t.reshape(120, 16, 5, 5)

In [None]:
lenet.conv5.weight.data = torch.from_numpy(t)

In [None]:
pred = torch.argmax(lenet(images.reshape(-1, 1, 32, 32)), dim=-1).to(dtype=torch.uint8)
torch.sum(pred == labels) / pred.numel()

In [None]:
quantized_w5 = ((t.reshape(120, 16, 25) - bias) / scale).round().reshape(120, 16, 5, 5).astype(np.uint8)
quantized_w5.tofile("quantized_w5.bin")

scale = scale[:, :, 0]
bias = bias[:, :, 0]

In [None]:
scale.tofile("scale.bin")
bias.tofile("bias.bin")

In [None]:
with open('quantized_w5.bin', 'rb') as f:
    params = np.fromfile(f, dtype=np.uint8) 
    read_from_file_quantized_w5 = params.reshape(120, 16, 5, 5)
with open('scale.bin', 'rb') as f:
    params = np.fromfile(f, dtype=np.float32) 
    read_from_file_scale = params.reshape(120, 16)
with open('bias.bin', 'rb') as f:
    params = np.fromfile(f, dtype=np.float32) 
    read_from_file_bias = params.reshape(120, 16)


print(np.sum(quantized_w5 == read_from_file_quantized_w5), np.sum(read_from_file_scale == scale), np.sum(read_from_file_bias == bias))

In [None]:
scale = np.repeat(scale, 25, axis=-1).reshape(*scale.shape, 25)
bias = np.repeat(bias, 25, axis=-1).reshape(*bias.shape, 25)
dequantized_w5 = (quantized_w5.reshape(120, 16, 25) * scale + bias).reshape(120, 16, 5, 5)

In [None]:
lenet.conv5.weight.data = torch.from_numpy(dequantized_w5)

In [None]:
pred = torch.argmax(lenet(images.reshape(-1, 1, 32, 32)), dim=-1).to(dtype=torch.uint8)
torch.sum(pred == labels) / pred.numel()