In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
from astropy.io import fits
from unyt import um

from synference import SBI_Fitter

In [None]:
fitter = SBI_Fitter.init_from_hdf5(
    "Spectra_BPASS_Chab_Continuity_SFH_0.01_z_14_logN_4.4_Calzetti_v4_multinode",
    "/cosma7/data/dp276/dc-harv3/work/sbi_grids/grid_spectra_BPASS_Chab_Continuity_SFH_0.01_z_14_logN_4.4_Calzetti_v4_multinode.hdf5",  # noqa: E501
)

In [None]:
tab = fits.getdata("/cosma/apps/dp276/dc-harv3/synference/priv/jwst_nirspec_prism_disp.fits")
wavs = tab["WAVELENGTH"] * um
R = tab["R"]

In [None]:
fitter.create_feature_array(
    flux_units="log10 nJy",
    crop_wavelength_range=(0.6, 5.0),
    resample_wavelengths=wavs,
    inst_resolution_wavelengths=wavs,
    inst_resolution_r=R,
    theory_r=np.inf,
    min_flux_value=-10,
)

In [None]:
idx = 17

import matplotlib.pyplot as plt
import numpy as np

redshift = fitter.fitted_parameter_array[idx][0]
mass = fitter.fitted_parameter_array[idx][1]
plt.plot(
    wavs,
    -2.5 * np.log10((1 + redshift) * 10 ** fitter.feature_array[idx]) + 31.4,
    label="Observed [NIRSpec PRISM]",
)
plt.gca().invert_yaxis()

print(f"Redshift: {redshift:.2f}, log10(M/Msun): {mass:.2f}")
plt.text(
    3.7,
    32.5,
    f"$z$={redshift:.2f}\n$log_{{10}}(M/M_\\odot)={mass:.2f}$",
    fontsize=12,
    color="black",
    bbox=dict(facecolor="lightgrey", alpha=0.8, edgecolor="none"),
)
# Show Lyman break
"""plt.axvline(0.1216 * (1 + redshift), color='k', linestyle='--', label='Lyman break')
plt.axvline(0.0912 * (1 + redshift), color='k', linestyle=':', label='Lyman limit')
# Show Hb, Ha
plt.axvline(0.4861 * (1 + redshift), color='r', linestyle='--', label='Hβ')
plt.axvline(0.6563 * (1 + redshift), color='r', linestyle=':', label='Hα')
# Show OIII, OII
plt.axvline(0.5007 * (1 + redshift), color='g', linestyle='--', label='[OIII]')
plt.axvline(0.3727 * (1 + redshift), color='g', linestyle=':', label='[OII]')"""

# normal_spec
ax = plt.gca()
ax.set_ylim(ax.get_ylim())
ax.set_xlim(ax.get_xlim())
orig = -2.5 * np.log10(fitter.raw_observation_grid[:, idx]) + 31.4
plt.plot(
    fitter.raw_observation_names * (1 + redshift),
    orig,
    color="gray",
    alpha=0.5,
    linewidth=0.5,
    label="Intrinsic",
)
plt.xlabel("Wavelength [micron]")
plt.ylabel("Mag [ABmag]")
plt.legend()
plt.ylim(35, None)
plt.xlim(0.6, 10)
plt.xscale("log")

In [None]:
run_model = SBI_Fitter.load_saved_model(
    "/cosma/apps/dp276/dc-harv3/synference/models/Spectra_BPASS_Chab_Continuity_SFH_0.01_z_14_logN_4.4_Calzetti_v4_multinode" # noqa: E501
)

In [None]:
samples = run_model.sample_posterior(
    X_test=fitter.feature_array[idx],
    num_samples=100,
)

In [None]:
run_model.recreate_simulator_from_library();

In [None]:
sim = run_model.simulator

sim.param_transforms["Av"] = ("tau_v", lambda Av: Av / 1.086)
sim.output_type = ["fnu", "photo_fnu"]
sim.out_flux_unit = "nJy"

recovered_sed = []

from tqdm import trange

for i in trange(samples.shape[0]):
    input = {
        run_model.fitted_parameter_names[j]: samples[i, j]
        for j in range(len(run_model.fitted_parameter_names))
    }
    out = sim(input)
    fnu = out["fnu"]
    wav = out["fnu_wav"]
    recovered_sed.append(fnu)

recovered_sed = np.array(recovered_sed)

recovered_16, recovered_50, recovered_84 = np.percentile(recovered_sed, [16, 50, 84], axis=0)

from synference.utils import transform_spectrum

In [None]:
import matplotlib.pyplot as plt
import numpy as np

redshift = fitter.fitted_parameter_array[idx][0]
mass = fitter.fitted_parameter_array[idx][1]
plt.plot(
    wavs,
    -2.5 * np.log10((1 + redshift) * 10 ** fitter.feature_array[idx]) + 31.4,
    label="Observed [NIRSpec PRISM]",
)
plt.gca().invert_yaxis()

print(f"Redshift: {redshift:.2f}, log10(M/Msun): {mass:.2f}")
plt.text(
    3.7,
    32.5,
    f"$z$={redshift:.2f}\n$log_{{10}}(M/M_\\odot)={mass:.2f}$",
    fontsize=12,
    color="black",
    bbox=dict(facecolor="lightgrey", alpha=0.8, edgecolor="none"),
)
# Show Lyman break
"""plt.axvline(0.1216 * (1 + redshift), color='k', linestyle='--', label='Lyman break')
plt.axvline(0.0912 * (1 + redshift), color='k', linestyle=':', label='Lyman limit')
# Show Hb, Ha
plt.axvline(0.4861 * (1 + redshift), color='r', linestyle='--', label='Hβ')
plt.axvline(0.6563 * (1 + redshift), color='r', linestyle=':', label='Hα')
# Show OIII, OII
plt.axvline(0.5007 * (1 + redshift), color='g', linestyle='--', label='[OIII]')
plt.axvline(0.3727 * (1 + redshift), color='g', linestyle=':', label='[OII]')"""

# normal_spec
ax = plt.gca()
ax.set_ylim(ax.get_ylim())
ax.set_xlim(ax.get_xlim())
orig = -2.5 * np.log10(fitter.raw_observation_grid[:, idx]) + 31.4
plt.plot(
    fitter.raw_observation_names * (1 + redshift),
    orig,
    color="gray",
    alpha=0.5,
    linewidth=0.5,
    label="Intrinsic",
)
plt.xlabel("Wavelength [micron]")
plt.ylabel("Mag [ABmag]")
plt.legend()

plt.xlim(0.6, 10)
plt.xscale("log")

"""
_, transformed_16 = transform_spectrum(
    theory_wave=wav.to('um').value,
    theory_flux=recovered_16,
    z=redshift,
    observed_wave=wavs.to('um').value,
    resolution_curve_wave=wavs.to('um').value,
    resolution_curve_r=R)
"""
    _, transformed_50 = transform_spectrum(
    theory_wave=wav.to("um").value / (1 + redshift),
    theory_flux=recovered_50,
    z=redshift,
    observed_wave=wavs.to("um").value,
    resolution_curve_wave=wavs.to("um").value,
    resolution_curve_r=R,
)

"""
_, transformed_84 = transform_spectrum(
    theory_wave=wav.to('um').value,
    theory_flux=recovered_84,
    z=redshift,
    observed_wave=wavs.to('um').value,
    resolution_curve_wave=wavs.to('um').value,
    resolution_curve_r=R)

plt.fill_between(wavs.to('um').value, 
          -2.5*np.log10((1+redshift)*transformed_16) + 31.4, 
          -2.5*np.log10((1+redshift)*transformed_84) + 31.4,
          color='C1', alpha=0.3, label='68% credible interval')
"""

plt.plot(
    wavs.to("um").value,
    -2.5 * np.log10((1 + redshift) * transformed_50) + 40,
    color="C1",
    label="Recovered median SED",
)
plt.legend()

In [None]:
plt.plot(
    wavs, -2.5 * np.log10(10 ** fitter.feature_array[idx]) + 31.4, label="Observed [NIRSpec PRISM]" # noqa: E501
)
plt.gca().invert_yaxis()

print(f"Redshift: {redshift:.2f}, log10(M/Msun): {mass:.2f}")
plt.text(
    3.7,
    32.5,
    f"$z$={redshift:.2f}\n$log_{{10}}(M/M_\\odot)={mass:.2f}$",
    fontsize=12,
    color="black",
    bbox=dict(facecolor="lightgrey", alpha=0.8, edgecolor="none"),
)
# normal_spec
ax = plt.gca()
ax.set_ylim(ax.get_ylim())
ax.set_xlim(ax.get_xlim())
orig = -2.5 * np.log10(fitter.raw_observation_grid[:, idx]) + 31.4
plt.plot(
    fitter.raw_observation_names * (1 + redshift),
    orig,
    color="gray",
    alpha=0.5,
    linewidth=0.5,
    label="Intrinsic",
)
plt.xlabel("Wavelength [micron]")
plt.ylabel("Mag [ABmag]")
plt.xlim(0.6, 10)
plt.xscale("log")


plt.plot(
    wavs.to("um").value, -2.5 * np.log10(transformed_50), color="C1", label="Recovered median SED" # noqa: E501
)
plt.legend()

In [None]:
np.log10(1 + redshift)

In [None]:
plt.plot(
    wav.to("um").value,
    -2.5 * np.log10((1 + redshift) * recovered_50) + 31.4,
    color="C1",
    label="Recovered median SED",
) # noqa: E501
plt.plot(
    wavs.to("um").value,
    -2.5 * np.log10((1 + redshift) * transformed_50) + 31.4,
    color="C2",
    label="Recovered median SED (resampled)",
) # noqa: E501

plt.ylim(36, 29)
plt.xlim(0.6, 10)

In [None]:
plt.plot(
    wavs.to("um").value,
    -2.5 * np.log10((1 + redshift) * transformed_50) + 31.4,
    color="C2",
    label="Recovered median SED (resampled)",
)
plt.ylim(50, 29)