In [11]:
import torch

In [12]:
class MyModel(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(3, 3, bias=False)
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(3, 1, bias=False)

    def forward(self, x):
        x = self.linear2(self.relu(self.linear1(x)))
        return x

In [13]:
weights = torch.tensor([[1.1], [2.2], [3.3]])
torch.manual_seed(123)
training_features = torch.randn(12000, 3)
training_labels = torch.matmul(training_features, weights)

torch.manual_seed(123)
test_features = torch.randn(1000, 3)
test_labels = torch.matmul(test_features, weights)

In [14]:
model = MyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
for epoch in range(100):
    preds = model(training_features)
    loss = torch.nn.functional.mse_loss(preds, training_labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

In [15]:
model.eval()
with torch.no_grad():
    preds = model(test_features)
    mse = torch.nn.functional.mse_loss(preds, test_labels)
    print(f"float32 model testing loss: {mse.item():.3f}")

float32 model testing loss: 0.199


In [16]:
model_int8 = torch.ao.quantization.quantize_dynamic(model, {torch.nn.Linear},
                                                    dtype=torch.qint8)
with torch.no_grad():
    preds = model_int8(test_features)
    mse = torch.nn.functional.mse_loss(preds, test_labels)
    print(f"int8 model testing loss: {mse.item():.3f}")

int8 model testing loss: 0.212


In [17]:
model.linear1.weight

Parameter containing:
tensor([[-0.8265, -1.9637, -3.6101],
        [ 0.1484,  0.5460,  1.7600],
        [ 0.9747,  1.7071,  1.6407]], requires_grad=True)

In [18]:
torch.int_repr(model_int8.linear1.weight())

tensor([[ -29,  -69, -127],
        [   5,   19,   62],
        [  34,   60,   58]], dtype=torch.int8)

In [19]:
model_int8.linear1.weight()

tensor([[-0.8211, -1.9537, -3.5960],
        [ 0.1416,  0.5380,  1.7555],
        [ 0.9627,  1.6989,  1.6423]], size=(3, 3), dtype=torch.qint8,
       quantization_scheme=torch.per_tensor_affine, scale=0.028314679861068726,
       zero_point=0)