In [1]:
import sys
import os
import json
import secrets
import numpy as np

# Navigate to the parent directory of the project structure
project_dir = os.path.abspath(os.path.join(os.getcwd(), '../../'))
src_dir = os.path.join(project_dir, 'src')
data_dir = os.path.join(project_dir, 'data')
fig_dir = os.path.join(project_dir, 'fig')
logs_dir = os.path.join(project_dir, 'logs')
os.makedirs(fig_dir, exist_ok=True)
os.makedirs(data_dir, exist_ok=True)
os.makedirs(logs_dir, exist_ok=True)

# Add the src directory to sys.path
sys.path.append(src_dir)

import torch
from opacus import PrivacyEngine
import torch.optim as optim

from utils import setup_logging, save_checkpoint, find_latest_checkpoint, load_checkpoint
from network_arch import WideResNet
from train import test, train_whitebox
from classifier.white_box_dp_sgd import sample_gaussian


In [None]:
# ==========================================
# Hyperparameters (Settings from the paper "Unlocking High-Accuracy Differentially Private Image Classification through Scale")
# ==========================================
LOGICAL_BATCH_SIZE = 4096     # Target batch size (Paper)
MAX_PHYSICAL_BATCH_SIZE = 128  # GPU limit (128 * 16 = 512 effective images)
AUG_MULTIPLICITY = 1         # K=16 augmentations
MAX_GRAD_NORM = 1.0
EPSILON = 8.0
DELTA = 1e-5
EPOCHS = 20                   # Increase to 100+ for best results
LR = 4.0                      # High LR for large batch
MOMENTUM = 0.0                # No momentum
NOISE_MULTIPLIER = 3.0        # Sigma ~ 3.0 is optimal for BS=4096
CKPT_INTERVAL = 10            # Save checkpoint every 10 epochs


# ==========================================
# Experiment Parameters
# ==========================================
CANARY_COUNT = 10000           # Number of canaries
PKEEP = 0.5                   # Probability of including each canary in the training set
# DATABSEED = 53841938803364779163249839521218793645  # if seed is set to None then seed is random
DATABSEED = 27198899012190525004019618245709479116  # if seed is set to None then seed is random

In [3]:
logger, log_file = setup_logging(log_dir=logs_dir)
logdir_path = os.path.dirname(log_file) 

# Create experiment directory
if DATABSEED is not None:
    exp_dir = os.path.join(data_dir, f"mislabeled-canaries-{DATABSEED}-{CANARY_COUNT}-{PKEEP}-cifar10")
else:
    DATABSEED = secrets.randbits(128)
    logger.info(f"Generated random 128-bit seed: {DATABSEED}")
    exp_dir = os.path.join(data_dir, f"mislabeled-canaries-{DATABSEED}-{CANARY_COUNT}-{PKEEP}-cifar10")

os.makedirs(exp_dir, exist_ok=True)
logger.info(f"Experiment directory: {exp_dir}")

# Create checkpoint directory under experiment directory
ckpt_dir = os.path.join(exp_dir, "ckpt")
os.makedirs(ckpt_dir, exist_ok=True)
logger.info(f"Checkpoint directory: {ckpt_dir}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Run experiment on device: {device}")

# Set random seeds for reproducibility
# Cast 128-bit seed to 64-bit for PyTorch compatibility
torch_seed = int(DATABSEED % (2**32 - 1))
np_seed = int(DATABSEED % (2**32 - 1))
torch.manual_seed(torch_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(torch_seed)
    torch.cuda.manual_seed_all(torch_seed)
np.random.seed(np_seed)
logger.info(f"Set random seeds (torch, numpy) to: {torch_seed} (from DATABSEED: {DATABSEED})")
rng = np.random.default_rng(np_seed)

# Store hyperparameters in a dictionary
params = {
    'logical_batch_size': LOGICAL_BATCH_SIZE,
    'max_physical_batch_size': MAX_PHYSICAL_BATCH_SIZE,
    'aug_multiplicity': AUG_MULTIPLICITY,
    'max_grad_norm': MAX_GRAD_NORM,
    'epsilon': EPSILON,
    'delta': DELTA,
    'epochs': EPOCHS,
    'lr': LR,
    'momentum': MOMENTUM,
    'noise_multiplier': NOISE_MULTIPLIER,
    'ckpt_interval': CKPT_INTERVAL,
    'canary_count': CANARY_COUNT,
    'pkeep': PKEEP,
    'database_seed': DATABSEED
}

# Save hyperparameters to experiment directory
hparams_path = os.path.join(exp_dir, 'hparams.json')
with open(hparams_path, 'w') as f:
    json.dump(params, f, indent=2)
logger.info(f"Hyperparameters saved to: {hparams_path}")


2026-02-02 16:34:19 - INFO - Logging initialized. Log file: /storage/coda1/p-vzikas3/0/ywei368/Yu-Project/Auditing/dpsgd-auditbench/logs/train_20260202_163419.log
2026-02-02 16:34:19 - INFO - Experiment directory: /storage/coda1/p-vzikas3/0/ywei368/Yu-Project/Auditing/dpsgd-auditbench/data/mislabeled-canaries-27198899012190525004019618245709479116-10000-0.5-cifar10
2026-02-02 16:34:19 - INFO - Checkpoint directory: /storage/coda1/p-vzikas3/0/ywei368/Yu-Project/Auditing/dpsgd-auditbench/data/mislabeled-canaries-27198899012190525004019618245709479116-10000-0.5-cifar10/ckpt
2026-02-02 16:34:19 - INFO - Run experiment on device: cuda
2026-02-02 16:34:19 - INFO - Set random seeds (torch, numpy) to: 2413360701 (from DATABSEED: 27198899012190525004019618245709479116)
2026-02-02 16:34:19 - INFO - Hyperparameters saved to: /storage/coda1/p-vzikas3/0/ywei368/Yu-Project/Auditing/dpsgd-auditbench/data/mislabeled-canaries-27198899012190525004019618245709479116-10000-0.5-cifar10/hparams.json


In [4]:
from dataset import get_data_loaders
logger.info("Loading data...")
train_loader, test_dataset = get_data_loaders(
    data_dir=data_dir,
    logical_batch_size=LOGICAL_BATCH_SIZE,
    num_workers=4
)

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=1024, shuffle=False, num_workers=4
)

2026-02-02 16:34:19 - INFO - Loading data...


  entry = pickle.load(f, encoding="latin1")


In [5]:
# Create model
logger.info("Creating model...")
model = WideResNet(depth=16, widen_factor=4).to(device)
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)

# Setup privacy engine
logger.info("Setting up privacy engine...")
privacy_engine = PrivacyEngine()
model, optimizer, train_loader = privacy_engine.make_private(
    module=model,
    optimizer=optimizer,
    data_loader=train_loader,
    noise_multiplier=NOISE_MULTIPLIER,
    max_grad_norm=MAX_GRAD_NORM,
)

2026-02-02 16:34:22 - INFO - Creating model...
2026-02-02 16:34:22 - INFO - Setting up privacy engine...




In [6]:
# ==========================================
# Checkpoint Loading
# ==========================================
start_epoch = 1

# Find the latest checkpoint (largest epoch number)
checkpoint_result = find_latest_checkpoint(ckpt_dir)
if checkpoint_result is not None:
    checkpoint_path, checkpoint_epoch = checkpoint_result
else:
    checkpoint_path, checkpoint_epoch = None, None

if checkpoint_path is not None:
    logger.info(f"Loading checkpoint '{checkpoint_path}' (epoch {checkpoint_epoch})...")
    
    # Load model and optimizer state
    loaded_epoch, loaded_global_step = load_checkpoint(checkpoint_path, model, optimizer, device, logger)
    start_epoch = loaded_epoch + 1
    
    # RECOVER PRIVACY STATE
    # We must tell the accountant we have already taken N steps.
    steps_per_epoch = len(train_loader) 
    sample_rate = 1 / len(train_loader)

    # This line forces the accountant to remember the past
    privacy_engine.accountant.history.append((NOISE_MULTIPLIER, sample_rate, loaded_global_step))
    
    logger.info(f"Resumed from Epoch {start_epoch}")
    logger.info(f"Privacy Accountant updated with {loaded_global_step} past steps.")
    
    # Initialize total_steps from loaded checkpoint
    total_steps = loaded_global_step if loaded_global_step is not None else 0

    # Load in_scores and out_scores for the same epoch so sum_scores is restored correctly
    in_scores_path = os.path.join(exp_dir, f'in_scores_{checkpoint_epoch:06d}.csv')
    out_scores_path = os.path.join(exp_dir, f'out_scores_{checkpoint_epoch:06d}.csv')
    if os.path.isfile(in_scores_path) and total_steps > 0:
        in_scores_loaded = np.loadtxt(in_scores_path, delimiter=',')
        sum_scores = in_scores_loaded * total_steps
        logger.info(f"Loaded in_scores for epoch {checkpoint_epoch} -> restored sum_scores (total_steps={total_steps})")
    else:
        sum_scores = 0

    # Verify Epsilon matches where you left off
    current_eps = privacy_engine.get_epsilon(DELTA)
    logger.info(f"Current Cumulative Epsilon: {current_eps:.2f}")
else:
    logger.info("No checkpoint found. Starting from scratch.")
    total_steps = 0
    sum_scores = 0

2026-02-02 16:34:22 - INFO - Loading checkpoint '/storage/coda1/p-vzikas3/0/ywei368/Yu-Project/Auditing/dpsgd-auditbench/data/mislabeled-canaries-27198899012190525004019618245709479116-10000-0.5-cifar10/ckpt/0000000005.npz' (epoch 5)...


2026-02-02 16:34:22 - INFO - Loaded checkpoint from epoch 5, global step 60
2026-02-02 16:34:22 - INFO - Resumed from Epoch 6
2026-02-02 16:34:22 - INFO - Privacy Accountant updated with 60 past steps.
2026-02-02 16:34:22 - INFO - Loaded in_scores for epoch 5 -> restored sum_scores (total_steps=60)
2026-02-02 16:34:22 - INFO - Current Cumulative Epsilon: 0.88


In [7]:
# Build 10k dirac canaries
params = list(model.parameters())
canary_dirac_indices = []
remaining = CANARY_COUNT
for p_idx, p in enumerate(params):
    take = min(remaining, p.numel())
    canary_dirac_indices.extend((p_idx, i) for i in range(take))
    remaining -= take
    if remaining == 0:
        break

In [8]:
# Training loop
logger.info("Starting training...")


# Initialize total_steps if not already set from checkpoint loading
if 'total_steps' not in locals():
    total_steps = 0

final_test_acc = None
for epoch in range(start_epoch, EPOCHS + 1):
    train_loss, num_steps, scores = train_whitebox(
        model, optimizer, train_loader, device, epoch, AUG_MULTIPLICITY, MAX_PHYSICAL_BATCH_SIZE, LOGICAL_BATCH_SIZE, canary_dirac_indices=canary_dirac_indices, canary_prob=1.0 / len(train_loader), return_scores=True)

    scores = np.asarray(scores)   # (num_steps, num_canaries)
    assert scores.shape[0] == num_steps, f"scores.shape[0] = {scores.shape[0]} != num_steps = {num_steps}"
    sum_scores += scores.sum(axis=0)

    total_steps += num_steps
    test_acc = test(model, test_loader, device)
    
    # Get current privacy budget (epsilon)
    epsilon = privacy_engine.get_epsilon(delta=DELTA)
    
    logger.info(f"Epoch {epoch} - Train Loss: {train_loss:.4f}, Test Accuracy: {test_acc:.2f}%, Epsilon: {epsilon:.2f}, Delta: {DELTA}, Steps: {num_steps}, Total Steps: {total_steps}")
    final_test_acc = test_acc  # Store for final checkpoint
    
    # Save checkpoint every N epochs
    if epoch % CKPT_INTERVAL == 0:
        save_checkpoint(model, optimizer, epoch, test_acc, ckpt_dir, logger, global_step=total_steps)

        in_scores = sum_scores / total_steps
        out_canary_observations = sample_gaussian(total_steps, CANARY_COUNT, NOISE_MULTIPLIER, rng)
        out_scores = out_canary_observations.sum(axis=1) / total_steps

        np.savetxt(os.path.join(exp_dir, f'out_scores_{epoch:06d}.csv'), out_scores, delimiter=',')
        np.savetxt(os.path.join(exp_dir, f'in_scores_{epoch:06d}.csv'), in_scores, delimiter=',')
        
        # Save current_eps and delta for this epoch (same epoch as in_scores/out_scores)
        np.savetxt(os.path.join(exp_dir, f'privacy_params_{epoch:06d}.csv'), 
                [[epsilon, DELTA]], delimiter=',', header='current_eps,delta', comments='')


logger.info("Training complete!")
save_checkpoint(model, optimizer, EPOCHS, final_test_acc, ckpt_dir, logger, global_step=total_steps)
logger.info(f"Final log file saved at: {log_file}")

2026-02-02 16:34:22 - INFO - Starting training...


Epoch 6: 400batch [01:00,  6.62batch/s, loss=1.92]                      


2026-02-02 16:35:24 - INFO - Epoch 6 - Train Loss: 1.9473, Test Accuracy: 30.48%, Epsilon: 0.97, Delta: 1e-05, Steps: 12, Total Steps: 72


Epoch 7: 396batch [00:59,  6.64batch/s, loss=1.84]                      


2026-02-02 16:36:25 - INFO - Epoch 7 - Train Loss: 1.8397, Test Accuracy: 30.29%, Epsilon: 1.05, Delta: 1e-05, Steps: 12, Total Steps: 84


Epoch 8: 397batch [00:59,  6.65batch/s, loss=1.8]                       


2026-02-02 16:37:26 - INFO - Epoch 8 - Train Loss: 1.8427, Test Accuracy: 34.93%, Epsilon: 1.12, Delta: 1e-05, Steps: 12, Total Steps: 96


Epoch 9: 398batch [01:00,  6.63batch/s, loss=1.93]                      


2026-02-02 16:38:27 - INFO - Epoch 9 - Train Loss: 1.8464, Test Accuracy: 31.21%, Epsilon: 1.19, Delta: 1e-05, Steps: 12, Total Steps: 108


Epoch 10: 396batch [00:59,  6.66batch/s, loss=1.83]                      


2026-02-02 16:39:28 - INFO - Epoch 10 - Train Loss: 1.8135, Test Accuracy: 34.20%, Epsilon: 1.26, Delta: 1e-05, Steps: 12, Total Steps: 120
2026-02-02 16:39:29 - INFO - Checkpoint saved: /storage/coda1/p-vzikas3/0/ywei368/Yu-Project/Auditing/dpsgd-auditbench/data/mislabeled-canaries-27198899012190525004019618245709479116-10000-0.5-cifar10/ckpt/0000000010.npz (Epoch 10, Global Step 120, Test Accuracy: 34.20%)


Epoch 11: 397batch [00:59,  6.64batch/s, loss=1.97]                      


2026-02-02 16:40:30 - INFO - Epoch 11 - Train Loss: 1.7973, Test Accuracy: 36.92%, Epsilon: 1.32, Delta: 1e-05, Steps: 12, Total Steps: 132


Epoch 12: 397batch [00:59,  6.67batch/s, loss=1.66]                      


2026-02-02 16:41:31 - INFO - Epoch 12 - Train Loss: 1.8677, Test Accuracy: 41.77%, Epsilon: 1.39, Delta: 1e-05, Steps: 12, Total Steps: 144


Epoch 13: 400batch [00:59,  6.69batch/s, loss=1.68]                      


2026-02-02 16:42:32 - INFO - Epoch 13 - Train Loss: 1.7482, Test Accuracy: 37.13%, Epsilon: 1.45, Delta: 1e-05, Steps: 12, Total Steps: 156


Epoch 14: 397batch [00:59,  6.64batch/s, loss=1.55]                      


2026-02-02 16:43:33 - INFO - Epoch 14 - Train Loss: 1.7874, Test Accuracy: 39.26%, Epsilon: 1.50, Delta: 1e-05, Steps: 12, Total Steps: 168


Epoch 15: 394batch [00:59,  6.61batch/s, loss=1.65]                      


AssertionError: scores.shape[0] = 12 != num_steps = 11