In [1]:
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '2' 
from autocvd import autocvd
autocvd(num_gpus = 1)

import jax 
import jax.numpy as jnp
from jax import jit, random
import equinox as eqx
from jax.sharding import Mesh, PartitionSpec, NamedSharding

# jax.config.update("jax_enable_x64", True)
import matplotlib.pyplot as plt

from functools import partial

import numpy as np
from astropy import units as u
from astropy import constants as c

import odisseo
from odisseo import construct_initial_state
from odisseo.integrators import leapfrog
from odisseo.dynamics import direct_acc, DIRECT_ACC, DIRECT_ACC_LAXMAP, DIRECT_ACC_FOR_LOOP, DIRECT_ACC_MATRIX, NO_SELF_GRAVITY
from odisseo.option_classes import SimulationConfig, SimulationParams, MNParams, NFWParams, PlummerParams, PSPParams, MN_POTENTIAL, NFW_POTENTIAL, PSP_POTENTIAL, DIFFRAX_BACKEND, LEAPFROG
from odisseo.option_classes import SEMIIMPLICITEULER, TSIT5
from odisseo.initial_condition import Plummer_sphere, Plummer_sphere_reparam
from odisseo.utils import center_of_mass
from odisseo.time_integration import time_integration
from odisseo.units import CodeUnits
from odisseo.visualization import create_3d_gif, create_projection_gif, energy_angular_momentum_plot
from odisseo.potentials import MyamotoNagai, NFW

from odisseo.utils import halo_to_gd1_velocity_vmap, halo_to_gd1_vmap, projection_on_GD1
from jax.test_util import check_grads
from numpyro.infer import MCMC, NUTS, AIES
import arviz as az

plt.rcParams.update({
    'font.size': 20,
    'axes.labelsize': 20,
    'xtick.labelsize': 13,
    'ytick.labelsize': 13,
    'legend.fontsize': 15,
})

plt.style.use('default')


import jax.numpy as jnp
import jax
from jax import jit
import pandas as pd
from tqdm import tqdm

code_length = 10 * u.kpc
code_mass = 1e4 * u.Msun
G = 1
code_time = 3 * u.Gyr
code_units = CodeUnits(code_length, code_mass, G=1, unit_time = code_time )  


config = SimulationConfig(N_particles = 1000, 
                          return_snapshots = True, 
                          num_snapshots = 1000, 
                          num_timesteps = 1000, 
                          external_accelerations=(NFW_POTENTIAL, MN_POTENTIAL, PSP_POTENTIAL), 
                          acceleration_scheme = DIRECT_ACC_MATRIX,
                          softening = (0.1 * u.pc).to(code_units.code_length).value,
                          integrator = DIFFRAX_BACKEND,
                          differentation_mode=TSIT5,
                          fixed_timestep=False,
                          ) #default values

params = SimulationParams(t_end = (3 * u.Gyr).to(code_units.code_time).value,  
                          Plummer_params= PlummerParams(Mtot=(10**4.5 * u.Msun).to(code_units.code_mass).value,
                                                        a=(8 * u.pc).to(code_units.code_length).value),
                           MN_params= MNParams(M = (68_193_902_782.346756 * u.Msun).to(code_units.code_mass).value,
                                              a = (3.0 * u.kpc).to(code_units.code_length).value,
                                              b = (0.280 * u.kpc).to(code_units.code_length).value),
                          NFW_params= NFWParams(Mvir=(4.3683325e11 * u.Msun).to(code_units.code_mass).value,
                                               r_s= (16.0 * u.kpc).to(code_units.code_length).value,),      
                          PSP_params= PSPParams(M = 4501365375.06545 * u.Msun.to(code_units.code_mass),
                                                alpha = 1.8, 
                                                r_c = (1.9*u.kpc).to(code_units.code_length).value),                    
                          G=code_units.G, ) 


key = random.PRNGKey(1)

#set up the particles in the initial state
positions, velocities, mass = Plummer_sphere(key=key, params=params, config=config)

#the center of mass needs to be integrated backwards in time first 
config_com = config._replace(N_particles=1,)
params_com = params._replace(t_end=-params.t_end,)

#this is the final position of the cluster, we need to integrate backwards in time 
pos_com_final = jnp.array([[11.8, 0.79, 6.4]]) * u.kpc.to(code_units.code_length)
vel_com_final = jnp.array([[109.5,-254.5,-90.3]]) * (u.km/u.s).to(code_units.code_velocity)


mass_com = jnp.array([params_com.Plummer_params.Mtot])
final_state_com = construct_initial_state(pos_com_final, vel_com_final)

snapshots_com = time_integration(final_state_com, mass_com, config_com, params_com)
pos_com, vel_com = snapshots_com.states[-1, :, 0], snapshots_com.states[-1, :, 1]


# Add the center of mass position and velocity to the Plummer sphere particles
positions = positions + pos_com
velocities = velocities + vel_com

#initialize the initial state
initial_state_stream = construct_initial_state(positions, velocities)

#run the simulation
snapshots = time_integration(initial_state_stream, mass, config, params)

final_state = snapshots.states[-1]
stream_data = projection_on_GD1(final_state, code_units=code_units,)

params_sim = params


# ----------------------------- Load observation & precompute target densities ----------------
true_GD1_observation_path = '/export/data/vgiusepp/odisseo_data/data_fix_position/true.npz'
_obs = np.load(true_GD1_observation_path)
stream_data = jnp.array(_obs['x']).reshape(1000, 6)  # will be used only to compute target densities
true_theta = jnp.array(_obs['theta'])


@jit
def run_simulation(params):

    #Final position and velocity of the center of mass
    pos_com_final = jnp.array([[11.8, 0.79, 6.4]]) * u.kpc.to(code_units.code_length)
    vel_com_final = jnp.array([[109.5,-254.5,-90.3]]) * (u.km/u.s).to(code_units.code_velocity)
    
    #we construmt the initial state of the com 
    initial_state_com = construct_initial_state(pos_com_final, vel_com_final,)

    #function that integrates the com backwards and forwards in time and then the stream, and projects it on the sky
    @jit
    def assign_params_integrate_projection(t_end):
        new_params = params_sim._replace(
                        NFW_params=params_sim.NFW_params._replace(
                            Mvir = params['M_NFW']*u.Msun.to(code_units.code_mass),
                            # r_s = params['r_s']*u.kpc.to(code_units.code_length),xs
                        ),
                        MN_params=params_sim.MN_params._replace(
                            M = params['M_MN']*u.Msun.to(code_units.code_mass),
                            # a = params['a_MN']*u.kpc.to(code_units.code_length),
                        ),
                        t_end = t_end,)
        snapshots = time_integration(initial_state_com, mass, config=config_com, params=new_params)
        stream_coordinate = jax.vmap(projection_on_GD1, in_axes=(0, None))(snapshots.states, code_units)
        return stream_coordinate

    t_end_mag = 0.2 * u.Gyr.to(code_units.code_time)
    t_end_array = jnp.array([-t_end_mag, t_end_mag])  # backward, forward
    
    # vmap over both parameters
    stream_coordinate_com = jax.vmap(assign_params_integrate_projection)(t_end_array)

    return stream_coordinate_com


@jit
def stream_loglikelihood(stream_coordinate_com, ):
    stream_coordinate_com_backward, stream_coordinate_com_forward = stream_coordinate_com[0], stream_coordinate_com[1]
    phi1_min, phi1_max = -120, 70
    phi2_min, phi2_max = -8, 2

    # Create masks for valid time steps
    mask_window_backward = (stream_coordinate_com_backward[:, 0, 1] < phi1_max) & \
                          (stream_coordinate_com_backward[:, 0, 1] > phi1_min) & \
                          (stream_coordinate_com_backward[:, 0, 2] < phi2_max) & \
                          (stream_coordinate_com_backward[:, 0, 2] > phi2_min)
    
    mask_diff_backward = jnp.ediff1d(stream_coordinate_com_backward[:, 0, 1], to_begin=1) > 0
    
    mask_window_forward = (stream_coordinate_com_forward[:, 0, 1] < phi1_max) & \
                         (stream_coordinate_com_forward[:, 0, 1] > phi1_min) & \
                         (stream_coordinate_com_forward[:, 0, 2] < phi2_max) & \
                         (stream_coordinate_com_forward[:, 0, 2] > phi2_min)
    
    mask_diff_forward = jnp.ediff1d(stream_coordinate_com_forward[:, 0, 1], to_begin=-1) < 0

    # Combined time step masks
    valid_time_backward = mask_window_backward & mask_diff_backward
    valid_time_forward = mask_window_forward & mask_diff_forward

    # Create masked coordinates for interpolation (only valid time steps)
    phi1_backward_valid = jnp.where(valid_time_backward, 
                                   stream_coordinate_com_backward[:, 0, 1], 
                                   100000.)
    
    
    phi1_forward_valid = jnp.where(valid_time_forward, 
                                  stream_coordinate_com_forward[:, 0, 1], 
                                  100000.)
    

    # Stream data masks - which data points to use for each direction
    mask_stream_backward = stream_data[:, 1] > stream_coordinate_com_backward[0, 0, 1]
    mask_stream_forward = stream_data[:, 1] < stream_coordinate_com_forward[0, 0, 1]

    
    coord_indices=jnp.array([2, 3, 4, 5])

    def interpolate_coord_backward(coord_idx):

        coord_backward_valid = jnp.where(valid_time_backward, 
                                   stream_coordinate_com_backward[:, 0, coord_idx], 
                                   100000.0)

        return jnp.interp(
            stream_data[:, 1], 
            phi1_backward_valid, 
            coord_backward_valid
        )
    
    def interpolate_coord_forward(coord_idx):

        coord_forward_valid = jnp.where(valid_time_forward, 
                                   stream_coordinate_com_forward[:, 0, coord_idx], 
                                   100000.0)

        return jnp.interp(
            stream_data[:, 1], 
            phi1_forward_valid, 
            coord_forward_valid
        )

    # Apply interpolation to all coordinates
    interp_tracks_backward = jax.vmap(interpolate_coord_backward)(coord_indices)  # Shape: (n_coords, n_data)
    interp_tracks_forward = jax.vmap(interpolate_coord_forward)(coord_indices)  # Shape: (n_coords, n_data)

    # Calculate residuals for all coordinates
    data_coords = stream_data[:, coord_indices].T  # Shape: (n_coords, n_data)

    #
    sigma = jnp.array([0.5, 10., 2., 2. ]) 
 
    # Calculate chi2 using only the appropriate data points for each direction
    residuals_backward = jnp.where(mask_stream_backward, 
                                  (data_coords - interp_tracks_backward)/sigma[:, None],
                                  0.0)
    residuals_forward = jnp.where(mask_stream_forward, 
                                 (data_coords - interp_tracks_forward)/sigma[:, None],
                                 0.0)
    
    chi2_backward = jnp.sum(residuals_backward**2) 
    chi2_forward = jnp.sum(residuals_forward**2) 
    
    # Use only backward for now (as in your original code)
    chi2 = chi2_backward + chi2_forward

    n_valid = jnp.sum(mask_stream_backward) + jnp.sum(mask_stream_forward)
    log_norm = - 0.5*n_valid * jnp.sum(jnp.log(2 * jnp.pi * sigma**2))
    log_likelihood = -0.5 * chi2 + log_norm

    return log_likelihood
    
 


# --- keep your existing imports; add these ---
import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from numpyro.infer.initialization import init_to_value, init_to_median
from numpyro.handlers import reparam
from jax import random

# --- ASSUMPTION: run_simulation(params_dict) returns the stream coordinates for given params
# --- and stream_loglikelihood(stream_coordinate_com) returns the log-likelihood (float)

# Wrap run_simulation and stream_loglikelihood in jitted functions (you already have @jit in your code).
# Here we assume run_simulation expects a dict with keys 'M_NFW','r_s','M_MN','a_MN' (floats).

# --- NUMPYRO MODEL ---
def numpyro_stream_model():
    """
    Model parameterization:
      - sample **log** masses / radii so variables are roughly on similar additive scales.
      - this is a simple, effective reparameterization to improve NUTS geometry.
    """
    # Priors in log-space (you can widen/narrow the stddev as you prefer)
    # log_M_NFW = numpyro.sample("log_M_NFW", dist.Normal(jnp.log(4.3683325e11), 1.5))
    # log_M_MN  = numpyro.sample("log_M_MN",  dist.Normal(jnp.log(6.8193902782e10), 1.5))  # (example)
    # log_r_s   = numpyro.sample("log_r_s",   dist.Normal(jnp.log(16.0), 0.6))
    # log_a_MN  = numpyro.sample("log_a_MN",  dist.Normal(jnp.log(3.0), 0.6))

    # Deterministic (expose transformed parameters for diagnostics/traces)
    # M_NFW = numpyro.deterministic("M_NFW", jnp.exp(log_M_NFW))
    # M_MN  = numpyro.deterministic("M_MN",  jnp.exp(log_M_MN))
    # r_s   = numpyro.deterministic("r_s",   jnp.exp(log_r_s))
    # a_MN  = numpyro.deterministic("a_MN",  jnp.exp(log_a_MN))

    # M_NFW = numpyro.sample("M_NFW", dist.Uniform(4.3683325e11 * 0.25, 4.3683325e11 * 2.0))

    # M_MN  = numpyro.sample("M_MN",  dist.Uniform(68_193_902_782.346756 * 0.25, 68_193_902_782.346756 * 2.0))  # (example)
    # LogUniform distribution (recommended for masses)
    log_M_NFW = numpyro.sample("log_M_NFW", dist.Uniform(jnp.log(4.3683325e11 * 0.5), jnp.log(4.3683325e11 * 2.0)))
    log_M_MN = numpyro.sample("log_M_MN", dist.Uniform(jnp.log(68_193_902_782.346756 * 0.5), jnp.log(68_193_902_782.346756 * 2.0)))

    # Transform back to physical units
    M_NFW = numpyro.deterministic("M_NFW", jnp.exp(log_M_NFW))
    M_MN = numpyro.deterministic("M_MN", jnp.exp(log_M_MN))

    params_dict = {
        "M_NFW": M_NFW,
        # "r_s": r_s,
        "M_MN": M_MN,
        # "a_MN": a_MN,
    }


    stream_coordinate_com = run_simulation(params_dict)   # returns shape (2, n_snapshots, ncols) as before

    # Register the likelihood with NumPyro
    # numpyro.factor adds an arbitrary log-probability term to the joint.
    # numpyro.factor("sim_loglik", log_like)
    phi1_min, phi1_max = -90, 10
    phi2_min, phi2_max = -8, 2

    stream_coordinate_com_backward, stream_coordinate_com_forward = stream_coordinate_com[0], stream_coordinate_com[1]
    

    # Create masks for valid time steps
    mask_window_backward = (stream_coordinate_com_backward[:, 0, 1] < phi1_max) & \
                          (stream_coordinate_com_backward[:, 0, 1] > phi1_min) & \
                          (stream_coordinate_com_backward[:, 0, 2] < phi2_max) & \
                          (stream_coordinate_com_backward[:, 0, 2] > phi2_min)
    
    mask_diff_backward = jnp.ediff1d(stream_coordinate_com_backward[:, 0, 1], to_begin=1) > 0
    # New mask - True until first False appears
    mask_diff_backward = jnp.cumprod(mask_diff_backward, dtype=bool)


    mask_window_forward = (stream_coordinate_com_forward[:, 0, 1] < phi1_max) & \
                         (stream_coordinate_com_forward[:, 0, 1] > phi1_min) & \
                         (stream_coordinate_com_forward[:, 0, 2] < phi2_max) & \
                         (stream_coordinate_com_forward[:, 0, 2] > phi2_min)
    
    mask_diff_forward = jnp.ediff1d(stream_coordinate_com_forward[:, 0, 1], to_begin=-1) < 0
    mask_diff_forward = jnp.cumprod(mask_diff_forward, dtype=bool)


    # Combined time step masks
    valid_time_backward = mask_window_backward & mask_diff_backward
    valid_time_forward = mask_window_forward & mask_diff_forward

    # Create masked coordinates for interpolation (only valid time steps)
    phi1_backward_valid = jnp.where(valid_time_backward, 
                                   stream_coordinate_com_backward[:, 0, 1], 
                                   10000.)
    
    
    phi1_forward_valid = jnp.where(valid_time_forward, 
                                  stream_coordinate_com_forward[:, 0, 1], 
                                  -10000.)
    

    # Stream data masks - which data points to use for each direction
    mask_stream_backward = stream_data[:, 1] > stream_coordinate_com_backward[0, 0, 1]
    mask_stream_forward = stream_data[:, 1] < stream_coordinate_com_forward[0, 0, 1]

    mask_evaluate_inside_track_backward = (stream_data[:, 1] < jnp.max(phi1_backward_valid)) & (stream_data[:, 1] < phi1_max)
    mask_evaluate_inside_track_forward = (stream_data[:, 1] > jnp.min(phi1_forward_valid)) & (stream_data[:, 1] > phi1_min)

    def interpolate_coord_backward(coord_idx):

        coord_backward_valid = jnp.where(valid_time_backward, 
                                   stream_coordinate_com_backward[:, 0, coord_idx], 
                                   -100000.0)

        return jnp.interp(
            jnp.where(mask_stream_backward & mask_evaluate_inside_track_backward, stream_data[:, 1], 100000.0), 
            phi1_backward_valid, 
            coord_backward_valid
        )
    
    def interpolate_coord_forward(coord_idx):

        coord_forward_valid = jnp.where(valid_time_forward, 
                                   stream_coordinate_com_forward[:, 0, coord_idx], 
                                   100000.0)

        return jnp.interp(
            -jnp.where(mask_stream_forward & mask_evaluate_inside_track_forward, stream_data[:, 1], -100000.0), 
            -phi1_forward_valid, 
            coord_forward_valid
        )
        
    coord_indices=jnp.array([2, 3, 4, 5])

    # Apply interpolation to all coordinates
    interp_tracks_backward = jax.vmap(interpolate_coord_backward)(coord_indices)  # Shape: (n_coords, n_data)
    interp_tracks_forward = jax.vmap(interpolate_coord_forward)(coord_indices)  # Shape: (n_coords, n_data)

    # Calculate residuals for all coordinates
    data_coords = stream_data[:, coord_indices].T  # Shape: (n_coords, n_data)
    sigma = jnp.array([0.5, 10., 2., 2. ])

    
    mask_correct_interpolation_backward = phi1_backward_valid < 8
    mask_correct_interpolation_forward = phi1_forward_valid > - 88

    with numpyro.plate("N", data_coords.shape[1]):
        masked_dist_backward = dist.Normal(interp_tracks_backward, sigma[:, None], ).mask(mask_stream_backward & mask_evaluate_inside_track_backward & mask_correct_interpolation_backward)
        masked_dist_forward = dist.Normal(interp_tracks_forward, sigma[:, None], ).mask(mask_stream_forward & mask_evaluate_inside_track_forward & mask_correct_interpolation_forward)
        numpyro.sample('obs_backward', masked_dist_backward, obs=data_coords,)
        numpyro.sample('obs_forward', masked_dist_forward, obs=data_coords,) 
   

# --- RUN AIES / MCMC ---
rng_key = random.PRNGKey(42)

# NUTS kernel: try dense_mass=True for complex geometry; set target_accept higher if needed
kernel = AIES(numpyro_stream_model, )  # try True if needed
mcmc = MCMC(kernel, num_warmup=100, num_samples=5000, num_chains=100, chain_method='vectorized', progress_bar=True, jit_model_args=True,)

# (Optional) choose a good init strategy:
# - init_to_median() is a reasonable generic choice if the prior is informative
# - init_to_value({'log_M_NFW': jnp.log(4.368e11), ...}) can be used to start near truth
# for demonstration we run with default init; you can pass init_strategy to NUTS()
mcmc.run(rng_key, )
# mcmc.print_summary(exclude_deterministic=False)  # include deterministic fields (M_NFW, r_s, ...)

#saving
# numpyro_data = az.from_numpyro(mcmc, extra_fields="potential_energy")
# numpyro_data.to_json('./aies')

  mcmc.run(rng_key, )
  mcmc.run(rng_key, )
sample: 100%|██████████| 5100/5100 [3:44:36<00:00,  2.64s/it, acc. prob=0.23]  


In [2]:
mcmc.print_summary(exclude_deterministic=False)  # include deterministic fields (M_NFW, r_s, ...)


                 mean       std    median      5.0%     95.0%     n_eff     r_hat
       M_MN 97683415040.00 12486968320.00 98729828352.00 74401325056.00 116830240768.00  21807.54      1.00
      M_NFW 289782267904.00 56222973952.00 277026701312.00 218418380800.00 369257447424.00  24775.69      1.00
   log_M_MN     25.30      0.13     25.32     25.09     25.53  21958.71      1.00
  log_M_NFW     26.38      0.18     26.35     26.11     26.63  24172.18      1.00



<numpyro.infer.mcmc.MCMC at 0x7f17d8592e70>

In [5]:
AIES_data = az.from_numpyro(mcmc, )

2025-10-08 02:14:11.269642: E external/xla/xla/service/gpu/gpu_hlo_schedule.cc:795] The byte size of input/output arguments (16004024048) exceeds the base limit (8654290944). This indicates an error in the calculation!
2025-10-08 02:14:11.325524: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3023] Can't reduce memory use below 0B (0 bytes) by rematerialization; only reduced to 14.90GiB (16000240000 bytes), down from 14.90GiB (16000244144 bytes) originally
2025-10-08 02:14:23.239086: W external/xla/xla/tsl/framework/bfc_allocator.cc:501] Allocator (GPU_0_bfc) ran out of memory trying to allocate 7.45GiB (rounded to 8000000000)requested by op 
If the cause is memory fragmentation maybe the environment variable 'TF_GPU_ALLOCATOR=cuda_malloc_async' will improve the situation. 
Current allocation summary follows.
Current allocation summary follows.
2025-10-08 02:14:23.239203: W external/xla/xla/tsl/framework/bfc_allocator.cc:512] ************************************

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 8000000000 bytes.