In [1]:
import os
from math import pi

# Use 8 CPU devices
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=2'

os.environ["CUDA_VISIBLE_DEVICES"] = "1, 2 "  # Use only the first GPU
from typing import Optional, Tuple, Callable, Union, List
from functools import partial

import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from jax import vmap, jit, pmap
from jax import random

# jax.config.update("jax_enable_x64", True)

import numpy as np
from astropy import units as u
from astropy import constants as c

import odisseo
from odisseo import construct_initial_state
from odisseo.integrators import leapfrog
from odisseo.dynamics import direct_acc, DIRECT_ACC, DIRECT_ACC_LAXMAP, DIRECT_ACC_FOR_LOOP, DIRECT_ACC_MATRIX
from odisseo.option_classes import SimulationConfig, SimulationParams, MNParams, NFWParams, PlummerParams, MN_POTENTIAL, NFW_POTENTIAL
from odisseo.initial_condition import Plummer_sphere, ic_two_body, sample_position_on_sphere, inclined_circular_velocity
from odisseo.utils import center_of_mass
from odisseo.time_integration import time_integration
from odisseo.units import CodeUnits
from odisseo.visualization import create_3d_gif, create_projection_gif, energy_angular_momentum_plot
from odisseo.potentials import MyamotoNagai, NFW



plt.rcParams.update({
    'font.size': 20,
    'axes.labelsize': 20,
    'xtick.labelsize': 13,
    'ytick.labelsize': 13,
    'legend.fontsize': 15,
})

In [2]:
jax.devices()

[CudaDevice(id=0), CudaDevice(id=1)]

In [3]:
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import shard_map
from jax.sharding import NamedSharding

# Create a mesh from all devices
devices = jax.devices()
mesh = Mesh(devices, axis_names=('N_particles',))
sharding = NamedSharding(mesh, P('N_particles')) 

#  # jitting and device_put does weird stuff with what is put where
@partial(jit, static_argnames=('sharding'))
def shard_positions_with_shard_map(positions, sharding):
    """
    Shards positions across available devices along axis 0 using shard_map.
    
    Args:
        positions: Array of shape (n, d) to be sharded
        
    Returns:
        Tuple of (sharded_positions, mesh, sharding)
    """
    
    
    # Shard the positions array
    sharded_positions = jax.device_put(positions, sharding)
    
    return sharded_positions


In [4]:
code_length = 10.0 * u.kpc
code_mass = 1e8 * u.Msun
G = 1 
code_units = CodeUnits(code_length, code_mass, G=G)


# Define the 
config = SimulationConfig(N_particles=1_000_000, 
                          return_snapshots=True, 
                          num_snapshots=100, 
                          num_timesteps=1_000, 
                          acceleration_scheme=DIRECT_ACC_LAXMAP,
                          double_map = True,
                          batch_size=100,
                          softening=(0.1 * u.kpc).to(code_units.code_length).value) #default values

params = SimulationParams(t_end = (1 * u.Gyr).to(code_units.code_time).value,  
                          Plummer_params= PlummerParams(Mtot=(1e8 * u.Msun).to(code_units.code_mass).value,
                                                        a=(1 * u.kpc).to(code_units.code_length).value),
                          G=G, ) 

#set up the particles in the initial state
positions, velocities, mass = Plummer_sphere(key=random.PRNGKey(0), params=params, config=config)
initial_state = construct_initial_state(positions, velocities)

# Plummer sphere distribution
# fig = plt.figure(figsize=(15, 5))
# ax = fig.add_subplot(121)
# ax.hist((jnp.linalg.norm(positions, axis=1) * code_units.code_length).to(u.kpc), bins=100, histtype='step', color='k')
# ax.axvline((params.Plummer_params.a*code_units.code_length).to(u.kpc).value, color='r', label='Plummer a')
# ax.set_xlabel('R [kpc]')

# ax = fig.add_subplot(122)
# ax.hist(jnp.linalg.norm((velocities * code_units.code_velocity).to(u.km/u.s).value, axis=1), bins=100, histtype='step', color='k')
# ax.set_xlabel('v [km/s]')
# plt.show()




In [5]:
shard_position = jax.device_put(positions, sharding)
print(shard_position.devices())
jax.debug.visualize_array_sharding(shard_position)

{CudaDevice(id=0), CudaDevice(id=1)}


In [6]:
# distance = jnp.linalg.norm(shard_position - shard_position[:, None, :], axis=2)
# jax.debug.visualize_array_sharding(distance)

In [7]:
#let's shard position, velocity and mass
shard_inital_state = jax.device_put(initial_state, sharding)
print(shard_inital_state.shape)
print(shard_inital_state.devices())

shard_mass = jax.device_put(mass, sharding)
print(shard_mass.shape)
print(shard_mass.devices())


(1000000, 2, 3)
{CudaDevice(id=0), CudaDevice(id=1)}
(1000000,)
{CudaDevice(id=0), CudaDevice(id=1)}


In [None]:
#lets's try to integrate now
snapshots = jax.block_until_ready(time_integration(shard_inital_state, shard_mass, config, params))

In [None]:
snapshots.states.devices()

{CudaDevice(id=0), CudaDevice(id=1)}

In [None]:
snapshots = jax.block_until_ready(time_integration(initial_state, mass, config, params))