In [None]:
import matplotlib.pyplot as plt
from rubix.core.pipeline import RubixPipeline 
import os

config = {
    "pipeline":{"name": "calc_ifu_gas"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars", "gas"],
            "cube_type": ["gas"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 11,
        "reuse": True,
        },

        "subset": {
            "use_subset": True,
            "subset_size": 10000,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-11.hdf5",
        },
    
    },
    "output_path": "output",

    "telescope":
        {"name": "MUSE",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 1,"noise_distribution": "normal"},},
        
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
    "ssp": {
        "template": {
            "name": "BruzualCharlot2003"
        },
    },    
}

pipe = RubixPipeline(config)

data= pipe.run()

#datacube = data.stars.datacube
#img = datacube.sum(axis=2)
#plt.imshow(img, origin="lower")

In [None]:
datacube = data.gas.datacube
img = datacube.sum(axis=2)
plt.imshow(img, origin="lower")

In [None]:
#NBVAL_SKIP
# plot example spectrum
import matplotlib.pyplot as plt
import jax.numpy as jnp

wavelengths = pipe.telescope.wave_seq

# get the indices of the visible wavelengths of 4000-8000 Angstroms
visible_indices = jnp.where((wavelengths >= 4000) & (wavelengths <= 8000))

spec = data.gas.datacube[12, 12]

#plt.plot(wavelengths[visible_indices], spec[visible_indices])
plt.plot(wavelengths, spec)
plt.yscale("log")

In [None]:
#NBVAL_SKIP
# get the spectra of the visible wavelengths from the ifu cube
final_ifu_cube = data.stars.datacube
visible_spectra = final_ifu_cube[:, :, visible_indices[0]]
visible_spectra.shape

# Sum up all spectra to create an image
image = jnp.sum(visible_spectra, axis = 2)
plt.imshow(image, origin="lower", cmap="inferno")
plt.colorbar()