# General N-body simulations

While `jorbit` was designed primarily for solar system orbits and usually assumes that simulations should account for perturbations from the Sun and planets, it can also be used for more general N-body simulations.

In [None]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from astropy.time import Time
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

from jorbit import Particle, System
from jorbit.utils.states import CartesianState, SystemState

We usually create `System` objects using a collection of massless `Particle` objects, but we can also create one directly from a `SystemState` object that doesn't need to reference anything having to do with the solar system. Here we'll create a 3-body system whose funny looking initial condition will become explained shortly.

In [None]:
r = 1 / (2 * jnp.sin(jnp.pi / 3))
v = 1.0
initial_state = SystemState(
    tracer_positions=jnp.empty((0, 3)),
    tracer_velocities=jnp.empty((0, 3)),
    massive_positions=jnp.array(
        [
            [r * jnp.cos(0), r * jnp.sin(0), 0],
            [r * jnp.cos(2 * jnp.pi / 3), r * jnp.sin(2 * jnp.pi / 3), 0],
            [r * jnp.cos(4 * jnp.pi / 3), r * jnp.sin(4 * jnp.pi / 3), 0],
        ]
    ),
    massive_velocities=jnp.array(
        [
            [-v * jnp.sin(0), v * jnp.cos(0), 0],
            [-v * jnp.sin(2 * jnp.pi / 3), v * jnp.cos(2 * jnp.pi / 3), 0],
            [-v * jnp.sin(4 * jnp.pi / 3), v * jnp.cos(4 * jnp.pi / 3), 0],
        ]
    ),
    log_gms=jnp.array([0.0, 0.0, 0.0]),
    time=0.0,
    acceleration_func_kwargs={},
)

From this we can create a `System` object, but now we set the `gravity` argument to either "generic newtonian" for Newtonian gravity or "generic gr" for PPN corrections.

In [3]:
s = System(state=initial_state, gravity="generic newtonian")

In [4]:
times = jnp.array(jnp.linspace(0, 10.0, 500))
positions, velocities = s.integrate(times=times)

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))
ax.set_xlim(-1.5, 1.5)
ax.set_ylim(-1.5, 1.5)
ax.set_aspect("equal")

scat = ax.scatter(positions[0, :, 0], positions[0, :, 1], s=50, c=["C0", "C1", "C2"])
time_text = ax.text(0.02, 0.95, "", transform=ax.transAxes)


def update(frame):
    scat.set_offsets(positions[frame])
    time_text.set_text(f"t = {times[frame]:.2f}")
    return scat, time_text


anim = FuncAnimation(fig, update, frames=len(times), interval=50, blit=True)
plt.close()

HTML(anim.to_html5_video())