In [None]:
#NBVAL_SKIP
import os
#os.environ['SPS_HOME'] = '/Users/annalena/Documents/GitHub/fsps'
os.environ['SPS_HOME'] = '/home/annalena/sps_fsps'

In [None]:
#NBVAL_SKIP
config = {
    "pipeline": {"name": "calc_ifu"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 11,
        "reuse": True,
        },

        "subset": {
            "use_subset": True,
            "subset_size": 100000,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-11.hdf5",
        },
    
    },
    "output_path": "output",
    "output_modified": False,

    "telescope": {
        "name": "MUSE",
        "psf": {"name": "gaussian", "size": 5, "sigma": 0.5},
        "lsf": {"sigma": 0.5},
        "noise": {"signal_to_noise": 100, "noise_distribution": "normal"},
    },
    "cosmology": {"name": "PLANCK15"},
    "galaxy": {
        "dist_z": 0.1,
        "rotation": {"type": "face-on"},
    },
    "ssp": {
        "template": {"name": "FSPS"},
    },
}

In [None]:
#NBVAL_SKIP
import jax.numpy as jnp
from rubix.core.pipeline import RubixPipeline
pipe = RubixPipeline(config)

rubixdata = pipe.run()
rubixdata_fsps = rubixdata

In [None]:
#NBVAL_SKIP
from rubix.spectra.ifu import convert_luminoisty_to_flux
from rubix.cosmology import PLANCK15

observation_lum_dist = PLANCK15.luminosity_distance_to_z(config["galaxy"]["dist_z"])
observation_z = config["galaxy"]["dist_z"]
pixel_size = 1.0

spectra_fsps = convert_luminoisty_to_flux(rubixdata_fsps.stars.datacube, observation_lum_dist, observation_z, pixel_size)

In [None]:
#NBVAL_SKIP
from rubix.telescope.filters import load_filter, print_filter_list, print_filter_list_info, print_filter_property
# NBVAL_SKIP
# load all fliter curves for SLOAN
curves = load_filter("SLOAN")

In [None]:
# NBVAL_SKIP
curves.plot()

In [None]:
#NBVAL_SKIP
from rubix.telescope.filters import convolve_filter_with_spectra
import matplotlib.pyplot as plt

In [None]:
# NBVAL_SKIP
wave = pipe.telescope.wave_seq
datacube = spectra_fsps

for filter in curves:
    convolved = convolve_filter_with_spectra(filter, datacube, wave)
    plt.figure()
    plt.imshow(convolved)
    plt.colorbar()
    plt.title(filter.name)

In [None]:
#NBVAL_SKIP
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize

# Assuming curves, datacube, and wave are defined
num_filters = len(curves)
nrows = 2
ncols = 5

fig, axes = plt.subplots(nrows, ncols, figsize=(15, 6))

# Find the global min and max for the colorbars for each row
vmin_row1 = np.inf
vmax_row1 = -np.inf
vmin_row2 = np.inf
vmax_row2 = -np.inf
convolved_list = []

for i, filter in enumerate(curves):
    convolved = convolve_filter_with_spectra(filter, datacube, wave)
    convolved_list.append(convolved)
    if i in [0, 3, 5, 7, 9]:  # First row
        vmin_row1 = min(vmin_row1, convolved.min())
        vmax_row1 = max(vmax_row1, convolved.max())
    else:  # Second row
        vmin_row2 = min(vmin_row2, convolved.min())
        vmax_row2 = max(vmax_row2, convolved.max())

# Plot each convolved image in the grid
for i, ax in enumerate(axes.flat):
    if i < 5:  # First row
        filter_index = [0, 3, 5, 7, 9][i]
        im = ax.imshow(convolved_list[filter_index], vmin=vmin_row1, vmax=vmax_row1, cmap='viridis')
        ax.set_title(curves[filter_index].name)
    else:  # Second row
        filter_index = [1, 2, 4, 6, 8][i - 5]
        im = ax.imshow(convolved_list[filter_index], vmin=vmin_row2, vmax=vmax_row2, cmap='inferno')
        ax.set_title(curves[filter_index].name)
    ax.axis('off')

# Adjust layout with tight_layout
plt.tight_layout()

# Create smaller axes for the colorbars outside the grid
fig.subplots_adjust(right=0.85)
cbar_ax1 = fig.add_axes([0.87, 0.55, 0.02, 0.35])  # Position for the colorbar of the first row
cbar_ax2 = fig.add_axes([0.87, 0.07, 0.02, 0.35])  # Position for the colorbar of the second row

# Create ScalarMappable objects for the colorbars
norm_row1 = Normalize(vmin=vmin_row1, vmax=vmax_row1)
norm_row2 = Normalize(vmin=vmin_row2, vmax=vmax_row2)
sm_row1 = ScalarMappable(norm=norm_row1, cmap='viridis')
sm_row2 = ScalarMappable(norm=norm_row2, cmap='inferno')

# Add colorbars for each row with different colormaps
fig.colorbar(sm_row1, cax=cbar_ax1)
fig.colorbar(sm_row2, cax=cbar_ax2)

plt.savefig("output/filters_fsps_galaxy.png")
plt.show()