In [None]:
import os
from datetime import datetime
from pathlib import Path

import numba_extinction.numba_extinction as ne
import numpy as np
import pandas as pd
import speclite.filters as filters
from astropy import units

# can be anything from astropy.cosmology
from astropy.cosmology import Planck18 as cosmology

from quest_qso import mlconfig as cfg
from quest_qso.photometry import generate_photometry as gp
from quest_qso.scripts import generate_photometry as gpscript
from quest_qso.utils import generate_photometry_utils as gp_qa
from quest_qso.utils import resources, utilities

In [None]:
# Set all the variables. This would be handled by argparse in the script version
#  but for the purpose of this notebook we are setting them manually here
globals_ = {}

# only related to this notebook
globals_["PLOT_FILTER_RESPONSE"] = True

# default folder for everything
globals_["LOCAL_PATH"] = Path(os.getenv("ML_QSO_MODEL_LOCALPATH"))
# This should be the path that contains the json model file and the trained model files
globals_["MODEL_PATH"] = Path(os.getenv("QUEST_GP_MODEL_LOCALPATH"))

# timestamps, used to save globals_puts in unique folders
globals_["DAY"] = datetime.today().strftime("%Y%m%d")
globals_["TIMESTAMP"] = datetime.now().strftime("%Y%m%d_%H%M%S")

# clean intermediate products, useful to lower the memory footprint
globals_["CLEAR_MEMORY"] = False
globals_["SEED"] = 42

# make QA plots? How many examples to plot?
globals_["MAKE_QA_PLOTS"] = True
globals_["N_EXAMPLES"] = 10

# Simulation parameters for the quasar grid
# number of objects per mag-z bin
globals_["N_PER_BIN"] = 5

# number of redshift bins, and redshift limits
globals_["N_Z"] = 20
globals_["LOW_Z_LIM"] = 1.0
globals_["HIGH_Z_LIM"] = 5.0

# number of magnitude bins, and magnitude limits
globals_["N_M1450"] = 20
globals_["FAINT_M1450_LIM"] = -22.0
globals_["BRIGHT_M1450_LIM"] = -30.0

# Reddening parameters
globals_["REDDENING_MODEL"] = ne.Go23
globals_["B_BOUNDS"] = 15 * units.deg
globals_["R_V"] = 3.1

# Sampling mode, valid are "uniform", "lf" and "hist2d",
# note that hist2d is currently not supported as an option from command line and as a consequence here
globals_["SAMPLE_MODE"] = "uniform"

# These parameters are only needed for "lf" mode and ignored otherwise
# The package currently leverages Atelier (https://github.com/jtschindler/atelier) to perform the sampling
#  in lf mode
# has to be the same as one of the classes in atelier/lumfun.py
globals_["CHOSEN_LF"] = "Matsuoka2023DPLQLF"
globals_["LF_SKY_AREA"] = 10000  # in sq deg

# print summary of parameters
gpscript.print_cfg_summary(globals_)

In [None]:
# load the speclite filters
SVO_SDSS = filters.load_filters("sdss2010-*")

# optionally, plot them too
if globals_["PLOT_FILTER_RESPONSE"]:
    from quest_qso.photometry import custom_filters as cf

    fig, ax = cf.plot_filters(
        SVO_SDSS,
        figsize=(8, 8 / 1.61 / 2),
        legend=False,
        add_filter_name=True,
        add_effective_wavelength=True,
    )

In [None]:
# generate the grid parameters used in generating the photometry
# this sets the redshift and magnitude grid according to the input parameters
#  i.e., generates a list of tuples (M1450, z) that will be assigned to each sampled quasar
#  and use to scale its flux

# generate the parameters
z_M1450_grid_params = gpscript.generate_grid_params(globals_)

# actually instantiate the grid
z_M1450_grid = gp.generate_grid(params=z_M1450_grid_params)

In [None]:
# Sample spectra from the VAE model based on the generated grid

# Set the parameters of the model, and load a pre-trained model
model_param_fname = globals_["MODEL_PATH"] / "params.json"
model_params = cfg.MLConfig().from_json(model_param_fname)
model_params.print_params()

In [None]:
# Load the model, prepare it for sampling and generate spectra based on the user defined
#  parameters
spectra, dispersion = gp.sample_from_VAE(z_M1450_grid, model_params)

# Plot some example spectra
gp_qa.plot_example_spectra(
    globals_["N_EXAMPLES"],
    spectra,
    dispersion,
    "GenPhotExample.png",
    ylabel="Flux Density [A.U.]",
    dir=globals_["LOCAL_PATH"] / "QA" / str(globals_["DAY"]) / "GeneratePhotometry",
    make_qa=globals_["MAKE_QA_PLOTS"],
)

In [None]:
# The spectra we have now are in arbitrary units, we need to scale them to physical units
#  based on the assigned redshift and M1450

# computes the scaling factor for each spectra, based on its assigned redshift and M1450
# NB: it modifies grid.grid_data in place, adding the corresponding columns
# Worth noting:
# - The scaling is done to produce a spectrum in erg / (cm^2 s AA)
# - Cosmology is currently Planck 18, but can be easily changed as long as it is an astropy cosmology object
gp.compute_scale_factor(dispersion, spectra, z_M1450_grid.grid_data, cosmology)

# Scale each spectrum based on the scale value compute in the previous cell
scaled_spectra = gp.scale_VAE_spectra(spectra, z_M1450_grid.grid_data["scale"])

gp_qa.plot_example_spectra(
    globals_["N_EXAMPLES"],
    scaled_spectra,
    dispersion,
    "GenPhotExample_Scaled.png",
    ylabel=r"Flux density [erg s$^{-1}$ cm$^{-2}$ $\AA^{-1}$]",
    dir=globals_["LOCAL_PATH"] / "QA" / str(globals_["DAY"]) / "GeneratePhotometry",
    make_qa=globals_["MAKE_QA_PLOTS"],
)

if globals_["CLEAR_MEMORY"]:
    print("[INFO] Clearing memory and removing intermediate products.")
    del spectra

In [None]:
# Resample and shift to observed frame the spectra on a new wavelength grid.

# The user can specify the desired wavelength grid, or the grid can be computed based
#  on the requested filters.
# In the first case, the easiest thing is to use utilities.gen_wave_grid(), as follows:
#  new_rest_frame_dispersion = utilities.gen_wave_grid(7000 * units.AA, 10000 * units.AA, 140 * utilities.kms)
# Note that the new grid will be equally spaced in velocity space, and units are required.
#  the returned array is a quantity array!
# Otherwise, if no grid is provided, then the grid is computed based on the filters set used
# Likewise, the wavelength grid has constant bin width in velocity space

# returns the resampled spectra, and the new (common) wavelength grid
resampled_spectra, new_dispersion = gp.resample_on_wavelength_grid(
    SVO_SDSS,  # Set of filter response functions
    dispersion,  # Current dispersion (rest frame, it is multiplied by redshif)
    z_M1450_grid.grid_data["redshift"].to_numpy(),  # redshfit per each object
    scaled_spectra,  # Spectra in physical units
    new_rest_frame_dispersion=None,  # Computed based on filters, otherwise see above
)

gp_qa.plot_example_spectra(
    globals_["N_EXAMPLES"],
    resampled_spectra,
    new_dispersion,
    "GenPhotExample_Scaled_Resampled.png",
    ylabel=r"Flux density [erg s$^{-1}$ cm$^{-2}$ $\AA^{-1}$]",
    dir=globals_["LOCAL_PATH"] / "QA" / str(globals_["DAY"]) / "GeneratePhotometry",
    make_qa=globals_["MAKE_QA_PLOTS"],
)

# clear out some memory, otherwise things get very, very slow
if globals_["CLEAR_MEMORY"]:
    print("[INFO] Clearing memory and removing intermediate products.")
    del scaled_spectra

In [None]:
# Compute the IGM transmission from SimQSO and apply it to the resampled spectra
# Note that this is slow! And it gets slower the wider the wavelength coverage is
resampled_spectra_applied_IGM = gp.compute_apply_IGM_simqso(
    new_dispersion,
    z_M1450_grid,
    resampled_spectra,
)

gp_qa.plot_example_spectra(
    globals_["N_EXAMPLES"],
    resampled_spectra_applied_IGM,
    new_dispersion,
    "GenPhotExample_Scaled_Resampled_IGMApplied.png",
    ylabel=r"Flux density [erg s$^{-1}$ cm$^{-2}$ $\AA^{-1}$]",
    dir=globals_["LOCAL_PATH"] / "QA" / str(globals_["DAY"]) / "GeneratePhotometry",
    make_qa=globals_["MAKE_QA_PLOTS"],
)

In [None]:
# Optionally, re-apply reddening to the spectra as we use dereddened spectra while training
# Note that this modifies the spectra in place!
gp.redden_sampled_spectra(
    new_dispersion,
    resampled_spectra_applied_IGM,
    globals_["REDDENING_MODEL"],
    globals_["R_V"],
    b_bounds=globals_["B_BOUNDS"],
);

In [None]:
# save array of generated spectra and dataframe grid
generated_spectra_output = globals_["LOCAL_PATH"] / "generated_spectra"
if not generated_spectra_output.exists():
    generated_spectra_output.mkdir(parents=True)

np.save(
    generated_spectra_output
    / f"VAE_Spectra_with_IGM_dispersion_{globals_['TIMESTAMP']}.npy",
    new_dispersion.value,
)

np.save(
    generated_spectra_output / f"VAE_Spectra_with_IGM_{globals_['TIMESTAMP']}.npy",
    resampled_spectra_applied_IGM,
)

z_M1450_grid.grid_data.to_hdf(
    generated_spectra_output / f"Param_grid_{globals_['TIMESTAMP']}.hdf5",
    key="data",
)

In [None]:
# Finally, generate new magnitudes based on the output of the previous cells
mags = pd.concat(
    (
        z_M1450_grid.grid_data,
        SVO_SDSS.get_ab_magnitudes(
            resampled_spectra_applied_IGM, new_dispersion
        ).to_pandas(),
    ),
    axis=1,
)

# get the bands used - note that this does not ensure that the order of the filters is always the same
bands = gp_qa.merge_bands(mags.columns[5:])

# generate a (long but informative) filename
outfilepath = globals_["LOCAL_PATH"] / "generated_photometry" / str(globals_["DAY"])
if not outfilepath.exists():
    outfilepath.mkdir(parents=True)

# Conventionally, - indicates separation between different fields, _ indicates separation within fields
outfilename = (
    f"GeneratedPhot-no_pert-{bands}-"
    + f"{globals_['SAMPLE_MODE']}-"
    + (f"{globals_['CHOSEN_LF']}" if globals_["SAMPLE_MODE"] == "lf" else "")
    + (f"{globals_['LF_SKY_AREA']}-" if globals_["SAMPLE_MODE"] == "lf" else "")
    + f"seed_{globals_['SEED']}-"
    + f"redsh_{str(globals_['LOW_Z_LIM']).replace('.', 'p')}to{str(globals_['HIGH_Z_LIM']).replace('.', 'p')}_"
    + f"M1450_m{np.abs(globals_['BRIGHT_M1450_LIM'])}tom{np.abs(globals_['FAINT_M1450_LIM'])}_{globals_['TIMESTAMP']}.hdf5"
)

# save catalogues
mags.to_hdf(
    outfilepath / outfilename,
    key="data",
)

print(f"[INFO] Catalogue saved to {outfilepath / outfilename}")
print("[INFO] Photometry generation completed!")
print("[INFO] Edit and run `add_pert.py` with the appropriate filters to add noise.")

In [None]:
# the magnitudes we generated are error-free
# we add some noise in the next cells

# pandas does not like units, so we temporarily go back to numpy arrays
mags_np = mags[
    [
        "sdss2010-u",
        "sdss2010-g",
        "sdss2010-r",
        "sdss2010-i",
        "sdss2010-z",
    ]
].to_numpy()

flux_np = gp.AB_to_flux(mags_np)

In [None]:
# load the error functions
error_function_sdss = (
    utilities.pandas_to_recarray(resources.load_sdss_error_functions())
    * units.microJansky
)

perturbed_photometry = {}
for n, band in enumerate(["U", "G", "R", "I", "Z"]):
    # the output are in this order:
    # band_perturbed, band_err_perturbed, band_flux_perturbed, band_snr
    perturbed_photometry[f"SDSS_{band}"] = gp.generate_perturbed_magnitudes(
        flux_np[:, n],
        error_function_sdss,
        f"SDSS_{band}",
        flag=99.0,
    )

In [None]:
# structure the output in a pandas dataframe
# _pert -> perturbed magnitudes
# _sigma -> errors on perturbed magnitudes
# _flux -> perturbed flux from which I compute the magnitudes

mags_perturbed = utilities.numpy_to_pandas(
    [
        "SDSS_U_pert",
        "SDSS_U_sigma",
        "SDSS_U_flux",
        "SDSS_U_snr",
        "SDSS_G_pert",
        "SDSS_G_sigma",
        "SDSS_G_flux",
        "SDSS_G_snr",
        "SDSS_R_pert",
        "SDSS_R_sigma",
        "SDSS_R_flux",
        "SDSS_R_snr",
        "SDSS_I_pert",
        "SDSS_I_sigma",
        "SDSS_I_flux",
        "SDSS_I_snr",
        "SDSS_Z_pert",
        "SDSS_Z_sigma",
        "SDSS_Z_flux",
        "SDSS_Z_snr",
    ],
    np.array(
        (
            perturbed_photometry["SDSS_U"][0].value,
            perturbed_photometry["SDSS_U"][1].value,
            perturbed_photometry["SDSS_U"][2].value,
            perturbed_photometry["SDSS_U"][3].value,
            perturbed_photometry["SDSS_G"][0].value,
            perturbed_photometry["SDSS_G"][1].value,
            perturbed_photometry["SDSS_G"][2].value,
            perturbed_photometry["SDSS_G"][3].value,
            perturbed_photometry["SDSS_R"][0].value,
            perturbed_photometry["SDSS_R"][1].value,
            perturbed_photometry["SDSS_R"][2].value,
            perturbed_photometry["SDSS_R"][3].value,
            perturbed_photometry["SDSS_I"][0].value,
            perturbed_photometry["SDSS_I"][1].value,
            perturbed_photometry["SDSS_I"][2].value,
            perturbed_photometry["SDSS_I"][3].value,
            perturbed_photometry["SDSS_Z"][0].value,
            perturbed_photometry["SDSS_Z"][1].value,
            perturbed_photometry["SDSS_Z"][2].value,
            perturbed_photometry["SDSS_Z"][3].value,
        )
    ).T,
)

catalogue = pd.concat((mags, mags_perturbed), axis=1)
catalogue

In [None]:
# Finally, save catalogues in the same folder as the previous one, but with an updated filename
catalogue.to_hdf(
    outfilepath / outfilename.replace("no_pert", "w_pert"),
    key="data",
)

print(f"Catalogue saved to {outfilepath / outfilename}")