    Copyright 2024 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
    Under the terms of Contract DE-NA0003525 with NTESS, the U.S. Government retains certain rights in this software.

**Step 7 - compute results on measured data from PNNL**

In [None]:
"""Imports"""

import glob
import os
import re

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from config import (FG_SEED_FILE, IMAGE_DIR, PNNL_DATA_DIR, PNNL_FG_FILE,
                    PNNL_MEASUREMENTS_FILE, TARGET_BINS, TARGET_ECAL)
from riid.data.sampleset import (SampleSet, SpectraState, SpectraType,
                                 read_hdf, read_pcf)
from riid.losses import (chi_squared_diff, jensen_shannon_divergence,
                         poisson_nll_diff, sse_diff)
from riid.visualize import plot_spectra
from scipy.optimize import nnls
from scipy.spatial.distance import jensenshannon
from utils import load_final_models

In [None]:
"""Focus unsupervised loss."""

plt_names = ["$\chi^2$",  "JSD", "PNLL", "SSE"]
focus_unsup_loss = 1  # JSD

In [None]:
"""Links for each bg spectrum to each measurement spectrum."""

bg_links = {
    'u-235_fiss_pnnl_6-hl_cnf':         'Bkg_K0PSV2ML_080811',
    'u-235_fiss_68pct_24-min_cnf':      'B1109173',
    'u-235_fiss_pnnl_6-h_cnf':          'Bkg_K0PSV2ML_080811',
    'u-235_fiss_pnnl_148-h_cnf':        'Bkg_K0PSV2ML_080811',
    'u-235_fiss_68pct_173-min_cnf':     'B1109173',
    'u-235_fiss_68pct_57-min_cnf':      'B1109173',
    'u-235_fiss_pnnl_7-h_cnf':          'Bkg_C0FILPVL',
    'u-235_fiss_68pct_91-min_cnf':      'B1109173',
    'u-235_fiss_68pct_444-min_cnf':     'B1109173',
    'u-235_fiss_68pct_35-min_cnf':      'B1109173',
    'u-235_fiss_68pct_12-min_cnf':      'B1109173',
    'u-235_fiss_68pct_236-min_cnf':     'B1109173',
    'u-235_fiss_68pct_5-min_cnf':       'B1109173',
    'u-235_fiss_68pct_107-min_cnf':     'B1109173',
    'u-235_fiss_pnnl_30-h_cnf':         'Bkg_K0PSV2ML_080811',
    'u-235_fiss_pnnl_24-h_cnf':         'Bkg_K0PSV2ML_080811',
    'u-235_fiss_68pct_46-min_cnf':      'B1109173',
    'u-235_fiss_68pct_139-min_cnf':     'B1109173',
    'u-235_fiss_68pct_18-min_cnf':      'B1109173',
    'u-235_fiss_68pct_74-min_cnf':      'B1109173',
    'np-237_epi_68pct_16-min_cnf':      'B1109173',
    'np-237_epi_68pct_87-min_cnf':      'B1109173',
    'np-237_epi_68pct_103-min_cnf':     'B1109173',
    'np-237_epi_68pct_44-min_cnf':      'B1109173',
    'np-237_epi_pnnl_100-h_cnf':        'Bkg_H0960',
    'np-237_epi_68pct_22-min_cnf':      'B1109173',
    'np-237_epi_68pct_10-min_cnf':      'B1109173',
    'np-237_epi_68pct_4-min_cnf':       'B1109173',
    'np-237_epi_pnnl_294-h_cnf':        'Bkg_H0960',
    'np-237_epi_68pct_167-min_cnf':     'B1109173',
    'np-237_epi_pnnl_241-h_cnf':        'Bkg_H0960',
    'np-237_epi_68pct_55-min_cnf':      'B1109173',
    'np-237_epi_68pct_230-min_cnf':     'B1109173',
    'np-237_epi_68pct_293-min_cnf':     'B1109173',
    'np-237_epi_pnnl_247-h_cnf':        'Bkg_H0960',
    'np-237_epi_68pct_135-min_cnf':     'B1109173',
    'np-237_epi_68pct_33-min_cnf':      'B1109173',
    'np-237_epi_68pct_71-min_cnf':      'B1109173',
    'np-237_epi_pnnl_6-h_cnf':          'Bkg_H0960',
    'pu-239_th_pnnl_7-h_cnf':           'Bkg_E0PSV2ML_080811',
    'pu-239_th_68pct_100-min_cnf':      'B1109173',
    'pu-239_th_68pct_31-min_cnf':       'B1109173',
    'pu-239_th_68pct_42-min_cnf':       'B1109173',
    'pu-239_th_68pct_192-min_cnf':      'B1109173',
    'pu-239_th_68pct_131-min_cnf':      'B1109173',
    'pu-239_th_pnnl_77-h_cnf':          'Bkg_E0PSV2ML_080811',
    'pu-239_th_68pct_15-min_cnf':       'B1109173',
    'pu-239_th_68pct_84-min_cnf':       'B1109173',
    'pu-239_th_pnnl_170-h_cnf':         'Bkg_E0PSV2ML_080811',
    'pu-239_th_68pct_69-min_cnf':       'B1109173',
    'pu-239_th_68pct_4-min_cnf':        'B1109173',
    'pu-239_th_pnnl_55-h_cnf':          'Bkg_E0PSV2ML_080811',
    'pu-239_th_68pct_253-min_cnf':      'B1109173',
    'pu-239_th_pnnl_8-h_cnf':           'Bkg_E0PSV2ML_080811',
    'pu-239_th_68pct_20-min_cnf':       'B1109173',
    'pu-239_th_68pct_313-min_cnf':      'B1109173',
    'pu-239_th_pnnl_54-h_cnf':          'Bkg_E0PSV2ML_080811',
    'pu-239_th_68pct_53-min_cnf':       'B1109173',
    'pu-239_th_68pct_162-min_cnf':      'B1109173',
    'pu-239_th_pnnl_28-h_cnf':          'Bkg_E0PSV2ML_080811',
    'pu-239_th_68pct_9-min_cnf':        'B1109173',
    'pu-239_epi_pnnl_56-h_cnf':         'Bkg_D0PSV10MLW',
    'pu-239_epi_pnnl_6-hl_cnf':         'Bkg_D0PSV10MLW',
    'pu-239_epi_68pct_285-min_cnf':     'B1109173',
    'pu-239_epi_68pct_4-min_cnf':       'B1109173',
    'pu-239_epi_68pct_87-min_cnf':      'B1109173',
    'pu-239_epi_68pct_21-min_cnf':      'B1109173',
    'pu-239_epi_68pct_164-min_cnf':     'B1109173',
    'pu-239_epi_68pct_44-min_cnf':      'B1109173',
    'pu-239_epi_68pct_133-min_cnf':     'B1109173',
    'pu-239_epi_68pct_102-min_cnf':     'B1109173',
    'pu-239_epi_68pct_10-min_cnf':      'B1109173',
    'pu-239_epi_68pct_15-min_cnf':      'B1109173',
    'pu-239_epi_68pct_225-min_cnf':     'B1109173',
    'pu-239_epi_68pct_32-min_cnf':      'B1109173',
    'pu-239_epi_pnnl_6-h_cnf':          'Bkg_D0PSV10MLW',
    'pu-239_epi_68pct_55-min_cnf':      'B1109173',
    'pu-239_epi_68pct_71-min_cnf':      'B1109173',
    'pu-239_epi_pnnl_344-h_cnf':        'Bkg_D0PSV10MLW',
    'pu-239_epi_pnnl_202-h_cnf':        'Bkg_D0PSV10MLW',
    'np-237_fiss_68pct_103-min_cnf':    'B1109173',
    'np-237_fiss_68pct_54-min_cnf':     'B1109173',
    'np-237_fiss_68pct_21-min_cnf':     'B1109173',
    'np-237_fiss_pnnl_322-h_cnf':       'Bkg_C0FILPVL',  # should it be 321 hr?
    'np-237_fiss_68pct_176-min_cnf':    'B1109173',
    'np-237_fiss_pnnl_34-h_cnf':        'Bkg_C0FILPVL',  # ?
    'np-237_fiss_68pct_10-min_cnf':     'B1109173',
    'np-237_fiss_pnnl_27-h_cnf':        'Bkg_C0FILPVL',  # ?
    'np-237_fiss_68pct_298-min_cnf':    'B1109173',
    'np-237_fiss_pnnl_6-h_cnf':         'Bkg_C0FILPVL',
    'np-237_fiss_68pct_15-min_cnf':     'B1109173',
    'np-237_fiss_pnnl_56-h_cnf':        'Bkg_C0FILPVL',  # ?
    'np-237_fiss_68pct_32-min_cnf':     'B1109173',
    'np-237_fiss_68pct_70-min_cnf':     'B1109173',
    'np-237_fiss_68pct_86-min_cnf':     'B1109173',
    'np-237_fiss_pnnl_9-h_cnf':         'Bkg_C0FILPVL',  # ?
    'np-237_fiss_68pct_43-min_cnf':     'B1109173',
    'np-237_fiss_68pct_4-min_cnf':      'B1109173',
    'np-237_fiss_68pct_136-min_cnf':    'B1109173',
    'np-237_fiss_68pct_233-min_cnf':    'B1109173',
    'np-237_fiss_pnnl_174-h_cnf':       'Bkg_C0FILPVL',  # ?
    'u-238_epi_68pct_86-min_cnf':       'B1109173',
    'u-238_epi_68pct_151-min_cnf':      'B1109173',
    'u-238_epi_68pct_29-min_cnf':       'B1109173',
    'u-238_epi_68pct_246-min_cnf':      'B1109173',
    'u-238_epi_68pct_42-min_cnf':       'B1109173',
    'u-238_epi_pnnl_6-h_cnf':           'Bkg_H0960',
    'u-238_epi_68pct_69-min_cnf':       'B1109173',
    'u-238_epi_68pct_4-min_cnf':        'B1109173',
    'u-238_epi_68pct_118-min_cnf':      'B1109173',
    'u-238_epi_pnnl_105-h_cnf':         'Bkg_H0960',
    'u-238_epi_68pct_22-min_cnf':       'B1109173',
    'u-238_epi_68pct_13-min_cnf':       'B1109173',
    'u-238_epi_pnnl_75-h_cnf':          'Bkg_H0960',
    'u-238_epi_68pct_183-min_cnf':      'B1109173',
    'u-238_epi_68pct_102-min_cnf':      'B1109173',
    'u-238_epi_68pct_54-min_cnf':       'B1109173',
    'u-238_epi_68pct_309-min_cnf':      'B1109173',
    'u-233_th_68pct_17-min_cnf':        'B1109173',
    'u-233_th_68pct_106-min_cnf':       'B1109173',
    'u-233_th_68pct_137-min_cnf':       'B1109173',
    'u-233_th_68pct_291-min_cnf':       'B1109173',
    'u-233_th_pnnl_9-h_cnf':            'Bkg_E0PSV2ML_080811',
    'u-233_th_68pct_89-min_cnf':        'B1109173',
    'u-233_th_68pct_11-min_cnf':        'B1109173',
    'u-233_th_68pct_23-min_cnf':        'B1109173',
    'u-233_th_68pct_4-min_cnf':         'B1109173',
    'u-233_th_68pct_34-min_cnf':        'B1109173',
    'u-233_th_68pct_44-min_cnf':        'B1109173',
    'u-233_th_pnnl_7-h_cnf':            'Bkg_E0PSV2ML_080811',
    'u-233_th_pnnl_52-h_cnf':           'Bkg_E0PSV2ML_080811',
    'u-233_th_pnnl_23-h_cnf':           'Bkg_E0PSV2ML_080811',
    'u-233_th_68pct_229-min_cnf':       'B1109173',
    'u-233_th_68pct_73-min_cnf':        'B1109173',
    'u-233_th_pnnl_30-h_cnf':           'Bkg_E0PSV2ML_080811',
    'u-233_th_68pct_57-min_cnf':        'B1109173',
    'u-233_th_68pct_168-min_cnf':       'B1109173',
    'u-233_epi_pnnl_29-h_cnf':          'Bkg_E0PSV2ML_080811',
    'u-233_epi_68pct_306-min_cnf':      'B1109173',
    'u-233_epi_68pct_10-min_cnf':       'B1109173',
    'u-233_epi_68pct_22-min_cnf':       'B1109173',
    'u-233_epi_pnnl_11-h_cnf':          'Bkg_H0960',
    'u-233_epi_68pct_32-min_cnf':       'B1109173',
    'u-233_epi_68pct_70-min_cnf':       'B1109173',
    'u-233_epi_68pct_102-min_cnf':      'B1109173',
    'u-233_epi_68pct_245-min_cnf':      'B1109173',
    'u-233_epi_pnnl_82-h_cnf':          'Bkg_H0960',
    'u-233_epi_68pct_54-min_cnf':       'B1109173',
    'u-233_epi_68pct_149-min_cnf':      'B1109173',
    'u-233_epi_68pct_16-min_cnf':       'B1109173',
    'u-233_epi_68pct_43-min_cnf':       'B1109173',
    'u-233_epi_68pct_4-min_cnf':        'B1109173',
    'u-233_epi_68pct_184-min_cnf':      'B1109173',
    'u-233_epi_pnnl_6-h_cnf':           'Bkg_H0960',
    'u-233_epi_68pct_86-min_cnf':       'B1109173',
    'u-233_fiss_68pct_14-min_cnf':      'B1109173',
    'u-233_fiss_68pct_85-min_cnf':      'B1109173',
    'u-233_fiss_68pct_53-min_cnf':      'B1109173',
    'u-233_fiss_68pct_9-min_cnf':       'B1109173',
    'u-233_fiss_pnnl_6-h_cnf':          'Bkg_E0PSV2ML_080811',
    'u-233_fiss_pnnl_50-h_cnf':         'Bkg_E0PSV2ML_080811',
    'u-233_fiss_68pct_20-min_cnf':      'B1109173',
    'u-233_fiss_68pct_227-min_cnf':     'B1109173',
    'u-233_fiss_pnnl_77-h_cnf':         'Bkg_K0PSV2ML_080811',
    'u-233_fiss_68pct_69-min_cnf':      'B1109173',
    'u-233_fiss_68pct_4-min_cnf':       'B1109173',
    'u-233_fiss_68pct_102-min_cnf':     'B1109173',
    'u-233_fiss_68pct_42-min_cnf':      'B1109173',
    'u-233_fiss_68pct_133-min_cnf':     'B1109173',
    'u-233_fiss_68pct_31-min_cnf':      'B1109173',
    'u-233_fiss_68pct_289-min_cnf':     'B1109173',
    'u-233_fiss_68pct_164-min_cnf':     'B1109173',
    'u-235_epi_68pct_134-min_cnf':      'B1109173',
    'u-235_epi_pnnl_6-h_cnf':           'Bkg_E0PSV2ML_080811',
    'u-235_epi_68pct_72-min_cnf':       'B1109173',
    'u-235_epi_68pct_228-min_cnf':      'B1109173',
    'u-235_epi_pnnl_58-h_cnf':          'Bkg_E0PSV2ML_080811',
    'u-235_epi_68pct_165-min_cnf':      'B1109173',
    'u-235_epi_68pct_46-min_cnf':       'B1109173',
    'u-235_epi_68pct_23-min_cnf':       'B1109173',
    'u-235_epi_68pct_4-min_cnf':        'B1109173',
    'u-235_epi_pnnl_29-h_cnf':          'Bkg_E0PSV2ML_080811',
    'u-235_epi_68pct_16-min_cnf':       'B1109173',
    'u-235_epi_68pct_34-min_cnf':       'B1109173',
    'u-235_epi_68pct_10-min_cnf':       'B1109173',
    'u-235_epi_68pct_104-min_cnf':      'B1109173',
    'u-235_epi_pnnl_8-h_cnf':           'Bkg_E0PSV2ML_080811',
    'u-235_epi_pnnl_20-h_cnf':          'Bkg_E0PSV2ML_080811',
    'u-235_epi_68pct_88-min_cnf':       'B1109173',
    'u-235_epi_68pct_289-min_cnf':      'B1109173',
    'u-235_epi_68pct_57-min_cnf':       'B1109173',
    'pu-239_fiss_68pct_71-min_cnf':     'B1109173',
    'pu-239_fiss_68pct_228-min_cnf':    'B1109173',
    'pu-239_fiss_68pct_103-min_cnf':    'B1109173',
    'pu-239_fiss_68pct_43-min_cnf':     'B1109173',
    'pu-239_fiss_68pct_33-min_cnf':     'B1109173',
    'pu-239_fiss_68pct_166-min_cnf':    'B1109173',
    'pu-239_fiss_68pct_135-min_cnf':    'B1109173',
    'pu-239_fiss_pnnl_30-h_cnf':        'Bkg_E0PSV2ML_080811',
    'pu-239_fiss_68pct_10-min_cnf':     'B1109173',
    'pu-239_fiss_68pct_22-min_cnf':     'B1109173',
    'pu-239_fiss_pnnl_6-h_cnf':         'Bkg_E0PSV2ML_080811',
    'pu-239_fiss_68pct_87-min_cnf':     'B1109173',
    'pu-239_fiss_68pct_54-min_cnf':     'B1109173',
    'pu-239_fiss_68pct_16-min_cnf':     'B1109173',
    'pu-239_fiss_68pct_4-min_cnf':      'B1109173',
    'pu-239_fiss_pnnl_125-h_cnf':       'Bkg_E0PSV2ML_080811',
    'pu-239_fiss_68pct_290-min_cnf':    'B1109173',
    'pu-239_fiss_pnnl_54-h_cnf':        'Bkg_E0PSV2ML_080811',
    'u-238_fiss_68pct_70-min_cnf':      'B1109173',
    'u-238_fiss_68pct_4-min_cnf':       'B1109173',
    'u-238_fiss_pnnl_14-h_cnf':         'Bkg_E0PSV2ML_080811',
    'u-238_fiss_68pct_22-min_cnf':      'B1109173',
    'u-238_fiss_68pct_10-min_cnf':      'Bkg_E0PSV2ML_080811',
    'u-238_fiss_68pct_131-min_cnf':     'B1109173',
    'u-238_fiss_pnnl_8-h_cnf':          'Bkg_E0PSV2ML_080811',
    'u-238_fiss_68pct_44-min_cnf':      'B1109173',
    'u-238_fiss_68pct_16-min_cnf':      'B1109173',
    'u-238_fiss_pnnl_81-h_cnf':         'Bkg_E0PSV2ML_080811',
    'u-238_fiss_68pct_102-min_cnf':     'B1109173',
    'u-238_fiss_68pct_162-min_cnf':     'B1109173',
    'u-238_fiss_68pct_33-min_cnf':      'B1109173',
    'u-238_fiss_68pct_310-min_cnf':     'B1109173',
    'u-238_fiss_pnnl_10-h_cnf':         'Bkg_E0PSV2ML_080811',
    'u-238_fiss_pnnl_55-h_cnf':         'Bkg_E0PSV2ML_080811',
    'u-238_fiss_68pct_55-min_cnf':      'B1109173',
    'u-238_fiss_68pct_86-min_cnf':      'B1109173',
    'u-238_fiss_68pct_223-min_cnf':     'B1109173',
    'u-238_fiss_pnnl_30-h_cnf':         'Bkg_E0PSV2ML_080811',
    'u-235_th_68pct_4-min_cnf_cnf':         'B1109173',
    'u-235_th_68pct_10-min_cnf':        'B1109173',
    'u-235_th_68pct_15-min_cnf':        'B1109173',
    'u-235_th_68pct_20-min_cnf':        'B1109173',
    'u-235_th_68pct_31-min_cnf':        'B1109173',
    'u-235_th_68pct_41-min_cnf':        'B1109173',
    'u-235_th_68pct_52-min_cnf':        'B1109173',
    'u-235_th_68pct_68-min_cnf':        'B1109173',
    'u-235_th_68pct_83-min_cnf':        'B1109173',
    'u-235_th_68pct_99-min_cnf':        'B1109173',
    'u-235_th_68pct_129-min_cnf':       'B1109173',
    'u-235_th_68pct_160-min_cnf':       'B1109173',
    'u-235_th_68pct_220-min_cnf':       'B1109173',
    'u-235_th_68pct_280-min_cnf':       'B1109173',
    'u-235_th_pnnl_7-h_cnf':            'Bkg_E0PSV2ML_080811',
    'u-235_th_pnnl_8-h_cnf':            'Bkg_E0PSV2ML_080811',
    'u-235_th_pnnl_8-hl_cnf':            'Bkg_E0PSV2ML_080811',
    'u-235_th_pnnl_22-h_cnf':            'Bkg_E0PSV2ML_080811',
    'u-235_th_pnnl_30-h_cnf':            'Bkg_E0PSV2ML_080811',
    'u-235_th_pnnl_45-h_cnf':            'Bkg_E0PSV2ML_080811',
    'u-235_th_pnnl_54-h_cnf':            'Bkg_E0PSV2ML_080811',
    'u-235_th_pnnl_74-h_cnf':            'Bkg_E0PSV2ML_080811',
    'u-235_th_pnnl_95-h_cnf':            'Bkg_E0PSV2ML_080811',
    'u-235_th_pnnl_103-h_cnf':            'Bkg_E0PSV2ML_080811',
    'u-235_th_pnnl_124-h_cnf':            'Bkg_E0PSV2ML_080811',
    'u-235_th_pnnl_151-h_cnf':            'Bkg_E0PSV2ML_080811',
    'u-235_th_pnnl_177-h_cnf':            'Bkg_E0PSV2ML_080811',
    'u-235_th_pnnl_191-h_cnf':            'Bkg_E0PSV2ML_080811',
    'u-235_th_pnnl_218-h_cnf':            'Bkg_E0PSV2ML_080811',
    'u-235_th_detective_14min':           'Detective_bkgd_many_day',
    'u-235_th_detective_19min':           'Detective_bkgd_many_day',
    'u-235_th_detective_25min':           'Detective_bkgd_many_day',
    'u-235_th_detective_31min':           'Detective_bkgd_many_day',
    'u-235_th_detective_42min':           'Detective_bkgd_many_day',
    'u-235_th_detective_52min':           'Detective_bkgd_many_day',
    'u-235_th_detective_1h3min':           'Detective_bkgd_many_day',
    'u-235_th_detective_1h19min':           'Detective_bkgd_many_day',
    'u-235_th_detective_1h35min':           'Detective_bkgd_many_day',
    'u-235_th_detective_1h51min':           'Detective_bkgd_many_day',
    'u-235_th_detective_2h24min':           'Detective_bkgd_many_day',
    'u-235_th_detective_2h54min':           'Detective_bkgd_many_day',
    'u-235_th_detective_3h26min':           'Detective_bkgd_many_day',
    'u-235_th_detective_4h28min':           'Detective_bkgd_many_day',
    'u-235_th_detective_5h29min':           'Detective_bkgd_many_day',
    'u-235_th_detective_6h40min':           'Detective_bkgd_many_day',
    'u-235_th_microdetective_13min':      'Detective_bkgd_many_day',  # unclear from documentation
    'u-235_th_microdetective_19min':      'Detective_bkgd_many_day',  # unclear from documentation
    'u-235_th_microdetective_24min':      'Detective_bkgd_many_day',  # unclear from documentation
    'u-235_th_microdetective_31min':      'Detective_bkgd_many_day',  # unclear from documentation
    'u-235_th_microdetective_13min':      'Detective_bkgd_many_day',  # unclear from documentation
    'u-235_th_microdetective_42min':      'Detective_bkgd_many_day',  # unclear from documentation
    'u-235_th_microdetective_52min':      'Detective_bkgd_many_day',  # unclear from documentation
    'u-235_th_microdetective_1h3min':      'Detective_bkgd_many_day',  # unclear from documentation
    'u-235_th_microdetective_1h19min':      'Detective_bkgd_many_day',  # unclear from documentation
    'u-235_th_microdetective_1h35min':      'Detective_bkgd_many_day',  # unclear from documentation
    'u-235_th_microdetective_1h51min':      'Detective_bkgd_many_day',  # unclear from documentation
    'u-235_th_microdetective_2h24min':      'Detective_bkgd_many_day',  # unclear from documentation
    'u-235_th_microdetective_2h54min':      'Detective_bkgd_many_day',  # unclear from documentation
    'u-235_th_microdetective_3h26min':      'Detective_bkgd_many_day',  # unclear from documentation
    'u-235_th_microdetective_4h28min':      'Detective_bkgd_many_day',  # unclear from documentation
    'u-235_th_microdetective_5h29min':      'Detective_bkgd_many_day',  # unclear from documentation
    'u-235_th_microdetective_6h40min':      'Detective_bkgd_many_day',  # unclear from documentation
    'u-235_th_falcon_11min':                'wsu_canberra_falcon_bkg',
    'u-235_th_falcon_17min':                'wsu_canberra_falcon_bkg',
    'u-235_th_falcon_23min':                'wsu_canberra_falcon_bkg',
    'u-235_th_falcon_28min':                'wsu_canberra_falcon_bkg',
    'u-235_th_falcon_35min':                'wsu_canberra_falcon_bkg',
    'u-235_th_falcon_46min':                'wsu_canberra_falcon_bkg',
    'u-235_th_falcon_57min':                'wsu_canberra_falcon_bkg',
    'u-235_th_falcon_1hr8min':                'wsu_canberra_falcon_bkg',
    'u-235_th_falcon_1hr25min':                'wsu_canberra_falcon_bkg',
    'u-235_th_falcon_1hr39min':                'wsu_canberra_falcon_bkg',
    'u-235_th_falcon_1hr59min':                'wsu_canberra_falcon_bkg',
    'u-235_th_falcon_2hr27min':                'wsu_canberra_falcon_bkg',
    'u-235_th_falcon_2hr57min':                'wsu_canberra_falcon_bkg',
    'u-235_th_falcon_3hr30min':                'wsu_canberra_falcon_bkg',
    'u-235_th_falcon_4hr33min':                'wsu_canberra_falcon_bkg',
    'u-235_th_falcon_5hr34min':                'wsu_canberra_falcon_bkg',
    'u-235_th_falcon_6hr43min':                'wsu_canberra_falcon_bkg',
}

In [None]:
"""Get all the PNNL data."""

pnnl_pcf_files = [y for x in os.walk(PNNL_DATA_DIR) for y in glob.glob(os.path.join(x[0], "*.pcf"))]

valid_isos = ["np-237", "pu-239", "u-233", "u-235", "u-238"]
valid_bgs = [each.lower() for each in list(set(bg_links.values()))]
pnnl_categories = []
pnnl_times = []
pnnl_filenames = []

pnnl_bg_ss = SampleSet()
pnnl_gross_ss = SampleSet()

for pcf_file in pnnl_pcf_files:
    filename = os.path.splitext(os.path.basename(pcf_file))[0]
    if filename.split("_")[0].lower() in valid_isos:
        # first get time
        delay = re.sub("-", "", filename.split("_")[3]).lower()
        delay_min = re.findall(r"(\d+)min", delay)
        delay_hr = re.findall(r"(\d+)h", delay)
        delay_time = 0
        if delay_min:
            delay_time += int(delay_min[0])
        if delay_hr:
            delay_time += 60 * int(delay_hr[0])
        pnnl_times.append(delay_time)

        # get category
        spectra_category = "_".join(filename.split("_")[0:3]).lower()
        pnnl_categories.append(spectra_category)

        ss = read_pcf(pcf_file)
        spectra_description = "_".join(filename.split("_")[0:4]).lower()
        ss.info.description = spectra_description
        ss.downsample_spectra(target_bins=TARGET_BINS)
        pnnl_gross_ss.concat(ss)

        pnnl_filenames.append(filename)

    elif filename.lower() in valid_bgs:
        ss = read_pcf(pcf_file)
        ss.info.description = filename.lower()
        ss.downsample_spectra(target_bins=TARGET_BINS)
        pnnl_bg_ss.concat(ss)
    else:
        print(filename)

grouped_inds = pd.Series(range(len(pnnl_categories))).groupby(pnnl_categories, sort=False).apply(list).tolist()
pnnl_times = np.array(pnnl_times)
pnnl_categories = np.array(pnnl_categories)

# Optionally, save SampleSet to file.
pnnl_gross_ss.to_hdf(PNNL_MEASUREMENTS_FILE)

In [None]:
"""Plot PNNL background spectra."""

fig, ax = plot_spectra(pnnl_bg_ss, figsize=(12,6), show=False, in_energy=True)
ax.set_title('PNNL bg spectra')
ax.legend(list(pnnl_bg_ss.info.description))
plt.show()

In [None]:
"""Create bg-subtracted fg SampleSet from PNNL data."""

pnnl_snrs = []
pnnl_fg_ss = SampleSet()

for idx, each in enumerate(pnnl_filenames):
    # first find bg spectrum associated with net spectrum
    bg_idx = list(pnnl_bg_ss.info.description).index(bg_links[each.lower()].lower())

    bg_ss = pnnl_bg_ss[bg_idx]
    bg_ss = bg_ss.as_ecal(*pnnl_gross_ss.ecal[idx])
    bg_ss.spectra_type = SpectraType.Background
    bg_ss.spectra_state = SpectraState.Counts
    gross_ss = pnnl_gross_ss[idx]
    gross_ss.spectra_type = SpectraType.Gross
    gross_ss.spectra_state = SpectraState.Counts

    bg_cps = bg_ss.info.total_counts.values / bg_ss.info.live_time.values
    gross_cps = gross_ss.info.total_counts.values / gross_ss.info.live_time.values
    fg_cps = gross_cps - bg_cps
    pnnl_snrs.append((fg_cps / np.sqrt(bg_cps))[0])

    fg_ss = gross_ss - bg_ss
    pnnl_fg_ss.concat(fg_ss)

pnnl_fg_ss.clip_negatives()

pnnl_fg_ss.to_hdf(PNNL_FG_FILE)
pnnl_fg_ss

In [None]:
"""Compare a net and fg spectrum just to check."""

plot_idx = 0
plot_ss = SampleSet()
plot_ss.concat(pnnl_gross_ss[plot_idx])
plot_ss.concat(pnnl_fg_ss[plot_idx])
fig, ax = plot_spectra(plot_ss, figsize=(12,6), show=False, in_energy=True)
ax.set_title('net vs. fg PNNL spectrum ({})'.format(pnnl_gross_ss.info.description[plot_idx]))
ax.legend(['net', 'fg'])
plt.show()

In [None]:
"""Generate reference spectrum from reconstruction."""

fg_seeds_ss = read_hdf(FG_SEED_FILE)
fg_seeds_ss, _ = fg_seeds_ss.split_fg_and_bg()
fg_seeds_ss.drop_sources_columns_with_all_zeros()

# get expected source contributions before anything else!
source_counts = {
    x.split(",")[0]: v
    for x, v in zip(
        fg_seeds_ss.sources.columns.get_level_values("Seed").values,
        fg_seeds_ss.info.total_counts
    )
}
Z = np.array(list(source_counts.values()))
expected_props = Z / Z.sum()

fg_seeds_ss.downsample_spectra(target_bins=TARGET_BINS)
fg_seeds_ss = fg_seeds_ss.as_ecal(*TARGET_ECAL)
fg_seeds_ss.normalize(p=1)

reconstructed_fission_seed = np.zeros((1, TARGET_BINS))
for i in range(fg_seeds_ss.n_samples):
    reconstructed_fission_seed += expected_props[i] * fg_seeds_ss.spectra.values[i,:]

reconstructed_fission_seed_ss = SampleSet()
reconstructed_fission_seed_ss.spectra = pd.DataFrame(
    data=reconstructed_fission_seed
)
reconstructed_fission_seed_ss.sources = fg_seeds_ss.sources.drop(
    fg_seeds_ss.sources.index.to_list()[1:], axis=0
)
reconstructed_fission_seed_ss.sources.iloc[0] = expected_props
reconstructed_fission_seed_ss.info = pd.DataFrame(
    data=fg_seeds_ss.info.values[0,:].reshape(1,len(fg_seeds_ss.info.values[0,:])),
    columns=fg_seeds_ss.info.columns
)

isotope_order = np.argsort(expected_props)[::-1]

fig, ax = plot_spectra(reconstructed_fission_seed_ss, title="Reconstructed Spectrum", in_energy=True)
plt.show()

# Plot SME estimates
fig, ax = plt.subplots(figsize=(12,6))
ax.scatter(np.arange(1, fg_seeds_ss.n_samples+1), expected_props[isotope_order])
ax.set_xticks(np.arange(1, fg_seeds_ss.n_samples+1), np.array(fg_seeds_ss.get_labels())[isotope_order], rotation=60)
ax.set_xlabel("Isotope")
ax.set_ylabel("Proportion")
ax.grid(alpha=0.3)
fig.tight_layout()
fig.savefig(os.path.join(IMAGE_DIR, "SME_expectation.png"), dpi=300)
plt.show()

In [None]:
"""Find closest measurement to reference spectrum."""

jsds = []
for i in range(pnnl_fg_ss.n_samples):
    tmp_pnnl_ss = pnnl_fg_ss[i]
    tmp_pnnl_ss = tmp_pnnl_ss.as_ecal(*TARGET_ECAL)
    jsds.append(jensenshannon(
        tmp_pnnl_ss.spectra.values[0,:], reconstructed_fission_seed_ss.spectra.values[0,:]
))
closest_idx = int(np.argmin(jsds))

ss = SampleSet()
cal_pnnl_fg_ss = pnnl_fg_ss[:]
cal_pnnl_fg_ss = cal_pnnl_fg_ss.as_ecal(*TARGET_ECAL)
ss.concat(cal_pnnl_fg_ss[closest_idx])
calibrated_recon_ss = reconstructed_fission_seed_ss[0]
calibrated_recon_ss = calibrated_recon_ss.as_ecal(*cal_pnnl_fg_ss.ecal[closest_idx])
ss.concat(calibrated_recon_ss)
ss.normalize()

fig, ax = plot_spectra(ss, in_energy=True, figsize=(10,5), show=False)
labels = [pnnl_fg_ss.info.description.values[closest_idx]]
labels.extend(["Reconstructed Spectrum"])
ax.legend(labels)
plt.show()

# save final results to CSV
final_results = list(sorted(zip(jsds, pnnl_fg_ss.info.description)))
final_results = pd.DataFrame(
    data=final_results,
    columns=["JSD", "PNNL_source"]
)
final_results.to_csv("PNNL_JSDS.csv")

In [None]:
"""Compare to target spectrum."""

target_description = "u-235_th_falcon_11min"

target_idx = list(pnnl_fg_ss.info.description).index(target_description)
target_jsd = jsds[target_idx]
print(f"JSD (reconstruction to {target_description}) = {target_jsd:.4f}")

ss = SampleSet()
ss.concat(pnnl_fg_ss[target_idx])
calibrated_recon_ss = reconstructed_fission_seed_ss[0]
calibrated_recon_ss = calibrated_recon_ss.as_ecal(*pnnl_fg_ss.ecal[target_idx])
ss.concat(calibrated_recon_ss)
ss.normalize(p=1)

fig, ax = plot_spectra(ss, in_energy=True, figsize=(10,5), show=False, title=f"JSD: {target_jsd:.4f}")
labels = [target_description, "reconstruction"]
ax.legend(labels)
ax.set_title(f"SME Reconstruction Compared to {target_description} (JSD={target_jsd:.4f})")
plt.show()

In [None]:
"""JSD time trend plots."""

# First for particular fission group
target_description = "u-235_th_falcon_11min"
target_source = "_".join(target_description.split("_")[0:-1])
target_inds = np.where(np.array([each.startswith(target_source) for each in list(pnnl_fg_ss.info.description)]) == True)[0]
target_shortest_idx = target_inds[int(np.argmin(pnnl_times[target_inds]))]
target_times = pnnl_times[target_inds]

target_jsds = []
for each in target_inds:
    target_jsds.append(jensenshannon(
        cal_pnnl_fg_ss.spectra.values[each], cal_pnnl_fg_ss.spectra.values[target_shortest_idx]
    ))

fig, ax = plt.subplots()
ax.scatter(target_times, target_jsds)
ax.set_xlabel("Time (in minutes)")
ax.set_ylabel("JSD between current spectra and initial spectra")
ax.set_title(f"target source: {target_description}")

plt.show()

marker_list = ["x", "o", "v"]

all_jsds_to_target = []
for idx in range(cal_pnnl_fg_ss.n_samples):
    all_jsds_to_target.append(
        jensenshannon(cal_pnnl_fg_ss.spectra.values[idx], cal_pnnl_fg_ss.spectra.values[target_shortest_idx])
    )
all_jsds_to_target = np.array(all_jsds_to_target)

fig, ax = plt.subplots(figsize=(8,7))

for idx, each in enumerate(grouped_inds):
    group_title = pnnl_categories[each[0]]
    marker = marker_list[idx % len(marker_list)]
    if group_title.lower() == target_source:
        marker = "s"

    ax.scatter(pnnl_times[each], all_jsds_to_target[each], label=pnnl_categories[each[0]], marker=marker, alpha=0.7)

ax.legend(bbox_to_anchor=(1.05, 1.0))
ax.set_xlabel("Time (in minutes)")
ax.set_title(f"JSD of All PNNL Measurements to {target_description}")
ax.set_ylabel("JSD")
ax.set_xscale("log")

fig.tight_layout()
plt.show()

jsds_to_recon = []
for idx in range(cal_pnnl_fg_ss.n_samples):
    jsds_to_recon.append(jensenshannon(
        cal_pnnl_fg_ss.spectra.values[idx], reconstructed_fission_seed_ss.spectra.values[0]
    ))
jsds_to_recon = np.array(jsds_to_recon)

fig, ax = plt.subplots(figsize=(8,7))

for idx, each in enumerate(grouped_inds):
    group_title = pnnl_categories[each[0]]
    marker = marker_list[idx % len(marker_list)]
    if group_title.lower() == target_source:
        marker = "s"

    ax.scatter(pnnl_times[each], jsds_to_recon[each], label=pnnl_categories[each[0]], marker=marker, alpha=0.7)

ax.legend(bbox_to_anchor=(1.05, 1.0))
ax.set_xlabel("Time (in minutes)")
ax.set_title("JSD of All PNNL Measurements to SME-Based Reconstruction")
ax.set_ylabel("JSD")
ax.set_xscale("log")

fig.tight_layout()
plt.show()

In [None]:
"""Plot of SNRS for PNNL data."""

fig, ax = plt.subplots()
ax.scatter(np.arange(pnnl_fg_ss.n_samples), pnnl_snrs, alpha=0.5, label="measurement")
ax.scatter(target_idx, pnnl_snrs[target_idx], marker="x", label=f"U-235_th_FALCON_11min ({pnnl_snrs[target_idx]:.0f})")
ax.axhline(5000, color="red", linestyle="--", label="SNR = 5000")
ax.axhline(2000, color="green", linestyle="--", label="SNR = 2000")
ax.set_yscale("log")
ax.set_xlabel("PNNL Sample #")
ax.set_ylabel("SNR")
ax.legend(loc="lower right")
plt.show()

In [None]:
"""Load in final model for each unsupervised loss (either locally or with W&B)."""

best_runs, best_models = load_final_models()

In [None]:
"""Run models on all PNNL data."""

# do forward pass on data
pred_ss = []
bg_cps = (pnnl_bg_ss.spectra.values.sum(axis=1) / pnnl_bg_ss.info.live_time.values)[-1]

for idx, unsup_loss in enumerate(best_runs.keys()):
    tmp_ss = cal_pnnl_fg_ss[:]
    best_models[idx].predict(tmp_ss, bg_cps=bg_cps)
    pred_ss.append(tmp_ss)

# compute reconstructions and reconstruction errors
fg_dict = np.array(best_models[0].source_dict)
all_reconstructions = []
all_reconstruction_errors = []
for idx, unsup_loss in  enumerate(best_runs.keys()):
    tmp_reconstructions = []
    tmp_reconstruction_errors = []
    for i in range(cal_pnnl_fg_ss.n_samples):
        prediction = pred_ss[idx].prediction_probas.values[i][:, np.newaxis]
        reconstruction = np.dot(fg_dict.T, prediction)
        tmp_reconstructions.append(reconstruction)

        reconstruction_error = pred_ss[idx].info[best_models[idx].unsup_loss_func_name].values[i]
        tmp_reconstruction_errors.append(reconstruction_error)

    all_reconstructions.append(tmp_reconstructions)
    all_reconstruction_errors.append(np.array(tmp_reconstruction_errors))

In [None]:
"""Compare SME estimates to model predictions on IND PNNL data."""
target_spectrum = cal_pnnl_fg_ss[target_idx]
target_spectrum.normalize()

# compute reconstruction errors for each unsupervised function
recon_error_funcs = [
    chi_squared_diff,
    jensen_shannon_divergence,
    poisson_nll_diff,
    sse_diff
]
sme_target_reconstruction_errors = []
nnls_target_reconstruction_errors = []
model_target_reconstruction_errors = {
    each: [] for each in best_runs.keys()
}

target_measurement = cal_pnnl_fg_ss.spectra.values[target_idx, :].reshape(1, TARGET_BINS)

sme_reconstruction = reconstructed_fission_seed_ss.spectra.values[0,:].reshape(1, TARGET_BINS) * target_measurement.sum()

# compute nnls reconstruction
nnls_sol, nnls_rnorm = nnls(fg_dict.T, target_spectrum.spectra.values[0])
nnls_target_reconstruction = np.dot(fg_dict.T, nnls_sol)
nnls_target_reconstruction = nnls_target_reconstruction / nnls_target_reconstruction.sum()  # L1 normalize
nnls_target_reconstruction *= target_measurement.sum()  # convert to counts
jsd_nnls = jensenshannon(nnls_target_reconstruction, target_spectrum.spectra.values[0])
print(f"NNLS Sol: {nnls_sol}")
print(f"NNLS Residual = {nnls_rnorm:4f}")

# Compute reconstruction errors
for func_idx, recon_error_func in enumerate(recon_error_funcs):
    # get SME recon error
    sme_recon_error = recon_error_func(
        target_measurement,
        sme_reconstruction / sme_reconstruction.sum()
    )
    sme_target_reconstruction_errors.append(sme_recon_error.numpy().sum())

    # get NNLS recon error
    nnls_recon_error = recon_error_func(
        target_measurement,
        nnls_target_reconstruction.reshape(1, TARGET_BINS) / nnls_target_reconstruction.sum()
    )
    nnls_target_reconstruction_errors.append(nnls_recon_error.numpy().sum())

    for unsup_idx, unsup_loss in enumerate(best_runs.keys()):
        model_reconstruction = all_reconstructions[unsup_idx][target_idx]  # * target_measurement.sum()
        model_recon_error = recon_error_func(
            target_measurement,
            model_reconstruction.reshape(1, TARGET_BINS)
        )
        model_target_reconstruction_errors[unsup_loss].append(model_recon_error.numpy().sum())

print("SME Reconstruction Errors:")
display(sme_target_reconstruction_errors)
print("NNLS Reconstruction Errors:")
display(nnls_target_reconstruction_errors)
for each in best_runs.keys():
    print(f"{each} Model Reconstructin Errors:")
    display(model_target_reconstruction_errors[each])

alpha = 0.7
fig, ax = plt.subplots(figsize=(10,6))
ax.scatter(
    np.arange(fg_seeds_ss.n_samples),
    expected_props[isotope_order],
    label="SME expectation",
    alpha=alpha,
    color="#006BA4"
)
ax.scatter(
    np.arange(fg_seeds_ss.n_samples),
    nnls_sol[isotope_order],
    label="NNLS solution",
    alpha=alpha,
    color="#FF800E"
)
ind_model_preds = pred_ss[focus_unsup_loss].prediction_probas.values[target_idx][:, np.newaxis]
ax.scatter(
    np.arange(fg_seeds_ss.n_samples),
    ind_model_preds[isotope_order],
    label=f"model prediction ({plt_names[focus_unsup_loss]})",
    alpha=alpha,
    color="#595959"
)

ax.set_xticks(np.arange(fg_seeds_ss.n_samples), np.array(fg_seeds_ss.get_labels())[isotope_order], rotation=60)
ax.set_xlabel("Isotope")
ax.set_ylabel("Proportion")
ax.legend()
ax.grid(alpha=0.3)
ax.set_yscale("linear")
fig.tight_layout()
fig.savefig(os.path.join(
    IMAGE_DIR,
    "ind_pnnl_estimates.png"
), dpi=300)
plt.show()

# plot all reconstruction comparison
for idx, unsup_loss in enumerate(best_runs.keys()):
    recon_ss = SampleSet()
    recon_ss.concat(cal_pnnl_fg_ss[target_idx])  # target measurement
    model_recon_ss = reconstructed_fission_seed_ss[0]
    model_recon_ss.spectra.iloc[0] = all_reconstructions[idx][target_idx].squeeze()
    model_recon_ss.sources.iloc[0] = pred_ss[idx].prediction_probas.values[target_idx]
    recon_ss.concat(model_recon_ss)  # model reconstruction
    recon_ss.normalize()
    recon_ss.spectra = pd.DataFrame(
        data=recon_ss.spectra.values * np.array([cal_pnnl_fg_ss[target_idx].spectra.values.sum()] * recon_ss.n_samples)[:, np.newaxis]
    )
    recon_ss.info["label"] = [
        "PNNL measurement",
        f"model ({plt_names[idx]}) reconstruction"
    ]
    fig, ax = plot_spectra(
        recon_ss,
        in_energy=True,
        show=False,
        ylim=(100, None)
    )
    ax.legend(recon_ss.info.label)
    fig.tight_layout()
    fig.savefig(os.path.join(
        IMAGE_DIR,
        f"ind_pnnl_reconstruction_{unsup_loss}.png"
    ), dpi=300)
    plt.show()

# plot measurement and SME/NNLS reconstructions
recon_ss = SampleSet()
recon_ss.concat(cal_pnnl_fg_ss[target_idx])  # target measurement
recon_ss.concat(reconstructed_fission_seed_ss)  # SME reconstruction
recon_ss.normalize()
recon_ss.spectra = pd.DataFrame(
    data=recon_ss.spectra.values * np.array([cal_pnnl_fg_ss[target_idx].spectra.values.sum()] * recon_ss.n_samples)[:, np.newaxis]
)
recon_ss.info["label"] = [
    "PNNL measurement",
    "SME reconstruction"
]
fig, ax = plot_spectra(recon_ss, in_energy=True, show=False, ylim=(100, None))
ax.legend(recon_ss.info.label)
fig.tight_layout()
fig.savefig(os.path.join(
    IMAGE_DIR,
    "ind_pnnl_reconstruction_SME.png"
), dpi=300)
plt.show()

recon_ss = SampleSet()
recon_ss.concat(cal_pnnl_fg_ss[target_idx])  # target measurement
nnls_recon_ss = reconstructed_fission_seed_ss[0]
nnls_recon_ss.spectra.iloc[0] = nnls_target_reconstruction
nnls_recon_ss.sources.iloc[0] = nnls_sol
recon_ss.concat(nnls_recon_ss)  # NNLS reconstruction
recon_ss.normalize()
recon_ss.spectra = pd.DataFrame(
    data=recon_ss.spectra.values * np.array([cal_pnnl_fg_ss[target_idx].spectra.values.sum()] * recon_ss.n_samples)[:, np.newaxis]
)
recon_ss.info["label"] = [
    "PNNL measurement",
    "NNLS reconstruction"
]
fig, ax = plot_spectra(recon_ss, in_energy=True, show=False, ylim=(100, None))
ax.legend(recon_ss.info.label)
fig.tight_layout()
fig.savefig(os.path.join(
    IMAGE_DIR,
    "ind_pnnl_reconstruction_NNLS.png"
), dpi=300)
plt.show()

# plot errors in counts for each reconstruction
recon_errors = []
fig, axes = plt.subplots(3, 2, sharex=True, sharey=True, figsize=(10, 12))

for plt_idx, ax in enumerate(axes.reshape(-1)):
    # model reconstructions
    if plt_idx < 4:
        recon_ss = SampleSet()
        recon_ss.concat(cal_pnnl_fg_ss[target_idx])  # target measurement
        model_recon_ss = reconstructed_fission_seed_ss[0]
        model_recon_ss.spectra.iloc[0] = all_reconstructions[plt_idx][target_idx].squeeze()
        model_recon_ss.sources.iloc[0] = pred_ss[plt_idx].prediction_probas.values[target_idx]
        recon_ss.concat(model_recon_ss)  # model reconstruction
        ax.set_title(f"Model ({plt_names[plt_idx]}) Reconstruction")

    # SME reconstruction
    elif plt_idx == 4:
        recon_ss = SampleSet()
        recon_ss.concat(cal_pnnl_fg_ss[target_idx])  # target measurement
        recon_ss.concat(reconstructed_fission_seed_ss)  # SME reconstruction
        ax.set_title("SME Reconstruction")

    # NNLS reconstruction
    else:
        recon_ss = SampleSet()
        recon_ss.concat(cal_pnnl_fg_ss[target_idx])  # target measurement
        nnls_recon_ss = reconstructed_fission_seed_ss[0]
        nnls_recon_ss.spectra.iloc[0] = nnls_target_reconstruction
        nnls_recon_ss.sources.iloc[0] = nnls_sol
        recon_ss.concat(nnls_recon_ss)  # NNLS reconstruction
        ax.set_title("NNLS Reconstruction")

    recon_ss.normalize()
    recon_ss.spectra = pd.DataFrame(
        data=recon_ss.spectra.values * np.array([cal_pnnl_fg_ss[target_idx].spectra.values.sum()] * recon_ss.n_samples)[:, np.newaxis]
    )

    recon_errors = recon_ss.spectra.values[0] - recon_ss.spectra.values[1]
    ax.plot(np.arange(1, recon_ss.n_channels + 1), recon_errors, alpha=1.0, linewidth=1)
    ax.set_yscale("linear")

fig.supylabel("Error in Counts")
fig.supxlabel("Energy Channel")
fig.tight_layout()
fig.savefig(os.path.join(
    IMAGE_DIR,
    "ind_pnnl_count_error.png"
), dpi=300)
plt.show()

In [None]:
"""Generate relative min-max recon error plots for all PNNL data."""

fig, axes = plt.subplots(2, 2, figsize=(10, 10), sharex=True, sharey=True)

for ax_idx, ax in enumerate(axes.reshape(-1)):
    min_error = all_reconstruction_errors[ax_idx].min()
    recon_error_range = all_reconstruction_errors[ax_idx].max() - min_error
    rel_min_max_errors = (all_reconstruction_errors[ax_idx] - min_error) / recon_error_range

    for idx, each in enumerate(grouped_inds):
        group_title = pnnl_categories[each[0]]
        marker = marker_list[idx % len(marker_list)]
        if group_title.lower() == target_source:
            marker = "s"

        ax.scatter(pnnl_times[each], rel_min_max_errors[each], label=group_title, marker=marker, alpha=0.7)

    ax.set_title(f"Unsupervised Loss: {list(best_runs.keys())[ax_idx]}")
    ax.set_xscale("log")
    ax.set_yscale("log")

handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc="center right", bbox_to_anchor=(1.23, 0.6))

fig.supxlabel("Time Post-Irradiation (in minutes)")
fig.supylabel("Relative Min-Max Reconstruction Error")
fig.tight_layout()
plt.show()

In [None]:
"""Generate recon error plots for all PNNL data."""

fig, axes = plt.subplots(2, 2, figsize=(10, 10), sharex=True, sharey=False)

cnts = cal_pnnl_fg_ss.spectra.values.sum(axis=1)

for ax_idx, ax in enumerate(axes.reshape(-1)):
    plt_recon_errors = all_reconstruction_errors[ax_idx]

    if list(best_runs.keys())[ax_idx] == "jsd":
            plt_recon_errors = all_reconstruction_errors[ax_idx]

    elif list(best_runs.keys())[ax_idx] == "sse":
        plt_recon_errors = all_reconstruction_errors[ax_idx] / np.square(cnts)

    elif list(best_runs.keys())[ax_idx] == "chi_squared":
        plt_recon_errors = all_reconstruction_errors[ax_idx]

    else:
        plt_recon_errors = all_reconstruction_errors[ax_idx] / cnts

    for idx, each in enumerate(grouped_inds):
        group_title = pnnl_categories[each[0]]
        marker = marker_list[idx % len(marker_list)]
        if group_title.lower() == target_source:
            marker = "s"

        ax.scatter(pnnl_times[each], plt_recon_errors[each], label=group_title, marker=marker, alpha=0.5)

    ax.scatter(
        pnnl_times[target_idx],
        plt_recon_errors[target_idx],
        s=110,
        facecolors="none",
        edgecolors="r",
        label="IND target spectrum"
    )

    ax.set_title(f"Unsupervised Loss: {plt_names[ax_idx]}")
    ax.set_xscale("log")
    ax.set_yscale("log")

handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc="center right", bbox_to_anchor=(1.23, 0.6))

axes[0, 0].set_ylabel("$\chi^2$")
axes[0, 1].set_ylabel("JSD")
axes[1, 0].set_ylabel("PNLL / Counts")
axes[1, 1].set_ylabel("SSE / Counts$^2$")

fig.supxlabel("Time Post-Irradiation (in minutes)")
fig.supylabel("Reconstruction Error Metric")
fig.tight_layout()
fig.savefig(os.path.join(
    IMAGE_DIR,
    "ood_pnnl_reconstruction_error_vs_time.png"
), dpi=300, bbox_inches='tight')
plt.show()

# generate them separately for each unsup loss as well
for loss_idx in range(len(best_models)):
    print(list(best_runs.keys())[loss_idx])
    fig, ax = plt.subplots(figsize=(10, 10))

    cnts = cal_pnnl_fg_ss.spectra.values.sum(axis=1)

    plt_recon_errors = all_reconstruction_errors[loss_idx]

    if list(best_runs.keys())[loss_idx] == "jsd":
        plt_recon_errors = all_reconstruction_errors[loss_idx]

    elif list(best_runs.keys())[loss_idx] == "sse":
        plt_recon_errors = all_reconstruction_errors[loss_idx] / np.square(cnts)

    elif list(best_runs.keys())[loss_idx] == "chi_squared":
        plt_recon_errors = all_reconstruction_errors[loss_idx]

    else:
        plt_recon_errors = all_reconstruction_errors[loss_idx] / cnts

    for idx, each in enumerate(grouped_inds):
        group_title = pnnl_categories[each[0]]
        marker = marker_list[idx % len(marker_list)]
        if group_title.lower() == target_source:
            marker = "s"

        ax.scatter(pnnl_times[each], plt_recon_errors[each], label=group_title, marker=marker, alpha=0.5)

    ax.scatter(
        pnnl_times[target_idx],
        plt_recon_errors[target_idx],
        s=110,
        facecolors="none",
        edgecolors="r",
        label="IND target spectrum"
    )

    ax.set_title(f"Unsupervised Loss: {plt_names[loss_idx]}")
    ax.set_xscale("log")
    ax.set_yscale("log")

    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc="center right", bbox_to_anchor=(1.23, 0.6))

    if list(best_runs.keys())[loss_idx] == "chi_squared":
        ax.set_ylabel("$\chi^2$")
    elif list(best_runs.keys())[loss_idx] == "jsd":
        ax.set_ylabel("JSD")
    elif list(best_runs.keys())[loss_idx] == "pnll":
        ax.set_ylabel("PNLL / Counts")
    elif list(best_runs.keys())[loss_idx] == "sse":
        ax.set_ylabel("SSE / Counts$^2$")

    fig.supxlabel("Time Post-Irradiation (in minutes)")
    fig.tight_layout()
    fig.savefig(os.path.join(
        IMAGE_DIR,
        f"ood_pnnl_reconstruction_error_vs_time_{list(best_runs.keys())[loss_idx]}.png"
    ), dpi=300, bbox_inches='tight')
    plt.show()