# Cleaning and Averaging ATLAS Light Curves
### Includes chi-square cut, uncertainty cut, control light curve cut, and averaging

This iPython notebook will help you apply each cut with a greater degree of control than an automatic cleaning. You will be walked through the following:
1. Static uncertainty cut
2. Estimating true uncertainties
3. Dynamic chi-square cut
4. Control light curve cut
5. Averaging the light curve and cutting bad bins
6. Optionally correct for ATLAS template changes
7. Save files

After running a cell, the descriptions located above that cell will help you interpret the plots and make decisions about the supernova.

In order for this notebook to work correctly, the ATLAS light curves must already be downloaded and saved. Each light curve must also only include measurements for a single filter.

## Step 1: Load the ATLAS light curves 

In [None]:
##### LOADING THE SN LIGHT CURVE #####

# Enter the target SN name:
tnsname = '2018gkr'

# Enter the path to the data directory that contains the SN directory:
source_dir = f'/Users/sofiarest/Desktop/Supernovae/data/temp'

# Enter the path to a directory to optionally save any plots:
output_dir = f'{source_dir}/{tnsname}/plots'

# Enter the filter for this light curve (must be 'o' or 'c'):
filt = 'o'

# Optionally, enter the SN's discovery date (if None is entered, it will be 
# fetched automatically from TNS using the API key, TNS ID, and bot name):
discovery_date = 58207.146991
api_key = None
tns_id = None
bot_name = None

##### LOADING CONTROL LIGHT CURVES #####

# Set to True if you are planning on applying the control light curve cut 
# and have already downloaded the control light curves:
load_controls = True

# Enter the number of control light curves to load:
n_controls = 8

# Enter the source directory of the control light curve files:
controls_dir = f'{source_dir}/{tnsname}/controls'

##### DEFAULT SETTINGS FOR PLOTTING #####

# If True, try to calculate the best y limits automatically for each plot;
# if False, leave y limits to matplotlib 
auto_xylimits = True

In [None]:
# import modules, set preliminary variables, etc.

from pdastro import pdastrostatsclass, AandB, AnotB, AorB, not_AandB
from atlas_lc import atlas_lc

import sys, os
import numpy as np
from copy import deepcopy
import pandas as pd

# plotting
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import warnings
warnings.simplefilter('error', RuntimeWarning)
warnings.filterwarnings("ignore")
# plotting styles
plt.rc('axes', titlesize = 17)
plt.rc('xtick', labelsize = 12)
plt.rc('ytick', labelsize = 12)
plt.rc('legend', fontsize = 10)
plt.rcParams['font.size'] = 12
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['axes.prop_cycle'] = matplotlib.cycler(color=['red', 'orange', 'green', 'blue', 'purple', 'magenta'])
matplotlib.rcParams['xtick.major.size'] = 6
matplotlib.rcParams['xtick.major.width'] = 1
matplotlib.rcParams['xtick.minor.size'] = 3
matplotlib.rcParams['xtick.minor.width'] = 1
matplotlib.rcParams['ytick.major.size'] = 6
matplotlib.rcParams['ytick.major.width'] = 1
matplotlib.rcParams['ytick.minor.size'] = 3
matplotlib.rcParams['ytick.minor.width'] = 1
matplotlib.rcParams['axes.linewidth'] = 1
marker_size = 30
marker_edgewidth = 1.5
sn_flux = 'orange' if filt == 'o' else 'cyan'
sn_flagged_flux = 'red' 
ctrl_flux = 'steelblue' #'darkgreen' #'limegreen' 
#select_ctrl_flux = 'darkgreen'

# ATLAS template change dates (MJD)
#tchange0 = 57500
tchange1 = 58417
tchange2 = 58882

# 'Mask' column flags
flags = {'chisquare':0x1, 
		 
		 'uncertainty':0x2,
		 
		 'controls_bad':0x400000,
		 'controls_questionable':0x80000,
		 'controls_x2':0x100,
		 'controls_stn':0x200,
		 'controls_Nclip':0x400,
		 'controls_Ngood':0x800,
		 
		 'avg_badday':0x800000,
		 'avg_ixclip':0x1000,
		 'avg_smallnum':0x2000}

In [None]:
# get discovery date if needed, load in light curve, account for template changes

if filt != 'o' and filt != 'c': 
	raise RuntimeError('Filter must be "o" or "c"!')

# new text file that will contain record of each cut, etc.
f = open(f'{source_dir}/{tnsname}/{tnsname}_output.md', 'w')
f.write(f'# SN {tnsname} Light Curve Cleaning and Averaging\n\nFilter: {filt}-band\nDiscovery date: {discovery_date}\nNumber of control light curves: {n_controls}')

# SN and control light curves
lc = atlas_lc(tnsname, discdate=discovery_date)
if lc.discdate is None:
	lc._get_tns_data(tnsname, api_key, tns_id, bot_name)
lc._load(source_dir, filt, n_controls)
lc.prep_for_cleaning()
print(lc)

## Plot SN and control light curves

In [None]:
# Plot control light curves underneath SN light curve?
plot_controls = True

# Optionally, manually enter the x and y limits:
xlim_lower = None
xlim_upper = None
ylim_lower = None
ylim_upper = None

In [None]:
# plot SN and control light curves

def do_manual_xylimits(limits):
    for limit in limits:
        if not limit is None:
            return True
    return False

def set_xylimits(lc, limits, control_index=0, indices=None):
    if auto_xylimits:
        if limits[0] is None:
            limits[0] = lc.lcs[control_index].t['MJD'].min() * 0.999
        if limits[1] is None:
            limits[1] = lc.lcs[control_index].t['MJD'].max() * 1.001
        
        if indices is None:
            indices = lc.get_ix()
        # exclude measurements with duJy > 160
        good_ix = lc.lcs[control_index].ix_inrange(colnames='duJy', uplim=160, indices=indices)
        # get 5% of abs(max flux - min flux)
        flux_min = lc.lcs[control_index].t.loc[good_ix, 'uJy'].min()
        flux_max = lc.lcs[control_index].t.loc[good_ix, 'uJy'].max()
        diff = abs(flux_max - flux_min)
        offset = 0.05 * diff

        if limits[2] is None: 
            limits[2] = flux_min - offset
        if limits[3] is None:
            limits[3] = flux_max + offset

        return limits
    
    if do_manual_xylimits(limits):
        return limits
    
    return None

def save_plot(save_filename=None):
    if not save_filename is None:
        filename = f'{output_dir}/{save_filename}.png'
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        print(f'Saving plot: {filename}')
        plt.savefig(filename, dpi=200)

def plot_all_lcs(lc, add2title=None, plot_controls=False, plot_templates=False, limits=None, save_filename=None):
    fig, ax1 = plt.subplots(1, constrained_layout=True)
    fig.set_figwidth(7)
    fig.set_figheight(4)

    title = f'SN {tnsname} & control light curves {filt}-band flux'
    if not(add2title is None):
        title += add2title
    ax1.set_title(title)

    ax1.minorticks_on()
    ax1.tick_params(direction='in', which='both')
    ax1.set_ylabel(r'Flux ($\mu$Jy)')
    ax1.set_xlabel('MJD')
    ax1.axhline(linewidth=1, color='k')

    # set x and y limits
    limits = set_xylimits(lc, limits)
    if not limits is None:
        ax1.set_xlim(limits[0],limits[1])
        ax1.set_ylim(limits[2],limits[3])

    if plot_templates:
        ax1.axvline(x=tchange1, color='k', linestyle='dotted', label='ATLAS template change', zorder=100)
        ax1.axvline(x=tchange2, color='k', linestyle='dotted', zorder=100)
        #ax1.axvline(x=tchange2, color='k', linestyle='dotted', zorder=100)

    preSN_ix = lc.get_pre_SN_ix()
    postSN_ix = lc.get_post_SN_ix()

    if load_controls and plot_controls:
        for control_index in range(1, lc.num_controls+1):
            plt.errorbar(lc.lcs[control_index].t['MJD'], lc.lcs[control_index].t['uJy'], yerr=lc.lcs[control_index].t[lc.dflux_colnames[control_index]], fmt='none', ecolor=ctrl_flux, elinewidth=1.5, capsize=1.2, c=ctrl_flux, alpha=0.5, zorder=0)
            if control_index == 1:
                plt.scatter(lc.lcs[control_index].t['MJD'], lc.lcs[control_index].t['uJy'], s=marker_size, color=ctrl_flux, marker='o', alpha=0.5, zorder=0, label=f'{lc.num_controls} control light curves')
            else:
                plt.scatter(lc.lcs[control_index].t['MJD'], lc.lcs[control_index].t['uJy'], s=marker_size, color=ctrl_flux, marker='o', alpha=0.5, zorder=0)


    plt.errorbar(lc.lcs[0].t.loc[preSN_ix,'MJD'], lc.lcs[0].t.loc[preSN_ix,'uJy'], yerr=lc.lcs[0].t.loc[preSN_ix,lc.dflux_colnames[0]], fmt='none', ecolor=sn_flux, elinewidth=1, capsize=1.2, c=sn_flux, alpha=0.5, zorder=10)
    plt.scatter(lc.lcs[0].t.loc[preSN_ix,'MJD'], lc.lcs[0].t.loc[preSN_ix,'uJy'], s=marker_size, lw=marker_edgewidth, color=sn_flux, marker='o', alpha=0.5, zorder=10, label='Pre-SN light curve')

    plt.errorbar(lc.lcs[0].t.loc[postSN_ix,'MJD'], lc.lcs[0].t.loc[postSN_ix,'uJy'], yerr=lc.lcs[0].t.loc[postSN_ix,lc.dflux_colnames[0]], fmt='none', ecolor='red', elinewidth=1, capsize=1.2, c=sn_flux, alpha=0.5, zorder=10)
    plt.scatter(lc.lcs[0].t.loc[postSN_ix,'MJD'], lc.lcs[0].t.loc[postSN_ix,'uJy'], s=marker_size, lw=marker_edgewidth, color='red', marker='o', alpha=0.5, zorder=10, label='Post-SN light curve')

    ax1.legend(loc='upper right', facecolor='white', framealpha=1.0).set_zorder(100)

    save_plot(save_filename=save_filename)

limits = [xlim_lower, xlim_upper, ylim_lower, ylim_upper]
plot_all_lcs(lc, plot_controls=plot_controls, plot_templates=True, limits=limits)

## Step 2: Static uncertainty cut

The following uncertainty cut implements a static cut that applies the same way to each light curve. The purpose of this cut is to identify and clean out the most egregious outliers with large uncertainties and small chi-square values that would not be cut out in the dynamic chi-square cut. The default value of this cut (160) was determined after calculating the typical uncertainty of bright stars just below the saturation limit. 

WARNING: **if the SN is particularly bright, you may want to increase the value and rerun the cut**.

In [None]:
# You may change the following static uncertainty cut value to your liking;
# however, the default value is set to 160.
uncertainty_cut = 160

# Plot the light curve before and after the applied uncertainty cut?:
plot = True

# Optionally, manually enter the x and y limits for the uncertainty cut plot:
xlim_lower = None
xlim_upper = None
ylim_lower = None
ylim_upper = None

In [None]:
# update 'Mask' column with uncertainty cut
print(f'Applying uncertainty cut of {uncertainty_cut:0.2f}...')

def plot_cut_lc(lc, title, flag, limits=None, save_filename=None):
    fig, (ax2, ax1) = plt.subplots(2, constrained_layout=True)
    fig.set_figwidth(7)
    fig.set_figheight(5)

    fig.suptitle(f'{title} (flag {hex(flag)})')

    ax1.minorticks_on()
    ax1.tick_params(direction='in', which='both')
    ax2.get_xaxis().set_ticks([])
    ax1.set_ylabel(r'Flux ($\mu$Jy)')
    ax1.axhline(linewidth=1, color='k')

    ax2.minorticks_on()
    ax2.tick_params(direction='in', which='both')
    ax2.set_ylabel(r'Flux ($\mu$Jy)')
    ax1.set_xlabel('MJD')
    ax2.axhline(linewidth=1, color='k')

    # set x and y limits
    limits = set_xylimits(lc, limits)
    if not limits is None:
        ax1.set_xlim(limits[0],limits[1])
        ax1.set_ylim(limits[2],limits[3])
        ax2.set_xlim(limits[0],limits[1])
        ax2.set_ylim(limits[2],limits[3])

    good_ix = lc.lcs[0].ix_unmasked('Mask', maskval=flag)
    bad_ix = lc.lcs[0].ix_masked('Mask', maskval=flag)

    ax1.errorbar(lc.lcs[0].t.loc[good_ix,'MJD'], lc.lcs[0].t.loc[good_ix,'uJy'], yerr=lc.lcs[0].t.loc[good_ix,lc.dflux_colnames[0]], fmt='none', ecolor=sn_flux, elinewidth=1, capsize=1.2, c=sn_flux, alpha=0.5)
    ax1.scatter(lc.lcs[0].t.loc[good_ix,'MJD'], lc.lcs[0].t.loc[good_ix,'uJy'], s=marker_size, lw=marker_edgewidth, color=sn_flux, marker='o', alpha=0.5, label='Kept measurements')

    ax2.errorbar(lc.lcs[0].t.loc[good_ix,'MJD'], lc.lcs[0].t.loc[good_ix,'uJy'], yerr=lc.lcs[0].t.loc[good_ix,lc.dflux_colnames[0]], fmt='none', ecolor=sn_flux, elinewidth=1, capsize=1.2, c=sn_flux, alpha=0.5, zorder=5)
    ax2.scatter(lc.lcs[0].t.loc[good_ix,'MJD'], lc.lcs[0].t.loc[good_ix,'uJy'], s=marker_size, lw=marker_edgewidth, color=sn_flux, marker='o', alpha=0.5, label='Kept measurements', zorder=5)

    ax2.errorbar(lc.lcs[0].t.loc[bad_ix,'MJD'], lc.lcs[0].t.loc[bad_ix,'uJy'], yerr=lc.lcs[0].t.loc[bad_ix,lc.dflux_colnames[0]], fmt='none', ecolor=sn_flagged_flux, elinewidth=1, capsize=1.2, c=sn_flagged_flux, alpha=0.5, zorder=10)
    ax2.scatter(lc.lcs[0].t.loc[bad_ix,'MJD'], lc.lcs[0].t.loc[bad_ix,'uJy'], s=marker_size, lw=marker_edgewidth, color=sn_flagged_flux, facecolors='none', edgecolors=sn_flagged_flux, marker='o', alpha=0.5, label='Cut measurements', zorder=10)

    ax1.legend(loc='upper right', facecolor='white', framealpha=1.0).set_zorder(100)
    ax2.legend(loc='upper right', facecolor='white', framealpha=1.0).set_zorder(100)

    save_plot(save_filename=save_filename)

def print_statistics(num_cut, percent_cut):
    print(f'\nNumber of cut measurements: {num_cut:d}\nPercent of total measurements cut: {percent_cut:0.2f}%')
    if percent_cut > 10:
        print(f'WARNING: percent of total measurements cut is greater than 10%. Plotting...')

def apply_uncertainty_cut(lc, uncertainty_cut, control_index=0):
    ix = lc.get_ix(control_index=control_index)
    kept_ix = lc.lcs[control_index].ix_inrange(colnames=['duJy'],uplim=uncertainty_cut)
    cut_ix = AnotB(ix, kept_ix)
    lc.update_mask_col(flags['uncertainty'], cut_ix, control_index=control_index)

    if control_index == 0:
        num_cut = len(cut_ix)
        percent_cut = 100 * num_cut/len(ix)
        print_statistics(num_cut, percent_cut)
        f.write(f'\n\n## Uncertainty cut\nNumber of cut measurements: {num_cut:d}\nPercent of total measurements cut: {percent_cut:0.2f}%\nHex value in "Mask" column: 0x2')
        
        if plot or percent_cut > 10:
            limits = [xlim_lower, xlim_upper, ylim_lower, ylim_upper]
            plot_cut_lc(lc, 'Uncertainty cut', flags['uncertainty'], limits=limits, save_filename='uncertainty_cut')

for control_index in range(lc.num_controls+1):
    apply_uncertainty_cut(lc, uncertainty_cut, control_index=control_index)
print('Success')

## Step 3: Estimating true uncertainties

This section attempts to account for an extra noise source in the data by estimating the true typical uncertainty, deriving the additional systematic uncertainty, and lastly applying this extra noise to a new uncertainty column. This new uncertainty column will be used in the cuts following this section.

Here is the procedure we use:
1. Keep the previously applied uncertainty cut and apply a preliminary chi-square cut at 20 (default value). Filter out any measurements flagged by these two cuts.
2. Calculate the true typical uncertainties $\text{sigma\_true\_typical}$ for each control light curve by taking a 3σ cut of the unflagged flux and getting the standard deviation.
3. ~~If $\text{sigma\_true\_typical}$ is 10%+ greater than the median uncertainty of the unflagged control light curve flux, $\text{median}(∂µJy)$, proceed with estimating the extra noise to add. Otherwise, skip this procedure.~~
4. Calculate the extra noise source for each control light curve using the following formula, where the median uncertainty, $\text{median}(∂µJy)$, is taken from the unflagged baseline flux:
    - $\text{sigma\_extra}^2 = \text{sigma\_true\_typical}^2 - \text{sigma\_poisson}^2$
    - $\text{sigma\_extra} = \sqrt{\text{sigma\_true\_typical}^2 - \text{median}(∂µJy)^2}$
5. Calculate the final extra noise source by taking the median of all $\text{sigma\_extra}$.
5. Apply the extra noise source to the existing uncertainty using the following formula:
    - $\text{new }∂µJy = \sqrt{(\text{old }∂µJy)^2 + \text{sigma\_extra}^2}$
6. For cuts following this procedure, use the new uncertainty column with the extra noise added instead of the old uncertainty column.

In [None]:
# Enter a preliminary chi-square cut (keep at a high number):
prelim_x2_cut = 20

# Plot the light curve before and after estimating true uncertainties?:
plot = False

# Optionally, manually enter the x and y limits for the plot:
xlim_lower = None
xlim_upper = None
ylim_lower = None
ylim_upper = None

In [None]:
def get_median_dflux(lc, indices=None, control_index=0):
    if indices is None:
        indices = lc.get_ix(control_index=control_index)
    return np.median(lc.lcs[control_index].t.loc[indices, 'duJy'])

def get_stdev(lc, indices=None, control_index=0):
    lc.lcs[control_index].calcaverage_sigmacutloop('uJy', indices=indices, Nsigma=3.0, median_firstiteration=True)
    return lc.lcs[control_index].statparams['stdev']

def get_sigma_extra(median_dflux, stdev):
    return max(0, np.sqrt(stdev**2 - median_dflux**2))
    # make sure bigger than 0 

def get_stats(lc):
    stats = pd.DataFrame(columns=['control_index', 'median_dflux', 'stdev', 'sigma_extra'])
    stats['control_index'] = list(range(1, lc.num_controls+1))
    stats.set_index('control_index', inplace=True)

    for control_index in range(1, lc.num_controls+1):
        dflux_clean_ix = lc.lcs[control_index].ix_unmasked('Mask', maskval=flags['uncertainty'])
        x2_clean_ix = lc.lcs[control_index].ix_inrange(colnames=['chi/N'], uplim=prelim_x2_cut, exclude_uplim=True)
        clean_ix = AandB(dflux_clean_ix, x2_clean_ix)

        median_dflux = get_median_dflux(lc, indices=clean_ix, control_index=control_index)
        stdev = get_stdev(lc, indices=clean_ix, control_index=control_index)
        sigma_extra = get_sigma_extra(median_dflux, stdev)

        stats.loc[control_index, 'median_dflux'] = median_dflux
        stats.loc[control_index, 'stdev'] = stdev
        stats.loc[control_index, 'sigma_extra'] = sigma_extra

    print(stats)
    return stats

def get_final_sigma_extra(stats):
    return np.median(stats['sigma_extra'])

def hist_stats(stats):
    fig, ax1 = plt.subplots(1, constrained_layout=True)
    fig.set_figwidth(4)
    fig.set_figheight(2)

    max_y = max(max(stats['median_dflux']), max(stats['stdev']), max(stats['sigma_extra']))
    bins = np.linspace(0, max_y+3, 10)
    median_sigma_extra = np.median(stats['sigma_extra'])
    #labels = ['median dflux', 'std dev', 'sigma extra']
    #ax1.hist([stats['median_dflux'], stats['stdev'], stats['sigma_extra']], bins=bins, label=labels)
   
    ax1.hist(stats['median_dflux'], bins=bins, alpha=0.5, label='median dflux')
    
    ax1.hist(stats['stdev'], bins=bins, alpha=0.5, label='std dev')

    ax1.hist(stats['sigma_extra'], bins=bins, alpha=0.5, label='sigma extra')
    ax1.axvline(median_sigma_extra, color='green', label='median sigma extra')

    ax1.legend(facecolor='white', framealpha=1, loc='upper left',  bbox_to_anchor=(1, 1))

def add_noise(lc, sigma_extra):
    lc.dflux_colnames = ['duJy_new'] * (lc.num_controls+1)
    for control_index in range(0, lc.num_controls+1):
        lc.lcs[control_index].t['duJy_new'] = np.sqrt(lc.lcs[control_index].t['duJy']*lc.lcs[control_index].t['duJy'] + sigma_extra**2)
        lc.recalculate_fdf(control_index=control_index)
    return lc

def get_results(lc, stats, final_sigma_extra, limits=None):
    print(f'Final sigma extra: {final_sigma_extra:0.2f}')

    sigma_typical_old = np.median(stats['median_dflux'])
    sigma_typical_new = np.sqrt(final_sigma_extra**2 + sigma_typical_old**2)
    percent_greater = 100 * ((sigma_typical_new - sigma_typical_old)/sigma_typical_old)
    
    print(f'We increase the typical uncertainties from {sigma_typical_old:0.2f} to {sigma_typical_new:0.2f} by adding an additional systematic uncertainty of {final_sigma_extra:0.2f} in quadrature. Proceed?')
    print(f'New typical uncertainty is {percent_greater:0.2f}% greater than old typical uncertainty.')
    if percent_greater >= 10:
        answer = input('True uncertainties estimation recommended. Proceed? (y/n)')
        if answer == 'y':
            print('Calculating new uncertainties in \'duJy_new\' column for each light curve...') 
            lc = add_noise(lc, sigma_extra=final_sigma_extra)
            print('Success')
            print('Quick sanity check:')
            print(lc.lcs[0].t[['MJD', 'uJy', 'duJy', 'duJy_new']].head())
            if plot:
                plot_true_uncertainties(lc, limits=limits, save_filename='true_uncerts')
        else:
            print('Skipping procedure.')
    else:
        print(f'True uncertainties estimation not needed.')
    return lc

def plot_true_uncertainties(lc, limits=None, save_filename=None):
    fig, (ax1, ax2) = plt.subplots(2, constrained_layout=True)
    fig.set_figwidth(7)
    fig.set_figheight(5)

    ax1.set_title(f'SN {tnsname} {filt}-band flux\nbefore true uncertainties estimation')
    ax1.minorticks_on()
    ax1.tick_params(direction='in', which='both')
    ax1.get_xaxis().set_ticks([])
    ax1.set_ylabel(r'Flux ($\mu$Jy)')
    ax1.axhline(linewidth=1, color='k')

    ax2.set_title(f'after true uncertainties estimation')
    ax2.minorticks_on()
    ax2.tick_params(direction='in', which='both')
    ax2.set_ylabel(r'Flux ($\mu$Jy)')
    ax2.set_xlabel('MJD')
    ax2.axhline(linewidth=1, color='k')

    # set x and y limits
    limits = set_xylimits(lc, limits)
    if not limits is None:
        ax1.set_xlim(limits[0],limits[1])
        ax1.set_ylim(limits[2],limits[3])
        ax2.set_xlim(limits[0],limits[1])
        ax2.set_ylim(limits[2],limits[3])

    ax1.errorbar(lc.lcs[0].t['MJD'], lc.lcs[0].t['uJy'], yerr=lc.lcs[0].t['duJy'], fmt='none', ecolor=sn_flux, elinewidth=1, capsize=1.2, c=sn_flux, alpha=0.5)
    ax1.scatter(lc.lcs[0].t['MJD'], lc.lcs[0].t['uJy'], s=marker_size, lw=marker_edgewidth, color=sn_flux, marker='o', alpha=0.5)

    ax2.errorbar(lc.lcs[0].t['MJD'], lc.lcs[0].t['uJy'], yerr=lc.lcs[0].t['duJy_new'], fmt='none', ecolor=sn_flux, elinewidth=1, capsize=1.2, c=sn_flux, alpha=0.5)
    ax2.scatter(lc.lcs[0].t['MJD'], lc.lcs[0].t['uJy'], s=marker_size, lw=marker_edgewidth, color=sn_flux, marker='o', alpha=0.5)

    save_plot(save_filename=save_filename)

stats = get_stats(lc)
hist_stats(stats)
final_sigma_extra = get_final_sigma_extra(stats)
lc = get_results(lc, stats, final_sigma_extra, limits=limits)

## Step 4: Dynamic chi-square cut

### 5a) Plot the flux/dflux and chi-square distributions

The following two histograms display the flux/dflux and chi-square distributions of the SN and control light curves. Both histograms show probability density so as to ease comparison between the groups plotted within each histogram.

- The first histogram focuses on the baseline flux/dflux (µJy/dµJy) measurements, where we can expect the flux to equal 0. In orange, we plot flux/dflux (µJy/dµJy) measurements with a chi-square value less than or equal to `x2bound`, which is currently set to 5 below; in blue, we plot flux/dflux (µJy/dµJy) measurements with a chi-square value greater than `x2bound`. 
- The second histogram focuses on the baseline chi-square measurements. In green, we plot chi-square measurements with an abs(µJy/dµJy) value less than or equal to `stnbound`, which is currently set to 3 below; in red, we plot chi-square measurements with an abs(µJy/dµJy) value greater than `stnbound`. 

Ideally, all measurements with a chi-square value less than or equal to `x2bound` should have an abs(µJy/dµJy) value less than or equal to `stnbound`, and measurements with a chi-square value greater than `x2bound` should have an abs(µJy/dµJy) value greater than `stnbound`. Our goal is to separate good measurements from bad measurements using a chi-square cut; in order for our cut to be effective, these histograms should hopefully showcase this relation between the target SN's flux/dflux and chi-square measurements.

In [None]:
# Enter the bound that should separate a good chi-square measurement from a bad one:
x2bound = 5.0

# Enter the bound that should separate a good abs(flux/dflux) measurement from a bad one:
stnbound = 3.0

# Optionally, set the density (if True, each bin will display the bin's raw count divided by total count and bin width)
# flux/dflux histogram density:
fdf_density = True
# chi-square histogram density:
#x2_density = True

# Optionally, manually enter the histograms' x limits here:
# flux/dflux histogram x limits:
fdf_xlims = (-30, 30)
# chi-square histogram x limits:
#x2_xlims = (-10, 50)

In [None]:
def fdf_hist(lc, all_controls, limits=None, density=True): 
    preSN_ix = lc.get_pre_SN_ix()
    preSN_good_ix = lc.lcs[0].ix_inrange(colnames=['chi/N'], uplim=x2bound, indices=preSN_ix)
    preSN_bad_ix = AnotB(preSN_ix, preSN_good_ix)

    controls_ix = all_controls.t.index.values
    controls_good_ix = all_controls.ix_inrange(colnames=['chi/N'], uplim=x2bound)
    controls_bad_ix = AnotB(controls_ix, controls_good_ix)

    fig, (ax1, ax2) = plt.subplots(2,1, constrained_layout=True)
    fig.set_figwidth(3)
    fig.set_figheight(3)

    ax1.set_title('for pre-SN light curve', fontsize=12)
    ax2.set_title('for control light curves', fontsize=12)
    ax2.set_xlabel('µJy/dµJy')
    
    bins = None
    if not limits is None:
        bins = np.linspace(limits[0]-10, limits[1]+10, 20)
        ax1.get_xaxis().set_ticks([])
        ax1.set_xlim(limits[0], limits[1])
        ax2.set_xlim(limits[0], limits[1])

    ax1.hist(lc.lcs[0].t.loc[preSN_good_ix,'uJy/duJy'], bins=bins, color='green', alpha=0.5, label=f'Data with chi-square<{x2bound}', density=density)
    ax1.hist(lc.lcs[0].t.loc[preSN_bad_ix,'uJy/duJy'], bins=bins, color='red', alpha=0.5, label=f'Data with chi-square≥{x2bound}', density=density)
    
    ax2.hist(all_controls.t.loc[controls_good_ix,'uJy/duJy'], bins=bins, color='green', alpha=0.5, density=density)
    ax2.hist(all_controls.t.loc[controls_bad_ix,'uJy/duJy'], bins=bins, color='red', alpha=0.5, density=density)

    fig.legend(facecolor='white', framealpha=1, loc='upper left',  bbox_to_anchor=(1, 1))

def x2_hist(lc, all_controls, limits=None, density=True):
    preSN_ix = lc.get_pre_SN_ix()
    preSN_good_ix = lc.lcs[0].ix_inrange(colnames=['uJy/duJy'], lowlim=-stnbound, uplim=stnbound, indices=preSN_ix)
    preSN_bad_ix = AnotB(preSN_ix, preSN_good_ix)

    controls_ix = all_controls.t.index.values
    controls_good_ix = all_controls.ix_inrange(colnames=['uJy/duJy'], lowlim=-stnbound, uplim=stnbound)
    controls_bad_ix = AnotB(controls_ix, controls_good_ix)

    fig, (ax1, ax2) = plt.subplots(2,1, constrained_layout=True)
    fig.set_figwidth(3)
    fig.set_figheight(3)

    ax1.set_title('for pre-SN light curve', fontsize=12)
    ax2.set_title('for control light curves', fontsize=12)
    ax2.set_xlabel('chi-square')
    
    bins = None
    if not limits is None:
        bins = np.linspace(limits[0]-10, limits[1]+10, 20)
        ax1.get_xaxis().set_ticks([])
        ax1.set_xlim(limits[0], limits[1])
        ax2.set_xlim(limits[0], limits[1])

    ax1.hist(lc.lcs[0].t.loc[preSN_good_ix,'chi/N'], bins=bins, color='green', alpha=0.5, label=f'Data with µJy/dµJy<{stnbound}', density=density)
    ax1.hist(lc.lcs[0].t.loc[preSN_bad_ix,'chi/N'], bins=bins, color='red', alpha=0.5, label=f'Data with µJy/dµJy≥{stnbound}', density=density)
    
    ax2.hist(all_controls.t.loc[controls_good_ix,'chi/N'], bins=bins, color='green', alpha=0.5, density=density)
    ax2.hist(all_controls.t.loc[controls_bad_ix,'chi/N'], bins=bins, color='red', alpha=0.5, density=density)

    fig.legend(facecolor='white', framealpha=1, loc='upper left',  bbox_to_anchor=(1, 1))

controls = [lc.lcs[control_index].t for control_index in lc.lcs if control_index > 0]
all_controls = pdastrostatsclass()
all_controls.t = pd.concat(controls, ignore_index=True)

fdf_hist(lc, all_controls, limits=fdf_xlims, density=fdf_density)
#x2_hist(lc, all_controls, limits=x2_xlims, density=x2_density)

### 5b) Calculate best chi-square cut based on contamination and loss

The following cells use two factors, <strong>contamination</strong> and <strong>loss</strong>, to attempt to calculate an optimal PSF chi-square cut for the target SN, with flux/dflux as the deciding factor of what constitutes a good measurement vs. a bad measurement. We aim to separate good measurements from bad using the calculated chi-square cut by removing as much contamination as possible with the smallest loss possible. Since we can assume that the expected value of the baseline flux is 0, we look only at the baseline measurements before the SN occurs in order to determine the best chi-square cut for the SN itself.

First, we decide what will determine a good measurement vs. a bad measurement using a factor outside of the chi-square values. Our chosen factor is the <strong>absolute value of flux (µJy) divided by dflux (dµJy)</strong>. The recommended boundary is a value of 3, such that any measurements with a value of abs(µJy/dµJy) less than or equal to 3 are regarded as "good" measurements, and any measurements with a value of abs(µJy/dµJy) greater than 3 are regarded as "bad" measurements. You can set this boundary to a different number by changing the value of `stn_cut` below.

Next, we set the upper and lower bounds of our final chi-square cut. We start at a low value of 3 (which can be changed by setting the value of `cut_start` below) and end at 50 (this value is inclusive and can be changed by setting the value of `cut_stop` below) with a step size of 1 (`cut_step` below). <strong>For chi-square cuts falling on or between `cut_start` and `cut_stop` in increments of `cut_step`, we can begin to calculate contamination and loss percentages.</strong>

We define contamination to be the number of bad kept measurements over the total number of kept measurements for that chi-square cut (<strong>contamination = Nbad,kept/Nkept</strong>). For our final chi-square cut, we can also set a limit on what maximum percent contamination we want to have--the recommended value is <strong>15%</strong> but can be changed by setting the value of `contam_lim` below.

We define loss to be the number of good cut measurements over the total number of good measurements for that chi-square cut (<strong>loss = Ngood,cut/Ngood</strong>). For our final chi-square cut, we can also set a limit on what maximum percent loss we want to have--the recommended value is <strong>10%</strong> but can be changed by setting the value of `loss_lim` below.

Finally, we define which limit (`contam_lim` or `loss_lim`) to prioritize in the event that an optimal chi-square cut fitting both limits is not found. The default prioritized limit is `loss_lim` but can be changed by setting the value of `lim_to_prioritize` below.

In [None]:
# Enter the abs(uJy/duJy) boundary that will determine a "good" measurement vs. "bad" measurement:
stn_cut = 3

# Enter the bounds for the final chi-square cut (minimum cut, maximum cut, and step):
cut_start = 3 # this is inclusive
cut_stop = 50 # this is inclusive
cut_step = 1

# Enter the contamination limit (contamination = Nbad,kept/Nkept must be <= contam_lim% 
# for the final chi-square cut):
contam_lim = 15.0

# Enter the loss limit (loss = Ngood,cut/Ngood must be >= loss_lim%
# for the final chi-square cut):
loss_lim = 10.0

# Enter the limit to prioritize (must be 'loss_lim' or 'contam_lim') in the event that
# one or both limits are not met:
lim_to_prioritize = 'loss_lim'

# If set to True, we use the pre-SN light curve to find the best cut.
# Else, we use the control light curves.
use_preSN_lc = False

The following section describes in detail how we determine the final chi-square cut using the given contamination and loss limits (feel free to skip).

For each given limit (contamination and loss), we calculate a range of valid cuts whose contamination/loss percentage is less than that limit and then choose a single cut within that valid range. Then, we pass through a decision tree to determine which of the two suggested cuts to use using a variety of factors (including the user's selected `lim_to_prioritize`).

When choosing the loss cut according to the loss percentage limit `loss_lim`:
- <strong>If all loss percentages are below the limit</strong> `loss_lim`, all cuts falling on or between `cut_start` and `cut_stop` are valid.
- <strong>If all loss percentages are above the limit</strong> `loss_lim`, a cut with the required loss percentage is not possible; therefore, any cuts with the smallest percentage of loss are valid.
- <strong>Otherwise</strong>, the valid range of cuts includes any cuts with the loss percentage less than or equal to the limit `loss_lim`.
- The chosen cut for this limit is the <strong>minimum cut</strong> within the stated valid range of cuts.

When choosing the loss cut according to the contamination percentage limit `contam_lim`:
- <strong>If all contamination percentages are below the limit</strong> `contam_lim`, all cuts falling on or between `cut_start` and `cut_stop` are valid.
- <strong>If all contamination percentages are above the limit</strong> `contam_lim`, a cut with the required contamination percentage is not possible; therefore, any cuts with the smallest percentage of contamination are valid.
- <strong>Otherwise</strong>, the valid range of cuts includes any cuts with the contamination percentage less than or equal to the limit `contam_lim`.
- The chosen cut for this limit is the <strong>maximum cut</strong> within the stated valid range of cuts.

After we have calculated two suggested cuts based on the loss and contamination percentage limits, we follow the decision tree in order to suggest a final cut:
- If both loss and contamination cut percentages were chosen from a range that spanned from `cut_start` to `cut_stop`, we set the final cut to `cut_start`.
- If one cut's percentage was chosen from a range that spanned from `cut_start` to `cut_stop` and the other cut's percentage was not, we set the final cut to the latter cut.
- If both percentages were chosen from ranges that fell above their respective limits, we suggest reselecting either or both limits.
- Otherwise, we take into account the user's prioritized limit `lim_to_prioritize`:
    - If the loss cut is greater than the contamination cut, we set the final cut to whichever cut is associated with `lim_to_prioritize`.
    - Otherwise, if `lim_to_prioritize` is set to `contam_lim`, we set the final cut to the loss cut, and if `lim_to_prioritize` is set to `loss_lim`, we set the final cut to the contamination cut.

In [None]:
def plot_lim_cuts(lim_cuts, contam_lim_cut, loss_lim_cut):
    loss_color = 'darkmagenta'
    contam_color = 'teal'

    fig, ax1 = plt.subplots(1, constrained_layout=True)
    fig.set_figwidth(5.5)
    fig.set_figheight(3)

    #ax1.set_ylim(0,1) # DELETE ME
    ax1.set_title(f'SN {tnsname} {filt}-band chi-square cut')

    #ax1.set_xlim(0, max())

    ax1.minorticks_on()
    ax1.tick_params(direction='in', which='both')
    if use_preSN_lc:
        ax1.set_ylabel(f'% pre-SN light curve measurements')
    else:
        ax1.set_ylabel(f'% control light curve measurements')
    ax1.set_xlabel('Chi-square cut')
    ax1.axhline(linewidth=1, color='k')

    ax1.axhline(loss_lim, linewidth=1, color=loss_color, linestyle='dotted')#, label='Loss limit')
    ax1.text(3.5, loss_lim+0.1, 'Loss limit', color=loss_color)
    ax1.plot(lim_cuts['PSF Chi-Square Cut'], lim_cuts['Ploss'], ms=3.5, color=loss_color, marker='o', label='Loss')
    ax1.axvline(x=loss_lim_cut, color=loss_color, linestyle='--', label='Loss cut')
    ax1.axvspan(loss_lim_cut, cut_stop, alpha=0.2, color=loss_color)

    ax1.axhline(contam_lim, linewidth=1, color=contam_color, linestyle='dotted')#, label='Contamination limit')
    ax1.text(3.5, contam_lim+0.1, 'Contamination limit', color=contam_color)
    ax1.plot(lim_cuts['PSF Chi-Square Cut'], lim_cuts['Pcontamination'], ms=3.5, color=contam_color, marker='o', label='Contamination')
    ax1.axvline(x=contam_lim_cut, color=contam_color, linestyle='--', label='Contamination cut')
    ax1.axvspan(cut_start, contam_lim_cut, alpha=0.2, color=contam_color)

    if ax1.get_ylim()[1] < loss_lim+2 or ax1.get_ylim()[1] < contam_lim+2:
        ax1.set_ylim(0, max(loss_lim+2, contam_lim+2))

    ax1.legend(facecolor='white', framealpha=1, bbox_to_anchor=(1.02, 1), loc='upper left')

def choose_btwn_lim_cuts(contam_lim_cut, loss_lim_cut, contam_case, loss_case):
    # case 1 and 1: final_cut = 3
    # case 1 and 2: take limit of case 2
    # case 1 and 3: take limit of case 3
    # case 2 and 2: print lims don't work
    # case 2 and 3: choose_btwn_lim_cuts
    # case 3 and 3: choose_btwn_lim_cuts

    case1 = loss_case == 'below lim' or contam_case == 'below lim'
    case2 = loss_case == 'above lim' or contam_case == 'above lim'
    case3 = loss_case == 'crosses lim' or contam_case == 'crosses lim'

    final_cut = None
    if case1 and not case2 and not case3: # 1 and 1
        print('Valid chi-square cut range from %0.2f to %0.2f! Setting to 3...' % (loss_lim_cut, contam_lim_cut))
        final_cut = cut_start
    elif case1: # 1
        if case2: # and 2
            if loss_case == 'above lim':
                print('WARNING: contam_lim_cut <= %0.2f falls below limit %0.2f%%, but loss_lim_cut >= %0.2f falls above limit %0.2f%%! Setting to %0.2f...' % (contam_lim_cut, contam_lim, loss_lim_cut, loss_lim, loss_lim_cut))
                final_cut = loss_lim_cut
            else:
                print('WARNING: loss_lim_cut <= %0.2f falls below limit %0.2f%%, but contam_lim_cut >= %0.2f falls above limit %0.2f%%! Setting to %0.2f...' % (loss_lim_cut, loss_lim, contam_lim_cut, contam_lim, contam_lim_cut))
                final_cut = contam_lim_cut
        else: # and 3
            if loss_case == 'crosses lim':
                print('Contam_lim_cut <= %0.2f falls below limit %0.2f%% and loss_lim_cut >= %0.2f crosses limit %0.2f%%, setting to %0.2f...' % (contam_lim_cut, contam_lim, loss_lim_cut, loss_lim, loss_lim_cut))
                final_cut = loss_lim_cut
            else:
                print('Loss_lim_cut <= %0.2f falls below limit %0.2f%% and contam_lim_cut >= %0.2f crosses limit %0.2f%%, setting to %0.2f...' % (loss_lim_cut, loss_lim, contam_lim_cut, contam_lim, contam_lim_cut))
                final_cut = contam_lim_cut
    elif case2 and not case3: # 2 and 2
        print('ERROR: chi-square loss_lim_cut >= %0.2f and contam_lim_cut <= %0.2f both fall above limits %0.2f%% and %0.2f%%! Try setting less strict limits. Setting final cut to nan.' % (loss_lim_cut, contam_lim_cut, loss_lim, contam_lim))
        final_cut = np.nan
    else: # 2 and 3 or 3 and 3
        if loss_lim_cut > contam_lim_cut:
            print('WARNING: chi-square loss_lim_cut >= %0.2f and contam_lim_cut <= %0.2f do not overlap! ' % (loss_lim_cut, contam_lim_cut))
            if lim_to_prioritize == 'contam_lim':
                print('Prioritizing %s and setting to %0.2f...' % (lim_to_prioritize, contam_lim_cut))
                final_cut = contam_lim_cut
            else:
                print('Prioritizing %s and setting to %0.2f... ' % (lim_to_prioritize, loss_lim_cut))
                final_cut = loss_lim_cut
        else:
            print('Valid chi-square cut range from %0.2f to %0.2f! ' % (loss_lim_cut, contam_lim_cut))
            if lim_to_prioritize == 'contam_lim':
                print('Prioritizing %s and setting to %0.2f... ' % (lim_to_prioritize, loss_lim_cut))
                final_cut = loss_lim_cut
            else:
                print('Prioritizing %s and setting to %0.2f... ' % (lim_to_prioritize, contam_lim_cut))
                final_cut = contam_lim_cut
    return final_cut

def get_lim_cuts(lim_cuts): 
    contam_lim_cut = None
    loss_lim_cut = None
    contam_case = None
    loss_case = None

    sortby_loss = lim_cuts.iloc[(lim_cuts['Ploss']).argsort()].reset_index()
    min_loss = sortby_loss.loc[0,'Ploss']
    max_loss = sortby_loss.loc[len(sortby_loss)-1,'Ploss']
    # if all loss below lim, loss_lim_cut is min cut
    if min_loss < loss_lim and max_loss < loss_lim:
        loss_case = 'below lim'
        loss_lim_cut = lim_cuts.loc[0,'PSF Chi-Square Cut']
    else:
        # else if all loss above lim, loss_lim_cut is min cut with min% loss
        if min_loss > loss_lim and max_loss > loss_lim:
            loss_case = 'above lim'
            a = np.where(lim_cuts['Ploss'] == min_loss)[0]
            b = lim_cuts.iloc[a]
            c = b.iloc[(b['PSF Chi-Square Cut']).argsort()].reset_index()
            loss_lim_cut = c.loc[0,'PSF Chi-Square Cut']
        # else if loss crosses lim at some point, loss_lim_cut is min cut with max% loss <= loss_lim
        else:
            loss_case = 'crosses lim'
            valid_cuts = sortby_loss[sortby_loss['Ploss'] <= loss_lim]
            a = np.where(lim_cuts['Ploss'] == valid_cuts.loc[len(valid_cuts)-1,'Ploss'])[0]
            # sort by cuts
            b = lim_cuts.iloc[a]
            c = b.iloc[(b['PSF Chi-Square Cut']).argsort()].reset_index()
            # get midpoint of loss1 and loss2 (two points on either side of lim)
            loss1_i = np.where(lim_cuts['PSF Chi-Square Cut'] == c.loc[0,'PSF Chi-Square Cut'])[0][0]
            if lim_cuts.loc[loss1_i,'Ploss'] == loss_lim:
                loss_lim_cut = lim_cuts.loc[loss1_i,'PSF Chi-Square Cut']
            else:
                loss2_i = loss1_i - 1
                x = np.array([lim_cuts.loc[loss1_i,'PSF Chi-Square Cut'], lim_cuts.loc[loss2_i,'PSF Chi-Square Cut']])
                contam_y = np.array([lim_cuts.loc[loss1_i,'Pcontamination'], lim_cuts.loc[loss2_i,'Pcontamination']])
                loss_y = np.array([lim_cuts.loc[loss1_i,'Ploss'], lim_cuts.loc[loss2_i,'Ploss']])
                contam_line = np.polyfit(x,contam_y,1)
                loss_line = np.polyfit(x,loss_y,1)
                loss_lim_cut = (loss_lim-loss_line[1])/loss_line[0]

    sortby_contam = lim_cuts.iloc[(lim_cuts['Pcontamination']).argsort()].reset_index()
    min_contam = sortby_contam.loc[0,'Pcontamination']
    max_contam = sortby_contam.loc[len(sortby_contam)-1,'Pcontamination']
    # if all contam below lim, contam_lim_cut is max cut
    if min_contam < contam_lim and max_contam < contam_lim:
        contam_case = 'below lim'
        contam_lim_cut = lim_cuts.loc[len(lim_cuts)-1,'PSF Chi-Square Cut']
    else:
        # else if all contam above lim, contam_lim_cut is max cut with min% contam
        if min_contam > contam_lim and max_contam > contam_lim:
            contam_case = 'above lim'
            a = np.where(lim_cuts['Pcontamination'] == min_contam)[0]
            b = lim_cuts.iloc[a]
            c = b.iloc[(b['PSF Chi-Square Cut']).argsort()].reset_index()
            contam_lim_cut = c.loc[len(c)-1,'PSF Chi-Square Cut']
        # else if contam crosses lim at some point, contam_lim_cut is max cut with max% contam <= contam_lim
        else:
            contam_case = 'crosses lim'
            valid_cuts = sortby_contam[sortby_contam['Pcontamination'] <= contam_lim]
            a = np.where(lim_cuts['Pcontamination'] == valid_cuts.loc[len(valid_cuts)-1,'Pcontamination'])[0]
            # sort by cuts
            b = lim_cuts.iloc[a]
            c = b.iloc[(b['PSF Chi-Square Cut']).argsort()].reset_index()
            # get midpoint of contam1 and contam2 (two points on either side of lim)
            contam1_i = np.where(lim_cuts['PSF Chi-Square Cut'] == c.loc[len(c)-1,'PSF Chi-Square Cut'])[0][0]
            if lim_cuts.loc[contam1_i,'Pcontamination'] == contam_lim:
                contam_lim_cut = lim_cuts.loc[contam1_i,'PSF Chi-Square Cut']
            else:
                contam2_i = contam1_i + 1
                x = np.array([lim_cuts.loc[contam1_i,'PSF Chi-Square Cut'], lim_cuts.loc[contam2_i,'PSF Chi-Square Cut']])
                contam_y = np.array([lim_cuts.loc[contam1_i,'Pcontamination'], lim_cuts.loc[contam2_i,'Pcontamination']])
                loss_y = np.array([lim_cuts.loc[contam1_i,'Ploss'], lim_cuts.loc[contam2_i,'Ploss']])
                contam_line = np.polyfit(x,contam_y,1)
                loss_line = np.polyfit(x,loss_y,1)
                contam_lim_cut = (contam_lim-contam_line[1])/contam_line[0]

    return contam_lim_cut, loss_lim_cut, contam_case, loss_case

def get_keptcut_indices(lc, ix, cut):
    kept_ix = lc.ix_inrange(colnames=['chi/N'], uplim=cut, indices=ix)
    cut_ix = AnotB(ix, kept_ix)
    return kept_ix, cut_ix

def get_goodbad_indices(lc, ix):
    good_ix = lc.ix_inrange(colnames=['uJy/duJy'], lowlim=-stn_cut, uplim=stn_cut, indices=ix)
    bad_ix = AnotB(ix, good_ix)
    return good_ix, bad_ix

def get_lim_cuts_data(lc, cut, ix, good_ix, bad_ix):
    kept_ix, cut_ix = get_keptcut_indices(lc, ix, cut)
    out = {}
    out['PSF Chi-Square Cut'] = cut
    out['N'] = len(ix)
    out['Ngood'] = len(good_ix)
    out['Nbad'] = len(bad_ix)
    out['Nkept'] = len(kept_ix)
    out['Ncut'] = len(cut_ix)
    out['Ngood,kept'] = len(AandB(good_ix,kept_ix))
    out['Ngood,cut'] = len(AandB(good_ix,cut_ix))
    out['Nbad,kept'] = len(AandB(bad_ix,kept_ix))
    out['Nbad,cut'] = len(AandB(bad_ix,cut_ix))
    out['Pgood,kept'] = 100*len(AandB(good_ix,kept_ix))/len(ix)
    out['Pgood,cut'] = 100*len(AandB(good_ix,cut_ix))/len(ix)
    out['Pbad,kept'] = 100*len(AandB(bad_ix,kept_ix))/len(ix)
    out['Pbad,cut'] = 100*len(AandB(bad_ix,cut_ix))/len(ix)
    out['Ngood,kept/Ngood'] = 100*len(AandB(good_ix,kept_ix))/len(good_ix)
    out['Ploss'] = 100*len(AandB(good_ix,cut_ix))/len(good_ix)
    out['Pcontamination'] = 100*len(AandB(bad_ix,kept_ix))/len( kept_ix)
    return out

def get_lim_cuts_table(lc, ix, good_ix, bad_ix, stn_cut, cut_start, cut_stop, cut_step, is_SNlc=False):
    print('abs(uJy/duJy) cut at %0.2f \nx2 cut from %0.2f to %0.2f inclusive, with step size %d' % (stn_cut,cut_start,cut_stop,cut_step))

    lim_cuts = pd.DataFrame(columns=['PSF Chi-Square Cut', 'N', 'Ngood', 'Nbad', 'Nkept', 'Ncut', 'Ngood,kept', 'Ngood,cut', 'Nbad,kept', 'Nbad,cut',
                                     'Pgood,kept', 'Pgood,cut', 'Pbad,kept', 'Pbad,cut', 'Ngood,kept/Ngood', 'Ploss', 'Pcontamination'])#,
                                     #'Nbad,cut 3<stn<=5', 'Nbad,cut 5<stn<=10', 'Nbad,cut 10<stn', 'Nbad,kept 3<stn<=5', 'Nbad,kept 5<stn<=10', 'Nbad,kept 10<stn'])

    # static cut at x2 = 50
    x2cut_50 = np.where(lc.t['chi/N'] >= 50)[0]
    print('Static chi square cut at 50: %0.2f%% cut for baseline' % (100*len(AandB(ix,x2cut_50))/len(ix)))

    # for different x2 cuts decreasing from 50
    for cut in range(cut_start,cut_stop+1,cut_step):
        kept_ix, cut_ix = get_keptcut_indices(lc, ix, cut)
        if 100*(len(kept_ix)/len(ix)) < 10:
            # less than 10% of measurements kept, so no chi-square cuts beyond this point are valid
            print(f'# At cut {cut}, less than 10% of measurements are kept ({100*(len(kept_ix)/len(ix)):0.2f}% kept)--skipping...')
            continue
        
        data = get_lim_cuts_data(lc, cut, ix, good_ix, bad_ix)
        lim_cuts = pd.concat([lim_cuts, pd.DataFrame([data])],ignore_index=True)

    return lim_cuts

if lim_to_prioritize != 'loss_lim' and lim_to_prioritize != 'contam_lim':
    print("ERROR: lim_to_prioritize must be 'loss_lim' or 'contam_lim'!")
    sys.exit()
print('Contamination limit: %0.2f%%\nLoss limit: %0.2f%%' % (contam_lim,loss_lim))

if use_preSN_lc:
    lc_temp = lc.lcs[0]
    ix = lc_temp.ix_inrange('MJD', uplim=discovery_date)
else:
    lc_temp = all_controls
    ix = all_controls.t.index.values
good_ix, bad_ix = get_goodbad_indices(lc_temp, ix)

lim_cuts = get_lim_cuts_table(lc_temp, 
                              ix, 
                              good_ix, 
                              bad_ix, 
                              stn_cut, 
                              cut_start, 
                              cut_stop, 
                              cut_step, 
                              is_SNlc=use_preSN_lc)
#print(lim_cuts.to_string())
if lim_cuts.empty:
    raise RuntimeError('No cuts kept more than 10%% of measurements--chi-square cut not applicable for this SN!')

contam_lim_cut, loss_lim_cut, contam_case, loss_case = get_lim_cuts(lim_cuts)
#kept_ix, cut_ix = get_keptcut_indices(lc_temp, )
contam_data = get_lim_cuts_data(lc_temp, contam_lim_cut, ix, good_ix, bad_ix)
loss_data = get_lim_cuts_data(lc_temp, loss_lim_cut, ix, good_ix, bad_ix)

print('\nContamination cut according to given contam_limit, with %0.2f%% contamination and %0.2f%% loss: %0.2f' % (contam_data['Pcontamination'], contam_data['Ploss'], contam_lim_cut))
if contam_case == 'above lim':
    print('WARNING: Contamination cut not possible with contamination <= contam_lim %0.1f!' % contam_lim)
print('Loss cut according to given loss_limit, with %0.2f%% contamination and %0.2f%% loss: %0.2f' % (loss_data['Pcontamination'], loss_data['Ploss'], loss_lim_cut))
if loss_case == 'above lim':
    print('WARNING: Loss cut not possible with loss <= loss_lim %0.2f!' % loss_lim)

final_cut = choose_btwn_lim_cuts(contam_lim_cut, loss_lim_cut, contam_case, loss_case)

if np.isnan(final_cut):
    print('\nERROR: Final suggested chi-square cut could not be determined. We suggest rethinking your contamination and loss limits.')
    Pcontamination = np.nan
    Ploss = np.nan
else:
    if final_cut == contam_lim_cut:
        Pcontamination = contam_data['Pcontamination']
        Ploss = contam_data['Ploss']
    else: # final_cut == loss_lim_cut
        Pcontamination = loss_data['Pcontamination']
        Ploss = loss_data['Ploss']
    print('\nFinal suggested chi-square cut is %0.2f, with %0.2f%% contamination and %0.2f%% loss.' % (final_cut, Pcontamination, Ploss))
    if (Pcontamination > contam_lim):
        print('WARNING: Final cut\'s contamination %0.2f%% exceeds contam_lim %0.2f%%!' % (Pcontamination,contam_lim))
    if (Ploss > loss_lim):
        print('WARNING: Final cut\'s loss exceeds %0.2f%% loss_lim %0.2f%%!' % (Ploss,loss_lim))

fdf_hist(lc, all_controls, limits=fdf_xlims, density=fdf_density)
plot_lim_cuts(lim_cuts, contam_lim_cut, loss_lim_cut)

### 5c) Confirm or override the final cut, apply final cut, and optionally plot

For this last section, we apply the final chi-square cut to the SN and control light curves. We warn that for very bright SNe, the chi-square values may increase even for good measurements due to imperfection in PSF fitting--therefore, we recommend that the user double-check the chi-square values (or this section's plot) to verify that the cut is working as intended.

In [None]:
# Plot the light curve before and after the applied chi-square cut?:
plot = True

# Optionally, manually enter the x and y limits for the chi-square cut plot:
xlim_lower = None
xlim_upper = None
ylim_lower = None
ylim_upper = None

In [None]:
def apply_chisquare_cut(lc, chisquare_cut, control_index=0):
    ix = lc.get_ix(control_index=control_index)
    kept_ix = lc.lcs[control_index].ix_inrange(colnames=['chi/N'],uplim=chisquare_cut)
    cut_ix = AnotB(ix, kept_ix)
    lc.update_mask_col(flags['chisquare'], cut_ix, control_index=control_index)

    if control_index == 0:
        num_cut = len(cut_ix)
        percent_cut = 100 * num_cut/len(ix)
        print_statistics(num_cut, percent_cut)
        f.write(f'\n\n## Chi-square cut\nNumber of cut measurements: {num_cut:d}\nPercent of total measurements cut: {percent_cut:0.2f}%\nHex value in "Mask" column: 0x2')

        if plot or percent_cut > 10:
            limits = [xlim_lower, xlim_upper, ylim_lower, ylim_upper]
            plot_cut_lc(lc, 'Chi-square cut', flags['chisquare'], limits=limits, save_filename='chisquare_cut')

def get_percent_greater(a, b):
    return 100 * ((a - b)/b)

def sanity_check(lc, all_controls, chisquare_cut):
    postSN_ix = lc.get_post_SN_ix()
    postSN_kept_ix, postSN_cut_ix = get_keptcut_indices(lc.lcs[0], postSN_ix, chisquare_cut)
    postSN_percent_cut = len(postSN_cut_ix)/len(postSN_ix)

    preSN_ix = lc.get_pre_SN_ix() #.get_ix(control_index=control_index)
    preSN_kept_ix, preSN_cut_ix = get_keptcut_indices(lc.lcs[0], preSN_ix, chisquare_cut)
    preSN_percent_cut = len(preSN_cut_ix)/len(preSN_ix)

    controls_ix = all_controls.t.index.values
    controls_kept_ix, controls_cut_ix = get_keptcut_indices(all_controls, controls_ix, chisquare_cut)
    controls_percent_cut = len(controls_cut_ix)/len(controls_ix)

    postSN_percent_greater = get_percent_greater(postSN_percent_cut, controls_percent_cut)
    preSN_percent_greater = get_percent_greater(preSN_percent_cut, controls_percent_cut)

    print('\nSanity check:')
    print(f'# {controls_percent_cut:0.4f}% of measurements cut in control light curves')
    
    out = f'{postSN_percent_cut:0.4f}% of measurements cut in SN light curve'
    if postSN_percent_greater >= 50:
        out = f'# WARNING: {out} (increase by a factor of approx. {postSN_percent_greater/100+1:0.2f})'
        out += f'\n  Bright SNe may cause the chi-square values to increase even at good measurements due to PSF fitting imperfections. Please double check the plot.'
    else:
        out = f'# {out}'
    print(out)

    out = f'{preSN_percent_cut:0.4f}% of measurements cut in pre-SN light curve'
    if preSN_percent_greater >= 50:
        out = f'# WARNING: {out} (increase by a factor of approx. {preSN_percent_greater/100+1:0.2f})'
        out += f'\n  Bright SNe may cause the chi-square values to increase even at good measurements due to PSF fitting imperfections. Please double check the plot.'
    else:
        out = f'# {out}'
    print(out)

answer = input(f'Accept final chi-square cut of {final_cut:0.2f} (y/n), or enter your own cut that minimizes contamination and loss?:')
if answer != 'y':
    final_cut = float(input('Overriding final chi-square cut; enter manual cut: '))
    if use_preSN_lc:
        lc_temp = lc.lcs[0]
        ix = lc.lcs[0].ix_inrange('MJD', uplim=discovery_date)
    else:
        lc_temp = all_controls
        ix = all_controls.t.index.values
    good_ix, bad_ix = get_goodbad_indices(lc_temp, ix)
    data = get_lim_cuts_data(lc_temp, final_cut, ix, good_ix, bad_ix)
    print(f'Overridden: final cut is now {final_cut}, with contamination {data["Pcontamination"]:0.2f}% and loss {data["Ploss"]:0.2f}%')

for control_index in range(lc.num_controls+1):
    apply_chisquare_cut(lc, final_cut, control_index=control_index)
print('Success')
sanity_check(lc, all_controls, final_cut)

## Step 6: Control light curves:  3σ-clipped average cut

While the chi-square and uncertainty cuts are effective in cutting out a majority of the bad measurements, tricky cases may require a larger set of control light curves that can be used as a basis of comparison for inconsistent flux. In order to account for this inconsistent flux, we can obtain ~8 quality control forced photometry light curves in a 17" circle pattern around the SN location OR around a nearby bright object that may be poorly subtracting. Then, we use statistics from these control light curves to cut bad measurements from the SN light curve.

For a given epoch, we have 1 SN measurement for which we examine 8 control measurements within the same epoch. We know that if the control light curve measurements are NOT consistent with 0, this indicates something wrong with this epoch, so the SN measurement is unreliable. Therefore, we obtain statistics for the control light curves by calculating the 3σ-clipped average of the control flux. 

For the given epoch, we cut the SN measurement for which the returned control statistics fulfill any of the following criteria: 
- A returned chi-square > 2.5
- A returned abs(flux/dflux) > 3.0
- Number of clipped/"bad" measurements in the 3σ-clipped average > 2
- Number of used/"good" measurements in the 3σ-clipped average < 4

Measurements not fulfilling any of the criteria above but with Nclip > 0 are flagged as questionable.

In [None]:
# Enter the bound for an epoch's maximum chi-square 
# (if x2 > x2_max, flag SN measurement):
x2_max = 2.5

# Enter the bound for an epoch's maximum abs(flux/dflux) ratio 
# (if abs(flux/dflux) > stn_max, flag SN measurement):
stn_max = 3.0

# Enter the bound for an epoch's maximum number of clipped control measurements
# (if Nclip > Nclip_max, flag SN measurement):
Nclip_max = 2

# Enter the bound for an epoch's minimum number of good control measurements
# (if Ngood < Ngood_min, flag SN measurement):
Ngood_min = 4

# Plot the light curve before and after the applied uncertainty cut?:
plot = True

# Optionally, manually enter the x and y limits for the control light curve cut plot:
xlim_lower = None
xlim_upper = None
ylim_lower = None
ylim_upper = None

In [None]:
# apply control light curve cut and save flags in 'Mask' column

def get_control_stats(lc):
    print('\nCalculating control light curve statistics...')

    len_mjd = len(lc.lcs[0].t['MJD'])

    # construct arrays for control lc data
    uJy = np.full((lc.num_controls, len_mjd), np.nan)
    duJy = np.full((lc.num_controls, len_mjd), np.nan)
    Mask = np.full((lc.num_controls, len_mjd), 0, dtype=np.int32)
    
    for control_index in range(1, lc.num_controls+1):
        if (len(lc.lcs[control_index].t) != len_mjd) or (np.array_equal(lc.lcs[0].t['MJD'], lc.lcs[control_index].t['MJD']) is False):
            raise RuntimeError(f'ERROR: SN lc not equal to control lc for control_index {control_index}! Rerun or debug verify_mjds().')
        else:
            uJy[control_index-1,:] = lc.lcs[control_index].t['uJy']
            duJy[control_index-1,:] = lc.lcs[control_index].t[lc.dflux_colnames[control_index]]
            Mask[control_index-1,:] = lc.lcs[control_index].t['Mask']

    c2_param2columnmapping = lc.lcs[0].intializecols4statparams(prefix='c2_',format4outvals='{:.2f}',skipparams=['converged','i'])

    for index in range(uJy.shape[-1]):
        pda4MJD = pdastrostatsclass()
        pda4MJD.t['uJy'] = uJy[0:,index]
        pda4MJD.t[lc.dflux_colnames[0]] = duJy[0:,index]
        pda4MJD.t['Mask'] = np.bitwise_and(Mask[0:,index], flags['chisquare']|flags['uncertainty'])
        
        pda4MJD.calcaverage_sigmacutloop('uJy',
                                         noisecol=lc.dflux_colnames[0],
                                         maskcol='Mask',
                                         maskval=(flags['chisquare']|flags['uncertainty']),
                                         verbose=1, Nsigma=3.0, median_firstiteration=True)
        lc.lcs[0].statresults2table(pda4MJD.statparams, c2_param2columnmapping, destindex=index)

    return lc

def controls_cut(lc):
    print('Flagging SN light curve based on control light curve statistics...')

    lc.lcs[0].t['c2_abs_stn'] = lc.lcs[0].t['c2_mean'] / lc.lcs[0].t['c2_mean_err']

    # flag measurements according to given bounds
    flag_x2_i = lc.lcs[0].ix_inrange(colnames=['c2_X2norm'], lowlim=x2_max, exclude_lowlim=True)
    flag_stn_i = lc.lcs[0].ix_inrange(colnames=['c2_abs_stn'], lowlim=stn_max, exclude_lowlim=True)
    flag_nclip_i = lc.lcs[0].ix_inrange(colnames=['c2_Nclip'], lowlim=Nclip_max, exclude_lowlim=True)
    flag_ngood_i = lc.lcs[0].ix_inrange(colnames=['c2_Ngood'], uplim=Ngood_min, exclude_uplim=True)
    lc.lcs[0].t.loc[flag_x2_i,'Mask'] |= flags['controls_x2']
    lc.lcs[0].t.loc[flag_stn_i,'Mask'] |= flags['controls_stn']
    lc.lcs[0].t.loc[flag_nclip_i,'Mask'] |= flags['controls_Nclip']
    lc.lcs[0].t.loc[flag_ngood_i,'Mask'] |= flags['controls_Ngood']

    # update mask column with control light curve cut on any measurements flagged according to given bounds
    zero_Nclip_ix = lc.lcs[0].ix_equal('c2_Nclip', 0)
    flags_temp = (flags['controls_x2'] | flags['controls_stn'] | flags['controls_Nclip'] | flags['controls_Ngood'])
    unmasked_ix = lc.lcs[0].ix_unmasked('Mask',maskval=flags_temp) #lc.get_unmasked_ix(flags=flags_temp) #lc.lcs[0].ix_unmasked('Mask', maskval=flag_controls_x2|flag_controls_stn|flag_controls_Nclip|flag_controls_Ngood)
    lc.lcs[0].t.loc[AnotB(unmasked_ix,zero_Nclip_ix),'Mask'] |= flags['controls_questionable']
    lc.lcs[0].t.loc[AnotB(lc.get_ix(),unmasked_ix),'Mask'] |= flags['controls_bad']

    # copy over SN's control cut flags to control light curve 'Mask' column
    flags_temp = (flags['controls_questionable'] | flags['controls_x2'] | flags['controls_stn'] | flags['controls_Nclip'] | flags['controls_Ngood'])
    flags_arr = np.full(lc.lcs[0].t['Mask'].shape, flags_temp)
    flags_to_copy = np.bitwise_and(lc.lcs[0].t['Mask'], flags_arr)
    for control_index in range(1,lc.num_controls+1):
        lc.lcs[control_index].t['Mask'] = lc.lcs[control_index].t['Mask'].astype(np.int32)
        if len(lc.lcs[control_index].t) < 1:
            continue
        elif len(lc.lcs[control_index].t) == 1:
            lc.lcs[control_index].t.loc[0,'Mask']= int(lc.lcs[control_index].t.loc[0,'Mask']) | flags_to_copy
        else:
            lc.lcs[control_index].t['Mask'] = np.bitwise_or(lc.lcs[control_index].t['Mask'], flags_to_copy)

    return lc

def print_flag_stats(lc):
    percent_cut = 100 * len(lc.lcs[0].ix_masked('Mask', maskval=flags['controls_bad'])) / len(lc.lcs[0].t) 
    percent_questionable = 100 * len(lc.lcs[0].ix_masked('Mask', maskval=flags['controls_questionable'])) / len(lc.lcs[0].t)

    x2_max_pcnt = 100 * len(lc.lcs[0].ix_masked('Mask', maskval=flags['controls_x2'])) / len(lc.lcs[0].t)
    stn_max_pcnt = 100 * len(lc.lcs[0].ix_masked('Mask', maskval=flags['controls_stn'])) / len(lc.lcs[0].t)
    Nclip_max_pcnt = 100 * len(lc.lcs[0].ix_masked('Mask', maskval=flags['controls_Nclip'])) / len(lc.lcs[0].t)
    Ngood_min_pcnt = 100 * len(lc.lcs[0].ix_masked('Mask', maskval=flags['controls_Ngood'])) / len(lc.lcs[0].t)

    print('\nLength of SN light curve: %d' % len(lc.lcs[0].t))
    print('Percent of data above x2_max bound: %0.2f%%' % x2_max_pcnt)
    print('Percent of data above stn_max bound: %0.2f%%' % stn_max_pcnt)
    print('Percent of data above Nclip_max bound: %0.2f%%' % Nclip_max_pcnt)
    print('Percent of data below Ngood_min bound: %0.2f%%' % Ngood_min_pcnt)
    print('Total percent of data flagged as bad: %0.2f%%' % percent_cut)
    print('Total percent of data flagged as questionable: %0.2f%%' % percent_questionable)

    f.write(f'\n\n## Control light curve cut\nPercent of data above x2_max bound: {x2_max_pcnt:0.2f}%\nPercent of data above stn_max bound: {stn_max_pcnt:0.2f}%\nPercent of data above Nclip_max bound: {Nclip_max_pcnt:0.2f}%\nPercent of data below Ngood_min bound: {Ngood_min_pcnt:0.2f}%')
    f.write(f'\nTotal percent of data flagged as bad: {percent_cut:0.2f}%\nTotal percent of data flagged as questionable: {percent_questionable:0.2f}%\nHex value in "Mask" column (flagged as "bad"): 0x400000\nHex value in "Mask" column (flagged as "questionable"): 0x80000')

    if plot or percent_cut > 10:
        limits = [xlim_lower, xlim_upper, ylim_lower, ylim_upper]
        plot_cut_lc(lc, 'Control light curve cut', flags['controls_bad'], limits=limits, save_filename='controls_cut')

if load_controls:    
    lc = get_control_stats(lc)
    lc = controls_cut(lc)
    print_flag_stats(lc)
else:
    print('Load_controls set to False! Skipping...')

## Plot the ATLAS light curve with a combination of previous cuts

In [None]:
plot_cut_lc(lc, 'Uncertainty, chi-square, and control light curve cuts', flags['uncertainty']|flags['chisquare']|flags['controls_bad'], limits=limits, save_filename='all_cut')

## Step 7: Averaging and cutting bad bins

Our goal is to identify and cut out bad MJD bins by taking a 3σ-clipped average of each bin. For each bin, we calculate the 3σ-clipped average of any SN measurements falling within that bin and use that average as our flux for that bin. Because the ATLAS survey takes about 4 exposures every 2 days, we usually average together approximately 4 measurements per epoch. However, out of these 4 exposures, only measurements not cut in the previous methods are averaged in the 3σ-clipped average cut. (The exception to this statement would be the case that all 4 measurements are cut in previous methods; in this case, they are averaged anyway and flagged as a bad bin.)

Then we cut any measurements in the SN light curve for the given epoch for which statistics fulfill any of the following criteria: 
- A returned chi-square > 4.0
- Number of measurements averaged < 2
- Number of measurements clipped > 1

For this part of the cleaning, we still need to improve the cutting at the peak of the SN (important epochs are sometimes cut, maybe due to fast rise, etc.).

In [None]:
# Enter the MJD bin size in days:
mjd_bin_size = 1

# Should MJD bins with no measurements be translated as NaN (True) 
# or removed from the averaged light curve (False)?
keep_empty_bins = False

# After flux is averaged, average magnitudes are calculated using a flux-to-magnitude conversion.
# Magnitudes are limits if the dmagnitude is NaN. Enter these magnitudes' sigma limit:
flux2mag_sigma_limit = 3

# Enter the bound for a bin's maximum number of clipped measurements
# (if Nclip > Nclip_max, flag day):
Nclip_max = 1

# Enter the bound for a bin's minimum number of good measurements
# (if Ngood < Ngood_min, flag day):
Ngood_min = 2

# Enter the bound for a bin's maximum chi-square (if x2 > x2_max, flag day):
x2_max = 4.0

# Optionally, manually enter the x and y limits for the plot:
xlim_lower = None
xlim_upper = None
ylim_lower = None
ylim_upper = None

In [None]:
def average_lc(lc, avglc, Nclip_max, Ngood_min, x2_max, mjd_bin_size=1, flux2mag_sigma_limit=3.0, keep_empty_bins=True):
    mjd = int(np.amin(lc.lcs[0].t['MJD']))
    mjd_max = int(np.amax(lc.lcs[0].t['MJD']))+1

    good_ix = lc.lcs[0].ix_unmasked('Mask', maskval=flags['chisquare']|flags['uncertainty'])

    while mjd <= mjd_max:
        range_ix = lc.lcs[0].ix_inrange(colnames=['MJD'], lowlim=mjd, uplim=mjd+mjd_bin_size, exclude_uplim=True)
        range_good_ix = AandB(range_ix,good_ix)

        # add new row to avglc if keep_empty_bins or any measurements present
        if keep_empty_bins or len(range_ix) >= 1:
            new_row = {'MJDbin':mjd+0.5*mjd_bin_size, 'Nclip':0, 'Ngood':0, 'Nexcluded':len(range_ix)-len(range_good_ix), 'Mask':0}
            avglc_index = avglc.lcs[0].newrow(new_row)
        
        # if no measurements present, flag or skip over day
        if len(range_ix) < 1:
            if keep_empty_bins:
                avglc.update_mask_col(flags['avg_badday'], [avglc_index])
            mjd += mjd_bin_size
            continue
        
        # if no good measurements, average values anyway and flag
        if len(range_good_ix) < 1:
            # average flux
            lc.lcs[0].calcaverage_sigmacutloop('uJy', noisecol=lc.dflux_colnames[0], indices=range_ix, Nsigma=3.0, median_firstiteration=True)
            fluxstatparams = deepcopy(lc.lcs[0].statparams)

            # average mjd
            # SHOULD NOISECOL HERE BE DUJY OR NONE??
            lc.lcs[0].calcaverage_sigmacutloop('MJD', noisecol=lc.dflux_colnames[0], indices=fluxstatparams['ix_good'], Nsigma=0, median_firstiteration=False)
            avg_mjd = lc.lcs[0].statparams['mean']

            # add row and flag
            avglc.lcs[0].add2row(avglc_index, {'MJD':avg_mjd, 
                                               'uJy':fluxstatparams['mean'], 
                                               'duJy':fluxstatparams['mean_err'], 
                                               'stdev':fluxstatparams['stdev'],
                                               'x2':fluxstatparams['X2norm'],
                                               'Nclip':fluxstatparams['Nclip'],
                                               'Ngood':fluxstatparams['Ngood'],
                                               'Mask':0})
            lc.update_mask_col(flags['avg_badday'], range_ix)
            avglc.update_mask_col(flags['avg_badday'], [avglc_index])

            mjd += mjd_bin_size
            continue
        
        # average good measurements
        lc.lcs[0].calcaverage_sigmacutloop('uJy', noisecol=lc.dflux_colnames[0], indices=range_good_ix, Nsigma=3.0, median_firstiteration=True)
        fluxstatparams = deepcopy(lc.lcs[0].statparams)

        if fluxstatparams['mean'] is None or len(fluxstatparams['ix_good']) < 1:
            lc.update_mask_col(flags['avg_badday'], range_ix)
            avglc.update_mask_col(flags['avg_badday'], [avglc_index])
            mjd += mjd_bin_size
            continue

        # average mjd
        # SHOULD NOISECOL HERE BE DUJY OR NONE??
        lc.lcs[0].calcaverage_sigmacutloop('MJD', noisecol=lc.dflux_colnames[0], indices=fluxstatparams['ix_good'], Nsigma=0, median_firstiteration=False)
        avg_mjd = lc.lcs[0].statparams['mean']

        # add row
        avglc.lcs[0].add2row(avglc_index, {'MJD':avg_mjd, 
                                           'uJy':fluxstatparams['mean'], 
                                           'duJy':fluxstatparams['mean_err'], 
                                           'stdev':fluxstatparams['stdev'],
                                           'x2':fluxstatparams['X2norm'],
                                           'Nclip':fluxstatparams['Nclip'],
                                           'Ngood':fluxstatparams['Ngood'],
                                           'Mask':0})
        
        # flag clipped measurements in lc
        if len(fluxstatparams['ix_clip']) > 0:
            lc.update_mask_col(flags['avg_ixclip'], fluxstatparams['ix_clip'])
        
        # if small number within this bin, flag measurements
        if len(range_good_ix) < 3:
            lc.update_mask_col(flags['avg_smallnum'], range_good_ix) # CHANGE TO RANGE_I??
            avglc.update_mask_col(flags['avg_smallnum'], [avglc_index])
        # else check sigmacut bounds and flag
        else:
            is_bad = False
            if fluxstatparams['Ngood'] < Ngood_min:
                is_bad = True
            if fluxstatparams['Nclip'] > Nclip_max:
                is_bad = True
            if not(fluxstatparams['X2norm'] is None) and fluxstatparams['X2norm'] > x2_max:
                is_bad = True
            if is_bad:
                lc.update_mask_col(flags['avg_badday'], range_ix)
                avglc.update_mask_col(flags['avg_badday'], [avglc_index])

        mjd += mjd_bin_size

    # convert flux to magnitude and dflux to dmagnitude
    for col in ['uJy','duJy']: 
        avglc.lcs[0].t[col] =avglc.lcs[0].t[col].astype(float)
    avglc.lcs[0].flux2mag('uJy','duJy','m','dm', zpt=23.9, upperlim_Nsigma=flux2mag_sigma_limit)

    print('Success')

    avglc.drop_extra_columns()

    for col in ['Nclip','Ngood','Nexcluded','Mask']: 
        avglc.lcs[0].t[col] = avglc.lcs[0].t[col].astype(np.int32)

    return avglc

if len(lc.lcs[0].t) < 1:
    print('ERROR: No data in lc so cannot average; exiting... ')
    sys.exit()

avglc = atlas_lc(tnsname, is_averaged=True, mjd_bin_size=mjd_bin_size, discdate=lc.discdate)
avglc.lcs[0] = pdastrostatsclass(columns=['MJD','MJDbin','uJy','duJy','stdev','x2','Nclip','Ngood','Nexcluded','Mask'],hexcols=['Mask'])
print('Averaging light curve with the following criteria: MJD bin size = %0.1f day(s), Nclip_max = %d, Ngood_min = %d, x2_max = %0.2f... ' % (mjd_bin_size, Nclip_max, Ngood_min, x2_max))
avglc = average_lc(lc, avglc, Nclip_max, Ngood_min, x2_max, keep_empty_bins=keep_empty_bins)

# print statistics and plot
num_cut = len(avglc.lcs[0].ix_masked('Mask', flags['avg_badday']))
percent_cut = (num_cut/len(avglc.lcs[0].t)) * 100
print(f'\nNumber of cut measurements: {num_cut:d}\nPercent of total measurements cut: {percent_cut:0.2f}%')
f.write(f'\n\n## Averaging and cutting bad bins\nNumber of cut measurements: {num_cut}\nPercent of total measurements cut: {percent_cut:0.2f}%\nHex value in "Mask" column: 0x800000')
if percent_cut > 10:
    print(f'WARNING: percent of total measurements cut is greater than 10%')
limits = [xlim_lower, xlim_upper, ylim_lower, ylim_upper]
plot_cut_lc(avglc, 'Averaged light curve', flags['avg_badday'], limits=limits, save_filename='averaged')

## Optional: Correct for ATLAS reference template changes

This notebook takes into account ATLAS's periodic replacement of the difference image reference templates, which may cause step discontinuities in flux. Two template changes have been recorded at MJDs 58417 and 58882. More information can be found here: https://fallingstar-data.com/forcedphot/faq/.

In [None]:
# Set to True for ATLAS template change correction
template_correction = False

In [None]:
# 3-panel plot of (1) template regions in different colors and (2) zoom-in on transitions

def plot_template_correction(lc, limits=None):
    colors = ['salmon', 'sandybrown', 'darkseagreen']
    
    t1, t2 = 58417, 58882
    region1_ix = lc.lcs[0].ix_inrange('MJD', uplim=t1)
    region2_ix = lc.lcs[0].ix_inrange('MJD', lowlim=t1, uplim=t2)
    region3_ix = lc.lcs[0].ix_inrange('MJD', lowlim=t2)

    region1_mean = lc._get_mean(region1_ix[-40:]) # last 40 measurements before t1
    region2a_mean = lc._get_mean(region2_ix[:40]) # first 40 measurements after t1
    region2b_mean = lc._get_mean(region2_ix[-40:]) # last 40 measurements before t2
    region3_mean = lc._get_mean(region3_ix[:40]) # first 40 measurements after t2

    gs = gridspec.GridSpec(2, 2, height_ratios=[1, 1], hspace=0.35, wspace=0.4)
    fig = plt.figure()
    fig.set_figwidth(6)
    fig.set_figheight(6)
    fig.tight_layout()

    ax1 = plt.subplot(gs[0, :])
    ax1.axvline(x=t1, color='k', linestyle='dotted', label='ATLAS template change', zorder=100)
    ax1.axvline(x=t2, color='k', linestyle='dotted', zorder=100)
    ax1.axhline(color='k',zorder=0)
    limits = set_xylimits(lc, limits)
    ax1.set_xlim(limits[0], limits[1])
    ax1.set_ylim(limits[2], limits[3])

    ax1.errorbar(lc.lcs[0].t.loc[region1_ix,'MJD'], lc.lcs[0].t.loc[region1_ix,'uJy'], yerr=lc.lcs[0].t.loc[region1_ix,lc.dflux_colnames[0]], fmt='none', ecolor=colors[0], elinewidth=1, capsize=1.2, c=sn_flux, alpha=0.5, zorder=10)
    ax1.scatter(lc.lcs[0].t.loc[region1_ix,'MJD'], lc.lcs[0].t.loc[region1_ix,'uJy'], s=marker_size, lw=marker_edgewidth, color=colors[0], marker='o', alpha=0.5, zorder=10, label='Region 1 flux')
    ax1.errorbar(lc.lcs[0].t.loc[region2_ix,'MJD'], lc.lcs[0].t.loc[region2_ix,'uJy'], yerr=lc.lcs[0].t.loc[region2_ix,lc.dflux_colnames[0]], fmt='none', ecolor=colors[1], elinewidth=1, capsize=1.2, c=sn_flux, alpha=0.5, zorder=10)
    ax1.scatter(lc.lcs[0].t.loc[region2_ix,'MJD'], lc.lcs[0].t.loc[region2_ix,'uJy'], s=marker_size, lw=marker_edgewidth, color=colors[1], marker='o', alpha=0.5, zorder=10, label='Region 2 flux')
    ax1.errorbar(lc.lcs[0].t.loc[region3_ix,'MJD'], lc.lcs[0].t.loc[region3_ix,'uJy'], yerr=lc.lcs[0].t.loc[region3_ix,lc.dflux_colnames[0]], fmt='none', ecolor=colors[2], elinewidth=1, capsize=1.2, c=sn_flux, alpha=0.5, zorder=10)
    ax1.scatter(lc.lcs[0].t.loc[region3_ix,'MJD'], lc.lcs[0].t.loc[region3_ix,'uJy'], s=marker_size, lw=marker_edgewidth, color=colors[2], marker='o', alpha=0.5, zorder=10, label='Region 2 flux')

    ax1.legend(facecolor='white', framealpha=1, loc='upper left',  bbox_to_anchor=(1, 1))

    ax2 = plt.subplot(gs[1, 0])
    ax2.axvline(x=t1, color='k', linestyle='dotted', zorder=100)
    ax2.axhline(color='k',zorder=0)
    ax2.set_xlim(lc.lcs[0].t.loc[region1_ix[-40:][0], 'MJD'], lc.lcs[0].t.loc[region2_ix[:40][-1], 'MJD'])

    ax2.errorbar(lc.lcs[0].t.loc[region1_ix,'MJD'], lc.lcs[0].t.loc[region1_ix,'uJy'], yerr=lc.lcs[0].t.loc[region1_ix,lc.dflux_colnames[0]], fmt='none', ecolor=colors[0], elinewidth=1, capsize=1.2, c=sn_flux, alpha=0.5, zorder=10)
    ax2.scatter(lc.lcs[0].t.loc[region1_ix,'MJD'], lc.lcs[0].t.loc[region1_ix,'uJy'], s=marker_size, lw=marker_edgewidth, color=colors[0], marker='o', alpha=0.5, zorder=10)
    ax2.errorbar(lc.lcs[0].t.loc[region2_ix,'MJD'], lc.lcs[0].t.loc[region2_ix,'uJy'], yerr=lc.lcs[0].t.loc[region2_ix,lc.dflux_colnames[0]], fmt='none', ecolor=colors[1], elinewidth=1, capsize=1.2, c=sn_flux, alpha=0.5, zorder=10)
    ax2.scatter(lc.lcs[0].t.loc[region2_ix,'MJD'], lc.lcs[0].t.loc[region2_ix,'uJy'], s=marker_size, lw=marker_edgewidth, color=colors[1], marker='o', alpha=0.5, zorder=10)

    ax2.axhline(y=region1_mean, color=colors[0], linestyle='dashed', label='Region 1 mean')
    ax2.axhline(y=region2a_mean, color=colors[1], linestyle='dashed', label='Region 2 mean')

    limits = set_xylimits(lc, limits, indices=AorB(region1_ix[-40:], region2_ix[:40]))
    ax2.set_ylim(limits[2], limits[3])
    ax2.legend(facecolor='white', framealpha=1)

    ax3 = plt.subplot(gs[1, 1]) 
    ax3.axvline(x=t2, color='k', linestyle='dotted', zorder=100)
    ax3.axhline(color='k',zorder=0)
    ax3.set_xlim(lc.lcs[0].t.loc[region2_ix[-40:][0], 'MJD'], lc.lcs[0].t.loc[region3_ix[:40][-1], 'MJD'])
    ax3.set_ylim(limits[2], limits[3])

    ax3.errorbar(lc.lcs[0].t.loc[region2_ix,'MJD'], lc.lcs[0].t.loc[region2_ix,'uJy'], yerr=lc.lcs[0].t.loc[region2_ix,lc.dflux_colnames[0]], fmt='none', ecolor=colors[1], elinewidth=1, capsize=1.2, c=sn_flux, alpha=0.5, zorder=10)
    ax3.scatter(lc.lcs[0].t.loc[region2_ix,'MJD'], lc.lcs[0].t.loc[region2_ix,'uJy'], s=marker_size, lw=marker_edgewidth, color=colors[1], marker='o', alpha=0.5, zorder=10)
    ax3.errorbar(lc.lcs[0].t.loc[region3_ix,'MJD'], lc.lcs[0].t.loc[region3_ix,'uJy'], yerr=lc.lcs[0].t.loc[region3_ix,lc.dflux_colnames[0]], fmt='none', ecolor=colors[2], elinewidth=1, capsize=1.2, c=sn_flux, alpha=0.5, zorder=10)
    ax3.scatter(lc.lcs[0].t.loc[region3_ix,'MJD'], lc.lcs[0].t.loc[region3_ix,'uJy'], s=marker_size, lw=marker_edgewidth, color=colors[2], marker='o', alpha=0.5, zorder=10)

    ax3.axhline(y=region2b_mean, color=colors[1], linestyle='dashed', label='Region 2 mean')
    ax3.axhline(y=region3_mean, color=colors[2], linestyle='dashed', label='Region 3 mean')

    limits = set_xylimits(lc, limits, indices=AorB(region2_ix[-40:], region3_ix[:40]))
    ax3.set_ylim(limits[2], limits[3])
    ax3.legend(facecolor='white', framealpha=1)

    for ax in (ax1, ax2, ax3):
        ax.minorticks_on()
        ax.tick_params(direction='in', which='both')
        ax.set_xlabel('MJD')
        ax.set_ylabel('uJy')

if template_correction:
    plot_template_correction(lc, limits=[None]*4)

In [None]:
# Optionally enter manual offsets for each region (set each to None for automatic correction)
global_offset = None
region1_offset = None
region2_offset = None
region3_offset = None

In [None]:
if template_correction:
    output = lc.template_correction(maskval=0x1|0x2|0x400000|0x800000, 
                                    region1_offset=region1_offset, 
                                    region2_offset=region2_offset, 
                                    region3_offset=region3_offset)
    f.write(f'\n\n## ATLAS template change correction\n{output}')
else:
    print('Skipping template correction')

In [None]:
if template_correction:
    plot_template_correction(lc, limits=[None]*4)

## Optional: save the SN light curve with the new `'Mask'` and `'duJy_new'` columns

Hex values in the `'Mask'` column for each cut's flag:
- Uncertainty cut: 0x2
- Chi-square cut: 0x1
- Control light curve cut: 0x400000
- Bad day (for averaged light curves): 0x800000

You can combine these hex values together to create certain combinations of cuts that define a "bad" measurement.

In [None]:
# save the SN and control light curves
save_lc = False

# save the SN and control averaged light curves
save_avglc = False

In [None]:
if save_lc:
    print('Saving light curve with updated mask column...')
    lc._save(source_dir, filt=filt, overwrite=True)

In [None]:
if save_avglc:
    print('Saving averaged light curve with updated mask column...')
    avglc._save(source_dir, filt=filt, overwrite=True)

## Summary

In [None]:
# print summary of all cuts and corrections

f.close()
f1 = open(f'{source_dir}/{tnsname}/{tnsname}_output.md')
content = f1.read()
print()
print(content)
f1.close()