In [1]:
import sys
sys.path.append('/mnt/new_home/ronedr/evolution-strategy-baselines-comparison')

In [2]:
from tqdm import tqdm
import brax.envs as brax_envs
from evosax.problems import BraxProblem as Problem
from evosax.problems.networks import MLP
from utils.problem_utils import get_problem_settings
from experiment.run_experiments import run_experiment_permutations
import jax

In [3]:
es_dict = {
    "SimpleES": {},
    "LES": {},
    "DES": {},
    "EvoTF_ES": {},
    "PGPE": {},
    "Open_ES": {},
    "SNES": {},
    "Sep_CMA_ES": {},
    "CMA_ES": {},
}


num_generations = 512
population_size = 128
seeds = list(range(0, 5))
result_dir = "../experiment_results"
problems_brax_envs = list(brax_envs._envs.keys())

In [None]:
for env_name in tqdm(problems_brax_envs, desc="Loading Problems .."):
    action_num, out_fn = get_problem_settings(env_name)
    try:
        problem = Problem(
            env_name=env_name,
            policy=MLP(
                layer_sizes=(32, 32, 32, 32, action_num),
                output_fn=out_fn,
            ),
            episode_length=1000,
            env_kwargs={"backend": "generalized"},
        )
        
        print("Successfully loaded:", env_name)
        for es in es_dict:
            for seed in seeds:
                key = jax.random.key(seed)
                run_experiment_permutations(problems=[problem],
                                            es_dict={es: es_dict[es]},
                                            num_generations=num_generations,
                                            population_size=population_size,
                                            seed=seed,
                                            result_dir=result_dir, 
                                            run_again_if_exist=False)
    except Exception as e:
        print("Failed to load:", env_name, e)
        continue
