    Copyright 2024 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 
    Under the terms of Contract DE-NA0003525 with NTESS, the U.S. Government retains certain rights in this software.

**Step 9 - train/evaluate model on 8192 channel data**

In [None]:
import os
import time

import numpy as np
import tensorflow as tf
from config import (BG_SEED_FILE, DATA_DIR, FG_SEED_FILE, MIX_TEST_FG_SAMPLES,
                    MIX_TRAIN_BG_ALPHA, MIX_TRAIN_BG_SAMPLES,
                    MIX_TRAIN_BG_SIZE, MIX_TRAIN_FG_LAMBDA,
                    MIX_TRAIN_FG_SAMPLES, MODEL_ACTIVITY_L1_REG,
                    MODEL_ACTIVITY_L2_REG, MODEL_BATCH_SIZE, MODEL_BETA,
                    MODEL_DIR, MODEL_DROPOUT, MODEL_EPOCHS,
                    MODEL_HIDDEN_LAYER_ACTIVATION, MODEL_HIDDEN_LAYERS,
                    MODEL_INIT_LR, MODEL_KERNEL_L1_REG, MODEL_KERNEL_L2_REG,
                    MODEL_LR_SCHED_MIN_DELTA, MODEL_LR_SCHED_PATIENCE,
                    MODEL_MIN_DELTA, MODEL_NORMALIZE_SUP_LOSS, MODEL_OOD_FPR,
                    MODEL_OPTIMIZER, MODEL_OPTIMIZER_KWARGS, MODEL_PATIENCE,
                    MODEL_SPLINE_BINS, MODEL_SPLINE_K, MODEL_SPLINE_S,
                    MODEL_SUP_LOSS, MODEL_SUP_NORMALIZE_SCALER,
                    MODEL_TARGET_LEVEL, MODEL_TRAIN_ON_GPU,
                    MODEL_TRAIN_ON_STATIC_SYNTH_DATA, MODEL_UNSUP_LOSS,
                    MODEL_VAL_SPLIT, OOD_SEED_FILE, RANDOM_SEED,
                    STATIC_SYNTH_BG_CPS, STATIC_SYNTH_LIVE_TIME_RANGE,
                    STATIC_SYNTH_LIVE_TIME_SAMPLING, STATIC_SYNTH_LONG_BG_CPS,
                    STATIC_SYNTH_SNR_RANGE, STATIC_SYNTH_SNR_SAMPLING,
                    STATIC_SYNTH_SPS, STATIC_SYNTH_TEST_SNR_RANGE, TARGET_ECAL)
from riid.data.sampleset import read_hdf
from riid.data.synthetic.seed import SeedMixer
from riid.data.synthetic.static import StaticSynthesizer
from riid.models.neural_nets import LabelProportionEstimator
from sklearn.metrics import mean_absolute_error as mae
from utils import SaveTruthsandPredictionsCallback

In [None]:
TARGET_BINS = 8192
file_identifier = "_8192channels"
BG_MIX_FILE = os.path.join(DATA_DIR, f"bg_mixtures{file_identifier}.h5")
TRAIN_FILE = os.path.join(DATA_DIR, f"train.h5")
TRAIN_FG_MIX_FILE = os.path.join(DATA_DIR, f"train_fg_mixtures{file_identifier}.h5")
TEST_FILE = os.path.join(DATA_DIR, f"test{file_identifier}.h5")
TEST_FG_MIX_FILE = os.path.join(DATA_DIR, f"test_fg_mixtures{file_identifier}.h5")

In [None]:
"""Generate IND data for 8192  channels"""
# Set rng
rng = np.random.default_rng(seed=RANDOM_SEED)

# Load in seeds
fg_seeds_ss = read_hdf(FG_SEED_FILE)
fg_seeds_ss, _ = fg_seeds_ss.split_fg_and_bg()
fg_seeds_ss.drop_sources_columns_with_all_zeros()

bg_seeds_ss = read_hdf(BG_SEED_FILE)
_, bg_seeds_ss = bg_seeds_ss.split_fg_and_bg()
bg_seeds_ss.drop_sources_columns_with_all_zeros()

# Get expected source contributions
source_counts = {
    x.split(",")[0]: v
    for x, v in zip(
        fg_seeds_ss.sources.columns.get_level_values("Seed").values,
        fg_seeds_ss.info.total_counts
    )
}
Z = np.array(list(source_counts.values()))
expected_props = Z / Z.sum()

# Preprocessing
fg_seeds_ss = fg_seeds_ss.as_ecal(*TARGET_ECAL)
fg_seeds_ss.downsample_spectra(target_bins=TARGET_BINS)
fg_seeds_ss.normalize()

bg_seeds_ss = bg_seeds_ss.as_ecal(*TARGET_ECAL)
bg_seeds_ss.downsample_spectra(target_bins=TARGET_BINS)
bg_seeds_ss.normalize()

static_syn = StaticSynthesizer(
    samples_per_seed=STATIC_SYNTH_SPS,
    bg_cps=STATIC_SYNTH_BG_CPS,
    live_time_function=STATIC_SYNTH_LIVE_TIME_SAMPLING,
    live_time_function_args=STATIC_SYNTH_LIVE_TIME_RANGE,
    snr_function=STATIC_SYNTH_SNR_SAMPLING,
    snr_function_args=STATIC_SYNTH_SNR_RANGE,
    long_bg_live_time=STATIC_SYNTH_LONG_BG_CPS,
    rng=rng
)

# Background
mixed_bg_seeds_ss = SeedMixer(
    bg_seeds_ss,
    mixture_size=MIX_TRAIN_BG_SIZE,
    dirichlet_alpha=MIX_TRAIN_BG_ALPHA,
    random_state=RANDOM_SEED
).generate(MIX_TRAIN_BG_SAMPLES)
mixed_bg_seeds_ss.to_hdf(BG_MIX_FILE)

# Train
train_mixed_fg_seeds_ss = SeedMixer(
    fg_seeds_ss,
    mixture_size=fg_seeds_ss.n_samples,
    dirichlet_alpha=expected_props * MIX_TRAIN_FG_LAMBDA,
    random_state=RANDOM_SEED
).generate(MIX_TRAIN_FG_SAMPLES)
train_mixed_fg_seeds_ss.to_hdf(TRAIN_FG_MIX_FILE)
train_ss, _ = static_syn.generate(
    fg_seeds_ss=train_mixed_fg_seeds_ss,
    bg_seeds_ss=mixed_bg_seeds_ss
)
train_ss.drop_spectra_with_no_contributors()
train_ss.clip_negatives()
train_ss.to_hdf(TRAIN_FILE)

# Test
static_syn.snr_function_args = STATIC_SYNTH_TEST_SNR_RANGE
test_mixed_fg_seeds_ss = SeedMixer(
    fg_seeds_ss,
    mixture_size=fg_seeds_ss.n_samples,
    dirichlet_alpha=expected_props * MIX_TRAIN_FG_LAMBDA,
    random_state=RANDOM_SEED
).generate(MIX_TEST_FG_SAMPLES)
test_mixed_fg_seeds_ss.to_hdf(TEST_FG_MIX_FILE)
test_ss, _ = static_syn.generate(
    fg_seeds_ss=test_mixed_fg_seeds_ss,
    bg_seeds_ss=mixed_bg_seeds_ss
)
test_ss.drop_spectra_with_no_contributors()
test_ss.clip_negatives()
test_ss.to_hdf(TEST_FILE)

In [None]:
"""Train model on 8192 channel data."""
if MODEL_TRAIN_ON_GPU:
    os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"  # optionally select valid gpus
else:
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

time_str = time.strftime("%Y%m%d-%H%M%S")
model_name = f"lpe{file_identifier}_{MODEL_SUP_LOSS}_{MODEL_UNSUP_LOSS}_{MODEL_BETA}_{time_str}"

fg_seeds_ss = read_hdf(FG_SEED_FILE)
fg_seeds_ss, _ = fg_seeds_ss.split_fg_and_bg()
fg_seeds_ss.drop_sources_columns_with_all_zeros()
ood_seeds_ss = read_hdf(OOD_SEED_FILE)
ood_seeds_ss, _ = ood_seeds_ss.split_fg_and_bg()
ood_seeds_ss.drop_sources_columns_with_all_zeros()

fg_seeds_ss = fg_seeds_ss.as_ecal(*TARGET_ECAL)
fg_seeds_ss.downsample_spectra(target_bins=TARGET_BINS)
fg_seeds_ss.normalize()
ood_seeds_ss = ood_seeds_ss.as_ecal(*TARGET_ECAL)
ood_seeds_ss.downsample_spectra(target_bins=TARGET_BINS)
ood_seeds_ss.normalize()

run_dir = MODEL_DIR
callback_dir = os.path.join(run_dir, "callbacks")

if not os.path.exists(callback_dir):
    os.makedirs(callback_dir)

if MODEL_TRAIN_ON_STATIC_SYNTH_DATA:
    train_ss = read_hdf(TRAIN_FILE)
    test_ss = read_hdf(TEST_FILE)
else:
    train_ss = read_hdf(TRAIN_FG_MIX_FILE)
    test_ss = read_hdf(TEST_FG_MIX_FILE)

model = LabelProportionEstimator(
    hidden_layers=MODEL_HIDDEN_LAYERS,
    sup_loss=MODEL_SUP_LOSS,
    unsup_loss=MODEL_UNSUP_LOSS,
    beta=MODEL_BETA,
    fg_dict=None,
    optimizer=MODEL_OPTIMIZER,
    optimizer_kwargs=MODEL_OPTIMIZER_KWARGS,
    learning_rate=MODEL_INIT_LR,
    metrics=["mae"],
    hidden_layer_activation=MODEL_HIDDEN_LAYER_ACTIVATION,
    kernel_l1_regularization=MODEL_KERNEL_L1_REG,
    kernel_l2_regularization=MODEL_KERNEL_L2_REG,
    activity_l1_regularization=MODEL_ACTIVITY_L1_REG,
    activity_l2_regularization=MODEL_ACTIVITY_L2_REG,
    dropout=MODEL_DROPOUT,
    target_level=MODEL_TARGET_LEVEL,
    bg_cps=STATIC_SYNTH_BG_CPS,
    fit_spline=MODEL_TRAIN_ON_STATIC_SYNTH_DATA,
    ood_fp_rate=MODEL_OOD_FPR,
    spline_bins=MODEL_SPLINE_BINS,
    spline_k=MODEL_SPLINE_K,
    spline_s=MODEL_SPLINE_S
)

callbacks = [
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss",
        factor=0.1,
        patience=MODEL_LR_SCHED_PATIENCE,
        min_delta=MODEL_LR_SCHED_MIN_DELTA
    ),
    SaveTruthsandPredictionsCallback(
        test_ss.spectra,
        test_ss.sources,
        model_name,
        model.activation,
        callback_dir=callback_dir
    )
]

history = model.fit(
    seeds_ss=fg_seeds_ss,
    ss=train_ss,
    batch_size=MODEL_BATCH_SIZE,
    epochs=MODEL_EPOCHS,
    validation_split=MODEL_VAL_SPLIT,
    callbacks=callbacks,
    patience=MODEL_PATIENCE,
    verbose=True,
    normalize_scaler=MODEL_SUP_NORMALIZE_SCALER,
    normalize_sup_loss=MODEL_NORMALIZE_SUP_LOSS,
    bg_cps=STATIC_SYNTH_BG_CPS,
    es_min_delta=MODEL_MIN_DELTA
)

model.save(os.path.join(run_dir, f"{model_name}.onnx"))

In [None]:
"""Run forward pass on test set."""
tmp_ss = test_ss[:]
model.predict(tmp_ss, bg_cps=STATIC_SYNTH_BG_CPS)
y_true = tmp_ss.sources.values
y_pred = tmp_ss.prediction_probas.values
maes = mae(y_true.T, y_pred.T, multioutput="raw_values")
snrs = tmp_ss.spectra.values.sum(axis=1) / \
    np.sqrt(STATIC_SYNTH_BG_CPS * tmp_ss.info.live_time.values)
recon_errors = tmp_ss.info[model.unsup_loss_func_name]
print(f"Test MAE: {np.mean(maes)}")
print(f"Test MAE (SNR > 100): {np.mean(maes[np.where(snrs > 100)[0]])}")