In [5]:
import torch

In [6]:
class MyModel(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.quant = torch.ao.quantization.QuantStub()
        self.linear1 = torch.nn.Linear(3, 3, bias=False)
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(3, 1, bias=False)
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
        q_inputs = self.quant(x)
        outputs = self.linear2(self.relu(self.linear1(q_inputs)))
        f_outputs = self.dequant(outputs)
        return f_outputs

In [7]:
weights = torch.tensor([[1.1], [2.2], [3.3]])
torch.manual_seed(123)
training_features = torch.randn(12000, 3)
training_labels = training_features @ weights

torch.manual_seed(123)
test_features = torch.randn(1000, 3)
test_labels = test_features @ weights

In [8]:
model = MyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
for epoch in range(100):
    preds = model(training_features)
    loss = torch.nn.functional.mse_loss(preds, training_labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

In [9]:
model.eval()
with torch.no_grad():
    preds = model(test_features)
    mse = torch.nn.functional.mse_loss(preds, test_labels)
    print(f"float32 model testing loss: {mse.item():.3f}")

float32 model testing loss: 0.199


In [10]:
model.qconfig = torch.ao.quantization.get_default_qconfig('x86')
model_prepared = torch.ao.quantization.prepare(model)
model_prepared(test_features)



tensor([[-1.1362e+00],
        [ 2.0354e+00],
        [ 1.1349e+00],
        [-4.5833e+00],
        [ 5.9918e+00],
        [-3.5204e+00],
        [ 6.5489e+00],
        [ 3.8726e+00],
        [ 5.6379e-01],
        [-1.9973e+00],
        [-1.0604e+01],
        [ 2.0177e-01],
        [-3.8567e+00],
        [ 5.2840e+00],
        [ 6.4200e-01],
        [-1.6240e+00],
        [-2.5490e-02],
        [ 2.4547e+00],
        [ 8.2064e+00],
        [ 6.6574e-01],
        [-1.4114e+00],
        [ 4.3773e+00],
        [-2.4360e+00],
        [ 5.4590e-01],
        [-1.1498e+00],
        [ 6.8148e-02],
        [-1.5154e+00],
        [ 4.0666e+00],
        [-3.6973e-02],
        [-8.2602e+00],
        [-2.7473e+00],
        [-4.2103e+00],
        [-1.1666e+01],
        [-8.5860e-01],
        [ 6.8589e+00],
        [-8.1069e-01],
        [-5.1303e+00],
        [-2.8813e+00],
        [ 2.0469e+00],
        [-3.2249e+00],
        [-7.8458e-01],
        [ 5.0872e+00],
        [-6.2163e+00],
        [-6

In [17]:
model_int8 = torch.ao.quantization.convert(model_prepared)
with torch.no_grad():
    preds = model_int8(test_features)
    mse = torch.nn.functional.mse_loss(preds, test_labels)
    print(f"int8 model testing loss: {mse.item():.3f}")

int8 model testing loss: 0.219


In [18]:
model.linear1.weight

Parameter containing:
tensor([[-0.8265, -1.9637, -3.6101],
        [ 0.1484,  0.5460,  1.7600],
        [ 0.9747,  1.7071,  1.6407]], requires_grad=True)

In [19]:
torch.int_repr(
    model_int8.linear1.weight()
)

tensor([[ -29,  -69, -127],
        [  11,   40,  127],
        [  73,  127,  123]], dtype=torch.int8)

In [20]:
model_int8.linear1.weight()

tensor([[-0.8211, -1.9537, -3.5960],
        [ 0.1518,  0.5522,  1.7531],
        [ 0.9774,  1.7004,  1.6468]], size=(3, 3), dtype=torch.qint8,
       quantization_scheme=torch.per_channel_affine,
       scale=tensor([0.0283, 0.0138, 0.0134], dtype=torch.float64),
       zero_point=tensor([0, 0, 0]), axis=0)