In [1]:
import jax
import jax.numpy as jnp
from evosax.problems import (
    MetaBBOBProblem as MetaProblem,
)


In [None]:
meta_problem = MetaProblem(
    fn_names=[
        "sphere",
        "ellipsoidal",
        "rastrigin",
        "bueche_rastrigin",
        "linear_slope",
        "attractive_sector",
        "step_ellipsoidal",
        "rosenbrock",
        "rosenbrock_rotated",
        "ellipsoidal_rotated",
        "discus",
        "bent_cigar",
        "sharp_ridge",
        "different_powers",
        "rastrigin_rotated",
        "weierstrass",
        "schaffers_f7",
        "schaffers_f7_ill_cond",
        "griewank_rosenbrock",
        "katsuura",
        "lunacek",
    ],
    min_num_dims=2,
    max_num_dims=16,
)

In [None]:
import optax
from evosax.algorithms import algorithms

num_generations = 8_192
population_size = 1_024
num_tasks = 128

es_dict = {
    "SimpleES": {},
    "PGPE": {},
    "Open_ES": {"optimizer": optax.adam(1e-3)},
    "SNES": {},
    "Sep_CMA_ES": {},
    "CMA_ES": {},
}

# Dictionary to store results for each ES
results = {}

# Sample BBOB tasks
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, num_tasks)
problem_params = jax.vmap(meta_problem.sample_params)(keys)
problem_params = problem_params.replace(f_opt=jnp.zeros_like(problem_params.f_opt))

key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, num_tasks)
problem_state = jax.vmap(meta_problem.init)(keys, problem_params)

# Sample dummy solution
key, subkey = jax.random.split(key)
solution = meta_problem.sample(subkey)

# Sample initial solutions
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, num_tasks)
solutions = jax.vmap(meta_problem.sample)(keys)

# Loop over the selected ES algorithms
for es_name in es_dict:
    print(f"Running {es_name}...")

    # Get the ES class from the algorithms dictionary
    ES = algorithms[es_name]

    # Initialize the ES
    es = ES(
        population_size=population_size,
        solution=solution,
        **es_dict[es_name],
    )
    params = es.default_params

    # Initialize ES state for each BBOB task
    key, subkey = jax.random.split(key)
    keys = jax.random.split(subkey, num_tasks)
    state = jax.vmap(es.init, in_axes=(0, 0, None))(keys, solutions, params)


    # Define the step function for the scan
    def step(carry, key):
        state, params, problem_state, problem_params = carry
        key_ask, key_eval, key_tell = jax.random.split(key, 3)

        # Ask - Eval - Tell
        population, state = es.ask(key_ask, state, params)
        population = jnp.clip(population, -5.0, 5.0)
        fitness, problem_state, _ = meta_problem.eval(
            key_eval, population, problem_state, problem_params
        )
        state, metrics = es.tell(key_tell, population, fitness, state, params)

        return (state, params, problem_state, problem_params), (state, metrics)


    @jax.jit
    def eval(key, params, problem_state, problem_params):
        # Init state
        key, subkey = jax.random.split(key)
        solution = jnp.clip(meta_problem.sample(subkey), -4.0, 4.0)

        key, subkey = jax.random.split(key)
        state = es.init(subkey, solution, params)

        # Scan
        keys = jax.random.split(subkey, num_generations)
        (state, params, problem_state, problem_params), (states, metrics) = (
            jax.lax.scan(
                step,
                (state, params, problem_state, problem_params),
                keys,
                length=num_generations,
            )
        )

        return metrics, states.mean[-1]


    # Run evaluation across all tasks
    key, subkey = jax.random.split(key)
    keys = jax.random.split(subkey, num_tasks)
    metrics_batch, final_means = jax.vmap(eval, in_axes=(0, None, 0, 0))(
        keys, params, problem_state, problem_params
    )

    # Average metrics across tasks
    metrics = jax.tree.map(lambda x: jnp.mean(x, axis=0), metrics_batch)

    # Store the results
    results[es_name] = metrics
