In [6]:
import numpy as np
from spacerocks.cbindings import calc_E_from_M, calc_M_from_E, calc_f_from_E
from spacerocks.constants import mu_bary
from spacerocks.units import Units
from spacerocks.observer import Observer

units = Units()
units.timescale = 'utc'

#from spacerocks import unit
mu = mu_bary.value
from spacerocks import SpaceRock

import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from scipy.optimize import minimize_scalar, fsolve, newton
from scipy.spatial import KDTree
import math
from astropy import units as u
import random
import itertools
import sympy as sp
from sympy import re, im

rads_to_arcsec = 206264.806247

# PlanarBody Class, which calculates all relvant orbital elements for a given orbit

In [7]:
def nice_arccos(x):
    return np.arccos(np.clip(x, -1, 1))

class PlanarBody:

    def __init__(self, r, vr, vo, t0):
        self.r0 = r
        self.vr0 = vr
        self.vo0 = vo
        self.t0 = t0
        self.incoming = self.vr0 < 0
        
        vsq = vr**2 + vo**2
        self.a = 1 / ((2 / r) - (vsq / mu))
        self.e = np.sqrt(1.0 + r**2 * vo**2 * ((vsq / mu**2) - 2.0 / (mu * r)))    
        self.parabolic = abs(self.e - 1) < 1e-8    

        argument = (1.0 - self.r0 / self.a) / self.e
        self.argument = argument
    
        if self.e < 1:
            # This is an elliptical orbit
            E0 = nice_arccos(argument)
            if self.incoming:
                E0 = 2 * np.pi - E0
    
        else: 
            if argument < 1 and not self.parabolic:
                # This happens when the orbit is very nearly parabolic, and argument ~ 0.9999
                argument = math.ceil(argument)
            # This is a hyperbolic orbit
            E0 = np.arccosh(argument)
            if self.incoming:
                E0 = - E0
        
        
        self.E0 = E0
        self.M0 = calc_M_from_E(np.atleast_1d(self.e), np.atleast_1d(self.E0)).rad
        if not self.parabolic:
            self.f0 = self.f(self.t0)

        
        

    @property
    def n(self):
        ''' Find the mean motion'''
        return np.sqrt(mu / abs(self.a)**3)

    def r(self, t):
        ''' Find the distance to the body''' 
        if self.parabolic:
            # Orbit equation for parabolic orbits

            # This original commented out code was running into some trouble, at times for parabolic orbits
            # The problem is that for for some parabolic orbits, f is very close to pi. 
            # cos (f) for these orbits runs into some trouble, for one reason because cos (f) is very close to -1
            # and the other reason is that the expression cos(f) often evaluates to the same value for two different f values
            # To correct for this, I am using sympy for the calculations, which prevides more precision. In any case, parabolic orbits
            # make up a small fraction of the total orbits, so this should not be a problem. I use sympy for all calculations
            # related to parabolic orbits

            # h = self.r0 * self.vo0
            # r = (h**2 / mu) * (1 / (1 + np.cos(self.f_parabolic(t))))
            # return r

            precision = 50
            h, muu, f, r0, vo0 = sp.symbols('h muu f r0, vo0')

            expr = (h**2 / muu) * (1 / (1 + sp.cos(f)))
            hexpr = r0 * vo0

            r0_val = sp.Float(str(self.r0), precision)
            vo0_val = sp.Float(str(self.vo0), precision)
            muu_val = sp.Float(str(mu), precision)
            f_val = sp.Float(str(self.f(t)), precision)
            h_val = hexpr.subs({r0: r0_val, vo0: vo0_val}).evalf(precision)

            result = expr.subs({h: h_val, muu: muu_val, f: f_val}).evalf(precision)

            return float(result)


        elif self.e < 1:
            # Orbit equation for elliptical orbits
            return abs(self.a * (1 - self.e * np.cos(self.E(t))))
        else:
            # Orbit equation for hyperbolic orbits
            return abs(self.a * (self.e * np.cosh(self.E(t)) - 1))
    
    def E(self, t):
        ''' Find the eccentric anomaly for the body'''
        dt = t - self.t0
        M = self.M0 + self.n * dt
        return calc_E_from_M(np.atleast_1d(self.e), np.atleast_1d(M)).rad
    
    def f(self, t):
        ''' Find the true anomaly'''
        if self.parabolic:
            # This is a parabolic orbit
            
            # Mp = mu**2 / (self.r0**3 *self.vo0**3) * (t - self.T())
            # self.Mp = Mp
            # Mp = abs(Mp)
            # f_t = 2*np.arctan((3*Mp + np.sqrt((3*Mp)**2 + 1))**(1/3) - (3*Mp + np.sqrt((3*Mp)**2 + 1))**(-1/3))
            # return abs(f_t)

            precision = 50
            Tp, t0, r0, vo0, muu, Mp = sp.symbols('Tp t0 r0 vo0, muu, Mp')

            Mp_expr = (muu**2 / (r0**3 *vo0**3)) * (t0 - Tp)
            expr =  2*sp.atan((3*Mp + sp.sqrt((3*Mp)**2 + 1))**(1/3) - (3*Mp + sp.sqrt((3*Mp)**2 + 1))**(-1/3))


            T_val = sp.Float(str(self.T()), precision)
            if isinstance(t, np.ndarray):
                t_val = sp.Float(str(t[0]), precision)
            else:
                t_val = sp.Float(str(t), precision)
            r0_val = sp.Float(str(self.r0), precision)
            vo0_val = sp.Float(str(self.vo0), precision)
            muu_val = sp.Float(str(mu), precision)
            Mp_val = Mp_expr.subs({t0: t_val, Tp: T_val, r0: r0_val, vo0: vo0_val, muu: muu_val}).evalf(precision)

            result = expr.subs({Mp: Mp_val}).evalf(precision)
            #print(result)
            return abs(float(result))

        elif self.e<1:
            # This is an elliptical orbit
            return calc_f_from_E(np.atleast_1d(self.e), np.atleast_1d(self.E(t))).rad
        else:
            # This is a hyperbolic orbit
            if self.incoming:
                return -calc_f_from_E(np.atleast_1d(self.e), np.atleast_1d(self.E(t))).rad
            else:
                return calc_f_from_E(np.atleast_1d(self.e), np.atleast_1d(self.E(t))).rad

    def longitude(self, t):
        ''' Find the longitude from the true anomaly'''
        if self.e < 1:
            # This is an elliptical orbit
            return (self.f(t) - self.f0) % (2 * np.pi)
        else:
            # This is a hyperbolic orbit
            if self.incoming:
                return -(self.f(t) - self.f0)
            else:
                return (self.f(t) - self.f0) % (2 * np.pi)
    
    

## Section for handling parabolic orbits ##

    def f0_parabolic(self):
        ''' Find f0 for parabolic orbits'''

        # h = self.r0 * self.vo0
        # argument = h**2 / (self.r0 * mu) - 1
        # f0 = np.arccos(argument)
        # return abs(f0)


        precision = 50

        h, muu, r0, vo = sp.symbols('h muu r0 vo')
        expr = sp.acos((h**2 / (r0 * muu)) - 1)
        hexpr = r0*vo

        rval = sp.Float(str(self.r0), precision)
        voval = sp.Float(str(self.vo0), precision)
        muu_val = sp.Float(str(mu), precision)
        h_val = hexpr.subs({r0: rval, vo: voval}).evalf(precision)

        result = expr.subs({h: h_val, muu: muu_val, r0: rval}).evalf(precision)
        try:
            finalres = abs(float(result))
        except TypeError:
            finalres = abs(float(im(result)))
        
        return finalres

            
        
    def T(self):
        ''' Find the time of Periapsis passing relative to reference time'''

        # def eq_to_solve(T, t, r0, vo0, f_to):
        #     Mp = mu**2 / (r0**3 *vo0**3) * (t - T)
        #     return f_to - 2*np.arctan((3*Mp + np.sqrt((3*Mp)**2 + 1))**(1/3) - (3*Mp + np.sqrt((3*Mp)**2 + 1))**(-1/3))
        
        
        # guess = 0
        # F_solution = fsolve(eq_to_solve, guess, args=(self.t0, self.r0, self.vo0, self.f0_parabolic()))
        # sol = abs(F_solution[0])
        # if self.vr0 < 0:
        #     return sol
        # else:
        #     return -sol

        precision = 50
        Tp, tsym, r0, vo0, f0, muu = sp.symbols('Tp tsym r0 vo0, f0, muu')

        expr = f0 - 2*sp.atan((3*(muu**2 / (r0**3 *vo0**3) * (tsym - Tp)) + sp.sqrt((3*(muu**2 / (r0**3 *vo0**3) * (tsym - Tp)))**2 + 1))**(1/3) - (3*(muu**2 / (r0**3 *vo0**3) * (tsym - Tp)) + sp.sqrt((3*(muu**2 / (r0**3 *vo0**3) * (tsym - Tp)))**2 + 1))**(-1/3))

        t_val = sp.Float(str(self.t0), precision)
        r0_val = sp.Float(str(self.r0), precision)
        vo0_val = sp.Float(str(self.vo0), precision)
        f0_val = sp.Float(str(self.f0_parabolic()), precision)
        muu_val = sp.Float(str(mu), precision)

        eq = expr.subs({tsym: t_val, r0: r0_val, vo0: vo0_val, f0: f0_val, muu: muu_val})

        T_initial_guess = 0
        result = sp.nsolve(eq, Tp, T_initial_guess).evalf(precision)

        if self.vr0 < 0:
            return -float(result)
        else:
            return float(result)


        



# Rectangle Class (should be hypercube or tesseract), which defines the grid which is refined via an Adaptive Mesh Refinement method

In [11]:
class Rectangle:

    def __init__(self, vertices):

        self.vertices = vertices
        self.initial_tesseract = False

        self.need_vr_subdivision = False
        self.need_vo_subdivision = False
        self.need_r_subdivision = False
        self.need_i_subdivision = False

    
    def check_location(self, r, i):
        ''' Check if the rectangle eitther the initial tesseract, or if it lies entirely outide the boundary
        of bound orbits. In this case, we do not need to subdivide the rectangle further.'''
        initial_vertices = generate_initial_tesseract_vertices(r,i)
        
    
        if set(self.vertices) == set(initial_vertices):
            self.initial_tesseract = True
            return False
        for vertex in self.vertices:
            if np.sqrt(vertex[0]**2 + vertex[1]**2) < np.sqrt(2*mu / vertex[2]):
                return False
            
     
        return True

    def calculate_condition(self, tmax, epsilon):
        values = self.get_edge_values(tmax)
    
        # all_values = [value[0] for sublist in values.values() for value in sublist]
        # all_points = [value[1] for sublist in values.values() for value in sublist]
        # vals_points = ((i, j) for i, j in zip(all_values, all_points))

        # Find the maximum value in the flattened list
        #max_value = max(vals_points, key=lambda x: x[0])
        #points = max_value[1]

        self.need_vr_subdivision = any(value[0] > epsilon for value in values['vr'])
        self.need_vo_subdivision = any(value[0] > epsilon for value in values['vo'])
        self.need_r_subdivision = any(value[0] > epsilon for value in values['r'])
        self.need_i_subdivision = any(value[0] > epsilon for value in values['i'])

        if self.initial_tesseract:
            # Ensure that we subdivide in vr if we are checking the initial tesseract
            self.need_vr_subdivision = True

        #print('Max separation: ', max_value[0])
        return 
        

    def get_edge_values(self, tmax):
        # Return a list of values associated with each edge
        #print('Getting edge values')
        values = {}
        edges = identify_edges(self.vertices)

        vr_vals = []
        for points in edges['vr']:
            vr_vals.append([(calc_max_sep_4D(points[0][0], points[0][1], points[0][2], points[0][3], points[1][0], points[1][1], points[1][2], points[1][3], tmax)), (points[0], points[1])])
        
        values['vr'] = vr_vals

        vo_vals = []
        for points in edges['vo']:
            vo_vals.append([(calc_max_sep_4D(points[0][0], points[0][1], points[0][2], points[0][3], points[1][0], points[1][1], points[1][2], points[1][3], tmax)), (points[0], points[1])])
            
        values['vo'] = vo_vals

        r_vals = []
        for points in edges['r']:
            r_vals.append([(calc_max_sep_4D(points[0][0], points[0][1], points[0][2], points[0][3], points[1][0], points[1][1], points[1][2], points[1][3], tmax)), (points[0], points[1])])
        
        values['r'] = r_vals

        i_vals = []
        for points in edges['i']:
            i_vals.append([(calc_max_sep_4D(points[0][0], points[0][1], points[0][2], points[0][3], points[1][0], points[1][1], points[1][2], points[1][3], tmax)), (points[0], points[1])])
        
        values['i'] = i_vals
 
        return values
    
    def can_subdivide(self, epsilon):
        return self.need_vr_subdivision or self.need_vo_subdivision or self.need_r_subdivision or self.need_i_subdivision
        

    def subdivide(self, vr = False, vo = False, r = False, i = False):
        # Subdivide the rectangle into four smaller rectangles
       
        min_vr = min(self.vertices, key = lambda x: x[0])[0]
        max_vr = max(self.vertices, key = lambda x: x[0])[0]
        min_vo = min(self.vertices, key = lambda x: x[1])[1]
        max_vo = max(self.vertices, key = lambda x: x[1])[1]
        min_r = min(self.vertices, key = lambda x: x[2])[2]
        max_r = max(self.vertices, key = lambda x: x[2])[2]
        min_i = min(self.vertices, key = lambda x: x[3])[3]
        max_i = max(self.vertices, key = lambda x: x[3])[3]

    
        # This could be cleaned up. We first only subdivide as far as possible for the 3D case (vr, vo, r)
        # Once we have this set of tesseracts (each with a value of psi = 0 and psi = 2pi), we then subdivide
        # further in the i direction. So i=False for all cases except the last one, for which i is the only dimension which requires a subdivision
        if vr and vo and r and i:
            vertices = generate_subdivided_tesseracts((min_vr, max_vr), (min_vo, max_vo), (min_r, max_r), (min_i, max_i), vr = True, vo = True, r = True, i = False)
            rectangles = []
            for v in vertices:
                rectangles.append(Rectangle(v))
            return rectangles
            
        
        elif vr and vo and r and not i:
            vertices = generate_subdivided_tesseracts((min_vr, max_vr), (min_vo, max_vo), (min_r, max_r), (min_i, max_i), vr = True, vo = True, r = True, i = False)
            rectangles = []
            for v in vertices:
                rectangles.append(Rectangle(v))
            return rectangles
            
        elif vr and not vo and r and i:
            vertices = generate_subdivided_tesseracts((min_vr, max_vr), (min_vo, max_vo), (min_r, max_r), (min_i, max_i), vr = True, vo = False, r = True, i = False)
            rectangles = []
            for v in vertices:
                rectangles.append(Rectangle(v))
            return rectangles

        elif not vr and vo and r and i:
            vertices = generate_subdivided_tesseracts((min_vr, max_vr), (min_vo, max_vo), (min_r, max_r), (min_i, max_i), vr = False, vo = True, r = True, i = False)
            rectangles = []
            for v in vertices:
                rectangles.append(Rectangle(v))
            return rectangles
        
        elif vr and vo and not r and i:
            vertices = generate_subdivided_tesseracts((min_vr, max_vr), (min_vo, max_vo), (min_r, max_r), (min_i, max_i), vr = True, vo = True, r = False, i = False)
            rectangles = []
            for v in vertices:
                rectangles.append(Rectangle(v))
            return rectangles

        elif vr and not vo and not r and i:
            vertices = generate_subdivided_tesseracts((min_vr, max_vr), (min_vo, max_vo), (min_r, max_r), (min_i, max_i), vr = True, vo = False, r = False, i = False)
            rectangles = []
            for v in vertices:
                rectangles.append(Rectangle(v))
            return rectangles
        
        elif not vr and vo and r and not i:
            vertices = generate_subdivided_tesseracts((min_vr, max_vr), (min_vo, max_vo), (min_r, max_r), (min_i, max_i), vr = False, vo = True, r = True, i = False)
            rectangles = []
            for v in vertices:
                rectangles.append(Rectangle(v))
            return rectangles
        
        elif not vr and vo and not r and i:
            vertices = generate_subdivided_tesseracts((min_vr, max_vr), (min_vo, max_vo), (min_r, max_r), (min_i, max_i), vr = False, vo = True, r = False, i = False)
            rectangles = []
            for v in vertices:
                rectangles.append(Rectangle(v))
            return rectangles

        
        elif not vr and not vo and r and i:
            vertices = generate_subdivided_tesseracts((min_vr, max_vr), (min_vo, max_vo), (min_r, max_r), (min_i, max_i), vr = False, vo = False, r = True, i = False)
            rectangles = []
            for v in vertices:
                rectangles.append(Rectangle(v))
            return rectangles
        
        elif vr and not vo and r and not i:
            vertices = generate_subdivided_tesseracts((min_vr, max_vr), (min_vo, max_vo), (min_r, max_r), (min_i, max_i), vr = True, vo = False, r = True, i = False)
            rectangles = []
            for v in vertices:
                rectangles.append(Rectangle(v))
            return rectangles
        
        elif vr and vo and not r and not i:
            vertices = generate_subdivided_tesseracts((min_vr, max_vr), (min_vo, max_vo), (min_r, max_r), (min_i, max_i), vr = True, vo = True, r = False, i = False)
            rectangles = []
            for v in vertices:
                rectangles.append(Rectangle(v))
            return rectangles
        
        elif vr and not vo and not r and not i:
            vertices = generate_subdivided_tesseracts((min_vr, max_vr), (min_vo, max_vo), (min_r, max_r), (min_i, max_i), vr = True, vo = False, r = False, i = False)
            rectangles = []
            for v in vertices:
                rectangles.append(Rectangle(v))
            return rectangles
        
        elif not vr and vo and not r and not i:
            vertices = generate_subdivided_tesseracts((min_vr, max_vr), (min_vo, max_vo), (min_r, max_r), (min_i, max_i), vr = False, vo = True, r = False, i = False)
            rectangles = []
            for v in vertices:
                rectangles.append(Rectangle(v))
            return rectangles
        
        elif not vr and not vo and r and not i:
            vertices = generate_subdivided_tesseracts((min_vr, max_vr), (min_vo, max_vo), (min_r, max_r), (min_i, max_i), vr = False, vo = False, r = True, i = False)
            rectangles = []
            for v in vertices:
                rectangles.append(Rectangle(v))
            return rectangles
        
        elif not vr and not vo and not r and i:
            vertices = generate_subdivided_tesseracts((min_vr, max_vr), (min_vo, max_vo), (min_r, max_r), (min_i, max_i), vr = False, vo = False, r = False, i = True)
            rectangles = []
            for v in vertices:
                rectangles.append(Rectangle(v))
            return rectangles


# Relevant functions used in the grid determination

In [12]:
def generate_subdivided_tesseracts(vr_range, vo_range, r_range, i_range, vr=True, vo=True, r=True, i=True):
    ''' Find the vertices of resultant tesseracts from an initial tesseract '''
    # Calculate midpoints for each dimension
    vr_mid = (vr_range[0] + vr_range[1]) / 2
    vo_mid = (vo_range[0] + vo_range[1]) / 2
    r_mid = (r_range[0] + r_range[1]) / 2
    i_mid = (i_range[0] + i_range[1]) / 2


    
    # Generate vertices for each smaller tesseract
    vertices = []
    for vr_min, vr_max in [(vr_range[0], vr_mid), (vr_mid, vr_range[1])] if vr else [(vr_range[0], vr_range[1])]:
        for vo_min, vo_max in [(vo_range[0], vo_mid), (vo_mid, vo_range[1])] if vo else [(vo_range[0], vo_range[1])]:
            for r_min, r_max in [(r_range[0], r_mid), (r_mid, r_range[1])] if r  else [(r_range[0], r_range[1])]:
                for i_min, i_max in [(i_range[0], i_mid), (i_mid, i_range[1])] if i else [(i_range[0], i_range[1])]:
                    vertices.append([
                        (vr_min, vo_min, r_min, i_min),
                        (vr_min, vo_min, r_min, i_max),
                        (vr_min, vo_min, r_max, i_min),
                        (vr_min, vo_min, r_max, i_max),
                        (vr_min, vo_max, r_min, i_min),
                        (vr_min, vo_max, r_min, i_max),
                        (vr_min, vo_max, r_max, i_min),
                        (vr_min, vo_max, r_max, i_max),
                        (vr_max, vo_min, r_min, i_min),
                        (vr_max, vo_min, r_min, i_max),
                        (vr_max, vo_min, r_max, i_min),
                        (vr_max, vo_min, r_max, i_max),
                        (vr_max, vo_max, r_min, i_min),
                        (vr_max, vo_max, r_min, i_max),
                        (vr_max, vo_max, r_max, i_min),
                        (vr_max, vo_max, r_max, i_max)
                    ])
    
    return vertices


def generate_initial_tesseract_vertices(r, i):
    ''' Get inital tesseract vertices for a given r and i range, which entirely encompasses the space of bound orbits'''

    # Define the ranges for each dimension
    r_min, r_max = r
    i_min, i_max = i
    
    bound = np.sqrt(2*mu / r_min)

    vr_min, vr_max = -bound, bound
    vo_min, vo_max = 1e-4, bound
  

    ranges = [(vr_min, vr_max), (vo_min, vo_max), (r_min, r_max), (i_min, i_max)]
    vertices = list(itertools.product(*ranges))
    
    return vertices

def identify_edges(vertices):
    '''Find the edges for a given set of vertices'''
    edges = {
        'vr': [],
        'vo': [],
        'r': [],
        'i': []
    }
    
    # Compare each pair of vertices
    for v1, v2 in itertools.combinations(vertices, 2):
        # Check which dimensions are different
        diff = [v1[i] != v2[i] for i in range(4)]
        
        if sum(diff) == 1:  # Exactly one dimension differs
            dim_index = diff.index(True)
            if dim_index == 0:
                edges['vr'].append((v1, v2))
            elif dim_index == 1:
                edges['vo'].append((v1, v2))
            elif dim_index == 2:
                edges['r'].append((v1, v2))
            elif dim_index == 3:
                edges['i'].append((v1, v2))
    
    return edges

def stumpff_C(z):
   
    if z == 0:
        return 1/2
    elif z> 0:
        return (1 - np.cos(np.sqrt(z))) / z
    else:
        return (np.cosh(np.sqrt(-z)) - 1) / (-z)

def stumpff_S(z):

    if z == 0:
        return 1/6
    elif z > 0:
        return (np.sqrt(z) - np.sin(np.sqrt(z))) / np.sqrt(z)**3
    else:
        return (np.sinh(np.sqrt(-z)) - np.sqrt(-z)) / np.sqrt(-z)**3
    

def universal_kepler(chi, r_0, v_r0, alpha, delta_t, mu):
    """Solve the universal Kepler equation in terms of the universal anomaly chi.

    This function is intended to be used with an iterative solution algorithm,
    such as Newton's algorithm.
    """
    z = alpha * chi**2
    first_term = r_0 * v_r0 / np.sqrt(mu) * chi**2 * stumpff_C(z)
    second_term = (1 - alpha * r_0) * chi**3 * stumpff_S(z)
    third_term = r_0 * chi
    fourth_term = np.sqrt(mu) * delta_t
    return first_term + second_term + third_term - fourth_term

def d_universal_d_chi(chi, r_0, v_r0, alpha, delta_t, mu):
    """The derivative of the universal Kepler equation in terms of the universal anomaly."""
    z = alpha * chi**2
    first_term = r_0 * v_r0 / np.sqrt(mu) * chi * (1 - z * stumpff_S(z))
    second_term = (1 - alpha * r_0) * chi**2 * stumpff_C(z)
    third_term = r_0
    return first_term + second_term + third_term

def lagrange_f(pb, chi, alpha, t):
    val = 1 - chi**2/pb.r0 * stumpff_C(chi**2 * alpha)
    return val

def lagrange_g(pb, chi, alpha, t):
    val = t - chi**3 / np.sqrt(mu) * stumpff_S(chi**2 * alpha)
    return val

def theta_function(t, psi1, psi2, pb1, pb2):
    ''' Minimize the argument representing the dot product of r1 and r2 to find the maximum separation between two orbits.
    The actual separation is given by the arccos of this value. ''' 
    
    if pb1.parabolic:
        alpha1 = 0
    else:
        alpha1 = 1 / pb1.a
    if pb2.parabolic:
        alpha2 = 0
    else:
        alpha2 = 1 / pb2.a

    chi_01 = np.sqrt(mu) * np.abs(alpha1) * t
    chi_02 = np.sqrt(mu) * np.abs(alpha2) * t
    
    chi1 = newton(
    func=universal_kepler,
    fprime=d_universal_d_chi,
    x0=chi_01,
    args=(pb1.r0, pb1.vr0, alpha1, t - pb1.t0, mu))

    chi2 = newton(
    func=universal_kepler,
    fprime=d_universal_d_chi,
    x0=chi_02,
    args=(pb2.r0, pb2.vr0, alpha2, t - pb1.t0, mu))
    

    arg1 = lagrange_f(pb1, chi1, alpha1, t - pb1.t0)*lagrange_f(pb2, chi2, alpha2, t - pb2.t0)*pb1.r0*pb2.r0
    arg2 = lagrange_f(pb1,chi1, alpha1, t - pb1.t0)*lagrange_g(pb2, chi2, alpha2, t - pb2.t0)*pb1.r0*pb2.vr0
    arg3 = lagrange_g(pb1, chi1, alpha1, t - pb1.t0)*lagrange_f(pb2,chi2, alpha2, t - pb2.t0)*pb2.r0*pb1.vr0
    arg4 = lagrange_g(pb1,chi1, alpha1, t - pb1.t0)*lagrange_g(pb2,chi2, alpha2, t - pb2.t0)*(pb1.vr0 * pb2.vr0 + pb1.vo0 * pb2.vo0 * np.cos(psi1 - psi2))
 
    theta_arg = (arg1 + arg2 + arg3 + arg4) / (pb1.r(t) * pb2.r(t))
    
    return theta_arg


def calc_max_sep_4D(vr1, vo1, r1, psi1, vr2, vo2, r2, psi2, tmax):
    ''' Function which calculates max separation. REFERENCING TO tmax/2, NOT 0'''
    pb1 = PlanarBody(r1, vr1, vo1, tmax/2)
    pb2 = PlanarBody(r2, vr2, vo2, tmax/2)

    if vr1 == vr2 and vo1 == vo2 and r1 == r2 and psi1 == psi2:
        return 0
    elif vr1 == vr2 and vo1 == vo2 and r1 == r2 and psi1 == psi2 - 2*np.pi:
        return 0
    elif vr1 == vr2 and vo1 == vo2 and r1 == r2 and psi1 == psi2 + 2*np.pi:
        return 0

    bounds = (0, tmax)
    result = minimize_scalar(theta_function, args=(psi1, psi2, pb1, pb2), bounds=bounds, method = 'bounded')

    
    if result.fun > 1:
        return 0
    else:
        return np.arccos(result.fun) * rads_to_arcsec



def recursive_amr_grid(tmax, r, i, epsilon):
    ''' The function I call to make the AMR grid. This function will return the final set of rectangles, 
    the vertices of the rectangles, the unique set of vertices, and the points inside the boundary.'''
    initial_vertices = generate_initial_tesseract_vertices(r, i)

    
    initial_rectangle = Rectangle(initial_vertices)
    print('Initial rectangle:')
    print(initial_rectangle.vertices)
    result_rectangles = recursively_subdivide(initial_rectangle, r, i, tmax, epsilon)
    result_vertices = [rect.vertices for rect in result_rectangles]
    flattened_list = [item for sublist in result_vertices for item in sublist]
    unique_set = set(flattened_list)

    # Convert the set back to a list if needed
    unique_list = list(unique_set)
    points_inside = [i for i in unique_list if i[0]**2 + i[1]**2 <= (2*mu / i[2])]
    

    return result_rectangles, result_vertices, unique_list, points_inside

def recursively_subdivide(rectangle, r, i, tmax, epsilon):
    ''' Function which does the work of recursively subdividing the tesseracts'''
    result = []
    print('Recursively Subdividing')
    queue = [rectangle]

    while queue:
        print(f'There are currently {len(queue)} rectangles in the queue')
        current_rectangle = queue.pop(0)
        
        #print('Checking if rectangle lies entirely outside of the boundary')
        if current_rectangle.check_location(r, i):
            #print(f'Rectangle lies entirely outside of the boundary')
            result.append(current_rectangle)
        else:
            #print('Rectangle lies within the boundary')
            #print('Calculating separations between vertices')
            current_rectangle.calculate_condition(tmax, epsilon)
            

            if not current_rectangle.need_vr_subdivision and not current_rectangle.need_vo_subdivision and not current_rectangle.need_r_subdivision and not current_rectangle.need_i_subdivision:
                #print('All separations are less than epsilon')
                #print('Adding rectangle to result')
                result.append(current_rectangle)

            elif not current_rectangle.need_vr_subdivision and not current_rectangle.need_vo_subdivision and not current_rectangle.need_r_subdivision and current_rectangle.need_i_subdivision:
                result.append(current_rectangle)

            elif current_rectangle.can_subdivide(epsilon):
                queue.extend(current_rectangle.subdivide(vr = current_rectangle.need_vr_subdivision, vo = current_rectangle.need_vo_subdivision, r = current_rectangle.need_r_subdivision, i = current_rectangle.need_i_subdivision))
                
    all_rects = []
   
    for i in range(len(result)):
        # Calculate how many subdivisions are needed in the i direction for each tesseract, and then make this many subdivisions, and add
        # the resulting tesseracts to the list of all tesseracts
        print('Rectangles left in result: ', len(result) - i)
        
        test_rect = result[i]
        
        new_rect1, new_rect2 = test_rect.subdivide(i = True)
        tester = test_rect.subdivide(i = True)[0]

        tester.calculate_condition(tmax, epsilon)
        number_of_subdivisions = 0
        while tester.need_i_subdivision:
            tester = tester.subdivide(i = True)[0]
            number_of_subdivisions += 1
            tester.calculate_condition(tmax, epsilon)

        rects = [new_rect1, new_rect2]
        #print('Number of subdivisions: ', number_of_subdivisions)
        for i in range(number_of_subdivisions):
            rects = [rect for sublist in [rect.subdivide(i = True) for rect in rects] for rect in sublist]
            #print(rects)

        
        all_rects.extend(rects)
    
    return all_rects


def generate_random_4d_points(vr_range, vo_range, r_range, psi_range, num_points):
    ''' Generate a set of random points in 4D space to test grid'''
    points = []
    while len(points) < num_points:
        vr = random.uniform(*vr_range)
        vo = random.uniform(*vo_range)
        r = random.uniform(*r_range)
        psi = random.uniform(*psi_range)


        if math.sqrt(vr**2 + vo**2) < math.sqrt(2 * mu / r):
            points.append((vr, vo, r, psi))
    
    return points

def test_amr_grid(tmax, r, epsilon, npoints, grid_points):
    ''' Function to test the amr grid. This needs some work, since we aren't necessarily interested in checking nearest neighbor 
    in euclidean space'''
    
    
    vr_range = (-np.sqrt(2*mu / r[0]), np.sqrt(2*mu / r[0]))
    vo_range = (0.0001, np.sqrt(2*mu / r[0]))
    psi_range = (0, 2*np.pi)

    
    test_points = generate_random_4d_points(vr_range, vo_range, r, psi_range, npoints)

    #test_points = generate_random_4d_points(npoints, r)
    kd_tree = KDTree(grid_points)

    num_neighbors = 26
    nearest_neighbor_indices = kd_tree.query(test_points, k = num_neighbors)[1]

    # Output the indices of nearest neighbors for each random point
    epsilon_vals = []
    for i, nearest_neighbors in enumerate(nearest_neighbor_indices):
        print(i)
        min_sep = 1e10
        for neighbor in nearest_neighbors:
            #sep = calc_max_sep(r, tmax, test_points[i][0], test_points[i][1], grid_points[neighbor][0], grid_points[neighbor][1])
            sep = calc_max_sep_4D(test_points[i][0], test_points[i][1], test_points[i][2], test_points[i][3], grid_points[neighbor][0], grid_points[neighbor][1], grid_points[neighbor][2], grid_points[neighbor][3], tmax)[0]
            if sep < min_sep:
                min_sep = sep
        epsilon_vals.append(min_sep)
    
        
    return epsilon_vals

# Below is the cell which makes the grid for a given set of Initial Conditions

In [16]:
tmax = 4/24
r = (35,1000)
epsilon = 1
i = (0,2*np.pi)
result_rectangles, result_vertices, all_points, points_inside = recursive_amr_grid(tmax, r, i, epsilon)

Initial rectangle:
[(-0.004114846117133663, 0.0001, 35, 0), (-0.004114846117133663, 0.0001, 35, 6.283185307179586), (-0.004114846117133663, 0.0001, 1000, 0), (-0.004114846117133663, 0.0001, 1000, 6.283185307179586), (-0.004114846117133663, 0.004114846117133663, 35, 0), (-0.004114846117133663, 0.004114846117133663, 35, 6.283185307179586), (-0.004114846117133663, 0.004114846117133663, 1000, 0), (-0.004114846117133663, 0.004114846117133663, 1000, 6.283185307179586), (0.004114846117133663, 0.0001, 35, 0), (0.004114846117133663, 0.0001, 35, 6.283185307179586), (0.004114846117133663, 0.0001, 1000, 0), (0.004114846117133663, 0.0001, 1000, 6.283185307179586), (0.004114846117133663, 0.004114846117133663, 35, 0), (0.004114846117133663, 0.004114846117133663, 35, 6.283185307179586), (0.004114846117133663, 0.004114846117133663, 1000, 0), (0.004114846117133663, 0.004114846117133663, 1000, 6.283185307179586)]
Recursively Subdividing
There are currently 1 rectangles in the queue
There are currently 8 