In [None]:
import importlib
import utils

# Reload utils module
importlib.reload(utils)

# import 
from utils import *

In [None]:
# MinMaxObserver (for calibration)
class MinMaxObserver:
    def __init__(self):
        self.min = float('inf')
        self.max = float('-inf')
    def __call__(self, x):
        self.min = min(self.min, x.min().item())
        self.max = max(self.max, x.max().item())
    def get_scale(self):
        r = max(abs(self.min), self.max)
        return r / 127

In [None]:
# Build STE
class STEQuant(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, scale, q_min=-128, q_max=127):
        s = scale.clamp(min=1e-8)
        q = torch.clamp(torch.round(x/s), q_min, q_max)
        return q * s # dequant return -> propagate noise to next layer
    
    @staticmethod
    def backward(ctx, grad_output):
        # Backward: STE - ignore quant, propagate 1
        return grad_output, None 

In [None]:
# Vanilla FakeQuant (fixed scale)
class VanillaFakeQuant(nn.Module):
    def __init__(self, scale):
        super().__init__()
        self.scale = torch.tensor(scale) # fixed

    def forward(self, x):
        if self.training:
            return STEQuant.apply(x, self.scale)
        return x

In [None]:
model = ResNet18().to(device)
model.load_state_dict(torch.load("resnet18_float32.pth"))

In [None]:
# Calibration (activation + weight scale)
observers = {}
hooks = []
def get_hook(name):
    def hook(m, i, o):
        if name not in observers:
            observers[name] = MinMaxObserver()
        observers[name](o.detach().cpu())
    return hook

for name, m in model.named_modules():
    if isinstance(m, (nn.Conv2d, nn.Linear, nn.ReLU)):
        hooks.append(m.register_forward_hook(get_hook(name)))

# Calibration (2048 images)
model.eval()
with torch.no_grad():
    for i, (x, _) in enumerate(calib_loader):
        if i > 15: break
        model(x.to(device))

for h in hooks:
    h.remove()

act_scales = {name: obs.get_scale() for name, obs in observers.items()}

weight_scales = {}
for name, p in model.named_parameters():