In [1]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Use only the first GPU
from typing import Optional, Tuple, Callable, Union, List
from functools import partial

import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from jax import vmap, jit
from jax import random

import jdgsim
from jdgsim import construct_initial_state
from jdgsim.potentials import NFW
from jdgsim.integrators import leapfrog
from jdgsim.dynamics import direct_acc
from jdgsim.option_classes import SimulationConfig, SimulationParams
from jdgsim.initial_condition import Plummer_sphere
from jdgsim.utils import center_of_mass
from jdgsim.time_integration import time_integration

In [2]:
# Define the 
config = SimulationConfig() #default values

params = SimulationParams() #default values

print(config)
print(params)

SimulationConfig(N_particles=1000, dimensions=3, return_snapshots=False, numb_snapshots=10, fixed_timestep=True)
SimulationParams(G=4.498e-06, t_end=1.0)


In [3]:
#set up the particles in the initial state
position, velocity, mass = Plummer_sphere(key=random.PRNGKey(0), params=params, config=config)

#move the center to 10 Kpc distance from the center of the galaxy
position = position + jnp.array([10, 0, 0])

#initialize the initial state
initial_state = construct_initial_state(position, velocity)

#center of mass
com = center_of_mass(initial_state, mass)

In [4]:
final_state = time_integration(initial_state, mass, config, params)



In [None]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(positions[:, 0], positions[:, 1], positions[:, 2], s=1)
ax.scatter(update_state[:, 0, 0], update_state[:, 0, 1], update_state[:, 0, 2], alpha=0.1)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
plt.show()

Array([[[-7.96182156e+00, -1.69075832e-01, -9.58577156e-01],
        [ 2.11125679e+01,  4.46847856e-01,  2.54345846e+00]],

       [[-8.69143486e+00, -1.22750308e-02,  3.45935896e-02],
        [ 1.90425797e+01,  2.75179297e-02, -8.01297650e-02]],

       [[-7.76151800e+00,  2.57508177e-02, -3.42531234e-01],
        [ 2.22194881e+01, -7.33626783e-02,  9.84624028e-01]],

       ...,

       [[ 7.46357083e-01,  4.60744292e-01, -4.01187152e-01],
        [-2.38797626e+01, -1.47444677e+01,  1.28366585e+01]],

       [[-6.98379850e+00, -4.64973748e-01, -4.02199209e-01],
        [ 2.48050671e+01,  1.65051675e+00,  1.42812657e+00]],

       [[-1.23370733e+01,  1.41978419e+00, -1.13456905e+00],
        [ 3.20607495e+00, -3.68995368e-01,  2.93717355e-01]]],      dtype=float32)