In [25]:
import importlib
import utils

# Reload utils module
importlib.reload(utils)

# import 
from utils import *

In [26]:
model = ResNet18().to(device)
model.load_state_dict(torch.load("resnet18_float32.pth"))

<All keys matched successfully>

In [27]:
class STEQuant(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, scale):
        s = scale
        q = torch.clamp(torch.round(x/s), -128, 127)
        return q * s
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None
    
class LSQFakeQuant(nn.Module):
    def __init__(self):
        super().__init__()
        self.scale = nn.Parameter(torch.tensor(0.1)) # learnable, dummy value
        self.register_buffer('initialized', torch.tensor(False))
    
    def forward(self, x):
        if self.training:
            if not self.initialized:
                Qp = 127.0
                mean_abs = x.detach().abs().mean()
                init_scale = 2*mean_abs/ (Qp ** 0.5)
                self.scale.data.copy_(init_scale)
                self.initialized.fill_(True)

            return STEQuant.apply(x, self.scale)
        return x
    
# LSQ QAT Model
class LSQQATModel(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base = base_model

        # put LSQFakeQuant on main activation
        self.fq_conv1 = LSQFakeQuant()
        self.fq_layer1 = LSQFakeQuant()
        self.fq_layer2 = LSQFakeQuant()
        self.fq_layer3 = LSQFakeQuant()
        self.fq_layer4 = LSQFakeQuant()

    def forward(self, x):
        x = self.base.conv1(x)
        x = self.base.bn1(x)
        x = self.base.relu(x)
        x = self.fq_conv1(x)

        x = self.base.layer1(x)
        x = self.fq_layer1(x)
        x = self.base.layer2(x)
        x = self.fq_layer2(x)
        x = self.base.layer3(x)
        x = self.fq_layer3(x)
        x = self.base.layer4(x)
        x = self.fq_layer4(x)

        x = self.base.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.base.fc(x)
        return x

In [28]:
qat_model = LSQQATModel(model).to(device)

In [29]:
# Weight learnable scale
weight_scale_params = {}
Qp= 127.0
for name, p in qat_model.named_parameters():
    if 'weight' in name and p.dim() > 1:
        original_name = name.replace('base.', '')
        mean_abs = p.detach().abs().abs().mean().item()
        init_scale = 2 * mean_abs / (Qp ** 0.5)
        weight_scale_params[original_name] = nn.Parameter(torch.tensor(init_scale))

# LSQ weight quantization function
def apply_lsq_weight_quant(model, scale_params):
    for name, p in model.named_parameters():
        if 'weight' in name and p.dim() > 1:
            original_name = name.replace('base.', '')
            s = scale_params[original_name]
            p.data = STEQuant.apply(p.data, s)

apply_lsq_weight_quant(qat_model, weight_scale_params)

In [30]:
# Optimizer
optimizer = optim.Adam(list(qat_model.parameters()) + list(weight_scale_params.values()), lr = 1e-4)
criterion = nn.CrossEntropyLoss()
epochs = 90
for epoch in range(epochs):
    qat_model.train()
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = criterion(qat_model(x), y)
        loss.backward()
        optimizer.step()
        apply_lsq_weight_quant(qat_model, weight_scale_params)
    
    print(f"Epoch {epoch+1} Acc: {evaluate(qat_model):.2f}%")

Epoch 1 Acc: 29.23%
Epoch 2 Acc: 47.53%
Epoch 3 Acc: 68.19%
Epoch 4 Acc: 79.40%
Epoch 5 Acc: 87.80%
Epoch 6 Acc: 91.35%
Epoch 7 Acc: 93.97%
Epoch 8 Acc: 95.18%
Epoch 9 Acc: 95.98%
Epoch 10 Acc: 96.35%
Epoch 11 Acc: 96.88%
Epoch 12 Acc: 97.22%
Epoch 13 Acc: 97.39%
Epoch 14 Acc: 97.63%
Epoch 15 Acc: 97.81%
Epoch 16 Acc: 97.97%
Epoch 17 Acc: 98.07%
Epoch 18 Acc: 98.15%
Epoch 19 Acc: 98.29%
Epoch 20 Acc: 98.32%
Epoch 21 Acc: 98.42%
Epoch 22 Acc: 98.49%
Epoch 23 Acc: 98.52%
Epoch 24 Acc: 98.60%
Epoch 25 Acc: 98.69%
Epoch 26 Acc: 98.70%
Epoch 27 Acc: 98.81%
Epoch 28 Acc: 98.83%
Epoch 29 Acc: 98.82%
Epoch 30 Acc: 98.87%
Epoch 31 Acc: 98.92%
Epoch 32 Acc: 98.91%
Epoch 33 Acc: 98.99%
Epoch 34 Acc: 98.94%
Epoch 35 Acc: 99.03%
Epoch 36 Acc: 99.12%
Epoch 37 Acc: 99.10%
Epoch 38 Acc: 99.10%
Epoch 39 Acc: 99.16%
Epoch 40 Acc: 99.15%
Epoch 41 Acc: 99.17%
Epoch 42 Acc: 99.18%
Epoch 43 Acc: 99.16%
Epoch 44 Acc: 99.19%
Epoch 45 Acc: 99.18%
Epoch 46 Acc: 99.22%
Epoch 47 Acc: 99.19%
Epoch 48 Acc: 99.21%
E