In [1]:
import sys
import os
import shutil
import json

# Navigate to the parent directory of the project structure
project_dir = os.path.abspath(os.path.join(os.getcwd(), '../'))
src_dir = os.path.join(project_dir, 'src')
data_dir = os.path.join(project_dir, 'data')
fig_dir = os.path.join(project_dir, 'fig')
logs_dir = os.path.join(project_dir, 'logs')
os.makedirs(fig_dir, exist_ok=True)
os.makedirs(data_dir, exist_ok=True)
os.makedirs(logs_dir, exist_ok=True)

# Add the src directory to sys.path
sys.path.append(src_dir)

import torch
from opacus import PrivacyEngine
import torch.optim as optim

from utils import setup_logging, save_checkpoint, find_latest_checkpoint, load_checkpoint
from dataset import get_data_loaders
from network_arch import WideResNet
from train import train, test


In [2]:
# ==========================================
# Hyperparameters (Settings from the paper "Unlocking High-Accuracy Differentially Private Image Classification through Scale")
# ==========================================
LOGICAL_BATCH_SIZE = 4096     # Target batch size (Paper)
MAX_PHYSICAL_BATCH_SIZE = 128  # GPU limit (128 * 16 = 512 effective images)
AUG_MULTIPLICITY = 16         # K=16 augmentations
MAX_GRAD_NORM = 1.0
EPSILON = 8.0
DELTA = 1e-5
EPOCHS = 140                   # Increase to 100+ for best results
LR = 4.0                      # High LR for large batch
MOMENTUM = 0.0                # No momentum
NOISE_MULTIPLIER = 3.0        # Sigma ~ 3.0 is optimal for BS=4096
CKPT_INTERVAL = 20            # Save checkpoint every 10 epochs

expid = 1
SEED = 42  # Random seed for reproducibility

In [3]:
logger, log_file = setup_logging(log_dir=logs_dir)
logdir_path = os.path.dirname(log_file) 

# Create experiment directory
exp_dir = os.path.join(data_dir, f"original-DP-{expid}-cifar10")
os.makedirs(exp_dir, exist_ok=True)
logger.info(f"Experiment directory: {exp_dir}")

# Create checkpoint directory under experiment directory
ckpt_dir = os.path.join(exp_dir, "ckpt")
os.makedirs(ckpt_dir, exist_ok=True)
logger.info(f"Checkpoint directory: {ckpt_dir}")


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Run experiment on device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
import numpy as np
np.random.seed(SEED)
logger.info(f"Set random seeds (torch, numpy) to: {SEED}")

# Store hyperparameters in a dictionary
params = {
    'logical_batch_size': LOGICAL_BATCH_SIZE,
    'max_physical_batch_size': MAX_PHYSICAL_BATCH_SIZE,
    'aug_multiplicity': AUG_MULTIPLICITY,
    'max_grad_norm': MAX_GRAD_NORM,
    'epsilon': EPSILON,
    'delta': DELTA,
    'epochs': EPOCHS,
    'lr': LR,
    'momentum': MOMENTUM,
    'noise_multiplier': NOISE_MULTIPLIER,
    'expid': expid,
    'ckpt_interval': CKPT_INTERVAL,
    'seed': SEED
}

# Save hyperparameters to experiment directory
hparams_path = os.path.join(exp_dir, 'hparams.json')
with open(hparams_path, 'w') as f:
    json.dump(params, f, indent=2)
logger.info(f"Hyperparameters saved to: {hparams_path}")


2026-01-09 16:58:54 - INFO - Logging initialized. Log file: /storage/coda1/p-vzikas3/0/ywei368/Yu-Project/Auditing/onerun-lira/logs/train_20260109_165854.log
2026-01-09 16:58:54 - INFO - Experiment directory: /storage/coda1/p-vzikas3/0/ywei368/Yu-Project/Auditing/onerun-lira/data/original-DP-1-cifar10
2026-01-09 16:58:54 - INFO - Checkpoint directory: /storage/coda1/p-vzikas3/0/ywei368/Yu-Project/Auditing/onerun-lira/data/original-DP-1-cifar10/ckpt
2026-01-09 16:58:54 - INFO - Run experiment on device: cuda
2026-01-09 16:58:54 - INFO - Hyperparameters saved to: /storage/coda1/p-vzikas3/0/ywei368/Yu-Project/Auditing/onerun-lira/data/original-DP-1-cifar10/hparams.json


In [4]:
# Load data
logger.info("Loading data...")
train_loader, test_dataset = get_data_loaders(
    data_dir=data_dir,
    logical_batch_size=LOGICAL_BATCH_SIZE,
    num_workers=4
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=1024, shuffle=False, num_workers=4
)

2026-01-09 16:58:54 - INFO - Loading data...


  entry = pickle.load(f, encoding="latin1")


In [None]:
# Create model
logger.info("Creating model...")
model = WideResNet(depth=16, widen_factor=4).to(device)
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)

# Setup privacy engine
logger.info("Setting up privacy engine...")
privacy_engine = PrivacyEngine()
model, optimizer, train_loader = privacy_engine.make_private(
    module=model,
    optimizer=optimizer,
    data_loader=train_loader,
    noise_multiplier=NOISE_MULTIPLIER,
    max_grad_norm=MAX_GRAD_NORM,
)

In [None]:
# ==========================================
# Checkpoint Loading
# ==========================================
start_epoch = 1

# Find the latest checkpoint (largest epoch number)
checkpoint_result = find_latest_checkpoint(ckpt_dir)
if checkpoint_result is not None:
    checkpoint_path, checkpoint_epoch = checkpoint_result
else:
    checkpoint_path, checkpoint_epoch = None, None

if checkpoint_path is not None:
    logger.info(f"Loading checkpoint '{checkpoint_path}' (epoch {checkpoint_epoch})...")
    
    # Load model and optimizer state
    loaded_epoch, loaded_global_step = load_checkpoint(checkpoint_path, model, optimizer, device, logger)
    start_epoch = loaded_epoch + 1
    
    # Manually insert this history into the accountant
    # The history format is a list of tuples: (noise_multiplier, sample_rate, num_steps)
    steps_per_epoch = len(train_loader) 
    sample_rate = 1 / len(train_loader)
    
    # This line forces the accountant to remember the past
    privacy_engine.accountant.history.append((NOISE_MULTIPLIER, sample_rate, loaded_global_step))
    
    logger.info(f"Resumed from Epoch {start_epoch}")
    logger.info(f"Privacy Accountant updated with {loaded_global_step} past steps.")
    
    # Initialize total_steps from loaded checkpoint
    total_steps = loaded_global_step if loaded_global_step is not None else 0
    
    # Verify Epsilon matches where you left off
    current_eps = privacy_engine.get_epsilon(DELTA)
    logger.info(f"Current Cumulative Epsilon: {current_eps:.2f}")
else:
    logger.info("No checkpoint found. Starting from scratch.")
    total_steps = 0

In [None]:
# Training loop
logger.info("Starting training...")

# Initialize total_steps if not already set from checkpoint loading
if 'total_steps' not in locals():
    total_steps = 0

final_test_acc = None
for epoch in range(start_epoch, EPOCHS + 1):
    train_loss, num_steps = train(
        model, optimizer, train_loader, device, epoch, AUG_MULTIPLICITY, MAX_PHYSICAL_BATCH_SIZE, LOGICAL_BATCH_SIZE
    )
    total_steps += num_steps
    test_acc = test(model, test_loader, device)
    
    # Get current privacy budget (epsilon)
    epsilon = privacy_engine.get_epsilon(delta=DELTA)
    
    logger.info(f"Epoch {epoch} - Train Loss: {train_loss:.4f}, Test Accuracy: {test_acc:.2f}%, Epsilon: {epsilon:.2f}, Delta: {DELTA}, Steps: {num_steps}, Total Steps: {total_steps}")
    final_test_acc = test_acc  # Store for final checkpoint
    
    # Save checkpoint every N epochs
    if epoch % CKPT_INTERVAL == 0:
        save_checkpoint(model, optimizer, epoch, test_acc, ckpt_dir, logger, global_step=total_steps)


logger.info("Training complete!")
save_checkpoint(model, optimizer, EPOCHS, final_test_acc, ckpt_dir, logger, global_step=total_steps)
logger.info(f"Final log file saved at: {log_file}")