In [None]:
from dataclasses import asdict
from dp_virtual_projection_population_only import ExperimentConfig as VProjConfig, train as vp_train
from local_dpsgd_experiment import ExperimentConfig as LDConfig, train as ld_train
from experiment_utils import clear_memory

from privacy_pipeline.config import ExperimentConfig as BaseConfig
from privacy_pipeline.training import train as base_train, train_with_outlier_clipping

In [None]:
# Baseline DP-SGD experiment
# Based on: https://proceedings.neurips.cc/paper_files/paper/2015/file/52d080a3e172c33fd6886a37e7288491-Paper.pdf
baseline_cfg = BaseConfig(
    num_epochs=1,
    batch_size=512,
    lr=0.05,
    outer_momentum=0.9,
    inner_momentum=0.10,
    noise_mult=1.0,
    delta=1e-5,
    c_start=4.0,
    c_end=2.0,
    self_aug_factor=1,
    schedule_milestones=[12, 18],
    schedule_gamma=0.1,
    max_momentum_size=10000,
)
print(asdict(baseline_cfg))
baseline_results = base_train(baseline_cfg, 'outputs/baseline_run')
clear_memory()
baseline_results

In [None]:
# Synthetic data mixing experiment
# Based on: https://arxiv.org/pdf/2311.01295
outlier_cfg = BaseConfig(
    seed=0,
    batch_size=1000,
    lr=0.1,
    outer_momentum=0.9,
    inner_momentum=0.08,
    noise_mult=1.5,
    delta=1e-5,
    num_epochs=1,
    self_aug_factor=3,
)
print(asdict(outlier_cfg))
outlier_results = train_with_outlier_clipping(outlier_cfg, 'outputs/outlier_run')
clear_memory()
outlier_results

In [None]:
# Synthetic data projection experiment
# Based on: https://arxiv.org/pdf/2506.16661
vp_cfg = VProjConfig(
    num_epochs=1,
    batch_size=512,
    lr=0.05,
    outer_momentum=0.9,
    inner_momentum=0.10,
    noise_mult=1.0,
    delta=1e-5,
    c_start=4.0,
    c_end=2.0,
    self_aug_factor=1,
    schedule_milestones=[12, 18],
    schedule_gamma=0.1,
    max_momentum_size=10000,
)
print(asdict(vp_cfg))
vp_results = vp_train(vp_cfg, 'outputs/vproj_run')
clear_memory()
vp_results


In [None]:
# Local DP-SGD/Scaffold experiment
# Based on: https://arxiv.org/pdf/2306.16504
ld_cfg = LDConfig(
    EXPERIMENT='local_pub',
    PRIV_BATCH=512,
    NUM_EPOCHS=1,
    LR_OUTER=0.05,
    MOMENTUM_OUTER=0.9,
)
print(asdict(ld_cfg))
ld_results = ld_train(ld_cfg, 'outputs/localdp_run')
clear_memory()
ld_results
