In [None]:
import jax

In [None]:
jax.devices()

In [None]:
from jax.lib import xla_bridge


In [None]:
print(xla_bridge.get_backend().platform)

In [None]:
import os 
os.environ['ILLUSTRIS_API_KEY'] = 'your_illustris_key_here'

In [None]:
import yaml
from pathlib import Path

In [None]:
config = {
    "pipeline":{"name": "calc_ifu"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 14,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": True,
            "subset_size": 25000,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-14.hdf5",
        },
    
    },
    "output_path": "output",

    "telescope":
        {"name": "MUSE"},
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1},
        
    "ssp": {
        "template": {
            "name": "BruzualCharlot2003"
        },
    },        
}

In [None]:
# NBVAL_SKIP
from rubix.core.data import convert_to_rubix, prepare_input

convert_to_rubix(config) # Convert the config to rubix format and store in output_path folder
coords, vel, metallicity, mass, age = prepare_input(config) # Prepare the input for the pipeline

In [None]:
from rubix.core import pipeline as rpl


In [None]:
pipeline = rpl.RubixPipeline(config);

In [None]:
data = pipeline.run()

jax.block_until_ready(data);

Finishes in <  1s and does not produce any profiling output when using tensorflow-board: 

```bash 
pip install tensorflow tensorboard-plugin-profile
```

execute code below, then: 

```bash 
tensorboard --logdir=/tmp/jax-trace 
```

In [None]:
with jax.profiler.trace("/tmp/jax-trace"):
    data = pipeline.run()
    jax.block_until_ready(data);

... neither does this, which is the example code from [the documentation](https://jax.readthedocs.io/en/latest/profiling.html)

In [28]:
with jax.profiler.trace("/tmp/tensorboard"):
  key = jax.random.key(0)
  x = jax.random.normal(key, (5000, 5000))
  y = x @ x
  y.block_until_ready()