In [None]:
import jax
import jax.numpy as jnp
import jax.random as jr
import gymnax

from kozax.genetic_programming import GeneticProgramming

In [None]:
class GymFitnessFunction:
    def __init__(self, env_name) -> None:
        self.env, self.env_params = gymnax.make(env_name)
        self.num_steps = 200

    def __call__(self, candidate, keys, tree_evaluator):
        reward = jax.vmap(self.simulate_rollout, in_axes=(None, 0, None))(candidate, keys, tree_evaluator)
        return jnp.mean(reward)
        
    def simulate_rollout(self, candidate, key, tree_evaluator):
        key, subkey = jr.split(key)
        state, env_state = self.env.reset(subkey, self.env_params)

        def policy(state):
            a = tree_evaluator(candidate, state)[0]
            return jax.lax.select(a == 0, 1, jax.lax.select(a > 0, 2, 0))

        def step_fn(carry, _):
            state, env_state, key = carry

            action = policy(state)

            key, subkey = jr.split(key)
            next_state, next_env_state, reward, _, _ = self.env.step(
                subkey, env_state, action, self.env_params
            )

            return (next_state, next_env_state, key), (state, reward)

        (_, (states, rewards)) = jax.lax.scan(
            step_fn, (state, env_state, key), None, length=self.num_steps
        )
        
        first_success = jnp.argmax(rewards)
        return (first_success + (first_success == 0) * self.num_steps)/self.num_steps
    
fitness_function = GymFitnessFunction("Acrobot-v1")

In [None]:
#Define hyperparameters
population_size = 100
num_populations = 10
num_generations = 50
batch_size = 4

#Define operators and variables
operator_list = [
    ("+", lambda x, y: jnp.add(x, y), 2, 0.5), 
    ("-", lambda x, y: jnp.subtract(x, y), 2, 0.1), 
    ("*", lambda x, y: jnp.multiply(x, y), 2, 0.5), 
    ("**", lambda x, y: jnp.power(x, y), 2, 0.1), 
    ("/", lambda x, y: jnp.divide(x, y), 2, 0.1),
    ("sin", lambda x: jnp.sin(x), 1, 0.1)
    ]

variable_list = [["y1", "y2", "y3", "y4", "y5", "y6"]]

#Initialize strategy
strategy = GeneticProgramming(num_generations, population_size, fitness_function, operator_list, variable_list)

In [None]:
key = jr.PRNGKey(0)
key, init_key, data_key = jr.split(key, 3)

batch_keys = jr.split(data_key, batch_size)

#Initial population
population = strategy.initialize_population(init_key)

for g in range(num_generations):
    key, eval_key, sample_key = jr.split(key, 3)
    fitness, population = strategy.evaluate_population(population, (batch_keys), eval_key)

    if g < (num_generations-1):
        population = strategy.evolve(population, fitness, sample_key)

strategy.print_pareto_front()