In [None]:
#NBVAL_SKIP
import matplotlib.pyplot as plt
from rubix.core.pipeline import RubixPipeline 
import os
config = {
    "pipeline":{"name": "calc_ifu"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars", "gas"],
            "cube_type": ["gas"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 11,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": False,
            "subset_size": 1000,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-11.hdf5",
        },
    
    },
    "output_path": "output",

    "telescope":
        {"name": "MUSE",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 0.001,"noise_distribution": "normal"},},
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "BruzualCharlot2003"
        },
    },
    "cloudy": {
        "templates": {
            "name": "UVB + CMB"
        },
        },       
}


#pipe = RubixPipeline(config)
#data= pipe.run()

In [None]:
# NBVAL_SKIP
from rubix.core.data import convert_to_rubix, prepare_input

convert_to_rubix(config) # Convert the config to rubix format and store in output_path folder
rubixdata = prepare_input(config) # Prepare the input for the pipeline

In [None]:
#NBVAL_SKIP
from rubix.core.rotation import get_galaxy_rotation

rotate = get_galaxy_rotation(config)
rubixdata = rotate(rubixdata)

In [None]:
#NBVAL_SKIP
from rubix.core.telescope import get_filter_particles
filter_particles = get_filter_particles(config)

rubixdata = filter_particles(rubixdata)

In [None]:
# NBVAL_SKIP
from rubix.core.telescope import get_spaxel_assignment
bin_particles = get_spaxel_assignment(config)

rubixdata = bin_particles(rubixdata)

#print(rubixdata.gas.pixel_assignment)
#print(rubixdata.gas.spatial_bin_edges)

In [None]:
#NBVAL_SKIP
#from rubix.spectra.cue.grid import CueGasLookup

#CueClass = CueGasLookup(config)

#rubixdata = CueClass.get_continuum(rubixdata)

In [None]:
#NBVAL_SKIP
#import matplotlib.pyplot as plt

#plt.figure()
#for i in range(10): #range(len(rubixdata.gas.mass)):
#    plt.plot(rubixdata.gas.continuum[0], rubixdata.gas.continuum[1][i])
#plt.plot(rubixdata.gas.continuum[0][0], rubixdata.gas.continuum[0][1])
#plt.show()

In [None]:
#NBVAL_SKIP
#from rubix.spectra.cue.grid import CueGasLookup

#CueClass = CueGasLookup(config)

#rubixdata = CueClass.get_gas_emission_flux(rubixdata)

#import matplotlib.pyplot as plt

#plt.plot(rubixdata.gas.wavelengthrange, rubixdata.gas.spectra[2])
#plt.ylim(0, 1e-24)
#plt.show()

In [None]:
from rubix.core.cue import preprocess_config

configdata = preprocess_config(config)
print(configdata)
print(configdata["factor"])

In [None]:
#NBVAL_SKIP
from rubix.core.cue import get_gas_emission

get_gas_emission = get_gas_emission(config)

rubixdata = get_gas_emission(rubixdata)

In [None]:
import matplotlib.pyplot as plt

plt.plot(rubixdata.gas.wavelengthrange, rubixdata.gas.spectra[2])
plt.ylim(0, 1e-24)
plt.show()

In [None]:
#NBVAL_SKIP
from rubix.core.ifu import get_scale_spectrum_by_mass

scale_spectrum_by_mass = get_scale_spectrum_by_mass(config)

rubixdata = scale_spectrum_by_mass(rubixdata)

In [None]:
#NBVAL_SKIP
import matplotlib.pyplot as plt

print(rubixdata.gas.spectra[0].shape)
plt.plot(rubixdata.gas.wavelengthrange, rubixdata.gas.spectra[4])
#plt.ylim(0, 1e-16)
plt.show()

In [None]:
#NBVAL_SKIP
from rubix.core.ifu import get_doppler_shift_and_resampling

#print(jnp.nonzero(rubixdata.gas.spectra))
print(f"Initial rubixdata.gas.spectra shape: {rubixdata.gas.spectra.shape}")

doppler_shift_and_resampling = get_doppler_shift_and_resampling(config)

rubixdata = doppler_shift_and_resampling(rubixdata)
#print(rubixdata.stars.spectra)

#print(jnp.nonzero(rubixdata.gas.spectra))
print(f"Processed rubixdata.gas.spectra shape: {rubixdata.gas.spectra.shape}")

In [None]:
#NBVAL_SKIP
from rubix.core.pipeline import RubixPipeline 

pipe = RubixPipeline(config)

wave = pipe.telescope.wave_seq
print(wave)
print(rubixdata.gas.spectra[0])

plt.figure()
#for i in range(100):
#    plt.plot(wave, rubixdata.gas.spectra[0][i])
plt.plot(wave, rubixdata.gas.spectra[0][0])
plt.plot(wave, rubixdata.gas.spectra[0][1])
plt.plot(wave, rubixdata.gas.spectra[0][2])
#plt.ylim(0, 1e-25)
plt.show()

In [None]:
#NBVAL_SKIP
import jax.numpy as jnp

# Assuming rubixdata.gas.spectra is already defined
print(f"Shape of the array: {rubixdata.gas.spectra.shape}")

# Check if there are any NaN values in the array
has_nan = jnp.any(jnp.isnan(rubixdata.gas.spectra))
print(f"Does the array contain NaN values? {has_nan}")

# Check if there are any inf values in the array
has_inf = jnp.any(jnp.isinf(rubixdata.gas.spectra))
print(f"Does the array contain inf values? {has_inf}")

# Check if there are any negative values in the array
has_negative_values = jnp.any(rubixdata.gas.spectra < 0)
print(has_negative_values)

In [None]:
#NBVAL_SKIP
from rubix.core.ifu import get_calculate_datacube
calculate_datacube = get_calculate_datacube(config)

rubixdata = calculate_datacube(rubixdata)
#print(rubixdata.gas.datacube)

datacube = rubixdata.gas.datacube
img = datacube.sum(axis=2)
plt.imshow(img, origin="lower")

In [None]:
#NBVAL_SKIP
from rubix.core.psf import get_convolve_psf
convolve_psf = get_convolve_psf(config)

rubixdata = convolve_psf(rubixdata)

In [None]:
#NBVAL_SKIP
from rubix.core.lsf import get_convolve_lsf
convolve_lsf = get_convolve_lsf(config)

rubixdata = convolve_lsf(rubixdata)

plt.plot(wave, rubixdata.gas.datacube[12,12,:])
plt.plot(wave, rubixdata.gas.datacube[0,0,:])
#plt.ylim(0,1e-13)
#print(rubixdata.gas.datacube)

In [None]:
#NBVAL_SKIP
from rubix.core.noise import get_apply_noise

apply_noise = get_apply_noise(config)

rubixdata = apply_noise(rubixdata)

datacube = rubixdata.gas.datacube
img = datacube.sum(axis=2)
plt.imshow(img, origin="lower")
datacube.shape

In [None]:
#subcube = datacube[:, :, 1200:1600]
#print(subcube.shape)

#img = subcube.sum(axis=2)
#plt.imshow(img, origin="lower")

In [None]:
#NBVAL_SKIP
plt.plot(wave, datacube[12,12,:])
#plt.plot(wave, datacube[0,0,:])

In [None]:
import numpy as np
import jax.numpy as jnp

log_QH = jnp.log10(rubixdata.gas.electron_abundance/rubixdata.gas.density)
print(jnp.nanmin(log_QH), jnp.nanmax(log_QH))
# n_H = rubixdata.gas.density
OH_ratio = rubixdata.gas.metals[:, 4] / rubixdata.gas.metals[:, 0]
NO_ratio = rubixdata.gas.metals[:, 3] / rubixdata.gas.metals[:, 4]
CO_ratio = rubixdata.gas.metals[:, 2] / rubixdata.gas.metals[:, 4]
#print(min(np.log10(rubixdata.gas.electron_abundance)), max(np.log10(rubixdata.gas.electron_abundance)))
print(min(rubixdata.gas.density), max(jnp.log10(rubixdata.gas.density)))
#print(min(np.log10(rubixdata.gas.metals[:, 4] / rubixdata.gas.metals[:, 0])), max(np.log10(rubixdata.gas.metals[:, 4] / rubixdata.gas.metals[:, 0])))
#print(min(rubixdata.gas.metals[:, 3] / rubixdata.gas.metals[:, 4]), max(rubixdata.gas.metals[:, 3] / rubixdata.gas.metals[:, 4]))
#print(min(rubixdata.gas.metals[:, 2] / rubixdata.gas.metals[:, 4]), max(rubixdata.gas.metals[:, 2] / rubixdata.gas.metals[:, 4]))

log_oh_sol = -3.07
log_co_sol = -0.37
log_no_sol = -0.88

oh_factor = 16/1
co_factor = 12/16
no_factor = 14/16

final_log_oh = jnp.log10(OH_ratio * oh_factor) / log_oh_sol
final_co = CO_ratio * co_factor / 10**log_co_sol
final_no = NO_ratio * no_factor / 10**log_no_sol

print(jnp.nanmin(final_log_oh), np.nanmax(final_log_oh))
print(OH_ratio)
print(OH_ratio.shape)
print(type(OH_ratio))
print(jnp.nanmin(final_co), np.nanmax(final_co))
print(jnp.min(final_co), jnp.max(final_co))
print(jnp.nanmin(final_no), jnp.nanmax(final_no))
print(jnp.min(final_no), jnp.max(final_no))

In [None]:
import jax.numpy as jnp

nan_count_no = jnp.sum(jnp.isnan(final_no))
print(nan_count_no)
print(len(final_no))

nan_count_co = jnp.sum(jnp.isnan(final_co))
print(nan_count_co)
print(len(final_co))

nan_count_oh = jnp.sum(jnp.isnan(final_log_oh))
print(nan_count_oh)
print(len(final_log_oh))

n_H = rubixdata.gas.density
nan_count_nh = jnp.sum(jnp.isnan(n_H))
print(nan_count_nh)
print(len(n_H))

In [None]:
import jax.numpy as jnp
import matplotlib.pyplot as plt

# Assuming rubixdata is already defined and contains the necessary data
boundary_log_u = [-4, -1]
boundary_log_nh = [1, 4]
boundary_log_OH = [-2.2, 0.5]
boundary_CO = [0.1, 5.4]
boundary_NO = [0.1, 5.4]

print(10**boundary_log_nh[0], 10**boundary_log_nh[1])

n_H = rubixdata.gas.density
# n_H = jnp.full(len(rubixdata.gas.mass), 10**2.5)
OH_ratio = rubixdata.gas.metals[:, 4] / rubixdata.gas.metals[:, 0]
NO_ratio = rubixdata.gas.metals[:, 3] / rubixdata.gas.metals[:, 4]
CO_ratio = rubixdata.gas.metals[:, 2] / rubixdata.gas.metals[:, 4]

log_oh_sol = -3.07
log_co_sol = -0.37
log_no_sol = -0.88

oh_factor = 16 / 1
co_factor = 12 / 16
no_factor = 14 / 16

final_log_oh = jnp.log10(OH_ratio * oh_factor) / log_oh_sol
final_log_co = jnp.log10(CO_ratio * co_factor) / log_co_sol
final_log_no = jnp.log10(NO_ratio * no_factor) / log_no_sol
log_QH = jnp.full(len(rubixdata.gas.mass), 49.58)

print(f"The galaxy is loaded with {len(rubixdata.gas.mass)} gas particles")

# Convert JAX arrays to NumPy arrays for plotting
n_H = jnp.asarray(n_H)
final_log_oh = jnp.asarray(final_log_oh)
final_log_co = jnp.asarray(final_log_co)
final_log_no = jnp.asarray(final_log_no)
log_QH = jnp.asarray(log_QH)

# Plot histograms
fig, axs = plt.subplots(2, 2, figsize=(20, 10))

axs[0, 0].hist(n_H, bins=50, color='blue', alpha=0.7)
axs[0, 0].set_title('Histogram of n_H')
axs[0, 0].set_xlabel('n_H')
axs[0, 0].set_ylabel('Frequency')
axs[0,0].set_xscale('log')
axs[0, 0].axvline(x=10**boundary_log_nh[0], color='red')
axs[0, 0].axvline(x=10**boundary_log_nh[1], color='red')

axs[0, 1].hist(final_log_oh, bins=50, color='green', alpha=0.7)
axs[0, 1].set_title('Histogram of final_log_oh')
axs[0, 1].set_xlabel('final_log_oh')
axs[0, 1].set_ylabel('Frequency')
axs[0, 1].axvline(x=boundary_log_OH[0], color='red')
axs[0, 1].axvline(x=boundary_log_OH[1], color='red')

axs[1, 0].hist(final_log_co, bins=50, color='red', alpha=0.7)
axs[1, 0].set_title('Histogram of final_log_co')
axs[1, 0].set_xlabel('final_log_co')
axs[1, 0].set_ylabel('Frequency')
axs[1, 0].axvline(x=boundary_CO[0], color='red')
axs[1, 0].axvline(x=boundary_CO[1], color='red')

axs[1, 1].hist(final_log_no, bins=50, color='purple', alpha=0.7)
axs[1, 1].set_title('Histogram of final_log_no')
axs[1, 1].set_xlabel('final_log_no')
axs[1, 1].set_ylabel('Frequency')
axs[1, 1].axvline(x=boundary_NO[0], color='red')
axs[1, 1].axvline(x=boundary_NO[1], color='red')

# Hide the empty subplot
#axs[2, 1].axis('off')

plt.tight_layout()
plt.show()

In [None]:
rubixdata.gas.mass
