# The N-body problem

### Imports

In [1]:
# Magic function: enables interactive plot
%matplotlib widget
import numpy as np

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

from dataclasses import dataclass

### Initial conditions

In [2]:
class Body:
    def __init__(self, name: str, mass: float, initial_position: np.array, initial_velocity: np.array):
        self.name = name
        self.mass = mass
        self.positions = [initial_position]
        self.velocities = [initial_velocity]

    def get_last_position(self):
        return self.positions[-1]

    def get_last_velocity(self):
        return self.velocities[-1]

    def update_state(self, new_position: np.array, new_velocity: np.array):
        self.positions.append(new_position)
        self.velocities.append(new_velocity)
        
    
    

In [3]:
Body1 = Body(name="body1",
             mass = 1e26, # (kg)
             initial_position = np.array([2E4, 0, 0]), # (km)
             initial_velocity = np.array([0, -10, 0])  # (km/s)
            ) 

Body2 = Body(name="body2",
             mass = 1e26, # (kg)
             initial_position = np.array([-2E4, 0, 0]), # (km)
             initial_velocity = np.array([0, 10, 0])    # (km/s)
            ) 

Body3 = Body(name="body3",
             mass = 1e31, # (kg)
             initial_position = np.array([0, 0, 0]), # (km)
             initial_velocity = np.array([0, 0, 0])  # (km/s)
            ) 

In [4]:
G = 6.67259e-20  # Gravitational constant (km**3/kg/s**2)

bodies = [Body1, Body2, Body3]

In [5]:
from typing import Callable

In [18]:
class SolarSystem:
    def __init__(self, planets: list[Body], increment_func: Callable):
        self.planets = planets
        self.increment_func = increment_func

    def evolve(self, n_steps: int):
        tn=0
        for n in range(n_steps):
            tn+=n*dt
            increments = self.increment_func(tn)
            for increment in increments:
                a, v = increment[0], increment[1]
                body.update_state(
                    position=body.position + v*dt,
                    velocity=body.velocity + a*dt
                )
                    

    def eqm_derivatives(self):
        derivatives = []
        for body_A in self.planets:
            r_A = body_A.get_last_position()
            m_A = body_A.mass
            v_A = body_A.get_last_velocity()
    
            a_A = 0
            for body_B in self.planets:
                if body_B.name != body_A.name:
                    r_B = body_B.get_last_position()
                    
                    a_A += _G * m_A * (r_B - r_A) / np.linalg.norm(r_B - r_A)**3
            
            derivatives.append(np.stack([v_A, a_A]))
    
        derivatives = np.stack(derivatives)
    
        return derivatives


    def euler(tn):
        f = self.eqm_derivatives(tn)
        return f
    
    def runge_kutta4(tn):
        f1 = self.eqm_derivatives(tn)
        f2 = self.eqm_derivatives(bodies + f1 * dt / 2, tn + dt / 2)  # TODO: fix this
        f3 = self.eqm_derivatives(bodies + f2 * dt / 2, tn + dt / 2)
        f4 = self.eqm_derivatives(bodies + f3 * dt, tn + dt)
        return (f1 + 2 * f2 + 2 * f3 + f4) / 6
    

### Forward time evolution

In [19]:
dt = 0.1  # time step (s)
tf = 1E3  # end of simulation (s)

### Update methods

### Running the simulation

In [21]:
solar_system = SolarSystem(
    planets=bodies,
    increment_func=runge_kutta4
)
solar_system.evolve(n_steps=100)

TypeError: SolarSystem.eqm_derivatives() takes 1 positional argument but 3 were given

### Plotting the outcomes

In [None]:
# Trajectories
trajectories = history[:, 0, :, slice(None, None, 100)]
trajectories.shape

In [None]:
def compute_relative_marker_size(m, max_m):
    return 30 - 3*(np.log10(max_m) - np.log10(m))

In [None]:
plt.style.use('dark_background')

fig = plt.figure()
ax = plt.axes(projection='3d')
colors = mpl.colormaps["Set3"].colors
marker_sizes = [compute_relative_marker_size(m, max(masses)) for m in masses]

xm = np.min(trajectories[:, 0, :])
xM = np.max(trajectories[:, 0, :])
ym = np.min(trajectories[:, 1, :])
yM = np.max(trajectories[:, 1, :])
zm = np.min(trajectories[:, 2, :])
zM = np.max(trajectories[:, 2, :])

def animate(frame_num):
    ax.clear()
    ax.set(xlim3d=(xm, xM), xlabel="X")
    ax.set(ylim3d=(ym, yM), xlabel="Y")
    ax.set(zlim3d=(zm, zM), xlabel="Z")

    for i in range(trajectories.shape[0]):
        color=np.array(colors[i]).reshape(1, -1)
        ax.plot3D(
            trajectories[i, 0, :frame_num],
            trajectories[i, 1, :frame_num],
            trajectories[i, 2, :frame_num],
            c=color,
        )
        ax.scatter(
            trajectories[i, 0, frame_num],
            trajectories[i, 1, frame_num],
            trajectories[i, 2, frame_num],
            c=color,
            marker='o',
            s=marker_sizes[i]
        )

anim = FuncAnimation(fig, animate, frames=trajectories.shape[-1], interval=100, repeat=False)
plt.show()