In [None]:
# NBVAL_SKIP
import os
#os.environ['SPS_HOME'] = '/Users/annalena/Documents/GitHub/fsps'
os.environ['SPS_HOME'] = '/home/annalena/sps_fsps'

In [None]:
# NBVAL_SKIP
import jax
jax.config.update("jax_enable_x64", True)

In [None]:
#NBVAL_SKIP
config = {
    "pipeline": {"name": "calc_ifu"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 11,
        "reuse": True,
        },

        "subset": {
            "use_subset": True,
            "subset_size": 200000,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-11.hdf5",
        },
    
    },
    "output_path": "output",
    "output_modified": False,

    "telescope": {
        "name": "MUSE",
        "psf": {"name": "gaussian", "size": 5, "sigma": 0.5},
        "lsf": {"sigma": 0.5},
        "noise": {"signal_to_noise": 100, "noise_distribution": "normal"},
    },
    "cosmology": {"name": "PLANCK15"},
    "galaxy": {
        "dist_z": 0.1,
        "rotation": {"type": "face-on"},
    },
    "ssp": {
        "template": {"name": "BruzualCharlot2003"},
    },
}

# Bruzual&Charlot

In [None]:
#NBVAL_SKIP
import jax.numpy as jnp
from rubix.core.pipeline import RubixPipeline
pipe = RubixPipeline(config)

rubixdata = pipe.run()

rubixdata_bruzual = rubixdata

# FSPS

In [None]:
# NBVAL_SKIP
config["ssp"]["template"]["name"] = "FSPS"

pipe = RubixPipeline(config)

rubixdata = pipe.run()
rubixdata_fsps = rubixdata

# MaStar

In [None]:
# NBVAL_SKIP
config["ssp"]["template"]["name"] = "Mastar_CB19_SLOG_1_5"

pipe = RubixPipeline(config)

rubixdata = pipe.run()
rubixdata_mastar = rubixdata

# Convert luminosity to flux

In [None]:
# NBVAL_SKIP
from rubix.spectra.ifu import convert_luminoisty_to_flux
from rubix.cosmology import PLANCK15

observation_lum_dist = PLANCK15.luminosity_distance_to_z(config["galaxy"]["dist_z"])
observation_z = config["galaxy"]["dist_z"]
pixel_size = 1.0
spectra_bruzual = convert_luminoisty_to_flux(rubixdata_bruzual.stars.datacube, observation_lum_dist, observation_z, pixel_size)
spectra_fsps = convert_luminoisty_to_flux(rubixdata_fsps.stars.datacube, observation_lum_dist, observation_z, pixel_size)
spectra_mastar = convert_luminoisty_to_flux(rubixdata_mastar.stars.datacube, observation_lum_dist, observation_z, pixel_size)

# Visualize the mock data

In [None]:
# NBVAL_SKIP
import jax.numpy as jnp
import matplotlib.pyplot as plt

# Assuming wave and spectra are already defined
wave = pipe.telescope.wave_seq
spectra = rubixdata_mastar.stars.datacube

# Define the spaxel index to highlight
spaxel_x, spaxel_y = 12, 12
spaxel_x2, spaxel_y2 = 12, 14
spaxel_x3, spaxel_y3 = 12, 16
spaxel_x4, spaxel_y4 = 16, 12

# Prepare the visible range data
visible_indices = jnp.where((wave >= 4000) & (wave <= 8000))
visible_spectra = spectra[:, :, visible_indices[0]]
image = jnp.sum(visible_spectra, axis=2)

# Create subplots
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

# Plot the spectrum on the left
axes[0].plot(wave, spectra[spaxel_x, spaxel_y, :], label=f"Spaxel [{spaxel_x}, {spaxel_y}]")
axes[0].plot(wave, spectra[spaxel_x2, spaxel_y2, :], label=f"Spaxel [{spaxel_x2}, {spaxel_y2}]")
axes[0].plot(wave, spectra[spaxel_x3, spaxel_y3, :], label=f"Spaxel [{spaxel_x3}, {spaxel_y3}]")
axes[0].plot(wave, spectra[spaxel_x4, spaxel_y4, :], label=f"Spaxel [{spaxel_x4}, {spaxel_y4}]")
axes[0].set_title("Spectrum of Spaxel [12, 12]")
axes[0].set_xlabel("Wavelength [Å]")
axes[0].set_ylabel("Flux")
axes[0].legend()

# Plot the image on the right
im = axes[1].imshow(image, origin="lower", cmap="inferno")
axes[1].scatter(spaxel_y, spaxel_x, color="red", marker="*", s=100, label="Spaxel [12, 12]")  # Mark the spaxel
axes[1].scatter(spaxel_y2, spaxel_x2, color="blue", marker="*", s=100, label="Spaxel [12, 14]")  # Mark the spaxel
axes[1].scatter(spaxel_y3, spaxel_x3, color="green", marker="*", s=100, label="Spaxel [12, 16]")  # Mark the spaxel
axes[1].scatter(spaxel_y4, spaxel_x4, color="orange", marker="*", s=100, label="Spaxel [16, 12]")  # Mark the spaxel
axes[1].set_title("Spatial Image from Data Cube")
axes[1].legend()
cbar = fig.colorbar(im, ax=axes[1], orientation="vertical", label="Integrated Flux")

# Adjust layout and show the plots
plt.tight_layout()
plt.show()

In [None]:
# NBVAL_SKIP
import jax.numpy as jnp
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from matplotlib.gridspec import GridSpec

# Assuming wave, spectra, image, image2, and image3 are defined
wave = pipe.telescope.wave_seq
#spectra1 = rubixdata_bruzual.stars.datacube
#spectra2 = rubixdata_fsps.stars.datacube
#spectra3 = rubixdata_mastar.stars.datacube
spectra1 = spectra_bruzual
spectra2 = spectra_fsps
spectra3 = spectra_mastar

# Spaxel to highlight
spaxel_x, spaxel_y = 12, 12 #75, 75
spaxel_x2, spaxel_y2 = 12, 14 #75, 95
spaxel_x3, spaxel_y3 = 12, 16 #75, 105

# Example images (replace with your data)
visible_indices = jnp.where((wave >= 4000) & (wave <= 8000))
visible_spectra1 = spectra1[:, :, visible_indices[0]]
visible_spectra2 = spectra2[:, :, visible_indices[0]]
visible_spectra3 = spectra3[:, :, visible_indices[0]]
image1 = jnp.sum(visible_spectra1, axis=2)  # Bruzual image
image2 = jnp.sum(visible_spectra2, axis=2)  # FSPS image
image3 = jnp.sum(visible_spectra3, axis=2)  # MaStar image

vmin = 0

# Create figure with GridSpec
fig = plt.figure(figsize=(16, 14))
gs = GridSpec(4, 3, height_ratios=[0.7, 0.3, 0.3, 0.3], hspace=0.4)

sum1 = jnp.sum(spectra1[spaxel_x, spaxel_y, :])
sum2 = jnp.sum(spectra2[spaxel_x, spaxel_y, :])
sum3 = jnp.sum(spectra3[spaxel_x, spaxel_y, :])
print(sum1, sum2, sum3)

# First row: images
ax1 = fig.add_subplot(gs[0, 0])
im1 = ax1.imshow(image1, origin="lower", cmap="inferno")#, vmin=vmin, vmax=1.8e7)#, norm=LogNorm())
ax1.scatter(spaxel_y, spaxel_x, color="red", marker="*", s=100, label="Spaxel [12, 12]")
ax1.scatter(spaxel_y2, spaxel_x2, color="blue", marker="*", s=100, label="Spaxel [12, 14]")
ax1.scatter(spaxel_y3, spaxel_x3, color="green", marker="*", s=100, label="Spaxel [12, 16]")
ax1.set_title("Bruzual&Charlot 2003")
ax1.legend()
fig.colorbar(im1, ax=ax1, orientation="vertical")

ax2 = fig.add_subplot(gs[0, 1])
im2 = ax2.imshow(image2, origin="lower", cmap="inferno")#, vmin=vmin, vmax=1.8e5)#, norm=LogNorm())
ax2.scatter(spaxel_y, spaxel_x, color="red", marker="*", s=100)
ax2.scatter(spaxel_y2, spaxel_x2, color="blue", marker="*", s=100)
ax2.scatter(spaxel_y3, spaxel_x3, color="green", marker="*", s=100)
ax2.set_title("FSPS")
fig.colorbar(im2, ax=ax2, orientation="vertical")

ax3 = fig.add_subplot(gs[0, 2])
im3 = ax3.imshow(image3, origin="lower", cmap="inferno")#, vmin=vmin, vmax=0.9e6)#, norm=LogNorm())
ax3.scatter(spaxel_y, spaxel_x, color="red", marker="*", s=100)
ax3.scatter(spaxel_y2, spaxel_x2, color="blue", marker="*", s=100)
ax3.scatter(spaxel_y3, spaxel_x3, color="green", marker="*", s=100)
ax3.set_title("MaStar")
fig.colorbar(im3, ax=ax3, orientation="vertical")

# Second row: spectrum
ax4 = fig.add_subplot(gs[1, :])  # Full-width spectrum
ax4.plot(wave, spectra1[spaxel_x, spaxel_y, :], color="red")
ax4.plot(wave, spectra1[spaxel_x2, spaxel_y2, :], color="blue")
ax4.plot(wave, spectra1[spaxel_x3, spaxel_y3, :], color="green")
#ax4.plot(wave, spectra2[spaxel_x, spaxel_y, :], label=f"Spaxel [{spaxel_x}, {spaxel_y}], FSPS")
#ax4.plot(wave, spectra3[spaxel_x, spaxel_y, :], label=f"Spaxel [{spaxel_x}, {spaxel_y}], MaStar")
ax4.set_title(f"Spectrum of Spaxels from Bruzual")
ax4.set_xlabel("Wavelength [Å]")
ax4.set_ylabel("Flux [erg/s/cm2/Å]")
#ax4.set_yscale("log")
#ax4.legend()

ax5 = fig.add_subplot(gs[2, :])  # Full-width spectrum
ax5.plot(wave, spectra2[spaxel_x, spaxel_y, :], color="red")
ax5.plot(wave, spectra2[spaxel_x2, spaxel_y2, :], color="blue")
ax5.plot(wave, spectra2[spaxel_x3, spaxel_y3, :], color="green")
ax5.set_title(f"Spectrum of Spaxels from FSPS")
ax5.set_xlabel("Wavelength [Å]")
ax5.set_ylabel("Flux [erg/s/cm2/Å]")
#ax4.set_yscale("log")
#ax5.legend()

ax6 = fig.add_subplot(gs[3, :])  # Full-width spectrum
ax6.plot(wave, spectra3[spaxel_x, spaxel_y, :], color="red")
ax6.plot(wave, spectra3[spaxel_x2, spaxel_y2, :], color="blue")
ax6.plot(wave, spectra3[spaxel_x3, spaxel_y3, :], color="green")
ax6.set_title(f"Spectrum of Spaxels from MaStar")
ax6.set_xlabel("Wavelength [Å]")
ax6.set_ylabel("Flux [erg/s/cm2/Å]")
#ax4.set_yscale("log")
#ax6.legend()

# Adjust layout and show
plt.tight_layout()
plt.savefig("output/ssp_compare_spectra_100000.png")
plt.show()