In [81]:
import importlib
import utils

# Reload utils module
importlib.reload(utils)

# import 
from utils import *

In [82]:
model = ResNet18().to(device)
model.load_state_dict(torch.load("resnet18_float32.pth"))

<All keys matched successfully>

In [83]:
class STEQuant(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, scale):
        s = scale
        q = torch.clamp(torch.round(x/s), -128, 127)
        return q * s
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None
    
class LSQFakeQuant(nn.Module):
    def __init__(self, init_scale):
        super().__init__()
        self.scale = nn.Parameter(torch.tensor(init_scale)) # learnable
    
    def forward(self, x):
        if self.training:
            return STEQuant.apply(x, self.scale)
        return x
    
# LSQ QAT Model
class LSQQATModel(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base = base_model

        # put LSQFakeQuant on main activation
        self.fq_conv1 = LSQFakeQuant(init_scale=0.07)
        self.fq_layer1 = LSQFakeQuant(init_scale=0.07)
        self.fq_layer2 = LSQFakeQuant(init_scale=0.07)
        self.fq_layer3 = LSQFakeQuant(init_scale=0.07)
        self.fq_layer4 = LSQFakeQuant(init_scale=0.07)

    def forward(self, x):
        x = self.base.conv1(x)
        x = self.base.bn1(x)
        x = self.base.relu(x)
        x = self.fq_conv1(x)

        x = self.base.layer1(x)
        x = self.fq_layer1(x)
        x = self.base.layer2(x)
        x = self.fq_layer2(x)
        x = self.base.layer3(x)
        x = self.fq_layer3(x)
        x = self.base.layer4(x)
        x = self.fq_layer4(x)

        x = self.base.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.base.fc(x)
        return x

In [84]:
qat_model = LSQQATModel(model).to(device)

In [85]:
# Weight learnable scale
weight_scale_params = {}
for name, p in qat_model.named_parameters():
    if 'weight' in name and p.dim() > 1:
        weight_scale_params[name.replace('base.', '')] = nn.Parameter(torch.tensor(0.07))

# LSQ weight quantization function
def apply_lsq_weight_quant(model, scale_params):
    for name, p in model.named_parameters():
        if 'weight' in name and p.dim() > 1:
            original_name = name.replace('base.', '')
            s = scale_params[original_name]
            p.data = STEQuant.apply(p.data, s)

apply_lsq_weight_quant(qat_model, weight_scale_params)

In [86]:
# Optimizer
optimizer = optim.Adam(list(qat_model.parameters()) + list(weight_scale_params.values()), lr = 1e-4)
criterion = nn.CrossEntropyLoss()
epochs = 30
for epoch in range(epochs):
    qat_model.train()
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = criterion(qat_model(x), y)
        loss.backward()
        optimizer.step()
        apply_lsq_weight_quant(qat_model, weight_scale_params)
    
    print(f"Epoch {epoch+1} Acc: {evaluate(qat_model):.2f}%")

Epoch 1 Acc: 33.99%
Epoch 2 Acc: 59.91%
Epoch 3 Acc: 76.78%
Epoch 4 Acc: 84.64%
Epoch 5 Acc: 91.44%
Epoch 6 Acc: 94.14%
Epoch 7 Acc: 95.30%
Epoch 8 Acc: 95.88%
Epoch 9 Acc: 96.32%
Epoch 10 Acc: 96.55%
Epoch 11 Acc: 96.91%
Epoch 12 Acc: 97.16%
Epoch 13 Acc: 97.36%
Epoch 14 Acc: 97.56%
Epoch 15 Acc: 97.79%
Epoch 16 Acc: 97.91%
Epoch 17 Acc: 98.00%
Epoch 18 Acc: 98.04%
Epoch 19 Acc: 98.15%
Epoch 20 Acc: 98.15%
Epoch 21 Acc: 98.12%
Epoch 22 Acc: 98.08%
Epoch 23 Acc: 98.19%
Epoch 24 Acc: 98.39%
Epoch 25 Acc: 98.48%
Epoch 26 Acc: 98.35%
Epoch 27 Acc: 98.53%
Epoch 28 Acc: 98.57%
Epoch 29 Acc: 98.61%
Epoch 30 Acc: 98.69%
