https://levelup.gitconnected.com/the-two-body-problem-in-python-6bbe4a0b2f88

### Imports

At the beginning of each notebook, we place the **imports** of the libraries we intend to use.

In our case, we will use `numpy` for handling numbers and vectors, and both `plotly` and `matplotlib` for plots.

In [None]:
# Magic function: enables interactive plot
%matplotlib widget
import numpy as np

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import plotly.express as px


### Initial conditions

Let's start with setting the kinematic initial conditions of our two bodies $A$ and $B$, namely their positions $r^A(0), r^B(0)$ and their velocities $v^A(0), v^B(0)$.

We use 3d arrays to denote the $x$, $y$ and $z$ dimensions respectively.

In [None]:
# body mA initial conditions
mA = 1e26  # mass (kg)
rA0 = np.array([1E3, 0, 0])  # initial position (km)
vA0 = np.array([10, -20, 10])  # initial velocity (km/s)

# body mB initial conditions
mB = 1e26  # mass (kg)
rB0 = np.array([-1E3, 0, 0])  # initial position (km)
vB0 = np.array([10, 40, 10])  # initial velocity (km/s)

We want to store this information in a more compact and useful way: we define a *state vector* $y(0)$ containing all initial conditions.

We structure it as a 3d tensor or *array*, in computer science terms, so that each dimension encodes independent information.

To do so we use `numpy.stack`, which stacks arrays on top of one another thus creating new dimensions.

In [None]:
y0 = np.stack([np.stack([rA0, vA0]), np.stack([rB0, vB0])])

If we inspect the shape of $y_0$ we can see that it has indeed 3 dimensions:
- the first can have two values, selecting body $A$ or $B$;
- the second can have two values as well and chooses the kinematic variable, either the position $r$ or the velocity $v$;
- the third can have three values, corresponding to the three dimensions $x$, $y$, and $z$.

In [None]:
y0.shape

We can play with indices and see what comes out of our tensor; a single element is retrieved by passing all three coordinates needed to identify it, while *slices* or projections are obtained by using a `:` in place of one or more of the coordinates.

Here is $v_y^A(0)$:

In [None]:
# body=A, variable=v, dimension=y
y0[0, 1, 1]

And here is $r^B(0)$:

In [None]:
# body=B, variable=r, all dimensions
y0[1, 0, :]

Finally, let's set the value of $G$, the gravitational constant:

In [None]:
G = 6.67259e-20  # Gravitational constant (km**3/kg/s**2)

### The two-body equations of motion


In [None]:
def two_body_eqm_derivatives(_y, t, _G, _mA, _mB):
    """
    derivatives of the equations of motion describing the two-body system
    t is unused, but we keep it for consistency with scipy requirement
    """
    rA = _y[0, 0, :]
    rB = _y[1, 0, :]

    vA = _y[0, 1, :]
    vB = _y[1, 1, :]

    # magnitude of position vector from rA to rB
    distance = np.linalg.norm(rB - rA)

    # accelerations
    aA = _G * _mA * ((rB - rA) / np.power(distance, 3))
    aB = _G * _mB * ((rA - rB) / np.power(distance, 3))

    derivatives = np.stack([np.stack([vA, aA]), np.stack([vB, aB])])

    return derivatives


### Forward time evolution of the system

In [None]:
dt = 0.001  # time step (s)
tf = 1E2  # end of simulation (s)

In [None]:
def evolve(y0, tf, dt, method, params):
    history = []
    yn = y0
    t_axis = np.arange(0, tf, dt)
    for tn in t_axis:
        yn = evolve_one_step(yn, tn, dt, method, params)
        history.append(yn.copy())

    history = np.stack(history, axis=-1)
    return history

### Euler's method

In [None]:
def evolve_one_step(yn, tn, dt, method, params):
    if method == "euler":
        f = two_body_eqm_derivatives(yn, tn, *params)
        yn += f * dt
    elif method == "rk4":
        f1 = two_body_eqm_derivatives(yn, tn, *params)
        f2 = two_body_eqm_derivatives(yn + f1 * dt / 2, tn + dt / 2, *params)
        f3 = two_body_eqm_derivatives(yn + f2 * dt / 2, tn + dt / 2, *params)
        f4 = two_body_eqm_derivatives(yn + f3 * dt, tn + dt, *params)
        yn += (f1 + 2 * f2 + 2 * f3 + f4) * dt / 6

    return yn

### Running the simulation

Finally, let's run the main `evolve` function:

In [None]:
history = evolve(y0, tf, dt, "euler", params=(G, mA, mB))

In [None]:
# if we want to compare with odeint from scipy
#from scipy.integrate import odeint
#history = odeint(two_body_eqm_derivatives, y0, np.arange(0, tf, dt), args=(G, mA, mB))

As we can see, the `history` output looks like the initial state vector $y(0)$ but now it has an additional axis storing time information: it is, indeed, the whole sequence of $y(t)$ for $t=0, \dots, t_f$.

In [None]:
history.shape

### Plotting the outcomes

First let's extract the trajectories of both bodies from `history`, by remembering how indexing is done on multi-dimensional `numpy` arrays:

In [None]:
# Trajectories
xA = history[0, 0, 0, :]
yA = history[0, 0, 1, :]
zA = history[0, 0, 2, :]

xB = history[1, 0, 0, :]
yB = history[1, 0, 1, :]
zB = history[1, 0, 2, :]

For example, we can plot a single position coordinate against time.

In [None]:
px.line(xA)

But we can do better: thanks to the `FuncAnimation` object from `matplotlib` we can display an animated 3d plot, where our two bodies can be seen dancing with one another.

In [None]:
plt.style.use('dark_background')

fig = plt.figure()
ax = plt.axes(projection='3d')


def animate(frame_num):
    ax.clear()
    ax.plot3D(xA[:frame_num], yA[:frame_num], zA[:frame_num], c='blue')
    ax.scatter(xA[frame_num], yA[frame_num], zA[frame_num], c='blue', marker='o')

    ax.plot3D(xB[:frame_num], yB[:frame_num], zB[:frame_num], c='orange')
    ax.scatter(xB[frame_num], yB[frame_num], zB[frame_num], c='orange', marker='o')

    xm = np.min(np.concatenate([xA, xB]))
    xM = np.max(np.concatenate([xA, xB]))
    ym = np.min(np.concatenate([yA, yB]))
    yM = np.max(np.concatenate([yA, yB]))
    zm = np.min(np.concatenate([zA, zB]))
    zM = np.max(np.concatenate([zA, zB]))

    ax.set(xlim3d=(xm, xM), xlabel='X')
    ax.set(ylim3d=(ym, yM), ylabel='Y')
    ax.set(zlim3d=(zm, zM), zlabel='Z')


anim = FuncAnimation(fig, animate, frames=len(xA), interval=10, repeat=False)
plt.show()