In [1]:
import sys
sys.path.append('/mnt/new_home/ronedr/evolution-strategy-baselines-comparison')

In [None]:
import optax
from tqdm import tqdm
import brax.envs as brax_envs
from evosax.problems import BraxProblem as Problem
from evosax.problems.networks import MLP
from utils.problem_utils import get_problem_settings
from experiment.run_experiments import run_experiment_permutations

In [None]:
num_generations = 64
population_size = 128
seed = 0
result_dir = "../experiment_results"

In [None]:
from evosax.problems import CNN, TorchVisionProblem as Problem, identity_output_fn

network = CNN(
    num_filters=[8, 16],
    kernel_sizes=[(5, 5), (5, 5)],
    strides=[(1, 1), (1, 1)],
    mlp_layer_sizes=[10],
    output_fn=identity_output_fn,
)

es_dict = {
    "LES": {},
    "SimpleES": {},
    "PGPE": {},
    "Open_ES": {},
    "SNES": {},
    "Sep_CMA_ES": {},
    "CMA_ES": {},
}

In [None]:
import jax
import os
from utils.problem_utils import get_problem_name
from evosax.algorithms import algorithms

for task_name in tqdm(["MNIST", "FashionMNIST", "CIFAR10", "SVHN"], desc="Loading Problems .."):
    try:
        problem = Problem(task_name=task_name, network=network, batch_size=1024)
        print("Successfully loaded:", task_name)
        for es in es_dict:
            key = jax.random.key(seed)
            es_path = f"{result_dir}/{get_problem_name(problem)}/{algorithms[es](population_size=population_size, solution=problem.sample(key)).__class__.__name__}.json"
            print(es_path)
            if os.path.exists(es_path):
                print("Path exists")
            else:
                print("Path does not exist")
                run_experiment_permutations(problems=[problem],
                                            es_dict={es: {}},
                                            num_generations=num_generations,
                                            population_size=population_size,
                                            seed=seed,
                                            result_dir=result_dir)
    except Exception as e:
        print("Failed to load:", task_name, e)
        continue