# Symbolic regression of a dynamical system

In this example, Kozax is applied to recover the state equations of the Lotka-Volterra system. The candidate solutions are integrated as a system of differential equations, after which the predictions are compared to the true observations to determine a fitness score.

In [1]:
# Specify the cores to use for XLA
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=10'

import jax
import diffrax
import jax.numpy as jnp
import jax.random as jr
import diffrax

from kozax.genetic_programming import GeneticProgramming
from kozax.fitness_functions.ODE_fitness_function import ODEFitnessFunction
from kozax.environments.SR_environments.lotka_volterra import LotkaVolterra

These device(s) are detected: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7), CpuDevice(id=8), CpuDevice(id=9)]


First the data is generated, consisting of initial conditions, time points and the true observations. Kozax provides the Lotka-Volterra environment, which is integrated with Diffrax.

In [2]:
def get_data(key, env, dt, T, batch_size=20):
    x0s = env.sample_init_states(batch_size, key)
    ts = jnp.arange(0, T, dt)

    solver = diffrax.Dopri5()
    dt0 = 0.001
    saveat = diffrax.SaveAt(ts=ts)

    system = diffrax.ODETerm(env.drift)

    def solve(ts, x0):
        # Solve the system given an initial conditions
        sol = diffrax.diffeqsolve(system, solver, ts[0], ts[-1], dt0, x0, saveat=saveat, max_steps=500, 
                                  adjoint=diffrax.DirectAdjoint(), stepsize_controller=diffrax.PIDController(atol=1e-7, rtol=1e-7, dtmin=0.001), args=jnp.array([0.0]))
        
        return sol.ys

    ys = jax.vmap(solve, in_axes=[None, 0])(ts, x0s) #Parallelize over the batch dimension
    
    return x0s, ts, ys

key = jr.PRNGKey(0)
data_key, gp_key = jr.split(key)

T = 30
dt = 0.2
env = LotkaVolterra()

# Simulate the data
data = get_data(data_key, env, dt, T, batch_size=4)
x0s, ts, ys = data

For the fitness function, we used the ODEFitnessFunction that uses Diffrax to integrate candidate solutions. It is possible to select the solver, time step, number of steps and a stepsize controller to balance efficiency and accuracy. To ensure convergence of the genetic programming algorithm, constant optimization is applied to the best candidates at every generation. The constant optimization is performed with a couple of epochs of gradient-based updates using automatic differentiation in JAX. The hyperparameters that define the constant optimization are `constant_optimization_steps` (number of iterations of constant optimization for each candidate), `optimize_constants_elite` (number of candidates that constant optimization is applied to), `constant_step_size_init` (value of the step size for sampling constants) and `optimizer_class` (the optimizer used to compute gradients).

In [3]:
#Define the nodes and hyperparameters
operator_list = [
        ("+", lambda x, y: jnp.add(x, y), 2, 0.5), 
        ("*", lambda x, y: jnp.multiply(x, y), 2, 0.5), 
    ]

variable_list = [["x" + str(i) for i in range(env.n_var)]]
layer_sizes = jnp.array([env.n_var])

population_size = 100
num_populations = 5
num_generations = 50

#Initialize the fitness function and the genetic programming strategy
fitness_function = ODEFitnessFunction(solver=diffrax.Dopri5(), dt0 = 0.01, stepsize_controller=diffrax.PIDController(atol=1e-6, rtol=1e-6, dtmin=0.001), max_steps=300)

import optax

strategy = GeneticProgramming(num_generations, population_size, fitness_function, operator_list, variable_list, layer_sizes, num_populations = num_populations,
                        constant_optimization_method="gradient", constant_optimization_steps = 10, optimizer_class = optax.adam,
                        optimize_constants_elite=100, constant_step_size_init=0.025)

Input data should be formatted as: ['x0', 'x1'].


Kozax provides a fit function that receives the data and a random key. However, it is also possible to run Kozax with an easy loop consisting of evaluating and evolving. This is useful as different input data can be provided during evaluation. In symbolic regression of dynamical systems, it helps to first optimize on a small part of the time points, and provide the full data trajectories only after a couple of generations.  

In [4]:
key = jr.PRNGKey(0)
data_key, gp_key = jr.split(key)

T = 30
dt = 0.2
env = LotkaVolterra()

# Simulate the data
data = get_data(data_key, env, dt, T, batch_size=4)
x0s, ts, ys = data

# Sample the initial population
population = strategy.initialize_population(gp_key)

# Define the number of timepoints to include in the data
end_ts = int(ts.shape[0]/2)

for g in range(num_generations):
    if g == 25: # After 25 generations, use the full data
        end_ts = ts.shape[0]

    key, eval_key, sample_key = jr.split(key, 3)
    # Evaluate the population on the data, and return the fitness
    fitness, population = strategy.evaluate_population(population, (x0s, ts[:end_ts], ys[:,:end_ts]), eval_key)

    if (g%5)==0:
        print("Generation:", g)
        strategy.print_pareto_front()

    # Evolve the population until the last generation. The fitness should be given to the evolve function.
    if g < (num_generations-1):
        population = strategy.evolve_population(population, fitness, sample_key)

Generation: 0
Complexity: 4, fitness: 31.509140014648438, equations: [x1 + 1.62, 0.423]
Complexity: 6, fitness: 5.7940993309021, equations: [x1 + 0.836, -0.376*x0]
Complexity: 8, fitness: 3.7546756267547607, equations: [0.262*x1, -0.78*x1]
Complexity: 12, fitness: 1.8652304410934448, equations: [-1.03*x0, -0.0371*x0*x1*(x0 + 0.608)]
Complexity: 14, fitness: 1.3812541961669922, equations: [-0.701*x0**2, 0.121*x0 - 0.302*x1]
Generation: 5
Complexity: 2, fitness: 3.1859934329986572, equations: [-0.894, -0.276]
Complexity: 4, fitness: 2.3859434127807617, equations: [-0.253*x0, -0.760]
Complexity: 6, fitness: 2.2180819511413574, equations: [-0.183*x1, -0.249*x1]
Complexity: 8, fitness: 1.3809990882873535, equations: [-0.284*x0*x1, -0.283*x1]
Complexity: 10, fitness: 1.3652344942092896, equations: [-0.0956*x0**2*x1, -0.29*x1]
Complexity: 12, fitness: 1.3651535511016846, equations: [-0.0273*x0**2*x1**2, -0.304*x1]
Complexity: 14, fitness: 1.2635600566864014, equations: [-0.0294*x0*x1**2, 0.40

In [5]:
strategy.print_pareto_front()

Complexity: 2, fitness: 3.115840435028076, equations: [-0.404, -0.411]
Complexity: 4, fitness: 2.1877434253692627, equations: [-1.6*x0, -1.04]
Complexity: 6, fitness: 1.5642895698547363, equations: [-0.649*x0, -0.332*x1]
Complexity: 8, fitness: 1.380143404006958, equations: [-1.78*x0, 0.188 - 0.353*x1]
Complexity: 10, fitness: 1.35891592502594, equations: [-0.0526*x0*x1**2, -0.301*x1]
Complexity: 12, fitness: 1.2722225189208984, equations: [-0.238*x0*x1, 0.579*x0 - 0.387*x1]
Complexity: 14, fitness: 1.2413502931594849, equations: [-0.309*x0*(x0 + x1 - 0.111) + x0, -0.28*x1]
Complexity: 16, fitness: 0.14472123980522156, equations: [-0.382*x0*x1 + x0, 0.103*x0*x1 - 0.431*x1]
Complexity: 20, fitness: 0.11848549544811249, equations: [-0.387*x0*(x1 - 0.0732) + x0, 0.102*x0*x1 - 0.42*x1]
Complexity: 22, fitness: 0.049823880195617676, equations: [-0.404*x0*(0.976*x1 - 0.211) + x0, 0.102*x0*x1 - 0.404*x1]
