In [None]:
import os

import jax
import jax.numpy as jnp
import numpy as np
from evosax.algorithms import CMA_ES
from evosax.algorithms.distribution_based.learned_es import LearnedES
from evosax.problems import BBOBProblem as Problem, bbob_fns

# Config
meta_gens = 100
meta_pop = 32
inner_gens = 50
inner_pop = 16
dim = 2
seed = 0
result_dir = "./les_meta_train"
os.makedirs(result_dir, exist_ok=True)


def sample_task(key):
    keys = list(bbob_fns.keys())
    index = jax.random.choice(key, len(keys), shape=())
    return Problem(fn_name=keys[int(index)], num_dims=dim, sample_rotations=False)


def eval_inner(es, params, state, prob, prob_state, key):
    def step(carry, _):
        st, problem_state, key = carry
        key, ka, ke = jax.random.split(key, 3)
        pop, st = es.ask(ka, st, params)
        fitness, problem_state, _ = prob.eval(ke, pop, problem_state)
        st, _ = es.tell(ke, pop, fitness, st, params)
        return (st, problem_state, key), jnp.min(fitness)

    (final_state, final_prob_state, _), mins = jax.lax.scan(
        step, (state, prob_state, key), None, length=inner_gens)
    return jnp.mean(mins)


def main():
    key = jax.random.PRNGKey(seed)
    dummy_solution = jnp.zeros(dim)

    # Outer ES (CMA-ES) meta-optimizing LES parameters
    outer = CMA_ES(population_size=meta_pop, solution=dummy_solution)
    theta = outer.default_params
    state_outer = outer.init(key, outer.solution, theta)

    for gen in range(meta_gens):
        key, ask_key = jax.random.split(key)
        solutions, state_outer = outer.ask(ask_key, state_outer, theta)

        fitness_list = []
        for sol in solutions:
            les = LearnedES(population_size=inner_pop, solution=sol)
            les_params = les.default_params
            key, init_key = jax.random.split(key)
            state_inner = les.init(init_key, les.solution, les_params)

            task_scores = []
            for _ in range(4):
                key, tk, ek = jax.random.split(key, 3)
                prob = sample_task(tk)
                key, sk = jax.random.split(key)
                prob_state = prob.init(sk)
                score = eval_inner(les, les_params, state_inner, prob, prob_state, ek)
                task_scores.append(score)
            fitness_list.append(-jnp.mean(jnp.stack(task_scores)))

        fitness_arr = jnp.stack(fitness_list)
        key, tell_key = jax.random.split(key)
        state_outer = outer.tell(tell_key, solutions, fitness_arr, state_outer, theta)
        print(f"[Meta gen {gen:03d}] meta-fitness = {fitness_arr.mean():.4f}")

    # Save best meta-learned LES params
    best_theta = np.array(solutions[jnp.argmax(fitness_arr)])
    np.save(os.path.join(result_dir, "best_les_theta.npy"), best_theta)
    print("Saved best LES parameters to best_les_theta.npy")


if __name__ == "__main__":
    main()