In [None]:
import jax
import diffrax
import jax.numpy as jnp
import jax.random as jrandom
import matplotlib.pyplot as plt

from MultiTreeGP.expression import Expression
import MultiTreeGP.evaluators.SR_evaluator as evaluator
from MultiTreeGP.environments.SR_environments.vd_pol_oscillator import VanDerPolOscillator
from MultiTreeGP.algorithms.genetic_programming import GeneticProgramming

key = jrandom.PRNGKey(0)

# Generate data

In [None]:
def get_data(key, env, T, batch_size=20):
    init_key, noise_key1, noise_key2 = jrandom.split(key, 3)
    x0s = env.sample_init_states(batch_size, init_key)
    process_noise_keys = jrandom.split(noise_key1, batch_size)
    obs_noise_keys = jrandom.split(noise_key2, batch_size)
    ts = jnp.arange(0, T, 0.1)

    def solve(env, ts, x0, process_noise_key, obs_noise_key):
        solver = diffrax.EulerHeun()
        dt0 = 0.001
        saveat = diffrax.SaveAt(ts=ts)

        brownian_motion = diffrax.UnsafeBrownianPath(shape=(env.n_var,), key=process_noise_key, levy_area=diffrax.BrownianIncrement)
        system = diffrax.MultiTerm(diffrax.ODETerm(env.drift), diffrax.ControlTerm(env.diffusion, brownian_motion))

        sol = diffrax.diffeqsolve(system, solver, ts[0], ts[-1], dt0, x0, saveat=saveat, max_steps=16**5, adjoint=diffrax.DirectAdjoint())
        xs = sol.ys
        _, ys = jax.lax.scan(env.f_obs, obs_noise_key, (ts, xs))

        return xs, ys

    xs, ys = jax.vmap(solve, in_axes=[None, None, 0, 0, 0])(env, ts, x0s, process_noise_keys, obs_noise_keys)
    
    return x0s, ts, xs, ys



init_key, data_key = jrandom.split(key)

env = VanDerPolOscillator(0, 0)

T = 40
x0s, ts, xs, ys = get_data(data_key, env, T=T, batch_size=10)

## Visualize trajectories

In [None]:
fig, ax = plt.subplots(2,2)
ax = ax.ravel()
for i in range(4):
    for j in range(env.n_var):
        ax[i].plot(ts, xs[i,:,j],color=f"C{j}", label=f"$x_{j}$")
plt.legend()
plt.show()

# Symbolic Regression

In [None]:
#Define hyperparameters
population_size = 50
num_populations = 3
pool_size = 8
num_generations = 10

#Define expressions
operators = ["+", "-", "*", "/", "square", "power"]
operator_probs = jnp.array([0.5, 0.1, 0.5, 0., 0.1, 0])
expressions = [Expression([["x",env.n_var]], operators, operator_probs, condition=lambda tree: sum(["x" in str(leaf) for leaf in jax.tree_util.tree_leaves(tree)])>0)]
layer_sizes = jnp.array([2])

#Define evaluator
fitness_function = evaluator.Evaluator(dt0 = 0.001)

In [None]:
#Initialize strategy
strategy = GeneticProgramming(num_generations, population_size, fitness_function, expressions, layer_sizes, num_populations = num_populations, pool_size = pool_size, 
                              init_method="ramped", tournament_size=4, max_init_depth=4, max_depth=8, size_parsinomy=0.0, leaf_sd=1, migration_period=5, gradient_optimisation=True, gradient_steps=5)

#Initialize population
population = strategy.initialize_population(init_key)

for g in range(num_generations):
    fitnesses, population = strategy.evaluate_population(population, (x0s, ts, ys))
    
    best_fitness, best_solution = strategy.get_statistics(g)
    
    print(f"In generation {g+1}, best fitness = {best_fitness:.4f}, best solution = {best_solution}")

    if g < (num_generations-1):
        key, sample_key = jrandom.split(key)
        population = strategy.evolve_population(population, sample_key)

best_fitnesses, best_solutions = strategy.get_statistics()

# Visualize best solution

In [None]:
#Generate test data
x0s, ts, xs, ys = get_data(jrandom.PRNGKey(42), env, T=T, batch_size=4)

#Evaluate best model on test data
pred, fitness = fitness_function.evaluate_model(best_solutions[-1].tree_to_function(expressions), (x0s, ts, ys))

fig, ax = plt.subplots(2,2)
ax = ax.ravel()
for i in range(4):
    for j in range(env.n_var):
        ax[i].plot(ts, xs[i,:,j],color=f"C{j}", label=f"$x_{j}$")
        ax[i].plot(ts, pred[i,:,j],color=f"C{9-j}", label=f"$y_{j}$")
plt.legend()
plt.show()