In [23]:
import jax

In [24]:
jax.devices()

[cuda(id=0)]

In [25]:
from jax.lib import xla_bridge


In [26]:
print(xla_bridge.get_backend().platform)

gpu


In [15]:
import os 
os.environ['ILLUSTRIS_API_KEY'] = 'your_illustris_key_here'

In [16]:
import yaml
from pathlib import Path

In [17]:
config = {
    "pipeline":{"name": "calc_ifu"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 14,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": True,
            "subset_size": 25000,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-14.hdf5",
        },
    
    },
    "output_path": "output",

    "telescope":
        {"name": "MUSE"},
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1},
        
    "ssp": {
        "template": {
            "name": "BruzualCharlot2003"
        },
    },        
}

In [18]:
# NBVAL_SKIP
from rubix.core.data import convert_to_rubix, prepare_input

convert_to_rubix(config) # Convert the config to rubix format and store in output_path folder
coords, vel, metallicity, mass, age = prepare_input(config) # Prepare the input for the pipeline

2024-05-28 17:12:06,378 - rubix - INFO - Rubix galaxy file already exists, skipping conversion


In [19]:
from rubix.core import pipeline as rpl


In [20]:
pipeline = rpl.RubixPipeline(config);

2024-05-28 17:12:06,465 - rubix - INFO - Getting rubix data...
2024-05-28 17:12:06,466 - rubix - INFO - Rubix galaxy file already exists, skipping conversion
2024-05-28 17:12:06,504 - rubix - INFO - Data loaded with 25000 particles.
2024-05-28 17:12:06,504 - rubix - DEBUG - Data Shape: {'coords': (25000, 3), 'velocities': (25000, 3), 'metallicity': (25000,), 'mass': (25000,), 'age': (25000,)}


In [None]:
data = pipeline.run()

jax.block_until_ready(data);

In [18]:
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
    data = pipeline.run()
    jax.block_until_ready(data);

In [21]:
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
  # Run the operations to be profiled
  key = jax.random.key(0)
  x = jax.random.normal(key, (5000, 5000))
  y = x @ x
  y.block_until_ready()



Open URL in browser: https://ui.perfetto.dev/#!/?url=http://127.0.0.1:9001/perfetto_trace.json.gz


127.0.0.1 - - [28/May/2024 17:12:53] code 501, message Unsupported method ('OPTIONS')
127.0.0.1 - - [28/May/2024 17:12:53] "OPTIONS /status HTTP/1.1" 501 -
127.0.0.1 - - [28/May/2024 17:12:53] code 404, message File not found
127.0.0.1 - - [28/May/2024 17:12:53] "POST /status HTTP/1.1" 404 -
127.0.0.1 - - [28/May/2024 17:12:53] code 501, message Unsupported method ('OPTIONS')
127.0.0.1 - - [28/May/2024 17:12:53] "OPTIONS /perfetto_trace.json.gz HTTP/1.1" 501 -
127.0.0.1 - - [28/May/2024 17:12:53] "GET /perfetto_trace.json.gz HTTP/1.1" 200 -


In [22]:
with jax.profiler.trace("/tmp/tensorboard"):
  key = jax.random.key(0)
  x = jax.random.normal(key, (5000, 5000))
  y = x @ x
  y.block_until_ready()