![pyriid](https://user-images.githubusercontent.com/1079118/124811147-623bd280-df1f-11eb-9f3a-a4a5e6ec5f94.png)

# PyRIID Primer 2

PyRIID (pronounced: PIE-rid) stands for Python-based Radioisotope IDentification (RIID).

PyRIID is a Python package intended to streamline the gamma spectrum synthesis and model fitting workflow.

**Itended audience:**

Potential or current PyRIID users studying gamma spec-related questions who want a more comprehensive demonstration of PyRIID utilities compared to Primer 1.

**Assumed background knowledge:**

1. Basic understanding of Python and how to install both it and PyRIID
2. Familiarity with what a gamma spectrum is and how they are obtained
3. How to install GADRAS

**Topics not covered in detail:**

- Extensive GADRAS details
   - Familiarity with the basics of the Detector and Inject tabs in GADRAS is helpful, but not strictly necessary
- Model performance metrics, neural network basics, stochastic gradient descent, and related topics

**Duration:**

When we present the content of this notebook, it takes ~2 hours as provide a lot of commentary and answer questions.
Going through it on your own and just running the cells will take considerably less time.

## Data Synthesis

### Simple Seed Synthesis

In [89]:
"""Imports, constants, and paths"""
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import yaml

from riid import (SampleSet, SeedMixer, SeedSynthesizer, StaticSynthesizer,
                  read_hdf, read_json, read_pcf)
from riid.models import MLPClassifier
from riid.visualize import plot_spectra

# SNR values for example problem
MIN_SNR = 5
MAX_SNR = 300

# Directories
SEED_CONFIGS_DIR = Path("./seed_configs/")
assert SEED_CONFIGS_DIR.exists()
DATA_DIR = Path("./problem_data")
DATA_DIR.mkdir(exist_ok=True)

# Files
BASIC_SEED_CONFIG_PATH = SEED_CONFIGS_DIR.joinpath("basic.yaml")
ADVANCED_SEED_CONFIG_PATH = SEED_CONFIGS_DIR.joinpath("advanced.yaml")
PROBLEM_SEED_CONFIG_PATH = SEED_CONFIGS_DIR.joinpath("problem.yaml")
SEEDS_PATH = str(DATA_DIR.joinpath("seeds.h5"))
FG_SEEDS_PATH = str(DATA_DIR.joinpath("fg_seeds.h5"))
BG_SEEDS_PATH = str(DATA_DIR.joinpath("bg_seeds.h5"))
MIXED_BG_SEEDS_PATH = str(DATA_DIR.joinpath("mixed_bg_seeds.h5"))
TRAIN_PATH = str(DATA_DIR.joinpath("train.h5"))
MODEL_JSON_PATH = str(DATA_DIR.joinpath("model.json"))
MODEL_ONNX_PATH = str(DATA_DIR.joinpath("model.onnx"))
MODEL_TFLITE_PATH = str(DATA_DIR.joinpath("model.tflite"))
IND_TEST_PATH = str(DATA_DIR.joinpath("test.h5"))  # Hold out, not seen in training

In [4]:
"""Basic seed synthesis"""
with open(BASIC_SEED_CONFIG_PATH) as fin:
    seed_synth_config = yaml.safe_load(fin)
seed_synth = SeedSynthesizer()
simple_seeds = seed_synth.generate(seed_synth_config)

In [None]:
"""Check where PyRIID is looking for GADRAS"""
from riid.gadras.api import GADRAS_INSTALL_PATH

print(GADRAS_INSTALL_PATH)

# If you have a non-default GADRAS installation, set an environment variable
# named "GADRAS_DIR" to point to your custom install location

### The `SampleSet`

#### Core Pandas `DataFrame`s

In [None]:
"""`spectra` DataFrame: the samples, the gamma spectra"""
simple_seeds.spectra

_ = plot_spectra(simple_seeds, in_energy=True)

In [None]:
"""`sources` DataFrame: the ground truth"""
simple_seeds.sources
# simple_seeds.sources.columns  # Note that the columns of `sources` DataFrame are a "MultiIndex" to track ground truth at multiple levels

In [None]:
"""`info` DataFrame: additional information describing each sample"""
simple_seeds.info

In [None]:
"""`prediction_probas` DataFrame: predictions made on data, likely differing from `sources`"""
simple_seeds.prediction_probas

In [None]:
"""`detector_info` dictionary: information about the detector response function"""
simple_seeds.detector_info

#### Core methods

In [12]:
"""Selecting and copying your data"""
seeds = simple_seeds[:]  # Select all samples; how you make a full copy; careful with large datasets...
# seeds[0].sources  # select first sample
# seeds[:4].sources  # select first 4 samples
# seeds[-2:].sources  # select last 2 samples

In [None]:
"""Separating foregrounds from background"""
fg_seeds, bg_seeds = simple_seeds.split_fg_and_bg()

print(fg_seeds)
print(bg_seeds)

In [None]:
"""Concatenating multiple SampleSets"""
new_seeds = SampleSet()
new_seeds.concat([fg_seeds, bg_seeds])

print(new_seeds)  # note loss of detector info; no good way to restore them yet
print(simple_seeds)

In [None]:
"""Get labels"""
seeds.get_labels()  # By default, the isotope level is extracted
seeds.get_labels(target_level="Isotope")
seeds.get_labels(target_level="Category")
seeds.get_labels(target_level="Seed")
seeds.get_labels(include_value=True)

# More arguments will be explored later as they become relevant
# Note: get_predictions() works identically to get_labels(), except it pulls from `prediction_probas` instead of `sources`

In [None]:
"""Energy calibration information"""
seeds.info  # A lot of noise
seeds.info.loc[:, seeds.ECAL_INFO_COLUMNS]  # Better
seeds.ecal  # Succinct, but not a DF
seeds.get_channel_energies(0)  # Converting channels to energy
seeds.get_all_channel_energies()  # Converting channels to energy for all samples
new_ecal_seeds = seeds.as_ecal(0, 1000, 0, 0, 0)  # Interpolating spectra to a new energy calibration
# Note: if you have a SampleSet with DISPARATE BUT KNOWN energy calibration parameters,
#   `as_ecal()` is useful for transforming all the spectra such that they have the same energy calibration.

# PyRIID defines energy calibration in the same terms as GADRAS, which uses the Full Range Fraction defined in ANSI N42.42-2006.
# However, PyRIID does not currently support deviation pairs.

# Plot in channel space
IDX = 0
fig, ax = plt.subplots()
ax.plot(simple_seeds.spectra.iloc[IDX], label="original")
ax.plot(new_ecal_seeds.spectra.iloc[IDX], label="new")
ax.set_yscale("log")
ax.set_title("Spectra when plotted using channels")
ax.legend()
plt.show()

# Plot in energy space (note that we have cut off our data)
fig, ax = plt.subplots()
ax.plot(simple_seeds.get_channel_energies(IDX), simple_seeds.spectra.iloc[IDX], label="original")
ax.plot(new_ecal_seeds.get_channel_energies(IDX), new_ecal_seeds.spectra.iloc[IDX], label="new")
ax.set_yscale("log")
ax.set_title("Spectra when plotted using energy")
ax.legend()
plt.show()

In [None]:
"""Downsampling (AKA down-binning, AKA rebinning into less channels)"""
downsampled_seeds = simple_seeds[:]  # Make a copy, we're about to be destructive
downsampled_seeds.downsample_spectra(128)
downsampled_seeds

In [None]:
"""Normalization"""
print(simple_seeds.spectra.sum(axis=1))

l1_norm_seeds = simple_seeds[:]  # We're about to be destructive again
l1_norm_seeds.normalize()  # p=1, L1 norm, dividing each channel by the sum of counts, default, seeds.normalize(p=1)
print(l1_norm_seeds.spectra.sum(axis=1))  # No change, seeds are already normalized

l2_norm_seeds = simple_seeds[:]
l2_norm_seeds.normalize(p=2)  # p=2, L2 norm, dividing each channel by the sum of squared counts
print(l2_norm_seeds.spectra.sum(axis=1))

# Plot
IDX = 3
fig, ax = plt.subplots()
ax.plot(l1_norm_seeds.spectra.iloc[IDX], label="L1 normalized")
ax.plot(l2_norm_seeds.spectra.iloc[IDX], label="L2 normalized")
ax.set_yscale("log")
ax.legend()
plt.show()

# General rule: bounding the range of your data via normalization makes it easier for a model to learn.

In [None]:
"""SampleSet arithmetic

You can only do the following:
- gross = fg + bg
- fg = gross - bg

Despite being strict, it will detect if you have L1-normalized spectra and rescale for you.
"""
fg_seeds, bg_seeds = simple_seeds.split_fg_and_bg()
# fg_seeds + bg_seeds  # breaks
# fg_seeds[0] + fg_seeds[0]  # breaks
# bg_seeds[0] + bg_seeds[0]  # breaks
# new_ss = fg_seeds[0] + bg_seeds[0]  # works
new_ss = fg_seeds + bg_seeds[0]  # works

new_ss.spectra.sum(axis=1)  # Spectra still sums 1 due to automatic rescaling

In [None]:
"""Spectra states and types

State: the form in which the spectra exist
Type: Background, Foreground, Gross

Both are used by PyRIID to track and properly carry out certain operations.
"""
simple_seeds.spectra_state, simple_seeds.spectra_type, fg_seeds.spectra_type, bg_seeds.spectra_type, new_ss.spectra_type

In [None]:
"""Seed health

Pretty simple as it checks two things:
- all spectra sum to 1 (required for later operations)
- dead time is not 100% (we don't like dead seeds)
"""
simple_seeds.check_seed_health()  # Performs the default check
print(simple_seeds.info.dead_time_prop)  # Note: prop = proportion
simple_seeds.check_seed_health(dead_time_threshold=0.01)  # You can relax the dead time check. Note how it now fails as intended!

In [None]:
"""Managing sources

Tracking ground truth is important and can be a lot of work.
You may find yourself managing it yourself if you're loading data from your custom files.

The truth is, for better or worse, what you make it when you synthesize data.
"""
from riid.data.labeling import label_to_index_element

seeds_sources_as_counts = seed_synth.generate(seed_synth_config, normalize_sources=False)

# seeds.normalize_sources()
# seeds.drop_sources()  # By default, drops background seeds at seed level; can provide your own; normalizes by default
# seeds.drop_spectra_with_no_contributors()
# seeds.drop_sources_columns_with_all_zeros()

# Outside of these functions, you have to manually build and set the sources DataFrame yourself
labels = ["K40", "Am241", "Ba133", "U235", "U238", "Mo99", "Tc99m", "WGPu", "K40"]
n_samples = len(labels)
column_tuples = [label_to_index_element(l, label_level="Seed") for l in labels]
columns = pd.MultiIndex.from_tuples(column_tuples, names=SampleSet.SOURCES_MULTI_INDEX_NAMES)

sources_df = pd.DataFrame(np.identity(n_samples), columns=columns)\
    .sort_index(axis=1)\
    .T.groupby(level=SampleSet.SOURCES_MULTI_INDEX_NAMES)\
    .sum().T
sources_df

In [25]:
"""Saving data"""
simple_seeds.to_hdf("seeds.h5", complevel=3)  # Preserves all information; configurable compression only meaningful for large datasets
simple_seeds.to_pcf("seeds.pcf")  # Useful for taking to GADRAS
simple_seeds.to_json("seeds.json")  # Useful for review processes

In [26]:
"""Loading data"""
hdf_seeds = read_hdf("seeds.h5")
pcf_seeds = read_pcf("seeds.pcf")
json_seeds = read_json("seeds.json")

### Advanced Synthesis

In [None]:
"""Randomizing injects"""
DETECTORS = [
    "Generic\\NaI\\3x3\\Front\\MidScat",
    "Generic\\CZT\\1cm-1cm-1cm",
    "Generic\\PVT\\2x2",
]

with open(ADVANCED_SEED_CONFIG_PATH) as fin:
    seed_synth_config = yaml.safe_load(fin)

detector_seeds = {}
for d in DETECTORS:
    seed_synth_config["gamma_detector"]["name"] = d
    seeds = seed_synth.generate(seed_synth_config, verbose=True)
    detector_seeds[d] = seeds
    print(seeds)

"""
Takeaways:
- Detector variation is slow; DRF (.dat) must be updated every time
  - PyRIID attempts to return each DRF to its original state, but errors happen--make backups
- Source variation is faster in GADRAS 19 utilizes batch inject per foreground source
"""

In [None]:
"""Behind the scenes of randomizing injects"""
from riid.gadras.api import validate_inject_config, get_expanded_config

with open(ADVANCED_SEED_CONFIG_PATH) as fin:
    seed_synth_config = yaml.safe_load(fin)

validate_inject_config(seed_synth_config)

get_expanded_config(seed_synth_config)

# You can set the random seed in the config for reproducibility

### Complex Sources

- Use built-in ones
- Build your own (1DM/3DM)
- Simulate them (GAM)

### Seed Mixing

In [None]:
"""Seed mixing

More details: https://www.osti.gov/biblio/2335905

Nothing is built in to PyRIID to estimate alphas from data,
but if you happen to have proportions, it is a pretty straight
forward maximum likelihood estimation problem.
"""
_, bg_seeds = simple_seeds.split_fg_and_bg()
alphas = [1, 3, 3, 3]
mixer = SeedMixer(
    bg_seeds,
    mixture_size=bg_seeds.n_samples,
    dirichlet_alpha=alphas,
)
mixed_bg_seeds = mixer.generate(100)

print(mixed_bg_seeds.sources.mean())
mixed_bg_seeds.sources.iloc[:, 1].hist()

In [None]:
"""Another way to visualize mixtures"""
barh_kwargs = {
    "height": 1.0,
    "edgecolor": "black",
    "linewidth": 0.5,
}
bar_x = np.arange(mixed_bg_seeds.n_samples)+1
props = mixed_bg_seeds.sources.to_numpy(float)
cols = mixed_bg_seeds.sources.columns.get_level_values("Seed")

fig, ax = plt.subplots(figsize=(8, 8), sharey=True, sharex=True)

ax.barh(bar_x, props[:,0], left=0, label=cols[0], **barh_kwargs)
ax.barh(bar_x, props[:,1], left=props[:,0], label=cols[1], **barh_kwargs)
ax.barh(bar_x, props[:,2], left=props[:,:2].sum(axis=1), label=cols[2], **barh_kwargs)
ax.barh(bar_x, props[:,3], left=props[:,:3].sum(axis=1), label=cols[3], **barh_kwargs)

ax.set_title(rf"$\alpha$ = {alphas}")
ax.set_xlabel("Partitions")
ax.set_xlim((0, 1))
ax.set_ylim((0, mixed_bg_seeds.n_samples))
ax.set_ylabel("Sample #")
ax.legend()
fig.tight_layout()
plt.show()

### Static Synthesis

In [None]:
"""Static synthesis

Static synthesis takes your foreground seeds and your background seeds and adds them together, randomly capturing variation in:
- live time
- signal-to-noise ratio (SNR)
- background count rate (effectively)
- Poisson fluctuations

You can obtain foreground, background, or gross spectra.
"""
static_synth = StaticSynthesizer(
    samples_per_seed=100,
    bg_cps=300.0,
    live_time_function="uniform",
    live_time_function_args=(0.25, 8),
    snr_function="log10",
    snr_function_args=(1, 200),
    long_bg_live_time=120,  # adjust this to make background subtraction worse
    return_fg=True,
    return_gross=False,
)
foregrounds, _ = static_synth.generate(fg_seeds, mixed_bg_seeds)

In [None]:
"""Sample our dataset to plot"""
_ = plot_spectra(foregrounds.sample(3), in_energy=True)  # Note the negatives

In [None]:
"""Inspect the dataset"""
foregrounds.info

## Model Fitting

There are many different models one can fit.
In this course, due to limited time, we will fit one type of classifier (with a simple architecture) and study it.
The goal is to demonstrate basic principles for training, using, and testing models that are not limited to "simple" ones.

### Our problem

1. Our detector setup is static, parameters defined in our seed config
1. Every spectrum we measure will be background subtracted
1. We could observe a wide variety of NORM, medical, and industrial sources, unshielded and shielded.
   - Ideally, this is informed by an SME and we iteratively improve our model over time as we get more info
1. We would like to classify measurements in reasonably detailed terms (we'll target isotope, collapsing specific configurations)
1. We want a basis for "confidence"
1. We would like to characterize some out-of-distribution behavior
   1. Did we generalize to SNR change?
   1. Did we generalize to OOD sources?


### Data Synthesis

In [None]:
"""Seeds"""
with open(PROBLEM_SEED_CONFIG_PATH) as fin:
    seed_synth_config = yaml.safe_load(fin)
seed_synth = SeedSynthesizer()
seeds = seed_synth.generate(seed_synth_config, verbose=True)
seeds.to_hdf(SEEDS_PATH)

In [None]:
"""Inspect seeds"""
print(f"Maximum dead time present: {seeds.info.dead_time_prop.max():.4f}")
print(f"# of distinct seeds:       {seeds.n_samples}")
print(f"# of distinct isotopes:    {seeds.get_labels().unique().shape[0]}")
_ = plot_spectra(seeds[:7], in_energy=True, target_level="Seed")

In [43]:
"""Split and mix"""
seeds = read_hdf(SEEDS_PATH)
# Downsample at the seed stage!  It speeds everything up.
seeds.downsample_spectra(128)
fg_seeds, bg_seeds = seeds.split_fg_and_bg()
fg_seeds.to_hdf(FG_SEEDS_PATH)
bg_seeds.to_hdf(BG_SEEDS_PATH)

mixer = SeedMixer(
    bg_seeds,
    mixture_size=bg_seeds.n_samples,
    dirichlet_alpha=2,
)
mixed_bg_seeds = mixer.generate(1)
mixed_bg_seeds.to_hdf(MIXED_BG_SEEDS_PATH)

In [None]:
"""Static synthesis"""
mixed_bg_seeds = read_hdf(MIXED_BG_SEEDS_PATH)
fg_seeds = read_hdf(FG_SEEDS_PATH)

static_synth = StaticSynthesizer(
    samples_per_seed=500,
    bg_cps=300,
    live_time_function="uniform",
    live_time_function_args=(1, 10),
    snr_function="log10",
    snr_function_args=(MIN_SNR, MAX_SNR),
    long_bg_live_time=120,  # adjust this to make background subtraction worse
    return_fg=True,
    return_gross=False,
)

foregrounds, _ = static_synth.generate(fg_seeds, mixed_bg_seeds)
foregrounds.to_hdf(TRAIN_PATH)

static_synth.samples_per_seed //= 4
foregrounds, _ = static_synth.generate(fg_seeds, mixed_bg_seeds)
foregrounds.to_hdf(IND_TEST_PATH)

### Training

In [45]:
"""Load and pre-process training data"""
def load_and_preprocess_data(path):
    """This function standardizes how we load and pre-process data,
    reducing the chance of a bug later.
    """
    data = read_hdf(path)
    data.normalize(p=1)
    return data


training_data = load_and_preprocess_data(TRAIN_PATH)

In [None]:
"""Training"""
model = MLPClassifier(dense_layer_size=64, dropout=0.8)
history = model.fit(training_data, target_level="Isotope", epochs=200, verbose=True)

In [None]:
"""Learning curve"""
from riid.visualize import plot_learning_curve

# _ = plot_learning_curve(history.history["loss"], [0])
_ = plot_learning_curve(history.history["loss"], history.history["val_loss"])

In [None]:
"""Model architecture summary"""
model.model.summary()

In [None]:
"""Save model
There are a multiple ways to "save" a model out as a file:

1. JSON: a PyRIID-specific format enabling you to load the model back in with PyRIID later.
   This format is useful because all information about the model, including metadata is encapsulated in one place.
2. ONNX: Open Neural Network Exchange format, an open format for machine learning models.
3. TFLite: TensorFlow lite format, useful for targeting TF runtimes in various places.
"""

model.save(MODEL_JSON_PATH)
model.to_onnx(MODEL_ONNX_PATH)
model.to_tflite(MODEL_TFLITE_PATH)

In [59]:
"""Load model"""
model = MLPClassifier()
model.load(MODEL_JSON_PATH)

### Testing

In [60]:
"""Load in-distribution (IND) test data"""
testing_data = load_and_preprocess_data(IND_TEST_PATH)

In [None]:
"""Predict IND data"""
model.predict(testing_data)
testing_data.get_predictions(include_value=True)

In [None]:
"""Score IND test data"""
from sklearn.metrics import f1_score

f1_score(testing_data.get_labels(),
         testing_data.get_predictions(),
         average="micro")

In [None]:
"""Confusion matrix"""
from riid.visualize import confusion_matrix

cm_kwargs = {
    "as_percentage": True,
    "figsize": (14, 14),
}
_ = confusion_matrix(testing_data, **cm_kwargs)  # We should dispel any notion of a "perfect" model--this looks good
_ = confusion_matrix(testing_data[testing_data.info.snr > 10], **cm_kwargs)

In [None]:
"""Model score vs. SNR"""
from riid.visualize import plot_snr_vs_score

_ = plot_snr_vs_score(testing_data)

### Investigating sources

In [None]:
"""Investigate specific Y88 performance and plot"""
from sklearn.metrics import recall_score

y88_label_mask = testing_data.get_labels() == "Y88"
y88_testing_data = testing_data[y88_label_mask]
y88_seed_labels = y88_testing_data.get_labels("Seed")
unique_y88_seed_labels = y88_seed_labels.unique()
for l in unique_y88_seed_labels:
    y88_config_testing_data = y88_testing_data[y88_seed_labels == l]
    recall = recall_score(
        y88_config_testing_data.get_labels(),
        y88_config_testing_data.get_predictions(),
        average="micro",
    )
    print(f"{l} recall = {recall:.4f}")

plot_data = SampleSet()
fg_seed_labels = fg_seeds.get_labels("Seed")
fg_isotope_labels = fg_seeds.get_labels()
plot_data.concat([
    fg_seeds[fg_seed_labels == "Y88,100uC {10,50}"],
    fg_seeds[fg_seed_labels == "Y88,100uC {26,30}"],
    fg_seeds[fg_isotope_labels == "Ra226"],
])
_ = plot_spectra(plot_data, target_level="Seed")

In [None]:
"""Investigate Y88 similarity to Ra226"""
from scipy.spatial.distance import jensenshannon

y88_seeds = fg_seeds[fg_isotope_labels == "Y88"]
y88_seed_labels = y88_seeds.get_labels("Seed")
y88_spectra = y88_seeds.spectra.to_numpy()
target_y88_seed_label = "Y88,100uC {10,50}"
target_y88_spectrum = fg_seeds[fg_seed_labels == target_y88_seed_label].spectra.to_numpy()[0]

ra226_seeds = fg_seeds[fg_isotope_labels == "Ra226"]
ra226_seed_labels = ra226_seeds.get_labels("Seed")
ra226_spectra = ra226_seeds.spectra.to_numpy()

def calculate_and_print_jsd(spec1, spec1_label, spec2, spec2_label):
    jsd = jensenshannon(spec1, spec2)
    print(f"{spec1_label} <-> {spec2_label}".ljust(45), f"= {jsd:.3f}")

for i, l in enumerate(ra226_seed_labels):
    calculate_and_print_jsd(target_y88_spectrum, target_y88_seed_label, ra226_spectra[i], l)

for i, l in enumerate(y88_seed_labels):
    calculate_and_print_jsd(target_y88_spectrum, target_y88_seed_label, y88_spectra[i], l)

### Out-of-distribution (OOD) Detection

In [67]:
"""Build some post-processing data"""
fg_seeds = read_hdf(FG_SEEDS_PATH)
fg_seeds.downsample_spectra(128)
# We specifically do NOT want to use `sources` for confidence as we won't have that in practice
sample_jsds = testing_data.get_multiclass_jsds(fg_seeds, model.target_level)

def multiclass_jsds_to_top_jsd(jsds):
    post = []
    for d in jsds:
        min_key = min(d, key=d.get)
        post.append((min_key, d[min_key]))
    post_df = pd.DataFrame(post, columns=["seed", "jsd"])
    return post_df

post_df = multiclass_jsds_to_top_jsd(sample_jsds)
post_df["model_proba"] = testing_data.prediction_probas.T.groupby(model.target_level).sum().max()
post_df["snr"] = testing_data.info.snr

In [None]:
"""Model probability vs. SNR"""
fig, ax = plt.subplots()
ax.scatter(post_df.snr, post_df.model_proba)
ax.set_xlabel("SNR")
ax.set_ylabel("Model probability")
plt.show()

In [None]:
"""JSD vs. SNR"""
fig, ax = plt.subplots()
ax.scatter(post_df.snr, post_df.jsd)
ax.set_xlabel("SNR")
ax.set_ylabel("JSD (sample vs. top seed)")
plt.show()

In [None]:
"""JSD histogram"""
post_df.jsd.hist()

In [None]:
"""JSD vs. model probability"""
fig, ax = plt.subplots()

ax.scatter(post_df.jsd, post_df.model_proba)
ax.set_xlabel("JSD")
ax.set_ylabel("Model output")
# ax.set_xscale("log")
# ax.set_yscale("log")
ax.set_xlim((0, 1))
plt.show()

In [None]:
"""Confidence using an out-of-distribution (OOD) detector (a binary classifier)

For an OOD binary classifier, the positive class is OOD and negative class is in-distribution (IND).
As such, a false positive (FP) corresponds to calling an IND sample OOD.
In practice, we typically only start out knowing the behavior of our model and in-distribution (synthetic) data.
Therefore, the OOD detector, at least for now, must be based on observing deviations from IND data, i.e., negative samples (in this context).
To do this we threshold on true negative rate (TNR), which is 1 minus our desired false positive rate (which describes OOD samples).
"""
from scipy.interpolate import UnivariateSpline

N_QUANTILES = 10
TARGET_FP_RATE = 0.001
TARGET_TNR = 1 - TARGET_FP_RATE
SPLINE_K = 2
SPLINE_S = 0
snrs = post_df.snr
jsds = post_df.jsd

snr_buckets = pd.qcut(snrs, N_QUANTILES, labels=False)
bucket_thresholds = [
    np.quantile(np.array(jsds)[snr_buckets == int(i)], TARGET_TNR)
    for i in range(N_QUANTILES)
]
median_snrs = [
    np.median(np.array(snrs)[snr_buckets == int(i)])
    for i in range(N_QUANTILES)
]
spline = UnivariateSpline(
    median_snrs,
    bucket_thresholds,
    k=SPLINE_K,
    s=SPLINE_S,
)  # The spline is a function which takes an SNR and returns the JSD representing our targeted FPR threshold

is_ood = jsds > spline(snrs)
post_df["ood"] = is_ood
print(f"Target FPR:   {TARGET_FP_RATE:.4f}")
print(f"Observed FPR: {post_df.ood.sum() / post_df.shape[0]:.4f}")
post_df

### OOD SNR Behavior

In [None]:
"""Generate some OOD SNR data.

In practical terms, we know very little about positive classes.
That is why we almost have to construct the OOD detector using IND samples.
But there is some more we can do.
"""
mixed_bg_seeds = read_hdf(MIXED_BG_SEEDS_PATH)
fg_seeds = read_hdf(FG_SEEDS_PATH)

OOD_MIN_SNR = 0.01
OOD_MAX_SNR = 1000
static_synth = StaticSynthesizer(
    samples_per_seed=100,
    bg_cps=300,
    live_time_function="uniform",
    live_time_function_args=(1, 10),
    snr_function="log10",
    snr_function_args=(OOD_MIN_SNR, OOD_MAX_SNR),
    long_bg_live_time=120,  # adjust this to make background subtraction worse
    return_fg=True,
    return_gross=False,
)
ood_snr_data, _ = static_synth.generate(fg_seeds, mixed_bg_seeds)
ood_snr_data.normalize()

In [119]:
"""Predict OOD SNR data"""
model.predict(ood_snr_data)

In [None]:
"""OOD SNR performance"""
f1_score(ood_snr_data.get_labels(), ood_snr_data.get_predictions(), average="micro")  # Things are worse, of course

In [121]:
"""Compute JSDs"""
ood_snr_data_jsds = ood_snr_data.get_multiclass_jsds(fg_seeds, model.target_level)

In [None]:
"""Get top JSDs"""
new_post_df = multiclass_jsds_to_top_jsd(ood_snr_data_jsds)
new_post_df["model_proba"] = ood_snr_data.prediction_probas.max(axis=1)
new_post_df["snr"] = ood_snr_data.info.snr
new_post_df["ood"] = (new_post_df.jsd > spline(new_post_df.snr)) | ~new_post_df.snr.between(MIN_SNR, MAX_SNR)
print(new_post_df.ood.value_counts())

In [None]:
"""Plot OOD and IND vs. SNR"""
ood_samples = new_post_df[new_post_df.ood]
ind_samples = new_post_df[~new_post_df.ood]

ALPHA = 0.3
fig, ax = plt.subplots()
ax.scatter(ood_samples.snr, ood_samples.jsd, color="black", label="OOD", alpha=ALPHA, marker="x")
ax.scatter(ind_samples.snr, ind_samples.jsd, color="blue", label="IND", alpha=ALPHA, marker=".")
ax.vlines([MIN_SNR, MAX_SNR], 0, 1, label="IND SNR range", color="red", linestyle="solid")
# Plot spline
snr_range = np.logspace(np.log10(MIN_SNR), np.log10(MAX_SNR), num=100)
ax.plot(snr_range, spline(snr_range), color="red", linestyle="dashed", label="Spline decision threshold")

ax.set_yscale("log")
ax.set_xscale("log")
ax.set_xlabel("SNR")
ax.set_ylabel("JSD")
ax.set_ylim((new_post_df.jsd.min(), 1))
ax.legend()
plt.show()

### OOD Source Behavior

In [None]:
"""Generate OOD source data"""
# Here, we're using individual background components (K, U, T, and Cosmic) as OOD sources
mixed_bg_seeds = read_hdf(MIXED_BG_SEEDS_PATH)
bg_seeds = read_hdf(BG_SEEDS_PATH)

ood_src_data, _ = static_synth.generate(bg_seeds, mixed_bg_seeds)
ood_src_data.normalize()

In [172]:
"""Predict"""
model.predict(ood_src_data)

In [175]:
"""Compute JSDs"""
ood_src_data_jsds = ood_src_data.get_multiclass_jsds(fg_seeds, model.target_level)

In [None]:
"""Get top JSDs"""
new_post_df = multiclass_jsds_to_top_jsd(ood_src_data_jsds)
new_post_df["model_proba"] = ood_src_data.prediction_probas.max(axis=1)
new_post_df["snr"] = ood_src_data.info.snr
new_post_df["ood"] = (new_post_df.jsd > spline(new_post_df.snr)) | ~new_post_df.snr.between(MIN_SNR, MAX_SNR)
print(new_post_df.ood.value_counts())  # what??? they should all be OOD

In [None]:
"""Plot OOD and IND vs. SNR"""
ood_samples = new_post_df[new_post_df.ood]
ind_samples = new_post_df[~new_post_df.ood]

ALPHA = 0.3
fig, ax = plt.subplots()
ax.scatter(ood_samples.snr, ood_samples.jsd, color="black", label="OOD", alpha=ALPHA, marker="x")
ax.scatter(ind_samples.snr, ind_samples.jsd, color="blue", label="IND", alpha=ALPHA, marker=".")
ax.vlines([MIN_SNR, MAX_SNR], 0, 1, label="IND SNR range", color="red", linestyle="solid")
# Plot spline
snr_range = np.logspace(np.log10(MIN_SNR), np.log10(MAX_SNR), num=100)
ax.plot(snr_range, spline(snr_range), color="red", linestyle="dashed", label="Spline decision threshold")

ax.set_yscale("log")
ax.set_xscale("log")
ax.set_xlabel("SNR")
ax.set_ylabel("JSD")
ax.set_ylim((new_post_df.jsd.min(), 1))
ax.legend()
plt.show()

# Here we can see that at > ~100 SNR, background components K, U, T, and cosmic reliably fall OOD.
# We can also see that each background components diverages differently due to its distinct features.