In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from utils import get_mnist, QALinear, train, test, LinearInt

In [3]:
train_loader = get_mnist(batch_size=64, train=True)
test_loader = get_mnist(batch_size=128, train=False)

In [11]:
bits_to_dtype = {
    8: torch.int8,
    16: torch.int16,
    32: torch.int32,
}

class Model(nn.Module):
    def __init__(self, hidden_features: int = 256):
        super().__init__()
        self.fc1 = nn.Linear(28*28, hidden_features)
        self.fc2 = nn.Linear(hidden_features, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def to_qat(self, bits: int) -> "Model":
        new_model = Model()
        new_model.fc1 = QALinear.from_linear(self.fc1, bits)
        new_model.fc2 = QALinear.from_linear(self.fc2, bits, only_positive_activations=True)
        return new_model

    def quantize(self, bits: int) -> "Model":
        assert (isinstance(self.fc1, QALinear) and isinstance(self.fc2, QALinear))
        int_dtype = bits_to_dtype[bits]
        new_model = Model()
        new_model.fc1 = LinearInt.from_qat(self.fc1, int_dtype)
        new_model.fc2 = LinearInt.from_qat(self.fc2, int_dtype)
        return new_model

In [5]:
model = Model(hidden_features=16).cuda()

In [6]:
epochs = 1
optimizer = optim.AdamW(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(1, epochs + 1):
    train(
        model,
        epoch,
        loss_fn,
        optimizer,
        train_loader,
        use_cuda=True,
        log_interval=500,
    )



In [7]:
test(model, loss_fn, None, test_loader, use_cuda=True)


Test set: Average loss: 0.0022, Accuracy: 9163/10000 (92%)



tensor(91.6300)

In [12]:
model_qat = model.to_qat(bits=8)

In [None]:
epochs = 3
optimizer = optim.AdamW(model_qat.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(1, epochs + 1):
    train(
        model_qat,
        epoch,
        loss_fn,
        optimizer,
        train_loader,
        use_cuda=True,
        log_interval=500,
    )
    # model does not change
    test(model, loss_fn, None, test_loader, use_cuda=True)
    test(model_qat, loss_fn, None, test_loader, use_cuda=True)
    model_quantized = model_qat.quantize(bits=8).to('cpu')
    test(model_quantized, loss_fn, None, test_loader, use_cuda=False)


Test set: Average loss: 0.0022, Accuracy: 9163/10000 (92%)


Test set: Average loss: 0.0017, Accuracy: 9332/10000 (93%)


Test set: Average loss: 0.0017, Accuracy: 9339/10000 (93%)


Test set: Average loss: 0.0022, Accuracy: 9163/10000 (92%)


Test set: Average loss: 0.0016, Accuracy: 9386/10000 (94%)


Test set: Average loss: 0.0016, Accuracy: 9369/10000 (94%)


Test set: Average loss: 0.0022, Accuracy: 9163/10000 (92%)


Test set: Average loss: 0.0014, Accuracy: 9492/10000 (95%)


Test set: Average loss: 0.0041, Accuracy: 9077/10000 (91%)



# OVERFLOW

In [71]:
def toi(t, i='8'):
    if i == '8':
        return t.to(torch.int8).to('cpu')
    elif i == '16':
        return t.to(torch.int16).to('cpu')
    elif i == '32':
        return t.to(torch.int32).to('cpu')
    assert 0

In [72]:
test_inp = next(iter(test_loader))
test_x, test_y = test_inp

In [73]:
a1 = test_x.view(-1, 28*28)
w1 = model_qat.fc1.fc.weight
b1 = model_qat.fc1.fc.bias

In [74]:
qa1, a1s = model_qat.fc1.quantizer_act(a1)
dqa1 = qa1 * a1s

In [75]:
qw1, w1s = model_qat.fc1.quantizer_weight(w1)
dqw1 = qw1 * w1s

In [76]:
# int8 dequantized
a1s.cpu() * w1s.cpu() * F.linear(toi(qa1), toi(qw1), None) + b1.cpu()

tensor([[ 0.0990,  0.0036, -0.1174,  ..., -0.0033, -0.1117, -0.0286],
        [ 0.2424,  0.0515,  0.1097,  ...,  0.1701,  0.0537, -0.2258],
        [ 0.0850, -0.1777,  0.2692,  ..., -0.1787, -0.2751, -0.2159],
        ...,
        [ 0.1189,  0.0893, -0.0676,  ..., -0.1268,  0.1334,  0.2544],
        [ 0.0890, -0.0581,  0.1954,  ...,  0.0266,  0.1971,  0.1707],
        [-0.0465,  0.1730,  0.1974,  ...,  0.2657, -0.1994, -0.0664]],
       grad_fn=<AddBackward0>)

In [77]:
# int16 dequantized
a1s.cpu() * w1s.cpu() * F.linear(toi(qa1, '16'), toi(qw1, '16'), None) + b1.cpu()

tensor([[ 31.2156,  -6.1177,  -6.2387,  ...,  12.2393,   7.5399,   6.6028],
        [ -4.8586,  17.3952,  10.8220,  ...,   2.2105, -18.8204,   0.7944],
        [  6.2063,  14.6155,  16.5926,  ...,  13.5943,  -5.3762,  -2.7664],
        ...,
        [ 12.3615,  -3.4814,  -1.5979,  ...,  14.1562,   9.8254,   3.8251],
        [ 16.4125,   0.4520,   2.2359,  ...,  13.2894,  -0.3130,  10.3729],
        [ 10.1557,  -2.3775,  -3.3733,  ...,  10.9780,  -4.2803,   2.4841]],
       grad_fn=<AddBackward0>)

In [78]:
# fp32 dequantized
a1s.cpu() * w1s.cpu() * F.linear(qa1.cpu(), qw1.cpu(), None) + b1.cpu()

tensor([[ 31.2156,  -6.1177,  -6.2387,  ...,  12.2393,   7.5399,   6.6028],
        [ -4.8586,  17.3952,  10.8220,  ...,   2.2105, -18.8204,   0.7944],
        [  6.2063,  14.6155,  16.5926,  ...,  13.5943,  -5.3762,  -2.7664],
        ...,
        [ 12.3615,  -3.4814,  -1.5979,  ...,  14.1562,   9.8254,   3.8251],
        [ 16.4125,   0.4520,   2.2359,  ...,  13.2894,  -0.3130,  10.3729],
        [ 10.1557,  -2.3775,  -3.3733,  ...,  10.9780,  -4.2803,   2.4841]],
       grad_fn=<AddBackward0>)

In [79]:
# fp32 before dequantization
F.linear(qa1.cpu(), qw1.cpu(), bias=None)

tensor([[-15630.,   3071.,   3177.,  ...,  -6115.,  -3809.,  -3311.],
        [  2474.,  -8729.,  -5385.,  ...,  -1082.,   9420.,   -396.],
        [ -3079.,  -7334.,  -8281.,  ...,  -6795.,   2673.,   1391.],
        ...,
        [ -6168.,   1748.,    848.,  ...,  -7077.,  -4956.,  -1917.],
        [ -8201.,   -226.,  -1076.,  ...,  -6642.,    132.,  -5203.],
        [ -5061.,   1194.,   1739.,  ...,  -5482.,   2123.,  -1244.]],
       grad_fn=<MmBackward0>)