In [112]:
import importlib
import utils

# Reload utils module
importlib.reload(utils)

# import 
from utils import *

In [113]:
# MinMaxObserver (for calibration)
class MinMaxObserver:
    def __init__(self):
        self.min = float('inf')
        self.max = float('-inf')
    def __call__(self, x):
        self.min = min(self.min, x.min().item())
        self.max = max(self.max, x.max().item())
    def get_scale(self):
        r = max(abs(self.min), self.max)
        return r / 127

In [114]:
# Build STE
class STEQuant(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, scale, q_min=-128, q_max=127):
        if isinstance(scale, float):
            s = torch.tensor(scale, device=x.device).clamp(min=1e-8)
        else:
            s = scale.clamp(min=1e-8)
        q = torch.clamp(torch.round(x/s), q_min, q_max)
        return q * s # dequant return -> propagate noise to next layer
    
    @staticmethod
    def backward(ctx, grad_output):
        # Backward: STE - ignore quant, propagate 1
        return grad_output, None 

In [115]:
# Vanilla FakeQuant (fixed scale)
class VanillaFakeQuant(nn.Module):
    def __init__(self, scale):
        super().__init__()
        self.scale = torch.tensor(scale) # fixed

    def forward(self, x):
        if self.training:
            return STEQuant.apply(x, self.scale)
        return x

In [116]:
model = ResNet18().to(device)
model.load_state_dict(torch.load("resnet18_float32.pth"))

<All keys matched successfully>

In [117]:
# Calibration (activation + weight scale)
observers = {}
hooks = []
count = 0

def get_safe_name(module):
    global count
    name = f"{module.__class__.__name__}_{count}"
    count += 1
    return name

def get_hook(safe_name):
    def hook(m, i, o):
        if name not in observers:
            observers[safe_name] = MinMaxObserver()
        observers[safe_name](o.detach().cpu())
    return hook

count = 0
for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear, nn.ReLU)):
        safe_name = get_safe_name(m)
        hooks.append(m.register_forward_hook(get_hook(safe_name)))

# Calibration (2048 images)
model.eval()
with torch.no_grad():
    for i, (x, _) in enumerate(calib_loader):
        if i > 15: break
        model(x.to(device))

for h in hooks:
    h.remove()

act_scales = {name: obs.get_scale() for name, obs in observers.items()}

weight_scales = {}
for name, p in model.named_parameters():
    if 'weight' in name and p.dim() > 1:
        obs = MinMaxObserver()
        obs(p.data.cpu())
        weight_scales[name] = obs.get_scale()

In [118]:
# Vanilla QAT model (manual forward + Insert FakeQuant)
class VanillaQATModel(nn.Module):
    def __init__(self, base_model, act_scales):
        super().__init__()
        self.base = base_model 
        self.fq_dict = nn.ModuleDict({k: VanillaFakeQuant(v) for k, v in act_scales.items()})
    
    def forward(self, x):
        x = self.base.conv1(x)
        x = self.base.bn1(x)
        x = self.base.relu(x)
        if 'conv1' in self.fq_dict: x = self.fq_dict['conv1'](x)

        x = self.base.layer1(x)
        if 'layer1' in self.fq_dict: x = self.fq_dict['layer1'](x)
        x = self.base.layer2(x)
        if 'layer2' in self.fq_dict: x = self.fq_dict['layer2'](x)
        x = self.base.layer3(x)
        if 'layer3' in self.fq_dict: x = self.fq_dict['layer3'](x)
        x = self.base.layer4(x)
        if 'layer4' in self.fq_dict: x = self.fq_dict['layer4'](x)

        x = self.base.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.base.fc(x)
        return x
    
qat_model = VanillaQATModel(model, act_scales).to(device)

# Weight fixed quant (DO STE every step)
def apply_fixed_weight_quant(model, scales):
    for name, p in model.named_parameters():
        if 'weight' in name and p.dim() > 1:
            # remove 'base.' prefix
            original_name = name.replace('base.', '')
            s = scales[original_name]
            p.data = STEQuant.apply(p.data, s)

apply_fixed_weight_quant(qat_model, weight_scales)



In [119]:
# Fine-tuning
optimizer = optim.Adam(qat_model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
epochs = 20

for epoch in range(epochs):
    qat_model.train()
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = criterion(qat_model(x), y)
        loss.backward()
        optimizer.step()
    apply_fixed_weight_quant(qat_model, weight_scales)

    print(f"Epoch {epoch+1}/{epochs} Acc: {evaluate(qat_model):.2f}%")

print(f"Final Vanilla QAT Acc: {evaluate(qat_model):.2f}%")

Epoch 1/20 Acc: 81.79%
Epoch 2/20 Acc: 99.20%
Epoch 3/20 Acc: 99.39%
Epoch 4/20 Acc: 99.23%
Epoch 5/20 Acc: 99.24%
Epoch 6/20 Acc: 99.31%
Epoch 7/20 Acc: 99.13%
Epoch 8/20 Acc: 99.33%
Epoch 9/20 Acc: 99.24%
Epoch 10/20 Acc: 99.37%
Epoch 11/20 Acc: 99.39%
Epoch 12/20 Acc: 99.37%
Epoch 13/20 Acc: 99.18%
Epoch 14/20 Acc: 99.35%
Epoch 15/20 Acc: 99.37%
Epoch 16/20 Acc: 99.35%
Epoch 17/20 Acc: 99.36%
Epoch 18/20 Acc: 99.41%
Epoch 19/20 Acc: 99.45%
Epoch 20/20 Acc: 99.31%
Final Vanilla QAT Acc: 99.31%
