# The N-body problem

## Imports

In [None]:
# Magic function: enables interactive plot
%matplotlib widget
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

## Initial conditions

In [None]:
# body mA initial conditions
mA = 1e26  # mass (kg)
rA0 = np.array([1e9, 0, 0])  # initial position (m)
vA0 = np.array([0, 2e3, 0])  # initial velocity (m/s)

# body mB initial conditions
mB = 1e26  # mass (kg)
rB0 = np.array([-1e9, 0, 0])  # initial position (m)
vB0 = np.array([0, -2e3, 0])  # initial velocity (m/s)

# body mC initial conditions
mC = 1e31  # mass (kg)
rC0 = np.array([0, 0, 0])  # initial position (m)
vC0 = np.array([0, 0, 0])  # initial velocity (m/s)

In [None]:
y0 = np.stack([
    np.stack([rA0, vA0]),
    np.stack([rB0, vB0]),
    np.stack([rC0, vC0])
])

In [None]:
masses = [mA, mB, mC]

In [None]:
y0.shape

In [None]:
G = 6.67259e-11  # Gravitational constant (m**3/kg/s**2)

## The derivatives


In [None]:
def n_body_eqm_derivatives(_y, t, _G, masses):
    """
    derivatives of the equations of motion describing the n-body system
    t is unused, but we keep it for consistency with scipy requirement
    """
    derivatives = []
    for i in range(_y.shape[0]):
        ri = _y[i, 0, :]
        vi = _y[i, 1, :]

        # acceleration
        ai = _G * sum([
            masses[j] * (_y[j, 0, :] - ri) / np.linalg.norm(_y[j, 0, :] - ri)**3
        for j in set(range(_y.shape[0])) - {i}
        ])
        derivatives.append(np.stack([vi, ai]))

    derivatives = np.stack(derivatives)

    return derivatives


## Forward time evolution

In [None]:
dt = 0.001  # time step (s)
tf = 1E2  # end of simulation (s)

In [None]:
def evolve(y0, tf, dt, method, params):
    history = []
    yn = y0
    t_axis = np.arange(0, tf, dt)
    for tn in t_axis:
        yn = evolve_one_step(yn, tn, dt, method, params)
        history.append(yn.copy())

    history = np.stack(history, axis=-1)
    return history

In [None]:
def evolve_one_step(yn, tn, dt, method, params):
    if method == "euler":
        f = n_body_eqm_derivatives(yn, tn, *params)
        yn += f * dt
    elif method == "rk4":
        f1 = n_body_eqm_derivatives(yn, tn, *params)
        f2 = n_body_eqm_derivatives(yn + f1 * dt / 2, tn + dt / 2, *params)
        f3 = n_body_eqm_derivatives(yn + f2 * dt / 2, tn + dt / 2, *params)
        f4 = n_body_eqm_derivatives(yn + f3 * dt, tn + dt, *params)
        yn += (f1 + 2 * f2 + 2 * f3 + f4) * dt / 6
    return yn

## Running the simulation

In [None]:
history = evolve(y0, tf, dt, "rk4", params=(G, masses))

In [None]:
history.shape

## Visualizing the outcome

In [None]:
# Trajectories
trajectories = history[:, 0, :, ::100]
trajectories.shape

In [None]:
def compute_relative_marker_size(m, max_m):
    return 30 - 3*(np.log10(max_m) - np.log10(m))

In [None]:
plt.style.use('dark_background')

fig = plt.figure()
ax = plt.axes(projection='3d')
colors = mpl.colormaps["Set3"].colors
marker_sizes = [compute_relative_marker_size(m, max(masses)) for m in masses]

print(trajectories.shape)

xm = np.min(trajectories[:, 0, :])
xM = np.max(trajectories[:, 0, :])
ym = np.min(trajectories[:, 1, :])
yM = np.max(trajectories[:, 1, :])
zm = np.min(trajectories[:, 2, :])
zM = np.max(trajectories[:, 2, :])

def animate(frame_num):
    ax.clear()
    ax.set(xlim3d=(xm, xM), xlabel="X")
    ax.set(ylim3d=(ym, yM), xlabel="Y")
    ax.set(zlim3d=(zm, zM), xlabel="Z")

    for i in range(trajectories.shape[0]):
        color=np.array(colors[i]).reshape(1, -1)
        ax.plot3D(
            trajectories[i, 0, :frame_num],
            trajectories[i, 1, :frame_num],
            trajectories[i, 2, :frame_num],
            c=color,
        )
        ax.scatter(
            trajectories[i, 0, frame_num],
            trajectories[i, 1, frame_num],
            trajectories[i, 2, frame_num],
            c=color,
            marker='o',
            s=marker_sizes[i]
        )

anim = FuncAnimation(fig, animate, frames=trajectories.shape[-1], interval=100, repeat=False)
plt.show()

In [None]:
anim.save("N_body_animation.gif", writer="pillow", fps=30)