    Copyright 2023 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
    Under the terms of Contract DE-NA0003525 with NTESS,
    the U.S. Government retains certain rights in this software.

# Train a final label proportion estimator (LPE)

In [None]:
"""Imports"""
import os
import re

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
from config import (BG_SEED_FILE, DATA_DIR, FG_SEED_FILE, IND_FG_SEED_FILE,
                    MEASURED_LLD, MEASUREMENT_FILE,
                    MEASUREMENT_MATCH_SYNTHETIC_FILE,
                    MEASUREMENT_TEST_MEASURED_FILE, MODEL_DIR, TARGET_BINS,
                    TRAIN_FILE)
from matplotlib.ticker import ScalarFormatter
from mpl_toolkits.axes_grid1 import make_axes_locatable
from riid.data.sampleset import SampleSet, SpectraState, SpectraType, read_hdf
from riid.losses.sparsemax import sparsemax
from riid.models.neural_nets import LabelProportionEstimator
from riid.visualize import plot_learning_curve, plot_spectra
from scipy.interpolate import UnivariateSpline
from sklearn.metrics import mean_absolute_error

In [None]:
"""Load in measurement data and synthetic seeds."""
measurements_ss = read_hdf(f"{MEASUREMENT_FILE}")
measurements_ss.spectra_state = SpectraState.Counts
gross_measurements_ss, bg_measurements_ss = measurements_ss.split_fg_and_bg(["Background"])
gross_measurements_ss.spectra_type = SpectraType.Gross
bg_measurements_ss.spectra_type = SpectraType.Background
fg_measurements_ss = gross_measurements_ss - bg_measurements_ss

fg_seeds_ss = read_hdf(f"{FG_SEED_FILE}")
fg_measurements_ss = fg_measurements_ss.as_ecal(*fg_seeds_ss.ecal[0])

fg_measurements_ss.downsample_spectra(target_bins=TARGET_BINS)
fg_measurements_ss.spectra.iloc[:,:MEASURED_LLD] = 0
fg_measurements_ss.normalize(p=1)

# resort measurements
srt_order = [19, 13, 9, 3, 2, 18, 7, 20, 4, 15, 16, 5, 21, 12, 10, 6, 8, 14, 1, 11, 0, 17]
fg_measurements_ss = fg_measurements_ss[srt_order]

id_seeds_ss = read_hdf(IND_FG_SEED_FILE)
id_seeds_ss.downsample_spectra(target_bins=TARGET_BINS)

bg_seeds_ss = read_hdf(BG_SEED_FILE)
bg_seeds_ss.downsample_spectra(target_bins=TARGET_BINS)

In [None]:
"""Plot in-distribution synthetic foreground seeds."""
plot_spectra(id_seeds_ss, target_level="Seed", in_energy=True, title="ID Synthetic Foreground Seeds")

In [None]:
"""Visually compare synthetic and measured spectra."""
comp_labels = ["Am241", "Ba133", "Co60", "U232"]
fg_labels = fg_measurements_ss.get_labels()
id_labels = id_seeds_ss.get_labels()

bg_cps = bg_measurements_ss.info.total_counts / bg_measurements_ss.info.live_time
bg_cps = bg_cps.values

for each in comp_labels:
    measured_ss = fg_measurements_ss[fg_labels == each]
    synthetic_ss = id_seeds_ss[id_labels == each]
    synthetic_ss.downsample_spectra(target_bins=TARGET_BINS)

    plt_snrs = measured_ss.info.total_counts / np.sqrt(measured_ss.info.live_time * bg_cps)
    plt_snrs = plt_snrs.values

    concat_ss = SampleSet()
    concat_ss.concat([measured_ss, synthetic_ss])

    fig, ax = plot_spectra(
        concat_ss,
        ylim=(None, None),
        in_energy=True,
        target_level="Seed",
        show=False,
        title=""
    )
    plt_labels = measured_ss.get_labels(target_level="Seed")
    plt_labels = [each + f" ({plt_snrs[idx]:.0f} SNR)" for idx, each in enumerate(plt_labels)]
    plt_labels.append(synthetic_ss.get_labels(target_level="Seed")[0])
    ax.legend(plt_labels)
    ax.set(title=None)
    plt.show()

In [None]:
"""Load model and run forward pass with LPE."""
model_path = os.path.join(
    MODEL_DIR,
    "lpe_sparsemax_jsd_beta0.85_20231012-125247.onnx"
)
model = LabelProportionEstimator()
model.load(model_path)

model.predict(fg_measurements_ss)
y = fg_measurements_ss.get_source_contributions().astype(float)
preds = fg_measurements_ss.prediction_probas.groupby(level="Isotope", axis=1).sum().values.astype(float)

print("MAE:", f"{mean_absolute_error(y, preds):.3f}")

In [None]:
"""Plot predictions on single-isotope measurements."""
ID_SOURCES = list(id_seeds_ss.sources.columns.levels[1])
ncols = 4
nrows = int(np.ceil(fg_measurements_ss.n_samples / 4))
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*3, nrows*2))
x_bar = list(fg_measurements_ss.prediction_probas.groupby(level="Isotope", axis=1).sum().columns)

bg_cps = bg_measurements_ss.info.total_counts / bg_measurements_ss.info.live_time

for idx in range(fg_measurements_ss.n_samples):
    if idx <= 18:
        ax = axes.reshape(-1)[idx]
    else:
        ax = axes.reshape(-1)[idx+1]
    snr = fg_measurements_ss.info.total_counts[idx] / np.sqrt(fg_measurements_ss.info.live_time[idx] * bg_cps)[0]
    ax.set_title(f"{fg_measurements_ss.get_labels(target_level="Seed")[idx]}, {snr:.0f} σ")

    if fg_measurements_ss.get_labels()[idx] in ID_SOURCES:
        plt_hatch = None
        id_plt = ax.bar(x_bar, preds[idx,:], color="white", edgecolor="black", hatch=plt_hatch)
    else:
        plt_hatch = "//"
        ood_plt = ax.bar(x_bar, preds[idx,:], color="white", edgecolor="black", hatch=plt_hatch)

    ax.set_xticks(x_bar)
    ax.set_xticklabels(x_bar, rotation=45)
    ax.set_ylim((0,1))

axes.reshape(-1)[19].set_axis_off()
axes.reshape(-1)[23].set_axis_off()
fig.supylabel("Estimated Proportion")
fig.supxlabel("Model Outputs")
fig.legend((id_plt, ood_plt), ("ID", "OOD"), loc="outside right lower")
fig.tight_layout()
fig.savefig("./imgs/measurement_preds.png", dpi=300)
plt.show()

In [None]:
"""Evaluate test samplesets, generated from both simulated and measured seeds."""
test_fg_measurement_ss = read_hdf(MEASUREMENT_TEST_MEASURED_FILE)
test_fg_synthetic_ss = read_hdf(MEASUREMENT_MATCH_SYNTHETIC_FILE)
test_fg_measurement_ss.normalize(p=1)
test_fg_synthetic_ss.normalize(p=1)

model.predict(test_fg_measurement_ss)
model.predict(test_fg_synthetic_ss)

test_measurement_preds = sparsemax(
    test_fg_measurement_ss.prediction_probas.groupby(
        level="Isotope",
        axis=1
    ).sum().values
).numpy().astype(float)

test_synthetic_preds = sparsemax(
    test_fg_synthetic_ss.prediction_probas.groupby(
        level="Isotope",
        axis=1
    ).sum().values
).numpy().astype(float)

test_measurement_y = test_fg_measurement_ss[:].sources.groupby(axis=1, level="Isotope").sum()
test_measurement_y["Cf252"] = 0
test_measurement_y["Eu152"] = 0
test_measurement_y = test_measurement_y.reindex(sorted(test_measurement_y.columns), axis=1).values.astype(float)

test_synthetic_y = test_fg_synthetic_ss[:].sources.groupby(axis=1, level="Isotope").sum()
test_synthetic_y["Cf252"] = 0
test_synthetic_y["Eu152"] = 0
test_synthetic_y = test_synthetic_y.reindex(sorted(test_synthetic_y.columns), axis=1).values.astype(float)

test_synthetic_maes = [mean_absolute_error(test_synthetic_y[i,:], test_synthetic_preds[i,:]) for i in range(test_fg_synthetic_ss.n_samples)]
test_measured_maes = [mean_absolute_error(test_measurement_y[i,:], test_measurement_preds[i,:]) for i in range(test_fg_measurement_ss.n_samples)]
test_synthetic_mae = np.mean(test_synthetic_maes)
test_measured_mae = np.mean(test_measured_maes)

print(f"MAE (synthetic):", "{test_synthetic_mae:.3f}")
print(f"MAE (measured): ", "{test_measured_mae:.3f}")

In [None]:
"""Plot distribution of reconstruction errors."""
fig, ax = plt.subplots()
ax.hist(test_fg_synthetic_ss.info["unsup_jsd_loss"], bins=50, alpha=0.7, label="synthetic")
ax.hist(test_fg_measurement_ss.info["unsup_jsd_loss"], bins=50, alpha=0.7, label="measured")
ax.legend()
ax.set_xlabel("Reconstruction Error")
ax.set_ylabel("Counts")
fig.savefig("imgs/reconstruction_error_dist.png", dpi=300)
plt.show()

In [None]:
"""Plot MAE vs SNR."""
use_constant_bins = True

plt_bins = 15
lt = test_fg_measurement_ss.info.live_time
cnts = test_fg_measurement_ss.info.total_counts
measured_snrs = np.array(cnts / np.sqrt(lt * 300))

lt = test_fg_synthetic_ss.info.live_time
cnts = test_fg_synthetic_ss.info.total_counts
synthetic_snrs = np.array(cnts / np.sqrt(lt * 300))

if not use_constant_bins:
    a = pd.qcut(measured_snrs, plt_bins, labels=False)
    plt_measured_maes = [np.array(test_measured_maes)[a.values == int(each)] for each in range(plt_bins)]
    plt_measured_mean_maes = [np.mean(np.array(test_measured_maes)[a.values == int(each)]) for each in range(plt_bins)]
    plt_measured_snrs = [np.mean(np.array(measured_snrs)[a.values == int(each)]) for each in range(plt_bins)]
    plt_measured_stds = [np.std(np.array(test_measured_maes)[a.values == int(each)]) for each in range(plt_bins)]

    b = pd.qcut(synthetic_snrs, plt_bins, labels=False)
    plt_synthetic_maes = [np.array(test_synthetic_maes)[b.values == int(each)] for each in range(plt_bins)]
    plt_synthetic_mean_maes = [np.mean(np.array(test_synthetic_maes)[b.values == int(each)]) for each in range(plt_bins)]
    plt_synthetic_snrs = [np.mean(np.array(synthetic_snrs)[b.values == int(each)]) for each in range(plt_bins)]
    plt_synthetic_stds = [np.std(np.array(test_synthetic_maes)[b.values == int(each)]) for each in range(plt_bins)]

else:
    plt_snr_steps = np.linspace(0, 105, plt_bins + 1)
    plt_snr_inds_measured = [np.where(
        (measured_snrs > plt_snr_steps[i]) & (measured_snrs <= plt_snr_steps[i+1])
    ) for i in range(plt_bins)]
    plt_measured_maes = [np.array(test_measured_maes)[each] for each in plt_snr_inds_measured]
    plt_measured_mean_maes = [np.mean(each) for each in plt_measured_maes]
    plt_measured_snrs = [np.mean(measured_snrs[each]) for each in plt_snr_inds_measured]

    plt_snr_inds_synthetic = [np.where(
        (synthetic_snrs > plt_snr_steps[i]) & (synthetic_snrs <= plt_snr_steps[i+1])
    ) for i in range(plt_bins)]
    plt_synthetic_maes = [np.array(test_synthetic_maes)[each] for each in plt_snr_inds_synthetic]
    plt_synthetic_mean_maes = [np.mean(each) for each in plt_synthetic_maes]
    plt_synthetic_snrs = [np.mean(synthetic_snrs[each]) for each in plt_snr_inds_synthetic]

    plt_snr_labels = [int(np.mean([plt_snr_steps[i], plt_snr_steps[i+1]])) for i in range(plt_bins)]


xticks = list(range(0, 101, 10))
xticklabels = [f"{each:.0f}" for each in xticks]

fig, ax = plt.subplots()
ax.boxplot(
    plt_synthetic_maes,
    positions=plt_snr_labels,
    vert=1,
    widths=5
)
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels)
ax.set_xlabel("SNR")
ax.set_ylabel("MAE")

fig.savefig("./imgs/mae_vs_snr_synthetic.png", dpi=300)
plt.show()

fig, ax = plt.subplots()
ax.boxplot(
    plt_measured_maes,
    positions=plt_snr_labels,
    vert=1,
    widths=5
)
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels)
ax.set_xlabel("SNR")
ax.set_ylabel("MAE")

fig.savefig("./imgs/mae_vs_snr_measured.png", dpi=300)
plt.show()

In [None]:
"""Plot predicted proportions alongside true proportions for simulated/measured scenarios."""
n = 1000
iso_names = list(test_fg_measurement_ss.prediction_probas.groupby(
    level="Isotope",
    axis=1
).sum().columns)

fig, axes = plt.subplots(len((iso_names)), 1, sharey=True, figsize=(8,10))
for i in range(y.shape[1]):
    axes[i].plot(test_synthetic_y[:n,i], marker="o", label="true")
    axes[i].plot(test_synthetic_preds[:n,i], marker="o", label="prediction", alpha=0.5)

    axes[i].legend()
    axes[i].set_title(f"{iso_names[i]}")

fig.supxlabel("Test Sample")
fig.supylabel("Proportion")
fig.tight_layout()
fig.savefig("imgs/pred_props_synthetic.png", dpi=300)
plt.show()

fig, axes = plt.subplots(len((iso_names)), 1, sharey=True, figsize=(8,10))

for i in range(y.shape[1]):
    axes[i].plot(test_measurement_y[:n,i], marker="o", label="true")
    axes[i].plot(test_measurement_preds[:n,i], marker="o", label="prediction", alpha=0.5)

    axes[i].legend()
    axes[i].set_title(f"{iso_names[i]}")

fig.supxlabel("Test Sample")
fig.supylabel("Proportion")
fig.tight_layout()
fig.savefig("imgs/pred_props_measured.png", dpi=300)
plt.show()

In [None]:
"""Plot learning curve."""
fig, ax = plot_learning_curve(
    train_loss=model.history["loss"],
    validation_loss=model.history["val_loss"],
    ylim=(None, None),
    yscale="linear",
    show=False
)
ax.set(title=None)
fig.savefig("./imgs/learning_curve.png", dpi=300)
plt.show()

In [None]:
"""Evaluate models on range of beta."""
beta_models_dir = "./beta_models/models_sparsemax_jsd_20231012-131200"
beta_model_paths = [each for each in os.listdir(beta_models_dir) if each.endswith("onnx")]
beta_models = []
beta_maes = []
beta_recon_errors = []
betas = []

for each in beta_model_paths:
    beta = float(re.search(r"beta(.*?)_trial", each).group(1))
    betas.append(beta)
    beta_model = LabelProportionEstimator()

    beta_model.load(os.path.join(beta_models_dir, each))
    beta_model.predict(test_fg_synthetic_ss)
    beta_models.append(beta_model)

    beta_preds = test_fg_synthetic_ss.prediction_probas.groupby(
        level="Isotope",
        axis=1
    ).sum().values

    beta_maes.append(mean_absolute_error(test_synthetic_y, beta_preds))

    beta_recon_errors.append(
        np.mean(test_fg_synthetic_ss.info[beta_model.unsup_loss_func_name])
    )

In [None]:
"""Plot results of beta models."""
beta_inds = [np.where(np.array(betas) == each)[0] for each in sorted(list(set(betas)))]
beta_mae_means = [np.mean(np.array(beta_maes)[each]) for each in beta_inds]
beta_mae_stds = [np.std(np.array(beta_maes)[each]) for each in beta_inds]
beta_recon_error_means = [np.mean(np.array(beta_recon_errors)[each]) for each in beta_inds]
beta_recon_error_stds = [np.std(np.array(beta_recon_errors)[each]) for each in beta_inds]

plt_x = sorted(list(set(betas)))

fig, ax = plt.subplots()

ax.plot(plt_x, beta_mae_means, color="black", linestyle="--", label="mean")
ax.fill_between(
    plt_x,
    np.array(beta_mae_means) - np.array(beta_mae_stds),
    np.array(beta_mae_means) + np.array(beta_mae_stds)
)

ax.xaxis.set_major_formatter(ScalarFormatter())

ax.axvline(0.85, color="black", linestyle=":", label="selected β")
ax.set_ylabel("MAE")
ax.legend()

ax.set_xlabel("β")
fig.savefig("imgs/mae_vs_beta.png", dpi=300)
plt.show()

fig, ax = plt.subplots()

ax.plot(plt_x, beta_recon_error_means, color="black", linestyle="--", label="mean")
ax.fill_between(
    plt_x,
    np.array(beta_recon_error_means) - np.array(beta_recon_error_stds),
    np.array(beta_recon_error_means) + np.array(beta_recon_error_stds),
    label="+/- std. dev."
)

ax.xaxis.set_major_formatter(ScalarFormatter())

ax.axvline(0.85, color="black", linestyle=":", label="selected β")
ax.set_ylabel("Reconstruction Error")
ax.legend()

ax.set_xlabel("β")
fig.savefig("imgs/recon_error_vs_beta.png", dpi=300)
plt.show()

In [None]:
"""Plot tanh function."""
x = np.linspace(-0.5, 0.5, 10000)
fig, ax = plt.subplots()
ax.axhline(0, color="black")
ax.axvline(0, color="black")
ax.plot(x, tf.math.tanh(5*x).numpy())
ax.set_xlabel("x")
ax.set_ylabel("tanh(5*x)")
ax.set_title("Hyperbolic Tangent Function")
fig.savefig("./imgs/tanh.png", dpi=300)
plt.show()

In [None]:
"""Generate distribution of training proportions."""
train_ss = read_hdf(f"./{TRAIN_FILE}")

train_source_df = train_ss.sources.groupby(axis=1, level="Isotope").sum()
fig, axes = plt.subplots(2, 3, figsize=(10,6), sharex=True, sharey=True)
for idx, ax in enumerate(axes.reshape(-1)):
    ax.hist(train_source_df.values[:, idx][train_source_df.values[:, idx] > 0.0], bins=20)
    ax.set_title(list(train_source_df.columns)[idx])
    ax.set_yscale("log")
fig.supylabel("Counts")
fig.supxlabel("Source Proportion")
fig.tight_layout()
fig.savefig("./imgs/train_data_distributions.png", dpi=300)
plt.show()

In [None]:
"""Plot distribution of mix sizes."""
mix_sizes = np.sum(train_ss.sources.values > 0.05, axis=1)
x_mixes = [1, 2, 3]
y_mixes = [np.sum(mix_sizes == each) for each in x_mixes]

fig, ax = plt.subplots()
ax.bar(x_mixes, y_mixes)
ax.set_ylabel("# of Samples")
ax.set_xlabel("# of Sources with Contribution > 0.1")
ax.set_yscale("log")

ax.set_xticks([1.0, 2.0, 3.0])
ax.set_xticklabels([1, 2, 3])
ax.set_yticks([1e4, 1e5, 1e6])

for i in range(len(x_mixes)):
    ax.text(x_mixes[i]-0.15, y_mixes[i]+0.07*y_mixes[i], str(y_mixes[i]))

fig.savefig("./imgs/mix_sizes.png", dpi=300)
plt.show()

In [None]:
"""Visualize random sample."""
random_sample = np.random.randint(train_ss.n_samples)
display(train_ss.sources.iloc[random_sample])
display(train_ss.info.iloc[random_sample])

fig, ax = plot_spectra(train_ss[random_sample], in_energy=True, legend_loc=None, show=False, title="Sample Training Spectrum")
ax.get_legend().remove()
ax.set(title=None)
fig.savefig("./imgs/sample_training_spectrum.png", dpi=300)
plt.show()

In [None]:
"""Load in OOD test data."""
print("loading in data")
bi207_ood_ss = read_hdf(os.path.join(DATA_DIR, "Bi207_ood_test.h5"))
cs137_ood_ss = read_hdf(os.path.join(DATA_DIR, "Cs137_ood_test.h5"))
k40_ood_ss = read_hdf(os.path.join(DATA_DIR, "K40_ood_test.h5"))
ra226_ood_ss = read_hdf(os.path.join(DATA_DIR, "Ra226_ood_test.h5"))
th232_ood_ss = read_hdf(os.path.join(DATA_DIR, "Th232_ood_test.h5"))
cosmic_ood_ss = read_hdf(os.path.join(DATA_DIR, "Cosmic_ood_test.h5"))

print("normalizing...")
bi207_ood_ss.normalize(p=1)
cs137_ood_ss.normalize(p=1)
k40_ood_ss.normalize(p=1)
ra226_ood_ss.normalize(p=1)
th232_ood_ss.normalize(p=1)
cosmic_ood_ss.normalize(p=1)

print("predicting...")
model.predict(bi207_ood_ss)
model.predict(cs137_ood_ss)
model.predict(k40_ood_ss)
model.predict(ra226_ood_ss)
model.predict(th232_ood_ss)
model.predict(cosmic_ood_ss)

print("combining...")
bg_ood_ss = SampleSet()
measured_ood_ss = SampleSet()

bg_ood_ss.concat([k40_ood_ss, ra226_ood_ss, th232_ood_ss, cosmic_ood_ss])
measured_ood_ss.concat([bi207_ood_ss, cs137_ood_ss])

In [None]:
"""Plot recon error vs. OOD contribution."""
n_bins = 10

ood_contrib_steps = np.linspace(0.0, 1.0, n_bins+1)

bg_recon_errors = bg_ood_ss.info[model.unsup_loss_func_name].values
bg_ood_contribs = np.max(bg_ood_ss.sources.values[:,[0,5, 6, 7]], axis=1)

measured_recon_errors = measured_ood_ss.info[model.unsup_loss_func_name].values
measured_ood_contribs = np.max(measured_ood_ss.sources.values[:,[0,5]], axis=1)

bg_ood_contrib_inds = [np.where(
    (bg_ood_contribs > ood_contrib_steps[i]) & (bg_ood_contribs <= ood_contrib_steps[i+1]))[0]
    for i in range(n_bins)
]
measured_ood_contrib_inds = [np.where(
    (measured_ood_contribs > ood_contrib_steps[i]) & (measured_ood_contribs <= ood_contrib_steps[i+1]))[0]
    for i in range(n_bins)
]
bg_recon_error_steps = [bg_recon_errors[each] for each in bg_ood_contrib_inds]
measured_recon_error_steps = [measured_recon_errors[each] for each in measured_ood_contrib_inds]

contrib_x = 0.5*(ood_contrib_steps[1:] + ood_contrib_steps[:-1])

xticks = np.linspace(0.0, 1.0, n_bins+1)
xticklabels = [f"{each:.1f}" for each in xticks]

fig, ax = plt.subplots()

ax.boxplot(bg_recon_error_steps, positions=contrib_x, widths=0.03)
ax.set_yscale("linear")
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels)
ax.set_xlim((0.0, 1.0))
ax.set_xlabel("OOD Contribution")
ax.set_ylabel("Reconstruction Error")
fig.savefig("./imgs/recon_error_vs_ood_contrib_bg.png", dpi=300)
plt.show()

fig, ax = plt.subplots()

ax.boxplot(measured_recon_error_steps, positions=contrib_x, widths=0.03)
ax.set_yscale("linear")
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels)
ax.set_xlim((0.0, 1.0))
ax.set_xlabel("OOD Contribution")
ax.set_ylabel("Reconstruction Error")
fig.savefig("./imgs/recon_error_vs_ood_contrib_measured.png", dpi=300)
plt.show()

In [None]:
"""Fit spline for snr and recon error."""
nbins = 15
out, bins = pd.qcut(synthetic_snrs, nbins, labels=False, retbins=True)
bins = list(bins)
thresholds = [np.quantile(test_fg_synthetic_ss.info[model.unsup_loss_func_name].values[out == int(i)], 0.99) for i in range(nbins)]
avg_snrs = [np.mean(synthetic_snrs[out == int(i)]) for i in range(nbins)]
snr_spl = UnivariateSpline(avg_snrs, thresholds, k=3, s=0)

fig, ax = plt.subplots()
ax.scatter(synthetic_snrs, test_fg_synthetic_ss.info[model.unsup_loss_func_name], alpha=0.4, label="synthetic test sample")
ax.scatter(avg_snrs, thresholds, label="bins", marker="x")
ax.plot(bins, snr_spl(bins), label="U-Spline threshold function", linestyle="--", color="red")
ax.set_xlabel("SNR")
ax.set_ylabel("Reconstruction Error")
ax.legend()
fig.savefig("./imgs/threshold_function_training.png", dpi=300)
plt.show()

In [None]:
"""Plot OOD heatmaps."""
snr_bins = 20
ood_prop_bins = 20

measured_ood_snrs = measured_ood_ss.info.total_counts.values / np.sqrt(measured_ood_ss.info.live_time.values * 300)
bg_ood_snrs = bg_ood_ss.info.total_counts.values / np.sqrt(bg_ood_ss.info.live_time.values * 300)

measured_ood_decisions = measured_recon_errors > snr_spl(measured_ood_snrs)
bg_ood_decisions = bg_recon_errors > snr_spl(bg_ood_snrs)

bg_fns = np.zeros((snr_bins, ood_prop_bins))  # matrix for all false negative values
measured_fns = np.zeros((snr_bins, ood_prop_bins))

snr_steps = np.linspace(0, 100, snr_bins+1)
ood_prop_steps = np.linspace(0, 1.0, ood_prop_bins+1)

for i in range(snr_bins):
    for j in range(ood_prop_bins):
        snr_inds = np.where(
            (measured_ood_snrs > snr_steps[i]) &
            (measured_ood_snrs <= snr_steps[i+1])
        )
        prop_inds = np.where(
            (measured_ood_contribs[snr_inds[0]] > ood_prop_steps[j]) &
            (measured_ood_contribs[snr_inds[0]] <= ood_prop_steps[j+1])
        )

        measured_fns[i, j] = 1.0 - measured_ood_decisions[snr_inds[0]][prop_inds[0]].mean()

measured_fns = pd.DataFrame(
    measured_fns
).interpolate().values

fig, ax = plt.subplots()
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
im = ax.imshow(measured_fns.T, extent=[0, 100, 0, 1.0], origin="lower", aspect="auto", interpolation="nearest")
ax.set_ylim((0, 0.9))
ax.set_xlabel("SNR")
ax.set_ylabel("OOD Proportion")
cb = fig.colorbar(im, cax=cax, orientation="vertical")
cb.set_label("OOD FNR")
fig.tight_layout()
fig.savefig("./imgs/ood_plot_measured.png", dpi=300)
plt.show()

for i in range(snr_bins):
    for j in range(ood_prop_bins):
        snr_inds = np.where(
            (bg_ood_snrs > snr_steps[i]) &
            (bg_ood_snrs <= snr_steps[i+1])
        )
        prop_inds = np.where(
            (bg_ood_contribs[snr_inds[0]] > ood_prop_steps[j]) &
            (bg_ood_contribs[snr_inds[0]] <= ood_prop_steps[j+1])
        )

        bg_fns[i, j] = 1.0 - bg_ood_decisions[snr_inds[0]][prop_inds[0]].mean()

bg_fns = pd.DataFrame(
    bg_fns
).interpolate().values


fig, ax = plt.subplots()
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
im = ax.imshow(bg_fns.T, extent=[0, 100, 0, 1.0], origin="lower", aspect="auto")
ax.set_ylim((0, 0.9))
ax.set_xlabel("SNR")
ax.set_ylabel("OOD Proportion")
cb = fig.colorbar(im, cax=cax, orientation="vertical")
cb.set_label("OOD FNR")
fig.tight_layout()
fig.savefig("./imgs/ood_plot_bg.png", dpi=300)
plt.show()

In [None]:
"""Generate spectral distance matrix."""

fg_measurement_seed_inds = [3, 7, 10, 14, 18, 21]
fg_measurement_seeds_ss = fg_measurements_ss[fg_measurement_seed_inds]
fg_measurement_seeds_ss.drop_sources_columns_with_all_zeros()

ood_id_seeds_ss = SampleSet()
ood_id_seeds_ss.concat([id_seeds_ss, bg_seeds_ss, fg_measurement_seeds_ss])

sort_indices = ood_id_seeds_ss.sources.idxmax().values
ood_id_seeds_ss.sources = ood_id_seeds_ss.sources.iloc[sort_indices].reset_index(drop=True)
ood_id_seeds_ss.spectra = ood_id_seeds_ss.spectra.iloc[sort_indices].reset_index(drop=True)
ood_id_seeds_ss.info = ood_id_seeds_ss.info.iloc[sort_indices].reset_index(drop=True)

sdm = ood_id_seeds_ss.get_spectral_distance_matrix()
sdm.to_html("./spectral_distances.html")
sdm.to_csv("./spectral_distances.csv")

In [None]:
"""Plot synthetic fg seeds."""
for i in range(id_seeds_ss.n_samples):
    fig, ax = plot_spectra(id_seeds_ss[i], show=False, in_energy=True)
    ax.set_title(f"{list(id_seeds_ss.sources.columns[i])[1]}")
    ax.set(title=None)
    fig.savefig(f"./imgs/synthetic_{list(id_seeds_ss.sources.columns[i])[1]}.png", dpi=300)
    plt.show()

In [None]:
"""Plot synthetic bg seeds."""
for i in range(bg_seeds_ss.n_samples):
    fig, ax = plot_spectra(bg_seeds_ss[i], show=False, in_energy=True)
    ax.set_title(f"{list(bg_seeds_ss.sources.columns[i])[2]}")
    ax.set(title=None)
    fig.savefig(f"./imgs/synthetic_bg_{list(bg_seeds_ss.sources.columns[i])[2]}.png", dpi=300)
    plt.show()

In [None]:
"""Plot measurements."""
gross_labels = np.array(gross_measurements_ss.get_labels())

for each in list(set(list(gross_measurements_ss.get_labels()))):
    measured_ss = gross_measurements_ss[gross_labels == each]

    fig, ax = plot_spectra(
        measured_ss,
        ylim=(None, None),
        in_energy=True,
        target_level="Seed",
        show=False,
        title=f"Gross Measurements for {each}"
    )
    ax.set(title=None)
    fig.savefig(f"./imgs/measured_{each}.png", dpi=300)
    plt.show()

In [None]:
"""Plot measured bg."""
fig, ax = plot_spectra(bg_measurements_ss, show=False, in_energy=True)
ax.set_title("Background Measurement")
ax.set(title=None)
fig.savefig("./imgs/measured_bg.png", dpi=300)
plt.show()