In [1]:
import jax

In [2]:
jax.devices()

[CpuDevice(id=0)]

In [3]:
from jax.lib import xla_bridge


In [4]:
print(xla_bridge.get_backend().platform)

cpu


In [5]:
import yaml
from pathlib import Path
import os

In [6]:
config = {
    "pipeline":{"name": "calc_ifu"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 14,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": True,
            "subset_size": 10000,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-14.hdf5",
        },
    
    },
    "output_path": "output",

    "telescope":
        {"name": "MUSE"},
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1},
        
    "ssp": {
        "template": {
            "name": "BruzualCharlot2003"
        },
    },        
}

In [7]:
# NBVAL_SKIP
from rubix.core.data import convert_to_rubix, prepare_input

convert_to_rubix(config) # Convert the config to rubix format and store in output_path folder
coords, vel, metallicity, mass, age = prepare_input(config) # Prepare the input for the pipeline

2024-05-28 19:00:40,000 - rubix - INFO - 
   ___  __  _____  _____  __
  / _ \/ / / / _ )/  _/ |/_/
 / , _/ /_/ / _  |/ /_>  <  
/_/|_|\____/____/___/_/|_|  
                            

2024-05-28 19:00:40,001 - rubix - INFO - Rubix version: 0.0.post89+g16c73c0.d20240528
2024-05-28 19:00:40,001 - rubix - INFO - Rubix galaxy file already exists, skipping conversion


In [8]:
from rubix.core import pipeline as rpl


In [9]:
pipeline = rpl.RubixPipeline(config);

2024-05-28 19:00:41,365 - rubix - INFO - Getting rubix data...
2024-05-28 19:00:41,366 - rubix - INFO - Rubix galaxy file already exists, skipping conversion
2024-05-28 19:00:41,945 - rubix - INFO - Data loaded with 10000 particles.
2024-05-28 19:00:41,946 - rubix - DEBUG - Data Shape: {'coords': (10000, 3), 'velocities': (10000, 3), 'metallicity': (10000,), 'mass': (10000,), 'age': (10000,)}


In [10]:
data = pipeline.run()

jax.block_until_ready(data);

2024-05-28 19:00:41,952 - rubix - INFO - Setting up the pipeline...
2024-05-28 19:00:41,952 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {'type': 'face-on'}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'reshape_data': {'name': 'reshape_data', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'calculate_spectra': {'name': 'calculate_spectra', 'depends_on': 'reshape_data', 'args': [], 'kwargs': {}}, 'scale_spectrum_by_mass': {'name': 'scale_spectrum_by_mass', 'depends_on': 'calculate_spectra', 'args': [], 'kwargs': {}}, 'doppler_shift_and_resampling': {'name': 'doppler_shift_and_resampling', 'depends_on': 'scale_spectrum_by_mass', 'args': [], 'kwargs': {}}}}
2024-05-28 19:00:42,234 - rubix - DEBUG - Method not defined, using default method: cubic
2024-05-28 19:00:42,265 - rubix - DEBUG - SSP Wave: (842,)
2024-05-

rotating galaxy:  face-on


2024-05-28 19:00:42,539 - rubix - DEBUG - Calculation Finished! Spectra shape: (1, 10000, 842)
2024-05-28 19:00:42,540 - rubix - INFO - Scaling Spectra by Mass...
2024-05-28 19:00:42,546 - rubix - DEBUG - Doppler Shifted SSP Wave: (1, 10000, 842)
2024-05-28 19:00:42,547 - rubix - DEBUG - Telescope Wave Seq: (3721,)
2024-05-28 19:00:46,290 - rubix - INFO - Pipeline run completed in 4.34 seconds.


Finishes in <  1s and does not produce any profiling output when using tensorflow-board: 

```bash 
pip install tensorflow tensorboard-plugin-profile
```

execute code below, then: 

```bash 
tensorboard --logdir=/tmp/jax-trace 
```

In [11]:
with jax.profiler.trace("/tmp/jax-trace"):
    data = pipeline.run()
    jax.block_until_ready(data);

2024-05-28 19:00:51,018 - rubix - INFO - Setting up the pipeline...
2024-05-28 19:00:51,020 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {'type': 'face-on'}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'reshape_data': {'name': 'reshape_data', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'calculate_spectra': {'name': 'calculate_spectra', 'depends_on': 'reshape_data', 'args': [], 'kwargs': {}}, 'scale_spectrum_by_mass': {'name': 'scale_spectrum_by_mass', 'depends_on': 'calculate_spectra', 'args': [], 'kwargs': {}}, 'doppler_shift_and_resampling': {'name': 'doppler_shift_and_resampling', 'depends_on': 'scale_spectrum_by_mass', 'args': [], 'kwargs': {}}}}
2024-05-28 19:00:51,056 - rubix - DEBUG - Method not defined, using default method: cubic
2024-05-28 19:00:51,080 - rubix - DEBUG - SSP Wave: (842,)
2024-05-

rotating galaxy:  face-on


2024-05-28 19:00:51,626 - rubix - DEBUG - Calculation Finished! Spectra shape: (1, 10000, 842)
2024-05-28 19:00:51,627 - rubix - INFO - Scaling Spectra by Mass...
2024-05-28 19:00:51,632 - rubix - DEBUG - Doppler Shifted SSP Wave: (1, 10000, 842)
2024-05-28 19:00:51,633 - rubix - DEBUG - Telescope Wave Seq: (3721,)
2024-05-28 19:00:56,291 - rubix - INFO - Pipeline run completed in 5.27 seconds.


... neither does this, which is the example code from [the documentation](https://jax.readthedocs.io/en/latest/profiling.html)

In [12]:
with jax.profiler.trace("/tmp/tensorboard"):
  key = jax.random.key(0)
  x = jax.random.normal(key, (5000, 5000))
  y = x @ x
  y.block_until_ready()