In [2]:
import importlib
import utils

# Reload utils module
importlib.reload(utils)

# import 
from utils import *

## Load Float Model

In [3]:
model = ResNet18().to(device)
model.load_state_dict(torch.load("resnet18_float32.pth"))

<All keys matched successfully>

### MinMax Observer

In [4]:
class MinMaxObserver:
    def __init__(self):
        self.min =float('inf')
        self.max = float('-inf')
    def __call__(self, x):
        self.min = min(self.min, x.min().item())
        self.max = max(self.max, x.max().item())
    def get_scale_zp(self, symmetric=True):
        if symmetric:
            r = max(abs(self.min), self.max)
            scale = r / 127 if r>0 else 1.0
            return scale, 0
        else:
            scale = (self.max - self.min) / 255
            zp = round(-self.min /scale)
            return scale, zp

### activation calibration by Hook

In [10]:
observers = {}
def get_hook(name):
    def hook(module, input, output):
        if name not in observers:
            observers[name] = MinMaxObserver()
        observers[name](output.detach().cpu())
    return hook

# Resigter Hook after Conv, Linear, ReLU, BatchNorm
for name, m in model.named_modules():
    if isinstance(m, (nn.Conv2d, nn.Linear, nn.ReLU, nn.BatchNorm2d)):
        m.register_forward_hook(get_hook(name))
        #print(f'name: {name}, m: {m}')

# Calibration (2048 images)
model.eval()
with torch.no_grad():
    for i, (x, _) in enumerate(calib_loader):
        if i > 15: break # 2048 images
        model(x.to(device))


### Activation quantization

In [14]:
activation_scales = {}
# after calibration, get scale and zero-point
for name, obs in observers.items():
    # Activation usually use asymmetric uint8 (0~255) quantization
    scale, zp = obs.get_scale_zp(symmetric=False)
    #print(f" scale: {scale}, zp: {zp} for {name}")
    activation_scales[name] = (scale, zp.item() if torch.is_tensor(zp) else zp)

def fake_quant_act(x, scale, zp):
    # Asymmetric uint8 quantization
    q_x = torch.clamp(torch.round(x/scale+zp), 0, 255)
    return (q_x -zp) *scale

# remove hooks
for name, m in model.named_modules():
    if isinstance(m, (nn.Conv2d, nn.Linear, nn.ReLU, nn.BatchNorm2d)):
        m._forward_hooks.clear()

# register quantization hooks for activations
def get_quant_hook(name):
    def hook(module, input, output):
        if name in activation_scales:
            scale, zp = activation_scales[name]
            return fake_quant_act(output, scale, zp)
        return output
    return hook

for name, m in model.named_modules():
    if isinstance(m, (nn.Conv2d, nn.Linear, nn.ReLU, nn.BatchNorm2d)):
        m.register_forward_hook(get_quant_hook(name))


### Weight quantization (symmetric)

In [15]:
weight_scales = {}
for name, p in model.named_parameters():
    if 'weight' in name and p.dim() > 1: #except bias
        obs = MinMaxObserver()
        obs(p.data.cpu())
        scale, _ = obs.get_scale_zp(symmetric=True)
        weight_scales[name] = scale
        q_w = torch.round(p.data / scale).clamp(-128, 127)
        p.data = q_w * scale # dequantized

### Print Results

In [16]:
print(f"Vanilla PTQ Accuracy: {evaluate(model):.2f}%")
print(f"Float Model Size: {get_float_model_size(model):.2f} MB")
print(f"Quantized Model Size (INT8) {get_quantized_model_size(model, weight_scales):.2f} MB")
print(f"Inference Time: {measure_inference_time(model):.3f} ms/image")

Vanilla PTQ Accuracy: 99.26%
Float Model Size: 44.76 MB
Quantized Model Size (INT8) 11.23 MB
Inference Time: 0.045 ms/image
