In [None]:
import sys
import os
import json
import secrets
import numpy as np

# Navigate to the parent directory of the project structure
project_dir = os.path.abspath(os.path.join(os.getcwd(), '../'))
src_dir = os.path.join(project_dir, 'src')
data_dir = os.path.join(project_dir, 'data')
fig_dir = os.path.join(project_dir, 'fig')
logs_dir = os.path.join(project_dir, 'logs')
os.makedirs(fig_dir, exist_ok=True)
os.makedirs(data_dir, exist_ok=True)
os.makedirs(logs_dir, exist_ok=True)

# Add the src directory to sys.path
sys.path.append(src_dir)

import torch
import torch.nn as nn
from opacus import PrivacyEngine
import torch.optim as optim
from torch.utils.data import Subset, DataLoader
import torchvision.transforms as transforms

from utils import find_latest_checkpoint, load_checkpoint
from dataset import get_auditable_data_loaders, generate_poisoned_canaries_and_mask
from network_arch import WideResNet
from inference import compute_audit_score_blackbox, compute_loss
from auditing import CanaryScoreAuditor

In [2]:
# ==========================================
# Hyperparameters (Settings from the paper "Unlocking High-Accuracy Differentially Private Image Classification through Scale")
# ==========================================
LOGICAL_BATCH_SIZE = 4096     # Target batch size (Paper)
MAX_PHYSICAL_BATCH_SIZE = 128  # GPU limit (128 * 16 = 512 effective images)
AUG_MULTIPLICITY = 16         # K=16 augmentations
MAX_GRAD_NORM = 1.0
EPSILON = 8.0
DELTA = 1e-5
EPOCHS = 140                   # Increase to 100+ for best results
LR = 4.0                      # High LR for large batch
MOMENTUM = 0.0                # No momentum
NOISE_MULTIPLIER = 3.0        # Sigma ~ 3.0 is optimal for BS=4096
CKPT_INTERVAL = 20            # Save checkpoint every 10 epochs


# ==========================================
# Experiment Parameters
# ==========================================
CANARY_COUNT = 5000           # Number of canaries
PKEEP = 0.5                   # Probability of including each canary in the training set
DATABSEED = 53841938803364779163249839521218793645  # if seed is set to None then seed is random
     

In [3]:
exp_dir = os.path.join(data_dir, f"mislabeled-canaries-{DATABSEED}-{CANARY_COUNT}-{PKEEP}-cifar10")

assert os.path.exists(exp_dir), f"Experiment directory {exp_dir} does not exist. You need to train an auditable DP-SGD model first. See train_auditable_DP_model_blackbox.ipynb as an example."

ckpt_dir = os.path.join(exp_dir, "ckpt")
logits_dir = os.path.join(exp_dir, "logits")
scores_dir = os.path.join(exp_dir, "scores")

# Create directories if they don't exist
os.makedirs(logits_dir, exist_ok=True)
os.makedirs(scores_dir, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
print(f"Run experiment on device: {device}")

# Load canary database and mask
print("Loading canary database and mask...")
poisoned_canaries, inclusion_mask = generate_poisoned_canaries_and_mask(
    data_dir=data_dir,
    canary_count=CANARY_COUNT,
    seed=DATABSEED,
    pkeep=PKEEP
)

print(f"Loaded {len(poisoned_canaries)} total canaries")
print(f"  - In-canaries (included in training): {np.sum(inclusion_mask)}")
print(f"  - Out-canaries (excluded from training): {np.sum(~inclusion_mask)}")

# Normalize canaries the same way as training (after augmentation)
normalize_transform = transforms.Normalize(
    (0.4914, 0.4822, 0.4465),
    (0.2023, 0.1994, 0.2010),
)

class NormalizeWrapper(torch.utils.data.Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        return self.transform(img), label

normalized_canaries = NormalizeWrapper(poisoned_canaries, normalize_transform)

Run experiment on device: cuda
Loading canary database and mask...
Loaded 5000 total canaries
  - In-canaries (included in training): 2488
  - Out-canaries (excluded from training): 2512


  entry = pickle.load(f, encoding="latin1")


In [4]:
print("Loading data...")
train_dataset, test_dataset = get_auditable_data_loaders(
    data_dir=data_dir,
    canary_count=CANARY_COUNT,
    seed=DATABSEED,
    pkeep=PKEEP
)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=LOGICAL_BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True
)

Loading data...


In [5]:
# Create initial model for score computation (only once)
torch_seed = int(DATABSEED % (2**32 - 1))
np_seed = int(DATABSEED % (2**32 - 1))
torch.manual_seed(torch_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(torch_seed)
    torch.cuda.manual_seed_all(torch_seed)
np.random.seed(np_seed)
initial_model = WideResNet(depth=16, widen_factor=4).to(device)

In [6]:
def _make_aug_transforms():
    augment_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
    ])
    normalize_transform = transforms.Normalize(
        (0.4914, 0.4822, 0.4465),
        (0.2023, 0.1994, 0.2010),
    )
    return augment_transform, normalize_transform


def compute_loss_with_aug(model, canaries, device, batch_size=128, aug_multiplicity=16):
    """Compute per-sample loss averaged over K random augmentations."""
    model.eval()
    criterion = nn.CrossEntropyLoss(reduction='none')

    if isinstance(canaries, torch.utils.data.Dataset):
        canary_loader = DataLoader(canaries, batch_size=batch_size, shuffle=False, num_workers=0)
    else:
        canary_loader = canaries

    augment_transform, normalize_transform = _make_aug_transforms()

    all_losses = []
    with torch.no_grad():
        for images, targets in canary_loader:
            images = images.to(device)
            targets = targets.to(device)

            aug_images = []
            for _ in range(aug_multiplicity):
                aug_images.append(augment_transform(images))
            aug_images = torch.stack(aug_images).transpose(0, 1).reshape(-1, 3, 32, 32)
            aug_images = normalize_transform(aug_images)

            aug_targets = targets.repeat_interleave(aug_multiplicity)
            outputs = model(aug_images)
            losses = criterion(outputs, aug_targets)

            # Average loss across augmentations per original sample
            losses = losses.view(-1, aug_multiplicity).mean(dim=1)
            all_losses.append(losses.cpu().numpy())

    return np.concatenate(all_losses)


def compute_audit_score_with_aug(
    initial_model,
    final_model,
    canaries,
    device,
    batch_size=128,
    aug_multiplicity=16,
):
    """Compute loss_init - loss_final using the SAME augmentations."""
    initial_model.eval()
    final_model.eval()
    criterion = nn.CrossEntropyLoss(reduction='none')

    if isinstance(canaries, torch.utils.data.Dataset):
        canary_loader = DataLoader(canaries, batch_size=batch_size, shuffle=False, num_workers=0)
    else:
        canary_loader = canaries

    augment_transform, normalize_transform = _make_aug_transforms()

    all_scores = []
    with torch.no_grad():
        for images, targets in canary_loader:
            images = images.to(device)
            targets = targets.to(device)

            aug_images = []
            for _ in range(aug_multiplicity):
                aug_images.append(augment_transform(images))
            aug_images = torch.stack(aug_images).transpose(0, 1).reshape(-1, 3, 32, 32)
            aug_images = normalize_transform(aug_images)

            aug_targets = targets.repeat_interleave(aug_multiplicity)
            outputs_init = initial_model(aug_images)
            outputs_final = final_model(aug_images)
            losses_init = criterion(outputs_init, aug_targets)
            losses_final = criterion(outputs_final, aug_targets)

            scores = (losses_init - losses_final).view(-1, aug_multiplicity).mean(dim=1)
            all_scores.append(scores.cpu().numpy())

    return np.concatenate(all_scores)


# Find all checkpoints and process them
print("Finding all checkpoints...")
checkpoint_files = [f for f in os.listdir(ckpt_dir) if f.endswith('.npz')]
checkpoint_files.sort()  # Sort to process in order

if not checkpoint_files:
    print("No checkpoints found!")
else:
    print(f"Found {len(checkpoint_files)} checkpoints: {checkpoint_files}")
    # Process each checkpoint
    for ckpt_file in checkpoint_files:
        ckpt_path = os.path.join(ckpt_dir, ckpt_file)
        ckpt_name = ckpt_file.replace('.npz', '')  # e.g., "0000000020"
        
        print(f"\nProcessing checkpoint: {ckpt_file}")
        
        # Create a fresh model for this checkpoint
        model = WideResNet(depth=16, widen_factor=4).to(device)
        optimizer = optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)

        # Setup privacy engine
        privacy_engine = PrivacyEngine()
        model, optimizer, train_loader = privacy_engine.make_private(
            module=model,
            optimizer=optimizer,
            data_loader=train_loader,
            noise_multiplier=NOISE_MULTIPLIER,
            max_grad_norm=MAX_GRAD_NORM,
        )

        # Load checkpoint
        loaded_epoch, loaded_global_step = load_checkpoint(ckpt_path, model, optimizer, device)
        print(f"  Loaded epoch {loaded_epoch}, global step {loaded_global_step}")

        # We must tell the accountant we have already taken N steps.
        steps_per_epoch = len(train_loader) 
        sample_rate = 1 / len(train_loader)

        # This line forces the accountant to remember the past
        privacy_engine.accountant.history.append((NOISE_MULTIPLIER, sample_rate, loaded_global_step))
    
        total_steps = loaded_global_step if loaded_global_step is not None else 0
        current_eps = privacy_engine.get_epsilon(DELTA)
        print(f"Current Cumulative Epsilon: {current_eps:.2f} with delta={DELTA}")
        
        # Compute losses for all canaries (in original order)
        print(f"  Computing losses for all canaries...")
        all_losses = compute_loss_with_aug(
            model,
            poisoned_canaries,
            device,
            batch_size=128,
            aug_multiplicity=AUG_MULTIPLICITY,
        )

        # Compute scores for all canaries (in original order)
        print(f"  Computing scores for all canaries...")
        all_scores = compute_audit_score_with_aug(
            initial_model=initial_model,
            final_model=model,
            canaries=poisoned_canaries,
            device=device,
            batch_size=128,
            aug_multiplicity=AUG_MULTIPLICITY,
        )

        # One-run audit epsilon lower bound (score = loss_init - loss_final)
        in_scores = all_scores[inclusion_mask]
        out_scores = all_scores[~inclusion_mask]
        auditor = CanaryScoreAuditor(in_scores, out_scores)
        eps_lb = auditor.epsilon_one_run(significance=0.05, delta=DELTA)
        auroc = auditor.attack_auroc()

        # Alternative score: negative final loss only
        loss_scores = -all_losses
        in_loss_scores = loss_scores[inclusion_mask]
        out_loss_scores = loss_scores[~inclusion_mask]
        auditor_loss = CanaryScoreAuditor(in_loss_scores, out_loss_scores)
        eps_lb_loss = auditor_loss.epsilon_one_run(significance=0.05, delta=DELTA)
        auroc_loss = auditor_loss.attack_auroc()

        print(f"  Score means: in={in_scores.mean():.4f}, out={out_scores.mean():.4f}")
        print(f"  AUROC (init-final): {auroc:.4f}")
        print(f"  One-run eps_lb (init-final, alpha=0.05): {eps_lb:.4f} vs true eps={current_eps:.2f}")
        print(f"  Loss-score means: in={in_loss_scores.mean():.4f}, out={out_loss_scores.mean():.4f}")
        print(f"  AUROC (neg final loss): {auroc_loss:.4f}")
        print(f"  One-run eps_lb (neg final loss, alpha=0.05): {eps_lb_loss:.4f}")
        

print("\n✅ Finished processing all checkpoints!")



Finding all checkpoints...
Found 9 checkpoints: ['0000000020.npz', '0000000040.npz', '0000000060.npz', '0000000080.npz', '0000000100.npz', '0000000120.npz', '0000000140.npz', '0000000160.npz', '0000000180.npz']

Processing checkpoint: 0000000020.npz
Loaded checkpoint from epoch 20, global step 220
  Loaded epoch 20, global step 220
Current Cumulative Epsilon: 1.91 with delta=1e-05
  Computing losses for all canaries...
  Computing scores for all canaries...
  Score means: in=-2.6039, out=-2.7647
  AUROC (init-final): 0.5230
  One-run eps_lb (init-final, alpha=0.05): 0.0000 vs true eps=1.91
  Loss-score means: in=-4.9238, out=-5.0943
  AUROC (neg final loss): 0.5239
  One-run eps_lb (neg final loss, alpha=0.05): 0.0000

Processing checkpoint: 0000000040.npz
Loaded checkpoint from epoch 40, global step 440
  Loaded epoch 40, global step 440
Current Cumulative Epsilon: 2.78 with delta=1e-05
  Computing losses for all canaries...
  Computing scores for all canaries...
  Score means: in=-2.