# Deep Dive

The `Particle` and `System` classes are convenient for simple simulations, but they are built on top of much more flexible individual functions. Here's we'll demonstrate how some of them come together to move our particles around.

In [1]:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from astropy.time import Time

from jorbit.accelerations import create_newtonian_ephemeris_acceleration_func
from jorbit.accelerations.newtonian import newtonian_gravity
from jorbit.accelerations.gr import ppn_gravity
from jorbit.ephemeris import Ephemeris
from jorbit.integrators.ias15 import ias15_evolve, initialize_ias15_integrator_state
from jorbit.utils.states import SystemState, IAS15IntegratorState


Here's a simple situation involving a handful of small particles and a few massive ones. I guarantee the system will drift since we're not in the center of mass frame, and it's likely that some of these particles are unbounded. But for our purposes, we're just going to let them go and see what happens.

In [2]:
n_tracer_particles = 10
n_massive_particles = 3

# the underlying state representation behind Particle and System
s = SystemState(
    tracer_positions=jax.random.uniform(jax.random.PRNGKey(0), (n_tracer_particles, 3))*10,
    tracer_velocities=jax.random.uniform(jax.random.PRNGKey(1), (n_tracer_particles, 3)),
    massive_positions=jax.random.uniform(jax.random.PRNGKey(2), (n_massive_particles, 3))*10,
    massive_velocities=jax.random.uniform(jax.random.PRNGKey(3), (n_massive_particles, 3)),
    log_gms=jnp.log(jax.random.uniform(jax.random.PRNGKey(3), (n_massive_particles,))*1e-3),
    acceleration_func_kwargs={},
    time=0.0
)

# ias15_evolve requires any jax partialized function that takes in a SystemState and
# returns a vector of accelerations of the same shape as the positions, ordered by
# massive particles first. This function can be time-dependent and/or include
# time-dependent parameters: that's why we include acceleration_func_kwargs and time
# in the SystemState. In the usual solar system integration case, we use the time
# attribute to compute the positions of the perturbing planets/asteriods at the timestep
# in question
acceleration_func = jax.tree_util.Partial(newtonian_gravity)

# we need to initialize the integrator with the starting acceleration values
a0 = acceleration_func(s)
init_integrator = initialize_ias15_integrator_state(a0=a0)

# now we run it
positions, velocities, final_system_state, final_integrator_state = ias15_evolve(
    initial_system_state=s,
    times=jnp.linspace(0, 10, 10), # actual end times--not timesteps. it picks the timesteps
    acceleration_func=acceleration_func,
    initial_integrator_state=init_integrator,
)

Right now, the `newtonian_acceleration` function is the only built-in that's optimized for large systems. It splits the massless tracer particles from the massive ones to avoid unnecessary pairwise calculations, which lets us evaluate the accelerations of much larger systems without a problem. However actually *integrating* those accelerations is still slower than ideal, so keep your systems small for now

In [3]:
# same as before, but now way more tracer particles

n_tracer_particles = int(1e6)
n_massive_particles = 3

s = SystemState(
    tracer_positions=jax.random.uniform(jax.random.PRNGKey(0), (n_tracer_particles, 3))*10,
    tracer_velocities=jax.random.uniform(jax.random.PRNGKey(1), (n_tracer_particles, 3)),
    massive_positions=jax.random.uniform(jax.random.PRNGKey(2), (n_massive_particles, 3))*10,
    massive_velocities=jax.random.uniform(jax.random.PRNGKey(3), (n_massive_particles, 3)),
    log_gms=jnp.log(jax.random.uniform(jax.random.PRNGKey(3), (n_massive_particles,))*1e-3),
    acceleration_func_kwargs={},
    time=0.0
)

acceleration_func = jax.tree_util.Partial(newtonian_gravity)

a0 = acceleration_func(s) # run it once to compile

%timeit acceleration_func(s).block_until_ready()

10.2 ms ± 286 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


When we're within the solar system, instead of using these vanilla gravitational acceleration functions, we instead use ones that take into account the perturbations of all the planets as set by the DE440 ephemeris. This happens automatically in the `Particle` and `System` classes, but here we'll do it manually.

First, let's create an `Ephemeris` object that can extract data from our local copy of the DE440 ephemeris:

In [4]:
eph = Ephemeris(
    ssos="default planets",
    earliest_time=Time("1980-01-01"),
    latest_time=Time("2050-01-01"),
)

This creates the nicely-facing public object that can serve up the positions and velocities of the planets at any time:

In [5]:
eph.state(Time("2000-01-01"))

{'sun': {'x': <Quantity [-0.00713986, -0.00264396, -0.00092139] AU>,
  'v': <Quantity [ 5.37426823e-06, -6.76193952e-06, -3.03437408e-06] AU / d>,
  'log_gm': Array(-8.12544774, dtype=float64, weak_type=True)},
 'mercury': {'x': <Quantity [-0.14785222, -0.40063289, -0.198918  ] AU>,
  'v': <Quantity [ 0.02117455, -0.00551464, -0.00514067] AU / d>,
  'log_gm': Array(-23.73665301, dtype=float64, weak_type=True)},
 'venus': {'x': <Quantity [-0.7257697 , -0.03968176,  0.02789532] AU>,
  'v': <Quantity [ 0.00051933, -0.01851507, -0.0083622 ] AU / d>,
  'log_gm': Array(-21.045753, dtype=float64, weak_type=True)},
 'earth': {'x': <Quantity [-0.17567731,  0.88619693,  0.3844338 ] AU>,
  'v': <Quantity [-0.01722853, -0.00276646, -0.00119947] AU / d>,
  'log_gm': Array(-20.84118348, dtype=float64, weak_type=True)},
 'moon': {'x': <Quantity [-0.17780043,  0.88461595,  0.3840147 ] AU>,
  'v': <Quantity [-0.01690458, -0.0031899 , -0.0013841 ] AU / d>,
  'log_gm': Array(-25.23933649, dtype=float64, 

But, it also contains a pytree-compatible JAX class with the same functionality called an EphemerisProcessor:

In [6]:
eph.processor.state(Time("2000-01-01").tdb.jd)

(Array([[-7.13986335e-03, -2.64396337e-03, -9.21394198e-04],
        [-1.47852217e-01, -4.00632892e-01, -1.98918003e-01],
        [-7.25769699e-01, -3.96817640e-02,  2.78953240e-02],
        [-1.75677314e-01,  8.86196930e-01,  3.84433804e-01],
        [-1.77800434e-01,  8.84615947e-01,  3.84014702e-01],
        [ 1.38322176e+00, -8.13948942e-03, -4.10352972e-02],
        [ 3.99631685e+00,  2.73099757e+00,  1.07327637e+00],
        [ 6.40141168e+00,  6.17025198e+00,  2.27302953e+00],
        [ 1.44233796e+01, -1.25101393e+01, -5.68313086e+00],
        [ 1.68036194e+01, -2.29835774e+01, -9.82565798e+00],
        [-9.88400421e+00, -2.79809491e+01, -5.75398118e+00]],      dtype=float64),
 Array([[ 5.37426823e-06, -6.76193952e-06, -3.03437408e-06],
        [ 2.11745508e-02, -5.51463941e-03, -5.14066968e-03],
        [ 5.19329969e-04, -1.85150738e-02, -8.36219771e-03],
        [-1.72285335e-02, -2.76645660e-03, -1.19946950e-03],
        [-1.69045775e-02, -3.18990180e-03, -1.38409671e-03],
  

We can use this `EphemerisProcessor` to build an acceleration function:

In [7]:
def func(inputs: SystemState) -> jnp.ndarray:
    perturber_xs, perturber_vs = ephem_processor.state(inputs.time)
    perturber_log_gms = ephem_processor.log_gms

    new_state = SystemState(
        massive_positions=jnp.concatenate([perturber_xs, inputs.massive_positions]),
        massive_velocities=jnp.concatenate(
            [perturber_vs, inputs.massive_velocities]
        ),
        tracer_positions=inputs.tracer_positions,
        tracer_velocities=inputs.tracer_velocities,
        log_gms=jnp.concatenate([perturber_log_gms, inputs.log_gms]),
        time=inputs.time,
        acceleration_func_kwargs=inputs.acceleration_func_kwargs,
    )

    accs = newtonian_gravity(new_state)

    num_perturbers = perturber_xs.shape[0]
    return accs[num_perturbers:]

acceleration_func = jax.tree_util.Partial(func)

This can now be used in with `ias15_evolve` just like the simpler `newtonian_gravity` function: now whenever we ask for the acceleration of a `SystemState` of particles, it'll compute the positions and velocities of the perturbing planets at that time, tack them onto the SystemState, compute self-consistent accelerations for everything, then cleave off the perturbers again at the end.