In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from main2 import get_mnist, QALinear, train, test, LinearInt

In [3]:
train_loader = get_mnist(batch_size=64, train=True)
test_loader = get_mnist(batch_size=128, train=False)

In [4]:
class Model(nn.Module):
    def __init__(self, bits=8):
        super(Model, self).__init__()
        self.fc1 = QALinear(28*28, 256, bit=bits)
        self.fc2 = QALinear(256, 10, bit=bits, only_positive_activations=True)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def quantize(self, int_dtype):
        new_model = Model()
        new_model.fc1 = LinearInt.from_quantized(self.fc1, int_dtype)
        new_model.fc2 = LinearInt.from_quantized(self.fc2, int_dtype)
        return new_model

In [5]:
model = Model().cuda()

In [6]:
epochs = 1
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(1, epochs + 1):
    train(
        model,
        epoch,
        loss_fn,
        optimizer,
        train_loader,
        use_cuda=True,
        log_interval=500,
    )



In [7]:
test(model, loss_fn, None, test_loader, use_cuda=True)


Test set: Average loss: 0.0010, Accuracy: 9616/10000 (96%)



tensor(96.1600)

In [8]:
quantized_model = model.quantize(torch.int16).to('cpu')
test(quantized_model, loss_fn, None, test_loader, use_cuda=False)


Test set: Average loss: 0.0040, Accuracy: 8984/10000 (90%)



tensor(89.8400)

In [9]:
quantized_model = model.quantize(torch.int8).to('cpu')
test(quantized_model, loss_fn, None, test_loader, use_cuda=False)


Test set: Average loss: 0.0182, Accuracy: 992/10000 (10%)



tensor(9.9200)

# OVERFLOW

In [10]:
def toi(t, i='8'):
    if i == '8':
        return t.to(torch.int8).to('cpu')
    elif i == '16':
        return t.to(torch.int16).to('cpu')
    elif i == '32':
        return t.to(torch.int32).to('cpu')
    assert 0

In [11]:
test_inp = next(iter(test_loader))
test_x, test_y = test_inp

In [12]:
a1 = test_x.view(-1, 28*28)
w1 = model.fc1.fc.weight
b1 = model.fc1.fc.bias

In [13]:
qa1, a1s = model.fc1.quantizer_act(a1)
dqa1 = qa1 * a1s

In [14]:
qw1, w1s = model.fc1.quantizer_weigh(w1)
dqw1 = qw1 * w1s

In [15]:
# int8 dequantized
a1s.cpu() * w1s.cpu() * F.linear(toi(qa1), toi(qw1), None) + b1.cpu()

tensor([[ 0.0427,  0.7478, -0.2351,  ...,  0.0679,  0.5372, -0.5603],
        [ 0.4999,  0.0721, -0.3170,  ..., -0.3757,  0.2028, -0.6081],
        [-0.1825,  0.6113,  0.5771,  ..., -0.6965, -0.8550,  0.6886],
        ...,
        [-0.1143, -0.2828,  0.6385,  ..., -0.4235, -0.8482, -0.5603],
        [ 0.5887, -0.6649,  0.5634,  ...,  0.8186, -0.3841, -0.6763],
        [ 0.6978, -0.8697,  0.7613,  ..., -0.5736,  0.5031, -0.1303]],
       grad_fn=<AddBackward0>)

In [16]:
# int16 dequantized
a1s.cpu() * w1s.cpu() * F.linear(toi(qa1, '16'), toi(qw1, '16'), None) + b1.cpu()

tensor([[-3.4515,  2.4949, -0.2351,  ..., -3.4263, -1.2099,  8.1752],
        [-6.4885, -1.6750, -0.3170,  ...,  4.8656, -1.5443, 11.6217],
        [-3.6767, -1.1358, -2.9171,  ...,  1.0506, -2.6021,  9.4241],
        ...,
        [-3.6085, -5.5241, -6.3499,  ...,  1.3236, -9.5837, 11.6694],
        [-2.9056,  6.3235,  2.3105,  ..., -2.6756,  3.1101,  4.5650],
        [-4.5435, -0.8697,  0.7613,  ...,  4.6677, -2.9911,  8.6052]],
       grad_fn=<AddBackward0>)

In [17]:
# fp32 dequantized
a1s.cpu() * w1s.cpu() * F.linear(qa1.cpu(), qw1.cpu(), None) + b1.cpu()

tensor([[-3.4515,  2.4949, -0.2351,  ..., -3.4263, -1.2099,  8.1752],
        [-6.4885, -1.6750, -0.3170,  ...,  4.8656, -1.5443, 11.6217],
        [-3.6767, -1.1358, -2.9171,  ...,  1.0506, -2.6021,  9.4241],
        ...,
        [-3.6085, -5.5241, -6.3499,  ...,  1.3236, -9.5837, 11.6694],
        [-2.9056,  6.3235,  2.3105,  ..., -2.6756,  3.1101,  4.5650],
        [-4.5435, -0.8697,  0.7613,  ...,  4.6677, -2.9911,  8.6052]],
       grad_fn=<AddBackward0>)

In [18]:
# fp32 before dequantization
F.linear(qa1.cpu(), qw1.cpu(), bias=None)

tensor([[  510.,  -368.,    32.,  ...,   505.,   179., -1192.],
        [  955.,   243.,    44.,  ...,  -710.,   228., -1697.],
        [  543.,   164.,   425.,  ...,  -151.,   383., -1375.],
        ...,
        [  533.,   807.,   928.,  ...,  -191.,  1406., -1704.],
        [  430.,  -929.,  -341.,  ...,   395.,  -454.,  -663.],
        [  670.,   125.,  -114.,  ...,  -681.,   440., -1255.]],
       grad_fn=<MmBackward0>)