In [None]:
#NBVAL_SKIP
import matplotlib.pyplot as plt
from rubix.core.pipeline import RubixPipeline 
import os
config = {
    "pipeline":{"name": "calc_ifu"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars", "gas"],
            "cube_type": ["gas"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 11,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": True,
            "subset_size": 100,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-11.hdf5",
        },
    
    },
    "output_path": "output",

    "telescope":
        {"name": "TESTGAS",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 1,"noise_distribution": "normal"},},
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "BruzualCharlot2003"
        },
    },
    "cloudy": {
        "templates": {
            "name": "UVB + CMB"
        },
        },       
}


#pipe = RubixPipeline(config)
#data= pipe.run()

In [None]:
# NBVAL_SKIP
from rubix.core.data import convert_to_rubix, prepare_input

convert_to_rubix(config) # Convert the config to rubix format and store in output_path folder
rubixdata = prepare_input(config) # Prepare the input for the pipeline

In [None]:
rubixdata.gas.metals#[:,0]

In [None]:
from rubix.core.rotation import get_galaxy_rotation

rotate = get_galaxy_rotation(config)
rubixdata = rotate(rubixdata)

In [None]:

from rubix.core.telescope import get_filter_particles
filter_particles = get_filter_particles(config)

rubixdata = filter_particles(rubixdata)

In [None]:
# NBVAL_SKIP
from rubix.core.telescope import get_spaxel_assignment
bin_particles = get_spaxel_assignment(config)

rubixdata = bin_particles(rubixdata)

print(rubixdata.gas.pixel_assignment)
print(rubixdata.gas.spatial_bin_edges)

In [None]:
#from rubix.core.data import get_reshape_data
#reshape_data = get_reshape_data(config)

#rubixdata = reshape_data(rubixdata)

In [None]:
from rubix.spectra.cue.cue.grid import CueGasLookup

CueClass = CueGasLookup(config)

gas_emission = CloudyClass.get_spectra(rubixdata)
rubixdata = gas_emission

In [None]:
import matplotlib.pyplot as plt

plt.plot(rubixdata.gas.wavelengthrange, rubixdata.gas.spectra[0])
plt.ylim(0, 1e-22)
plt.show()

In [None]:
import jax.numpy as jnp

print(jnp.isnan(rubixdata.gas.spectra).any())

nan_mask = jnp.isnan(rubixdata.gas.spectra)

# Replace NaN values with 0
cleaned_array = jnp.where(nan_mask, 0, rubixdata.gas.spectra)

print(jnp.isnan(cleaned_array).any())

rubixdata.gas.spectra = cleaned_array

In [None]:
from rubix.core.ifu import get_scale_spectrum_by_mass

scale_spectrum_by_mass = get_scale_spectrum_by_mass(config)

rubixdata = scale_spectrum_by_mass(rubixdata)
print(rubixdata.gas.spectra)

In [None]:
import matplotlib.pyplot as plt

print(rubixdata.gas.spectra[0].shape)
plt.plot(rubixdata.gas.wavelengthrange, rubixdata.gas.spectra[0])
plt.ylim(0, 1e-16)
plt.show()

In [None]:
rubixdata.stars.velocity.shape
rubixdata.gas.velocity.shape

In [None]:
from rubix.core.ifu import get_doppler_shift_and_resampling

#print(jnp.nonzero(rubixdata.gas.spectra))
print(f"Initial rubixdata.gas.spectra shape: {rubixdata.gas.spectra.shape}")

doppler_shift_and_resampling = get_doppler_shift_and_resampling(config)

rubixdata = doppler_shift_and_resampling(rubixdata)
#print(rubixdata.stars.spectra)

#print(jnp.nonzero(rubixdata.gas.spectra))
print(f"Processed rubixdata.gas.spectra shape: {rubixdata.gas.spectra.shape}")

In [None]:
from rubix.core.pipeline import RubixPipeline 

pipe = RubixPipeline(config)

wave = pipe.telescope.wave_seq
print(wave)
print(rubixdata.gas.spectra[0][:])

#for i in range(0, 10):
#    plt.plot(wave, rubixdata.gas.spectra[0][i][:])
plt.plot(wave, rubixdata.gas.spectra[0][:])
plt.plot(wave, rubixdata.gas.spectra[1][:])
plt.plot(wave, rubixdata.gas.spectra[2][:])
plt.ylim(0, 1e-25)

In [None]:
from rubix.core.ifu import get_calculate_datacube
calculate_datacube = get_calculate_datacube(config)

rubixdata = calculate_datacube(rubixdata)
#print(rubixdata.gas.datacube)

datacube = rubixdata.gas.datacube
img = datacube.sum(axis=2)
plt.imshow(img, origin="lower")

In [None]:
from rubix.core.psf import get_convolve_psf
convolve_psf = get_convolve_psf(config)

rubixdata = convolve_psf(rubixdata)

In [None]:
from rubix.core.lsf import get_convolve_lsf
convolve_lsf = get_convolve_lsf(config)

rubixdata = convolve_lsf(rubixdata)

plt.plot(wave, rubixdata.gas.datacube[12,12,:])
plt.plot(wave, rubixdata.gas.datacube[0,0,:])
#print(rubixdata.gas.datacube)

In [None]:
from rubix.core.noise import get_apply_noise

apply_noise = get_apply_noise(config)

rubixdata = apply_noise(rubixdata)

datacube = rubixdata.gas.datacube
img = datacube.sum(axis=2)
plt.imshow(jnp.log10(img), origin="lower")
datacube.shape

In [None]:
subcube = datacube[:, :, 1200:1600]
print(subcube.shape)

img = subcube.sum(axis=2)
plt.imshow(jnp.log10(img), origin="lower")

In [None]:
plt.plot(wave, datacube[12,12,:])
plt.plot(wave, datacube[0,0,:])

In [None]:
datacube = rubixdata.gas.datacube
img = datacube.sum(axis=2)
plt.imshow(img, origin="lower")