In [None]:
#NBVAL_SKIP
import matplotlib.pyplot as plt
from rubix.core.pipeline import RubixPipeline 
import os
config = {
    "pipeline":{"name": "calc_ifu"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars", "gas"],
            "cube_type":["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 11,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": True,
            "subset_size": 1000,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-11.hdf5",
        },
    
    },
    "output_path": "output",

    "telescope":
        {"name": "MUSE",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 1,"noise_distribution": "normal"},},
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "BruzualCharlot2003"
        },
    },        
}

pipe = RubixPipeline(config)

data= pipe.run()

#datacube = data.stars.datacube
#img = datacube.sum(axis=2)
#plt.imshow(img, origin="lower")

In [None]:
import jax.numpy as jnp
print(data.stars.pixel_assignment)
print(data.stars.spatial_bin_edges)
print(data.stars.spectra.shape)

sum_spectra = jnp.sum(data.stars.spectra, axis=1)
print(sum_spectra.shape)
print(sum_spectra)
#indices_nonzero = jnp.where(data.stars.spectra[0][:][100] != 0)
#print(indices_nonzero)

#spectra_np = jnp.array(data.stars.spectra)

# Set NumPy print options
#jnp.set_printoptions(threshold=60000)

# Print the full array
#print(spectra_np)

In [None]:
def print_object_attributes(obj):
    # Filter out magic methods and print attribute names and their values
    for attr in dir(obj):
        if not attr.startswith('__'):
            print(f"{attr}")#: {getattr(obj, attr)}")

# Assuming `data` is your object
print_object_attributes(data.stars)
print(data.stars.mask)
print_object_attributes(data.galaxy)


In [None]:
#NBVAL_SKIP
wave = pipe.telescope.wave_seq
plt.plot(wave, data.stars.spectra[0][0][:])
plt.plot(wave, data.stars.spectra[0][1][:])


Some of the spectra may be zero, this happens if the metallicity or age values are outside the range of the SSP model. This is currently the expected behavior

In [None]:
#NBVAL_SKIP
import jax 
import jax.numpy as jnp
# Create a function to calculate a single IFU cube
def calculate_ifu_cube(stars_spectra, pixel_indices):
    # Create an IFU cube of shape (25*25, 842)
    #ifu_cube = jnp.zeros((25 * 25, 842))
    
    # Use jax.ops.segment_sum to sum the spectra into the IFU cube based on pixel indices
    ifu_cube = jax.ops.segment_sum(stars_spectra, pixel_indices, num_segments=25*25)
    
    # Reshape the IFU cube to the desired shape (25, 25, 842)
    ifu_cube = ifu_cube.reshape((25, 25, 3721))
    
    return ifu_cube

spectra = data.stars.spectra
assignments = data.stars.pixel_assignment

# Calculate 4 individual IFU cubes
ifu_cubes = jax.vmap(calculate_ifu_cube)(spectra, assignments)

# Sum the 4 IFU cubes
final_ifu_cube = jnp.sum(ifu_cubes, axis=0)
final_ifu_cube.shape

In [None]:
#NBVAL_SKIP
wavelengths = pipe.telescope.wave_seq

# get the indices of the visible wavelengths of 4000-8000 Angstroms

visible_indices = jnp.where((wavelengths >= 4000) & (wavelengths <= 8000))



In [None]:
#NBVAL_SKIP
wavelengths

In [None]:
#NBVAL_SKIP
spectra[0,7]

In [None]:
#NBVAL_SKIP
# plot example spectrum
import matplotlib.pyplot as plt

spec = final_ifu_cube[12, 12]

plt.plot(wavelengths[visible_indices], spec[visible_indices])
plt.yscale("log")

In [None]:
#NBVAL_SKIP
# get the spectra of the visible wavelengths from the ifu cube
visible_spectra = final_ifu_cube[:, :, visible_indices[0]]
visible_spectra.shape

In [None]:
#NBVAL_SKIP
# Sum up all spectra to create an image
image = jnp.sum(visible_spectra, axis = 2)
plt.imshow(image, origin="lower", cmap="inferno")
plt.colorbar()

In [None]:
import jax.numpy as jnp

class SpectrumCalculator:
    def __init__(self, gaussian):
        self.gaussian = gaussian

    def compute_spectrum(self, wavelengthrange, emission_peaks, wavelengths, dispersionfactor):
        spectrum = jnp.zeros_like(wavelengthrange)  # Initialize the spectrum to zero

        # Loop over the 128 indices
        for i in range(128):
            spectrum += self.gaussian(wavelengthrange, emission_peaks[i], wavelengths[i], dispersionfactor * wavelengths[i])

        return spectrum

# Example usage
def gaussian(wavelengthrange, peak, wavelength, dispersion):
    return jnp.exp(-0.5 * ((wavelengthrange - wavelength) / dispersion) ** 2) * peak

wavelengthrange = jnp.linspace(1000, 10000, 1000)
emission_peaks = jnp.random.rand(128)
wavelengths = jnp.random.rand(128) * 9000 + 1000
dispersionfactor = 0.1

calculator = SpectrumCalculator(gaussian)
spectrum = calculator.compute_spectrum(wavelengthrange, emission_peaks, wavelengths, dispersionfactor)
print(spectrum)