In [1]:
import optax

In [2]:
num_generations = 64
population_size = 128
seed = 0
result_dir = "../experiment_results"

In [None]:
from evosax.problems import CNN, TorchVisionProblem as Problem, identity_output_fn

network = CNN(
    num_filters=[8, 16],
    kernel_sizes=[(5, 5), (5, 5)],
    strides=[(1, 1), (1, 1)],
    mlp_layer_sizes=[10],
    output_fn=identity_output_fn,
)

problems = [Problem(task_name="MNIST", network=network, batch_size=1024),
            Problem(task_name="FashionMNIST", network=network, batch_size=1024),
            Problem(task_name="CIFAR10", network=network, batch_size=1024),
            Problem(task_name="SVHN", network=network, batch_size=1024)]

lr_schedule = optax.exponential_decay(
    init_value=0.01,
    transition_steps=num_generations,
    decay_rate=0.1,
)
std_schedule = optax.exponential_decay(
    init_value=0.05,
    transition_steps=num_generations,
    decay_rate=0.2,
)

es_dict = {
    "LES": {},
    "EvoTF_ES": {},
    "DES": {},
    "SimpleES": {},
    "PGPE": {},
    "Open_ES": {},
    "SNES": {},
    "Sep_CMA_ES": {},
    "CMA_ES": {},
}

In [None]:
from experiment.run_experiments import run_experiment_permutations

run_experiment_permutations(problems=problems,
                            es_dict=es_dict,
                            num_generations=num_generations,
                            population_size=population_size,
                            seed=seed,
                            result_dir=result_dir)