In [None]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from sbifitter import MissingPhotometryHandler, SBI_Fitter

First load a trained model

In [None]:
grid_path = "/home/tharvey/work/output/BPASS_Chab_LogNorm_5_z_12_phot_grid2.hdf5"

fitter = SBI_Fitter.init_from_hdf5(
    "BPASS_Chab_LogNorm_5_z_12_phot_grid2", grid_path, return_output=False
)

fitter.load_model_from_pkl(
    "/home/tharvey/work/ltu-ili_testing/models/BPASS_Chab_LogNorm_5_z_12_phot_grid2",
    set_self=True,
);

Choose a random index from training data

In [None]:
index = np.random.choice(fitter._test_indices)

phot = fitter.feature_array[index]
true_params = fitter.fitted_parameter_array[index]

for i in range(len(phot)):
    print(fitter.feature_names[i], phot[i])

In [None]:
mask = np.zeros_like(phot, dtype=bool)
# Let's pretend we don't have F460M
mask[-4] = True

Create an instance of the MissingPhotomeryHandler class from our sbi model

In [None]:
phot_sampler = MissingPhotometryHandler.init_from_sbifitter(fitter, verbose=False)

Now we will take our training data, pretend we don't have one filter, and see how well we can recover the missing flux in that filter.

In [None]:
true = []
recovered = []

iters = np.random.choice(fitter._test_indices, size=200, replace=False)

for idx in tqdm(iters):
    phot = fitter.feature_array[idx]

    obs = {
        "mags_sbi": phot,
        "mags_unc_sbi": np.ones_like(phot) * 0.1,  # Say we have a 0.1 mag error
        "missing_mask": mask,
    }

    out = phot_sampler.process_observation(obs, noise_generator=None)
    true.append(phot[mask])
    recovered.append(np.mean(out["missing_photometry_dist"], axis=1))

true = np.array(true)
recovered = np.array(recovered)

In [None]:
plt.scatter(np.array(true)[:, 0], np.array(true)[:, 0] - recovered, alpha=1)
plt.ylabel(r"$\Delta$ Recovered Mag - True Mag")
plt.xlabel("True Mag")

In [None]:
plt.hist(out["missing_photometry_dist"], bins=10)
plt.vlines(out["missing_photometry_dist"].mean(), 0, 20, color="r", label="Mean")
plt.vlines(phot[-4], 0, 20, color="g", label="True Value")

Now a more extreme example, where we remove many filters.

In [None]:
true = []
recovered = []

iters = np.random.choice(fitter._test_indices, size=200, replace=False)

for idx in tqdm(iters):
    phot = fitter.feature_array[idx]

    obs = {
        "mags_sbi": phot,
        "mags_unc_sbi": np.ones_like(phot) * 0.1,  # Say we have a 0.1 mag error
        "missing_mask": mask,
    }

    out = phot_sampler.process_observation(obs, noise_generator=None)
    true.append(phot[mask])
    recovered.append(np.mean(out["missing_photometry_dist"], axis=1))

true = np.array(true)
recovered = np.array(recovered)

In [None]:
# Let's pretend we only have the JADES filters

mask = [1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1]
mask = np.array(mask, dtype=bool)

# Ok, pretending we only had the JADES filters ran into
# issues with the KDE's and low dimenionality.
# Probably can't stretch the model that far.

# Let's just hide three filters.

mask = np.zeros_like(phot, dtype=bool)
mask[-3] = True
mask[-5] = True
mask[4] = True

In [None]:
true = []
recovered = []

iters = np.random.choice(fitter._test_indices, size=30, replace=False)

for idx in tqdm(iters):
    phot = fitter.feature_array[idx]

    obs = {
        "mags_sbi": phot,
        "mags_unc_sbi": np.ones_like(phot) * 0.1,  # Say we have a 0.1 mag error
        "missing_mask": mask,
    }

    out = phot_sampler.process_observation(obs, noise_generator=None)
    true.append(phot[mask])
    recovered.append(np.mean(out["missing_photometry_dist"], axis=0))

true = np.array(true)
recovered = np.array(recovered)

Here we plot for each missing filter the discrepancy between the median recovered flux and the true flux. The closer to zero, the better the recovery. For this naive case, we can recover the flux within around 0.01 magnitudes.

In [None]:
fig, ax = plt.subplots(1, np.sum(mask), figsize=(20, 5))

for i in range(np.sum(mask)):
    ax[i].scatter(np.array(true)[:, i], np.array(true)[:, i] - recovered[:, i], alpha=1)
    ax[i].set_ylabel(r"$\Delta$ Recovered Mag - True Mag")
    ax[i].set_xlabel("True Mag")
    missing_filter_idx = np.where(mask)[0][i]
    missing_filter_name = fitter.feature_names[missing_filter_idx]
    ax[i].set_title(missing_filter_name)

In [None]:
grid_path = """/home/tharvey/work/output/
            grid_Pop_II_LogNormal_SFH_5_z_12_logN_5.0_BPASS_Chab_v1.hdf5"""

fitter = SBI_Fitter.init_from_hdf5("test", grid_path, return_output=False)