In [146]:
import os, platform, sys, json, io, time, math, subprocess, threading
from pathlib import Path

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms as T
from tinyimagenet import TinyImageNet
import pandas as pd
import copy

torch.manual_seed(0)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(0)
device = "cuda" if torch.cuda.is_available() else "cpu"

def gpu_specs():
    if torch.cuda.is_available():
        name = torch.cuda.get_device_name(0)
        capability = torch.cuda.get_device_capability(0)
        total = torch.cuda.get_device_properties(0).total_memory / (1024**3)
        return {"name": name, "compute_capability": capability, "vram_gb": round(total,2)}
    return None

specs = {
    "python": sys.version.split()[0],
    "pytorch": torch.__version__,
    "device": device,
    "cpu": platform.processor(),
    "machine": platform.machine(),
    "platform": platform.platform(),
    "gpu": gpu_specs(),
}
print(json.dumps(specs, indent=1))

{
 "python": "3.12.7",
 "pytorch": "2.6.0.dev20241112+cu121",
 "device": "cuda",
 "cpu": "Intel64 Family 6 Model 154 Stepping 3, GenuineIntel",
 "machine": "AMD64",
 "platform": "Windows-11-10.0.26100-SP0",
 "gpu": {
  "name": "NVIDIA GeForce RTX 4050 Laptop GPU",
  "compute_capability": [
   8,
   9
  ],
  "vram_gb": 6.0
 }
}


# Data

In [147]:
# --- Data: TinyImageNet loaders ---
train_tfms = T.Compose([
    T.Resize((64, 64)),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
])
val_tfms = T.Compose([
    T.Resize((64, 64)),
    T.ToTensor(),
])

train_ds = TinyImageNet(Path("~/.torchvision/tinyimagenet/").expanduser(), split="train")
val_ds   = TinyImageNet(Path("~/.torchvision/tinyimagenet/").expanduser(), split="val")
train_ds.transform = train_tfms
val_ds.transform   = val_tfms

train_loader = DataLoader(train_ds, batch_size=128, shuffle=True,
                          num_workers=0, pin_memory=(device=="cuda"))
val_loader   = DataLoader(val_ds,   batch_size=256, shuffle=False,
                          num_workers=0, pin_memory=(device=="cuda"))
print(f"TinyImageNet train={len(train_ds)}, val={len(val_ds)}")

TinyImageNet train=100000, val=10000


# Model

In [148]:
class ToyLinearModel(torch.nn.Module):
    def __init__(self, m=3*64*64, n=200, k=64):
        super().__init__()
        self.linear1 = torch.nn.Linear(m, k, bias=False)
        self.linear2 = torch.nn.Linear(k, n, bias=False)

    def forward(self, x):
        if x.dim() == 4:
            x = torch.flatten(x, 1)  # (B, 3*64*64)
        x = self.linear1(x)
        x = self.linear2(x)
        return x

# Observers

In [149]:
# --- TorchAO observers & helpers ---
from torchao.quantization.granularity import PerAxis, PerTensor
from torchao.quantization.observer import AffineQuantizedMinMaxObserver
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter

# per tensor input activation asymmetric quantization
act_obs = AffineQuantizedMinMaxObserver(
    MappingType.ASYMMETRIC,
    torch.uint8,
    granularity=PerTensor(),
    eps=torch.finfo(torch.float32).eps,
    scale_dtype=torch.float32,
    zero_point_dtype=torch.float32,
)

# per channel weight asymmetric quantization
weight_obs = AffineQuantizedMinMaxObserver(
    MappingType.ASYMMETRIC,
    torch.uint8,
    granularity=PerAxis(axis=0),
    eps=torch.finfo(torch.float32).eps,
    scale_dtype=torch.float32,
    zero_point_dtype=torch.float32,
)

In [150]:
class ObservedLinear(torch.nn.Linear):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        act_obs: torch.nn.Module,
        weight_obs: torch.nn.Module,
        bias: bool = True,
        device=None,
        dtype=None,
    ):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.act_obs = act_obs
        self.weight_obs = weight_obs

    def forward(self, input: torch.Tensor):
        observed_input = self.act_obs(input)
        observed_weight = self.weight_obs(self.weight)
        return F.linear(observed_input, observed_weight, self.bias)

    @classmethod
    def from_float(cls, float_linear, act_obs, weight_obs):
        observed_linear = cls(
            float_linear.in_features,
            float_linear.out_features,
            act_obs,
            weight_obs,
            False,
            device=float_linear.weight.device,
            dtype=float_linear.weight.dtype,
        )
        observed_linear.weight = float_linear.weight
        observed_linear.bias = float_linear.bias
        return observed_linear

In [151]:
def insert_observers_(model, act_obs, weight_obs):
    _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear)

    def replacement_fn(m):
        copied_act_obs = copy.deepcopy(act_obs)
        copied_weight_obs = copy.deepcopy(weight_obs)
        return ObservedLinear.from_float(m, copied_act_obs, copied_weight_obs)

    _replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear)

# Quantization

In [152]:
# --- QuantizedLinear backend (with safe float fallback) ---
from torchao.dtypes import to_affine_quantized_intx_static
class QuantizedLinear(torch.nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        act_obs: torch.nn.Module,
        weight_obs: torch.nn.Module,
        weight: torch.Tensor,
        bias: torch.Tensor,
        target_dtype: torch.dtype,
    ):
        super().__init__()
        self.act_scale, self.act_zero_point = act_obs.calculate_qparams()
        weight_scale, weight_zero_point = weight_obs.calculate_qparams()
        assert weight.dim() == 2
        block_size = (1, weight.shape[1])
        self.target_dtype = target_dtype
        self.bias = bias
        self.qweight = to_affine_quantized_intx_static(
            weight, weight_scale, weight_zero_point, block_size, self.target_dtype
        )

    def forward(self, input: torch.Tensor):
        block_size = input.shape
        qinput = to_affine_quantized_intx_static(
            input,
            self.act_scale,
            self.act_zero_point,
            block_size,
            self.target_dtype,
        )
        return F.linear(qinput, self.qweight, self.bias)

    @classmethod
    def from_observed(cls, observed_linear, target_dtype):
        quantized_linear = cls(
            observed_linear.in_features,
            observed_linear.out_features,
            observed_linear.act_obs,
            observed_linear.weight_obs,
            observed_linear.weight,
            observed_linear.bias,
            target_dtype,
        )
        return quantized_linear

In [153]:
from dataclasses import dataclass
from torchao.core.config import AOBaseConfig
from torchao.quantization import quantize_
from torchao.quantization.transform_module import register_quantize_module_handler

@dataclass
class StaticQuantConfig(AOBaseConfig):
    target_dtype: torch.dtype

@register_quantize_module_handler(StaticQuantConfig)
def _apply_static_quant(module: torch.nn.Module, config: StaticQuantConfig):
    return QuantizedLinear.from_observed(module, config.target_dtype)

is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear)

# FP32

In [154]:
# --- FP32 training baseline ---
torch.cuda.empty_cache() if device=="cuda" else None

def train_float_baseline(epochs=10, lr=2e-3):
    model = ToyLinearModel().to(torch.float32).to(device).train()
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    loss_fn = torch.nn.CrossEntropyLoss()
    for ep in range(epochs):
        running = 0.0
        for xb, yb in train_loader:
            xb = xb.to(device=device, dtype=torch.float32, non_blocking=(device=="cuda"))
            yb = yb.to(device)
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = loss_fn(logits, yb)
            loss.backward()
            opt.step()
            running += loss.item()
        print(f"[FP32] epoch {ep+1}/{epochs} loss={running/len(train_loader):.4f}")
    return model.eval()

m_fp_trained = train_float_baseline(epochs=10, lr=2e-3)

[FP32] epoch 1/10 loss=5.4695
[FP32] epoch 2/10 loss=4.9832
[FP32] epoch 3/10 loss=4.9515
[FP32] epoch 4/10 loss=4.9722
[FP32] epoch 5/10 loss=4.9785
[FP32] epoch 6/10 loss=4.9695
[FP32] epoch 7/10 loss=4.9703
[FP32] epoch 8/10 loss=4.9595
[FP32] epoch 9/10 loss=4.9587
[FP32] epoch 10/10 loss=4.9436


# PTQ

In [155]:
# --- PTQ: observe, calibrate, quantize ---
m_obs = copy.deepcopy(m_fp_trained).to(torch.float32).to(device).eval()
insert_observers_(m_obs, act_obs, weight_obs)

with torch.inference_mode():
    for i, (images, _) in enumerate(train_loader):
        images = images.to(device=device, dtype=torch.float32, non_blocking=(device=="cuda"))
        _ = m_obs(images)
        if i >= 50:  # decent min/max calibration
            break

m_ptq = copy.deepcopy(m_obs)
quantize_(m_ptq, StaticQuantConfig(torch.uint8), is_observed_linear)
m_ptq.eval().to(device)
print("PTQ ready.")

PTQ ready.


# QAT

In [156]:
# --- QAT: fake-quant with STE ---
def _aq_clamp_round(x, scale, zp, qmin=0, qmax=255):
    q = torch.round(x / scale) + zp
    q = torch.clamp(q, qmin, qmax)
    return (q - zp) * scale

class FakeQLinear(torch.nn.Linear):
    def __init__(self, in_features, out_features, act_obs, weight_obs, bias=True, device=None, dtype=None):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.act_obs = act_obs
        self.weight_obs = weight_obs
        self.in_features = in_features
        self.out_features = out_features
    @classmethod
    def from_observed(cls, observed_linear: torch.nn.Linear):
        fq = cls(
            observed_linear.in_features, observed_linear.out_features,
            observed_linear.act_obs, observed_linear.weight_obs,
            bias=(observed_linear.bias is not None),
            device=observed_linear.weight.device, dtype=observed_linear.weight.dtype,
        )
        with torch.no_grad():
            fq.weight.copy_(observed_linear.weight)
            if observed_linear.bias is not None:
                fq.bias.copy_(observed_linear.bias)
        return fq
    def forward(self, x: torch.Tensor):
        x_obs = self.act_obs(x)
        w_obs = self.weight_obs(self.weight)
        act_scale, act_zp = self.act_obs.calculate_qparams()
        w_scale,  w_zp    = self.weight_obs.calculate_qparams()
        x_fq = _aq_clamp_round(x_obs, act_scale, act_zp)               # per-tensor acts
        w_fq = _aq_clamp_round(w_obs, w_scale.view(-1,1), w_zp.view(-1,1))  # per-channel weights (axis=0)
        return F.linear(x_fq, w_fq, self.bias)

def replace_observed_with_fakeq_(model):
    _is_observed = lambda m, fqn: isinstance(m, ObservedLinear)
    def _to_fakeq(m):
        return FakeQLinear.from_observed(m)
    _replace_with_custom_fn_if_matches_filter(model, _to_fakeq, _is_observed)

def replace_fakeq_with_quantized_(model, target_dtype=torch.uint8):
    _is_fakeq = lambda m, fqn: isinstance(m, FakeQLinear)
    def _to_quant(m):
        return QuantizedLinear(
            m.in_features, m.out_features, m.act_obs, m.weight_obs, m.weight, m.bias, target_dtype
        )
    _replace_with_custom_fn_if_matches_filter(model, _to_quant, _is_fakeq)

In [157]:
# QAT finetune from trained float
m_qat_train = copy.deepcopy(m_fp_trained).to(torch.float32).to(device).train()
insert_observers_(m_qat_train, act_obs, weight_obs)
replace_observed_with_fakeq_(m_qat_train)

opt = torch.optim.AdamW(m_qat_train.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()
epochs_qat = 10

for ep in range(epochs_qat):
    running=0.0
    for xb, yb in train_loader:
        xb = xb.to(device=device, dtype=torch.float32, non_blocking=(device=="cuda"))
        yb = yb.to(device)
        opt.zero_grad(set_to_none=True)
        logits = m_qat_train(xb)   # fake-quant forward
        loss = loss_fn(logits, yb)
        loss.backward()
        opt.step()
        running += loss.item()
    print(f"[QAT] epoch {ep+1}/{epochs_qat} loss={running/len(train_loader):.4f}")

m_qat_train.eval()
m_qat = copy.deepcopy(m_qat_train).eval()
replace_fakeq_with_quantized_(m_qat, target_dtype=torch.uint8)
m_qat = m_qat.to(device).eval()
print("QAT ready.")

[QAT] epoch 1/10 loss=4.8713
[QAT] epoch 2/10 loss=4.8646
[QAT] epoch 3/10 loss=4.8613
[QAT] epoch 4/10 loss=4.8567
[QAT] epoch 5/10 loss=4.8516
[QAT] epoch 6/10 loss=4.8472
[QAT] epoch 7/10 loss=4.8443
[QAT] epoch 8/10 loss=4.8393
[QAT] epoch 9/10 loss=4.8373
[QAT] epoch 10/10 loss=4.8360
QAT ready.


# Metrics

In [158]:
# --- Metrics: size, accuracy, latency/power/energy ---
def linear_fp32_size_bytes(mod: torch.nn.Linear):
    w_bytes = mod.weight.numel() * 4
    b_bytes = (mod.bias.numel() * 4) if (mod.bias is not None) else 0
    return w_bytes + b_bytes

def linear_int8_size_bytes(out_features, in_features, has_bias=True, per_channel=True):
    w_bytes = out_features * in_features * 1  # int8
    if per_channel:
        scale_bytes = out_features * 4
        zp_bytes    = out_features * 4
    else:
        scale_bytes = 4
        zp_bytes    = 4
    b_bytes = (out_features * 4) if has_bias else 0
    return w_bytes + scale_bytes + zp_bytes + b_bytes

def model_size_mb_fp32(model: torch.nn.Module):
    total = 0
    for m in model.modules():
        if isinstance(m, torch.nn.Linear):
            total += m.weight.numel() * 4 + (m.bias.numel() * 4 if m.bias is not None else 0)
    return total / (1024**2)

def model_size_mb_int8_like(model: torch.nn.Module, per_channel=True):
    total = 0
    for m in model.modules():
        if isinstance(m, (torch.nn.Linear, ObservedLinear, FakeQLinear, QuantizedLinear)):
            # prefer attributes; fall back to weight.shape
            ofe = getattr(m, "out_features", None)
            ife = getattr(m, "in_features", None)
            if ofe is None or ife is None:
                if hasattr(m, "weight"): ofe, ife = m.weight.shape
                elif hasattr(m, "qweight"): ofe, ife = m.qweight.shape
                else: continue
            has_bias = getattr(m, "bias", None) is not None
            total += linear_int8_size_bytes(ofe, ife, has_bias, per_channel=per_channel)
    return total / (1024**2)

def top1_accuracy(model, loader, device="cuda", dtype=torch.float32):
    model.eval()
    correct = total = 0
    with torch.inference_mode():
        for x, y in loader:
            x = x.to(device=device, dtype=dtype, non_blocking=(device=="cuda"))
            y = y.to(device)
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.numel()
    return correct / max(1,total)

def latency_power_energy(model, batch_tensor, iters=200, warm=20, device="cuda"):
    model.eval()
    # warmup
    with torch.inference_mode():
        for _ in range(warm):
            _ = model(batch_tensor)
    if device=="cuda": torch.cuda.synchronize()

    class PowerSampler:
        def __init__(self, interval_s=0.02, device_index=0):
            self.interval_s = interval_s
            self.samples = []
            self.device_index = device_index
            self._stop = threading.Event()
            self._thr = None
            try:
                import pynvml; self.use_nvml = True
            except Exception:
                self.use_nvml = False
        def _loop_nvml(self):
            import pynvml
            pynvml.nvmlInit()
            h = pynvml.nvmlDeviceGetHandleByIndex(self.device_index)
            while not self._stop.is_set():
                p_mW = pynvml.nvmlDeviceGetPowerUsage(h)
                self.samples.append((time.perf_counter(), p_mW/1000.0))
                time.sleep(self.interval_s)
            pynvml.nvmlShutdown()
        def _loop_smi(self):
            while not self._stop.is_set():
                try:
                    out = subprocess.check_output(
                        ["nvidia-smi", "--query-gpu=power.draw", "--format=csv,noheader,nounits", "-i", "0"],
                        stderr=subprocess.DEVNULL
                    )
                    watts = float(out.decode().strip().splitlines()[0])
                except Exception:
                    watts = math.nan
                self.samples.append((time.perf_counter(), watts))
                time.sleep(self.interval_s)
        def __enter__(self):
            if device=="cuda" and torch.cuda.is_available():
                target = self._loop_nvml if self.use_nvml else self._loop_smi
                self._thr = threading.Thread(target=target, daemon=True); self._thr.start()
            return self
        def __exit__(self, *args):
            if self._thr:
                self._stop.set(); self._thr.join()
        def summary(self):
            if len(self.samples) < 2:
                return {"avg_W": float("nan"), "energy_J": float("nan")}
            e = 0.0
            finite = [p for _,p in self.samples if not math.isnan(p)]
            for (t0,p0),(t1,p1) in zip(self.samples[:-1], self.samples[1:]):
                if not (math.isnan(p0) or math.isnan(p1)):
                    e += 0.5*(p0+p1)*(t1-t0)
            avg_W = sum(finite)/len(finite) if finite else float("nan")
            return {"avg_W": avg_W, "energy_J": e}

    with PowerSampler(interval_s=0.02) as ps:
        t0 = time.perf_counter()
        with torch.inference_mode():
            for _ in range(iters):
                _ = model(batch_tensor)
        if device=="cuda": torch.cuda.synchronize()
        t1 = time.perf_counter()
    lp = ps.summary()
    return {"lat_ms": (t1-t0)/iters*1000.0, "avg_W": lp["avg_W"], "energy_J": lp["energy_J"]}


# Evaluation

In [159]:
# --- Evaluate FP32 vs PTQ vs QAT ---
# fixed batch for latency/power
xb_eval, _ = next(iter(val_loader))
xb_eval = xb_eval.to(device=device, dtype=torch.float32, non_blocking=(device=="cuda"))

acc_fp  = top1_accuracy(m_fp_trained, val_loader, device=device, dtype=torch.float32)
acc_ptq = top1_accuracy(m_ptq,        val_loader, device=device, dtype=torch.float32)
acc_qat = top1_accuracy(m_qat,        val_loader, device=device, dtype=torch.float32)

lpe_fp  = latency_power_energy(m_fp_trained, xb_eval, iters=200, warm=20, device=device)
lpe_ptq = latency_power_energy(m_ptq,        xb_eval, iters=200, warm=20, device=device)
lpe_qat = latency_power_energy(m_qat,        xb_eval, iters=200, warm=20, device=device)

size_fp  = model_size_mb_fp32(m_fp_trained)
size_ptq = model_size_mb_int8_like(m_ptq, per_channel=True)
size_qat = model_size_mb_int8_like(m_qat, per_channel=True)

df = pd.DataFrame([
    {"model":"FP32 baseline",
     "lat_ms": lpe_fp["lat_ms"],  "avg_W": lpe_fp["avg_W"],  "energy_J": lpe_fp["energy_J"],
     "top1": acc_fp, "size_MB": size_fp},
    {"model":"PTQ INT8 (TorchAO)",
     "lat_ms": lpe_ptq["lat_ms"], "avg_W": lpe_ptq["avg_W"], "energy_J": lpe_ptq["energy_J"],
     "top1": acc_ptq, "size_MB": size_ptq},
    {"model":"QAT INT8 (TorchAO)",
     "lat_ms": lpe_qat["lat_ms"], "avg_W": lpe_qat["avg_W"], "energy_J": lpe_qat["energy_J"],
     "top1": acc_qat, "size_MB": size_qat},
])
print(df)

                model    lat_ms  avg_W  energy_J    top1   size_MB
0       FP32 baseline  0.744754   4.35  0.573032  0.0444  3.048828
1  PTQ INT8 (TorchAO)  1.461631   5.25  1.227457  0.0443  0.764221
2  QAT INT8 (TorchAO)  1.153280  25.39  4.411358  0.0447  0.764221
