# rail_nncflag 

**Author:** Sam Schmidt 

**Last successfully run:** Sep 4, 2025

This notebook demonstrates the neural-net quality flag code, originally written by Adam Broussard and updated by Irene Moskowitz.  It takes in photometric data and the ensemble PDFs for a set of galaxies, along with true redshifts for a training set, and feeds in the reference magnitude and all colors to predict which objects return good point estimate redshifts.

There is also an option to compute ODDS, the integrated probability around +/- 0.06*(1+zmode) of the mode redshift, and add that as an additional feature.  This is done via the `include_odds` configuration parameter.  If `include_odds` is set to `True` ODDS, which is itself an additional quality flag, is also included as an output.

We begin with our usual imports, we will use the two cosmodc2 samples often employed in RAIL notebooks as our train and test sets. As PDFs are needed as well, we include a cell below with pre-computed PDFs from BPZ, you can download these files from NERSC as the tar file `bpz_traintest_ensembles.tar`.


The cell below will download a tar file (72Mb) from NERSC and untar it if not already present in the directory:

In [None]:
import os
tar_file = "./bpz_traintest_ensembles.tar"

if not os.path.exists(tar_file):
  os.system('curl -O https://portal.nersc.gov/cfs/lsst/PZ/bpz_traintest_ensembles.tar')
  os.system('tar -xvf bpz_traintest_ensembles.tar')

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import qp
import tables_io
from rail.core.data import TableHandle, ModelHandle, QPHandle
from rail.core.stage import RailStage
from rail.utils.path_utils import RAILDIR
from rail.estimation.algos.nnc_flag import NNFlagInformer, NNFlagEstimator

Let's load our data, both our photometry files and PDF ensembles from BPZ:

In [None]:
DS = RailStage.data_store
DS.__class__.allow_overwrite = True

In [None]:
trainFile = os.path.join(RAILDIR, 'rail/examples_data/testdata/test_dc2_training_9816.hdf5')
training_data = DS.read_file("training_data", TableHandle, trainFile)

testFile = os.path.join(RAILDIR, 'rail/examples_data/testdata/test_dc2_validation_9816.hdf5')
test_data = DS.read_file("test_data", TableHandle, testFile)

pdftrainfile = "./bpz_results_defaultprior_trainfile.hdf5"
pdfdata = DS.read_file("pdf_data", QPHandle, pdftrainfile)

pdftestfile = "./bpz_results_defaultprior_testfile.hdf5"
pdftestdata = DS.read_file("pdf_test_data", QPHandle, pdftestfile)

## Set up the inform stage

There are several parameters that control the behavior of the nncflag calculation.  Any parameter not set will take on their default values:

- bands (list, default: SHARED_PARAMS): the list of photometric bands
- ref_band (str, default: SHARED_PARAMS): the reference band, this band will be used as an input, the other bands are not included, but differences in the adjacent bands (colors) are used as the other inputs
- redshift_col: (str, default: SHARED_PARAMS): name of the true redshift column in the photometry file
- nodecounts (list, default: [100, 200, 100, 50, 1]): the number of notes in each neural net model
- splitnum (int, default: 5): the number of neural net models to train
- activations (list, default: ['selu', 'selu', 'selu', 'selu', 'sigmoid']): the activation functions for the neural net nodes.  NOTE: the length of `activations` must match the length of `nodecounts`!
- epochs (int, default: 1000): the max number of training epochs for the NN models
- acc_cutoff (float, default: 0.07): the boundary value of dz/1+z used for good/bad in NN, i.e. adjusting this value can change the distribution of what is classified as "good" and adjust the distribution of output flag values.
- zphot_name (str, default: "zmode"): the name of the point estimate to grab from the qp Ensemble file, usually 'zmode' or 'zmean'
- trainfrac (float, default: 0.75): the fraction of galaxies used for training, the remainder is reserved for validation in fitting the model
- seed (int, default: 1234): seed used by numpy for reproducibility
- include_odds (bool, default: False): if set to True, calculates ODDS parameter and includes in addition to the ref_band and colors as inputs to the neural net.  ODDS will also be output to file as an additional flag if this is set
 


We need to tell NNFlagInformer which columns to use, set whether or not to include the ODDS parameter, and tell it which point estimate to use.

we will set up the band names present in the train and test data, and create a dictionary that sets `include_odds` to True, and the `zphot_name` to `zmode`.

In [None]:
bands = ["u", "g", "r", "i", "z", "y"]
lsst_bands = []
for band in bands:
    lsst_bands.append(f"mag_{band}_lsst")
print(lsst_bands)

In [None]:
default_dict = dict(hdf5_groupname="photometry", model="NNCflag_model.pkl",
                    bands=lsst_bands, trainfrac=0.5,
                    include_odds=True,
                    zphot_name='zmode')
run_default = NNFlagInformer.make_stage(name="halfway", **default_dict)

In [None]:
%%time
nnc_model = run_default.inform(training_data, pdfdata)

We can look at the contents of the model, it contains the feature averages and variances (used to whiten the data, these will be used in the estimate stage as well), and 'nnlist', which is the stored neural net models.  The value of `include_odds` is stored as well, as if the model was trained or not trained with ODDS included, the estimate must do the same.

In [None]:
nnc_model()

# Run the estimator stage

Now we can use the model from the inform stage to estimate values for our test galaxies

In [None]:
default_dict2 = dict(hdf5_groupname="photometry", output="NNCflags_output.hdf5",
                    bands=lsst_bands, model="NNCflag_model.pkl",
                    zphot_name='zmode')

In [None]:
run_est = NNFlagEstimator.make_stage(name="testest", **default_dict2)

In [None]:
%%time
results = run_est.estimate(test_data, pdftestdata)

## Results plots

As we ran with `include_odds` set to `True`, our results file contains two arrays, `nncflag` and `ODDS`.  Both are quality flags that range from 0 to 1, with values closer to 0 indicating a likely bad fit, and values closer to 1 indicating a good fit.  We can make some plots showing how cuts in each flag look when compared to the overall sample.  We can set up the data and see what a cut of 0.5 in the flag value looks like on a zmode vs specz point estimate plot:

In [None]:
nnflag = results()['nncflag']

In [None]:
odds = results()['ODDS']

In [None]:
sz = test_data()['photometry']['redshift']

In [None]:
testpdffile = "./bpz_results_defaultprior_testfile.hdf5"
testpdfdata = qp.read(testpdffile)

In [None]:
zb = testpdfdata.ancil['zmode'].flatten()

In [None]:
import matplotlib.pyplot as plt
plt.hist(nnflag, bins=np.linspace(0,1,51));
plt.xlabel('nncflag value', fontsize=12)
plt.ylabel('number', fontsize=12);

In [None]:
import matplotlib.pyplot as plt
plt.hist(odds, bins=np.linspace(0,1,51));
plt.xlabel('ODDS value', fontsize=12)
plt.ylabel('number', fontsize=12);

In [None]:
flagcut = 0.5
mask = (nnflag>flagcut)
szx = sz[mask]
zbx = zb[mask]
plt.figure(figsize=(8,8))
plt.scatter(sz,zb, s=5,c='k', label="full sample")
plt.scatter(szx,zbx, s=2,c='r', label=f"nncflag>{flagcut}")
plt.legend(loc='center left', fontsize=12)
plt.xlabel("redshift", fontsize=14)
plt.ylabel("zmode", fontsize=14);

In [None]:
flagcut = 0.5
mask = (odds>flagcut)
szx = sz[mask]
zbx = zb[mask]
plt.figure(figsize=(8,8))
plt.scatter(sz,zb, s=5,c='k', label="full sample")
plt.scatter(szx,zbx, s=2,c='r', label=f"ODDS>{flagcut}")
plt.legend(loc='center left', fontsize=12)
plt.xlabel("redshift", fontsize=14)
plt.ylabel("zmode", fontsize=14);

We can also plot ODDS vs nncflag, we see that they are roughly correlated, but with a lot of scatter, particularly for "bad" galaxies.  However, we see that "good" galaxies tend to have high values of both flags, which is a good indicator that both are functioning similarly for high quality data.  It may be the case that doing a looser selection on both flags coule result in better selection than a more strict cut in one flag.

In [None]:
plt.figure(figsize=(7,7))
plt.scatter(nnflag, odds, s=1,c='k')
plt.plot([0,1],[0,1],'r--', lw=2)
plt.xlabel("NN flag", fontsize=13)
plt.ylabel("ODDS", fontsize=13);