In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from typing import NamedTuple

import numpyro 
import numpyro.distributions as dist

In [None]:
from math import pi

from typing import Optional, Tuple, Callable, Union, List
from functools import partial

import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from jax import vmap, jit, pmap
from jax import random

# jax.config.update("jax_enable_x64", True)

import numpy as np
from astropy import units as u
from astropy import constants as c

import odisseo
from odisseo import construct_initial_state
from odisseo.integrators import leapfrog
from odisseo.dynamics import direct_acc, DIRECT_ACC, DIRECT_ACC_LAXMAP, DIRECT_ACC_FOR_LOOP, DIRECT_ACC_MATRIX
from odisseo.option_classes import SimulationConfig, SimulationParams, MNParams, NFWParams, PlummerParams, MN_POTENTIAL, NFW_POTENTIAL
from odisseo.initial_condition import Plummer_sphere, ic_two_body, sample_position_on_sphere, inclined_circular_velocity, sample_position_on_circle, inclined_position
from odisseo.utils import center_of_mass
from odisseo.time_integration import time_integration
from odisseo.units import CodeUnits
from odisseo.visualization import create_3d_gif, create_projection_gif, energy_angular_momentum_plot
from odisseo.potentials import MyamotoNagai, NFW
from odisseo.option_classes import DIFFRAX_BACKEND, DOPRI5, TSIT5, SEMIIMPLICITEULER, LEAPFROGMIDPOINT, REVERSIBLEHEUN


plt.rcParams.update({
    'font.size': 20,
    'axes.labelsize': 20,
    'xtick.labelsize': 13,
    'ytick.labelsize': 13,
    'legend.fontsize': 15,
})

In [None]:
from jax.scipy.integrate import trapezoid
from galpy.potential import MiyamotoNagaiPotential

# Functions to compute the mass of the Myamoto-Nagai and NFW potentials
@partial(jit, static_argnames=("code_untis",))
def mass_enclosed_MN(R, z, params, code_units):
    """
    Compute the mass of the Myamoto-Nagai potential at a given radius R.

    """
    a = params.MN_params.a * code_units.code_length.to(u.kpc)
    b = params.MN_params.b * code_units.code_length.to(u.kpc)
    M = params.MN_params.M * code_units.code_mass.to( u.Msun)

    # Compute the mass using the Myamoto-Nagai formula
    mp = MiyamotoNagaiPotential(amp = M, a=a, b=b,) 

    return mp.mass(R, z) * u.Msun.to(code_units.code_mass)
    
@jit
def mass_enclosed_NFW(R, params):
    """
    Compute the mass of the NFW potential at a given radius R.

    ref: wikipedia
    """
    c = params.NFW_params.c
    Mvir = params.NFW_params.Mvir
    r_s = params.NFW_params.r_s
    rho_0 = (Mvir / (4*jnp.pi * r_s**3)) * (jnp.log(1+c) - c/(1+c))**-1

    return 4*jnp.pi*rho_0*r_s**3 * (jnp.log(1 + R/r_s) - R/(r_s + R))

    

In [None]:
code_length = 10.0 * u.kpc
code_mass = 1e4 * u.Msun
G = 1 
code_units = CodeUnits(code_length, code_mass, G=G)


In [None]:
def run_simulation(key,
                   with_noise: bool = True,):

    #config param, this cannot be differentiate 
    config = SimulationConfig(N_particles = 10_000, 
                          return_snapshots = False, 
                          num_timesteps = 1000, 
                          external_accelerations=(NFW_POTENTIAL, MN_POTENTIAL,  ), 
                          acceleration_scheme = DIRECT_ACC_MATRIX,
                          softening = (0.1 * u.kpc).to(code_units.code_length).value) #default values
    
    # simulation parameters to be sampled 
    t_end = numpyro.sample("t_end", dist.Uniform(0.500, 10.0))
    Mtot_plummer = numpyro.sample("Mtot_plummer", dist.Uniform(1e3, 1e5))
    a_plummer = numpyro.sample("a_plummer", dist.Uniform(0.1, 2.0))
    Mtot_MN = numpyro.sample("M_MN", dist.Uniform(5e10, 1e11))
    a = numpyro.sample("a", dist.Uniform(1.0, 5.0))
    b = numpyro.sample("b", dist.Uniform(0.1, 1))
    Mtot_NFW = numpyro.sample("M_NFW", dist.Uniform(5e11, 1.5e12))
    r_s = numpyro.sample("r_s", dist.Uniform(1, 20.0))

    params = SimulationParams(t_end = (t_end * u.Myr).to(code_units.code_time).value,  
                          Plummer_params= PlummerParams(Mtot=(Mtot_plummer * u.Msun).to(code_units.code_mass).value,
                                                        a=(a_plummer * u.kpc).to(code_units.code_length).value),
                           MN_params= MNParams(M=(Mtot_MN * u.Msun).to(code_units.code_mass).value,
                                              a = (a * u.kpc).to(code_units.code_length).value,
                                              b = (b * u.kpc).to(code_units.code_length).value),
                          NFW_params= NFWParams(Mvir=(Mtot_NFW * u.Msun).to(code_units.code_mass).value,
                                               r_s= (r_s * u.kpc).to(code_units.code_length).value,
                                               c = 8.0),                           
                          G=G, ) 
    
    # initial conditions
    #set up the particles in the initial state
    positions, velocities, mass = Plummer_sphere(key=key, params=params, config=config)

    #put the Plummer sphere in a ciruclar orbit around the NFW halo
    ra = 200*u.kpc.to(code_units.code_length)
    e = numpyro.sample("e", dist.Uniform(0.0, 0.7)) #nuance parameter 
    rp = (1-e)/(1+e) * ra
    # sample the position of the center of mass
    # Sample phi uniformly in [0, 2π]
    phi = numpyro.sample("phi", dist.Uniform(0, 2*pi)) #nuance parameter
    
    # Sample cos(theta) uniformly in [-1, 1] to ensure uniform distribution on the sphere
    costheta = random.uniform('costheta', dist.Uniform(-1, 1)) #nuance parameter
    theta = jnp.arccos(costheta)  # Convert to theta
    
    # Convert to Cartesian coordinates
    x = rp * jnp.sin(theta) * jnp.cos(phi)
    y = rp * jnp.sin(theta) * jnp.sin(phi)
    z = rp * jnp.cos(theta)

    pos_com = jnp.stack([x, y, z], axis=-1)

    z = pos_com[0, 2]
    inclination = jnp.pi/2 - jnp.acos(z/rp)

    mass1 = mass_enclosed_MN(rp, z) + mass_enclosed_NFW(rp)
    mass2 = params.Plummer_params.Mtot 
    _, bulk_velocity, _ = ic_two_body(mass1=mass1,
                                    mass2=mass2,
                                    rp=rp,
                                    e=e,
                                    params=params)
    bulk_velocity_modulus = bulk_velocity[1, 1].reshape((1))
    vel_com = inclined_circular_velocity(pos_com, bulk_velocity_modulus, inclination)

    # Add the center of mass position and velocity to the Plummer sphere particles
    positions = positions + pos_com
    velocities = velocities + vel_com

    #initialize the initial state
    initial_state = construct_initial_state(positions, velocities)

    #time integration
    final_state = time_integration(initial_state, params, config)

    if with_noise is True:
        pass
    
    else:
        x = numpyro.deterministic("x", final_state)



    




    



    


In [None]:
def to_observable(state):
    X = state[:, 0] * code_units.code_length.to(u.kpc)
    X_sun = odisseo.utils.halo_to_sun(X)
    X_gal = odisseo.utils.sun_to_gal(X_sun) #r, b, l
    r = X_gal[:, 0]
    b = X_gal[:, 1]
    l = X_gal[:, 2]

    v_x, v_y, v_z = state[:, 1] * code_units.code_velocity(u.km/u.s)
    v_l = -v_x * jnp.sin(l) + v_y * jnp.cos(l)
    v_b = -v_x * jnp.cos(l) * jnp.sin(b) - v_y * jnp.sin(l) * jnp.sin(b) + v_z * jnp.cos(b)
    v_r = v_x * jnp.cos(l) * jnp.cos(b) + v_y * jnp.sin(l) * jnp.cos(b) + v_z * jnp.sin(b)
    mu_l = v_l / (4.74047 * r)
    mu_b = v_b / (4.74047 * r)

    return jnp.stack([r, b, l, v_r, v_b, v_l, v_r, mu_b, mu_l], axis=1)



    

