In [1]:
# Specify the cores to use for XLA
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=10'

import jax
import diffrax
import jax.numpy as jnp
import jax.random as jr
import diffrax
import sys
sys.path.append("/Users/sigur.de.vries/Library/Mobile Documents/com~apple~CloudDocs/phd/kozax")

from kozax.genetic_programming import GeneticProgramming
from kozax.fitness_functions.ODE_fitness_function import ODEFitnessFunction
from kozax.environments.SR_environments.lotka_volterra import LotkaVolterra

These device(s) are detected:  [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7), CpuDevice(id=8), CpuDevice(id=9)]


In [2]:
def get_data(key, env, dt, T, batch_size=20):
    x0s = env.sample_init_states(batch_size, key)
    ts = jnp.arange(0, T, dt)

    def solve(env, ts, x0):
        solver = diffrax.Dopri5()
        dt0 = 0.001
        saveat = diffrax.SaveAt(ts=ts)

        system = diffrax.ODETerm(env.drift)

        sol = diffrax.diffeqsolve(system, solver, ts[0], ts[-1], dt0, x0, saveat=saveat, max_steps=500, 
                                  adjoint=diffrax.DirectAdjoint(), stepsize_controller=diffrax.PIDController(atol=1e-7, rtol=1e-7, dtmin=0.001))
        
        return sol.ys

    ys = jax.vmap(solve, in_axes=[None, None, 0])(env, ts, x0s)
    
    return x0s, ts, ys

In [None]:
key = jr.PRNGKey(0)
data_key, gp_key = jr.split(key)

T = 50
dt = 0.2
env = LotkaVolterra()

operator_list = [
        ("+", lambda x, y: jnp.add(x, y), 2, 0.5), 
        ("-", lambda x, y: jnp.subtract(x, y), 2, 0.1), 
        ("*", lambda x, y: jnp.multiply(x, y), 2, 0.5), 
    ]

variable_list = [["x" + str(i) for i in range(env.n_var)]]

population_size = 300
num_populations = 10
num_generations = 50

fitness_function = ODEFitnessFunction(solver=diffrax.Dopri5(), dt0 = 0.01, stepsize_controller=diffrax.PIDController(atol=1e-6, rtol=1e-6, dtmin=0.001), max_steps=300)

layer_sizes = jnp.array([env.n_var])

strategy = GeneticProgramming(num_generations, population_size, fitness_function, operator_list, variable_list, layer_sizes, num_populations = num_populations,
                        max_nodes = 15, constant_optimization_method="evolution", constant_optimization_N_offspring = 50, constant_optimization_steps = 1, size_parsimony=0.003, 
                        optimize_constants_elite=250, constant_step_size_init=0.2)

data = get_data(data_key, env, dt, T, batch_size=4)

strategy.fit(gp_key, data, verbose=True)

Input data should be formatted as: ['x0', 'x1'].
In generation 1, best fitness = 1.7050, best solution = [-1.39*x0, 0.616 - 0.426*x1]
In generation 2, best fitness = 1.6697, best solution = [1.56 - 2.3*x0, 0.46 - 0.379*x1]
In generation 3, best fitness = 1.6618, best solution = [0.224 - 2.43*x0, x0 - 0.56*x1 + 0.743]
In generation 4, best fitness = 1.6595, best solution = [-0.121*x0*x1**2 + 0.248, 0.541*x0 - 0.336*x1]
In generation 5, best fitness = 1.6582, best solution = [0.903 - 1.97*x0, 1.31*x0 - 0.581*x1]
In generation 6, best fitness = 1.6491, best solution = [1.0 - 2.52*x0, x0 - 0.562*x1 + 0.36]
In generation 7, best fitness = 1.6483, best solution = [0.901 - 2.34*x0, x0 - 0.549*x1 + 0.389]
In generation 8, best fitness = 1.6245, best solution = [-0.352*x0*(2*x0 + x1 - 0.115) + x0, x0 - 0.359*x1]
In generation 9, best fitness = 1.6244, best solution = [-0.4*x0*(2*x0 + x1 - 0.382) + x0, x0 - 0.351*x1]
In generation 10, best fitness = 1.6244, best solution = [-0.4*x0*(2*x0 + x1 - 