In [2]:
# NBVAL_SKIP
import os

config = {
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars", "gas"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 12,
        "reuse": True,
        },

        "subset": {
            "use_subset": True,
            "subset_size": 1000,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-12.hdf5",
        },
    },
    "output_path": "output",

        
}

Convert the Data into Rubix Galaxy HDF5. This will make the call to the IllustrisAPI to download the data, and then convert it into the rubix hdf5 format using the input handler

In [3]:
# NBVAL_SKIP
from rubix.core.data import convert_to_rubix

In [4]:
# NBVAL_SKIP
convert_to_rubix(config)

2024-07-16 11:53:49,747 - rubix - INFO - 
   ___  __  _____  _____  __
  / _ \/ / / / _ )/  _/ |/_/
 / , _/ /_/ / _  |/ /_>  <  
/_/|_|\____/____/___/_/|_|  
                            

2024-07-16 11:53:49,749 - rubix - INFO - Rubix version: 0.0.post66+g42d5801.d20240712
2024-07-16 11:53:49,750 - rubix - INFO - Rubix galaxy file already exists, skipping conversion


'output'

In [5]:
from rubix.core.data import convert_to_rubix, prepare_input
import sys
import jax.numpy as jnp

rubixdata = prepare_input(config)
type(rubixdata)
sys.getsizeof(rubixdata)
len(rubixdata.stars.coords)
attr = [attr for attr in dir(rubixdata.stars) if not attr.startswith('__')]# and attr not in ('coords', 'velocity')]
print(attr)

rubixdata.stars.__getattribute__(attr[3])
#rubixdata.stars.__getattribute__(attr[0]) = jnp.array(rubixdata.stars.__getattribute__(attr[0]))
#dir(rubixdata.galaxy)
#print(rubixdata.stars.__getattribute__(attr[2]))
attribute_value = rubixdata.stars.__getattribute__(attr[2])  # Get the current value
jax_array_value = jnp.array(attribute_value)  # Convert it to a JAX array
setattr(rubixdata.stars, attr[2], jax_array_value)  # Set the converted value back
#print(rubixdata.stars.__getattribute__(attr[2]))  # Check that the value has been updated
print(type(rubixdata.stars.velocity))

{'logger': {'log_level': 'DEBUG', 'log_file_path': None, 'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s'}, 'data': {'name': 'IllustrisAPI', 'args': {'api_key': '05add2d69d501eb62014faa4dadb1af0', 'particle_type': ['stars', 'gas'], 'simulation': 'TNG50-1', 'snapshot': 99, 'save_data_path': 'data'}, 'load_galaxy_args': {'id': 12, 'reuse': True}, 'subset': {'use_subset': True, 'subset_size': 1000}}, 'simulation': {'name': 'IllustrisTNG', 'args': {'path': 'data/galaxy-id-12.hdf5'}}, 'output_path': 'output'}


KeyError: 'gas'

In [6]:
# NBVAL_SKIP
from rubix.utils import print_hdf5_file_structure

print(print_hdf5_file_structure("output/rubix_galaxy.h5"))

File: output/rubix_galaxy.h5
Group: galaxy
    Dataset: center (float64) ((3,))
    Dataset: halfmassrad_stars (float64) (())
    Dataset: redshift (float64) (())
Group: meta
    Dataset: BoxSize (float64) (())
    Dataset: CutoutID (int64) (())
    Dataset: CutoutRequest (object) (())
    Dataset: CutoutType (object) (())
    Dataset: Git_commit (|S40) (())
    Dataset: Git_date (|S29) (())
    Dataset: HubbleParam (float64) (())
    Dataset: MassTable (float64) ((6,))
    Dataset: NumFilesPerSnapshot (int64) (())
    Dataset: NumPart_ThisFile (int32) ((6,))
    Dataset: Omega0 (float64) (())
    Dataset: OmegaBaryon (float64) (())
    Dataset: OmegaLambda (float64) (())
    Dataset: Redshift (float64) (())
    Dataset: SimulationName (object) (())
    Dataset: SnapshotNumber (int64) (())
    Dataset: Time (float64) (())
    Dataset: UnitLength_in_cm (float64) (())
    Dataset: UnitMass_in_g (float64) (())
    Dataset: UnitVelocity_in_cm_per_s (float64) (())
Group: particles
    Group

In [7]:
# NBVAL_SKIP
from rubix.utils import load_galaxy_data

load_galaxy_data("output/rubix_galaxy.h5")

({'subhalo_center': array([11413.86337268, 35893.86374042, 32027.01940138]),
  'subhalo_halfmassrad_stars': 7.727193253526112,
  'redshift': 2.220446049250313e-16,
  'particle_data': {'stars': {'age': array([6.61104195, 7.36190701, 6.66032986, ..., 6.40950954, 6.42073787,
           7.89013795]),
    'coords': array([[11413.87546682, 35893.80997872, 32027.00659656],
           [11413.86013316, 35893.80119009, 32027.00812409],
           [11413.87029215, 35893.78527919, 32027.01561964],
           ...,
           [11396.95793694, 35940.89677239, 32025.26185701],
           [11365.51268778, 35930.26383014, 32011.07761748],
           [11407.98233896, 35899.36944122, 32062.31917969]]),
    'mass': array([ 86821.71309025,  73289.33728592, 111279.11567214, ...,
           103573.98244387,  85964.95135161,  83291.24409303]),
    'metallicity': array([0.09918208, 0.04759451, 0.07264117, ..., 0.00494921, 0.00485685,
           0.00233734], dtype=float32),
    'velocity': array([[-2.5744148e-15