In [1]:
import sys
import os
import shutil
import json

# Navigate to the parent directory of the project structure
project_dir = os.path.abspath(os.path.join(os.getcwd(), '../'))
src_dir = os.path.join(project_dir, 'src')
data_dir = os.path.join(project_dir, 'data')
fig_dir = os.path.join(project_dir, 'fig')
logs_dir = os.path.join(project_dir, 'logs')
os.makedirs(fig_dir, exist_ok=True)
os.makedirs(data_dir, exist_ok=True)
os.makedirs(logs_dir, exist_ok=True)

# Add the src directory to sys.path
sys.path.append(src_dir)

import torch
from opacus import PrivacyEngine
import torch.optim as optim

from utils import setup_logging, save_checkpoint, find_latest_checkpoint, load_checkpoint
from dataset import get_data_loaders
from network_arch import WideResNet
from train import train, test


In [2]:
# ==========================================
# Hyperparameters (Settings from the paper "Unlocking High-Accuracy Differentially Private Image Classification through Scale")
# ==========================================
LOGICAL_BATCH_SIZE = 4096     # Target batch size (Paper)
MAX_PHYSICAL_BATCH_SIZE = 128  # GPU limit (128 * 16 = 512 effective images)
AUG_MULTIPLICITY = 2         # K=16 augmentations
MAX_GRAD_NORM = 1.0
EPSILON = 8.0
DELTA = 1e-5
EPOCHS = 140                   # Increase to 100+ for best results
LR = 4.0                      # High LR for large batch
MOMENTUM = 0.0                # No momentum
NOISE_MULTIPLIER = 3.0        # Sigma ~ 3.0 is optimal for BS=4096
CKPT_INTERVAL = 2            # Save checkpoint every 10 epochs

expid = 1

In [3]:
logger, log_file = setup_logging(log_dir=logs_dir)
logdir_path = os.path.dirname(log_file) 

# Create experiment directory
exp_dir = os.path.join(data_dir, f"exp-{expid}-cifar10")
os.makedirs(exp_dir, exist_ok=True)
logger.info(f"Experiment directory: {exp_dir}")

# Create checkpoint directory under experiment directory
ckpt_dir = os.path.join(exp_dir, "ckpt")
os.makedirs(ckpt_dir, exist_ok=True)
logger.info(f"Checkpoint directory: {ckpt_dir}")


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Run experiment on device: {device}")

# Store hyperparameters in a dictionary
params = {
    'logical_batch_size': LOGICAL_BATCH_SIZE,
    'max_physical_batch_size': MAX_PHYSICAL_BATCH_SIZE,
    'aug_multiplicity': AUG_MULTIPLICITY,
    'max_grad_norm': MAX_GRAD_NORM,
    'epsilon': EPSILON,
    'delta': DELTA,
    'epochs': EPOCHS,
    'lr': LR,
    'momentum': MOMENTUM,
    'noise_multiplier': NOISE_MULTIPLIER,
    'expid': expid,
    'ckpt_interval': CKPT_INTERVAL
}

# Save hyperparameters to experiment directory
hparams_path = os.path.join(exp_dir, 'hparams.json')
with open(hparams_path, 'w') as f:
    json.dump(params, f, indent=2)
logger.info(f"Hyperparameters saved to: {hparams_path}")


2026-01-09 15:31:03 - INFO - Logging initialized. Log file: /storage/coda1/p-vzikas3/0/ywei368/Yu-Project/Auditing/onerun-lira/logs/train_20260109_153103.log
2026-01-09 15:31:03 - INFO - Experiment directory: /storage/coda1/p-vzikas3/0/ywei368/Yu-Project/Auditing/onerun-lira/data/exp-1-cifar10


2026-01-09 15:31:03 - INFO - Checkpoint directory: /storage/coda1/p-vzikas3/0/ywei368/Yu-Project/Auditing/onerun-lira/data/exp-1-cifar10/ckpt
2026-01-09 15:31:03 - INFO - Run experiment on device: cuda
2026-01-09 15:31:03 - INFO - Hyperparameters saved to: /storage/coda1/p-vzikas3/0/ywei368/Yu-Project/Auditing/onerun-lira/data/exp-1-cifar10/hparams.json


In [4]:
# Load data
logger.info("Loading data...")
train_loader, test_dataset = get_data_loaders(
    data_dir=data_dir,
    logical_batch_size=LOGICAL_BATCH_SIZE,
    num_workers=4
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=1024, shuffle=False, num_workers=4
)

2026-01-09 15:31:03 - INFO - Loading data...


  entry = pickle.load(f, encoding="latin1")


In [5]:
# Create model
logger.info("Creating model...")
model = WideResNet(depth=16, widen_factor=4).to(device)
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)

# Setup privacy engine
logger.info("Setting up privacy engine...")
privacy_engine = PrivacyEngine()
model, optimizer, train_loader = privacy_engine.make_private(
    module=model,
    optimizer=optimizer,
    data_loader=train_loader,
    noise_multiplier=NOISE_MULTIPLIER,
    max_grad_norm=MAX_GRAD_NORM,
)

2026-01-09 15:31:06 - INFO - Creating model...
2026-01-09 15:31:06 - INFO - Setting up privacy engine...




In [6]:
# ==========================================
# Checkpoint Loading
# ==========================================
start_epoch = 1

# Find the latest checkpoint (largest epoch number)
checkpoint_result = find_latest_checkpoint(ckpt_dir)
if checkpoint_result is not None:
    checkpoint_path, checkpoint_epoch = checkpoint_result
else:
    checkpoint_path, checkpoint_epoch = None, None

if checkpoint_path is not None:
    logger.info(f"Loading checkpoint '{checkpoint_path}' (epoch {checkpoint_epoch})...")
    
    # Load model and optimizer state
    loaded_epoch = load_checkpoint(checkpoint_path, model, optimizer, device, logger)
    start_epoch = loaded_epoch + 1
    
    # RECOVER PRIVACY STATE
    # We must tell the accountant we have already taken N steps.
    # Calculate previous steps: (Epochs Done) * (Steps per Epoch)
    
    # Steps per epoch = Total dataset / Logical Batch Size
    steps_per_epoch = 50000 // LOGICAL_BATCH_SIZE 
    past_steps = (start_epoch - 1) * steps_per_epoch
    
    # Manually insert this history into the accountant
    # The history format is a list of tuples: (noise_multiplier, sample_rate, num_steps)
    sample_rate = LOGICAL_BATCH_SIZE / 50000
    
    # This line forces the accountant to remember the past
    privacy_engine.accountant.history.append((NOISE_MULTIPLIER, sample_rate, past_steps))
    
    logger.info(f"Resumed from Epoch {start_epoch}")
    logger.info(f"Privacy Accountant updated with {past_steps} past steps.")
    
    # Verify Epsilon matches where you left off
    current_eps = privacy_engine.get_epsilon(DELTA)
    logger.info(f"Current Cumulative Epsilon: {current_eps:.2f}")
else:
    logger.info("No checkpoint found. Starting from scratch.")

2026-01-09 15:31:06 - INFO - No checkpoint found. Starting from scratch.


In [7]:
# Training loop
logger.info("Starting training...")

final_test_acc = None
for epoch in range(start_epoch, EPOCHS + 1):
    train_loss = train(
        model, optimizer, train_loader, device, epoch, AUG_MULTIPLICITY, MAX_PHYSICAL_BATCH_SIZE
    )
    test_acc = test(model, test_loader, device)
    
    # Get current privacy budget (epsilon)
    epsilon = privacy_engine.get_epsilon(delta=DELTA)
    
    logger.info(f"Epoch {epoch} - Train Loss: {train_loss:.4f}, Test Accuracy: {test_acc:.2f}%, Epsilon: {epsilon:.2f}, Delta: {DELTA}")
    final_test_acc = test_acc  # Store for final checkpoint
    
    # Save checkpoint every N epochs
    if epoch % CKPT_INTERVAL == 0:
        save_checkpoint(model, optimizer, epoch, test_acc, ckpt_dir, logger)


logger.info("Training complete!")
save_checkpoint(model, optimizer, EPOCHS, final_test_acc, ckpt_dir, logger)
logger.info(f"Final log file saved at: {log_file}")

2026-01-09 15:31:06 - INFO - Starting training...


Epoch 1: 394batch [00:40,  9.74batch/s, loss=2.06]                      


2026-01-09 15:31:47 - INFO - Epoch 1 - Train Loss: 2.1714, Test Accuracy: 25.58%, Epsilon: 0.41, Delta: 1e-05


Epoch 2: 397batch [00:40,  9.81batch/s, loss=2.21]                      


2026-01-09 15:32:28 - INFO - Epoch 2 - Train Loss: 2.1448, Test Accuracy: 22.74%, Epsilon: 0.56, Delta: 1e-05
2026-01-09 15:32:28 - INFO - Checkpoint saved: /storage/coda1/p-vzikas3/0/ywei368/Yu-Project/Auditing/onerun-lira/data/exp-1-cifar10/ckpt/0000000002.npz (Epoch 2, Test Accuracy: 22.74%)


Epoch 3: 396batch [00:40,  9.78batch/s, loss=1.89]                      


2026-01-09 15:33:09 - INFO - Epoch 3 - Train Loss: 1.9672, Test Accuracy: 26.20%, Epsilon: 0.69, Delta: 1e-05


Epoch 4: 397batch [00:40,  9.83batch/s, loss=2.16]                      


2026-01-09 15:33:50 - INFO - Epoch 4 - Train Loss: 2.1219, Test Accuracy: 26.21%, Epsilon: 0.79, Delta: 1e-05
2026-01-09 15:33:51 - INFO - Checkpoint saved: /storage/coda1/p-vzikas3/0/ywei368/Yu-Project/Auditing/onerun-lira/data/exp-1-cifar10/ckpt/0000000004.npz (Epoch 4, Test Accuracy: 26.21%)


Epoch 5:  14%|█▎        | 53/391 [00:05<00:35,  9.55batch/s, loss=1.96]


KeyboardInterrupt: 