In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from utils import get_mnist, QALinear, train, test, LinearInt

In [3]:
train_loader = get_mnist(batch_size=64, train=True)
test_loader = get_mnist(batch_size=128, train=False)

In [4]:
bits_to_dtype = {
    8: torch.int8,
    16: torch.int16,
    32: torch.int32,
}

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def to_qat(self, bits: int) -> "Model":
        new_model = Model()
        new_model.fc1 = QALinear.from_linear(self.fc1, bits)
        new_model.fc2 = QALinear.from_linear(self.fc2, bits, only_positive_activations=True)
        return new_model

    def quantize(self, bits: int) -> "Model":
        assert (isinstance(self.fc1, QALinear) and isinstance(self.fc2, QALinear))
        int_dtype = bits_to_dtype[bits]
        new_model = Model()
        new_model.fc1 = LinearInt.from_quantized(self.fc1, int_dtype)
        new_model.fc2 = LinearInt.from_quantized(self.fc2, int_dtype)
        return new_model

In [5]:
model = Model().cuda()

In [6]:
epochs = 1
optimizer = optim.AdamW(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(1, epochs + 1):
    train(
        model,
        epoch,
        loss_fn,
        optimizer,
        train_loader,
        use_cuda=True,
        log_interval=500,
    )



In [7]:
model_qat = model.to_qat(bits=8)

In [8]:
epochs = 4
optimizer = optim.AdamW(model_qat.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(1, epochs + 1):
    train(
        model_qat,
        epoch,
        loss_fn,
        optimizer,
        train_loader,
        use_cuda=True,
        log_interval=500,
    )



In [9]:
model_quantized = model_qat.quantize(bits=8).to('cpu')

In [10]:
test(model, loss_fn, None, test_loader, use_cuda=True)


Test set: Average loss: 0.0005, Accuracy: 9788/10000 (98%)



tensor(97.8800)

In [11]:
test(model_qat, loss_fn, None, test_loader, use_cuda=True)


Test set: Average loss: 0.0006, Accuracy: 9770/10000 (98%)



tensor(97.7000)

In [12]:
test(model_quantized, loss_fn, None, test_loader, use_cuda=False)


Test set: Average loss: 0.0194, Accuracy: 984/10000 (10%)



tensor(9.8400)

# OVERFLOW

In [13]:
def toi(t, i='8'):
    if i == '8':
        return t.to(torch.int8).to('cpu')
    elif i == '16':
        return t.to(torch.int16).to('cpu')
    elif i == '32':
        return t.to(torch.int32).to('cpu')
    assert 0

In [14]:
test_inp = next(iter(test_loader))
test_x, test_y = test_inp

In [15]:
a1 = test_x.view(-1, 28*28)
w1 = model_qat.fc1.fc.weight
b1 = model_qat.fc1.fc.bias

In [16]:
qa1, a1s = model_qat.fc1.quantizer_act(a1)
dqa1 = qa1 * a1s

In [17]:
qw1, w1s = model_qat.fc1.quantizer_weight(w1)
dqw1 = qw1 * w1s

In [18]:
# int8 dequantized
a1s.cpu() * w1s.cpu() * F.linear(toi(qa1), toi(qw1), None) + b1.cpu()

tensor([[ 1.0006, -0.0435, -1.2142,  ...,  1.3935, -0.4836,  1.5288],
        [ 1.4083, -0.6861,  1.2076,  ..., -0.8553,  1.4935,  1.0346],
        [-0.8405, -1.3409, -0.2998,  ...,  1.3194, -1.3609,  1.4423],
        ...,
        [ 0.8152, -0.8467,  0.6516,  ..., -0.3240,  0.4185,  1.0593],
        [ 0.4569, -0.7108,  0.4168,  ..., -1.5473, -0.0017,  0.3055],
        [ 0.5557, -0.5996,  0.5157,  ..., -1.5473,  0.0107, -1.3379]],
       grad_fn=<AddBackward0>)

In [19]:
# int16 dequantized
a1s.cpu() * w1s.cpu() * F.linear(toi(qa1, '16'), toi(qw1, '16'), None) + b1.cpu()

tensor([[ 1.0006e+00, -6.3700e+00, -2.6520e+01,  ..., -1.7586e+01,
          2.6797e+00, -7.9609e+00],
        [-1.7549e+00, -3.8493e+00, -1.4609e+01,  ..., -4.0186e+00,
         -1.6698e+00, -8.4551e+00],
        [-8.4055e-01, -1.3409e+00, -1.6116e+01,  ..., -1.8439e+00,
         -7.6874e+00, -8.0474e+00],
        ...,
        [-1.1838e+01, -4.0099e+00, -1.2001e+01,  ..., -3.4873e+00,
          4.1845e-01, -8.4304e+00],
        [ 4.5687e-01, -7.1077e-01, -1.2236e+01,  ..., -1.1037e+01,
         -1.6653e-03, -6.0209e+00],
        [ 3.7190e+00, -3.7628e+00, -5.8108e+00,  ..., -4.7105e+00,
         -9.4790e+00, -7.6643e+00]], grad_fn=<AddBackward0>)

In [20]:
# fp32 dequantized
a1s.cpu() * w1s.cpu() * F.linear(qa1.cpu(), qw1.cpu(), None) + b1.cpu()

tensor([[ 1.0006e+00, -6.3700e+00, -2.6520e+01,  ..., -1.7586e+01,
          2.6797e+00, -7.9609e+00],
        [-1.7549e+00, -3.8493e+00, -1.4609e+01,  ..., -4.0186e+00,
         -1.6698e+00, -8.4551e+00],
        [-8.4055e-01, -1.3409e+00, -1.6116e+01,  ..., -1.8439e+00,
         -7.6874e+00, -8.0474e+00],
        ...,
        [-1.1838e+01, -4.0099e+00, -1.2001e+01,  ..., -3.4873e+00,
          4.1845e-01, -8.4304e+00],
        [ 4.5687e-01, -7.1077e-01, -1.2236e+01,  ..., -1.1037e+01,
         -1.6653e-03, -6.0209e+00],
        [ 3.7190e+00, -3.7628e+00, -5.8108e+00,  ..., -4.7105e+00,
         -9.4790e+00, -7.6643e+00]], grad_fn=<AddBackward0>)

In [21]:
# fp32 before dequantization
F.linear(qa1.cpu(), qw1.cpu(), bias=None)

tensor([[-8.2000e+01,  5.1700e+02,  2.1480e+03,  ...,  1.4240e+03,
         -2.1600e+02,  6.4500e+02],
        [ 1.4100e+02,  3.1300e+02,  1.1840e+03,  ...,  3.2600e+02,
          1.3600e+02,  6.8500e+02],
        [ 6.7000e+01,  1.1000e+02,  1.3060e+03,  ...,  1.5000e+02,
          6.2300e+02,  6.5200e+02],
        ...,
        [ 9.5700e+02,  3.2600e+02,  9.7300e+02,  ...,  2.8300e+02,
         -3.3000e+01,  6.8300e+02],
        [-3.8000e+01,  5.9000e+01,  9.9200e+02,  ...,  8.9400e+02,
          1.0000e+00,  4.8800e+02],
        [-3.0200e+02,  3.0600e+02,  4.7200e+02,  ...,  3.8200e+02,
          7.6800e+02,  6.2100e+02]], grad_fn=<MmBackward0>)