In [1]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Use only the first GPU
from typing import Optional, Tuple, Callable, Union, List
from functools import partial

import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from jax import vmap, jit
from jax import random
jax.config.update("jax_enable_x64", True)

import numpy as np
from astropy import units as u
from astropy import constants as c

import jdgsim
from jdgsim import construct_initial_state
from jdgsim.integrators import leapfrog
from jdgsim.dynamics import direct_acc, DIRECT_ACC, DIRECT_ACC_LAXMAP
from jdgsim.option_classes import SimulationConfig, SimulationParams, NFWParams
from jdgsim.initial_condition import Plummer_sphere, ic_two_body
from jdgsim.utils import center_of_mass, E_tot, Angular_momentum
from jdgsim.time_integration import time_integration
from jdgsim.units import CodeUnits

plt.rcParams.update({
    'font.size': 15,
    'axes.labelsize': 15,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 15,
})

# 2 body problem

In [2]:
code_length = 0.10 * u.kpc
code_mass = 1 * u.Msun
G = 1 
code_units = CodeUnits(code_length, code_mass, G=G)

In [7]:
# Define the 
config = SimulationConfig(N_particles=2, 
                          return_snapshots=False, 
                          num_snapshots=100,
                          num_timesteps=10, 
                          acceleration_scheme=DIRECT_ACC,
                          double_map=True, 
                          external_accelerations=(), softening=1e-3) #default values

params = SimulationParams(t_end = (0.1*u.Gyr).to(code_units.code_time).value, 
                          G=1) #default values

print(config)
print(params)

mass1 = (100*u.Msun).to(code_units.code_mass).value
mass2 = (1*u.Msun).to(code_units.code_mass).value
rp = (0.01*u.kpc).to(code_units.code_length).value
pos, vel, mass = ic_two_body(mass1, 
                             mass2, 
                             rp=rp, 
                             e=0., 
                             config=config, 
                             params=params)
mass = jnp.array([mass1, mass2])
initial_state = construct_initial_state(pos, vel)
target_state = time_integration(initial_state, mass, config, params)
energy_target, Lz_target = E_tot(target_state, mass, config, params), Angular_momentum(target_state, mass, )[2]
print(f"lo10(- Total Energy): {jnp.log10(-energy_target)}, Lz: {Lz_target}")



def time_integration_for_mass_grad(big_mass, ):
    params = SimulationParams(t_end = (0.1*u.Gyr).to(code_units.code_time).value, 
                            G=1) #default values
    
    mass1 = big_mass
    mass2 = (1*u.Msun).to(code_units.code_mass).value
    rp = (0.01*u.kpc).to(code_units.code_length).value
    pos, vel, mass = ic_two_body(mass1, 
                                mass2, 
                                rp=rp, 
                                e=0., 
                                config=config, 
                                params=params)
    mass = jnp.array([mass1, mass2])
    initial_state = construct_initial_state(pos, vel)
    final_state = time_integration(initial_state, mass, config, params)
    energy, Lz = E_tot(final_state, mass, config, params), Angular_momentum(final_state, mass,)[2]
    loss = (abs((energy - energy_target)) + abs((Lz - Lz_target)))/2
    # jax.debug.print(f"Loss: {loss}")
    return loss


# Calculate the value of the function and the gradient wrt the total mass of the plummer sphere
Mtot = (50 * u.Msun).to(code_units.code_mass).value
loss, grad = jax.value_and_grad(time_integration_for_mass_grad)(Mtot, )
print("Gradient of the total mass of the Plummer sphere:\n", grad)
print("Loss:\n", loss)  


SimulationConfig(N_particles=2, dimensions=3, return_snapshots=False, num_snapshots=100, fixed_timestep=True, num_timesteps=10, softening=0.001, integrator=0, acceleration_scheme=0, batch_size=10000, double_map=True, external_accelerations=())
SimulationParams(G=1, t_end=0.006707087409203456, Plummer_params=PlummerParams(a=<Quantity 7. kpc>, Mtot=<Quantity 1. solMass>), NFW_params=NFWParams(Mvir=<Quantity 1.62e+11 solMass>, r_s=<Quantity 15.3 kpc>, c=10, d_c=1.4888043637074615), PointMass_params=PointMassParams(M=<Quantity 1. solMass>), MN_params=MNParams(M=<Quantity 6.5e+10 solMass>, a=<Quantity 3. kpc>, b=<Quantity 0.28 kpc>))
lo10(- Total Energy): 3.1654009278422692, Lz: 3.1465838776377626
Gradient of the total mass of the Plummer sphere:
 -7.339575754297227
Loss:
 363.5115742393411
