# Robust and Efficient DeepFake Detection via Quantization and Adversarial Training 
## By Rahul Champaneria

**Dataset:** _OpenForensics_

**Model Type**: _RECCE_

**Quantization Types:**
* Dynamic
    * _Baseline INT8-Dynamic_ ```recce_standard_int8_dynamic.pt```
    * _PGD-AT INT8-Dynamic_ ```recce_int8dyn_pgd_at.pt```
    * _TRADES INT8-Dynamic_ ```recce_int8dyn_trades.pt```
* Static
    * _Baseline INT8-Static-FX_ ```recce_standard_int8_static_fx.pt```
    * _Standard INT8-Static-FX_ ```recce_int8static_fx_standard.pt```
* Quantized Aware Training (QAT)
    * _Baseline INT8-QAT-FX_ ```recce_standard_int8_qat_fx.pt```
    * _Standard INT8-QAT-FX_ ```recce_int8qat_fx_standard.pt```

In [2]:
!pip install torch torchvision torchaudio --upgrade

!pip install timm

!pip install torchmetrics

!pip install numpy pandas scikit-learn

!pip install matplotlib seaborn tqdm

!pip install tensorboard

!pip install pyyaml rich

!pip install --upgrade jupyter ipywidgets tqdm

!conda install -c conda-forge ipywidgets



'conda' is not recognized as an internal or external command,
operable program or batch file.


## Part 1 - Environment Setup & Dataset Overview
Loading of Dataset location, Pytorch, TorchVision, TorchMetrics, and other necessary libraries for:
* Model Construction
* Dataset Preprocessing
* Training Loops
* Adversarial Attacks
* Quantization

```text
Dataset/
├── Train/
│   ├── real/
│   └── fake/
├── Validation/
│   ├── real/
│   └── fake/
└── Test/
    ├── real/
    └── fake/
```


In [4]:
import os

DATA_ROOT = r"C:\Users\admin\Github\DFDetectAdversAttack\Dataset"

SEED = 42
IMSIZE = 224
BATCH = 32
NUM_WORKERS = 4

PGD_EPS = 8/255
PGD_ALPHA = 2/255
PGD_STEPS = 10
TRADES_BETA = 6.0

for split in ["Train","Validation","Test"]:
    real_path = os.path.join(DATA_ROOT, split, "real")
    fake_path = os.path.join(DATA_ROOT, split, "fake")
    print(f"\n[{split.upper()}]")
    print(" Real Images:", len(os.listdir(real_path)))
    print(" Fake Images:", len(os.listdir(fake_path)))


[TRAIN]
 Real Images: 70001
 Fake Images: 70001

[VALIDATION]
 Real Images: 19787
 Fake Images: 19641

[TEST]
 Real Images: 5413
 Fake Images: 5492


## Part 2 - Core Libraries, Device Setup, and Metric Utilities

This section:
- Imports all core libraries (PyTorch, TorchVision, TorchMetrics, NumPy, etc.).
- Sets the `device` (GPU if available, otherwise CPU) and random seeds for reproducibility.
- Defines:
  - ImageNet normalization (`mean`, `std`)
  - `make_metrics()` to create AUROC, AUPRC, and F1 metric objects
  - `denorm_bounds()` and `to_norm_eps()` to convert pixel-space epsilon/alpha to normalized space

These helpers are reused by training, PGD, TRADES, and robustness evaluation.

In [5]:
import os, time, copy, json, math, numpy as np
from typing import Dict, Any

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms, models
import timm

from torchmetrics.classification import BinaryAUROC, BinaryAveragePrecision, BinaryF1Score

# Repro
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(SEED); np.random.seed(SEED)

# Normalization (ImageNet)
mean = (0.485, 0.456, 0.406)
std  = (0.229, 0.224, 0.225)

def make_metrics():
    return (BinaryAUROC().to(device),
            BinaryAveragePrecision().to(device),
            BinaryF1Score().to(device))

def denorm_bounds():
    lo = (torch.tensor(0.0) - torch.tensor(mean).view(3,1,1)) / torch.tensor(std).view(3,1,1)
    hi = (torch.tensor(1.0) - torch.tensor(mean).view(3,1,1)) / torch.tensor(std).view(3,1,1)
    return lo.to(device), hi.to(device)

def to_norm_eps(eps_pixel=PGD_EPS, alpha_pixel=PGD_ALPHA):
    eps_n = torch.tensor(eps_pixel, device=device) / torch.tensor(std, device=device).view(3,1,1)
    alpha_n = torch.tensor(alpha_pixel, device=device) / torch.tensor(std, device=device).view(3,1,1)
    return eps_n, alpha_n


## Part 3 — Data Transforms, Datasets, and DataLoaders

This section:
- Defines data augmentations for training (random crop, flip, color jitter) and deterministic transforms for validation/test.
- Builds `ImageFolder` datasets for `Train`, `Validation`, and `Test`.
- Wraps them into `DataLoader`s.
- Checks that the task is binary (`NUM_CLASSES == 2`) and prints the class names.
- Probes one mini-batch to verify shapes and value range after normalization.


In [6]:
train_tf = transforms.Compose([
    transforms.RandomResizedCrop(IMSIZE, scale=(0.8,1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.2,0.2,0.2,0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean,std)
])
eval_tf = transforms.Compose([
    transforms.Resize(IMSIZE+32),
    transforms.CenterCrop(IMSIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean,std)
])

train_ds = datasets.ImageFolder(os.path.join(DATA_ROOT, "Train"), transform=train_tf)
val_ds   = datasets.ImageFolder(os.path.join(DATA_ROOT, "Validation"),   transform=eval_tf)
test_ds  = datasets.ImageFolder(os.path.join(DATA_ROOT, "Test"),  transform=eval_tf)

train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True,  num_workers=NUM_WORKERS, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

classes = train_ds.classes
NUM_CLASSES = len(classes)
assert NUM_CLASSES == 2, f"This notebook assumes binary classification; got {classes}"
print("Classes:", classes)
# quick loader probe
probe_loader = torch.utils.data.DataLoader(
    train_ds, batch_size=8, shuffle=True, num_workers=0, pin_memory=False
)
xb, yb = next(iter(probe_loader))
print("Probe batch:", xb.shape, yb.shape, "min/max", float(xb.min()), float(xb.max()))



Classes: ['Fake', 'Real']
Probe batch: torch.Size([8, 3, 224, 224]) torch.Size([8]) min/max -2.1179039478302 2.640000104904175


## Part 4 — Model Architectures (SRM, RECCE, SRMNet, and Factory)

This section:
- Defines `SRMConv2d`, a fixed high-pass filter block over the luminance channel.
- Defines `RECCE`:
  - Extracts spatial cues (Sobel X/Y, Laplacian) from grayscale.
  - Concatenates them with RGB to form a 6-channel input.
  - Uses a ResNet-18 backbone with `QuantStub`/`DeQuantStub` to support quantization.
- Loads a pretrained FP32 RECCE model from `./runs/recce_standard_fp32_final.pt`.
- Defines `SRMNet`, which uses SRM as a frontend and ResNet-18 as the classifier.
- Provides `make_xception()` and `make_model(kind)` as a factory to instantiate `xception`, `recce`, or `srm` models.


In [7]:
class SRMConv2d(nn.Module):
    def __init__(self):
        super().__init__()
        k = np.stack([
            [[0,0,0,0,0],[0,0,0,0,0],[0,-1,2,-1,0],[0,0,0,0,0],[0,0,0,0,0]],
            [[0,0,0,0,0],[0,-1,2,-1,0],[0,2,-4,2,0],[0,-1,2,-1,0],[0,0,0,0,0]],
            [[-1,2,-1,2,-1],[2,-6,8,-6,2],[-1,8,-12,8,-1],[2,-6,8,-6,2],[-1,2,-1,2,-1]]
        ]).astype(np.float32)
        k = k[:,None,:,:]  # (3,1,5,5)
        self.weight = nn.Parameter(torch.from_numpy(k), requires_grad=False)
        self.bias   = nn.Parameter(torch.zeros(3), requires_grad=False)

    def forward(self, x):  # x: (N,3,H,W) RGB normalized
        r,g,b = x[:,0:1], x[:,1:2], x[:,2:3]
        y = 0.299*r + 0.587*g + 0.114*b
        return F.conv2d(y, self.weight, self.bias, padding=2)

class RECCE(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.sobel_x = nn.Conv2d(1,1,3,padding=1,bias=False)
        self.sobel_y = nn.Conv2d(1,1,3,padding=1,bias=False)
        self.lap     = nn.Conv2d(1,1,3,padding=1,bias=False)
        with torch.no_grad():
            self.sobel_x.weight.copy_(torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], dtype=torch.float32))
            self.sobel_y.weight.copy_(torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], dtype=torch.float32))
            self.lap.weight.copy_(torch.tensor([[[[0,1,0],[1,-4,1],[0,1,0]]]], dtype=torch.float32))
        for m in [self.sobel_x, self.sobel_y, self.lap]:
            m.requires_grad_(False)

        self.backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        self.backbone.conv1 = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.backbone.fc = nn.Linear(self.backbone.fc.in_features, num_classes)

        self.quant   = torch.ao.quantization.QuantStub()
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
        r,g,b = x[:,0:1], x[:,1:2], x[:,2:3]
        y = 0.299*r + 0.587*g + 0.114*b
        sx = self.sobel_x(y)
        sy = self.sobel_y(y)
        lp = self.lap(y)
        x6 = torch.cat([x, sx, sy, lp], dim=1)

        x6_q = self.quant(x6)
        out  = self.backbone(x6_q)
        out  = self.dequant(out)
        return out

model = RECCE()
model.load_state_dict(torch.load("./runs/recce_standard_fp32_final.pt", map_location="cpu"))
print("Loaded pretrained FP32 RECCE.")

class SRMNet(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.srm = SRMConv2d()
        self.head = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        self.head.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.head.fc = nn.Linear(self.head.fc.in_features, num_classes)

    def forward(self, x):
        s = self.srm(x)
        return self.head(s)

def make_xception(num_classes=2):
    return timm.create_model("xception", pretrained=True, num_classes=num_classes)

def make_model(kind:str):
    k = kind.lower()
    if k == "xception": return make_xception(NUM_CLASSES).to(device)
    if k == "recce":    return RECCE(NUM_CLASSES).to(device)
    if k == "srm":      return SRMNet(NUM_CLASSES).to(device)
    raise ValueError("kind must be one of: xception | recce | srm")


Loaded pretrained FP32 RECCE.


## Part 5 — Adversarial Attack & Robust Training Losses

This section:
- Implements `pgd_linf()` to generate ℓ∞ PGD adversarial examples:
  - Works in normalized space.
  - Uses multi-step gradient ascent on the loss.
- Implements `trades_loss()`:
  - Inner loop builds adversarial examples by minimizing KL divergence between clean and adversarial predictions.
  - Final loss combines standard cross-entropy and a robustness KL penalty scaled by `TRADES_BETA`.

These functions are later used for PGD-Adversarial Training and TRADES-based training.


In [None]:
def pgd_linf(model, x, y, eps=PGD_EPS, alpha=PGD_ALPHA, steps=PGD_STEPS):
    model.eval()
    from contextlib import nullcontext
    autocast_off = torch.cuda.amp.autocast(enabled=False)

    eps_n, alpha_n = to_norm_eps(eps, alpha)
    lo, hi = denorm_bounds()

    x_adv = x.detach() + torch.empty_like(x).uniform_(-eps_n.max(), eps_n.max())
    x_adv = x_adv.clamp(lo, hi)

    for _ in range(steps):
        x_adv.requires_grad_(True)
        with autocast_off:
            logits = model(x_adv)
            loss = F.cross_entropy(logits, y)
        g = torch.autograd.grad(loss, x_adv, only_inputs=True)[0]
        x_adv = x_adv.detach() + alpha_n * torch.sign(g)
        x_adv = torch.max(torch.min(x_adv, x + eps_n), x - eps_n)
        x_adv = x_adv.clamp(lo, hi)
    return x_adv.detach()


def trades_loss(model, x, y, beta=TRADES_BETA, eps=PGD_EPS, alpha=PGD_ALPHA, steps=PGD_STEPS):
    from contextlib import nullcontext
    autocast_off = torch.cuda.amp.autocast(enabled=False)

    model.train()
    eps_n, alpha_n = to_norm_eps(eps, alpha)
    lo, hi = denorm_bounds()

    x_adv = x.detach() + 0.001*torch.randn_like(x)
    x_adv = torch.max(torch.min(x_adv, x + eps_n), x - eps_n).clamp(lo, hi)

    with torch.no_grad():
        p = F.softmax(model(x), dim=1)

    for _ in range(steps):
        x_adv.requires_grad_(True)
        with autocast_off:
            q_logits = model(x_adv)
            loss_kl = F.kl_div(F.log_softmax(q_logits, dim=1), p, reduction='batchmean')
        g = torch.autograd.grad(loss_kl, x_adv, only_inputs=True)[0]
        x_adv = x_adv.detach() + alpha_n * torch.sign(g)
        x_adv = torch.max(torch.min(x_adv, x + eps_n), x - eps_n).clamp(lo, hi)

    logits_clean = model(x)
    logits_adv   = model(x_adv)
    loss_nll = F.cross_entropy(logits_clean, y)
    loss_rob = F.kl_div(F.log_softmax(logits_adv,dim=1),
                        F.softmax(logits_clean.detach(),dim=1),
                        reduction='batchmean')
    return loss_nll + beta * loss_rob


## Part 6 — Evaluation Function and Base Training Loop

This section:
- Defines `evaluate(model, loader)`:
  - Computes Accuracy, AUROC, AUPRC, and F1 on a given DataLoader.
- Defines a general `train()` function:
  - Supports three regimes:
    - `"standard"` (clean training)
    - `"pgd_at"` (PGD Adversarial Training)
    - `"trades"` (TRADES robust training)
  - Tracks validation metrics and keeps the best model based on AUROC.

This is the core training/evaluation logic used throughout the pipeline.


In [8]:
def evaluate(model, loader) -> Dict[str, float]:
    model.eval()
    auroc, auprc, f1 = make_metrics()
    correct = total = 0
    with torch.no_grad():
        for x,y in loader:
            x,y = x.to(device), y.to(device)
            logits = model(x)
            prob_fake = torch.softmax(logits, dim=1)[:,1]
            pred = torch.argmax(logits, dim=1)
            correct += (pred==y).sum().item()
            total   += y.numel()
            auroc.update(prob_fake, y)
            auprc.update(prob_fake, y)
            f1.update(pred, y)
    return {
        "acc": float(correct/total),
        "auroc": float(auroc.compute().item()),
        "auprc": float(auprc.compute().item()),
        "f1": float(f1.compute().item())
    }

def train(model, regime="standard", epochs=10, lr=3e-4, weight_decay=1e-4,
          pgd_eps=PGD_EPS, pgd_alpha=PGD_ALPHA, pgd_steps=PGD_STEPS, trades_beta=TRADES_BETA):
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    best, best_state = -1, None
    for ep in range(1, epochs+1):
        model.train()
        running = 0.0
        for x,y in train_loader:
            x,y = x.to(device), y.to(device)
            opt.zero_grad(set_to_none=True)
            if regime == "standard":
                loss = F.cross_entropy(model(x), y)
            elif regime == "pgd_at":
                x_adv = pgd_linf(model, x, y, eps=pgd_eps, alpha=pgd_alpha, steps=pgd_steps)
                loss  = F.cross_entropy(model(x_adv), y)
            elif regime == "trades":
                loss  = trades_loss(model, x, y, beta=trades_beta, eps=pgd_eps, alpha=pgd_alpha, steps=pgd_steps)
            else:
                raise ValueError("regime must be 'standard' | 'pgd_at' | 'trades'")
            loss.backward(); opt.step()
            running += loss.item() * x.size(0)

        val_m = evaluate(model, val_loader)
        print(f"Ep {ep:02d} | train_loss={(running/len(train_ds)):.4f} | "
              f"val_acc={val_m['acc']:.3f} auroc={val_m['auroc']:.3f} auprc={val_m['auprc']:.3f} f1={val_m['f1']:.3f}")
        if val_m['auroc'] > best:
            best = val_m['auroc']; best_state = copy.deepcopy(model.state_dict())

    if best_state is not None:
        model.load_state_dict(best_state)
    return model


## Part 7 — Eager Quantization Utilities (Dynamic, Static, QAT)

This section:
- Defines:
  - `model_size_mb(path)`: compute model size in MB.
  - `latency_ms(model)`: approximate CPU inference latency.
  - `save_state(tag, model)`: save weights and report size/latency.
- Implements:
  - `quantize_dynamic()`: dynamic (weights-only) INT8 quantization of Linear layers.
  - `quantize_static()`: static post-training quantization with calibration.
  - `qat_int8()`: a basic Quantization-Aware Training loop for ResNet-like models.

These are early quantization utilities; later you use an FX-based version that is more advanced.


In [None]:
import torch.ao.quantization as tq

def model_size_mb(path): return os.path.getsize(path)/1e6

def latency_ms(model, iters=50, bs=1):
    m = copy.deepcopy(model).to('cpu').eval()
    x = torch.randn(bs,3,IMSIZE,IMSIZE)
    with torch.no_grad():
        for _ in range(10): m(x)
        t0=time.time()
        for _ in range(iters): m(x)
        t1=time.time()
    return 1e3*(t1-t0)/iters

def save_state(tag, model):
    path = os.path.join(RUN_DIR, f"{tag}.pt")
    torch.save(model.state_dict(), path)
    return path, model_size_mb(path), latency_ms(model)

def quantize_dynamic(model_cpu):
    return tq.quantize_dynamic(model_cpu, {nn.Linear}, dtype=torch.qint8)

def quantize_static(model_cpu, calib_loader, backend="fbgemm"):
    m = copy.deepcopy(model_cpu).to('cpu').eval()
    m.qconfig = tq.get_default_qconfig(backend)
    tq.prepare(m, inplace=True)
    with torch.no_grad():
        for x,_ in calib_loader:
            m(x)
    tq.convert(m, inplace=True)
    return m

def qat_int8(base_model, epochs=5, lr=1e-4, backend="fbgemm"):
    m = copy.deepcopy(base_model).to('cpu')
    m.train()
    m.qconfig = tq.get_default_qat_qconfig(backend)
    tq.prepare_qat(m, inplace=True)
    m = m.to(device)
    opt = torch.optim.AdamW(m.parameters(), lr=lr)
    for ep in range(1, epochs+1):
        m.train()
        for x,y in train_loader:
            x,y = x.to(device), y.to(device)
            loss = F.cross_entropy(m(x), y)
            opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
        print(f"[QAT] epoch {ep:02d} done")
    m_cpu = copy.deepcopy(m).to('cpu').eval()
    tq.convert(m_cpu, inplace=True)
    return m_cpu


## Part 8 — RECCE-Specific Eager PTQ and QAT Utilities

This section:
- Defines `recce_static_ptq()`:
  - Runs static post-training quantization on a RECCE model using its `QuantStub/DeQuantStub` structure and a calibration loader.
- Defines `recce_qat()`:
  - Runs QAT for RECCE on CPU/GPU:
    - Prepares the model with `prepare_qat`.
    - Fine-tunes for a given number of epochs.
    - Converts to a final INT8 model.
- Defines `save_and_report(tag, m)`:
  - Saves a given model under `./runs`.
  - Returns its path, size in MB, and CPU latency.

These helpers are an earlier quantization path; later you add the FX-based quantization (more modern).


In [None]:
import torch.ao.quantization as tq

def recce_static_ptq(recce_model, calib_loader, backend="fbgemm"):
    m = recce_model.to('cpu').eval()
    tq.backend = backend
    m.qconfig = tq.get_default_qconfig(backend)
    tq.prepare(m, inplace=True)
    with torch.no_grad():
        for x,_ in calib_loader:
            m(x)
    tq.convert(m, inplace=True)
    return m

def recce_qat(recce_model, train_loader, epochs=5, lr=1e-4, backend="fbgemm", device=None):
    dev = device or ("cuda" if torch.cuda.is_available() else "cpu")
    m = recce_model.to('cpu').train()
    tq.backend = backend
    m.qconfig = tq.get_default_qat_qconfig(backend)
    torch.ao.quantization.prepare_qat(m, inplace=True)

    m = m.to(dev)
    opt = torch.optim.AdamW(m.parameters(), lr=lr)
    for ep in range(1, epochs+1):
        m.train()
        for x,y in train_loader:
            x,y = x.to(dev), y.to(dev)
            loss = F.cross_entropy(m(x), y)
            opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
        print(f"[QAT] epoch {ep}/{epochs} done")

    m_cpu = m.to('cpu').eval()
    torch.ao.quantization.convert(m_cpu, inplace=True)
    return m_cpu

def save_and_report(tag, m):
    p = os.path.join("./runs", f"{tag}.pt"); torch.save(m.state_dict(), p)
    from pathlib import Path; size_mb = Path(p).stat().st_size/1e6
    lat_ms = latency_ms(m)
    return {"path": p, "size_mb": size_mb, "lat_ms": lat_ms}


## Part 9 — Smoke Test: Single Forward Pass

This section:
- Creates a tiny DataLoader to grab a small batch.
- Instantiates `model_dbg = make_model("recce")`.
- Runs a single forward pass and prints:
  - Output shape (logits)
  - Approximate forward time in milliseconds.

Purpose: verify that the model compiles and runs end-to-end before heavy training.


In [10]:
import time, torch

tiny_loader = torch.utils.data.DataLoader(
    train_ds, batch_size=8, shuffle=True, num_workers=0, pin_memory=False
)

x,y = next(iter(tiny_loader))
x,y = x.to(device), y.to(device)

model_dbg = make_model("recce")
t0 = time.time()
with torch.no_grad():
    logits = model_dbg(x)
t1 = time.time()
print("Forward OK. Logits shape:", logits.shape, "Time:", round(1000*(t1-t0),1), "ms")


Forward OK. Logits shape: torch.Size([8, 2]) Time: 145.0 ms


## Part 10 — DataLoader Configuration (Windows-Safe)

This section:
- Rebuilds `train_loader`, `val_loader`, and `test_loader` with:
  - `num_workers = 0`
  - `persistent_workers = False`
  - `pin_memory = False`
for maximum stability on Windows and inside some IDEs (e.g., VS Code Jupyter).


In [None]:
import os, torch
from torch.utils.data import DataLoader

HAS_CUDA = torch.cuda.is_available()
NUM_WORKERS = 0

train_loader = DataLoader(
    train_ds, batch_size=min(32, BATCH), shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=False, persistent_workers=False
)
val_loader = DataLoader(
    val_ds, batch_size=min(32, BATCH), shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=False, persistent_workers=False
)
test_loader = DataLoader(
    test_ds, batch_size=min(32, BATCH), shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=False, persistent_workers=False
)


## Part 11 — Progress-Bar Training Loop (TQDM)

This section:
- Redefines `train()` to:
  - Use a `tqdm` progress bar during training.
  - Show batch-level loss live.
- Supports the same training regimes (`standard`, `pgd_at`, `trades`).
- Prints per-epoch summary with train loss and validation metrics.

This is the user-friendly training loop you actually used to train RECCE for 8 epochs.


In [None]:
from tqdm.auto import tqdm

def train(model, regime="standard", epochs=3, lr=3e-4, weight_decay=1e-4,
          pgd_eps=8/255, pgd_alpha=2/255, pgd_steps=10, trades_beta=6.0):
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    best, best_state = -1, None
    for ep in range(1, epochs+1):
        model.train()
        running = 0.0
        it = tqdm(train_loader, desc=f"Epoch {ep}/{epochs} [{regime}]", leave=False)
        for x,y in it:
            x,y = x.to(device), y.to(device)
            opt.zero_grad(set_to_none=True)
            if regime == "standard":
                loss = F.cross_entropy(model(x), y)
            elif regime == "pgd_at":
                x_adv = pgd_linf(model, x, y, eps=pgd_eps, alpha=pgd_alpha, steps=min(pgd_steps, 2))
                loss  = F.cross_entropy(model(x_adv), y)
            elif regime == "trades":
                loss  = trades_loss(model, x, y, beta=trades_beta, eps=pgd_eps, alpha=pgd_alpha, steps=min(pgd_steps, 2))
            else:
                raise ValueError("regime must be 'standard' | 'pgd_at' | 'trades'")
            loss.backward(); opt.step()
            running += loss.item() * x.size(0)
            it.set_postfix(loss=f"{loss.item():.3f}")
        val_m = evaluate(model, val_loader)
        print(f"Ep {ep:02d} | train_loss={(running/len(train_ds)):.4f} | "
              f"val_acc={val_m['acc']:.3f} auroc={val_m['auroc']:.3f} auprc={val_m['auprc']:.3f} f1={val_m['f1']:.3f}")
        if val_m['auroc'] > best:
            best = val_m['auroc']; best_state = copy.deepcopy(model.state_dict())
    if best_state is not None:
        model.load_state_dict(best_state)
    return model


## Part 12 — Debug Training and Initial 1-Epoch Run

This section:
- Defines `train_debug()`:
  - Runs only a few batches per epoch (e.g., 5 batches).
  - Prints detailed batch-level loss and validation metrics.
  - Useful for debugging logic or data issues.
- Then sets:
  - `MODEL_KIND = "recce"`
  - `REGIME = "standard"`
  - `EPOCHS = 1`
- Trains RECCE for 1 epoch as a quick sanity check.


In [14]:

from itertools import islice

def train_debug(model, regime="standard", epochs=1, lr=3e-4, weight_decay=1e-4,
                pgd_eps=PGD_EPS, pgd_alpha=PGD_ALPHA, pgd_steps=PGD_STEPS, trades_beta=TRADES_BETA):
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    for ep in range(1, epochs+1):
        model.train()
        running = 0.0
        print(f"[start epoch {ep}/{epochs}] regime={regime}", flush=True)

        for b, (x,y) in enumerate(islice(train_loader, 5)):
            if b == 0: print(" first batch pulled", flush=True)
            x,y = x.to(device), y.to(device)
            opt.zero_grad(set_to_none=True)

            if regime == "standard":
                loss = F.cross_entropy(model(x), y)
            elif regime == "pgd_at":
                x_adv = pgd_linf(model, x, y, eps=pgd_eps, alpha=pgd_alpha, steps=2)
                loss  = F.cross_entropy(model(x_adv), y)
            elif regime == "trades":
                loss  = trades_loss(model, x, y, beta=trades_beta, eps=pgd_eps, alpha=pgd_alpha, steps=2)
            else:
                raise ValueError("regime must be 'standard' | 'pgd_at' | 'trades'")

            loss.backward(); opt.step()
            running += loss.item() * x.size(0)
            print(f"  batch {b}: loss={loss.item():.4f}", flush=True)

        val_m = evaluate(model, val_loader)
        print(f"[end epoch {ep}] train_loss={(running/(5*train_loader.batch_size)):.4f} | "
              f"val_acc={val_m['acc']:.3f} auroc={val_m['auroc']:.3f}", flush=True)

    return model


MODEL_KIND = "recce"
REGIME     = "standard"
EPOCHS     = 1

model = make_model(MODEL_KIND)
model = train(model, regime=REGIME, epochs=EPOCHS)


Epoch 1/1 [standard]:   0%|          | 0/4376 [00:00<?, ?it/s]

Ep 01 | train_loss=0.1193 | val_acc=0.952 auroc=0.992 auprc=0.992 f1=0.952


## Part 13 — Main FP32 Training (8 Epochs, RECCE)

This section:
- Configures:
  - `MODEL_KIND = "recce"`
  - `REGIME = "standard"`
  - `EPOCHS = 8`
- Recreates the RECCE model and trains it for 8 epochs.
- This is the primary FP32 baseline you later quantize and evaluate.


In [15]:
MODEL_KIND = "recce"
REGIME     = "standard"
EPOCHS     = 8

model = make_model(MODEL_KIND)
model = train(model, regime=REGIME, epochs=EPOCHS)

Epoch 1/8 [standard]:   0%|          | 0/4376 [00:00<?, ?it/s]

Ep 01 | train_loss=0.1196 | val_acc=0.953 auroc=0.994 auprc=0.994 f1=0.952


Epoch 2/8 [standard]:   0%|          | 0/4376 [00:00<?, ?it/s]

Ep 02 | train_loss=0.0712 | val_acc=0.955 auroc=0.994 auprc=0.994 f1=0.954


Epoch 3/8 [standard]:   0%|          | 0/4376 [00:00<?, ?it/s]

Ep 03 | train_loss=0.0599 | val_acc=0.967 auroc=0.996 auprc=0.996 f1=0.967


Epoch 4/8 [standard]:   0%|          | 0/4376 [00:00<?, ?it/s]

Ep 04 | train_loss=0.0528 | val_acc=0.959 auroc=0.996 auprc=0.996 f1=0.958


Epoch 5/8 [standard]:   0%|          | 0/4376 [00:00<?, ?it/s]

Ep 05 | train_loss=0.0465 | val_acc=0.967 auroc=0.996 auprc=0.996 f1=0.967


Epoch 6/8 [standard]:   0%|          | 0/4376 [00:00<?, ?it/s]

Ep 06 | train_loss=0.0442 | val_acc=0.974 auroc=0.997 auprc=0.997 f1=0.974


Epoch 7/8 [standard]:   0%|          | 0/4376 [00:00<?, ?it/s]

Ep 07 | train_loss=0.0409 | val_acc=0.954 auroc=0.996 auprc=0.996 f1=0.952


Epoch 8/8 [standard]:   0%|          | 0/4376 [00:00<?, ?it/s]

Ep 08 | train_loss=0.0395 | val_acc=0.973 auroc=0.997 auprc=0.997 f1=0.973


## Part 14 — Evaluation Helpers for FP32 and INT8 (CPU Version, Eager)

This section:
- Defines `evaluate_fp32()`:
  - Evaluates FP32 models on GPU (or CPU if no CUDA) with Accuracy, AUROC, AUPRC, and F1.
- Defines `evaluate_int8_cpu()`:
  - Evaluates quantized INT8 models on CPU (for static/dynamic/QAT) using TorchMetrics.

These were the first evaluation helpers for your quantized models; later updated for FX-based quantization.


In [None]:
def evaluate_fp32(model, loader) -> Dict[str, float]:
    """Use this for FP32 models on GPU (or CPU if no CUDA)."""
    model = model.to(device).eval()
    auroc, auprc, f1 = make_metrics()
    correct = total = 0
    with torch.no_grad():
        for x,y in loader:
            x,y = x.to(device), y.to(device)
            logits = model(x)
            prob_fake = torch.softmax(logits, dim=1)[:,1]
            pred = torch.argmax(logits, dim=1)
            correct += (pred==y).sum().item()
            total   += y.numel()
            auroc.update(prob_fake, y)
            auprc.update(prob_fake, y)
            f1.update(pred, y)
    return {
        "acc": float(correct/total),
        "auroc": float(auroc.compute().item()),
        "auprc": float(auprc.compute().item()),
        "f1": float(f1.compute().item())
    }

def evaluate_int8_cpu(model, loader) -> Dict[str, float]:
    """Use this for STATIC/QAT/DYNAMIC INT8 models (CPU-only)."""
    model = copy.deepcopy(model).to("cpu").eval()
    auroc = BinaryAUROC().to("cpu")
    auprc = BinaryAveragePrecision().to("cpu")
    f1    = BinaryF1Score().to("cpu")

    correct = total = 0
    with torch.no_grad():
        for x,y in loader:
            logits = model(x)
            prob_fake = torch.softmax(logits, dim=1)[:,1]
            pred = torch.argmax(logits, dim=1)
            correct += (pred==y).sum().item()
            total   += y.numel()
            auroc.update(prob_fake.cpu(), y.cpu())
            auprc.update(prob_fake.cpu(), y.cpu())
            f1.update(pred.cpu(), y.cpu())
    return {
        "acc": float(correct/total),
        "auroc": float(auroc.compute().item()),
        "auprc": float(auprc.compute().item()),
        "f1": float(f1.compute().item())
    }


## Part 15 — FX Graph-Mode Quantization Helpers for RECCE

This section:
- Imports FX-based quantization utilities: `prepare_fx`, `convert_fx`, `prepare_qat_fx`.
- Redefines:
  - `evaluate_fp32()` and `evaluate_int8_cpu()` (now consistent with FX flow).
- Implements:
  - `recce_fx_static_ptq()`:
    - Uses FX `prepare_fx` and `convert_fx` with an example input and calibration loader.
  - `recce_fx_qat()`:
    - Uses `prepare_qat_fx` to inject fake-quant nodes.
    - Fine-tunes with QAT for a small number of epochs.
    - Converts to a fully quantized INT8 model.

This is the “advanced” quantization path compatible with PyTorch 2.9+ deprecations.


In [None]:
import torch.ao.quantization as tq
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx, prepare_qat_fx

def evaluate_fp32(model, loader) -> Dict[str, float]:
    model = model.to(device).eval()
    auroc, auprc, f1 = make_metrics()
    correct = total = 0
    with torch.no_grad():
        for x,y in loader:
            x,y = x.to(device), y.to(device)
            logits = model(x)
            prob_fake = torch.softmax(logits, dim=1)[:,1]
            pred = torch.argmax(logits, dim=1)
            correct += (pred==y).sum().item()
            total   += y.numel()
            auroc.update(prob_fake, y)
            auprc.update(prob_fake, y)
            f1.update(pred, y)
    return {
        "acc": float(correct/total),
        "auroc": float(auroc.compute().item()),
        "auprc": float(auprc.compute().item()),
        "f1": float(f1.compute().item())
    }

def evaluate_int8_cpu(model, loader) -> Dict[str, float]:
    m = copy.deepcopy(model).to("cpu").eval()
    auroc = BinaryAUROC().to("cpu")
    auprc = BinaryAveragePrecision().to("cpu")
    f1    = BinaryF1Score().to("cpu")
    correct = total = 0
    with torch.no_grad():
        for x,y in loader:
            logits = m(x)
            prob_fake = torch.softmax(logits, dim=1)[:,1]
            pred = torch.argmax(logits, dim=1)
            correct += (pred==y).sum().item()
            total   += y.numel()
            auroc.update(prob_fake.cpu(), y.cpu())
            auprc.update(prob_fake.cpu(), y.cpu())
            f1.update(pred.cpu(), y.cpu())
    return {
        "acc": float(correct/total),
        "auroc": float(auroc.compute().item()),
        "auprc": float(auprc.compute().item()),
        "f1": float(f1.compute().item())
    }

def recce_fx_static_ptq(model_fp32, calib_loader, backend="fbgemm"):
    m = copy.deepcopy(model_fp32).to("cpu").eval()
    tq.backend = backend
    qconfig = tq.get_default_qconfig(backend)
    qconfig_dict = {"": qconfig}

    example_x, _ = next(iter(calib_loader))
    prepared = prepare_fx(m, qconfig_dict, example_inputs=example_x)

    with torch.no_grad():
        for x,_ in calib_loader:
            prepared(x)
    quantized = convert_fx(prepared)
    return quantized

def recce_fx_qat(model_fp32, train_loader, epochs=5, lr=1e-4,
                 backend="fbgemm", device=None):
    dev = device or ("cuda" if torch.cuda.is_available() else "cpu")
    base = copy.deepcopy(model_fp32).to("cpu").train()
    tq.backend = backend
    qconfig = tq.get_default_qat_qconfig(backend)
    qconfig_dict = {"": qconfig}

    example_x, _ = next(iter(train_loader))
    prepared = prepare_qat_fx(base, qconfig_dict, example_inputs=example_x).to(dev)

    opt = torch.optim.AdamW(prepared.parameters(), lr=lr)
    for ep in range(1, epochs+1):
        prepared.train()
        running = 0.0
        for x,y in train_loader:
            x,y = x.to(dev), y.to(dev)
            logits = prepared(x)
            loss = F.cross_entropy(logits, y)
            opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
            running += loss.item() * x.size(0)
        print(f"[FX-QAT] epoch {ep}/{epochs} done, train_loss={running/len(train_ds):.4f}")

    prepared_cpu = prepared.to("cpu").eval()
    quantized = convert_fx(prepared_cpu)
    return quantized


## Part 16 — Quantization Suite (FP32 vs Dynamic vs Static-FX vs QAT-FX)

This section:
- Asserts that all necessary symbols from earlier parts exist.
- Configures:
  - `RUN_DIR`, `MODEL_TAG`, number of calibration samples, and QAT epochs.
- Evaluates and saves:
  1. **FP32 baseline** (`recce_standard_fp32_final`)
  2. **INT8 Dynamic** (weights-only)
  3. **INT8 Static FX** (calibrated)
  4. **INT8 QAT FX** (fine-tuned with fake-quant)
- Computes:
  - Accuracy, AUROC, AUPRC, F1
  - Model size (MB)
  - CPU latency (ms)
- Displays a summary table and saves a JSON file `*_quant_summary.json` with all metrics and model info.

This is your main **quantization comparison** section.


In [28]:
import os, json, copy, numpy as np
from pathlib import Path
from torch.utils.data import Subset, DataLoader
import torch
import pandas as pd

needed = ["model","train_ds","train_loader","test_loader","device",
          "evaluate_fp32","evaluate_int8_cpu","recce_fx_static_ptq","recce_fx_qat","latency_ms"]
missing = [n for n in needed if n not in globals()]
assert not missing, f"Missing in notebook: {missing}"

RUN_DIR = "./runs"; os.makedirs(RUN_DIR, exist_ok=True)
MODEL_TAG = "recce_standard"
CALIB_SAMPLES = min(1024, len(train_ds))
QAT_EPOCHS = 5
NUM_WORKERS = 0

def save_and_report(tag, m):
    p = os.path.join(RUN_DIR, f"{tag}.pt")
    torch.save(m.state_dict(), p)
    size_mb = Path(p).stat().st_size / 1e6
    lat_ms = latency_ms(m)
    return {"path": p, "size_mb": size_mb, "lat_ms": lat_ms}

print("[FP32] evaluating baseline…")
fp32_metrics = evaluate_fp32(model, test_loader)
fp32_info    = save_and_report(f"{MODEL_TAG}_fp32_final",
                               copy.deepcopy(model).to('cpu').eval())

print("[INT8-Dynamic] quantizing + evaluating…")
dyn = torch.ao.quantization.quantize_dynamic(
    copy.deepcopy(model).to('cpu').eval(),
    {torch.nn.Linear},
    dtype=torch.qint8
)
dyn_metrics = evaluate_int8_cpu(dyn, test_loader)
dyn_info    = save_and_report(f"{MODEL_TAG}_int8_dynamic", dyn)

print("[INT8-Static FX] building calibration loader & evaluating…")
calib_idx = np.random.choice(len(train_ds), size=CALIB_SAMPLES, replace=False)
calib_loader = DataLoader(Subset(train_ds, calib_idx), batch_size=64, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=False, persistent_workers=False)

st = recce_fx_static_ptq(model, calib_loader)
st_metrics = evaluate_int8_cpu(st, test_loader)
st_info    = save_and_report(f"{MODEL_TAG}_int8_static_fx", st)

print(f"[INT8-QAT FX] fine-tuning for {QAT_EPOCHS} epochs + evaluating…")
qatm = recce_fx_qat(model, train_loader, epochs=QAT_EPOCHS, device=device)
qat_metrics = evaluate_int8_cpu(qatm, test_loader)
qat_info    = save_and_report(f"{MODEL_TAG}_int8_qat_fx", qatm)

rows = [
    ["FP32",          fp32_metrics.get("acc"), fp32_metrics.get("auroc"), fp32_metrics.get("auprc"), fp32_metrics.get("f1"),
                      fp32_info["size_mb"], fp32_info["lat_ms"]],
    ["INT8-Dynamic",  dyn_metrics.get("acc"), dyn_metrics.get("auroc"), dyn_metrics.get("auprc"), dyn_metrics.get("f1"),
                      dyn_info["size_mb"], dyn_info["lat_ms"]],
    ["INT8-StaticFX", st_metrics.get("acc"),  st_metrics.get("auroc"),  st_metrics.get("auprc"),  st_metrics.get("f1"),
                      st_info["size_mb"], st_info["lat_ms"]],
    ["INT8-QATFX",    qat_metrics.get("acc"), qat_metrics.get("auroc"), qat_metrics.get("auprc"), qat_metrics.get("f1"),
                      qat_info["size_mb"], qat_info["lat_ms"]],
]
df = pd.DataFrame(rows, columns=["Version","Acc","AUROC","AUPRC","F1","Size (MB)","CPU Latency (ms)"])
display(df.style.format({"Acc":"{:.3f}","AUROC":"{:.3f}","AUPRC":"{:.3f}","F1":"{:.3f}",
                         "Size (MB)":"{:.1f}","CPU Latency (ms)":"{:.1f}"}))

summary = {
    "model_tag": MODEL_TAG,
    "fp32":           {"metrics": fp32_metrics, **fp32_info},
    "int8_dynamic":   {"metrics": dyn_metrics,  **dyn_info},
    "int8_static_fx": {"metrics": st_metrics,   **st_info},
    "int8_qat_fx":    {"metrics": qat_metrics,  **qat_info},
}
with open(os.path.join(RUN_DIR, f"{MODEL_TAG}_quant_summary.json"), "w") as f:
    json.dump(summary, f, indent=2)
print("Saved JSON:", os.path.join(RUN_DIR, f"{MODEL_TAG}_quant_summary.json"))


[FP32] evaluating baseline…
[INT8-Dynamic] quantizing + evaluating…


For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  dyn = torch.ao.quantization.quantize_dynamic(


[INT8-Static FX] building calibration loader & evaluating…


For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  prepared = prepare_fx(m, qconfig_dict, example_inputs=example_x)
  prepared = prepare(
For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrat

[INT8-QAT FX] fine-tuning for 5 epochs + evaluating…


For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  prepared = prepare_qat_fx(base, qconfig_dict, example_inputs=example_x).to(dev)
  prepared = prepare(


[FX-QAT] epoch 1/5 done, train_loss=0.0371
[FX-QAT] epoch 2/5 done, train_loss=0.0323
[FX-QAT] epoch 3/5 done, train_loss=0.0308
[FX-QAT] epoch 4/5 done, train_loss=0.0286
[FX-QAT] epoch 5/5 done, train_loss=0.0275


For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  quantized = convert_fx(prepared_cpu)


Unnamed: 0,Version,Acc,AUROC,AUPRC,F1,Size (MB),CPU Latency (ms)
0,FP32,0.901,0.96,0.969,0.894,44.8,21.8
1,INT8-Dynamic,0.901,0.96,0.969,0.894,44.8,23.9
2,INT8-StaticFX,0.899,0.96,0.968,0.891,11.3,11.4
3,INT8-QATFX,0.888,0.967,0.974,0.875,11.3,12.1


Saved JSON: ./runs\recce_standard_quant_summary.json


## Part 17 — PGD Robustness Evaluation Helper

This section:
- Defines `evaluate_under_pgd()`:
  - For each batch in a loader:
    - Generates PGD adversarial examples (white-box).
    - Evaluates model on these adversarial images.
  - Returns Accuracy, AUROC, AUPRC, and F1 under PGD-10.

Used in the robustness section to measure how well models survive strong adversarial attacks.


In [32]:
from tqdm.auto import tqdm

def evaluate_under_pgd(model, loader, steps=10, eps=PGD_EPS, alpha=PGD_ALPHA):
    """
    Evaluate a model under white-box PGD (L_inf) on a given loader.
    Returns the same metric dict structure as `evaluate`.
    """
    model = model.to(device).eval()
    auroc, auprc, f1 = make_metrics()
    correct = total = 0

    for x,y in tqdm(loader, desc=f"PGD-{steps} eval", leave=False):
        x,y = x.to(device), y.to(device)
        x_adv = pgd_linf(model, x, y, eps=eps, alpha=alpha, steps=steps)

        with torch.no_grad():
            logits = model(x_adv)
            prob_fake = torch.softmax(logits, dim=1)[:,1]
            pred = torch.argmax(logits, dim=1)
            correct += (pred == y).sum().item()
            total   += y.numel()
            auroc.update(prob_fake, y)
            auprc.update(prob_fake, y)
            f1.update(pred, y)

    return {
        "acc":  float(correct/total),
        "auroc": float(auroc.compute().item()),
        "auprc": float(auprc.compute().item()),
        "f1":    float(f1.compute().item())
    }


## Part 18 — Robustness & Adversarial Evaluation (FP32 & INT8)

This section runs the full robustness study:
- **Sanity checks & config**:
  - Ensures all required symbols exist.
  - Sets `MODEL_TAG`, `ROBUST_EPOCHS`, and flags:
    - `DO_TRAIN_PGD_AT`, `DO_TRAIN_TRADES`, `DO_AUTOATTACK`.
- **AutoAttack (optional)**:
  - Sets up AutoAttack if enabled.
- **JPEG robustness**:
  - Implements JPEG compression for:
    - FP32 models (`eval_jpeg_fp32`)
    - INT8 models (`eval_jpeg_int8`)
- **Evaluation helpers**:
  - `eval_all_fp32(...)`:
    - Clean metrics
    - PGD-10 metrics
    - Optional AutoAttack metrics
    - JPEG robustness at qualities in `JPEG_QUALITIES`
  - `eval_all_int8(...)`:
    - Clean + JPEG robustness for INT8 models
- **Models evaluated**:
  - FP32 baseline (`*_fp32_standard`)
  - FP32 PGD-AT (`*_fp32_pgd_at`) — if enabled
  - FP32 TRADES (`*_fp32_trades`) — if enabled
  - INT8 Dynamic, Static-FX, and QAT-FX variants (baseline + robust versions where available)
- **Outputs**:
  - A Pandas DataFrame `df` summarizing all models and robustness metrics.
  - Saved to:
    - `*_robust_summary.json`
    - `*_robust_summary.csv`

This section gives you the final story:
how FP32 vs INT8 and Standard vs PGD-AT vs TRADES behave under attacks and compression.


In [None]:
import os, copy, json, numpy as np, pandas as pd, torch
from pathlib import Path

need_syms = [
    "make_model", "train", "evaluate", "evaluate_under_pgd",
    "quantize_dynamic", "device", "train_ds", "test_loader",
    "PGD_EPS", "PGD_ALPHA", "PGD_STEPS", "TRADES_BETA",
    "evaluate_int8_cpu"  # from Part 8 helpers
]
missing = [n for n in need_syms if n not in globals()]
assert not missing, f"Missing symbols from earlier parts: {missing}"

RUN_DIR = "./runs"; os.makedirs(RUN_DIR, exist_ok=True)
MODEL_TAG = "recce"


DO_TRAIN_PGD_AT   = True    
DO_TRAIN_TRADES   = True   
DO_AUTOATTACK     = False   
AUTOATTACK_EPS    = PGD_EPS 
JPEG_QUALITIES    = [90, 70]  


ROBUST_EPOCHS = 8


if DO_AUTOATTACK:
    get_ipython().system("pip -q install autoattack")
    from autoattack import AutoAttack
else:
    AutoAttack = None

def autoattack_acc(model, loader, eps=AUTOATTACK_EPS):
    """AutoAttack accuracy for FP32 models (GPU)."""
    assert AutoAttack is not None, "AutoAttack not enabled (DO_AUTOATTACK=False)"
    model = model.to(device).eval()
    adversary = AutoAttack(model, norm='Linf', eps=eps, version='standard', device=device)
    total = correct = 0
    for x,y in loader:
        x,y = x.to(device), y.to(device)
        x_adv = adversary.run_standard_evaluation(x, y, bs=x.size(0))
        with torch.no_grad():
            pred = model(x_adv).argmax(1)
        correct += (pred==y).sum().item()
        total   += y.numel()
    return correct/total

from PIL import Image
import io, torchvision.transforms.functional as TF
mean = (0.485, 0.456, 0.406); std = (0.229, 0.224, 0.225)

def jpeg_compress_tensor(x, quality=70):
    x_cpu = x.detach().cpu().clamp(0,1)
    out = []
    for i in range(x_cpu.size(0)):
        img = TF.to_pil_image(x_cpu[i])
        buf = io.BytesIO(); img.save(buf, format="JPEG", quality=quality); buf.seek(0)
        out.append(TF.to_tensor(Image.open(buf)))
    return torch.stack(out, dim=0).to(x.device)

def eval_jpeg_fp32(model, loader, quality=70):
    """JPEG robustness for FP32 models (on GPU)."""
    model = model.to(device).eval()
    from torchmetrics.classification import BinaryAUROC, BinaryAveragePrecision, BinaryF1Score
    auroc, auprc, f1 = BinaryAUROC().to(device), BinaryAveragePrecision().to(device), BinaryF1Score().to(device)
    correct=total=0
    mean_t = torch.tensor(mean, device=device).view(1,3,1,1)
    std_t  = torch.tensor(std,  device=device).view(1,3,1,1)
    with torch.no_grad():
        for x,y in loader:
            x,y = x.to(device), y.to(device)
            x_den = x*std_t + mean_t
            x_jpg = jpeg_compress_tensor(x_den, quality=quality)
            x_nrm = (x_jpg - mean_t)/std_t
            logits = model(x_nrm)
            prob = torch.softmax(logits, dim=1)[:,1]
            pred = logits.argmax(1)
            correct += (pred==y).sum().item(); total += y.numel()
            auroc.update(prob, y); auprc.update(prob, y); f1.update(pred, y)
    return {"acc":correct/total,"auroc":auroc.compute().item(),"auprc":auprc.compute().item(),"f1":f1.compute().item()}

def eval_jpeg_int8(model_cpu, loader, quality=70):
    """
    JPEG robustness for INT8 models (Dynamic/Static/QAT).
    Everything stays on CPU: model and tensors.
    """
    from torchmetrics.classification import BinaryAUROC, BinaryAveragePrecision, BinaryF1Score
    m = copy.deepcopy(model_cpu).to("cpu").eval()
    auroc = BinaryAUROC().to("cpu")
    auprc = BinaryAveragePrecision().to("cpu")
    f1    = BinaryF1Score().to("cpu")

    mean_t = torch.tensor(mean).view(1,3,1,1)
    std_t  = torch.tensor(std).view(1,3,1,1)

    correct = total = 0
    with torch.no_grad():
        for x,y in loader:
            # keep on CPU
            x_den = x*std_t + mean_t
            x_jpg = jpeg_compress_tensor(x_den, quality=quality)
            x_nrm = (x_jpg - mean_t)/std_t
            logits = m(x_nrm)
            prob = torch.softmax(logits, dim=1)[:,1]
            pred = logits.argmax(1)
            correct += (pred==y).sum().item(); total += y.numel()
            auroc.update(prob.cpu(), y.cpu())
            auprc.update(prob.cpu(), y.cpu())
            f1.update(pred.cpu(), y.cpu())
    return {
        "acc": float(correct/total),
        "auroc": float(auroc.compute().item()),
        "auprc": float(auprc.compute().item()),
        "f1": float(f1.compute().item())
    }

records = []

def eval_all_fp32(tag, m):
    """
    Evaluate a FP32 model:
      - Clean metrics
      - PGD-10 metrics
      - Optional AutoAttack
      - JPEG robustness at JPEG_QUALITIES
    """
    row = {"ModelTag": tag}
    met_clean = evaluate(m.to(device), test_loader)
    row.update({f"clean_{k}": float(v) for k,v in met_clean.items()})
    met_pgd = evaluate_under_pgd(m, test_loader, steps=10)
    row.update({f"pgd_{k}": float(v) for k,v in met_pgd.items()})
    if DO_AUTOATTACK:
        aa_acc = autoattack_acc(m, test_loader, eps=AUTOATTACK_EPS)
        row["aa_acc"] = float(aa_acc)
    for q in JPEG_QUALITIES:
        mj = eval_jpeg_fp32(m, test_loader, quality=q)
        for k,v in mj.items():
            row[f"jpeg{q}_{k}"] = float(v)
    return row

def eval_all_int8(tag, m_cpu):
    """
    Evaluate an INT8 model (Dynamic/Static/QAT) on:
      - Clean metrics (CPU, via evaluate_int8_cpu)
      - JPEG robustness
    We DO NOT run PGD or AutoAttack on quantized models (gradients & support are unreliable).
    """
    row = {"ModelTag": tag}
    met_clean = evaluate_int8_cpu(m_cpu, test_loader)
    row.update({f"clean_{k}": float(v) for k,v in met_clean.items()})
    for q in JPEG_QUALITIES:
        mj = eval_jpeg_int8(m_cpu, test_loader, quality=q)
        for k,v in mj.items():
            row[f"jpeg{q}_{k}"] = float(v)
    return row

baseline_fp32 = copy.deepcopy(model).to(device).eval()
records.append(eval_all_fp32(f"{MODEL_TAG}_fp32_standard", baseline_fp32))

if DO_TRAIN_PGD_AT:
    model_pgd = make_model(MODEL_TAG)
    model_pgd = train(model_pgd, regime="pgd_at", epochs=ROBUST_EPOCHS)
    torch.save(model_pgd.state_dict(), os.path.join(RUN_DIR, f"{MODEL_TAG}_pgdat_fp32.pt"))
    records.append(eval_all_fp32(f"{MODEL_TAG}_fp32_pgd_at", model_pgd.eval()))

if DO_TRAIN_TRADES:
    model_trd = make_model(MODEL_TAG)
    model_trd = train(model_trd, regime="trades", epochs=ROBUST_EPOCHS)
    torch.save(model_trd.state_dict(), os.path.join(RUN_DIR, f"{MODEL_TAG}_trades_fp32.pt"))
    records.append(eval_all_fp32(f"{MODEL_TAG}_fp32_trades", model_trd.eval()))

if "dyn" in globals():
    dyn_base = dyn
else:
    dyn_base = quantize_dynamic(copy.deepcopy(baseline_fp32).to('cpu').eval())
records.append(eval_all_int8(f"{MODEL_TAG}_int8dyn_standard", dyn_base))

if DO_TRAIN_PGD_AT:
    dyn_pgd = quantize_dynamic(copy.deepcopy(model_pgd).to('cpu').eval())
    records.append(eval_all_int8(f"{MODEL_TAG}_int8dyn_pgd_at", dyn_pgd))
if DO_TRAIN_TRADES:
    dyn_trd = quantize_dynamic(copy.deepcopy(model_trd).to('cpu').eval())
    records.append(eval_all_int8(f"{MODEL_TAG}_int8dyn_trades", dyn_trd))

if "st" in globals():
    records.append(eval_all_int8(f"{MODEL_TAG}_int8static_fx_standard", st))

if "qatm" in globals():
    records.append(eval_all_int8(f"{MODEL_TAG}_int8qat_fx_standard", qatm))

df = pd.DataFrame.from_records(records)

front_cols = [
    "clean_acc","clean_auroc","clean_auprc","clean_f1",
    "pgd_acc","pgd_auroc","pgd_auprc","pgd_f1"
]
aa_col = ["aa_acc"] if DO_AUTOATTACK else []
jpeg_cols = []
for q in JPEG_QUALITIES:
    jpeg_cols += [f"jpeg{q}_acc", f"jpeg{q}_auroc", f"jpeg{q}_auprc", f"jpeg{q}_f1"]

ordered = [c for c in front_cols+aa_col+jpeg_cols if c in df.columns] + \
          [c for c in df.columns if c not in (["ModelTag"]+front_cols+aa_col+jpeg_cols)]
df = df[["ModelTag"] + ordered]

fmt = {c:"{:.3f}" for c in df.columns if c != "ModelTag"}
display(df.style.format(fmt))

json_path = os.path.join(RUN_DIR, f"{MODEL_TAG}_robust_summary.json")
csv_path  = os.path.join(RUN_DIR, f"{MODEL_TAG}_robust_summary.csv")
with open(json_path, "w") as f:
    json.dump(df.to_dict(orient="records"), f, indent=2)
df.to_csv(csv_path, index=False)
print("Saved:", json_path, "|", csv_path)


PGD-10 eval:   0%|          | 0/341 [00:00<?, ?it/s]

  autocast_off = torch.cuda.amp.autocast(enabled=False)


Epoch 1/8 [pgd_at]:   0%|          | 0/4376 [00:00<?, ?it/s]

Ep 01 | train_loss=0.6946 | val_acc=0.498 auroc=0.611 auprc=0.591 f1=0.000


Epoch 2/8 [pgd_at]:   0%|          | 0/4376 [00:00<?, ?it/s]

Ep 02 | train_loss=0.6931 | val_acc=0.502 auroc=0.500 auprc=0.502 f1=0.668


Epoch 3/8 [pgd_at]:   0%|          | 0/4376 [00:00<?, ?it/s]

Ep 03 | train_loss=0.6932 | val_acc=0.502 auroc=0.500 auprc=0.502 f1=0.668


Epoch 4/8 [pgd_at]:   0%|          | 0/4376 [00:00<?, ?it/s]

Ep 04 | train_loss=0.6932 | val_acc=0.498 auroc=0.500 auprc=0.502 f1=0.000


Epoch 5/8 [pgd_at]:   0%|          | 0/4376 [00:00<?, ?it/s]

Ep 05 | train_loss=0.6932 | val_acc=0.498 auroc=0.500 auprc=0.502 f1=0.000


Epoch 6/8 [pgd_at]:   0%|          | 0/4376 [00:00<?, ?it/s]

Ep 06 | train_loss=0.6932 | val_acc=0.502 auroc=0.500 auprc=0.502 f1=0.668


Epoch 7/8 [pgd_at]:   0%|          | 0/4376 [00:00<?, ?it/s]

Ep 07 | train_loss=0.6932 | val_acc=0.502 auroc=0.500 auprc=0.502 f1=0.668


Epoch 8/8 [pgd_at]:   0%|          | 0/4376 [00:00<?, ?it/s]

Ep 08 | train_loss=0.6932 | val_acc=0.502 auroc=0.500 auprc=0.502 f1=0.668


PGD-10 eval:   0%|          | 0/341 [00:00<?, ?it/s]

Epoch 1/8 [trades]:   0%|          | 0/4376 [00:00<?, ?it/s]

  autocast_off = torch.cuda.amp.autocast(enabled=False)


Ep 01 | train_loss=0.8847 | val_acc=0.904 auroc=0.968 auprc=0.968 f1=0.907


Epoch 2/8 [trades]:   0%|          | 0/4376 [00:00<?, ?it/s]

Ep 02 | train_loss=1.0119 | val_acc=0.888 auroc=0.975 auprc=0.974 f1=0.880


Epoch 3/8 [trades]:   0%|          | 0/4376 [00:00<?, ?it/s]

Ep 03 | train_loss=1.0155 | val_acc=0.890 auroc=0.981 auprc=0.980 f1=0.881


Epoch 4/8 [trades]:   0%|          | 0/4376 [00:00<?, ?it/s]

Ep 04 | train_loss=0.9843 | val_acc=0.930 auroc=0.984 auprc=0.984 f1=0.929


Epoch 5/8 [trades]:   0%|          | 0/4376 [00:00<?, ?it/s]

Ep 05 | train_loss=0.9640 | val_acc=0.905 auroc=0.974 auprc=0.977 f1=0.899


Epoch 6/8 [trades]:   0%|          | 0/4376 [00:00<?, ?it/s]

Ep 06 | train_loss=0.9378 | val_acc=0.903 auroc=0.987 auprc=0.987 f1=0.895


Epoch 7/8 [trades]:   0%|          | 0/4376 [00:00<?, ?it/s]

Ep 07 | train_loss=0.9238 | val_acc=0.941 auroc=0.992 auprc=0.992 f1=0.940


Epoch 8/8 [trades]:   0%|          | 0/4376 [00:00<?, ?it/s]

Ep 08 | train_loss=0.9039 | val_acc=0.918 auroc=0.990 auprc=0.989 f1=0.913


PGD-10 eval:   0%|          | 0/341 [00:00<?, ?it/s]

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  return tq.quantize_dynamic(model_cpu, {nn.Linear}, dtype=torch.qint8)


Unnamed: 0,ModelTag,clean_acc,clean_auroc,clean_auprc,clean_f1,pgd_acc,pgd_auroc,pgd_auprc,pgd_f1,jpeg90_acc,jpeg90_auroc,jpeg90_auprc,jpeg90_f1,jpeg70_acc,jpeg70_auroc,jpeg70_auprc,jpeg70_f1
0,recce_fp32_standard,0.901,0.96,0.969,0.894,0.0,0.0,0.304,0.0,0.9,0.96,0.969,0.892,0.896,0.959,0.968,0.887
1,recce_fp32_pgd_at,0.504,0.565,0.558,0.0,0.504,0.079,0.318,0.0,0.504,0.566,0.558,0.0,0.504,0.565,0.558,0.0
2,recce_fp32_trades,0.917,0.986,0.986,0.911,0.31,0.267,0.359,0.205,0.917,0.986,0.986,0.911,0.916,0.987,0.986,0.91
3,recce_int8dyn_standard,0.901,0.96,0.969,0.894,,,,,0.9,0.96,0.969,0.893,0.895,0.959,0.968,0.887
4,recce_int8dyn_pgd_at,0.504,0.554,0.547,0.0,,,,,0.504,0.553,0.546,0.0,0.504,0.554,0.546,0.0
5,recce_int8dyn_trades,0.917,0.986,0.986,0.911,,,,,0.917,0.986,0.986,0.911,0.916,0.986,0.986,0.91
6,recce_int8static_fx_standard,0.899,0.96,0.968,0.891,,,,,0.897,0.959,0.968,0.888,0.891,0.959,0.967,0.88
7,recce_int8qat_fx_standard,0.888,0.967,0.974,0.875,,,,,0.887,0.968,0.974,0.875,0.878,0.968,0.974,0.862


Saved: ./runs\recce_robust_summary.json | ./runs\recce_robust_summary.csv


In [34]:
!pip install git+https://github.com/RobustBench/robustbench.git


Collecting git+https://github.com/RobustBench/robustbench.git
  Cloning https://github.com/RobustBench/robustbench.git to c:\users\admin\appdata\local\temp\pip-req-build-c1qg66uo
  Resolved https://github.com/RobustBench/robustbench.git to commit 78fcc9e48a07a861268f295a777b975f25155964
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'done'
Collecting torchdiffeq (from robustbench==1.1)
  Downloading torchdiffeq-0.2.5-py3-none-any.whl.metadata (440 bytes)
Collecting geotorch (from robustbench==1.1)
  Downloading geotorch-0.3.0-py3-none-any.whl.metadata (14 kB)
Collecting autoattack@ git+https://github.com/fra31/auto-attack.git@a39220048b3c9f2cca9a4d3a54604793c68eca7e#egg=autoattack (from robustbench==1.1)
  Cloning 

  Running command git clone --filter=blob:none --quiet https://github.com/RobustBench/robustbench.git 'C:\Users\admin\AppData\Local\Temp\pip-req-build-c1qg66uo'
  Running command git clone --filter=blob:none --quiet https://github.com/fra31/auto-attack.git 'C:\Users\admin\AppData\Local\Temp\pip-install-mi5mzuc3\autoattack_df12421caa13418db9cc632f0dab8f6a'
  Running command git rev-parse -q --verify 'sha^a39220048b3c9f2cca9a4d3a54604793c68eca7e'
  Running command git fetch -q https://github.com/fra31/auto-attack.git a39220048b3c9f2cca9a4d3a54604793c68eca7e


In [None]:
import copy, json, torch, pandas as pd
from tqdm.auto import tqdm

CKPT_FP32 = "./runs/recce_standard_fp32_final.pt"
MODEL_TAG_BASELINE = "recce_fp32_standard"   

base_model = make_model("recce")
base_model.load_state_dict(torch.load(CKPT_FP32, map_location=device))
base_model = base_model.to(device).eval()

mean_t = torch.tensor(mean, device=device).view(1,3,1,1)
std_t  = torch.tensor(std,  device=device).view(1,3,1,1)

class NormalizedWrapper(torch.nn.Module):
    def __init__(self, model, mean_t, std_t):
        super().__init__()
        self.model = model
        self.mean_t = mean_t
        self.std_t = std_t

    def forward(self, x_pixel):          # x_pixel in [0,1]
        x_norm = (x_pixel - self.mean_t) / self.std_t
        return self.model(x_norm)

aa_model = NormalizedWrapper(base_model, mean_t, std_t).to(device).eval()

from math import inf
MAX_SAMPLES = 2000

x_list, y_list = [], []
collected = 0
for x_norm, y in tqdm(test_loader, desc="Collecting subset for AutoAttack"):
    x_den = x_norm * torch.tensor(std).view(1,3,1,1) + torch.tensor(mean).view(1,3,1,1)
    x_den = x_den.clamp(0, 1)
    x_list.append(x_den)
    y_list.append(y)
    collected += x_norm.size(0)
    if collected >= MAX_SAMPLES:
        break

x_test = torch.cat(x_list, dim=0)[:MAX_SAMPLES].to(device)
y_test = torch.cat(y_list, dim=0)[:MAX_SAMPLES].to(device)
print(f"AutoAttack will run on {x_test.shape[0]} test images (subset).")

try:
    from autoattack import AutoAttack
except ImportError:
    import subprocess, sys
    subprocess.run([sys.executable, "-m", "pip", "install", "autoattack"], check=True)
    from autoattack import AutoAttack

adversary = AutoAttack(
    aa_model,
    norm="Linf",
    eps=float(PGD_EPS),
    version="standard",
    device=device,
    log_path=None
)

x_adv = adversary.run_standard_evaluation(x_test, y_test, bs=BATCH)

with torch.no_grad():
    logits_adv = aa_model(x_adv)
    pred_adv   = logits_adv.argmax(dim=1)
    aa_acc     = (pred_adv == y_test).float().mean().item()

print(f"[AutoAttack] subset robust accuracy for {MODEL_TAG_BASELINE}: {aa_acc:.4f}")

csv_path  = "./runs/recce_robust_summary.csv"
json_path = "./runs/recce_robust_summary.json"

df = pd.read_csv(csv_path)

if "aa_acc" not in df.columns:
    df["aa_acc"] = float("nan")

mask = df["ModelTag"] == MODEL_TAG_BASELINE
if not mask.any():
    print(f"WARNING: ModelTag '{MODEL_TAG_BASELINE}' not found in {csv_path}")
else:
    df.loc[mask, "aa_acc"] = aa_acc
    print(f"Updated aa_acc for row '{MODEL_TAG_BASELINE}'")

df.to_csv(csv_path, index=False)
with open(json_path, "w") as f:
    json.dump(df.to_dict(orient="records"), f, indent=2)

print("Saved updated robust summary with AutoAttack to:")
print("  CSV :", csv_path)
print("  JSON:", json_path)


Collecting subset for AutoAttack:   0%|          | 0/341 [00:15<?, ?it/s]

AutoAttack will run on 2000 test images (subset).
setting parameters for standard version
using standard version including apgd-ce, apgd-t, fab-t, square.
initial accuracy: 96.20%
apgd-ce - 1/61 - 32 out of 32 successfully perturbed
apgd-ce - 2/61 - 32 out of 32 successfully perturbed
apgd-ce - 3/61 - 32 out of 32 successfully perturbed
apgd-ce - 4/61 - 32 out of 32 successfully perturbed
apgd-ce - 5/61 - 32 out of 32 successfully perturbed
apgd-ce - 6/61 - 32 out of 32 successfully perturbed
apgd-ce - 7/61 - 32 out of 32 successfully perturbed
apgd-ce - 8/61 - 32 out of 32 successfully perturbed
apgd-ce - 9/61 - 32 out of 32 successfully perturbed
apgd-ce - 10/61 - 32 out of 32 successfully perturbed
apgd-ce - 11/61 - 32 out of 32 successfully perturbed
apgd-ce - 12/61 - 32 out of 32 successfully perturbed
apgd-ce - 13/61 - 32 out of 32 successfully perturbed
apgd-ce - 14/61 - 32 out of 32 successfully perturbed
apgd-ce - 15/61 - 32 out of 32 successfully perturbed
apgd-ce - 16/61 - 3

In [None]:
import pandas as pd
import json
import numpy as np

CSV_PATH = "./runs/recce_robust_summary.csv"
JSON_PATH = "./runs/recce_robust_summary.json"

df = pd.read_csv(CSV_PATH)

print("==============================================")
print(" FULL RECCE ROBUSTNESS + QUANTIZATION SUMMARY")
print("==============================================\n")

for col in df.columns:
    if col != "ModelTag":
        df[col] = pd.to_numeric(df[col], errors="ignore")

num_cols = df.select_dtypes(include=[np.number]).columns
fmt_dict = {c: "{:.4f}" for c in num_cols}

display(df.style.format(fmt_dict))

print("\n==============================================")
print(" DETAILED PRINTED RESULTS (PER MODEL)")
print("==============================================\n")

for idx, row in df.iterrows():
    tag = row["ModelTag"]
    print(f"ModelTag: {tag}")
    print("-" * (10 + len(str(tag))))

    def safe_get(name):
        v = row.get(name, np.nan)
        try:
            return float(v)
        except Exception:
            return np.nan

    # Clean metrics
    print(f" Clean Accuracy     : {safe_get('clean_acc'):.4f}")
    print(f" Clean AUROC        : {safe_get('clean_auroc'):.4f}")
    print(f" Clean AUPRC        : {safe_get('clean_auprc'):.4f}")
    print(f" Clean F1           : {safe_get('clean_f1'):.4f}")

    if "pgd_acc" in df.columns:
        print(f" PGD Accuracy       : {safe_get('pgd_acc'):.4f}")
        print(f" PGD AUROC          : {safe_get('pgd_auroc'):.4f}")
        print(f" PGD AUPRC          : {safe_get('pgd_auprc'):.4f}")
        print(f" PGD F1             : {safe_get('pgd_f1'):.4f}")

    if "aa_acc" in df.columns:
        aa = safe_get("aa_acc")
        if not np.isnan(aa):
            print(f" AutoAttack Robust Acc : {aa:.4f}")

    for q in [90, 70]:
        acc_col = f"jpeg{q}_acc"
        if acc_col in df.columns:
            print(f" JPEG{q} Acc        : {safe_get(acc_col):.4f}")
            print(f" JPEG{q} AUROC      : {safe_get(f'jpeg{q}_auroc'):.4f}")
            print(f" JPEG{q} AUPRC      : {safe_get(f'jpeg{q}_auprc'):.4f}")
            print(f" JPEG{q} F1         : {safe_get(f'jpeg{q}_f1'):.4f}")

    print("\n----------------------------------------------\n")

with open(JSON_PATH, "r") as f:
    json_data = json.load(f)

print("==============================================")
print(" RAW JSON SUMMARY (for reference)")
print("==============================================\n")
print(json.dumps(json_data, indent=2))


 FULL RECCE ROBUSTNESS + QUANTIZATION SUMMARY



  df[col] = pd.to_numeric(df[col], errors="ignore")


Unnamed: 0,ModelTag,clean_acc,clean_auroc,clean_auprc,clean_f1,pgd_acc,pgd_auroc,pgd_auprc,pgd_f1,jpeg90_acc,jpeg90_auroc,jpeg90_auprc,jpeg90_f1,jpeg70_acc,jpeg70_auroc,jpeg70_auprc,jpeg70_f1,aa_acc
0,recce_fp32_standard,0.9009,0.9597,0.9691,0.894,0.0,0.0,0.3041,0.0,0.8996,0.9597,0.969,0.8924,0.8957,0.959,0.9682,0.8871,0.0
1,recce_fp32_pgd_at,0.5036,0.5653,0.5578,0.0,0.5036,0.0794,0.3184,0.0,0.5036,0.5656,0.5579,0.0,0.5036,0.5654,0.5579,0.0,
2,recce_fp32_trades,0.9169,0.9864,0.9859,0.9111,0.3099,0.267,0.3585,0.2051,0.9169,0.9863,0.9859,0.9111,0.9158,0.9865,0.986,0.9098,
3,recce_int8dyn_standard,0.9007,0.9597,0.969,0.8938,,,,,0.8998,0.9596,0.9689,0.8926,0.8954,0.9589,0.9682,0.8866,
4,recce_int8dyn_pgd_at,0.5036,0.554,0.5466,0.0,,,,,0.5036,0.5531,0.5461,0.0,0.5036,0.5538,0.5463,0.0,
5,recce_int8dyn_trades,0.9168,0.9863,0.9859,0.911,,,,,0.9168,0.9863,0.9858,0.911,0.916,0.9865,0.986,0.91,
6,recce_int8static_fx_standard,0.8987,0.9595,0.9684,0.8906,,,,,0.8969,0.9594,0.9682,0.8884,0.8906,0.9586,0.9674,0.8801,
7,recce_int8qat_fx_standard,0.8877,0.9675,0.9739,0.8751,,,,,0.8872,0.9678,0.974,0.8746,0.8777,0.9682,0.9737,0.8621,



 DETAILED PRINTED RESULTS (PER MODEL)

ModelTag: recce_fp32_standard
-----------------------------
 Clean Accuracy     : 0.9009
 Clean AUROC        : 0.9597
 Clean AUPRC        : 0.9691
 Clean F1           : 0.8940
 PGD Accuracy       : 0.0000
 PGD AUROC          : 0.0000
 PGD AUPRC          : 0.3041
 PGD F1             : 0.0000
 AutoAttack Robust Acc : 0.0000
 JPEG90 Acc        : 0.8996
 JPEG90 AUROC      : 0.9597
 JPEG90 AUPRC      : 0.9690
 JPEG90 F1         : 0.8924
 JPEG70 Acc        : 0.8957
 JPEG70 AUROC      : 0.9590
 JPEG70 AUPRC      : 0.9682
 JPEG70 F1         : 0.8871

----------------------------------------------

ModelTag: recce_fp32_pgd_at
---------------------------
 Clean Accuracy     : 0.5036
 Clean AUROC        : 0.5653
 Clean AUPRC        : 0.5578
 Clean F1           : 0.0000
 PGD Accuracy       : 0.5036
 PGD AUROC          : 0.0794
 PGD AUPRC          : 0.3184
 PGD F1             : 0.0000
 JPEG90 Acc        : 0.5036
 JPEG90 AUROC      : 0.5656
 JPEG90 AUPRC      

: 