In [None]:
import wandb

from ydnpd.pretraining.trainer import TransformerTrainer, ModelConfig, PreTrainConfig
from ydnpd.pretraining.utils import set_strict_reproducibility_by_config
from ydnpd import ALL_EXPERIMENTS
from ydnpd.harness.config import EPSILONS


In [2]:
DATASET_FAMILY = "acs"

In [None]:
TransformerTrainer.train_and_evaluate(
        public_data_pointer=ALL_EXPERIMENTS[DATASET_FAMILY].test_name,
    )

In [None]:
public_dataaset_pointers = [name for name in ALL_EXPERIMENTS[DATASET_FAMILY].dev_names
                            if name != ALL_EXPERIMENTS[DATASET_FAMILY].test_name]



sweep_configuration = {
    "method": "grid",
    "metric": {"goal": "maximize", "name": "auc"},
    "parameters": {
        "pre_num_epochs": {"values": [1, 3 ,9]},
        "pre_batch_size": {"values": [4, 32, 128]},
        "pre_lr": {"values": [3e-4, 3e-5]},
        "dp_num_epochs": {"value": [20]},
        "dp_batch_size": {"values": [64, 128, 256]},
        "dp_lr": {"values": [3e-3, 3e-4]},
        "epsilon": {"values": EPSILONS},
        "private_data_pointer": {"value": ALL_EXPERIMENTS[DATASET_FAMILY].test_name},
        "public_data_pointer": {"values": [None] + public_dataaset_pointers[:1]},
    },
}

def runner():
    wandb.init(project="ydnpd-dp-ft")
    print(wandb.config)
    set_strict_reproducibility_by_config(wandb.config)
    results = TransformerTrainer.train_and_evaluate(
        config=ModelConfig(
            num_epochs=wandb.config.dp_num_epochs,
            batch_size=wandb.config.dp_batch_size,
            lr=wandb.config.dp_lr,
            epsilon=wandb.config.epsilon
        ),
        pretrain_config=PreTrainConfig(
            num_epochs=wandb.config.pre_num_epochs,
            batch_size=wandb.config.pre_batch_size,
            lr=wandb.config.pre_lr,
        ),
        public_data_pointer=wandb.config.public_data_pointer,
        private_data_pointer=wandb.config.private_data_pointer,
    )
    wandb.log(**results)

In [None]:
sweep_id = wandb.sweep(sweep=sweep_configuration, project="ydnpd-dp-ft")
wandb.agent(sweep_id, function=runner)