In [1]:
from rubix.core.pipeline import RubixPipeline 
import os
config = {
    "pipeline":{"name": "calc_ifu"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 14,
        "reuse": True,
        }
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-14.hdf5",
        },
    
    },
    "output_path": "output",

    "telescope":
        {"name": "MUSE"},
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1},
        
    "ssp": {
        "template": {
            "name": "BruzualCharlot2003"
        },
    },

        
}

pipe = RubixPipeline(config)

data= pipe.run()

2024-05-14 17:22:33,341 - rubix - INFO - Getting rubix data...
2024-05-14 17:22:33,342 - rubix - INFO - Rubix galaxy file already exists, skipping conversion
2024-05-14 17:22:35,392 - rubix - INFO - Data loaded with 484076 particles.
2024-05-14 17:22:35,393 - rubix - INFO - Setting up the pipeline...
2024-05-14 17:22:35,393 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {'type': 'face-on'}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}}}
2024-05-14 17:22:35,805 - rubix - INFO - Assembling the pipeline...
2024-05-14 17:22:35,806 - rubix - INFO - Compiling the expressions...
2024-05-14 17:22:35,807 - rubix - INFO - Running the pipeline on the input data...
2024-05-14 17:22:36,024 - rubix - INFO - Pipeline run completed in 0.63 seconds.


rotating galaxy:  face-on


In [2]:
print(data.keys())

dict_keys(['age', 'coords', 'mass', 'metallicity', 'n_particles', 'pixel_assignment', 'spatial_bin_edges', 'velocities'])


In [3]:
mass = data["mass"]
metallicity = data["metallicity"]
age = data["age"]


n_stars = 1000
#get the first n_stars

mass = mass[:n_stars]
metallicity = metallicity[:n_stars]
age = age[:n_stars]



In [4]:
import jax 
n_gpus = jax.device_count()
print(n_gpus)

1


In [5]:
mass = mass.reshape(n_gpus, -1)
metallicity = metallicity.reshape(n_gpus, -1)
age = age.reshape(n_gpus, -1)




print(mass.shape)

(1, 1000)


In [6]:
# Calculate Spectrum for each GPU
from rubix.core.ssp import get_lookup 

lookup = get_lookup(config)
lookup

2024-05-14 17:22:36,113 - rubix - DEBUG - Getting SSP template: BruzualCharlot2003
2024-05-14 17:22:36,125 - rubix - DEBUG - Method not defined, using default method: cubic


Partial(<PjitFunction of <function interp2d at 0x7aab2ada4860>>, method='cubic', x=Array([1.e-04, 4.e-04, 4.e-03, 8.e-03, 2.e-02, 5.e-02], dtype=float32), y=Array([ 0.       ,  5.100002 ,  5.1500006,  5.1999993,  5.25     ,
        5.3000016,  5.350002 ,  5.4000006,  5.4500012,  5.500002 ,
        5.550002 ,  5.600002 ,  5.6500025,  5.700002 ,  5.750002 ,
        5.8000026,  5.850003 ,  5.900003 ,  5.950003 ,  6.       ,
        6.0200005,  6.040001 ,  6.0599985,  6.0799985,  6.100002 ,
        6.120001 ,  6.1399984,  6.16     ,  6.18     ,  6.1999993,
        6.2200007,  6.24     ,  6.2599998,  6.2799997,  6.2999997,
        6.3199987,  6.3399997,  6.3600006,  6.3799996,  6.3999987,
        6.4200006,  6.44     ,  6.4599996,  6.4799995,  6.499999 ,
        6.52     ,  6.539999 ,  6.56     ,  6.5799994,  6.6      ,
        6.6199994,  6.6399994,  6.66     ,  6.679999 ,  6.699999 ,
        6.72     ,  6.7399993,  6.7599993,  6.7799997,  6.799999 ,
        6.819999 ,  6.839999 ,  6.85999

In [7]:
lookup(metallicity[0,0], age[0,0])

Array([-1.27273414e-09, -1.51689794e-09, -1.72743986e-09, -1.93381600e-09,
       -2.02696859e-09, -1.97521643e-09, -1.75261683e-09, -1.58770663e-09,
       -1.30428035e-09, -1.60758717e-09, -7.20912441e-10,  4.65053801e-10,
        1.24268569e-08,  1.64785874e-08,  1.76358341e-08,  1.89205860e-08,
        2.06282174e-08,  2.59454733e-08,  2.69601408e-08,  3.60744181e-08,
        3.84739494e-08,  1.31027292e-07,  3.38923940e-07,  3.53446353e-07,
        3.72944726e-07,  3.97857292e-07,  4.12958855e-07,  4.40954295e-07,
        4.58570213e-07,  4.76729269e-07,  4.94884887e-07,  5.07890718e-07,
        5.16900116e-07,  5.34818298e-07,  5.57560327e-07,  5.75684908e-07,
        5.90230684e-07,  6.09832455e-07,  6.23636424e-07,  6.13515738e-07,
        6.66773190e-07,  6.87257852e-07,  6.90145612e-07,  7.12055282e-07,
        1.09979089e-06,  1.80914742e-06,  2.37767790e-06,  2.15393561e-06,
        2.30861133e-06,  1.82965368e-06,  2.20347692e-06,  2.33049082e-06,
        2.29904981e-06,  

In [8]:
age.shape

(1, 1000)

In [9]:
from jax import vmap, pmap 

# Pmap over the GPUs, vmap over the particles


lookup_vmap = vmap(lookup, in_axes=(0, 0))

In [10]:
metallicity[0].shape

(1000,)

In [11]:
metallicity[0].shape

(1000,)

In [12]:
lookup_vmap(metallicity[0], age[0])

Array([[-1.2727341e-09, -1.5168979e-09, -1.7274399e-09, ...,
         3.7306543e-06,  3.6678998e-06,  3.6346380e-06],
       [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
       [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
       ...,
       [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
       [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
       [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00]], dtype=float32)

In [13]:
#pmap across the GPUs
lookup_pmap = pmap(lookup_vmap, in_axes=(0, 0))

lookup_pmap(metallicity, age)

Array([[[-1.2727341e-09, -1.5168979e-09, -1.7274399e-09, ...,
          3.7306543e-06,  3.6678998e-06,  3.6346380e-06],
        [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
          0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
        [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
          0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
        ...,
        [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
          0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
        [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
          0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
        [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
          0.0000000e+00,  0.0000000e+00,  0.0000000e+00]]], dtype=float32)