In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from ode_solver import solve_ode, EulerRichardson, RungeKutta

def single_body(t, x, p):
    """
    ODE function for a single body orbiting a fixed central mass.
    This is a simplified version of the n_body function for a single orbiting body.

    Parameters:
    t (float): time (unused but required for ODE solver compatibility)
    x (np.array): state vector containing positions and velocities [rx, vx, ry, vy]
    p (dict): parameters dictionary containing:
        - m (list): mass of the orbiting body (unused in this case)
        - G (float): gravitational constant
        - M (float): mass of the central body

    Returns:
    np.array: derivative of state vector [vx, ax, vy, ay]
    """
    # Extract parameters
    G = p['G']
    M = p['M']

    # Extract positions
    rx, vx, ry, vy = x

    # Calculate distance
    r = np.sqrt(rx**2 + ry**2)

    # Calculate accelerations (gravitational force)
    ax = -G * M * rx / r**3
    ay = -G * M * ry / r**3

    return np.array([vx, ax, vy, ay])

def simulate_orbit(x0, vy0, algorithm=RungeKutta, dt=0.01, t_max=10.0):
    """
    Simulate an orbit with given initial conditions

    Parameters:
    x0 (float): Initial x position
    vy0 (float): Initial y velocity
    algorithm: ODE solver algorithm
    dt (float): Time step
    t_max (float): Maximum simulation time

    Returns:
    tuple: (t, positions, velocities, speeds)
    """
    # Initial conditions: [x, vx, y, vy]
    initial_state = np.array([x0, 0.0, 0.0, vy0])

    # Parameters (in astronomical units, where G*M_sun = 4π²)
    params = {
        'G': 4 * np.pi**2,  # Gravitational constant in AU^3/yr^2
        'M': 1.0,           # Mass of central body (sun) in solar masses
        'm': [1e-6]         # Mass of orbiting body (planet) in solar masses (not used)
    }

    # Time span
    t_span = [0, t_max]

    # Solve the ODE
    t, y = solve_ode(single_body, t_span, initial_state, algorithm, params, first_step=dt)

    # Extract positions and velocities
    positions = y[:, [0, 2]]  # x, y positions
    velocities = y[:, [1, 3]]  # vx, vy velocities

    # Calculate speeds
    speeds = np.sqrt(velocities[:, 0]**2 + velocities[:, 1]**2)

    return t, positions, velocities, speeds

def plot_orbit(t, positions, speeds, title="Elliptical Orbit"):
    """Plot the orbit and the speed over time"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    # Plot orbit
    ax1.plot(positions[:, 0], positions[:, 1])
    ax1.plot(0, 0, 'ro', markersize=10)  # Central body
    ax1.set_xlabel('x (AU)')
    ax1.set_ylabel('y (AU)')
    ax1.set_title('Orbit')
    ax1.grid(True)
    ax1.set_aspect('equal')

    # Plot speed over time
    ax2.plot(t, speeds)
    ax2.set_xlabel('Time (years)')
    ax2.set_ylabel('Speed (AU/year)')
    ax2.set_title('Speed vs Time')
    ax2.grid(True)

    plt.suptitle(title)
    plt.tight_layout()
    return fig

# Try different initial conditions
test_cases = [
    {'x0': 1.0, 'vy0': 6.28, 'title': 'Nearly Circular Orbit (x0=1.0, vy0=6.28)'},
    {'x0': 1.5, 'vy0': 5.0, 'title': 'Elliptical Orbit (x0=1.5, vy0=5.0)'},
    {'x0': 2.0, 'vy0': 4.0, 'title': 'More Eccentric Orbit (x0=2.0, vy0=4.0)'},
    {'x0': 2.5, 'vy0': 3.5, 'title': 'Highly Eccentric Orbit (x0=2.5, vy0=3.5)'}
]

# Run simulations and create plots
for i, case in enumerate(test_cases):
    t, positions, velocities, speeds = simulate_orbit(
        case['x0'], case['vy0'], algorithm=RungeKutta, dt=0.01, t_max=10.0
    )

    # Find where the speed is maximum and minimum
    max_speed_idx = np.argmax(speeds)
    min_speed_idx = np.argmin(speeds)

    fig = plot_orbit(t, positions, speeds, case['title'])

    # Mark the points of maximum and minimum speed on the orbit
    max_pos = positions[max_speed_idx]
    min_pos = positions[min_speed_idx]

    ax1 = fig.axes[0]
    ax1.plot(max_pos[0], max_pos[1], 'gx', markersize=10, label=f'Max Speed: {speeds[max_speed_idx]:.2f}')
    ax1.plot(min_pos[0], min_pos[1], 'bx', markersize=10, label=f'Min Speed: {speeds[min_speed_idx]:.2f}')
    ax1.legend()

    ax2 = fig.axes[1]
    ax2.axhline(y=speeds[max_speed_idx], color='g', linestyle='--')
    ax2.axhline(y=speeds[min_speed_idx], color='b', linestyle='--')

    plt.savefig(f'orbit_case_{i+1}.png')
    plt.close(fig)

# Create an animation of the most interesting orbit
def animate_orbit(x0=2.0, vy0=4.0, algorithm=RungeKutta, dt=0.01, t_max=5.0, save=False):
    t, positions, velocities, speeds = simulate_orbit(
        x0, vy0, algorithm=algorithm, dt=dt, t_max=t_max
    )

    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_xlim(np.min(positions[:, 0]) - 0.5, np.max(positions[:, 0]) + 0.5)
    ax.set_ylim(np.min(positions[:, 1]) - 0.5, np.max(positions[:, 1]) + 0.5)
    ax.set_xlabel('x (AU)')
    ax.set_ylabel('y (AU)')
    ax.set_title(f'Elliptical Orbit Animation (x0={x0}, vy0={vy0})')
    ax.grid(True)

    # Plot the complete orbit path
    ax.plot(positions[:, 0], positions[:, 1], 'b-', alpha=0.3)

    # Central body (sun)
    ax.plot(0, 0, 'ro', markersize=10)

    # Planet and trail
    line, = ax.plot([], [], 'b-')
    point, = ax.plot([], [], 'bo', markersize=6)

    # Speed indicator text
    speed_text = ax.text(0.02, 0.95, '', transform=ax.transAxes)

    def init():
        line.set_data([], [])
        point.set_data([], [])
        speed_text.set_text('')
        return line, point, speed_text

    def update(frame):
        if frame < len(positions):
            # Update trail
            trail_length = 100  # length of the trail
            start = max(0, frame - trail_length)
            line.set_data(positions[start:frame, 0], positions[start:frame, 1])

            # Update planet position
            point.set_data(positions[frame, 0], positions[frame, 1])

            # Update speed text
            speed_text.set_text(f'Speed: {speeds[frame]:.2f} AU/year')

        return line, point, speed_text

    ani = FuncAnimation(fig, update, frames=len(positions),
                        init_func=init, blit=True, interval=20)

    if save:
        ani.save('elliptical_orbit.gif', writer='pillow', fps=30)

    plt.tight_layout()
    plt.show()

    return ani

# Uncomment to create an animation of one of the orbits
# animate_orbit(x0=2.0, vy0=4.0, save=True)

# Calculate and print orbital characteristics for each case
print("Orbital Characteristics Summary:")
print("-" * 60)
print(f"{'Initial Conditions':<20} {'Eccentricity':<15} {'Period (years)':<15} {'Perihelion':<15} {'Aphelion':<15}")
print("-" * 60)

for case in test_cases:
    x0, vy0 = case['x0'], case['vy0']
    t, positions, velocities, speeds = simulate_orbit(
        x0, vy0, algorithm=RungeKutta, dt=0.01, t_max=10.0
    )

    # Calculate orbital characteristics
    # Find perihelion and aphelion
    r = np.sqrt(positions[:, 0]**2 + positions[:, 1]**2)
    perihelion = np.min(r)
    aphelion = np.max(r)

    # Calculate eccentricity from perihelion and aphelion
    eccentricity = (aphelion - perihelion) / (aphelion + perihelion)

    # Estimate period (find time between consecutive perihelion passages)
    # Need enough time to complete at least one orbit
    min_indices = []
    for i in range(1, len(r) - 1):
        if r[i] < r[i-1] and r[i] < r[i+1] and len(min_indices) < 2:
            min_indices.append(i)

    if len(min_indices) >= 2:
        period = t[min_indices[1]] - t[min_indices[0]]
    else:
        period = np.nan  # Not enough data to estimate period

    print(f"x0={x0}, vy0={vy0:<10.2f} {eccentricity:<15.4f} {period:<15.4f} {perihelion:<15.4f} {aphelion:<15.4f}")

# Analytical verification
print("\nAnalytical Verification (Kepler's Laws):")
print("-" * 60)
print("For an elliptical orbit with the sun at one focus:")
print("1. The orbit is an ellipse with eccentricity e")
print("2. The speed is highest at perihelion and lowest at aphelion")
print("3. T² ∝ a³ (where T is the period and a is the semi-major axis)")
print("-" * 60)

# Calculate semi-major axis and expected period for each case
for case in test_cases:
    x0, vy0 = case['x0'], case['vy0']
    t, positions, velocities, speeds = simulate_orbit(
        x0, vy0, algorithm=RungeKutta, dt=0.01, t_max=10.0
    )

    # Calculate perihelion and aphelion
    r = np.sqrt(positions[:, 0]**2 + positions[:, 1]**2)
    perihelion = np.min(r)
    aphelion = np.max(r)

    # Semi-major axis
    a = (perihelion + aphelion) / 2

    # According to Kepler's third law: T² = (4π²/GM) * a³
    # In our units, GM = 4π², so T² = a³
    expected_period = np.sqrt(a**3)

    print(f"x0={x0}, vy0={vy0:<10.2f} Semi-major axis: {a:<10.4f} Expected period: {expected_period:<10.4f}")

Orbital Characteristics Summary:
------------------------------------------------------------
Initial Conditions   Eccentricity    Period (years)  Perihelion      Aphelion       
------------------------------------------------------------
x0=1.0, vy0=6.28       0.0010          1.0000          0.9980          1.0000         
x0=1.5, vy0=5.00       0.0501          1.7100          1.3568          1.5000         
x0=2.0, vy0=4.00       0.1894          2.1800          1.3630          2.0000         
x0=2.5, vy0=3.50       0.2243          2.9200          1.5841          2.5000         

Analytical Verification (Kepler's Laws):
------------------------------------------------------------
For an elliptical orbit with the sun at one focus:
1. The orbit is an ellipse with eccentricity e
2. The speed is highest at perihelion and lowest at aphelion
3. T² ∝ a³ (where T is the period and a is the semi-major axis)
------------------------------------------------------------
x0=1.0, vy0=6.28       Se