    Copyright 2024 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 
    Under the terms of Contract DE-NA0003525 with NTESS, the U.S. Government retains certain rights in this software.

**Step 6 - compute results on synthetic data**

In [None]:
"""Imports"""

import os

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from config import (BG_MIX_FILE, DATA_DIR, FG_SEED_FILE, IMAGE_DIR,
                    MIX_TRAIN_FG_LAMBDA, OOD_SEED_FILE, RANDOM_SEED,
                    STATIC_SYNTH_BG_CPS, STATIC_SYNTH_LIVE_TIME_RANGE,
                    STATIC_SYNTH_LIVE_TIME_SAMPLING, STATIC_SYNTH_LONG_BG_CPS,
                    STATIC_SYNTH_SNR_RANGE, STATIC_SYNTH_SNR_SAMPLING,
                    STATIC_SYNTH_SPS, TARGET_BINS, TARGET_ECAL, TEST_FILE,
                    TRAIN_FILE)
from matplotlib import cm
from matplotlib.ticker import FormatStrFormatter
from mpl_toolkits.axes_grid1 import make_axes_locatable
from riid.data.sampleset import SampleSet, read_hdf
from riid.data.synthetic.seed import SeedMixer
from riid.data.synthetic.static import StaticSynthesizer
from riid.visualize import plot_spectra
from sklearn.metrics import mean_absolute_error as mae
from sklearn.metrics import r2_score
from utils import load_final_models

In [None]:
"""Load in final model for each unsupervised loss (either locally or with W&B)."""

best_runs, best_models = load_final_models()

In [None]:
"""Focus unsupervised loss."""

plt_names = ["$\chi^2$",  "JSD", "PNLL", "SSE"]
focus_unsup_loss = 1  # JSD

In [None]:
"""Generate SME reconstruction."""

fg_seeds_ss = read_hdf(FG_SEED_FILE)
fg_seeds_ss, _ = fg_seeds_ss.split_fg_and_bg()
fg_seeds_ss.drop_sources_columns_with_all_zeros()

# get expected source contributions before anything else!
source_counts = {
    x.split(",")[0]: v
    for x, v in zip(
        fg_seeds_ss.sources.columns.get_level_values("Seed").values,
        fg_seeds_ss.info.total_counts
    )
}
Z = np.array(list(source_counts.values()))
expected_props = Z / Z.sum()

fg_seeds_ss.downsample_spectra(target_bins=TARGET_BINS)
fg_seeds_ss = fg_seeds_ss.as_ecal(*TARGET_ECAL)
fg_seeds_ss.normalize(p=1)

reconstructed_fission_seed = np.zeros((1, TARGET_BINS))
for i in range(fg_seeds_ss.n_samples):
    reconstructed_fission_seed += expected_props[i] * fg_seeds_ss.spectra.values[i,:]

reconstructed_fission_seed_ss = SampleSet()
reconstructed_fission_seed_ss.spectra = pd.DataFrame(
    data=reconstructed_fission_seed
)
reconstructed_fission_seed_ss.sources = fg_seeds_ss.sources.drop(
    fg_seeds_ss.sources.index.to_list()[1:], axis=0
)
reconstructed_fission_seed_ss.sources.iloc[0] = expected_props
reconstructed_fission_seed_ss.info = pd.DataFrame(
    data=fg_seeds_ss.info.values[0,:].reshape(1,len(fg_seeds_ss.info.values[0,:])),
    columns=fg_seeds_ss.info.columns
)

fig, ax = plot_spectra(reconstructed_fission_seed_ss, title=None, show=False, in_energy=True)
ax.legend(["SME Reconstruction"])
ax.set(title=None)
plt.show()

isotope_order = np.argsort(expected_props)[::-1]

In [None]:
"""Load IND synth train and test datasets."""

ind_synth_test_ss = read_hdf(TEST_FILE)
ind_synth_train_ss = read_hdf(TRAIN_FILE)

In [None]:
"""Plot learning curves."""

CM = cm.tab20

fig, axes = plt.subplots(2, 2, figsize=(10,6), sharex=True)

for idx, ax in enumerate(axes.reshape(-1)):
    ax.plot(best_models[idx].history["loss"], label="train", color=CM(0))
    ax.plot(best_models[idx].history["val_loss"], label="validation", color=CM(1))
    ax.axhline(np.array(best_models[idx].history["loss"]).min(), linestyle="--", color=CM(0))
    ax.axhline(np.array(best_models[idx].history["val_loss"]).min(), linestyle="--", color=CM(1))
    ax.set_title(f"Unsupervised Loss: {plt_names[idx]}")

axes[1,1].legend()
fig.supxlabel("Epoch")
fig.supylabel("Loss")
fig.tight_layout()
fig.savefig(os.path.join(IMAGE_DIR, "training_curves.png"), dpi=300)
plt.show()

In [None]:
"""Do forward pass with all models on IND dataset."""

ind_synth_test_preds = []
all_maes = []
all_snrs = []
all_recon_errors = []
all_preds = []
all_trues = []

for idx, unsup_loss in enumerate(best_runs.keys()):
    tmp_ss = ind_synth_test_ss[:]
    best_models[idx].predict(tmp_ss, bg_cps=STATIC_SYNTH_BG_CPS)
    y_true = tmp_ss.sources.values
    y_pred = tmp_ss.prediction_probas.values
    maes = mae(y_true.T, y_pred.T, multioutput="raw_values")
    snrs = tmp_ss.spectra.values.sum(axis=1) / \
        np.sqrt(STATIC_SYNTH_BG_CPS * tmp_ss.info.live_time.values)
    recon_errors = tmp_ss.info[best_models[idx].unsup_loss_func_name]

    all_maes.append(maes)
    all_snrs.append(snrs)
    all_recon_errors.append(recon_errors)
    all_preds.append(y_pred.flatten())
    all_trues.append(y_true.flatten())

    ind_synth_test_preds.append(tmp_ss)

In [None]:
"""Show random prediction."""

sample_num = np.random.randint(ind_synth_test_ss.n_samples)
fig, ax = plt.subplots(figsize=(12,6))
ax.scatter(np.arange(1, len(fg_seeds_ss)+1), ind_synth_test_ss.sources.values[sample_num], label="true")
for idx, model in enumerate(best_models):
    ax.scatter(
        np.arange(1, len(fg_seeds_ss)+1),
        ind_synth_test_preds[idx].prediction_probas.values[sample_num],
        label=f"{plt_names[idx]} prediction",
        alpha=0.5
    )
ax.set_title(f"Sample Number: {sample_num}, SNR = {ind_synth_test_ss.info.snr[sample_num]:.1f}")
ax.set_xticks(np.arange(1, fg_seeds_ss.n_samples+1), list(fg_seeds_ss.get_labels()), rotation=60)
ax.set_xlabel("Isotope")
ax.set_ylabel("Proportion")
ax.legend()
ax.grid(alpha=0.3)
fig.tight_layout()
plt.show()

In [None]:
"""Get IND Synthetic Test Plots for LPE."""
nbins = 10

# first generate mae vs snr plots for each
fig, axes = plt.subplots(2, 2, figsize=(10,6), sharex=True, sharey=True)

for idx, ax in enumerate(axes.reshape(-1)):
    plt_maes = all_maes[idx]
    snrs = all_snrs[0]
    plt_snr_steps = np.logspace(
        np.log10(min(snrs)),
        np.log10(max(snrs)),
        nbins+1
    )
    plt_snr_inds = [np.where(
        (snrs > plt_snr_steps[i]) & (snrs <= plt_snr_steps[i+1])
    )[0].astype(int) for i in range(nbins)]
    plt_maes = [plt_maes[each] for each in plt_snr_inds]
    plt_snrs = [np.mean(snrs[each]) for each in plt_snr_inds]

    width = lambda p, w: 10**(np.log10(p)+w/2.)-10**(np.log10(p)-w/2.)

    ax.boxplot(
        plt_maes,
        positions=plt_snrs,
        vert=1,
        widths=width(plt_snrs, 0.1)
    )
    ax.set_xscale("log")
    ax.set_title(f"Unsupervised Loss: {plt_names[idx]}")

fig.supxlabel("SNR")
fig.supylabel("MAE")
fig.tight_layout()
fig.savefig(os.path.join(
    IMAGE_DIR,
    "ind_synth_test_mae_vs_snr.png"
), dpi=300)
plt.show()


# generate calibration plots on test data
fig, axes = plt.subplots(2, 2, figsize=(10,6), sharex=True, sharey=True)

plt_max = np.hstack([
    np.array(all_preds).flatten(),
    np.array(all_trues).flatten()
]).max() + 0.05

for idx, ax in enumerate(axes.reshape(-1)):
    y_true = all_trues[idx]
    y_pred = all_preds[idx]
    ax.plot(
        np.linspace(0, 1.0, 1000),
        np.linspace(0, 1.0, 1000),
        label="Ideal Calibration",
        linestyle="--",
        color="black"
    )
    ax.scatter(
        y_true,
        y_pred,
        label=f"Sample (MAE = {mae(y_true, y_pred):.5f}, $r^2$ = {r2_score(y_true, y_pred):.3f})",
        alpha=0.5
    )
    ax.set_xlim((0, plt_max))
    ax.set_ylim((0, plt_max))
    ax.legend(loc="upper left")
    ax.set_title(f"Unsupervised Loss: {plt_names[idx]}")

fig.supxlabel("True Proportion")
fig.supylabel("Predicted Proportion")
fig.tight_layout()
fig.savefig(os.path.join(
    IMAGE_DIR,
    "ind_synth_test_calibration.png"
), dpi=300)
plt.show()

# generate calibration plots on hihg-snr test data
snr_threshold = 100
valid_inds = np.where(all_snrs[0] > snr_threshold)[0]

fig, axes = plt.subplots(2, 2, figsize=(10,6), sharex=True, sharey=True)

plt_max = np.hstack([
    np.array(all_preds).flatten(),
    np.array(all_trues).flatten()
]).max() + 0.05

for idx, ax in enumerate(axes.reshape(-1)):
    y_true = ind_synth_test_preds[idx].sources.values[valid_inds].flatten()
    y_pred = ind_synth_test_preds[idx].prediction_probas.values[valid_inds].flatten()

    ax.plot(
        np.linspace(0, 1.0, 1000),
        np.linspace(0, 1.0, 1000),
        label="Ideal Calibration",
        linestyle="--",
        color="black"
    )
    ax.scatter(
        y_true,
        y_pred,
        label=f"Sample (MAE = {mae(y_true, y_pred):.5f}, $r^2$ = {r2_score(y_true, y_pred):.3f})",
        alpha=0.5
    )
    ax.set_xlim((0, plt_max))
    ax.set_ylim((0, plt_max))
    ax.legend(loc="upper left")
    ax.set_title(f"Unsupervised Loss: {plt_names[idx]}")

fig.supxlabel("True Proportion")
fig.supylabel("Predicted Proportion")
fig.tight_layout()
fig.savefig(os.path.join(
    IMAGE_DIR,
    "ind_synth_test_calibration_high_snr.png"
), dpi=300)
plt.show()

# generate plots for focus unsup loss
fig, ax = plt.subplots()

plt_maes = all_maes[focus_unsup_loss]
snrs = all_snrs[0]
plt_snr_steps = np.logspace(
    np.log10(min(snrs)),
    np.log10(max(snrs)),
    nbins+1
)
plt_snr_inds = [np.where(
    (snrs > plt_snr_steps[i]) & (snrs <= plt_snr_steps[i+1])
)[0].astype(int) for i in range(nbins)]
plt_maes = [plt_maes[each] for each in plt_snr_inds]
plt_snrs = [np.mean(snrs[each]) for each in plt_snr_inds]

width = lambda p, w: 10**(np.log10(p)+w/2.)-10**(np.log10(p)-w/2.)

ax.boxplot(
    plt_maes,
    positions=plt_snrs,
    vert=1,
    widths=width(plt_snrs, 0.1)
)
ax.set_xscale("log")
ax.set_title(f"Unsupervised Loss: {plt_names[focus_unsup_loss]}")

fig.supxlabel("SNR")
fig.supylabel("MAE")
fig.tight_layout()
fig.savefig(os.path.join(
    IMAGE_DIR,
    f"ind_synth_test_mae_vs_snr_{list(best_runs.keys())[focus_unsup_loss]}.png"
), dpi=300)
plt.show()

fig, ax = plt.subplots(1, 2, sharey=True, sharex=True, figsize=(10, 4))
plt_max = np.hstack([
    np.array(all_preds).flatten(),
    np.array(all_trues).flatten()
]).max() + 0.05

y_true = all_trues[focus_unsup_loss]
y_pred = all_preds[focus_unsup_loss]
ax[0].plot(
    np.linspace(0, 1.0, 1000),
    np.linspace(0, 1.0, 1000),
    label="Ideal Calibration",
    linestyle="--",
    color="black"
)
ax[0].scatter(
    y_true,
    y_pred,
    label=f"Sample (MAE = {mae(y_true, y_pred):.5f}, $r^2$ = {r2_score(y_true, y_pred):.3f})",
    alpha=0.5
)
ax[0].set_xlim((0, plt_max))
ax[0].set_ylim((0, plt_max))
ax[0].legend(loc="upper left")
ax[0].set_title("All Test Samples")

y_true = ind_synth_test_preds[focus_unsup_loss].sources.values[valid_inds].flatten()
y_pred = ind_synth_test_preds[focus_unsup_loss].prediction_probas.values[valid_inds].flatten()

ax[1].plot(
    np.linspace(0, 1.0, 1000),
    np.linspace(0, 1.0, 1000),
    label="Ideal Calibration",
    linestyle="--",
    color="black"
)
ax[1].scatter(
    y_true,
    y_pred,
    label=f"Sample (MAE = {mae(y_true, y_pred):.5f}, $r^2$ = {r2_score(y_true, y_pred):.3f})",
    alpha=0.5
)
ax[1].set_xlim((0, plt_max))
ax[1].set_ylim((0, plt_max))
ax[1].legend(loc="upper left")
ax[1].set_title("High-SNR Test Samples (SNR > 100)")

fig.suptitle(f"Unsupervised Loss: {plt_names[focus_unsup_loss]}")
fig.supxlabel("True Proportion")
fig.supylabel("Predicted Proportion")
fig.tight_layout()
fig.savefig(os.path.join(
    IMAGE_DIR,
    f"ind_synth_test_calibration_{list(best_runs.keys())[focus_unsup_loss]}.png"
), dpi=300)
plt.show()


In [None]:
"""Plot proportion distributions of the training and test data."""

fig, ax = plt.subplots(figsize=(12,6))
ax.boxplot(ind_synth_train_ss.sources.iloc[:, isotope_order].values)
ax.set_xticks(np.arange(1, fg_seeds_ss.n_samples+1), np.array(fg_seeds_ss.get_labels())[isotope_order], rotation=60)
ax.set_xlabel("Isotope")
ax.set_ylabel("Proportion")
ax.grid(alpha=0.3)
fig.tight_layout()
fig.savefig(os.path.join(IMAGE_DIR, "ind_synth_train_data_distributions.png"), dpi=300)
plt.show()

fig, ax = plt.subplots(figsize=(12,6))
ax.boxplot(ind_synth_test_ss.sources.iloc[:, isotope_order].values)
ax.set_xticks(np.arange(1, fg_seeds_ss.n_samples+1), np.array(fg_seeds_ss.get_labels())[isotope_order], rotation=60)
ax.set_xlabel("Isotope")
ax.set_ylabel("Proportion")
ax.grid(alpha=0.3)
fig.tight_layout()
fig.savefig(os.path.join(IMAGE_DIR, "ind_synth_test_data_distributions.png"), dpi=300)
plt.show()

In [None]:
"""Do forward pass on OOD synthetic data."""

ood_seeds_ss = read_hdf(OOD_SEED_FILE)
ood_seeds_ss = ood_seeds_ss.as_ecal(*TARGET_ECAL)
ood_seeds_ss.downsample_spectra(TARGET_BINS)

ood_names = list(ood_seeds_ss.get_labels())
ood_ss = []
ood_snrs = []

for each in ood_names:
    ss = read_hdf(os.path.join(
        DATA_DIR,
        f"{each}_ood_test.h5"
    ))
    ood_ss.append(ss)
    snrs = ss.spectra.values.sum(axis=1) / \
        np.sqrt(ss.info.live_time.values * STATIC_SYNTH_BG_CPS)
    ood_snrs.append(snrs)

all_ood_pred_ss = []
for idx, each in enumerate(best_runs.keys()):
    tmp_ss_list = []
    for ood_idx, ood_name in enumerate(ood_names):
        tmp_ss = ood_ss[ood_idx][:]
        best_models[idx].predict(tmp_ss, bg_cps=STATIC_SYNTH_BG_CPS)

        tmp_ss_list.append(tmp_ss)
    all_ood_pred_ss.append(tmp_ss_list)


In [None]:
"""Generate OOD plots."""

# threshold plots
fig, axes = plt.subplots(2, 2, figsize=(10,6), sharex=True)

for idx, ax in enumerate(axes.reshape(-1)):
    ax.scatter(
        best_models[idx].spline_snrs,
        best_models[idx].spline_recon_errors,
        alpha=0.3,
        label="training sample"
    )
    x = np.linspace(np.min(best_models[idx].spline_snrs), np.max(best_models[idx].spline_snrs), 100)
    ax.plot(
        x,
        best_models[idx].ood_threshold_func(x),
        label="OOD threshold func",
        color="black",
        linestyle="--"
    )
    ax.set_title(f"Unsupervised Loss: {plt_names[idx]}")
    ax.set_yscale("log")

axes[1,1].legend()
fig.supxlabel("SNR")
fig.supylabel("Reconstruction Error")
fig.tight_layout()
fig.savefig(os.path.join(
    IMAGE_DIR,
    "ood_threshold_func.png"
), dpi=300)
plt.show()

n_bins = 10
plt_ood_contrib_steps = np.linspace(0, 1.0, n_bins+1)
ood_threshold_snr = 100

# for all test samples
for ood_idx, tmp_ss in enumerate(ood_ss):

    fig, axes = plt.subplots(2, 2, figsize=(14,8), sharex=True)

    for idx, ax in enumerate(axes.reshape(-1)):
        ood_contribs = all_ood_pred_ss[idx][ood_idx].sources.values[:,-1]
        ood_recon_errors = all_ood_pred_ss[idx][ood_idx].info[best_models[idx].unsup_loss_func_name].values

        plt_ood_inds = [np.where(
            (ood_contribs > plt_ood_contrib_steps[i]) & (ood_contribs <= plt_ood_contrib_steps[i+1])
        )[0].astype(int) for i in range(n_bins)]
        plt_recon_errors = [ood_recon_errors[each] for each in plt_ood_inds]
        plt_ood_contribs = [np.mean(ood_contribs[each]) for each in plt_ood_inds]

        ax.boxplot(
            plt_recon_errors,
            positions=[(plt_ood_contrib_steps[i] + plt_ood_contrib_steps[i+1])/2 for i in range(n_bins)],
            vert=1,
            widths=0.05
        )
        mean_ind_recon_error = np.mean(all_recon_errors[idx][np.where(all_snrs[idx] > 0)[0]])
        ax.axhline(
            mean_ind_recon_error,
            label=f"mean IND recon error: {mean_ind_recon_error:.2E}",
            linestyle="--"
        )
        ood_threshold = best_models[idx].ood_threshold_func(ood_threshold_snr)
        ax.axhline(
            ood_threshold,
            label=f"ood threshold (0.05 FPR @ {ood_threshold_snr} SNR): {ood_threshold:.2E}",
            linestyle="-",
            color="red"
        )

        ax.set_yscale("log")
        ax.set_xlim(0.0, 1.0)
        ax.xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
        ax.set_title(f"Unsupervised Loss: {plt_names[idx]}")

        ax.legend()
    fig.suptitle(f"OOD Source: {ood_names[ood_idx]}")
    fig.supxlabel("OOD Contribution")
    fig.supylabel("Reconstruction Error")
    fig.tight_layout()
    fig.savefig(os.path.join(
        IMAGE_DIR,
        f"ood_synth_test_recon_error_vs_ood_contrib_{ood_names[ood_idx]}.png"
    ), dpi=300)
    plt.show()

# for high-snr samples
snr_threshold = 100

for ood_idx, tmp_ss in enumerate(ood_ss):

    valid_inds = np.where(ood_snrs[ood_idx] > snr_threshold)[0]

    fig, axes = plt.subplots(2, 2, figsize=(14,8), sharex=True)

    for idx, ax in enumerate(axes.reshape(-1)):
        ood_contribs = all_ood_pred_ss[idx][ood_idx].sources.values[:,-1][valid_inds]
        ood_recon_errors = all_ood_pred_ss[idx][ood_idx].info[best_models[idx].unsup_loss_func_name].values[valid_inds]

        plt_ood_inds = [np.where(
            (ood_contribs > plt_ood_contrib_steps[i]) & (ood_contribs <= plt_ood_contrib_steps[i+1])
        )[0].astype(int) for i in range(n_bins)]
        plt_recon_errors = [ood_recon_errors[each] for each in plt_ood_inds]
        plt_ood_contribs = [np.mean(ood_contribs[each]) for each in plt_ood_inds]

        ax.boxplot(
            plt_recon_errors,
            positions=[(plt_ood_contrib_steps[i] + plt_ood_contrib_steps[i+1])/2 for i in range(n_bins)],
            vert=1,
            widths=0.05
        )
        mean_ind_recon_error = np.mean(all_recon_errors[idx][np.where(all_snrs[idx] > 100)[0]])
        ax.axhline(
            np.mean(all_recon_errors[idx][np.where(all_snrs[idx] > 100)[0]]),
            label=f"mean IND recon error: {mean_ind_recon_error:.2E}",
            linestyle="--"
        )
        ood_threshold = best_models[idx].ood_threshold_func(ood_threshold_snr)
        ax.axhline(
            ood_threshold,
            label=f"ood threshold (0.05 FPR @ {ood_threshold_snr} SNR): {ood_threshold:.2E}",
            linestyle="-",
            color="red"
        )

        ax.set_yscale("log")
        ax.set_xlim(0.0, 1.0)
        ax.xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
        ax.set_title(f"Unsupervised Loss: {plt_names[idx]}")

        ax.legend()

    fig.suptitle(f"OOD Source: {ood_names[ood_idx]}")
    fig.supxlabel("OOD Contribution")
    fig.supylabel("Reconstruction Error")
    fig.tight_layout()
    fig.savefig(os.path.join(
        IMAGE_DIR,
        f"ood_synth_test_recon_error_vs_ood_contrib_{ood_names[ood_idx]}_high_snr.png"
    ), dpi=300)
    plt.show()

# focus unsup loss
fig = plt.figure(layout='constrained', figsize=(10, 10))
subfigs = fig.subfigures(1, 2, wspace=0.07)

axsLeft = subfigs[0].subplots(3, 1, sharey=True, sharex=True)

for ood_idx, tmp_ss in enumerate(ood_ss):

    ood_contribs = all_ood_pred_ss[focus_unsup_loss][ood_idx].sources.values[:,-1]
    ood_recon_errors = all_ood_pred_ss[focus_unsup_loss][ood_idx].info[best_models[focus_unsup_loss].unsup_loss_func_name].values

    plt_ood_inds = [np.where(
        (ood_contribs > plt_ood_contrib_steps[i]) & (ood_contribs <= plt_ood_contrib_steps[i+1])
    )[0].astype(int) for i in range(n_bins)]
    plt_recon_errors = [ood_recon_errors[each] for each in plt_ood_inds]
    plt_ood_contribs = [np.mean(ood_contribs[each]) for each in plt_ood_inds]

    axsLeft[ood_idx].boxplot(
        plt_recon_errors,
        positions=[(plt_ood_contrib_steps[i] + plt_ood_contrib_steps[i+1])/2 for i in range(n_bins)],
        vert=1,
        widths=0.05
    )
    mean_ind_recon_error = np.mean(all_recon_errors[focus_unsup_loss][np.where(all_snrs[focus_unsup_loss] > 0)[0]])
    axsLeft[ood_idx].axhline(
        mean_ind_recon_error,
        label=f"mean IND recon error: {mean_ind_recon_error:.2E}",
        linestyle="--"
    )
    ood_threshold = best_models[focus_unsup_loss].ood_threshold_func(ood_threshold_snr)
    axsLeft[ood_idx].axhline(
        ood_threshold,
        label=f"ood threshold (0.05 FPR @ {ood_threshold_snr} SNR): {ood_threshold:.2E}",
        linestyle="-",
        color="red"
    )

    axsLeft[ood_idx].set_yscale("log")
    axsLeft[ood_idx].set_xlim(0.0, 1.0)
    axsLeft[ood_idx].xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    axsLeft[ood_idx].set_title(f"OOD Source: {ood_names[ood_idx]}")

    axsLeft[ood_idx].legend()

axsRight = subfigs[1].subplots(3, 1, sharey=True, sharex=True)

for ood_idx, tmp_ss in enumerate(ood_ss):

    valid_inds = np.where(ood_snrs[ood_idx] > snr_threshold)[0]
    ood_contribs = all_ood_pred_ss[focus_unsup_loss][ood_idx].sources.values[:,-1][valid_inds]
    ood_recon_errors = all_ood_pred_ss[focus_unsup_loss][ood_idx].info[best_models[focus_unsup_loss].unsup_loss_func_name].values[valid_inds]

    plt_ood_inds = [np.where(
        (ood_contribs > plt_ood_contrib_steps[i]) & (ood_contribs <= plt_ood_contrib_steps[i+1])
    )[0].astype(int) for i in range(n_bins)]
    plt_recon_errors = [ood_recon_errors[each] for each in plt_ood_inds]
    plt_ood_contribs = [np.mean(ood_contribs[each]) for each in plt_ood_inds]

    axsRight[ood_idx].boxplot(
        plt_recon_errors,
        positions=[(plt_ood_contrib_steps[i] + plt_ood_contrib_steps[i+1])/2 for i in range(n_bins)],
        vert=1,
        widths=0.05
    )
    mean_ind_recon_error = np.mean(all_recon_errors[focus_unsup_loss][np.where(all_snrs[focus_unsup_loss] > 100)[0]])
    axsRight[ood_idx].axhline(
        mean_ind_recon_error,
        label=f"mean IND recon error: {mean_ind_recon_error:.2E}",
        linestyle="--"
    )
    ood_threshold = best_models[focus_unsup_loss].ood_threshold_func(ood_threshold_snr)
    axsRight[ood_idx].axhline(
        ood_threshold,
        label=f"ood threshold (0.05 FPR @ {ood_threshold_snr} SNR): {ood_threshold:.2E}",
        linestyle="-",
        color="red"
    )

    axsRight[ood_idx].set_yscale("log")
    axsRight[ood_idx].set_xlim(0.0, 1.0)
    axsRight[ood_idx].xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    axsRight[ood_idx].set_title(f"OOD Source: {ood_names[ood_idx]}")

    axsRight[ood_idx].legend()

axsLeft[0].set_ylim(1e-4, 0.5)
axsRight[0].set_ylim(1e-4, 0.5)

subfigs[0].suptitle("All OOD Test Samples")
subfigs[1].suptitle("High-SNR OOD Test Samples (SNR > 100)")

fig.supxlabel("OOD Contribution")
fig.supylabel("Reconstruction Error")
fig.savefig(os.path.join(
    IMAGE_DIR,
    f"ood_synth_test_recon_error_vs_ood_contrib_{list(best_runs.keys())[focus_unsup_loss]}.png"
), dpi=300)
plt.show()

In [None]:
"""Generate OOD heatmaps."""

nbins = 20
snr_threshold = 25
for ood_idx, tmp_ss in enumerate(ood_ss):
    snrs = ood_snrs[ood_idx]
    valid_inds = np.where(snrs >= snr_threshold)[0]
    snrs = snrs[valid_inds]

    fig, axes = plt.subplots(2, 2, figsize=(10,6), sharex=True)

    for idx, ax in enumerate(axes.reshape(-1)):
        ood_contribs = all_ood_pred_ss[idx][ood_idx].sources.values[:,-1][valid_inds]
        ood_decisions = all_ood_pred_ss[idx][ood_idx].info.ood.values[valid_inds]

        fns = np.ones((nbins, nbins))
        snr_steps = np.logspace(
            np.log10(snr_threshold),
            np.log10(np.max(snrs)),
            nbins+1
        )
        ood_prop_steps = np.linspace(
            0.0,
            1.0,
            nbins+1
        )

        for i in range(nbins):
            for j in range(nbins):
                snr_inds = np.where(
                    (snrs > snr_steps[i]) &
                    (snrs <= snr_steps[i+1])
                )
                prop_inds = np.where(
                    (ood_contribs[snr_inds[0]] > ood_prop_steps[j]) &
                    (ood_contribs[snr_inds[0]] <= ood_prop_steps[j+1])
                )

                fns[i, j] = 1.0 - ood_decisions[snr_inds[0]][prop_inds[0]].mean()

        fns = pd.DataFrame(
            fns
        ).interpolate().values

        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)

        im = ax.imshow(
            fns.T,
            extent=[snr_threshold, np.max(snrs), 0, 1.0],
            origin="lower",
            aspect="auto",
            interpolation="nearest",
            vmin=0.0,
            vmax=1.0
        )
        cb = fig.colorbar(im, cax=cax, orientation="vertical")
        cb.set_label("OOD FNR")
        cb.ax.set_ylim(0.0, 1.0)

        tmp_snr_range = np.max(snrs) - snr_threshold
        plt_threshold_idx = np.max(np.where(snr_steps < STATIC_SYNTH_SNR_RANGE[0])[0])
        plt_threshold_loc = ((plt_threshold_idx + 1) / nbins) * tmp_snr_range + snr_threshold
        ax.axvline(plt_threshold_loc, color="red", linestyle="dashed", label=f"valid SNR threshold ({STATIC_SYNTH_SNR_RANGE[0]})", linewidth=3)

        ind_valid_inds = np.where(all_snrs[0] >= 50)[0]
        fpr = np.mean(ind_synth_test_preds[idx].info.ood.values[ind_valid_inds])

        ax.set_title(f"Unsupervised Loss: {plt_names[idx]} (FPR = {fpr:.3f})")
        ax.set_xscale("linear")

        tmp_sample_space = np.logspace(np.log10(snr_threshold), np.log10(np.max(snrs)), 1000)
        first_tick_step = np.max(np.where(tmp_sample_space < 100)[0])
        second_tick_step = np.max(np.where(tmp_sample_space < 1000)[0])

        first_tick_loc = ((first_tick_step + 1) / 1000) * tmp_snr_range + snr_threshold
        second_tick_loc = ((second_tick_step + 1) / 1000) * tmp_snr_range + snr_threshold

        xticks = np.hstack([snr_threshold, first_tick_loc, second_tick_loc])
        ax.set_xticks(xticks)
        ax.set_xticklabels([snr_threshold, "$10^2$", "$10^3$"])

    axes[1, 1].legend()

    fig.suptitle(f"OOD Source: {ood_names[ood_idx]}")
    fig.supxlabel("SNR")
    fig.supylabel("OOD Contribution")
    fig.tight_layout()
    fig.savefig(os.path.join(
        IMAGE_DIR,
        f"ood_synth_test_heatmap_{ood_names[ood_idx]}.png"
    ), dpi=300)
    plt.show()

# focus unsup loss
fig = plt.figure(figsize=(12,6))
gs = gridspec.GridSpec(ncols=5, nrows=2, figure=fig)

ax1 = fig.add_subplot(gs[0, 0:2])
ax2 = fig.add_subplot(gs[0, 2:4])
ax3 = fig.add_subplot(gs[1, 1:3])

axes = [ax1, ax2, ax3]

for ood_idx, tmp_ss in enumerate(ood_ss):
    ax = axes[ood_idx]

    snrs = ood_snrs[ood_idx]
    valid_inds = np.where(snrs >= snr_threshold)[0]
    snrs = snrs[valid_inds]

    ood_contribs = all_ood_pred_ss[focus_unsup_loss][ood_idx].sources.values[:,-1][valid_inds]
    ood_decisions = all_ood_pred_ss[focus_unsup_loss][ood_idx].info.ood.values[valid_inds]

    fns = np.ones((nbins, nbins))
    snr_steps = np.logspace(
        np.log10(snr_threshold),
        np.log10(np.max(snrs)),
        nbins+1
    )
    ood_prop_steps = np.linspace(
        0.0,
        1.0,
        nbins+1
    )

    for i in range(nbins):
        for j in range(nbins):
            snr_inds = np.where(
                (snrs > snr_steps[i]) &
                (snrs <= snr_steps[i+1])
            )
            prop_inds = np.where(
                (ood_contribs[snr_inds[0]] > ood_prop_steps[j]) &
                (ood_contribs[snr_inds[0]] <= ood_prop_steps[j+1])
            )

            fns[i, j] = 1.0 - ood_decisions[snr_inds[0]][prop_inds[0]].mean()

    fns = pd.DataFrame(
        fns
    ).interpolate().values

    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)

    im = ax.imshow(
        fns.T,
        extent=[snr_threshold, np.max(snrs), 0, 1.0],
        origin="lower",
        aspect="auto",
        interpolation="nearest",
        vmin=0.0,
        vmax=1.0
    )
    cb = fig.colorbar(im, cax=cax, orientation="vertical")
    cb.set_label("OOD FNR")
    cb.ax.set_ylim(0.0, 1.0)

    tmp_snr_range = np.max(snrs) - snr_threshold
    plt_threshold_idx = np.max(np.where(snr_steps < STATIC_SYNTH_SNR_RANGE[0])[0])
    plt_threshold_loc = ((plt_threshold_idx + 1) / nbins) * tmp_snr_range + snr_threshold
    ax.axvline(plt_threshold_loc, color="red", linestyle="dashed", label=f"valid SNR threshold ({STATIC_SYNTH_SNR_RANGE[0]})", linewidth=3)

    ind_valid_inds = np.where(all_snrs[0] >= 50)[0]
    fpr = np.mean(ind_synth_test_preds[focus_unsup_loss].info.ood.values[ind_valid_inds])

    ax.set_xscale("linear")

    tmp_sample_space = np.logspace(np.log10(snr_threshold), np.log10(np.max(snrs)), 1000)
    first_tick_step = np.max(np.where(tmp_sample_space < 100)[0])
    second_tick_step = np.max(np.where(tmp_sample_space < 1000)[0])

    first_tick_loc = ((first_tick_step + 1) / 1000) * tmp_snr_range + snr_threshold
    second_tick_loc = ((second_tick_step + 1) / 1000) * tmp_snr_range + snr_threshold

    xticks = np.hstack([snr_threshold, first_tick_loc, second_tick_loc])
    ax.set_xticks(xticks)
    ax.set_xticklabels([snr_threshold, "$10^2$", "$10^3$"])

    ax.set_title(f"OOD Source: {ood_names[ood_idx]}")

fig.supxlabel("SNR")
fig.supylabel("OOD Contribution")
axes[2].legend()
fig.suptitle(f"Unsupervised Loss: {plt_names[focus_unsup_loss]} (FPR = {fpr:.3f})", x=0.43, y=0.95)
fig.tight_layout()
fig.savefig(os.path.join(
        IMAGE_DIR,
        f"ood_synth_test_heatmap_{list(best_runs.keys())[focus_unsup_loss]}.png"
    ), dpi=300)
plt.show()

In [None]:
"""Generate OOD synthetic datasets which are OOD due to novel proportion combinations"""

# load in seeds
fg_seeds_ss = read_hdf(FG_SEED_FILE)
fg_seeds_ss, _ = fg_seeds_ss.split_fg_and_bg()
fg_seeds_ss.drop_sources_columns_with_all_zeros()

# get expected source contributions before anything else!
source_counts = {
    x.split(",")[0]: v
    for x, v in zip(
        fg_seeds_ss.sources.columns.get_level_values("Seed").values,
        fg_seeds_ss.info.total_counts
    )
}
Z = np.array(list(source_counts.values()))
N = len(Z)
expected_props = Z / Z.sum()

fg_seeds_ss = fg_seeds_ss.as_ecal(*TARGET_ECAL)
fg_seeds_ss.downsample_spectra(target_bins=TARGET_BINS)
fg_seeds_ss.normalize(p=1)

mixed_bg_seeds_ss = read_hdf(BG_MIX_FILE)

rng = np.random.default_rng(seed=RANDOM_SEED)

mixed_bg_seeds_ss = read_hdf(BG_MIX_FILE)

n_lambda = 20
lambdas = np.logspace(-1, np.log10(MIX_TRAIN_FG_LAMBDA), n_lambda)[::-1]
snr_range = STATIC_SYNTH_SNR_RANGE

static_syn = StaticSynthesizer(
    samples_per_seed=STATIC_SYNTH_SPS,
    bg_cps=STATIC_SYNTH_BG_CPS,
    live_time_function=STATIC_SYNTH_LIVE_TIME_SAMPLING,
    live_time_function_args=STATIC_SYNTH_LIVE_TIME_RANGE,
    snr_function=STATIC_SYNTH_SNR_SAMPLING,
    snr_function_args=snr_range,
    long_bg_live_time=STATIC_SYNTH_LONG_BG_CPS
)

mixed_fg_seeds = []
synth_fg_mixes = []
for i in range(n_lambda):
    ss = SeedMixer(
        fg_seeds_ss,
        mixture_size=fg_seeds_ss.n_samples,
        dirichlet_alpha=expected_props * lambdas[i]
    ).generate(1000)
    mixed_fg_seeds.append(ss)

    fg_ss, _ = static_syn.generate(
        fg_seeds_ss=ss,
        bg_seeds_ss=mixed_bg_seeds_ss
    )

    fg_ss.drop_spectra_with_no_contributors()
    fg_ss.clip_negatives()

    synth_fg_mixes.append(fg_ss)

uniform_mixed_fg_seeds = []
uniform_synth_fg_mixes = []
for i in range(n_lambda):
    ss = SeedMixer(
        fg_seeds_ss,
        mixture_size=fg_seeds_ss.n_samples,
        dirichlet_alpha=np.array([1/fg_seeds_ss.n_samples] * fg_seeds_ss.n_samples) * lambdas[i]
    ).generate(1000)
    uniform_mixed_fg_seeds.append(ss)

    fg_ss, _ = static_syn.generate(
        fg_seeds_ss=ss,
        bg_seeds_ss=mixed_bg_seeds_ss
    )

    fg_ss.drop_spectra_with_no_contributors()
    fg_ss.clip_negatives()

    uniform_synth_fg_mixes.append(fg_ss)

In [None]:
"""Generate summary plots for synth OOD performance due to novel proportions."""

maes = []
unsup_losses = []
ood_rates = []
uniform_maes = []
uniform_unsup_losses = []
uniform_ood_rates = []

for idx, unsup_loss in enumerate(best_runs.keys()):
    tmp_maes = []
    tmp_unsup_losses = []
    tmp_ood_rates = []
    uniform_tmp_maes = []
    uniform_tmp_unsup_losses = []
    uniform_tmp_ood_rates = []
    for i in range(n_lambda):
        # first predict and compute stats on IND expectation
        best_models[idx].predict(synth_fg_mixes[i], bg_cps=STATIC_SYNTH_BG_CPS)
        y_true = synth_fg_mixes[i].sources.values.flatten()
        y_pred = synth_fg_mixes[i].prediction_probas.values.flatten()

        tmp_maes.append(mae(y_true, y_pred))
        tmp_unsup_losses.append(np.mean(
            synth_fg_mixes[i].info[best_models[idx].unsup_loss_func_name].values
        ))
        tmp_ood_rates.append(np.mean(
            synth_fg_mixes[i].info["ood"].values
        ))

        # then on OOD expectation
        best_models[idx].predict(uniform_synth_fg_mixes[i], bg_cps=STATIC_SYNTH_BG_CPS)
        y_true = uniform_synth_fg_mixes[i].sources.values.flatten()
        y_pred = uniform_synth_fg_mixes[i].prediction_probas.values.flatten()

        uniform_tmp_maes.append(mae(y_true, y_pred))
        uniform_tmp_unsup_losses.append(np.mean(
            uniform_synth_fg_mixes[i].info[best_models[idx].unsup_loss_func_name].values
        ))
        uniform_tmp_ood_rates.append(np.mean(
            uniform_synth_fg_mixes[i].info["ood"].values
        ))

    maes.append(tmp_maes)
    unsup_losses.append(tmp_unsup_losses)
    ood_rates.append(tmp_ood_rates)
    uniform_maes.append(uniform_tmp_maes)
    uniform_unsup_losses.append(uniform_tmp_unsup_losses)
    uniform_ood_rates.append(uniform_tmp_ood_rates)

for idx, unsup_loss in enumerate(best_runs.keys()):

    fig, axes = plt.subplots(1, 3, figsize=(12,6), sharex=True)

    axes[0].plot(lambdas, maes[idx], label="IND expectation")
    axes[0].plot(lambdas, uniform_maes[idx], label="uniform expectation")
    axes[0].set_ylabel("MAE")
    axes[0].set_xscale("log")

    axes[1].plot(lambdas, unsup_losses[idx], label="IND expectation")
    axes[1].plot(lambdas, uniform_unsup_losses[idx], label="uniform expectation")
    axes[1].set_ylabel("Reconstruction Error")

    axes[2].plot(lambdas, ood_rates[idx], label="IND expectation")
    axes[2].plot(lambdas, uniform_ood_rates[idx], label="uniform expectation")
    axes[2].set_ylabel("OOD Rate")
    axes[2].set_ylim(0.0, 1.05)
    axes[2].legend()

    fig.suptitle(f"Unsupervised Loss: {plt_names[idx]}")
    fig.supxlabel("$\lambda$")
    fig.tight_layout()
    fig.savefig(os.path.join(
        IMAGE_DIR,
        f"ood_synth_novel_proportions_{unsup_loss}.png"
    ), dpi=300)
    plt.show()