# Cleaning and Averaging ATLAS Light Curves
### Includes chi-square cut, uncertainty cut, control light curve cut, and averaging

This iPython notebook will help you apply each cut with a greater degree of control than an automatic cleaning. You will be walked through the following:
1. Correction for ATLAS reference template changes
2. Static uncertainty cut
3. Estimating true uncertainties
4. Dynamic chi-square cut
5. Control light curve cut
6. Averaging the light curve and cutting bad bins

After running a cell, the descriptions located above that cell will help you interpret the plots and make decisions about the supernova.

In order for this notebook to work correctly, the ATLAS light curves must already be downloaded and saved. Each light curve must also only include measurements for a single filter.

## Step 1: Load the ATLAS light curves 

In [None]:
##### LOADING THE SN LIGHT CURVE #####

# Enter the target SN name:
tnsname = ''

# Enter the SN light curve file name:
filename = ''

# Enter the filter for this light curve (must be 'o' or 'c'):
filter = ''

# Optionally, enter the SN's discovery date (if None is entered, it will be 
# fetched automatically from TNS using the API key, TNS ID, and bot name):
discdate = None
api_key = None
tns_id = None
bot_name = None

# Enter the number of minimum days between a template change date and the SN discovery date 
# in order to use this data as baseline (meaning before SN starts) flux for that template region:
Ndays_min = 6

##### LOADING CONTROL LIGHT CURVES #####

# Set to True if you are planning on applying the control light curve cut 
# and have already downloaded the control light curves:
load_controls = True

# Enter the number of control light curves to load:
Ncontrols = 8

# Enter the source directory of the control light curve files:
controls_dir = 'controls'

##### DEFAULT SETTINGS FOR PLOTTING #####

# If True, try to calculate the best y limits automatically for each plot;
# if False, leave y limits to matplotlib 
auto_xylimits = True

In [None]:
# import modules, set preliminary variables, etc.

import sys, re
import numpy as np
from copy import deepcopy

# storing, accessing, and manipulating the light curve
import pandas as pd
from pdastro import pdastrostatsclass, AandB, AnotB, AorB, not_AandB

# getting discovery date from TNS
import requests, json
from collections import OrderedDict
from astropy.time import Time

# plotting
import matplotlib
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import pylab as matlib
import warnings
warnings.simplefilter('error', RuntimeWarning)
warnings.filterwarnings("ignore")

# plotting styles
plt.rc('axes', titlesize=18)
plt.rc('axes', labelsize=14)
plt.rc('xtick', labelsize=13)
plt.rc('ytick', labelsize=13)
plt.rc('legend', fontsize=13)
plt.rc('font', size=13)
plt.rcParams['font.size'] = 12

# ATLAS template changes
global tchange1
global tchange2
tchange1 = 58417
tchange2 = 58882

# dictionary for storing light curve and other important information
global lc_info
lc_info = {}

# dictionary for optionally storing control light curves
global controls
controls = {}

# list for storing the new dflux column name ('duJy' vs. 'duJy_new')
global dflux_colnames
dflux_colnames = ['duJy'] * (Ncontrols+1)

# flag values for updating 'Mask' column with
flag_chisquare = 0x1
flag_uncertainty = 0x2
flag_controls_bad = 0x400000
flag_controls_questionable = 0x80000
flag_controls_x2 = 0x100
flag_controls_stn = 0x200
flag_controls_Nclip = 0x400
flag_controls_Ngood = 0x800
flag_badday = 0x800000
flag_ixclip = 0x1000
flag_smallnum = 0x2000

# get discovery date if needed, load in light curve, account for template changes

def get_tns_data(tnsname, api_key, tns_id, bot_name):
	try:
		get_obj = [("objname",tnsname), ("objid",""), ("photometry","1"), ("spectra","1")]
		get_url = 'https://www.wis-tns.org/api/get/object'
		json_file = OrderedDict(get_obj)
		get_data = {'api_key':api_key,'data':json.dumps(json_file)}
		response = requests.post(get_url, data=get_data, headers={'User-Agent':'tns_marker{"tns_id":%s,"type": "bot", "name":"%s"}' % (str(tns_id), str(bot_name))})
		json_data = json.loads(response.text,object_pairs_hook=OrderedDict)
		return json_data
	except Exception as e:
		return 'Error: \n'+str(e)

def get_discdate(tnsname, api_key):
	json_data = get_tns_data(tnsname, api_key, tns_id, bot_name)
	discoverydate = json_data['data']['reply']['discoverydate']
	date = list(discoverydate.partition(' '))[0]
	time = list(discoverydate.partition(' '))[2]
	disc_date_format = date+'T'+time
	dateobjects = Time(disc_date_format, format='isot', scale='utc')
	disc_date = dateobjects.mjd
	return disc_date

def get_xth_percentile_flux(lc_type, percentile, indices=None):
    if indices is None:
        indices = lc_info[lc_type].getindices()
    if len(indices)==0: 
        return None
    else:
        return np.percentile(lc_info[lc_type].t.loc[indices, 'uJy'], percentile)
	
def drop_extra_columns(lc_type):
	dropcols=[]
	if 'Noffsetlc' in lc_info[lc_type].t.columns: dropcols.append('Noffsetlc')
	if '__tmp_SN' in lc_info[lc_type].t.columns: dropcols.append('__tmp_SN')
	for col in lc_info[lc_type].t.columns:
		if re.search('^c\d_',col): 
			dropcols.append(col)
	if len(dropcols)>0: 
		print('Dropping extra columns: ',dropcols)
		lc_info[lc_type].t.drop(columns=dropcols,inplace=True)

def load_lc(filename):
	lc_info['lc'] = pdastrostatsclass()
	try:
		print('\nLoading SN %s light curve at %s and clearing previous flags in "Mask" column...' % (lc_info['tnsname'], filename))
		lc_info['lc'].load_spacesep(filename,delim_whitespace=True)
	except Exception as e:
		raise RuntimeError('Could not load light curve for SN %s at %s: %s' % (lc_info['tnsname'], filename, str(e)))
	
	lc_info['baseline_ix'] = lc_info['lc'].ix_inrange(colnames=['MJD'],uplim=lc_info['discdate']-20,exclude_uplim=True)
	if len(lc_info['baseline_ix'])<=0:
		raise RuntimeError('Baseline length is 0! Exiting...')
	lc_info['duringsn_ix'] = AnotB(lc_info['lc'].getindices(),lc_info['baseline_ix'])

def load_control_lcs(controls_dir, Ncontrols): 
	print('\nLoading control light curves and clearing previous flags in "Mask" column...')
	for control_index in range(1,Ncontrols+1):
		controls[control_index] = pdastrostatsclass()
		filename = controls_dir + '/' + lc_info['tnsname'] + '_i%03d.'%control_index + lc_info['filter'] + '.lc.txt'
		print('# Loading control light curve at ',filename)
		controls[control_index].load_spacesep(filename,delim_whitespace=True)

		# clear any previous control light curve flags
		controls[control_index].t['Mask'] = 0 #np.bitwise_and(controls[control_index].t['Mask'],(flag_chisquare|flag_uncertainty))

def save_lc(lc_type, filename, overwrite=False):
    print('Saving light curve at %s' % filename)
    lc_info[lc_type].write(filename,overwrite=overwrite)

def verify_mjds(Ncontrols):
    print()
    # sort sn lc by mjd
    mjd_sorted_i = lc_info['lc'].ix_sort_by_cols('MJD')
    lc_info['lc'].t = lc_info['lc'].t.loc[mjd_sorted_i]
    sn_sorted = lc_info['lc'].t.loc[mjd_sorted_i,'MJD'].to_numpy()

    for control_index in range(1,Ncontrols+1):
        # sort control lc by mjd
        mjd_sorted_i = controls[control_index].ix_sort_by_cols('MJD')
        control_sorted = controls[control_index].t.loc[mjd_sorted_i,'MJD'].to_numpy()
        
        # compare control lc to sn lc and, if out of agreement, fix
        if (len(sn_sorted) != len(control_sorted)) or (np.array_equal(sn_sorted, control_sorted) is False):
            print('MJDs out of agreement for control light curve %03d, fixing...' % control_index)

            mjds_onlysn = AnotB(sn_sorted, control_sorted)
            mjds_onlycontrol = AnotB(control_sorted, sn_sorted)

            # for the mjds only in sn, add row with that mjd to control lc, with all values of other columns NaN
            if len(mjds_onlysn) > 0:
                #print('# Adding %d NaN rows to control light curve...' % len(mjds_onlysn))
                for mjd in mjds_onlysn:
                    controls[control_index].newrow({'MJD':mjd,'Mask':0})
            
            # remove indices of rows in control lc for which there is no mjd in the sn lc
            if len(mjds_onlycontrol) > 0:
                #print('# Removing %d control light curve rows without matching SN rows...' % len(mjds_onlycontrol))
                indices2skip = []
                for mjd in mjds_onlycontrol:
                    ix = controls[control_index].ix_equal('MJD',mjd)
                    if len(ix)!=1:
                        raise RuntimeError(f'# Couldn\'t find MJD={mjd} in column MJD, but should be there!')
                    indices2skip.extend(ix)
                indices = AnotB(controls[control_index].getindices(),indices2skip)
            else:
                indices = controls[control_index].getindices()
            
            ix_sorted = controls[control_index].ix_sort_by_cols('MJD',indices=indices)
            controls[control_index].t = controls[control_index].t.loc[ix_sorted]
    print('\nFinished sorting SN and control light curves')

lc_info['tnsname'] = tnsname

if filter != 'o' and filter != 'c': 
	print('Filter must be "o" or "c"!')
	sys.exit()
lc_info['filter'] = filter

if discdate is None:
	print('Obtaining SN %s discovery date from TNS...' % lc_info['tnsname'])
	discdate = get_discdate(lc_info['tnsname'], api_key)
	print('Discovery date: ',discdate)
lc_info['discdate'] = discdate - 20

# new text file that will contain record of each cut, etc.
f = open(f'{lc_info["tnsname"]}_output.md', 'w')
f.write(f'# SN {lc_info["tnsname"]} Light Curve Cleaning and Averaging\n\nFilter: {lc_info["filter"]}-band\nDiscovery date: {lc_info["discdate"]}\nNumber of control light curves: {Ncontrols}')

load_lc(filename)
if load_controls: 
	load_control_lcs(controls_dir, Ncontrols)
	verify_mjds(Ncontrols)
else:
	Ncontrols = 0

# Calculate uJy/duJy column
print('\nCalculating uJy/duJy column...')
lc_info['lc'].t['uJy/duJy'] = lc_info['lc'].t['uJy']/lc_info['lc'].t['duJy']
lc_info['lc'].t = lc_info['lc'].t.replace([np.inf, -np.inf], np.nan)

In [None]:
# Plot control light curves underneath SN light curve?
plot_controls = True

# Optionally, manually enter the x and y limits for the following plot of the loaded light curves:
xlim_lower = None
xlim_upper = None
ylim_lower = None
ylim_upper = None

In [None]:
# plot SN and control light curves

def do_manual_xylimits(limits):
    # if any limit is not None:
    for limit in limits:
        if not(limit is None):
            return True
    # if all limits are None:
    return False

def set_xylimits(limits, lc_type='lc'):
    if auto_xylimits:
        if limits[0] is None:
            limits[0] = lc_info[lc_type].t['MJD'].min() * 0.999
        if limits[1] is None:
            limits[1] = lc_info[lc_type].t['MJD'].max() * 1.001

        # exclude measurements with duJy > 160
        good_ix = lc_info[lc_type].ix_inrange(colnames='duJy', uplim=160)
        # get 5% of abs(max flux - min flux)
        flux_min = lc_info[lc_type].t.loc[good_ix, 'uJy'].min()
        flux_max = lc_info[lc_type].t.loc[good_ix, 'uJy'].max()
        diff = abs(flux_max - flux_min)
        offset = 0.05 * diff

        if limits[2] is None: 
            limits[2] = flux_min - offset
        if limits[3] is None:
            limits[3] = flux_max + offset

        return limits
    else:
        if do_manual_xylimits(limits):
            return limits
        else:
            return None

def plot_all_lcs(add2title=None, plot_controls=False, templates=False, baseline_ix=None, duringsn_ix=None, dflux_colname='duJy', limits=None):
    color = 'orange' if lc_info["filter"] == 'o' else 'cyan'

    fig = plt.figure(figsize=(12,6), tight_layout=True)
    plt.axhline(linewidth=1,color='k')
    plt.ylabel('Flux (µJy)')
    plt.xlabel('MJD')
    title = f'SN {lc_info["tnsname"]} and control light curves {lc_info["filter"]}-band flux'
    if not(add2title is None):
        title += add2title
    plt.title(title)
    if templates:
        plt.axvline(x=tchange1, color='magenta', label='ATLAS template change', zorder=30)
        plt.axvline(x=tchange2, color='magenta', zorder=30)

    if baseline_ix is None:
        baseline_ix = lc_info['baseline_ix']
    if duringsn_ix is None:
        duringsn_ix = lc_info['duringsn_ix']

    # set x and y limits
    limits = set_xylimits(limits)
    if not(limits is None):
        plt.xlim(limits[0],limits[1])
        plt.ylim(limits[2],limits[3])

    if load_controls and plot_controls:
        for control_index in range(1,Ncontrols+1):
            plt.errorbar(controls[control_index].t['MJD'], controls[control_index].t['uJy'], yerr=controls[control_index].t[dflux_colname], fmt='none', ecolor='blue', elinewidth=1, c='blue', alpha=0.3, zorder=0)
            if control_index == 1:
                plt.scatter(controls[control_index].t['MJD'], controls[control_index].t['uJy'], s=45,color='blue',marker='o', alpha=0.3, zorder=0, label=f'{Ncontrols} control light curves')
            else:
                plt.scatter(controls[control_index].t['MJD'], controls[control_index].t['uJy'], s=45,color='blue',marker='o', alpha=0.3, zorder=0)

    plt.errorbar(lc_info['lc'].t.loc[baseline_ix,'MJD'], lc_info['lc'].t.loc[baseline_ix,'uJy'], yerr=lc_info['lc'].t.loc[baseline_ix,dflux_colname], fmt='none',ecolor=color, elinewidth=1, c=color, zorder=10)
    plt.scatter(lc_info['lc'].t.loc[baseline_ix,'MJD'],lc_info['lc'].t.loc[baseline_ix,'uJy'], s=45, color=color, marker='o', zorder=10, label='Baseline flux')
	
    plt.errorbar(lc_info['lc'].t.loc[duringsn_ix,'MJD'], lc_info['lc'].t.loc[duringsn_ix,'uJy'], lc_info['lc'].t.loc[duringsn_ix,dflux_colname], fmt='none', ecolor='red', elinewidth=1, c='red', zorder=20)
    plt.scatter(lc_info['lc'].t.loc[duringsn_ix,'MJD'], lc_info['lc'].t.loc[duringsn_ix,'uJy'], s=45, color='red', marker='o', zorder=20, label='SN flux')

    plt.legend(bbox_to_anchor=(1.04, 1), borderaxespad=0)

limits = [xlim_lower, xlim_upper, ylim_lower, ylim_upper]
plot_all_lcs(add2title=' (original)', plot_controls=plot_controls, limits=limits)

## Step 2: Correct for ATLAS reference template changes

This notebook takes into account ATLAS's periodic replacement of the difference image reference templates, which may cause step discontinuities in flux. Two template changes have been recorded at MJDs 58417 and 58882. More information can be found here: https://fallingstar-data.com/forcedphot/faq/.

In [None]:
# Enter the number of minimum days between a template change date and the SN discovery date 
# in order to use this data as baseline flux for that template region:
Ndays_min = 6

# Plot the light curve before and after correcting for ATLAS reference template changes?:
plot = True

# Optionally, manually enter the x and y limits for the plot:
xlim_lower = None
xlim_upper = None
ylim_lower = None
ylim_upper = None

In [None]:
# get template change regions and plot "before" template change correction

def get_Ndays(SN_region_index):
    return 200 if SN_region_index == 2 else 40

def get_baseline_regions(Ndays_min):
    print('Getting region indices around SN... ')
    regions = {}
    regions['t0']   = lc_info['lc'].ix_inrange(colnames=['MJD'],                  uplim=tchange1)
    regions['t1']   = lc_info['lc'].ix_inrange(colnames=['MJD'], lowlim=tchange1, uplim=tchange2)
    regions['t2']   = lc_info['lc'].ix_inrange(colnames=['MJD'], lowlim=tchange2)
    regions['b_t0'] = AandB(regions['t0'], lc_info['baseline_ix'])
    regions['b_t1'] = AandB(regions['t1'], lc_info['baseline_ix'])
    regions['b_t2'] = AandB(regions['t2'], lc_info['baseline_ix'])

    # find region SN starts in 
    SN_region_index = None
    if lc_info['discdate']<= tchange1:
        SN_region_index = 0
    elif lc_info['discdate'] > tchange1 and lc_info['discdate'] <= tchange2:
        SN_region_index = 1
    elif lc_info['discdate'] > tchange2:
        SN_region_index = 2
    if SN_region_index is None:
        raise RuntimeError('Could not find region with SN discovery date!')
    else:
        print('SN discovery date located in template region t%d' % SN_region_index)

    # for region with tail end of the SN, get last Ndays days and classify as baseline
    adjust_region_index = SN_region_index
    if adjust_region_index < 2 and len(regions['b_t%d'%adjust_region_index]) >= Ndays_min:
        adjust_region_index += 1
    if len(regions['b_t%d'%adjust_region_index]) < Ndays_min:
        print('Getting baseline flux for template region t%d by obtaining last %d days of region... ' % (adjust_region_index, get_Ndays(adjust_region_index)))
        regions['b_t%d'%adjust_region_index] = lc_info['lc'].ix_inrange(colnames=['MJD'],
																		lowlim=lc_info['lc'].t.loc[regions['t%d'%adjust_region_index][-1],'MJD']- get_Ndays(adjust_region_index),
																		uplim=lc_info['lc'].t.loc[regions['t%d'%adjust_region_index][-1],'MJD'])
    if adjust_region_index < 1: 
        regions['b_t1'] = regions['t1']
    if adjust_region_index < 2: 
        regions['b_t2'] = regions['t2']

    for region_index in range(0,3):
        if len(regions['t%d'%region_index]) > 0:
            print('TEMPLATE REGION t%d MJD RANGE: %0.2f - %0.2f' % (region_index, lc_info['lc'].t.loc[regions['t%d'%region_index][0],'MJD'], lc_info['lc'].t.loc[regions['t%d'%region_index][-1],'MJD']))
        else:
            print('TEMPLATE REGION t%d MJD RANGE: not found' % region_index)
        if len(regions['b_t%d'%region_index]) > 0:
            print('TEMPLATE REGION b_t%d BASELINE MJD RANGE: %0.2f - %0.2f' % (region_index, lc_info['lc'].t.loc[regions['b_t%d'%region_index][0],'MJD'], lc_info['lc'].t.loc[regions['b_t%d'%region_index][-1],'MJD']))
        else:
            print('TEMPLATE REGION b_t%d BASELINE MJD RANGE: not found' % region_index)
    
    # check to make sure baseline flux is still consistent by getting median of first and last halves of affected region
    first_i = regions['b_t%d'%adjust_region_index][0]
    mid_i   = regions['b_t%d'%adjust_region_index][int(len(regions['b_t%d'%adjust_region_index])/2)]
    last_i  = regions['b_t%d'%adjust_region_index][-1]
    median1 = np.median(lc_info['lc'].t.loc[lc_info['lc'].ix_inrange(colnames=['MJD'], lowlim=lc_info['lc'].t.loc[first_i,'MJD'], uplim=lc_info['lc'].t.loc[mid_i,'MJD']), 'uJy'])
    median2 = np.median(lc_info['lc'].t.loc[lc_info['lc'].ix_inrange(colnames=['MJD'], lowlim=lc_info['lc'].t.loc[mid_i+1,'MJD'], uplim=lc_info['lc'].t.loc[last_i,'MJD']), 'uJy'])
    print('Checking that baseline flux is consistent throughout adjusted region...\n# Median of first half: %0.2f\n# Median of second half: %0.2f' % (median1,median2))

    lc_info['baseline_rev_ix'] = np.concatenate([regions['b_t0'],regions['b_t1'],regions['b_t2']])
    lc_info['durings_rev_ix'] = AnotB(lc_info['lc'].getindices(),lc_info['baseline_rev_ix'])
    return regions

def set_baseline_region(regions, region_index, start_override=None, end_override=None):
	if not(start_override is None) and not(end_override) is None:
		regions['b_t%d'%region_index] = lc_info['lc'].ix_inrange(colnames=['MJD'], lowlim=start_override, uplim=end_override)
	elif not(start_override is None):
		regions['b_t%d'%region_index] = lc_info['lc'].ix_inrange(colnames=['MJD'], lowlim=start_override, uplim=lc_info['lc'].t.loc[regions['b_t%d'%region_index][-1],'MJD'])
	elif not(end_override is None):
		regions['b_t%d'%region_index] = lc_info['lc'].ix_inrange(colnames=['MJD'], lowlim=lc_info['lc'].t.loc[regions['b_t%d'%region_index][0],'MJD'], uplim=end_override)
	return regions

def set_baseline_regions(regions):
	for region_index in range(0,3):
		start_override = input('Override template region b_t%d START MJD (n to skip): ' % region_index)
		end_override = input('Override template region b_t%d END MJD (n to skip): ' % region_index)

		if start_override.isdigit() and end_override.isdigit():
			print('# Overriding START for template region b_t%d with %f and END with %f: ' % (region_index,float(start_override),float(end_override)))
			regions = set_baseline_region(regions, region_index, start_override=float(start_override), end_override=float(end_override))
		elif start_override.isdigit():
			print('# Overriding START for template region b_t%d with %f: ' % (region_index,float(start_override)))
			regions = set_baseline_region(regions, region_index, start_override=float(start_override))
		elif end_override.isdigit():
			print('# Overriding END for template region b_t%d with %f: ' % (region_index,float(end_override)))
			regions = set_baseline_region(regions, region_index, end_override=float(end_override))
		else:
			print('# Skipping region... ')

	lc_info['baseline_rev_ix'] = np.concatenate([regions['b_t0'],regions['b_t1'],regions['b_t2']])
	lc_info['durings_rev_ix'] = AnotB(lc_info['lc'].getindices(),lc_info['baseline_rev_ix'])

	return regions

# correct control light curves for atlas template changes by 
# getting median of same baseline regions as SN, then applying to entire region
def controls_correct_for_template(control_index, regions, region_index):
    b_goodx2_ix = controls[control_index].ix_inrange(colnames=['chi/N'], uplim=5)
    lowlim = lc_info['lc'].t.loc[regions[f'b_t{region_index}'][0], 'MJD']
    uplim = lc_info['lc'].t.loc[regions[f'b_t{region_index}'][-1], 'MJD']
    b_region_ix = controls[control_index].ix_inrange(colnames=['MJD'], lowlim=lowlim, uplim=uplim, exclude_uplim=True)
    
    if len(b_region_ix) > 0:
        print(f'## Adjusting for template change in region b_t{region_index} from {lowlim:0.2f}-{uplim:0.2f}...')
						
        if len(AandB(b_region_ix,b_goodx2_ix)) > 0:
            median = np.median(controls[control_index].t.loc[AandB(b_region_ix,b_goodx2_ix),'uJy'])
            print('### Median of measurements with chi-square ≤ 5 before correction: ', median)
        else:
            median = np.median(controls[control_index].t.loc[b_region_ix,'uJy'])
            print('### Median of measurements with chi-square ≤ 5 before correction: ', median)

        lowlim = lc_info['lc'].t.loc[regions[f't{region_index}'][0], 'MJD']
        uplim = lc_info['lc'].t.loc[regions[f't{region_index}'][-1], 'MJD']
        t_region_i = controls[control_index].ix_inrange(colnames=['MJD'], lowlim=lowlim, uplim=uplim, exclude_uplim=True)

        print(f'### Subtracting median {median:0.1f} uJy from light curve flux due to potential flux in the template...')
        controls[control_index].t.loc[t_region_i,'uJy'] -= median
        print(f'### Median of measurements after correction: {np.median(controls[control_index].t.loc[b_region_ix, "uJy"])}')
    else:
        print(f'### No valid region for baseline region {region_index}, skipping...')

def correct_for_template(regions, proceed):
    if True in proceed.values():
        f.write('\n\n## Correcting for ATLAS reference template changes')

    b_goodx2_ix = lc_info['lc'].ix_inrange(colnames=['chi/N'],uplim=5,indices=lc_info['baseline_rev_ix'])
    b_badx2_ix = AnotB(lc_info['baseline_rev_ix'],b_goodx2_ix)

    for region_index in range(0,3):
        if proceed[region_index]:
            region_ix = regions['b_t%d'%region_index]
            if len(region_ix) > 0:
                print('\nAdjusting for template change in baseline region %d from %0.2f-%0.2f ' % (region_index, lc_info['lc'].t.loc[region_ix[0],'MJD'], lc_info['lc'].t.loc[region_ix[-1],'MJD']))
                if len(AandB(region_ix,b_goodx2_ix)) > 0:
                    median = np.median(lc_info['lc'].t.loc[AandB(region_ix,b_goodx2_ix),'uJy'])
                    print('# Median of baseline measurements with chi-square ≤ 5 before correction: ', median)
                else:
                    median = np.median(lc_info['lc'].t.loc[region_ix,'uJy'])
                    print('# Median of baseline measurements before correction: ', median)
                
                print(f'# Subtracting median {median:0.1f} uJy from light curve flux due to potential flux in the template...')
                lc_info['lc'].t.loc[regions['t%d'%region_index],'uJy'] -= median
                print('# Median of baseline measurements after correction: ', np.median(lc_info['lc'].t.loc[region_ix,'uJy']))
                f.write(f'\nCorrection applied to baseline region {region_index:d}: {median} uJy subtracted')

                if load_controls:
                    print(f'Correcting control light curves for potential flux in template...')
                    for control_index in range(1, Ncontrols+1):
                        print(f'# Control index: {control_index}')
                        controls_correct_for_template(control_index, regions, region_index)
            else:
                print('No baseline region for region b_t%d, skipping... ' % region_index)

def get_weighted_mean(lc, indices=None):
    lc.calcaverage_sigmacutloop('uJy', indices=indices, Nsigma=3.0, median_firstiteration=True, verbose=1)
    return lc.statparams['mean'] 

def get_typical_uncertainty(lc, indices=None):
    if indices is None:
        return np.median(lc.t['duJy'])
    else:
        return np.median(lc.t.loc[indices, 'duJy'])

def proceed_check(regions):
    # calculate weighted mean with 3-sigma clipping for each template region
    # which should be below 2 * typical uncertainty (the median duJy of same region)
    # if smaller, then print; else, plots
    proceed = {}
    for region_index in range(0,3):
        region_ix = regions['b_t%d'%region_index]
        print(f'Baseline region {region_index:d}')

        weighted_mean = get_weighted_mean(lc_info['lc'], indices=region_ix)
        typical_uncertainty = get_typical_uncertainty(lc_info['lc'], indices=region_ix) #np.median(lc_info['lc'].t.loc[region_ix, 'duJy'])
        print(f'# Weighted mean flux (uJy) obtained using 3σ cut: {weighted_mean:0.2f}\n# Typical uncertainty (duJy) obtained by taking median: {typical_uncertainty:0.2f}')

        if weighted_mean > 2*typical_uncertainty:
            answer = input(f'# WARNING: weighted mean flux > 2 * typical uncertainty! Template change correction strongly recommended for this region. Proceed? (y/n)')
            if answer == 'y':
                print('# Proceeding') 
                proceed[region_index] = True
            else:
                print('# Skipping')
                proceed[region_index] = False
        else:
            print(f'# Weighted mean flux <= typical uncertainty. Proceeding with template change correction in this region')
            proceed[region_index] = True

    return proceed

regions = get_baseline_regions(Ndays_min)

# plot lc before correction for template
if plot:
    limits = [xlim_lower, xlim_upper, ylim_lower, ylim_upper]
    plot_all_lcs(add2title='\nbefore correcting for template changes', templates=True, limits=limits)

In [None]:
# optionally, manually override baseline region endpoints; then correct for template changes

# user can manually override baseline region endpoints
if input('Override baseline region endpoint(s)? (y/n)') == 'y':
    regions = set_baseline_regions(regions)

# decide whether or not to correct
proceed = proceed_check(regions)
print(proceed)

correct_for_template(regions, proceed)
drop_extra_columns('lc') 

# plot lc after correction for template
if plot:
    limits = [xlim_lower, xlim_upper, ylim_lower, ylim_upper]
    plot_all_lcs(add2title='\nafter correcting for template changes', templates=True, limits=limits)

## Step 3: Static uncertainty cut

The following uncertainty cut implements a static cut that applies the same way to each light curve. The purpose of this cut is to identify and clean out the most egregious outliers with large uncertainties and small chi-square values not cut in the dynamic chi-square cut. The default value of this cut (160) was determined after calculating the typical uncertainty of bright stars just below the saturation limit. 

WARNING: **if the SN is particularly bright, you may want to increase the value and rerun the cut**.

In [None]:
# You may change the following static uncertainty cut value to your liking;
# however, the default value is set to 160.
lc_info['uncertainty_cut'] = 160

# Plot the light curve before and after the applied uncertainty cut?:
plot = False

# Optionally, manually enter the x and y limits for the uncertainty cut plot:
xlim_lower = None
xlim_upper = None
ylim_lower = None
ylim_upper = None

In [None]:
# restart 'Mask' column and update with uncertainty cut flag

def update_mask_col(lc, flag, indices):
    if len(indices) > 1:
        flag_arr = np.full(lc.loc[indices,'Mask'].shape, flag)
        lc.loc[indices,'Mask'] = np.bitwise_or(lc.loc[indices,'Mask'], flag_arr)
    elif len(indices) == 1:
        lc.loc[indices[0],'Mask'] = int(lc.loc[indices[0],'Mask']) | flag
    else:
        print('WARNING: must pass at least 1 index to update_mask_col()! No indices masked...')

def plot_cut_lc(lc_type, mask, dflux_colname, add2title=None, limits=None):
    good_ix = lc_info[lc_type].ix_unmasked('Mask',maskval=mask)
    bad_ix = AnotB(lc_info[lc_type].getindices(),good_ix)

    fig, (cut, clean) = plt.subplots(1, 2, figsize=(16, 6), tight_layout=True)
    title = 'SN %s %s-band' % (lc_info['tnsname'], lc_info['filter'])
    if lc_type == 'avglc':
        title += ', averaged'
    if not(add2title is None):
        title += ', '+add2title
    plt.suptitle(title, fontsize=19, y=1)

    color = 'orange' if lc_info['filter'] == 'o' else 'cyan'
    
    # set x and y limits
    limits = set_xylimits(limits, lc_type=lc_type)
    if not(limits is None):
        cut.set_xlim(limits[0],limits[1])
        cut.set_ylim(limits[2],limits[3])
        clean.set_xlim(limits[0],limits[1])
        clean.set_ylim(limits[2],limits[3])
    """if ylim_lower is None: ylim_lower = -2000
    if ylim_upper is None: 
        afterdiscdate_i = lc_info[lc_type].ix_inrange(colnames=['MJD'],lowlim=lc_info['discdate']) if lc_type == 'avglc' else lc_info['duringsn_ix']
        ylim_upper = 3 * get_xth_percentile_flux(lc_type, 95, afterdiscdate_i)
    if xlim_lower is None: xlim_lower = lc_info['discdate'] - 200
    if xlim_upper is None: xlim_upper = lc_info['discdate'] + 800
    cut.set_ylim(ylim_lower, ylim_upper)
    cut.set_xlim(xlim_lower,xlim_upper)
    clean.set_ylim(ylim_lower, ylim_upper)
    clean.set_xlim(xlim_lower,xlim_upper)"""

    cut.errorbar(lc_info[lc_type].t.loc[good_ix,'MJD'], lc_info[lc_type].t.loc[good_ix,'uJy'], yerr=lc_info[lc_type].t.loc[good_ix,dflux_colname], fmt='none',ecolor=color,elinewidth=1,c=color)
    cut.scatter(lc_info[lc_type].t.loc[good_ix,'MJD'], lc_info[lc_type].t.loc[good_ix,'uJy'], s=50,color=color,marker='o',label='Kept measurements')
    cut.errorbar(lc_info[lc_type].t.loc[bad_ix,'MJD'], lc_info[lc_type].t.loc[bad_ix,'uJy'], yerr=lc_info[lc_type].t.loc[bad_ix,dflux_colname], fmt='none',mfc='white',ecolor=color,elinewidth=1,c=color)
    cut.scatter(lc_info[lc_type].t.loc[bad_ix,'MJD'], lc_info[lc_type].t.loc[bad_ix,'uJy'], s=50,facecolors='white',edgecolors=color,marker='o',label='Cut measurements')
    cut.set_title('All measurements')
    cut.axhline(linewidth=1,color='k')
    cut.set_xlabel('MJD')
    cut.set_ylabel('Flux (uJy)')

    fig.legend(loc='upper center', bbox_to_anchor=(0.5, 0),ncol=2)

    clean.errorbar(lc_info[lc_type].t.loc[good_ix,'MJD'], lc_info[lc_type].t.loc[good_ix,'uJy'], yerr=lc_info[lc_type].t.loc[good_ix,dflux_colname], fmt='none',ecolor=color,elinewidth=1,c=color)
    clean.scatter(lc_info[lc_type].t.loc[good_ix,'MJD'], lc_info[lc_type].t.loc[good_ix,'uJy'], s=50,color=color,marker='o',label='Kept measurements')
    clean.set_title('Kept measurements only')
    clean.axhline(linewidth=1,color='k')
    clean.set_xlabel('MJD')
    clean.set_ylabel('Flux (uJy)')
    clean.set_ylim(ylim_lower, ylim_upper)

def print_statistics(num_cut, percent_cut, flag, add2title=None):
    print(f'\nNumber of cut measurements: {num_cut:d}\nPercent of total measurements cut: {percent_cut:0.2f}%')
    if percent_cut > 10:
        print(f'WARNING: percent of total measurements cut is greater than 10%. Plotting...')
    if plot or percent_cut > 10:
        limits = [xlim_lower, xlim_upper, ylim_lower, ylim_upper]
        plot_cut_lc('lc', flag, dflux_colnames[0], add2title=add2title, limits=limits)

# remove old mask column
if 'Mask' in lc_info['lc'].t.columns: 
    print('Removing old \'Mask\' column...')
    lc_info['lc'].t.drop(columns=['Mask'],inplace=True)

# create new mask column and update it with uncertainty cut
print(f'Applying uncertainty cut of {lc_info["uncertainty_cut"]:0.2f}...')
lc_info['lc'].t['Mask'] = 0
kept_ix = lc_info['lc'].ix_inrange(colnames=['duJy'],uplim=lc_info['uncertainty_cut'])
cut_ix = AnotB(lc_info['lc'].getindices(), kept_ix)
update_mask_col(lc_info['lc'].t, flag_uncertainty, cut_ix)
print('Success')

num_cut = len(cut_ix)
percent_cut = 100 * num_cut/len(lc_info['lc'].t)
print_statistics(num_cut, percent_cut, flag_uncertainty, add2title='uncertainty cut')
f.write(f'\n\n## Uncertainty cut\nNumber of cut measurements: {num_cut:d}\nPercent of total measurements cut: {percent_cut:0.2f}%\nHex value in "Mask" column: 0x2')

## Step 4: Estimating true uncertainties

This section attempts to account for an extra noise source in the data by estimating the true typical uncertainty, deriving the additional systematic uncertainty, and lastly applying this extra noise to a new uncertainty column. This new uncertainty column will be used in the cuts following this section.

Here is the exact procedure we use:
1. Keep the previously applied uncertainty cut and apply a preliminary chi-square cut at 20 (default value). Filter out any measurements flagged by these two cuts.
2. Calculate the true typical uncertainty $\text{sigma\_true\_typical}$ by taking a 3σ cut of the unflagged baseline flux and getting the standard deviation.
3. If $\text{sigma\_true\_typical}$ is 10%+ greater than the median uncertainty of the unflagged baseline flux, $\text{median}(∂µJy)$, proceed with estimating the extra noise to add. Otherwise, skip this procedure. 
4. Calculate the extra noise source using the following formula, where the median uncertainty, $\text{median}(∂µJy)$, is taken from the unflagged baseline flux:
    - $\text{sigma\_extra}^2 = \text{sigma\_true\_typical}^2 - \text{sigma\_poisson}^2 = \text{sigma\_true\_typical}^2 - \text{median}(∂µJy)^2 $
5. Apply the extra noise source to the existing uncertainty using the following formula:
    - $\text{new }∂µJy = \sqrt{(\text{old }∂µJy)^2 + \text{sigma\_extra}^2}$
6. Repeat steps 1-5 for each control light curve. For cuts following this procedure, use the new uncertainty column with the extra noise added instead of the old uncertainty column.

In [None]:
# Enter a preliminary chi-square cut (keep at a high number):
prelim_x2_cut = 20

# Plot the light curve before and after estimating true uncertainties?:
plot = True

# Optionally, manually enter the x and y limits for the plot:
xlim_lower = None
xlim_upper = None
ylim_lower = None
ylim_upper = None

In [None]:
# estimate true uncertainties and create revised 'duJy_new' column; store preferred column names in dflux_colnames list

def get_median_dflux(lc, b_clean_ix):
    return np.median(lc.t.loc[b_clean_ix, 'duJy'])

def get_sigma_true_typical(lc, b_clean_ix):
    lc.calcaverage_sigmacutloop('uJy', indices=b_clean_ix, Nsigma=3.0, median_firstiteration=True, verbose=1)
    #print(lc.statparams)
    return lc.statparams['stdev']

def get_sigma_extra(sigma_true_typical, median_dflux):
    sigma_extra = np.sqrt(sigma_true_typical*sigma_true_typical - median_dflux)
    #print('# Sigma extra calculated: %0.4f' % sigma_extra)
    return sigma_extra

def add_sigma_extra(lc, sigma_extra):
    print(f'# Adding sigma_extra {sigma_extra:0.4f} to new duJy column...')
    lc.t['duJy_new'] = np.sqrt(lc.t['duJy']*lc.t['duJy'] + sigma_extra*sigma_extra)

def proceed_check(prelim_x2_cut, baseline_ix=None):
    print('Checking to see if true uncertainties estimation is necessary...')
    lc = lc_info['lc']
    if baseline_ix is None:
        baseline_ix = lc_info['baseline_rev_ix'] 
    clean_ix = AandB(lc.ix_unmasked('Mask',maskval=flag_uncertainty), lc.ix_inrange(colnames=['chi/N'],uplim=prelim_x2_cut,exclude_uplim=True))
    b_clean_ix = AandB(baseline_ix, clean_ix)

    sigma_true_typical = get_sigma_true_typical(lc, b_clean_ix)
    median_dflux = get_median_dflux(lc, b_clean_ix)
    print(f'Median uncertainty of baseline flux: {median_dflux:0.2f}\nTrue typical uncertainty of baseline flux: {sigma_true_typical:0.2f}')
    
    percent_greater = 100 * ((sigma_true_typical - median_dflux)/median_dflux)
    if percent_greater >= 10:
        print(f'WARNING: True typical uncertainty is {percent_greater:0.2f}% greater than median uncertainty of baseline flux. True uncertainties estimation recommended')
        answer = input('Proceed with true uncertainties estimation? (y/n):')
        if answer == 'y':
            print('Proceeding...') 
            return True
        else:
            print('Skipping procedure')
            return False
    print(f'True typical uncertainty is {percent_greater:0.2f}% greater than median uncertainty of baseline flux--no estimation needed!')
    return False

def estimate_true_uncertainties(prelim_x2_cut, control_index, baseline_ix=None):
    clean_ix = AandB(lc_info['lc'].ix_unmasked('Mask',maskval=flag_uncertainty), lc_info['lc'].ix_inrange(colnames=['chi/N'],uplim=prelim_x2_cut,exclude_uplim=True))
    b_clean_ix = AandB(lc_info['baseline_rev_ix'], clean_ix)

    if control_index == 0:
        print('\nEstimating true uncertainties for SN light curve...')
        lc = lc_info['lc']
        if baseline_ix is None:
            baseline_ix = lc_info['baseline_rev_ix'] 
        clean_ix = AandB(lc.ix_unmasked('Mask',maskval=flag_uncertainty), lc.ix_inrange(colnames=['chi/N'],uplim=prelim_x2_cut,exclude_uplim=True))
        b_clean_ix = AandB(baseline_ix, clean_ix)
    else:
        print(f'\nEstimating true uncertainties for control light curve {control_index:03d}...')
        lc = controls[control_index]
        b_clean_ix =  AandB(lc.ix_unmasked('Mask',maskval=flag_uncertainty), lc.ix_inrange(colnames=['chi/N'],uplim=prelim_x2_cut,exclude_uplim=True))

    sigma_true_typical = get_sigma_true_typical(lc, b_clean_ix)
    median_dflux = get_median_dflux(lc, b_clean_ix)
    print(f'# Median uncertainty: {median_dflux:0.2f}; true typical uncertainty: {sigma_true_typical:0.2f}')

    # use new uncertainty column from now on
    dflux_colnames[control_index] = 'duJy_new'

    sigma_extra = get_sigma_extra(sigma_true_typical, median_dflux)
    add_sigma_extra(lc, sigma_extra)
    if control_index == 0:
        f.write(f'\n\n## Estimating true uncertainties\nMedian uncertainty: {median_dflux:0.2f} duJy\nTrue typical uncertainty: {sigma_true_typical:0.2f} duJy\nExtra noise added (in "duJy_new" column): {sigma_extra:0.2f} duJy')

    # recalculate uJy/duJy column
    print('# Recalculating uJy/duJy column using duJy_new as the uncertainties...')
    lc_info['lc'].t['uJy/duJy'] = lc_info['lc'].t['uJy']/lc_info['lc'].t['duJy_new']
    lc_info['lc'].t = lc_info['lc'].t.replace([np.inf, -np.inf], np.nan)

def plot_trueuncerts(baseline_ix=None, duringsn_ix=None, limits=None):
    color = 'orange' if lc_info["filter"] == 'o' else 'cyan'

    if baseline_ix is None:
        baseline_ix = lc_info['baseline_rev_ix']
    if duringsn_ix is None:
        duringsn_ix = lc_info['duringsn_ix']

    fig, (ax1, ax2) = plt.subplots(2, constrained_layout=True)
    fig.set_figwidth(12)
    fig.set_figheight(8)
    
    ax1.set_title(f'SN {lc_info["tnsname"]} {lc_info["filter"]}-band flux\nbefore estimating true uncertainties')
    ax1.axvline(x=tchange1, color='magenta', label='ATLAS template change')
    ax1.axvline(x=tchange2, color='magenta')
    ax1.set_xlim(xlim_lower,xlim_upper)
    ax1.set_ylim(ylim_lower,ylim_upper)
    ax1.get_xaxis().set_ticks([])
    ax1.set_ylabel('Flux (µJy)')
    ax1.errorbar(lc_info['lc'].t.loc[baseline_ix,'MJD'], lc_info['lc'].t.loc[baseline_ix,'uJy'], yerr=lc_info['lc'].t.loc[baseline_ix, 'duJy'], fmt='none',ecolor=color, elinewidth=1, c=color, zorder=10)
    ax1.scatter(lc_info['lc'].t.loc[baseline_ix,'MJD'],lc_info['lc'].t.loc[baseline_ix,'uJy'], s=45, color=color, marker='o', zorder=10, label='Baseline')
    ax1.errorbar(lc_info['lc'].t.loc[duringsn_ix,'MJD'], lc_info['lc'].t.loc[duringsn_ix,'uJy'], lc_info['lc'].t.loc[duringsn_ix, 'duJy'], fmt='none', ecolor='red', elinewidth=1, c='red', zorder=20)
    ax1.scatter(lc_info['lc'].t.loc[duringsn_ix,'MJD'], lc_info['lc'].t.loc[duringsn_ix,'uJy'], s=45, color='red', marker='o', zorder=20, label='During SN')
    ax1.axhline(linewidth=1, color='k', zorder=30)
    ax1.legend(loc='center left', bbox_to_anchor=(1, 0.5))

    ax2.set_title('after estimating true uncertainties')
    ax2.axvline(x=tchange1, color='magenta', label='ATLAS template change')
    ax2.axvline(x=tchange2, color='magenta')
    ax2.set_xlim(xlim_lower,xlim_upper)
    ax2.set_ylim(ylim_lower,ylim_upper)
    ax2.set_ylabel('Flux (µJy)')
    ax2.set_xlabel('MJD')
    ax2.errorbar(lc_info['lc'].t.loc[baseline_ix,'MJD'], lc_info['lc'].t.loc[baseline_ix,'uJy'], yerr=lc_info['lc'].t.loc[baseline_ix, 'duJy_new'], fmt='none',ecolor=color, elinewidth=1, c=color, zorder=10)
    ax2.scatter(lc_info['lc'].t.loc[baseline_ix,'MJD'],lc_info['lc'].t.loc[baseline_ix,'uJy'], s=45, color=color, marker='o', zorder=10, label='Baseline')
    ax2.errorbar(lc_info['lc'].t.loc[duringsn_ix,'MJD'], lc_info['lc'].t.loc[duringsn_ix,'uJy'], lc_info['lc'].t.loc[duringsn_ix, 'duJy_new'], fmt='none', ecolor='red', elinewidth=1, c='red', zorder=20)
    ax2.scatter(lc_info['lc'].t.loc[duringsn_ix,'MJD'], lc_info['lc'].t.loc[duringsn_ix,'uJy'], s=45, color='red', marker='o', zorder=20, label='During SN')
    ax2.axhline(linewidth=1, color='k', zorder=30)
    ax2.legend(loc='center left', bbox_to_anchor=(1, 0.5))

    # set x and y limits
    limits = set_xylimits(limits)
    if not(limits is None):
        ax1.set_xlim(limits[0],limits[1])
        ax1.set_ylim(limits[2],limits[3])
        ax2.set_xlim(limits[0],limits[1])
        ax2.set_ylim(limits[2],limits[3])

check = proceed_check(prelim_x2_cut, baseline_ix=lc_info['baseline_rev_ix'])

if check:
    for control_index in range(Ncontrols+1):
        estimate_true_uncertainties(prelim_x2_cut, control_index)
        if control_index == 0 and plot:
            limits = [xlim_lower, xlim_upper, ylim_lower, ylim_upper]
            plot_trueuncerts(limits=limits)
    print('\nSuccess')

## Step 5: Dynamic chi-square cut

### 5a) Plot the flux/dflux and chi-square distributions

The following two histograms display the flux/dflux and chi-square distributions of the target SN. Both histograms show probability density so as to ease comparison between the groups plotted within each histogram.

- The first histogram focuses on the baseline flux/dflux (µJy/dµJy) measurements, where we can expect the flux to equal 0. In orange, we plot flux/dflux (µJy/dµJy) measurements with a chi-square value less than or equal to `x2bound`, which is currently set to 5 below; in blue, we plot flux/dflux (µJy/dµJy) measurements with a chi-square value greater than `x2bound`. 
- The second histogram focuses on the baseline chi-square measurements. In green, we plot chi-square measurements with an abs(µJy/dµJy) value less than or equal to `stnbound`, which is currently set to 3 below; in red, we plot chi-square measurements with an abs(µJy/dµJy) value greater than `stnbound`. 

Ideally, all measurements with a chi-square value less than or equal to `x2bound` should have an abs(µJy/dµJy) value less than or equal to `stnbound`, and measurements with a chi-square value greater than `x2bound` should have an abs(µJy/dµJy) value greater than `stnbound`. Our goal is to separate good measurements from bad measurements using a chi-square cut; in order for our cut to be effective, these histograms should hopefully showcase this relation between the target SN's flux/dflux and chi-square measurements.

In [None]:
# Enter the bound that should separate a good chi-square measurement from a bad one:
x2bound = 5.0

# Enter the bound that should separate a good abs(flux/dflux) measurement from a bad one:
stnbound = 3.0

# Optionally, manually enter the histograms' x limits here:
# flux/dflux histogram x limits:
fdf_xlim_lower = None 
fdf_xlim_upper = None
# chi-square histogram x limits:
x2_xlim_lower = None
x2_xlim_upper = None

In [None]:
# plot flux/dflux and chi-square distribution histograms

def plot_hists(x2bound, stnbound, fdf_xlim_lower=None, fdf_xlim_upper=None, x2_xlim_lower=None, x2_xlim_upper=None, baseline_ix=None):
    if baseline_ix is None:
        baseline_ix = lc_info['baseline_rev_ix']

    b_goodstn_i = lc_info['lc'].ix_inrange(colnames=['uJy/duJy'], lowlim=-stnbound, uplim=stnbound, indices=baseline_ix)
    b_badstn_i = AnotB(baseline_ix, b_goodstn_i)
    b_goodx2_i = lc_info['lc'].ix_inrange(colnames=['chi/N'], uplim=x2bound, indices=baseline_ix)
    b_badx2_i = AnotB(baseline_ix, b_goodx2_i)

    fig, (stn, x2) = plt.subplots(1, 2, figsize=(10, 6.5), tight_layout=True)
    plt.suptitle('SN %s %s-band, baseline only' % (lc_info['tnsname'], lc_info['filter']), fontsize=17, y=1)

    stn.set_title('µJy/dµJy distribution')
    stn.set_xlabel('µJy/dµJy')
    stn.spines.right.set_visible(False)
    stn.spines.top.set_visible(False)
    orange = mpatches.Patch(color='orange', label='Data with chi-square<%0.2f' % x2bound)
    blue = mpatches.Patch(color='blue', label='Data with chi-square≥%0.2f' % x2bound)
    stn.legend(handles=[orange, blue], loc='upper center', bbox_to_anchor=(0.5, -0.2), ncol=1)
    if len(baseline_ix)>0: 
        if fdf_xlim_lower is None: 
            fdf_xlim_lower = min(lc_info['lc'].t.loc[baseline_ix, 'uJy/duJy'])
        if fdf_xlim_upper is None: 
            fdf_xlim_upper = max(lc_info['lc'].t.loc[baseline_ix, 'uJy/duJy'])
        stn.hist(lc_info['lc'].t.loc[b_goodx2_i, 'uJy/duJy'], bins=30, color='orange', alpha=0.5, range=(fdf_xlim_lower,fdf_xlim_upper), density=True)
        stn.hist(lc_info['lc'].t.loc[b_badx2_i, 'uJy/duJy'], bins=30, color='blue', alpha=0.5, range=(fdf_xlim_lower,fdf_xlim_upper), density=True)
    else:
        stn.hist(lc_info['lc'].t.loc[b_goodx2_i, 'uJy/duJy'], bins=30, color='orange', alpha=0.5, density=True)
        stn.hist(lc_info['lc'].t.loc[b_badx2_i, 'uJy/duJy'], bins=30, color='blue', alpha=0.5, density=True)

    x2.set_title('Chi-square distribution')
    x2.set_xlabel('Chi-square')
    x2.spines.right.set_visible(False)
    x2.spines.top.set_visible(False)
    red = mpatches.Patch(color='green', label='Data with µJy/dµJy<%0.2f' % stnbound)
    green = mpatches.Patch(color='red', label='Data with µJy/dµJy≥%0.2f' % stnbound)
    x2.legend(handles=[red, green], loc='upper center', bbox_to_anchor=(0.5, -0.2), ncol=1)
    if len(baseline_ix)>0:
        if x2_xlim_lower is None: 
            x2_xlim_lower = min(lc_info['lc'].t.loc[baseline_ix, 'chi/N'])
        if x2_xlim_upper is None: 
            x2_xlim_upper = max(lc_info['lc'].t.loc[baseline_ix, 'chi/N'])
        x2.hist(lc_info['lc'].t.loc[b_goodstn_i, 'chi/N'], bins=30, color='green', alpha=0.5, range=(x2_xlim_lower,x2_xlim_upper), density=True)
        x2.hist(lc_info['lc'].t.loc[b_badstn_i, 'chi/N'], bins=30, color='red', alpha=0.5, range=(x2_xlim_lower,x2_xlim_upper), density=True)
    else:
        x2.hist(lc_info['lc'].t.loc[b_goodstn_i, 'chi/N'], bins=30, color='green', alpha=0.5, density=True)
        x2.hist(lc_info['lc'].t.loc[b_badstn_i, 'chi/N'], bins=30, color='red', alpha=0.5, density=True)

plot_hists(x2bound, stnbound, fdf_xlim_lower=fdf_xlim_lower, fdf_xlim_upper=fdf_xlim_upper, x2_xlim_lower=x2_xlim_lower, x2_xlim_upper=x2_xlim_upper)

### 5b) Calculate best chi-square cut based on contamination and loss

The following cells use two factors, <strong>contamination</strong> and <strong>loss</strong>, to attempt to calculate an optimal PSF chi-square cut for the target SN, with flux/dflux as the deciding factor of what constitutes a good measurement vs. a bad measurement. We aim to separate good measurements from bad using the calculated chi-square cut by removing as much contamination as possible with the smallest loss possible. Since we can assume that the expected value of the baseline flux is 0, we look only at the baseline measurements before the SN occurs in order to determine the best chi-square cut for the SN itself.

First, we decide what will determine a good measurement vs. a bad measurement using a factor outside of the chi-square values. Our chosen factor is the <strong>absolute value of flux (µJy) divided by dflux (dµJy)</strong>. The recommended boundary is a value of 3, such that any measurements with a value of abs(µJy/dµJy) less than or equal to 3 are regarded as "good" measurements, and any measurements with a value of abs(µJy/dµJy) greater than 3 are regarded as "bad" measurements. You can set this boundary to a different number by changing the value of `stn_cut` below.

Next, we set the upper and lower bounds of our final chi-square cut. We start at a low value of 3 (which can be changed by setting the value of `cut_start` below) and end at 50 (this value is inclusive and can be changed by setting the value of `cut_stop` below) with a step size of 1 (`cut_step` below). <strong>For chi-square cuts falling on or between `cut_start` and `cut_stop` in increments of `cut_step`, we can begin to calculate contamination and loss percentages.</strong>

We define contamination to be the number of bad kept measurements over the total number of kept measurements for that chi-square cut (<strong>contamination = Nbad,kept/Nkept</strong>). For our final chi-square cut, we can also set a limit on what maximum percent contamination we want to have--the recommended value is <strong>15%</strong> but can be changed by setting the value of `contam_lim` below.

We define loss to be the number of good cut measurements over the total number of good measurements for that chi-square cut (<strong>loss = Ngood,cut/Ngood</strong>). For our final chi-square cut, we can also set a limit on what maximum percent loss we want to have--the recommended value is <strong>10%</strong> but can be changed by setting the value of `loss_lim` below.

Finally, we define which limit (`contam_lim` or `loss_lim`) to prioritize in the event that an optimal chi-square cut fitting both limits is not found. The default prioritized limit is `loss_lim` but can be changed by setting the value of `lim_to_prioritize` below.

In [None]:
# Enter the abs(uJy/duJy) boundary that will determine a "good" measurement vs. "bad" measurement:
stn_cut = 3

# Enter the bounds for the final chi-square cut (minimum cut, maximum cut, and step):
cut_start = 3 # this is inclusive
cut_stop = 50 # this is inclusive
cut_step = 1

# Enter the contamination limit (contamination = Nbad,kept/Nkept must be <= contam_lim% 
# for the final chi-square cut):
contam_lim = 15.0

# Enter the loss limit (loss = Ngood,cut/Ngood must be >= loss_lim%
# for the final chi-square cut):
loss_lim = 10.0

# Enter the limit to prioritize (must be 'loss_lim' or 'contam_lim') in the event that
# one or both limits are not met:
lim_to_prioritize = 'loss_lim'

The following section describes in detail how we determine the final chi-square cut using the given contamination and loss limits (feel free to skip).

For each given limit (contamination and loss), we calculate a range of valid cuts whose contamination/loss percentage is less than that limit and then choose a single cut within that valid range. Then, we pass through a decision tree to determine which of the two suggested cuts to use using a variety of factors (including the user's selected `lim_to_prioritize`).

When choosing the loss cut according to the loss percentage limit `loss_lim`:
- <strong>If all loss percentages are below the limit</strong> `loss_lim`, all cuts falling on or between `cut_start` and `cut_stop` are valid.
- <strong>If all loss percentages are above the limit</strong> `loss_lim`, a cut with the required loss percentage is not possible; therefore, any cuts with the smallest percentage of loss are valid.
- <strong>Otherwise</strong>, the valid range of cuts includes any cuts with the loss percentage less than or equal to the limit `loss_lim`.
- The chosen cut for this limit is the <strong>minimum cut</strong> within the stated valid range of cuts.

When choosing the loss cut according to the contamination percentage limit `contam_lim`:
- <strong>If all contamination percentages are below the limit</strong> `contam_lim`, all cuts falling on or between `cut_start` and `cut_stop` are valid.
- <strong>If all contamination percentages are above the limit</strong> `contam_lim`, a cut with the required contamination percentage is not possible; therefore, any cuts with the smallest percentage of contamination are valid.
- <strong>Otherwise</strong>, the valid range of cuts includes any cuts with the contamination percentage less than or equal to the limit `contam_lim`.
- The chosen cut for this limit is the <strong>maximum cut</strong> within the stated valid range of cuts.

After we have calculated two suggested cuts based on the loss and contamination percentage limits, we follow the decision tree in order to suggest a final cut:
- If both loss and contamination cut percentages were chosen from a range that spanned from `cut_start` to `cut_stop`, we set the final cut to `cut_start`.
- If one cut's percentage was chosen from a range that spanned from `cut_start` to `cut_stop` and the other cut's percentage was not, we set the final cut to the latter cut.
- If both percentages were chosen from ranges that fell above their respective limits, we suggest reselecting either or both limits.
- Otherwise, we take into account the user's prioritized limit `lim_to_prioritize`:
    - If the loss cut is greater than the contamination cut, we set the final cut to whichever cut is associated with `lim_to_prioritize`.
    - Otherwise, if `lim_to_prioritize` is set to `contam_lim`, we set the final cut to the loss cut, and if `lim_to_prioritize` is set to `loss_lim`, we set the final cut to the contamination cut.

In [None]:
# calculate the suggested best chi-square cut using contamination and loss

def plot_lim_cuts(lim_cuts, contam_lim_cut, loss_lim_cut):
    fig = plt.figure(figsize=(10,5), tight_layout=True)
    plt.gca().spines['right'].set_visible(False)
    plt.gca().spines['top'].set_visible(False)
    plt.title('SN %s %s-band chi-square cut' % (lc_info['tnsname'],lc_info['filter']))

    plt.axhline(linewidth=1,color='k')
    plt.xlabel('Chi-square cut')
    plt.ylabel('% of baseline measurements')

    plt.axhline(loss_lim,linewidth=1,color='r',linestyle='--',label='Loss limit')
    plt.plot(lim_cuts.t['PSF Chi-Square Cut'], lim_cuts.t['Ploss'],ms=5,color='r',marker='o',label='Loss')
    plt.axvline(x=loss_lim_cut,color='r',label='Loss cut')
    plt.axvspan(loss_lim_cut, cut_stop, alpha=0.2, color='r')

    plt.axhline(contam_lim,linewidth=1,color='g',linestyle='--',label='Contamination limit')
    plt.plot(lim_cuts.t['PSF Chi-Square Cut'], lim_cuts.t['Pcontamination'],ms=5,color='g',marker='o',label='Contamination')
    plt.axvline(x=contam_lim_cut,color='g',label='Contamination cut')
    plt.axvspan(cut_start, contam_lim_cut, alpha=0.2, color='g')
    
    plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left')
    #fig.savefig('000001.png',bbox_inches="tight",dpi=200)

def choose_btwn_lim_cuts(contam_lim_cut, loss_lim_cut, contam_case, loss_case):
    # case 1 and 1: final_cut = 3
    # case 1 and 2: take limit of case 2
    # case 1 and 3: take limit of case 3
    # case 2 and 2: print lims don't work
    # case 2 and 3: choose_btwn_lim_cuts
    # case 3 and 3: choose_btwn_lim_cuts

    case1 = loss_case == 'below lim' or contam_case == 'below lim'
    case2 = loss_case == 'above lim' or contam_case == 'above lim'
    case3 = loss_case == 'crosses lim' or contam_case == 'crosses lim'

    final_cut = None
    if case1 and not case2 and not case3: # 1 and 1
        print('Valid chi-square cut range from %0.2f to %0.2f! Setting to 3...' % (loss_lim_cut, contam_lim_cut))
        final_cut = cut_start
    elif case1: # 1
        if case2: # and 2
            if loss_case == 'above lim':
                print('WARNING: contam_lim_cut <= %0.2f falls below limit %0.2f%%, but loss_lim_cut >= %0.2f falls above limit %0.2f%%! Setting to %0.2f...' % (contam_lim_cut, contam_lim, loss_lim_cut, loss_lim, loss_lim_cut))
                final_cut = loss_lim_cut
            else:
                print('WARNING: loss_lim_cut <= %0.2f falls below limit %0.2f%%, but contam_lim_cut >= %0.2f falls above limit %0.2f%%! Setting to %0.2f...' % (loss_lim_cut, loss_lim, contam_lim_cut, contam_lim, contam_lim_cut))
                final_cut = contam_lim_cut
        else: # and 3
            if loss_case == 'crosses lim':
                print('Contam_lim_cut <= %0.2f falls below limit %0.2f%% and loss_lim_cut >= %0.2f crosses limit %0.2f%%, setting to %0.2f...' % (contam_lim_cut, contam_lim, loss_lim_cut, loss_lim, loss_lim_cut))
                final_cut = loss_lim_cut
            else:
                print('Loss_lim_cut <= %0.2f falls below limit %0.2f%% and contam_lim_cut >= %0.2f crosses limit %0.2f%%, setting to %0.2f...' % (loss_lim_cut, loss_lim, contam_lim_cut, contam_lim, contam_lim_cut))
                final_cut = contam_lim_cut
    elif case2 and not case3: # 2 and 2
        print('ERROR: chi-square loss_lim_cut >= %0.2f and contam_lim_cut <= %0.2f both fall above limits %0.2f%% and %0.2f%%! Try setting less strict limits. Setting final cut to nan.' % (loss_lim_cut, contam_lim_cut, loss_lim, contam_lim))
        final_cut = np.nan
    else: # 2 and 3 or 3 and 3
        if loss_lim_cut > contam_lim_cut:
            print('WARNING: chi-square loss_lim_cut >= %0.2f and contam_lim_cut <= %0.2f do not overlap! ' % (loss_lim_cut, contam_lim_cut))
            if lim_to_prioritize == 'contam_lim':
                print('Prioritizing %s and setting to %0.2f...' % (lim_to_prioritize, contam_lim_cut))
                final_cut = contam_lim_cut
            else:
                print('Prioritizing %s and setting to %0.2f... ' % (lim_to_prioritize, loss_lim_cut))
                final_cut = loss_lim_cut
        else:
            print('Valid chi-square cut range from %0.2f to %0.2f! ' % (loss_lim_cut, contam_lim_cut))
            if lim_to_prioritize == 'contam_lim':
                print('Prioritizing %s and setting to %0.2f... ' % (lim_to_prioritize, loss_lim_cut))
                final_cut = loss_lim_cut
            else:
                print('Prioritizing %s and setting to %0.2f... ' % (lim_to_prioritize, contam_lim_cut))
                final_cut = contam_lim_cut
    return final_cut

def get_lim_cuts_data(cut, colname, baseline_ix=None):
    if baseline_ix is None:
        baseline_ix = lc_info['baseline_rev_ix'] 

    b_good_i = lc_info['lc'].ix_inrange(colnames=['uJy/duJy'],lowlim=-stn_cut,uplim=stn_cut,indices=baseline_ix)
    b_bad_i = AnotB(baseline_ix, b_good_i)
    b_kept_i = lc_info['lc'].ix_inrange(colnames=['chi/N'],uplim=cut,indices=baseline_ix)
    b_cut_i = AnotB(baseline_ix, b_kept_i)

    lc_info['%s_Ngood' % colname] = len(b_good_i)
    lc_info['%s_Nbad' % colname] = len(b_bad_i)
    lc_info['%s_Nkept' % colname] = len(b_kept_i)
    lc_info['%s_Ncut' % colname] = len(b_cut_i)
    lc_info['%s_Ngood,kept' % colname] = len(AandB(b_good_i,b_kept_i))
    lc_info['%s_Ngood,cut' % colname] = len(AandB(b_good_i,b_cut_i))
    lc_info['%s_Nbad,kept' % colname] = len(AandB(b_bad_i,b_kept_i))
    lc_info['%s_Nbad,cut' % colname] = len(AandB(b_bad_i,b_cut_i))
    lc_info['%s_Pgood,kept' % colname] = 100*len(AandB(b_good_i,b_kept_i))/len(baseline_ix)
    lc_info['%s_Pgood,cut' % colname] = 100*len(AandB(b_good_i,b_cut_i))/len(baseline_ix)
    lc_info['%s_Pbad,kept' % colname] = 100*len(AandB(b_bad_i,b_kept_i))/len(baseline_ix)
    lc_info['%s_Pbad,cut' % colname] = 100*len(AandB(b_bad_i,b_cut_i))/len(baseline_ix)
    lc_info['%s_Ngood,kept/Ngood' % colname] = 100*len(AandB(b_good_i,b_kept_i))/len(b_good_i)
    lc_info['%s_Ploss' % colname] = 100*len(AandB(b_good_i,b_cut_i))/len(b_good_i)
    lc_info['%s_Pcontamination' % colname] = 100*len(AandB(b_bad_i,b_kept_i))/len(b_kept_i)

def get_lim_cuts(lim_cuts): 
    contam_lim_cut = None
    loss_lim_cut = None
    contam_case = None
    loss_case = None

    sortby_loss = lim_cuts.t.iloc[(lim_cuts.t['Ploss']).argsort()].reset_index()
    min_loss = sortby_loss.loc[0,'Ploss']
    max_loss = sortby_loss.loc[len(sortby_loss)-1,'Ploss']
    # if all loss below lim, loss_lim_cut is min cut
    if min_loss < loss_lim and max_loss < loss_lim:
        loss_case = 'below lim'
        loss_lim_cut = lim_cuts.t.loc[0,'PSF Chi-Square Cut']
    else:
        # else if all loss above lim, loss_lim_cut is min cut with min% loss
        if min_loss > loss_lim and max_loss > loss_lim:
            loss_case = 'above lim'
            a = np.where(lim_cuts.t['Ploss'] == min_loss)[0]
            b = lim_cuts.t.iloc[a]
            c = b.iloc[(b['PSF Chi-Square Cut']).argsort()].reset_index()
            loss_lim_cut = c.loc[0,'PSF Chi-Square Cut']
        # else if loss crosses lim at some point, loss_lim_cut is min cut with max% loss <= loss_lim
        else:
            loss_case = 'crosses lim'
            valid_cuts = sortby_loss[sortby_loss['Ploss'] <= loss_lim]
            a = np.where(lim_cuts.t['Ploss'] == valid_cuts.loc[len(valid_cuts)-1,'Ploss'])[0]
            # sort by cuts
            b = lim_cuts.t.iloc[a]
            c = b.iloc[(b['PSF Chi-Square Cut']).argsort()].reset_index()
            # get midpoint of loss1 and loss2 (two points on either side of lim)
            loss1_i = np.where(lim_cuts.t['PSF Chi-Square Cut'] == c.loc[0,'PSF Chi-Square Cut'])[0][0]
            if lim_cuts.t.loc[loss1_i,'Ploss'] == loss_lim:
                loss_lim_cut = lim_cuts.t.loc[loss1_i,'PSF Chi-Square Cut']
            else:
                loss2_i = loss1_i - 1
                x = np.array([lim_cuts.t.loc[loss1_i,'PSF Chi-Square Cut'], lim_cuts.t.loc[loss2_i,'PSF Chi-Square Cut']])
                contam_y = np.array([lim_cuts.t.loc[loss1_i,'Pcontamination'], lim_cuts.t.loc[loss2_i,'Pcontamination']])
                loss_y = np.array([lim_cuts.t.loc[loss1_i,'Ploss'], lim_cuts.t.loc[loss2_i,'Ploss']])
                contam_line = np.polyfit(x,contam_y,1)
                loss_line = np.polyfit(x,loss_y,1)
                loss_lim_cut = (loss_lim-loss_line[1])/loss_line[0]

    sortby_contam = lim_cuts.t.iloc[(lim_cuts.t['Pcontamination']).argsort()].reset_index()
    min_contam = sortby_contam.loc[0,'Pcontamination']
    max_contam = sortby_contam.loc[len(sortby_contam)-1,'Pcontamination']
    # if all contam below lim, contam_lim_cut is max cut
    if min_contam < contam_lim and max_contam < contam_lim:
        contam_case = 'below lim'
        contam_lim_cut = lim_cuts.t.loc[len(lim_cuts.t)-1,'PSF Chi-Square Cut']
    else:
        # else if all contam above lim, contam_lim_cut is max cut with min% contam
        if min_contam > contam_lim and max_contam > contam_lim:
            contam_case = 'above lim'
            a = np.where(lim_cuts.t['Pcontamination'] == min_contam)[0]
            b = lim_cuts.t.iloc[a]
            c = b.iloc[(b['PSF Chi-Square Cut']).argsort()].reset_index()
            contam_lim_cut = c.loc[len(c)-1,'PSF Chi-Square Cut']
        # else if contam crosses lim at some point, contam_lim_cut is max cut with max% contam <= contam_lim
        else:
            contam_case = 'crosses lim'
            valid_cuts = sortby_contam[sortby_contam['Pcontamination'] <= contam_lim]
            a = np.where(lim_cuts.t['Pcontamination'] == valid_cuts.loc[len(valid_cuts)-1,'Pcontamination'])[0]
            # sort by cuts
            b = lim_cuts.t.iloc[a]
            c = b.iloc[(b['PSF Chi-Square Cut']).argsort()].reset_index()
            # get midpoint of contam1 and contam2 (two points on either side of lim)
            contam1_i = np.where(lim_cuts.t['PSF Chi-Square Cut'] == c.loc[len(c)-1,'PSF Chi-Square Cut'])[0][0]
            if lim_cuts.t.loc[contam1_i,'Pcontamination'] == contam_lim:
                contam_lim_cut = lim_cuts.t.loc[contam1_i,'PSF Chi-Square Cut']
            else:
                contam2_i = contam1_i + 1
                x = np.array([lim_cuts.t.loc[contam1_i,'PSF Chi-Square Cut'], lim_cuts.t.loc[contam2_i,'PSF Chi-Square Cut']])
                contam_y = np.array([lim_cuts.t.loc[contam1_i,'Pcontamination'], lim_cuts.t.loc[contam2_i,'Pcontamination']])
                loss_y = np.array([lim_cuts.t.loc[contam1_i,'Ploss'], lim_cuts.t.loc[contam2_i,'Ploss']])
                contam_line = np.polyfit(x,contam_y,1)
                loss_line = np.polyfit(x,loss_y,1)
                contam_lim_cut = (contam_lim-contam_line[1])/contam_line[0]

    get_lim_cuts_data(loss_lim_cut, 'loss_lim_cut', baseline_ix=lc_info['baseline_rev_ix'])
    get_lim_cuts_data(contam_lim_cut, 'contam_lim_cut', baseline_ix=lc_info['baseline_rev_ix'])

    return contam_lim_cut, loss_lim_cut, contam_case, loss_case

def get_lim_cuts_table(stn_cut, cut_start, cut_stop, cut_step, baseline_ix=None):
    print('abs(uJy/duJy) cut at %0.2f \nx2 cut from %0.2f to %0.2f inclusive, with step size %d' % (stn_cut,cut_start,cut_stop,cut_step))

    if baseline_ix is None: 
        baseline_ix = lc_info['baseline_rev_ix']

    lim_cuts = pdastrostatsclass(columns=['PSF Chi-Square Cut', 'N', 'Ngood', 'Nbad', 'Nkept', 'Ncut', 'Ngood,kept', 'Ngood,cut', 'Nbad,kept', 'Nbad,cut',
                                          'Pgood,kept', 'Pgood,cut', 'Pbad,kept', 'Pbad,cut', 'Ngood,kept/Ngood', 'Ploss', 'Pcontamination',
                                          'Nbad,cut 3<stn<=5', 'Nbad,cut 5<stn<=10', 'Nbad,cut 10<stn', 'Nbad,kept 3<stn<=5', 'Nbad,kept 5<stn<=10', 'Nbad,kept 10<stn'])
    
    # static cut at x2 = 50
    x2cut_50 = np.where(lc_info['lc'].t['chi/N'] < 50)[0]
    print('Static chi square cut at 50: %0.2f%% cut for baseline' % (100*len(AnotB(baseline_ix,x2cut_50))/len(baseline_ix)))

    # good baseline measurement indices
    b_good_i = lc_info['lc'].ix_inrange(colnames=['uJy/duJy'],lowlim=-stn_cut,uplim=stn_cut,indices=baseline_ix)
    b_bad_i = AnotB(baseline_ix, b_good_i)
    # for different x2 cuts decreasing from 50
    for cut in range(cut_start,cut_stop+1,cut_step):
        # kept baseline measurement indices
        b_kept_i = lc_info['lc'].ix_inrange(colnames=['chi/N'],uplim=cut,indices=baseline_ix)
        b_cut_i = AnotB(baseline_ix, b_kept_i)

        if 100*(len(b_kept_i)/len(baseline_ix)) < 10:
            # less than 10% of measurements kept, so no chi-square cuts beyond this point are valid
            print(f'# At cut {cut}, less than 10% of measurements are kept ({100*(len(b_kept_i)/len(baseline_ix)):0.2f}% kept)--skipping...')
            continue
        else: 
            df = pd.DataFrame([[cut, len(baseline_ix), # N
                                len(b_good_i), # Ngood
                                len(b_bad_i), # Nbad
                                len(b_kept_i), # Nkept
                                len(b_cut_i), # Ncut
                                len(AandB(b_good_i,b_kept_i)), # Ngood,kept
                                len(AandB(b_good_i,b_cut_i)), # Ngood,cut
                                len(AandB(b_bad_i,b_kept_i)), # Nbad,kept
                                len(AandB(b_bad_i,b_cut_i)), # Nbad,cut
                                100*len(AandB(b_good_i,b_kept_i))/len(baseline_ix), # Ngood,kept/Nbaseline
                                100*len(AandB(b_good_i,b_cut_i))/len(baseline_ix), # Ngood,cut/Nbaseline 
                                100*len(AandB(b_bad_i,b_kept_i))/len(baseline_ix), # Nbad,kept/Nbaseline
                                100*len(AandB(b_bad_i,b_cut_i))/len(baseline_ix), # Nbad,cut/Nbaseline
                                100*len(AandB(b_good_i,b_kept_i))/len(b_good_i), # Ngood,kept/Ngood
                                100*len(AandB(b_good_i,b_cut_i))/len(b_good_i), # Ngood,cut/Ngood = Loss
                                100*len(AandB(b_bad_i,b_kept_i))/len(b_kept_i), # Nbad,kept/Nkept = Contamination
                                len(AandB(AandB(b_bad_i,b_cut_i), lc_info['lc'].ix_inrange(colnames=['uJy/duJy'],lowlim=-3,uplim=5,exclude_lowlim=True))), # Nbad,cut 3<stn<=5
                                len(AandB(AandB(b_bad_i,b_cut_i), lc_info['lc'].ix_inrange(colnames=['uJy/duJy'],lowlim=5,uplim=10,exclude_lowlim=True))), # Nbad,cut 5<stn<=10
                                len(AandB(AandB(b_bad_i,b_cut_i), lc_info['lc'].ix_inrange(colnames=['uJy/duJy'],lowlim=10,exclude_lowlim=True))), # Nbad,cut 10<stn 
                                len(AandB(AandB(b_bad_i,b_kept_i), lc_info['lc'].ix_inrange(colnames=['uJy/duJy'],lowlim=-3,uplim=5,exclude_lowlim=True))), # Nbad,kept 3<stn<=5
                                len(AandB(AandB(b_bad_i,b_kept_i), lc_info['lc'].ix_inrange(colnames=['uJy/duJy'],lowlim=5,uplim=10,exclude_lowlim=True))), # Nbad,kept 5<stn<=10
                                len(AandB(AandB(b_bad_i,b_kept_i), lc_info['lc'].ix_inrange(colnames=['uJy/duJy'],lowlim=10,exclude_lowlim=True))), # Nbad,kept 10<stn 
                                ]], columns=['PSF Chi-Square Cut', 'N', 'Ngood', 'Nbad', 'Nkept', 'Ncut', 'Ngood,kept', 'Ngood,cut', 'Nbad,kept', 'Nbad,cut',
                                            'Pgood,kept', 'Pgood,cut', 'Pbad,kept', 'Pbad,cut', 'Ngood,kept/Ngood', 'Ploss', 'Pcontamination',
                                            'Nbad,cut 3<stn<=5', 'Nbad,cut 5<stn<=10', 'Nbad,cut 10<stn', 'Nbad,kept 3<stn<=5', 'Nbad,kept 5<stn<=10', 'Nbad,kept 10<stn'])
            lim_cuts.t = pd.concat([lim_cuts.t,df],ignore_index=True)
    return lim_cuts

if lim_to_prioritize != 'loss_lim' and lim_to_prioritize != 'contam_lim':
    print("ERROR: lim_to_prioritize must be 'loss_lim' or 'contam_lim'!")
    sys.exit()
print('Contamination limit: %0.2f%%\nLoss limit: %0.2f%%' % (contam_lim,loss_lim))

lim_cuts = get_lim_cuts_table(stn_cut, cut_start, cut_stop, cut_step, baseline_ix=lc_info['baseline_rev_ix'])
if lim_cuts.t.empty:
    raise RuntimeError('No cuts kept more than 10%% of measurements--chi-square cut not applicable for this SN!')

contam_lim_cut, loss_lim_cut, contam_case, loss_case = get_lim_cuts(lim_cuts)
lc_info['contam_case'] = contam_case
lc_info['loss_case'] = loss_case
lc_info['contam_lim_cut'] = contam_lim_cut
lc_info['loss_lim_cut'] = loss_lim_cut

print('\nContamination cut according to given contam_limit, with %0.2f%% contamination and %0.2f%% loss: %0.2f' % (lc_info['contam_lim_cut_Pcontamination'], lc_info['contam_lim_cut_Ploss'], contam_lim_cut))
if lc_info['contam_case'] == 'above lim':
    print('WARNING: Contamination cut not possible with contamination <= contam_lim %0.1f!' % contam_lim)
print('Loss cut according to given loss_limit, with %0.2f%% contamination and %0.2f%% loss: %0.2f' % (lc_info['loss_lim_cut_Pcontamination'], lc_info['loss_lim_cut_Ploss'], loss_lim_cut))
if lc_info['loss_case'] == 'above lim':
    print('WARNING: Loss cut not possible with loss <= loss_lim %0.2f!' % loss_lim)

final_cut = choose_btwn_lim_cuts(contam_lim_cut, loss_lim_cut, contam_case, loss_case)
lc_info['final_cut'] = final_cut
        
if np.isnan(final_cut):
    print('\nERROR: Final suggested chi-square cut could not be determined. We suggest rethinking your contamination and loss limits.')
    lc_info['Pcontamination'] = np.nan
    lc_info['Ploss'] = np.nan
else:
    if final_cut==contam_lim_cut:
        lc_info['Pcontamination'] = lc_info['contam_lim_cut_Pcontamination']
        lc_info['Ploss'] = lc_info['contam_lim_cut_Ploss']
    else:
        lc_info['Pcontamination'] = lc_info['loss_lim_cut_Pcontamination']
        lc_info['Ploss'] = lc_info['loss_lim_cut_Ploss']
    print('\nFinal suggested chi-square cut is %0.2f, with %0.2f%% contamination and %0.2f%% loss.' % (final_cut, lc_info['Pcontamination'], lc_info['Ploss']))
    if (lc_info['Pcontamination'] > contam_lim):
        print('WARNING: Final cut\'s contamination %0.2f%% exceeds contam_lim %0.2f%%!' % (lc_info['Pcontamination'],contam_lim))
    if (lc_info['Ploss'] > loss_lim):
        print('WARNING: Final cut\'s loss exceeds %0.2f%% loss_lim %0.2f%%!' % (lc_info['Ploss'],loss_lim))

plot_lim_cuts(lim_cuts, contam_lim_cut, loss_lim_cut)

### 5c) Confirm or override the final cut, apply final cut, and optionally plot

In [None]:
# Plot the light curve before and after the applied chi-square cut?:
plot = False

# Optionally, manually enter the x and y limits for the chi-square cut plot:
xlim_lower = None
xlim_upper = None
ylim_lower = None
ylim_upper = None

In [None]:
# confirm the final chi-square cut and update 'Mask' column

answer = input('Accept final chi-square cut of %0.2f (y/n):' % float(lc_info['final_cut']))
if answer != 'y':
    final_cut_override = float(input('Overriding final chi-square cut; enter manual cut: '))

    b_good_i = lc_info['lc'].ix_inrange(colnames=['uJy/duJy'],lowlim=-stn_cut,uplim=stn_cut,indices=lc_info['baseline_i'])
    b_bad_i = AnotB(lc_info['baseline_i'], b_good_i)
    b_kept_i = lc_info['lc'].ix_inrange(colnames=['chi/N'],uplim=final_cut_override,indices=lc_info['baseline_i'])
    b_cut_i = AnotB(lc_info['baseline_i'], b_kept_i)
    lc_info['Ploss'] = 100*len(AandB(b_good_i,b_cut_i))/len(b_good_i)
    lc_info['Pcontamination'] = 100*len(AandB(b_bad_i,b_kept_i))/len(b_kept_i)
    lc_info['final_cut'] = final_cut_override

    print('Overridden: final cut is now %0.2f, with contamination %0.2f%% and loss %0.2f%%' % (lc_info['final_cut'],lc_info['Pcontamination'],lc_info['Ploss']))

print(f'Applying chi-square cut of {lc_info["final_cut"]:0.2f}...')
kept_ix = lc_info['lc'].ix_inrange(colnames=['chi/N'],uplim=lc_info['final_cut'])
cut_ix = AnotB(lc_info['lc'].getindices(), kept_ix)
update_mask_col(lc_info['lc'].t, flag_chisquare, cut_ix)
print('Success')

num_cut = len(cut_ix)
percent_cut = 100 * num_cut/len(lc_info['lc'].t)
print_statistics(num_cut, percent_cut, flag_chisquare, add2title='chi-square cut')
f.write(f'\n\n## Chi-square cut\nNumber of cut measurements: {num_cut:d}\nPercent of total measurements cut: {percent_cut:0.2f}%\nHex value in "Mask" column: 0x1')

## Step 6: Control light curves:  3σ-clipped average cut

While the chi-square and uncertainty cuts are effective in cutting out a majority of the bad measurements, tricky cases may require a larger set of control light curves that can be used as a basis of comparison for inconsistent flux. In order to account for this inconsistent flux, we can obtain ~8 quality control forced photometry light curves in a 17" circle pattern around the SN location OR around a nearby bright object that may be poorly subtracting. Then, we use statistics from these control light curves to cut bad measurements from the SN light curve.

For a given epoch, we have 1 SN measurement for which we examine 8 control measurements within the same epoch. We know that if the control light curve measurements are NOT consistent with 0, this indicates something wrong with this epoch, so the SN measurement is unreliable. Therefore, we obtain statistics for the control light curves by calculating the 3σ-clipped average of the control flux. 

For the given epoch, we cut the SN measurement for which the returned control statistics fulfill any of the following criteria: 
- A returned chi-square > 2.5
- A returned abs(flux/dflux) > 3.0
- Number of clipped/"bad" measurements in the 3σ-clipped average > 2
- Number of used/"good" measurements in the 3σ-clipped average < 4

Measurements not fulfilling any of the criteria above but with Nclip > 0 are flagged as questionable.

In [None]:
# Enter the bound for an epoch's maximum chi-square 
# (if x2 > x2_max, flag SN measurement):
x2_max = 2.5

# Enter the bound for an epoch's maximum abs(flux/dflux) ratio 
# (if abs(flux/dflux) > stn_max, flag SN measurement):
stn_max = 3.0

# Enter the bound for an epoch's maximum number of clipped control measurements
# (if Nclip > Nclip_max, flag SN measurement):
Nclip_max = 2

# Enter the bound for an epoch's minimum number of good control measurements
# (if Ngood < Ngood_min, flag SN measurement):
Ngood_min = 4

# Plot the light curve before and after the applied uncertainty cut?:
plot = False

# Optionally, manually enter the x and y limits for the control light curve cut plot:
xlim_lower = None
xlim_upper = None
ylim_lower = None
ylim_upper = None

In [None]:
# apply control light curve cut and save flags in 'Mask' column

def get_control_stats(Ncontrols):
    print('\nCalculating control light curve statistics...')

    # construct arrays for control lc data
    uJy = np.full((Ncontrols, len(lc_info['lc'].t['MJD'])), np.nan)
    duJy = np.full((Ncontrols, len(lc_info['lc'].t['MJD'])), np.nan)
    Mask = np.full((Ncontrols, len(lc_info['lc'].t['MJD'])), 0, dtype=np.int32)
    
    for control_index in range(1,Ncontrols+1):
        if (len(controls[control_index].t) != len(lc_info['lc'].t['MJD'])) or (np.array_equal(lc_info['lc'].t['MJD'], controls[control_index].t['MJD']) is False):
            raise RuntimeError(f'ERROR: SN lc not equal to control lc for control_index {control_index}! Rerun or debug verify_mjds().')
        else:
            uJy[control_index-1,:] = controls[control_index].t['uJy']
            duJy[control_index-1,:] = controls[control_index].t[dflux_colnames[control_index]]
            Mask[control_index-1,:] = controls[control_index].t['Mask']

    c2_param2columnmapping = lc_info['lc'].intializecols4statparams(prefix='c2_',format4outvals='{:.2f}',skipparams=['converged','i'])

    for index in range(uJy.shape[-1]):
        pda4MJD = pdastrostatsclass()
        pda4MJD.t['uJy'] = uJy[0:,index]
        pda4MJD.t[dflux_colnames[0]] = duJy[0:,index]
        pda4MJD.t['Mask'] = np.bitwise_and(Mask[0:,index], flag_chisquare|flag_uncertainty)
        
        pda4MJD.calcaverage_sigmacutloop('uJy',noisecol=dflux_colnames[0],maskcol='Mask',maskval=(flag_chisquare|flag_uncertainty),verbose=1,Nsigma=3.0,median_firstiteration=True)
        lc_info['lc'].statresults2table(pda4MJD.statparams, c2_param2columnmapping, destindex=index)

def controls_cut():
    print('Flagging SN light curve based on control light curve statistics...')

    lc_info['lc'].t['c2_abs_stn'] = lc_info['lc'].t['c2_mean']/lc_info['lc'].t['c2_mean_err']

    # flag measurements according to given bounds
    flag_x2_i = lc_info['lc'].ix_inrange(colnames=['c2_X2norm'], lowlim=x2_max, exclude_lowlim=True)
    flag_stn_i = lc_info['lc'].ix_inrange(colnames=['c2_abs_stn'], lowlim=stn_max, exclude_lowlim=True)
    flag_nclip_i = lc_info['lc'].ix_inrange(colnames=['c2_Nclip'], lowlim=Nclip_max, exclude_lowlim=True)
    flag_ngood_i = lc_info['lc'].ix_inrange(colnames=['c2_Ngood'], uplim=Ngood_min, exclude_uplim=True)
    lc_info['lc'].t.loc[flag_x2_i,'Mask'] |= flag_controls_x2
    lc_info['lc'].t.loc[flag_stn_i,'Mask'] |= flag_controls_stn
    lc_info['lc'].t.loc[flag_nclip_i,'Mask'] |= flag_controls_Nclip
    lc_info['lc'].t.loc[flag_ngood_i,'Mask'] |= flag_controls_Ngood

    # update mask column with control light curve cut on any measurements flagged according to given bounds
    zero_Nclip_i = lc_info['lc'].ix_equal('c2_Nclip', 0)
    unmasked_i = lc_info['lc'].ix_unmasked('Mask', maskval=flag_controls_x2|flag_controls_stn|flag_controls_Nclip|flag_controls_Ngood)
    lc_info['lc'].t.loc[AnotB(unmasked_i,zero_Nclip_i),'Mask'] |= flag_controls_questionable
    lc_info['lc'].t.loc[AnotB(lc_info['lc'].getindices(),unmasked_i),'Mask'] |= flag_controls_bad

    # copy over SN's control cut flags to control light curve 'Mask' column
    flags_arr = np.full(lc_info['lc'].t['Mask'].shape, (flag_badday|flag_controls_questionable|flag_controls_x2|flag_controls_stn|flag_controls_Nclip|flag_controls_Ngood))
    flags_to_copy = np.bitwise_and(lc_info['lc'].t['Mask'], flags_arr)
    for control_index in range(1,Ncontrols+1):
        controls[control_index].t['Mask'] = controls[control_index].t['Mask'].astype(np.int32)
        if len(controls[control_index].t) < 1:
            continue
        elif len(controls[control_index].t) == 1:
            controls[control_index].t.loc[0,'Mask']= int(controls[control_index].t.loc[0,'Mask']) | flags_to_copy
        else:
            controls[control_index].t['Mask'] = np.bitwise_or(controls[control_index].t['Mask'], flags_to_copy)

def print_flag_stats():
    percent_cut = 100 * len(lc_info['lc'].ix_masked('Mask', maskval=flag_controls_bad)) / len(lc_info['lc'].t) 
    percent_questionable = 100 * len(lc_info['lc'].ix_masked('Mask', maskval=flag_controls_questionable)) / len(lc_info['lc'].t)

    x2_max_pcnt = 100 * len(lc_info['lc'].ix_masked('Mask', maskval=flag_controls_x2)) / len(lc_info['lc'].t)
    stn_max_pcnt = 100 * len(lc_info['lc'].ix_masked('Mask', maskval=flag_controls_stn)) / len(lc_info['lc'].t)
    Nclip_max_pcnt = 100 * len(lc_info['lc'].ix_masked('Mask', maskval=flag_controls_Nclip)) / len(lc_info['lc'].t)
    Ngood_min_pcnt = 100 * len(lc_info['lc'].ix_masked('Mask', maskval=flag_controls_Ngood)) / len(lc_info['lc'].t)

    print('\nLength of SN light curve: %d' % len(lc_info['lc'].t))
    print('Percent of data above x2_max bound: %0.2f%%' % x2_max_pcnt)
    print('Percent of data above stn_max bound: %0.2f%%' % stn_max_pcnt)
    print('Percent of data above Nclip_max bound: %0.2f%%' % Nclip_max_pcnt)
    print('Percent of data below Ngood_min bound: %0.2f%%' % Ngood_min_pcnt)
    print('Total percent of data flagged as bad: %0.2f%%' % percent_cut)
    print('Total percent of data flagged as questionable: %0.2f%%' % percent_questionable)

    f.write(f'\n\n## Control light curve cut\nPercent of data above x2_max bound: {x2_max_pcnt:0.2f}%\nPercent of data above stn_max bound: {stn_max_pcnt:0.2f}%\nPercent of data above Nclip_max bound: {Nclip_max_pcnt:0.2f}%\nPercent of data below Ngood_min bound: {Ngood_min_pcnt:0.2f}%')
    f.write(f'\nTotal percent of data flagged as bad: {percent_cut:0.2f}%\nTotal percent of data flagged as questionable: {percent_questionable:0.2f}%\nHex value in "Mask" column (flagged as "bad"): 0x400000\nHex value in "Mask" column (flagged as "questionable"): 0x80000')

    if plot or percent_cut > 10:
        limits = [xlim_lower, xlim_upper, ylim_lower, ylim_upper]
        plot_cut_lc('lc', flag_controls_bad, dflux_colnames[0], add2title='control light curve cut', limits=limits)

if load_controls:    
    get_control_stats(Ncontrols)
    controls_cut()
    print_flag_stats()
else:
    print('Load_controls set to False! Skipping...')

## Plot the ATLAS light curve with a combination of previous cuts

In [None]:
# Select the flags you would like to use to plot kept vs. clean measurements
#   - uncertainty cut flag: flag_uncertainty
#   - chi-square cut flag: flag_chisquare
#   - control light curve cut flag: flag_controls_bad
flags = flag_uncertainty | flag_chisquare | flag_controls_bad

# Optionally, manually enter the plot's x and y limits:
xlim_lower = None
xlim_upper = None
ylim_lower = None
ylim_upper = None

In [None]:
# plot the light curve

limits = [xlim_lower, xlim_upper, ylim_lower, ylim_upper]
plot_cut_lc('lc', flags, dflux_colnames[0], add2title='\nuncertainty, chi-square, and control light curve cuts', limits=limits)

## Step 7: Averaging and cutting bad bins

Our goal is to identify and cut out bad MJD bins by taking a 3σ-clipped average of each bin. For each bin, we calculate the 3σ-clipped average of any SN measurements falling within that bin and use that average as our flux for that bin. Because the ATLAS survey takes about 4 exposures every 2 days, we usually average together approximately 4 measurements per epoch. However, out of these 4 exposures, only measurements not cut in the previous methods are averaged in the 3σ-clipped average cut. (The exception to this statement would be the case that all 4 measurements are cut in previous methods; in this case, they are averaged anyway and flagged as a bad bin.)

Then we cut any measurements in the SN light curve for the given epoch for which statistics fulfill any of the following criteria: 
- A returned chi-square > 4.0
- Number of measurements averaged < 2
- Number of measurements clipped > 1

For this part of the cleaning, we still need to improve the cutting at the peak of the SN (important epochs are sometimes cut, maybe due to fast rise, etc.).

In [None]:
# Enter the MJD bin size in days:
mjdbinsize = 1

# Should MJD bins with no measurements be translated as NaN (True) 
# or removed from the averaged light curve (False)?
keep_empty_bins = False

# After flux is averaged, average magnitudes are calculated using a flux-to-magnitude conversion.
# Magnitudes are limits if the dmagnitude is NaN. Enter these magnitudes' sigma limit:
flux2mag_sigma_limit = 3

# Enter the bound for a bin's maximum number of clipped measurements
# (if Nclip > Nclip_max, flag day):
Nclip_max = 1

# Enter the bound for a bin's minimum number of good measurements
# (if Ngood < Ngood_min, flag day):
Ngood_min = 2

# Enter the bound for a bin's maximum chi-square (if x2 > x2_max, flag day):
x2_max = 4.0

# Optionally, manually enter the x and y limits for the plot:
xlim_lower = None
xlim_upper = None
ylim_lower = None
ylim_upper = None

In [None]:
# average the light curve

def average_lc(Nclip_max, Ngood_min, x2_max, flag_badday, flag_ixclip, flag_smallnum, mjdbinsize=1, flux2mag_sigma_limit=3.0, keep_empty_bins=True):
    mjd = int(np.amin(lc_info['lc'].t['MJD']))
    mjd_max = int(np.amax(lc_info['lc'].t['MJD']))+1

    good_i = lc_info['lc'].ix_unmasked('Mask', maskval=flag_chisquare|flag_uncertainty)

    while mjd <= mjd_max:
        range_i = lc_info['lc'].ix_inrange(colnames=['MJD'], lowlim=mjd, uplim=mjd+mjdbinsize, exclude_uplim=True)
        range_good_i = AandB(range_i,good_i)

        # add new row to avglc if keep_empty_bins or any measurements present
        if keep_empty_bins or len(range_i) >= 1:
            new_row = {'MJDbin':mjd+0.5*mjdbinsize, 'Nclip':0, 'Ngood':0, 'Nexcluded':len(range_i)-len(range_good_i), 'Mask':0}
            avglc_index = lc_info['avglc'].newrow(new_row)
        
        # if no measurements present, flag or skip over day
        if len(range_i) < 1:
            if keep_empty_bins:
                update_mask_col(lc_info['avglc'].t, flag_badday, [avglc_index])
            mjd += mjdbinsize
            continue
        
        # if no good measurements, average values anyway and flag
        if len(range_good_i) < 1:
            # average flux
            lc_info['lc'].calcaverage_sigmacutloop('uJy', noisecol=dflux_colnames[0], indices=range_i, Nsigma=3.0, median_firstiteration=True)
            fluxstatparams = deepcopy(lc_info['lc'].statparams)

            # average mjd
            # SHOULD NOISECOL HERE BE DUJY OR NONE??
            lc_info['lc'].calcaverage_sigmacutloop('MJD', noisecol=dflux_colnames[0], indices=fluxstatparams['ix_good'], Nsigma=0, median_firstiteration=False)
            avg_mjd = lc_info['lc'].statparams['mean']

            # add row and flag
            lc_info['avglc'].add2row(avglc_index, {'MJD':avg_mjd, 
                                                   'uJy':fluxstatparams['mean'], 
                                                   'duJy':fluxstatparams['mean_err'], 
                                                   'stdev':fluxstatparams['stdev'],
                                                   'x2':fluxstatparams['X2norm'],
                                                   'Nclip':fluxstatparams['Nclip'],
                                                   'Ngood':fluxstatparams['Ngood'],
                                                   'Mask':0})
            update_mask_col(lc_info['lc'].t, flag_badday, range_i)
            update_mask_col(lc_info['avglc'].t, flag_badday, [avglc_index])

            mjd += mjdbinsize
            continue
        
        # average good measurements
        lc_info['lc'].calcaverage_sigmacutloop('uJy', noisecol=dflux_colnames[0], indices=range_good_i, Nsigma=3.0, median_firstiteration=True)
        fluxstatparams = deepcopy(lc_info['lc'].statparams)

        if fluxstatparams['mean'] is None or len(fluxstatparams['ix_good']) < 1:
            update_mask_col(lc_info['lc'].t, flag_badday, range_i)
            update_mask_col(lc_info['avglc'].t, flag_badday, [avglc_index])
            mjd += mjdbinsize
            continue

        # average mjd
        # SHOULD NOISECOL HERE BE DUJY OR NONE??
        lc_info['lc'].calcaverage_sigmacutloop('MJD', noisecol=dflux_colnames[0], indices=fluxstatparams['ix_good'], Nsigma=0, median_firstiteration=False)
        avg_mjd = lc_info['lc'].statparams['mean']

        # add row
        lc_info['avglc'].add2row(avglc_index, {'MJD':avg_mjd, 
                                                   'uJy':fluxstatparams['mean'], 
                                                   'duJy':fluxstatparams['mean_err'], 
                                                   'stdev':fluxstatparams['stdev'],
                                                   'x2':fluxstatparams['X2norm'],
                                                   'Nclip':fluxstatparams['Nclip'],
                                                   'Ngood':fluxstatparams['Ngood'],
                                                   'Mask':0})
        
        # flag clipped measurements in lc
        if len(fluxstatparams['ix_clip']) > 0:
            update_mask_col(lc_info['lc'].t, flag_ixclip, fluxstatparams['ix_clip'])
        
        # if small number within this bin, flag measurements
        if len(range_good_i) < 3:
            update_mask_col(lc_info['lc'].t, flag_smallnum, range_good_i) # CHANGE TO RANGE_I??
            update_mask_col(lc_info['avglc'].t, flag_smallnum, [avglc_index])
        # else check sigmacut bounds and flag
        else:
            is_bad = False
            if fluxstatparams['Ngood'] < Ngood_min:
                is_bad = True
            if fluxstatparams['Nclip'] > Nclip_max:
                is_bad = True
            if not(fluxstatparams['X2norm'] is None) and fluxstatparams['X2norm'] > x2_max:
                is_bad = True
            if is_bad:
                update_mask_col(lc_info['lc'].t, flag_badday, range_i)
                update_mask_col(lc_info['avglc'].t, flag_badday, [avglc_index])

        mjd += mjdbinsize

    # convert flux to magnitude and dflux to dmagnitude
    for col in ['uJy','duJy']: 
        lc_info['avglc'].t[col] = lc_info['avglc'].t[col].astype(float)
    lc_info['avglc'].flux2mag('uJy','duJy','m','dm', zpt=23.9, upperlim_Nsigma=flux2mag_sigma_limit)

    print('Success')

    drop_extra_columns('avglc')

    for col in ['Nclip','Ngood','Nexcluded','Mask']: 
        lc_info['avglc'].t[col] = lc_info['avglc'].t[col].astype(np.int32)

if len(lc_info['lc'].t) < 1:
    print('ERROR: No data in lc so cannot average; exiting... ')
    sys.exit()

print('Averaging light curve with the following criteria: Nclip_max = %d, Ngood_min = %d, x2_max = %0.2f... ' % (Nclip_max, Ngood_min, x2_max))
lc_info['avglc'] = pdastrostatsclass(columns=['MJD','MJDbin','uJy','duJy','stdev','x2','Nclip','Ngood','Nexcluded','Mask'],hexcols=['Mask'])
average_lc(Nclip_max, Ngood_min, x2_max, flag_badday, flag_ixclip, flag_smallnum, keep_empty_bins=keep_empty_bins)

# print statistics and plot
num_cut = len(lc_info['avglc'].ix_masked('Mask', flag_badday))
percent_cut = (num_cut/len(lc_info['avglc'].t)) * 100
print(f'\nNumber of cut measurements: {num_cut:d}\nPercent of total measurements cut: {percent_cut:0.2f}%')
f.write(f'\n\n## Averaging and cutting bad bins\nNumber of cut measurements: {num_cut}\nPercent of total measurements cut: {percent_cut:0.2f}%\nHex value in "Mask" column: 0x800000')
if percent_cut > 10:
    print(f'WARNING: percent of total measurements cut is greater than 10%')
limits = [xlim_lower, xlim_upper, ylim_lower, ylim_upper]
plot_cut_lc('avglc', flag_chisquare|flag_uncertainty|flag_badday, 'duJy', limits=limits)

## Optional: save the SN light curve with the new `'Mask'` and `'duJy_new'` columns

Hex values in the `'Mask'` column for each cut's flag:
- Uncertainty cut: 0x2
- Chi-square cut: 0x1
- Control light curve cut: 0x400000
- Bad day (for averaged light curves): 0x800000

You can combine these hex values together to create certain combinations of cuts that define a "bad" measurement.

In [None]:
# Save the light curve with the new 'Mask' column?
save = False

# Overwrite the old light curve file?
overwrite_old_lc = False

# If not overwriting old light curve file, enter new filename:
filename_new = ''

In [None]:
# save light curve

if save:
    print('Saving light curve with updated mask column...')
    if overwrite_old_lc:
        print('Overwriting old light curve file at %s... ' % filename)
        save_lc('lc',filename,overwrite=True)
    else:
        print('Writing new file at %s... ' % filename_new)
        save_lc('lc',filename_new,overwrite=True)
else:
    print('Did not save SN light curve')

## Summary

In [None]:
# print summary of all cuts and corrections

f.close()
f1 = open(f'{lc_info["tnsname"]}_output.md')
content = f1.read()
print()
print(content)
f1.close()