In [None]:
import wandb

from ydnpd.pretraining.trainer import TransformerTrainer, ModelConfig
from ydnpd.pretraining.utils import set_strict_reproducibility_by_config

In [7]:
from ydnpd import ALL_EXPERIMENTS
from ydnpd.harness.config import EPSILONS

DATASET_FAMILY = "acs"

public_dataaset_pointers = [name for name in ALL_EXPERIMENTS[DATASET_FAMILY].dev_names
                            if name != ALL_EXPERIMENTS[DATASET_FAMILY].test_name]


sweep_configuration = {
    "method": "grid",
    "metric": {"goal": "maximize", "name": "auc"},
    "parameters": {
        "num_epochs": {"values": [1, 3 ,9]},
        "batch_size": {"values": [8, 32, 128]},
        "lr": {"values": [3e-4, 3e-5]},
        "private_data_pointer": {"value": ALL_EXPERIMENTS[DATASET_FAMILY].test_name},
        "public_data_pointer": {"values": public_dataaset_pointers[:1]},
        "epsilon": {"values": EPSILONS},
    },
}

def runner():
    wandb.init(project="ydnpd-dp-ft")
    print(wandb.config)
    set_strict_reproducibility_by_config(wandb.config)
    results = TransformerTrainer.train_and_evaluate(
        config=ModelConfig(num_epochs=wandb.config.num_epochs, 
                           batch_size=wandb.config.batch_size,
                           lr=wandb.config.lr,
                           epsilon=wandb.config.epsilon,),
        public_data_pointer=wandb.config.public_data_pointer,
        private_data_pointer=wandb.config.private_data_pointer,
    )
    wandb.log(**results)

In [None]:
sweep_id = wandb.sweep(sweep=sweep_configuration, project="ydnpd-dp-ft")
wandb.agent(sweep_id, function=runner)