# Single File Calibration

**by Josh Dillon, Aaron Parsons, Tyler Cox, and Zachary Martinot**, last updated July 15, 2024

This notebook is designed to infer as much information about the array from a single file, including pushing the calibration and RFI mitigation as far as possible. Calibration includes redundant-baseline calibration, RFI-based calibration of delay slopes, model-based calibration of overall amplitudes, and a full per-frequency phase gradient absolute calibration if abscal model files are available.

Here's a set of links to skip to particular figures and tables:
# [• Figure 1: RFI Flagging](#Figure-1:-RFI-Flagging)
# [• Figure 2: Plot of autocorrelations with classifications](#Figure-2:-Plot-of-autocorrelations-with-classifications)
# [• Figure 3: Summary of antenna classifications prior to calibration](#Figure-3:-Summary-of-antenna-classifications-prior-to-calibration)
# [• Figure 4: Redundant calibration of a single baseline group](#Figure-4:-Redundant-calibration-of-a-single-baseline-group)
# [• Figure 5: Absolute calibration of redcal degeneracies](#Figure-5:-Absolute-calibration-of-redcal-degeneracies)
# [• Figure 6: Relative Phase Calibration](#Figure-6:-Relative-Phase-Calibration)
# [• Figure 7: chi^2 per antenna across the array](#Figure-7:-chi^2-per-antenna-across-the-array)
# [• Figure 8: Summary of antenna classifications after redundant calibration](#Figure-8:-Summary-of-antenna-classifications-after-redundant-calibration)
# [• Table 1: Complete summary of per antenna classifications](#Table-1:-Complete-summary-of-per-antenna-classifications)


In [2]:
import time
tstart = time.time()
!hostname

wario


In [7]:
import os
import time
os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'
import h5py
import hdf5plugin  # REQUIRED to have the compression plugins available
import numpy as np
from scipy import constants, interpolate
import copy
import glob
import re
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
pd.set_option('display.max_rows', 1000)
from uvtools.plot import plot_antpos, plot_antclass
from hera_qm import ant_metrics, ant_class, xrfi
from hera_cal import io, utils, redcal, apply_cal, datacontainer, abscal
from hera_filters import dspec
#from hera_notebook_templates.data import DATA_PATH as HNBT_DATA
from IPython.display import display, HTML
import linsolve
display(HTML("<style>.container { width:100% !important; }</style>"))
_ = np.seterr(all='ignore')  # get rid of red warnings
%config InlineBackend.figure_format = 'retina'

In [2]:
# this enables better memory management on linux
import ctypes
def malloc_trim():
    try:
        ctypes.CDLL('libc.so.6').malloc_trim(0) 
    except OSError:
        pass

## Parse inputs and outputs

To use this notebook interactively, you will have to provide a sum filename path if none exists as an environment variable. All other parameters have reasonable default values.


In [3]:
# figure out whether to save results
SAVE_RESULTS = os.environ.get("SAVE_RESULTS", "TRUE").upper() == "TRUE"
SAVE_OMNIVIS_FILE = os.environ.get("SAVE_OMNIVIS_FILE", "FALSE").upper() == "TRUE"


# get infile names
# SUM_FILE = os.environ.get("SUM_FILE", None)
SUM_FILE = "/safepool/rbyrne/hera_data/H6C-data/2459861/zen.2459861.45004.sum.abs_calibrated.red_avg.uvh5"
DIFF_FILE = SUM_FILE.replace('sum', 'diff')

# get outfilenames
AM_FILE = (SUM_FILE.replace('.uvh5', '.ant_metrics.hdf5') if SAVE_RESULTS else None)
ANTCLASS_FILE = (SUM_FILE.replace('.uvh5', '.ant_class.csv') if SAVE_RESULTS else None)
OMNICAL_FILE = (SUM_FILE.replace('.uvh5', '.omni.calfits') if SAVE_RESULTS else None)
OMNIVIS_FILE = (SUM_FILE.replace('.uvh5', '.omni_vis.uvh5') if SAVE_RESULTS else None)

for fname in ['SUM_FILE', 'DIFF_FILE', 'AM_FILE', 'ANTCLASS_FILE', 'OMNICAL_FILE', 'OMNIVIS_FILE', 'SAVE_RESULTS', 'SAVE_OMNIVIS_FILE']:
    print(f"{fname} = '{eval(fname)}'")

SUM_FILE = '/safepool/rbyrne/hera_data/H6C-data/2459861/zen.2459861.45004.sum.abs_calibrated.red_avg.uvh5'
DIFF_FILE = '/safepool/rbyrne/hera_data/H6C-data/2459861/zen.2459861.45004.diff.abs_calibrated.red_avg.uvh5'
AM_FILE = '/safepool/rbyrne/hera_data/H6C-data/2459861/zen.2459861.45004.sum.abs_calibrated.red_avg.ant_metrics.hdf5'
ANTCLASS_FILE = '/safepool/rbyrne/hera_data/H6C-data/2459861/zen.2459861.45004.sum.abs_calibrated.red_avg.ant_class.csv'
OMNICAL_FILE = '/safepool/rbyrne/hera_data/H6C-data/2459861/zen.2459861.45004.sum.abs_calibrated.red_avg.omni.calfits'
OMNIVIS_FILE = '/safepool/rbyrne/hera_data/H6C-data/2459861/zen.2459861.45004.sum.abs_calibrated.red_avg.omni_vis.uvh5'
SAVE_RESULTS = 'True'
SAVE_OMNIVIS_FILE = 'False'


### Parse settings
Load settings relating to the operation of the notebook, then print what was loaded (or default).

In [4]:
# parse plotting settings
PLOT = os.environ.get("PLOT", "TRUE").upper() == "TRUE"
if PLOT:
    %matplotlib inline

# parse omnical settings
OC_MAX_DIMS = int(os.environ.get("OC_MAX_DIMS", 4))
OC_MIN_DIM_SIZE = int(os.environ.get("OC_MIN_DIM_SIZE", 8))
OC_SKIP_OUTRIGGERS = os.environ.get("OC_SKIP_OUTRIGGERS", "FALSE").upper() == "TRUE"
OC_MIN_BL_LEN = float(os.environ.get("OC_MIN_BL_LEN", 1))
OC_MAX_BL_LEN = float(os.environ.get("OC_MAX_BL_LEN", 1e100))
OC_MAXITER = int(os.environ.get("OC_MAXITER", 50))
OC_MAX_RERUN = int(os.environ.get("OC_MAX_RERUN", 4))
OC_RERUN_MAXITER = int(os.environ.get("OC_MAXITER", 25))
OC_MAX_CHISQ_FLAGGING_DYNAMIC_RANGE = float(os.environ.get("OC_MAX_CHISQ_FLAGGING_DYNAMIC_RANGE", 1))
OC_USE_PRIOR_SOL = os.environ.get("OC_USE_PRIOR_SOL", "FALSE").upper() == "TRUE"
OC_PRIOR_SOL_FLAG_THRESH = float(os.environ.get("OC_PRIOR_SOL_FLAG_THRESH", .95))
OC_USE_GPU = os.environ.get("SAVE_RESULTS", "FALSE").upper() == "TRUE"

# parse RFI settings
RFI_DPSS_HALFWIDTH = float(os.environ.get("RFI_DPSS_HALFWIDTH", 300e-9))
RFI_NSIG = float(os.environ.get("RFI_NSIG", 4))

# parse abscal settings
ABSCAL_MODEL_FILES_GLOB = os.environ.get("ABSCAL_MODEL_FILES_GLOB", None)
ABSCAL_MIN_BL_LEN = float(os.environ.get("ABSCAL_MIN_BL_LEN", 1.0))
ABSCAL_MAX_BL_LEN = float(os.environ.get("ABSCAL_MAX_BL_LEN", 140.0))
CALIBRATE_CROSS_POLS = os.environ.get("CALIBRATE_CROSS_POLS", "FALSE").upper() == "TRUE"

# print settings
for setting in ['PLOT', 'SAVE_RESULTS', 'OC_MAX_DIMS', 'OC_MIN_DIM_SIZE', 'OC_SKIP_OUTRIGGERS', 
                'OC_MIN_BL_LEN', 'OC_MAX_BL_LEN', 'OC_MAXITER', 'OC_MAX_RERUN', 'OC_RERUN_MAXITER', 
                'OC_MAX_CHISQ_FLAGGING_DYNAMIC_RANGE', 'OC_USE_PRIOR_SOL', 'OC_PRIOR_SOL_FLAG_THRESH', 
                'OC_USE_GPU', 'RFI_DPSS_HALFWIDTH', 'RFI_NSIG', 'ABSCAL_MODEL_FILES_GLOB', 
                'ABSCAL_MIN_BL_LEN', 'ABSCAL_MAX_BL_LEN', "CALIBRATE_CROSS_POLS"]:
    print(f'{setting} = {eval(setting)}')

PLOT = True
SAVE_RESULTS = True
OC_MAX_DIMS = 4
OC_MIN_DIM_SIZE = 8
OC_SKIP_OUTRIGGERS = False
OC_MIN_BL_LEN = 1.0
OC_MAX_BL_LEN = 1e+100
OC_MAXITER = 50
OC_MAX_RERUN = 4
OC_RERUN_MAXITER = 25
OC_MAX_CHISQ_FLAGGING_DYNAMIC_RANGE = 1.0
OC_USE_PRIOR_SOL = False
OC_PRIOR_SOL_FLAG_THRESH = 0.95
OC_USE_GPU = False
RFI_DPSS_HALFWIDTH = 3e-07
RFI_NSIG = 4.0
ABSCAL_MODEL_FILES_GLOB = None
ABSCAL_MIN_BL_LEN = 1.0
ABSCAL_MAX_BL_LEN = 140.0
CALIBRATE_CROSS_POLS = False


### Parse bounds
Load settings related to classifying antennas as good, suspect, or bad, then print what was loaded (or default).

In [5]:
# ant_metrics bounds for low correlation / dead antennas
am_corr_bad = (0, float(os.environ.get("AM_CORR_BAD", 0.3)))
am_corr_suspect = (float(os.environ.get("AM_CORR_BAD", 0.3)), float(os.environ.get("AM_CORR_SUSPECT", 0.5)))

# ant_metrics bounds for cross-polarized antennas
am_xpol_bad = (-1, float(os.environ.get("AM_XPOL_BAD", -0.1)))
am_xpol_suspect = (float(os.environ.get("AM_XPOL_BAD", -0.1)), float(os.environ.get("AM_XPOL_SUSPECT", 0)))

# bounds on solar altitude (in degrees)
good_solar_altitude = (-90, float(os.environ.get("SUSPECT_SOLAR_ALTITUDE", 0)))
suspect_solar_altitude = (float(os.environ.get("SUSPECT_SOLAR_ALTITUDE", 0)), 90)

# bounds on zeros in spectra
good_zeros_per_eo_spectrum = (0, int(os.environ.get("MAX_ZEROS_PER_EO_SPEC_GOOD", 2)))
suspect_zeros_per_eo_spectrum = (0, int(os.environ.get("MAX_ZEROS_PER_EO_SPEC_SUSPECT", 8)))

# bounds on autocorrelation power
auto_power_good = (float(os.environ.get("AUTO_POWER_GOOD_LOW", 5)), float(os.environ.get("AUTO_POWER_GOOD_HIGH", 30)))
auto_power_suspect = (float(os.environ.get("AUTO_POWER_SUSPECT_LOW", 1)), float(os.environ.get("AUTO_POWER_SUSPECT_HIGH", 60)))

# bounds on autocorrelation slope
auto_slope_good = (float(os.environ.get("AUTO_SLOPE_GOOD_LOW", -0.4)), float(os.environ.get("AUTO_SLOPE_GOOD_HIGH", 0.4)))
auto_slope_suspect = (float(os.environ.get("AUTO_SLOPE_SUSPECT_LOW", -0.6)), float(os.environ.get("AUTO_SLOPE_SUSPECT_HIGH", 0.6)))

# bounds on autocorrelation RFI
auto_rfi_good = (0, float(os.environ.get("AUTO_RFI_GOOD", 1.5)))
auto_rfi_suspect = (0, float(os.environ.get("AUTO_RFI_SUSPECT", 2)))

# bounds on autocorrelation shape
auto_shape_good = (0, float(os.environ.get("AUTO_SHAPE_GOOD", 0.1)))
auto_shape_suspect = (0, float(os.environ.get("AUTO_SHAPE_SUSPECT", 0.2)))

# bound on per-xengine non-noiselike power in diff
bad_xengine_zcut = float(os.environ.get("BAD_XENGINE_ZCUT", 10.0))

# bounds on chi^2 per antenna in omnical
oc_cspa_good = (0, float(os.environ.get("OC_CSPA_GOOD", 2)))
oc_cspa_suspect = (0, float(os.environ.get("OC_CSPA_SUSPECT", 3)))

# print bounds
for bound in ['am_corr_bad', 'am_corr_suspect', 'am_xpol_bad', 'am_xpol_suspect', 
              'good_solar_altitude', 'suspect_solar_altitude',
              'good_zeros_per_eo_spectrum', 'suspect_zeros_per_eo_spectrum',
              'auto_power_good', 'auto_power_suspect', 'auto_slope_good', 'auto_slope_suspect',
              'auto_rfi_good', 'auto_rfi_suspect', 'auto_shape_good', 'auto_shape_suspect',
              'bad_xengine_zcut', 'oc_cspa_good', 'oc_cspa_suspect']:
    print(f'{bound} = {eval(bound)}')

am_corr_bad = (0, 0.3)
am_corr_suspect = (0.3, 0.5)
am_xpol_bad = (-1, -0.1)
am_xpol_suspect = (-0.1, 0.0)
good_solar_altitude = (-90, 0.0)
suspect_solar_altitude = (0.0, 90)
good_zeros_per_eo_spectrum = (0, 2)
suspect_zeros_per_eo_spectrum = (0, 8)
auto_power_good = (5.0, 30.0)
auto_power_suspect = (1.0, 60.0)
auto_slope_good = (-0.4, 0.4)
auto_slope_suspect = (-0.6, 0.6)
auto_rfi_good = (0, 1.5)
auto_rfi_suspect = (0, 2.0)
auto_shape_good = (0, 0.1)
auto_shape_suspect = (0, 0.2)
bad_xengine_zcut = 10.0
oc_cspa_good = (0, 2.0)
oc_cspa_suspect = (0, 3.0)


## Load sum and diff data

In [8]:
read_start = time.time()
hd = io.HERADataFastReader(SUM_FILE)
data, _, _ = hd.read(read_flags=False, read_nsamples=False)
#hd_diff = io.HERADataFastReader(DIFF_FILE)
#diff_data, _, _ = hd_diff.read(read_flags=False, read_nsamples=False, dtype=np.complex64, fix_autos_func=np.real)
print(f'Finished loading data in {(time.time() - read_start) / 60:.2f} minutes.')

Finished loading data in 0.01 minutes.


In [15]:
ants = sorted(set([ant for bl in hd.bls for ant in utils.split_bl(bl)]))
auto_bls = [bl for bl in data if (bl[0] == bl[1]) and (utils.split_pol(bl[2])[0] == utils.split_pol(bl[2])[1])]
antpols = sorted(set([ant[1] for ant in ants]))

In [14]:
# print basic information about the file
print(f'File: {SUM_FILE}')
print(f'JDs: {hd.times} ({np.median(np.diff(hd.times)) * 24 * 3600:.5f} s integrations)')
print(f'LSTS: {hd.lsts * 12 / np.pi } hours')
print(f'Frequencies: {len(hd.freqs)} {np.median(np.diff(hd.freqs)) / 1e6:.5f} MHz channels from {hd.freqs[0] / 1e6:.5f} to {hd.freqs[-1] / 1e6:.5f} MHz')
print(f'Antennas: {len(hd.data_ants)}')
print(f'Polarizations: {hd.pols}')

File: /safepool/rbyrne/hera_data/H6C-data/2459861/zen.2459861.45004.sum.abs_calibrated.red_avg.uvh5
JDs: [2459861.44998783 2459861.45009967] (9.66368 s integrations)
LSTS: [1.39788835 1.40058006] hours
Frequencies: 1536 0.12207 MHz channels from 46.92078 to 234.29871 MHz
Antennas: 172
Polarizations: ['nn', 'ee']


## Classify good, suspect, and bad antpols

### Run `ant_metrics`

This classifies antennas as cross-polarized, low-correlation, or dead. Such antennas are excluded from any calibration.

In [15]:
am = ant_metrics.AntennaMetrics(SUM_FILE, DIFF_FILE, sum_data=data, diff_data=diff_data)
am.iterative_antenna_metrics_and_flagging(crossCut=am_xpol_bad[1], deadCut=am_corr_bad[1])
am.all_metrics = {}  # this saves time and disk by getting rid of per-iteration information we never use
if SAVE_RESULTS:
    am.save_antenna_metrics(AM_FILE, overwrite=True)

NameError: name 'diff_data' is not defined

In [None]:
# Turn ant metrics into classifications
totally_dead_ants = [ant for ant, i in am.xants.items() if i == -1]
am_totally_dead = ant_class.AntennaClassification(good=[ant for ant in ants if ant not in totally_dead_ants], bad=totally_dead_ants)
am_corr = ant_class.antenna_bounds_checker(am.final_metrics['corr'], bad=[am_corr_bad], suspect=[am_corr_suspect], good=[(0, 1)])
am_xpol = ant_class.antenna_bounds_checker(am.final_metrics['corrXPol'], bad=[am_xpol_bad], suspect=[am_xpol_suspect], good=[(-1, 1)])
ant_metrics_class = am_totally_dead + am_corr + am_xpol
if np.all([ant_metrics_class[utils.split_bl(bl)[0]] == 'bad' for bl in auto_bls]):
    raise ValueError('All antennas are flagged for ant_metrics.')

### Mark sun-up (or high solar altitude) data as suspect

In [16]:
min_sun_alt = np.min(utils.get_sun_alt(hd.times))
solar_class = ant_class.antenna_bounds_checker({ant: min_sun_alt for ant in ants}, good=[good_solar_altitude], suspect=[suspect_solar_altitude])

### Classify antennas responsible for 0s in visibilities as bad: 

This classifier looks for X-engine failure or packet loss specific to an antenna which causes either the even visibilities (or the odd ones, or both) to be 0s. 

In [17]:
zeros_class = ant_class.even_odd_zeros_checker(data, diff_data, good=good_zeros_per_eo_spectrum, suspect=suspect_zeros_per_eo_spectrum)
if np.all([zeros_class[utils.split_bl(bl)[0]] == 'bad' for bl in auto_bls]):
    raise ValueError('All antennas are flagged for too many even/odd zeros.')

NameError: name 'diff_data' is not defined

### Examine and classify autocorrelation power and slope

These classifiers look for antennas with too high or low power or to steep a slope.

In [18]:
auto_power_class = ant_class.auto_power_checker(data, good=auto_power_good, suspect=auto_power_suspect)
auto_slope_class = ant_class.auto_slope_checker(data, good=auto_slope_good, suspect=auto_slope_suspect, edge_cut=100, filt_size=17)
if np.all([(auto_power_class + auto_slope_class)[utils.split_bl(bl)[0]] == 'bad' for bl in auto_bls]):
    raise ValueError('All antennas are flagged for bad autocorrelation power/slope.')
overall_class = auto_power_class + auto_slope_class + zeros_class + ant_metrics_class + solar_class

ValueError: All antennas are flagged for bad autocorrelation power/slope.

### Find starting set of array flags

In [None]:
antenna_flags, array_flags = xrfi.flag_autos(data, flag_method="channel_diff_flagger", nsig=RFI_NSIG * 5, 
                                             antenna_class=overall_class, flag_broadcast_thresh=.5)
for key in antenna_flags:
    antenna_flags[key] = array_flags
cache = {}
_, array_flags = xrfi.flag_autos(data, freqs=data.freqs, flag_method="dpss_flagger",
                                 nsig=RFI_NSIG, antenna_class=overall_class,
                                 filter_centers=[0], filter_half_widths=[RFI_DPSS_HALFWIDTH],
                                 eigenval_cutoff=[1e-9], flags=antenna_flags, mode='dpss_matrix', 
                                 cache=cache, flag_broadcast_thresh=.5)

### Classify antennas based on non-noiselike diffs

In [None]:
xengine_diff_class = ant_class.non_noiselike_diff_by_xengine_checker(data, diff_data, flag_waterfall=array_flags, 
                                                                     antenna_class=overall_class, 
                                                                     xengine_chans=96, bad_xengine_zcut=bad_xengine_zcut)
overall_class += xengine_diff_class
if np.all([overall_class[utils.split_bl(bl)[0]] == 'bad' for bl in auto_bls]):
    raise ValueError('All antennas are flagged after flagging non-noiselike diffs.')

### Examine and classify autocorrelation excess RFI and shape, finding consensus RFI mask along the way

This classifier iteratively identifies antennas for excess RFI (characterized by RMS of DPSS-filtered autocorrelations after RFI flagging) and bad shape, as determined by a discrepancy with the mean good normalized autocorrelation's shape. Along the way, it iteratively discovers a conensus array-wide RFI mask.

In [None]:
def auto_bl_zscores(data, flag_array, cache={}):
    '''This function computes z-score arrays for each delay-filtered autocorrelation, normalized by the expected noise. 
    Flagged times/channels for the whole array are given 0 weight in filtering and are np.nan in the z-score.'''
    zscores = {}
    for bl in auto_bls:
        wgts = np.array(np.logical_not(flag_array), dtype=np.float64)
        model, _, _ = dspec.fourier_filter(hd.freqs, data[bl], wgts, filter_centers=[0], filter_half_widths=[RFI_DPSS_HALFWIDTH], mode='dpss_solve',
                                            suppression_factors=[1e-9], eigenval_cutoff=[1e-9], cache=cache)
        res = data[bl] - model
        int_time = 24 * 3600 * np.median(np.diff(data.times))
        chan_res = np.median(np.diff(data.freqs))
        int_count = int(int_time * chan_res)
        sigma = np.abs(model) / np.sqrt(int_count / 2)
        zscores[bl] = res / sigma    
        zscores[bl][flag_array] = np.nan

    return zscores

In [None]:
def rfi_from_avg_autos(data, auto_bls_to_use, prior_flags=None, nsig=RFI_NSIG):
    '''Average together all baselines in auto_bls_to_use, then find an RFI mask by looking for outliers after DPSS filtering.'''
    
    # Compute int_count for all unflagged autocorrelations averaged together
    int_time = 24 * 3600 * np.median(np.diff(data.times_by_bl[auto_bls[0][0:2]]))
    chan_res = np.median(np.diff(data.freqs))
    int_count = int(int_time * chan_res) * len(auto_bls_to_use)
    avg_auto = {(-1, -1, 'ee'): np.mean([data[bl] for bl in auto_bls_to_use], axis=0)}
    
    # Flag RFI first with channel differences and then with DPSS
    antenna_flags, _ = xrfi.flag_autos(avg_auto, int_count=int_count, nsig=(nsig * 5))
    if prior_flags is not None:
        antenna_flags[(-1, -1, 'ee')] = prior_flags
    _, rfi_flags = xrfi.flag_autos(avg_auto, int_count=int_count, flag_method='dpss_flagger',
                                   flags=antenna_flags, freqs=data.freqs, filter_centers=[0],
                                   filter_half_widths=[RFI_DPSS_HALFWIDTH], eigenval_cutoff=[1e-9], nsig=nsig)

    return rfi_flags

In [None]:
# Iteratively develop RFI mask, excess RFI classification, and autocorrelation shape classification
stage = 1
rfi_flags = np.array(array_flags)
prior_end_states = set()
while True:
    # compute DPSS-filtered z-scores with current array-wide RFI mask
    zscores = auto_bl_zscores(data, rfi_flags)
    rms = {bl: np.nanmean(zscores[bl]**2)**.5 if np.any(np.isfinite(zscores[bl])) else np.inf for bl in zscores}
    
    # figure out which autos to use for finding new set of flags
    candidate_autos = [bl for bl in auto_bls if overall_class[utils.split_bl(bl)[0]] != 'bad']
    if stage == 1:
        # use best half of the unflagged antennas
        med_rms = np.nanmedian([rms[bl] for bl in candidate_autos])
        autos_to_use = [bl for bl in candidate_autos if rms[bl] <= med_rms]
    elif stage == 2:
        # use all unflagged antennas which are auto RFI good, or the best half, whichever is larger
        med_rms = np.nanmedian([rms[bl] for bl in candidate_autos])
        best_half_autos = [bl for bl in candidate_autos if rms[bl] <= med_rms]
        good_autos = [bl for bl in candidate_autos if (overall_class[utils.split_bl(bl)[0]] != 'bad')
                      and (auto_rfi_class[utils.split_bl(bl)[0]] == 'good')]
        autos_to_use = (best_half_autos if len(best_half_autos) > len(good_autos) else good_autos)
    elif stage == 3:
        # use all unflagged antennas which are auto RFI good or suspect
        autos_to_use = [bl for bl in candidate_autos if (overall_class[utils.split_bl(bl)[0]] != 'bad')]

    # compute new RFI flags
    rfi_flags = rfi_from_avg_autos(data, autos_to_use)

    # perform auto shape and RFI classification
    overall_class = auto_power_class + auto_slope_class + zeros_class + ant_metrics_class + solar_class + xengine_diff_class
    auto_rfi_class = ant_class.antenna_bounds_checker(rms, good=auto_rfi_good, suspect=auto_rfi_suspect, bad=(0, np.inf))
    overall_class += auto_rfi_class
    auto_shape_class = ant_class.auto_shape_checker(data, good=auto_shape_good, suspect=auto_shape_suspect,
                                                    flag_spectrum=np.sum(rfi_flags, axis=0).astype(bool), 
                                                    antenna_class=overall_class)
    overall_class += auto_shape_class
    
    # check for convergence by seeing whether we've previously gotten to this number of flagged antennas and channels
    if stage == 3:
        if (len(overall_class.bad_ants), np.sum(rfi_flags)) in prior_end_states:
            break
        prior_end_states.add((len(overall_class.bad_ants), np.sum(rfi_flags)))
    else:
        stage += 1

In [None]:
auto_class = auto_power_class + auto_slope_class + auto_rfi_class + auto_shape_class
if np.all([overall_class[utils.split_bl(bl)[0]] == 'bad' for bl in auto_bls]):
    raise ValueError('All antennas are flagged after flagging for bad autos power/slope/rfi/shape.')

In [None]:
def rfi_plot(cls, flags=rfi_flags):
    avg_auto = {(-1, -1, 'ee'): np.mean([data[bl] for bl in auto_bls if not cls[utils.split_bl(bl)[0]] == 'bad'], axis=0)}
    plt.figure(figsize=(12, 5), dpi=100)
    plt.semilogy(hd.freqs / 1e6, np.where(flags, np.nan, avg_auto[(-1, -1, 'ee')])[0], label = 'Average Good or Suspect Autocorrelation', zorder=100)
    plt.semilogy(hd.freqs / 1e6, np.where(False, np.nan, avg_auto[(-1, -1, 'ee')])[0], 'r', lw=.5, label=f'{np.sum(flags[0])} Channels Flagged for RFI')
    plt.legend()
    plt.xlabel('Frequency (MHz)')
    plt.ylabel('Uncalibrated Autocorrelation')
    plt.tight_layout()

# *Figure 1: RFI Flagging*

This figure shows RFI identified using the average of all autocorrelations---excluding bad antennas---for the first integration in the file. 

In [None]:
if PLOT: rfi_plot(overall_class)

In [None]:
def autocorr_plot(cls):    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5), dpi=100, sharey=True, gridspec_kw={'wspace': 0})
    labels = []
    colors = ['darkgreen', 'goldenrod', 'maroon']
    for ax, pol in zip(axes, antpols):
        for ant in cls.ants:
            if ant[1] == pol:
                color = colors[cls.quality_classes.index(cls[ant])]
                ax.semilogy(np.mean(data[utils.join_bl(ant, ant)], axis=0), color=color, lw=.5)
        ax.set_xlabel('Channel', fontsize=12)
        ax.set_title(f'{utils.join_pol(pol, pol)}-Polarized Autos')

    axes[0].set_ylabel('Raw Autocorrelation', fontsize=12)
    axes[1].legend([matplotlib.lines.Line2D([0], [0], color=color) for color in colors], 
                   [cl.capitalize() for cl in cls.quality_classes], ncol=1, fontsize=12, loc='upper right', framealpha=1)
    plt.tight_layout()

# *Figure 2: Plot of autocorrelations with classifications*
This figure shows a plot of all autocorrelations in the array, split by polarization. 
Antennas are classified based on their autocorrelations into good, suspect, and bad, by examining power, slope, and RFI-occupancy. 

In [None]:
if PLOT: autocorr_plot(auto_class)

### Summarize antenna classification prior to redundant-baseline calibration

In [None]:
def array_class_plot(cls, extra_label=""):
    outriggers = [ant for ant in hd.data_ants if ant >= 320]

    if len(outriggers) > 0:
        fig, axes = plt.subplots(1, 2, figsize=(14, 6), dpi=100, gridspec_kw={'width_ratios': [2, 1]})
        plot_antclass(hd.antpos, cls, ax=axes[0], ants=[ant for ant in hd.data_ants if ant < 320], legend=False, title=f'HERA Core{extra_label}')
        plot_antclass(hd.antpos, cls, ax=axes[1], ants=outriggers, radius=50, title='Outriggers')
    else:
        fig, axes = plt.subplots(1, 1, figsize=(9, 6), dpi=100)
        plot_antclass(hd.antpos, cls, ax=axes, ants=[ant for ant in hd.data_ants if ant < 320], legend=False, title=f'HERA Core{extra_label}')

# *Figure 3: Summary of antenna classifications prior to calibration*
This figure shows the location and classification of all antennas prior to calibration. 
Antennas are split along the diagonal, with ee-polarized antpols represented by the southeast half of each antenna and nn-polarized antpols represented by the northwest half.
Outriggers are split from the core and shown at exaggerated size in the right-hand panel. This classification includes `ant_metrics`, a count of the zeros in the even or odd visibilities, and autocorrelation power, slope, and RFI occupancy.
An antenna classified as bad in *any* classification will be considered bad. 
An antenna marked as suspect *any* in any classification will be considered suspect unless it is also classified as bad elsewhere.

In [None]:
if PLOT: array_class_plot(overall_class)

In [None]:
# delete diffs to save memory
del diff_data, hd_diff, cache
malloc_trim()

## Perform redundant-baseline calibration

In [9]:
def classify_off_grid(reds, all_ants):
    '''Returns AntennaClassification of all_ants where good ants are in reds while bad ants are not.'''
    ants_in_reds = set([ant for red in reds for bl in red for ant in utils.split_bl(bl)])
    on_grid = [ant for ant in all_ants if ant in ants_in_reds]
    off_grid = [ant for ant in all_ants if ant not in ants_in_reds]
    return ant_class.AntennaClassification(good=on_grid, bad=off_grid)

In [10]:
def per_pol_filter_reds(reds, pols=['nn', 'ee'], **kwargs):
    '''Performs redcal filtering separately on polarizations (which might have different min_dim_size issues).'''
    return [red for pol in pols for red in redcal.filter_reds(copy.deepcopy(reds), pols=[pol], **kwargs)]

In [11]:
def check_if_whole_pol_flagged(redcal_class, pols=['Jee', 'Jnn']):
    '''Checks if an entire polarization is flagged. If it is, returns True and marks all antennas as bad in redcal_class.'''
    if np.logical_or(*[np.all([redcal_class[ant] == 'bad' for ant in redcal_class.ants if ant[1] == pol]) for pol in pols]):
        print('An entire polarization has been flagged. Stopping redcal.')
        for ant in redcal_class:
            redcal_class[ant] = 'bad'
        return True
    return False

In [12]:
def recheck_chisq(cspa, sol, cutoff, avg_alg):
    '''Recompute chisq per ant without apparently bad antennas to see if any antennas get better.'''
    avg_cspa = {ant: avg_alg(np.where(rfi_flags, np.nan, cspa[ant])) for ant in cspa}
    sol2 = redcal.RedSol(sol.reds, gains={ant: sol[ant] for ant in avg_cspa if avg_cspa[ant] <= cutoff}, vis=sol.vis)
    new_chisq_per_ant = {ant: np.array(cspa[ant]) for ant in sol2.gains}
    if len(set([bl[2] for red in per_pol_filter_reds(sol2.reds, ants=sol2.gains.keys(), antpos=hd.data_antpos, **fr_settings) for bl in red])) >= 2:
        redcal.expand_omni_gains(sol2, sol2.reds, data, chisq_per_ant=new_chisq_per_ant)
    for ant in avg_cspa:
        if ant in new_chisq_per_ant:
            if np.any(np.isfinite(new_chisq_per_ant[ant])):
                if not np.all(np.isclose(new_chisq_per_ant[ant], 0)):
                    new_avg_cspa = avg_alg(np.where(rfi_flags, np.nan, cspa[ant]))
                    if new_avg_cspa > 0:
                        avg_cspa[ant] = np.min([avg_cspa[ant], new_avg_cspa])
    return avg_cspa

### Perform iterative `redcal`

In [16]:
# figure out and filter reds and classify antennas based on whether or not they are on the main grid
fr_settings = {'max_dims': OC_MAX_DIMS, 'min_dim_size': OC_MIN_DIM_SIZE, 'min_bl_cut': OC_MIN_BL_LEN, 'max_bl_cut': OC_MAX_BL_LEN}
reds = redcal.get_reds(hd.data_antpos, pols=['ee', 'nn'], pol_mode='2pol', bl_error_tol=2.0)
# reds = per_pol_filter_reds(reds, ex_ants=overall_class.bad_ants, antpos=hd.data_antpos, **fr_settings)
reds = per_pol_filter_reds(
    reds, ex_ants=[], antpos=hd.data_antpos, **fr_settings
)
if OC_SKIP_OUTRIGGERS:
    reds = redcal.filter_reds(reds, ex_ants=[ant for ant in ants if ant[0] >= 320])
redcal_class = classify_off_grid(reds, ants)

In [17]:
if OC_USE_PRIOR_SOL:
    # Find closest omnical file
    omnical_files = sorted(glob.glob('.'.join(OMNICAL_FILE.split('.')[:-5]) + '.*.' + '.'.join(OMNICAL_FILE.split('.')[-3:])))
    if len(omnical_files) == 0:
        OC_USE_PRIOR_SOL = False
    else:
        omnical_jds = np.array([float(re.findall("\d+\.\d+", ocf)[-1]) for ocf in omnical_files])
        closest_omnical = omnical_files[np.argmin(np.abs(omnical_jds - data.times[0]))]

        # Load closest omnical file and use it if the antenna flagging is not too dissimilar
        hc = io.HERACal(closest_omnical)
        prior_gains, prior_flags, _, _ = hc.read()
        not_bad_not_prior_flagged = [ant for ant in overall_class if not ant in redcal_class.bad_ants and not np.all(prior_flags[ant])]
        if (len(redcal_class.bad_ants) == len(redcal_class.ants)):
            OC_USE_PRIOR_SOL = False  # all antennas flagged
        elif (len(not_bad_not_prior_flagged) / (len(redcal_class.ants) - len(redcal_class.bad_ants))) < OC_PRIOR_SOL_FLAG_THRESH:
            OC_USE_PRIOR_SOL = False  # too many antennas unflaged that were flagged in the prior sol
        else:
            print(f'Using {closest_omnical} as a starting point for redcal.')

In [18]:
redcal_start = time.time()
rc_settings = {'max_dims': OC_MAX_DIMS, 'oc_conv_crit': 1e-10, 'gain': 0.4, 'run_logcal': False,
               'oc_maxiter': OC_MAXITER, 'check_after': OC_MAXITER, 'use_gpu': OC_USE_GPU}

if check_if_whole_pol_flagged(redcal_class):
    # skip redcal, initialize empty sol and meta 
    sol = redcal.RedSol(reds)
    meta = {'chisq': None, 'chisq_per_ant': None}
else:    
    if OC_USE_PRIOR_SOL:
        # use prior unflagged gains and data to create starting point for next step
        ants_in_reds = set([ant for red in reds for bl in red for ant in utils.split_bl(bl)])
        sol = redcal.RedSol(reds=reds, gains={ant: prior_gains[ant] for ant in not_bad_not_prior_flagged})
        reds_to_update = [[bl for bl in red if (utils.split_bl(bl)[0] in sol.gains) and (utils.split_bl(bl)[1] in sol.gains)] for red in reds]
        reds_to_update = [red for red in reds_to_update if len(red) > 0]
        sol.update_vis_from_data(data, reds_to_update=reds_to_update)
        redcal.expand_omni_gains(sol, reds, data)
        sol.update_vis_from_data(data)
    else:
        # perform first stage of redundant calibration 
        meta, sol = redcal.redundantly_calibrate(data, reds, **rc_settings)
        max_dly = np.max(np.abs(list(meta['fc_meta']['dlys'].values())))  # Needed for RFI delay-slope cal
        median_cspa = recheck_chisq(meta['chisq_per_ant'], sol, oc_cspa_suspect[1] * 5, np.nanmedian)
         # remove particularly bad antennas (5x the bound on median, not mean)
        cspa_class = ant_class.antenna_bounds_checker(median_cspa, good=(oc_cspa_good[0], oc_cspa_suspect[1] * 5), bad=[(-np.inf, np.inf)])
        redcal_class += cspa_class
        print(f'Removing {cspa_class.bad_ants} for >5x high median chi^2.')
        for ant in cspa_class.bad_ants:
            print(f'\t{ant}: {median_cspa[ant]:.3f}')
        
    malloc_trim()

KeyError: "Cannot find either (4, 5, 'nn') or (5, 4, 'nn') in this DataContainer."

In [None]:
# iteratively rerun redundant calibration
redcal_done = False
rc_settings['oc_maxiter'] = rc_settings['check_after'] = OC_RERUN_MAXITER
for i in range(OC_MAX_RERUN + 1):
    # refilter reds and update classification to reflect new off-grid ants, if any
    reds = per_pol_filter_reds(reds, ex_ants=(overall_class + redcal_class).bad_ants, antpos=hd.data_antpos, **fr_settings)
    reds = sorted(reds, key=len, reverse=True)
    redcal_class += classify_off_grid(reds, ants)
    ants_in_reds = set([ant for red in reds for bl in red for ant in utils.split_bl(bl)])
    
    # check to see whether we're done
    if check_if_whole_pol_flagged(redcal_class) or redcal_done or (i == OC_MAX_RERUN):
        break

    # re-run redundant calibration using previous solution, updating bad and suspicious antennas
    meta, sol = redcal.redundantly_calibrate(data, reds, sol0=sol, **rc_settings)
    malloc_trim()
    
    # recompute chi^2 for bad antennas without bad antennas to make sure they are actually bad
    mean_cspa = recheck_chisq(meta['chisq_per_ant'], sol, oc_cspa_suspect[1], np.nanmean)
    
    # remove bad antennas
    cspa_class = ant_class.antenna_bounds_checker(mean_cspa, good=oc_cspa_good, suspect=oc_cspa_suspect, bad=[(-np.inf, np.inf)])
    for ant in cspa_class.bad_ants:
        if mean_cspa[ant] < np.max(list(mean_cspa.values())) / OC_MAX_CHISQ_FLAGGING_DYNAMIC_RANGE:
            cspa_class[ant] = 'suspect'  # reclassify as suspect if they are much better than the worst antennas
    redcal_class += cspa_class
    print(f'Removing {cspa_class.bad_ants} for high mean unflagged chi^2.')
    for ant in cspa_class.bad_ants:
        print(f'\t{ant}: {mean_cspa[ant]:.3f}')

    if len(cspa_class.bad_ants) == 0:
        redcal_done = True  # no new antennas to flag

print(f'Finished redcal in {(time.time() - redcal_start) / 60:.2f} minutes.')

In [None]:
overall_class += redcal_class

### Expand solution to include calibratable baselines excluded from redcal (e.g. because they were too long)

In [None]:
expanded_reds = redcal.get_reds(hd.data_antpos, pols=['ee', 'nn'], pol_mode='2pol', bl_error_tol=2.0)
expanded_reds = per_pol_filter_reds(expanded_reds, ex_ants=(ant_metrics_class + solar_class + zeros_class + auto_class + xengine_diff_class).bad_ants,
                                    max_dims=OC_MAX_DIMS, min_dim_size=OC_MIN_DIM_SIZE)
if OC_SKIP_OUTRIGGERS:
    expanded_reds = redcal.filter_reds(expanded_reds, ex_ants=[ant for ant in ants if ant[0] >= 320])
if len(sol.gains) > 0:
    redcal.expand_omni_vis(sol, expanded_reds, data, chisq=meta['chisq'], chisq_per_ant=meta['chisq_per_ant'])

In [None]:
# now figure out flags, nsamples etc.
omni_flags = {ant: (~np.isfinite(g)) | (ant in overall_class.bad_ants) for ant, g in sol.gains.items()}
vissol_flags = datacontainer.RedDataContainer({bl: ~np.isfinite(v) for bl, v in sol.vis.items()}, reds=sol.vis.reds)
single_nsamples_array = np.ones((len(hd.times), len(hd.freqs)), dtype=float)
nsamples = datacontainer.DataContainer({bl: single_nsamples_array for bl in data})
vissol_nsamples = redcal.count_redundant_nsamples(nsamples, [red for red in expanded_reds if red[0] in vissol_flags], 
                                                  good_ants=[ant for ant in overall_class if ant not in overall_class.bad_ants])
for bl in vissol_flags:
    vissol_flags[bl][vissol_nsamples[bl] == 0] = True
sol.make_sol_finite()

### Fix the `firstcal` delay slope degeneracy using RFI transmitters

In [None]:
if not OC_USE_PRIOR_SOL:
    # find channels clearly contaminated by RFI
    not_bad_ants = [ant for ant in overall_class.ants if overall_class[ant] != 'bad']
    if len(not_bad_ants) > 0:
        chan_flags = np.mean([xrfi.detrend_medfilt(data[utils.join_bl(ant, ant)], Kf=8, Kt=2) for ant in not_bad_ants], axis=(0, 1)) > 5

        # hardcoded RFI transmitters and their headings
        # channel: frequency (Hz), heading (rad), chi^2
        phs_sol = {359: ( 90744018.5546875, 0.7853981, 23.3),
                   360: ( 90866088.8671875, 0.7853981, 10.8),
                   385: ( 93917846.6796875, 0.7853981, 27.3),
                   386: ( 94039916.9921875, 0.7853981, 18.1),
                   400: ( 95748901.3671875, 6.0632738, 24.0),
                   441: (100753784.1796875, 0.7853981, 21.7),
                   442: (100875854.4921875, 0.7853981, 19.4),
                   455: (102462768.5546875, 6.0632738, 18.8),
                   456: (102584838.8671875, 6.0632738,  8.8),
                   471: (104415893.5546875, 0.7853981, 13.3),
                   484: (106002807.6171875, 6.0632738, 21.2),
                   485: (106124877.9296875, 6.0632738,  4.0),
                  1181: (191085815.4296875, 0.7853981, 26.3),
                  1182: (191207885.7421875, 0.7853981, 27.0),
                  1183: (191329956.0546875, 0.7853981, 25.6),
                  1448: (223678588.8671875, 2.6075219, 25.7),
                  1449: (223800659.1796875, 2.6075219, 22.6),
                  1450: (223922729.4921875, 2.6075219, 11.6),
                  1451: (224044799.8046875, 2.6075219,  5.9),
                  1452: (224166870.1171875, 2.6075219, 22.6),
                  1510: (231246948.2421875, 0.1068141, 23.9)}

        if not np.isclose(hd.freqs[0], 46920776.3671875, atol=0.001) or len(hd.freqs) != 1536:
            # We have less frequencies than usual (maybe testing)
            phs_sol = {np.argmin(np.abs(hd.freqs - freq)): (freq, heading, chisq) for chan, (freq, heading, chisq) in phs_sol.items() if hd.freqs[0] <= freq <= hd.freqs[-1]}


        rfi_chans = [chan for chan in phs_sol if chan_flags[chan]]
        print('Channels used for delay-slope calibration with RFI:', rfi_chans)
        rfi_angles = np.array([phs_sol[chan][1] for chan in rfi_chans])
        rfi_headings = np.array([np.cos(rfi_angles), np.sin(rfi_angles), np.zeros_like(rfi_angles)])
        rfi_chisqs = np.array([phs_sol[chan][2] for chan in rfi_chans])

        # resolve firstcal degeneracy with delay slopes set by RFI transmitters, update cal
        RFI_dly_slope_gains = abscal.RFI_delay_slope_cal([red for red in expanded_reds if red[0] in sol.vis], hd.antpos, sol.vis, hd.freqs, rfi_chans, rfi_headings, rfi_wgts=rfi_chisqs**-1,
                                                         min_tau=-max_dly, max_tau=max_dly, delta_tau=0.1e-9, return_gains=True, gain_ants=sol.gains.keys())
        sol.gains = {ant: g * RFI_dly_slope_gains[ant] for ant, g in sol.gains.items()}
        apply_cal.calibrate_in_place(sol.vis, RFI_dly_slope_gains)
        malloc_trim()

### Perform absolute amplitude calibration using a model of autocorrelations

In [None]:
# Load simulated and then downsampled model of autocorrelations that includes receiver noise, then interpolate to upsample
hd_auto_model = io.HERAData(f'{HNBT_DATA}/SSM_autocorrelations_downsampled_sum_pol_convention.uvh5')
model, _, _ = hd_auto_model.read()
per_pol_interpolated_model = {}
for bl in model:
    sorted_lsts, lst_indices = np.unique(model.lsts, return_index=True)
    periodic_model = np.vstack([model[bl][lst_indices, :], model[bl][lst_indices[0], :]])
    periodic_lsts = np.append(sorted_lsts, sorted_lsts[0] + 2 * np.pi)
    lst_interpolated = interpolate.CubicSpline(periodic_lsts, periodic_model, axis=0, bc_type='periodic')(data.lsts)
    per_pol_interpolated_model[bl[2]] = interpolate.CubicSpline(model.freqs, lst_interpolated, axis=1)(data.freqs)
model = {bl: per_pol_interpolated_model[bl[2]] for bl in auto_bls if utils.split_bl(bl)[0] not in overall_class.bad_ants}

In [None]:
# Run abscal and update omnical gains with abscal gains
if len(model) > 0:
    redcaled_autos = {bl: sol.calibrate_bl(bl, data[bl]) for bl in auto_bls if utils.split_bl(bl)[0] not in overall_class.bad_ants}
    g_abscal = abscal.abs_amp_logcal(model, redcaled_autos, verbose=False, return_gains=True, gain_ants=sol.gains)
    sol.gains = {ant: g * g_abscal[ant] for ant, g in sol.gains.items()}
    apply_cal.calibrate_in_place(sol.vis, g_abscal)
    del redcaled_autos, g_abscal

### Full absolute calibration of phase gradients
If an `ABSCAL_MODEL_FILES_GLOB` is provided, try to perform a full absolute calibration of tip-tilt phase gradients across the array using that those model files. Specifically, this step calibrates omnical visbility solutions using unique baselines simulated with a model of the sky and HERA's beam.

In [None]:
if ABSCAL_MODEL_FILES_GLOB is not None:
    abscal_model_files = sorted(glob.glob(ABSCAL_MODEL_FILES_GLOB))
else:
    # try to find files on site
    abscal_model_files = sorted(glob.glob('/mnt/sn1/data1/abscal_models/H6C/zen.2458894.?????.uvh5'))
    if len(abscal_model_files) == 0:
        # try to find files at NRAO
        abscal_model_files = sorted(glob.glob('/lustre/aoc/projects/hera/h6c-analysis/abscal_models/h6c_abscal_files_unique_baselines/zen.2458894.?????.uvh5'))
print(f'Found {len(abscal_model_files)} abscal model files{" in " + os.path.dirname(abscal_model_files[0]) if len(abscal_model_files) > 0 else ""}.')

In [None]:
# Try to perform a full abscal of phase
if len(abscal_model_files) == 0:
    DO_FULL_ABSCAL = False
    print('No model files found... not performing full absolute calibration of phase gradients.')
elif np.all([ant in overall_class.bad_ants for ant in ants]):
    DO_FULL_ABSCAL = False
    print('All antennas classified as bad... skipping absolute calibration of phase gradients.')
else:
    abscal_start = time.time()
    # figure out which model files match the LSTs of the data
    matched_model_files = sorted(set(abscal.match_times(SUM_FILE, abscal_model_files, filetype='uvh5')))
    if len(matched_model_files) == 0:
        DO_FULL_ABSCAL = False
        print(f'No model files found matching the LSTs of this file after searching for {(time.time() - abscal_start) / 60:.2f} minutes. '
              'Not performing full absolute calibration of phase gradients.')
    else:
        DO_FULL_ABSCAL = True
        # figure out appropriate model times to load
        hdm = io.HERAData(matched_model_files)
        all_model_times, all_model_lsts = abscal.get_all_times_and_lsts(hdm, unwrap=True)
        d2m_time_map = abscal.get_d2m_time_map(data.times, np.unwrap(data.lsts), all_model_times, all_model_lsts, extrap_limit=.5)

In [None]:
if DO_FULL_ABSCAL:
    abscal_meta = {}
    for pol in ['ee', 'nn']:
        print(f'Performing absolute phase gradient calibration of {pol}-polarized visibility solutions...')
        
        # load matching times and baselines
        unflagged_data_bls = [bl for bl in vissol_flags if not np.all(vissol_flags[bl]) and bl[2] == pol]
        model_bls = copy.deepcopy(hdm.bls)
        model_antpos = hdm.data_antpos
        if len(matched_model_files) > 1:  # in this case, it's a dictionary
            model_bls = list(set([bl for bls in list(hdm.bls.values()) for bl in bls]))
            model_antpos = {ant: pos for antpos in hdm.data_antpos.values() for ant, pos in antpos.items()}
        data_bls, model_bls, data_to_model_bl_map = abscal.match_baselines(unflagged_data_bls, model_bls, data.antpos, model_antpos=model_antpos, 
                                                                         pols=[pol], data_is_redsol=True, model_is_redundant=True, tol=1.0,
                                                                         min_bl_cut=ABSCAL_MIN_BL_LEN, max_bl_cut=ABSCAL_MAX_BL_LEN, verbose=True)
        model, model_flags, _ = io.partial_time_io(hdm, np.unique([d2m_time_map[time] for time in data.times]), bls=model_bls)
        model_bls = [data_to_model_bl_map[bl] for bl in data_bls]
        
        # rephase model to match in lsts
        model_blvecs = {bl: model.antpos[bl[0]] - model.antpos[bl[1]] for bl in model.keys()}
        utils.lst_rephase(model, model_blvecs, model.freqs, data.lsts - model.lsts,
                          lat=hdm.telescope_location_lat_lon_alt_degrees[0], inplace=True)

        # run abscal and apply 
        abscal_meta[pol], delta_gains = abscal.complex_phase_abscal(sol.vis, model, sol.reds, data_bls, model_bls)
        
        # apply gains
        sol.gains = {antpol : g * delta_gains.get(antpol, 1) for antpol, g in sol.gains.items()}
        apply_cal.calibrate_in_place(sol.vis, delta_gains)            
     
    del model, model_flags, delta_gains
    malloc_trim()    
    
    print(f'Finished absolute calibration of tip-tilt phase slopes in {(time.time() - abscal_start) / 60:.2f} minutes.')

In [None]:
if DO_FULL_ABSCAL and CALIBRATE_CROSS_POLS:
    cross_pol_cal_start = time.time()

    # Compute reds for good antennas 
    cross_reds = redcal.get_reds(hd.data_antpos, pols=['en', 'ne'], bl_error_tol=2.0)        
    cross_reds = redcal.filter_reds(cross_reds, ex_ants=overall_class.bad_ants, pols=['en', 'ne'], antpos=hd.antpos, **fr_settings)    
    unflagged_data_bls = [red[0] for red in cross_reds]

    # Get cross-polarized model visibilities
    model_bls = copy.deepcopy(hdm.bls)
    model_antpos = hdm.data_antpos
    if len(matched_model_files) > 1:  # in this case, it's a dictionary
        model_bls = list(set([bl for bls in list(hdm.bls.values()) for bl in bls]))
        model_antpos = {ant: pos for antpos in hdm.data_antpos.values() for ant, pos in antpos.items()}

    data_bls, model_bls, data_to_model_bl_map = abscal.match_baselines(unflagged_data_bls, model_bls, data.antpos, model_antpos=model_antpos, 
                                                                     pols=['en', 'ne'], data_is_redsol=False, model_is_redundant=True, tol=1.0,
                                                                     min_bl_cut=ABSCAL_MIN_BL_LEN, max_bl_cut=ABSCAL_MAX_BL_LEN, verbose=True)
    
    model, model_flags, _ = io.partial_time_io(hdm, np.unique([d2m_time_map[time] for time in data.times]), bls=model_bls)
    model_bls = [data_to_model_bl_map[bl] for bl in data_bls]

    # rephase model to match in lsts
    model_blvecs = {bl: model.antpos[bl[0]] - model.antpos[bl[1]] for bl in model.keys()}
    utils.lst_rephase(model, model_blvecs, model.freqs, data.lsts - model.lsts, lat=hdm.telescope_location_lat_lon_alt_degrees[0], inplace=True)


    wgts_here = {}
    data_here = {}

    
    for red in cross_reds:
        data_bl = red[0]
        if data_bl in data_to_model_bl_map:

            wgts_here[data_bl] = np.sum([
                np.logical_not(omni_flags[utils.split_bl(bl)[0]] | omni_flags[utils.split_bl(bl)[1]])
                for bl in red
            ], axis=0)
            data_here[data_bl] = np.nanmean([
                np.where(
                    omni_flags[utils.split_bl(bl)[0]] | omni_flags[utils.split_bl(bl)[1]],
                    np.nan, sol.calibrate_bl(bl, data[bl])
                ) 
                for bl in red
            ], axis=0)
    
    # Run cross-polarized phase calibration
    delta_gains = abscal.cross_pol_phase_cal(
        model=model, data=data_here, wgts=wgts_here, data_bls=data_bls, model_bls=model_bls, return_gains=True,
        gain_ants=sol.gains.keys()
    )
    
    # apply gains
    # \Delta = \phi_e - \phi_n, where V_{en}^{cal} = V_{en}^{uncal} * e^{i \Delta} 
    # and V_{ne}^{cal} = V_{ne}^{uncal} * e^{-i \Delta}
    sol.gains = {antpol: g * delta_gains[antpol] for antpol, g in sol.gains.items()}
    apply_cal.calibrate_in_place(sol.vis, delta_gains)
    del hdm, model, model_flags, delta_gains
    print(f'Finished relative polarized phase calibration in {(time.time() - cross_pol_cal_start) / 60:.2f} minutes.')

## Plotting

In [None]:
def redundant_group_plot():
    if np.all([ant in overall_class.bad_ants for ant in ants]):
        print('All antennas classified as bad. Nothing to plot.')
        return
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 6), dpi=100, sharex='col', sharey='row', gridspec_kw={'hspace': 0, 'wspace': 0})
    for i, pol in enumerate(['ee', 'nn']):
        reds_here = redcal.get_reds(hd.data_antpos, pols=[pol], pol_mode='1pol', bl_error_tol=2.0)
        red = sorted(redcal.filter_reds(reds_here, ex_ants=overall_class.bad_ants), key=len, reverse=True)[0]
        rc_data = {bl: sol.calibrate_bl(bl, data[bl]) for bl in red}
        for bl in red:
            axes[0, i].plot(hd.freqs/1e6, np.angle(rc_data[bl][0]), alpha=.5, lw=.5)
            axes[1, i].semilogy(hd.freqs/1e6, np.abs(rc_data[bl][0]), alpha=.5, lw=.5)
        axes[0, i].plot(hd.freqs / 1e6, np.angle(sol.vis[red[0]][0]), lw=1, c='k')
        axes[1, i].semilogy(hd.freqs / 1e6, np.abs(sol.vis[red[0]][0]), lw=1, c='k', label=f'Baseline Group:\n{red[0]}')
        axes[1, i].set_xlabel('Frequency (MHz)')
        axes[1, i].legend(loc='upper right')
    axes[0, 0].set_ylabel('Visibility Phase (radians)')
    axes[1, 0].set_ylabel('Visibility Amplitude (Jy)')
    plt.tight_layout()

In [None]:
def abscal_degen_plot():
    if DO_FULL_ABSCAL:
        fig, axes = plt.subplots(3, 1, figsize=(14, 6), dpi=100, sharex=True, gridspec_kw={'hspace': .05})

        for ax, pol in zip(axes[:2], ['ee', 'nn']):
            for kk in range(abscal_meta[pol]['Lambda_sol'].shape[-1]):
                ax.plot(hd.freqs[~rfi_flags[0]] * 1e-6, abscal_meta[pol]['Lambda_sol'][0, ~rfi_flags[0], kk], '.', ms=1, label=f"Component {kk}")

            ax.set_ylim(-np.pi-0.5, np.pi+0.5)
            ax.set_xlabel('Frequency (MHz)')
            ax.set_ylabel('Phase Gradient\nVector Component')
            ax.legend(markerscale=20, title=f'{pol}-polarization', loc='lower right')
            ax.grid()
            
        for pol, color in zip(['ee', 'nn'], ['b', 'r']):
            axes[2].plot(hd.freqs[~rfi_flags[0]]*1e-6, abscal_meta[pol]['Z_sol'].real[0, ~rfi_flags[0]], '.', ms=1, label=pol, color=color)
        axes[2].set_ylim(-.25, 1.05)
        axes[2].set_ylabel('Re[Z($\\nu$)]')
        axes[2].legend(markerscale=20, loc='lower right')
        axes[2].grid()            
        plt.tight_layout()

In [None]:
def polarized_gain_phase_plot():
    if CALIBRATE_CROSS_POLS and DO_FULL_ABSCAL:
        plt.figure(figsize=(14, 4), dpi=100)
        for i, time in enumerate(data.times):
            plt.plot(data.freqs / 1e6, np.where(rfi_flags[i], np.nan, delta[i, :]), '.', ms=1.5, label=f'{time:.6f}')
        plt.ylim(-np.pi-0.5, np.pi+0.5)
        plt.xlabel('Frequency (MHz)')
        plt.ylabel('Relative Phase $\Delta \ (\phi_{ee} - \phi_{nn})$')
        plt.grid()
        plt.legend()

# *Figure 4: Redundant calibration of a single baseline group*
The results of a redundant-baseline calibration of a single integration and a single group, the one with the highest redundancy in each polarization after antenna classification and excision based on the above, plus the removal of antennas with high chi^2 per antenna. The black line is the redundant visibility solution. Each thin colored line is a different baseline group. Phases are shown in the top row, amplitudes in the bottom, ee-polarized visibilities in the left column, and nn-polarized visibilities in the right.

In [None]:
if PLOT: redundant_group_plot()

# *Figure 5: Absolute calibration of `redcal` degeneracies*
This figure shows the per-frequency phase gradient solutions across the array for both polarizations and all components of the degenerate subspace of redundant-baseline calibraton. While full HERA only has two such tip-tilt degeneracies, a subset of HERA can have up to `OC_MAX_DIMS` (depending on antenna flagging). In addition to the absolute amplitude, this is the full set of the calibration degrees of freedom not constrainted by `redcal`. This figure also includes a plot of $Re[Z(\nu)]$, the complex objective function which varies from -1 to 1 and indicates how well the data and the absolute calibration model have been made to agree. Perfect agreement is 1.0 and good agreement is anything above $\sim$0.5 Decorrelation yields values closer to 0, where anything below $\sim$0.3 is suspect.

In [None]:
if PLOT: abscal_degen_plot()

# *Figure 6: Relative Phase Calibration*

This figure shows the relative phase calibration between the `ee` vs. `nn` polarizations.

In [None]:
if PLOT: polarized_gain_phase_plot()

### Attempt to calibrate some flagged antennas
This attempts to calibrate bad antennas using information from good or suspect antennas without allowing bad antennas to affect their calibration. However, introducing 0s in gains or infs/nans in gains or visibilities can create problems down the line, so those are removed.

In [None]:
expand_start = time.time()
expanded_reds = redcal.get_reds(hd.data_antpos, pols=['ee', 'nn'], pol_mode='2pol', bl_error_tol=2.0)
sol.vis.build_red_keys(expanded_reds)
redcal.expand_omni_gains(sol, expanded_reds, data, chisq_per_ant=meta['chisq_per_ant'])
if not np.all([ant in overall_class.bad_ants for ant in ants]):
    redcal.expand_omni_vis(sol, expanded_reds, data)

# Replace near-zeros in gains and infs/nans in gains/sols
for ant in sol.gains:
    zeros_in_gains = np.isclose(sol.gains[ant], 0)
    if ant in omni_flags:
        omni_flags[ant][zeros_in_gains] = True
    sol.gains[ant][zeros_in_gains] = 1.0 + 0.0j
sol.make_sol_finite()
malloc_trim()
print(f'Finished expanding gain solution in {(time.time() - expand_start) / 60:.2f} minutes.')

In [None]:
def array_chisq_plot(include_outriggers=True):
    if np.all([ant in overall_class.bad_ants for ant in ants]):
        print('All antennas classified as bad. Nothing to plot.')
        return    
    
    def _chisq_subplot(ants, size=250):
        fig, axes = plt.subplots(1, 2, figsize=(14, 5), dpi=100)
        for ax, pol in zip(axes, ['ee', 'nn']):
            ants_to_plot = set([ant for ant in meta['chisq_per_ant'] if utils.join_pol(ant[1], ant[1]) == pol and (ant[0] in ants)])
            cspas = np.array([np.nanmean(np.where(rfi_flags, np.nan, meta['chisq_per_ant'][ant])) for ant in ants_to_plot])
            xpos = [hd.antpos[ant[0]][0] for ant in ants_to_plot]
            ypos = [hd.antpos[ant[0]][1] for ant in ants_to_plot]
            scatter = ax.scatter(xpos, ypos, s=size, c=cspas, lw=.25, edgecolors=np.where(np.isfinite(cspas) & (cspas > 0), 'none', 'k'), 
                                 norm=matplotlib.colors.LogNorm(vmin=1, vmax=oc_cspa_suspect[1]))
            for ant in ants_to_plot:
                ax.text(hd.antpos[ant[0]][0], hd.antpos[ant[0]][1], ant[0], va='center', ha='center', fontsize=8,
                        c=('r' if ant in overall_class.bad_ants else 'w'))
            plt.colorbar(scatter, ax=ax, extend='both')
            ax.axis('equal')
            ax.set_xlabel('East-West Position (meters)')
            ax.set_ylabel('North-South Position (meters)')
            ax.set_title(f'{pol}-pol $\\chi^2$ / Antenna (Red is Flagged)')
        plt.tight_layout()    
    
    _chisq_subplot([ant for ant in hd.data_ants if ant < 320])
    outriggers = [ant for ant in hd.data_ants if ant >= 320]    
    if include_outriggers & (len(outriggers) > 0):
        _chisq_subplot([ant for ant in hd.data_ants if ant >= 320], size=400)

# *Figure 7: chi^2 per antenna across the array*

This plot shows median (taken over time and frequency) of the normalized chi^2 per antenna. 
The expectation value for this quantity when the array is perfectly redundant is 1.0. 
Antennas that are classified as bad for any reason have their numbers shown in red. 
Some of those antennas were classified as bad during redundant calibration for high chi^2. 
Some of those antennas were originally excluded from redundant calibration because they were classified as bad earlier for some reason. 
See [here for more details.](#Attempt-to-calibrate-some-flagged-antennas)
Note that the color scale saturates at below 1 and above 10. 

In [None]:
if PLOT: array_chisq_plot(include_outriggers=(not OC_SKIP_OUTRIGGERS))

# *Figure 8: Summary of antenna classifications after redundant calibration*

This figure is the same as [Figure 2](#Figure-2:-Summary-of-antenna-classifications-prior-to-calibration), except that it now includes additional suspect or bad antennas based on redundant calibration.
This can include antennas with high chi^2, but it can also include antennas classified as "bad" because they would add extra degeneracies to calibration.

In [None]:
if PLOT: array_class_plot(overall_class, extra_label=", Post-Redcal")

In [None]:
to_show = {'Antenna': [f'{ant[0]}{ant[1][-1]}' for ant in ants]}
classes = {'Antenna': [overall_class[ant] if ant in overall_class else '-' for ant in ants]}
to_show['Dead?'] = [{'good': 'No', 'bad': 'Yes'}[am_totally_dead[ant]] if (ant in am_totally_dead) else '' for ant in ants]
classes['Dead?'] = [am_totally_dead[ant] if (ant in am_totally_dead) else '' for ant in ants]
for title, ac in [('Low Correlation', am_corr),
                  ('Cross-Polarized', am_xpol),
                  ('Solar Alt', solar_class),
                  ('Even/Odd Zeros', zeros_class),
                  ('Autocorr Power', auto_power_class),
                  ('Autocorr Slope', auto_slope_class),
                  ('Auto RFI RMS', auto_rfi_class),
                  ('Autocorr Shape', auto_shape_class),
                  ('Bad Diff X-Engines', xengine_diff_class)]:
    to_show[title] = [f'{ac._data[ant]:.2G}' if (ant in ac._data) else '' for ant in ants]
    classes[title] = [ac[ant] if ant in ac else 'bad' for ant in ants]
    
to_show['Redcal chi^2'] = [f'{np.nanmean(np.where(rfi_flags, np.nan, meta["chisq_per_ant"][ant])):.3G}' \
                           if (meta['chisq_per_ant'] is not None and ant in meta['chisq_per_ant']) else '' for ant in ants]
classes['Redcal chi^2'] = [redcal_class[ant] if ant in redcal_class else '' for ant in ants]

df = pd.DataFrame(to_show)
df_classes = pd.DataFrame(classes)
colors = {'good': 'darkgreen', 'suspect': 'goldenrod', 'bad': 'maroon'}
df_colors = df_classes.applymap(lambda x: f'background-color: {colors.get(x, None)}')

table = df.style.hide() \
                .apply(lambda x: pd.DataFrame(df_colors.values, columns=x.columns), axis=None) \
                .set_properties(subset=['Antenna'], **{'font-weight': 'bold', 'border-right': "3pt solid black"}) \
                .set_properties(subset=df.columns[1:], **{'border-left': "1pt solid black"}) \
                .set_properties(**{'text-align': 'center', 'color': 'white'})

# *Table 1: Complete summary of per-antenna classifications*

This table summarizes the results of the various classifications schemes detailed above. 
As before, <font color='#006400'>green is good</font>, <font color='#DAA520'>yellow is suspect</font>, and <font color='#800000'>red is bad</font>. The color for each antenna (first column) is the final summary of all other classifications.
Antennas missing from redcal $\chi^2$ were excluded redundant-baseline calibration, either because they were flagged by `ant_metrics` or the even/odd zeros check, or because they would add unwanted extra degeneracies.

In [None]:
HTML(table.to_html())

In [None]:
# Save antenna classification table as a csv
if SAVE_RESULTS:
    for ind, col in zip(np.arange(len(df.columns), 0, -1), df_classes.columns[::-1]):
        df.insert(int(ind), col + ' Class', df_classes[col])
    df.to_csv(ANTCLASS_FILE)    

In [None]:
print('Final Ant-Pol Classification:\n\n', overall_class)

## Save calibration solutions

In [None]:
# update flags in omnical gains and visibility solutions
for ant in omni_flags:
    omni_flags[ant] |= rfi_flags
for bl in vissol_flags:
    vissol_flags[bl] |= rfi_flags

In [None]:
if SAVE_RESULTS:
    add_to_history = 'Produced by file_calibration notebook with the following environment:\n' + '=' * 65 + '\n' + os.popen('conda env export').read() + '=' * 65    
    
    hd_vissol = io.HERAData(SUM_FILE)
    hc_omni = hd_vissol.init_HERACal(gain_convention='divide', cal_style='redundant')
    hc_omni.pol_convention = hd_auto_model.pol_convention
    hc_omni.gain_scale = hd_auto_model.vis_units
    hc_omni.update(gains=sol.gains, flags=omni_flags, quals=meta['chisq_per_ant'], total_qual=meta['chisq'])
    hc_omni.history += add_to_history
    hc_omni.write_calfits(OMNICAL_FILE, clobber=True)
    del hc_omni
    malloc_trim()
    
    if SAVE_OMNIVIS_FILE:
        # output results, harmonizing keys over polarizations
        all_reds = redcal.get_reds(hd.data_antpos, pols=['ee', 'nn', 'en', 'ne'], pol_mode='4pol', bl_error_tol=2.0)
        bl_to_red_map = {bl: red[0] for red in all_reds for bl in red}
        hd_vissol.read(bls=[bl_to_red_map[bl] for bl in sol.vis], return_data=False)
        hd_vissol.empty_arrays()
        hd_vissol.history += add_to_history
        hd_vissol.update(data={bl_to_red_map[bl]: sol.vis[bl] for bl in sol.vis}, 
                         flags={bl_to_red_map[bl]: vissol_flags[bl] for bl in vissol_flags}, 
                         nsamples={bl_to_red_map[bl]: vissol_nsamples[bl] for bl in vissol_nsamples})
        hd_vissol.pol_convention = hd_auto_model.pol_convention
        hd_vissol.vis_units = hd_auto_model.vis_units
        hd_vissol.write_uvh5(OMNIVIS_FILE, clobber=True)

In [None]:
if SAVE_RESULTS:
    del hd_vissol
    malloc_trim()

### Output fully flagged calibration file if `OMNICAL_FILE` is not written

In [None]:
if SAVE_RESULTS and not os.path.exists(OMNICAL_FILE):
    print(f'WARNING: No calibration file produced at {OMNICAL_FILE}. Creating a fully-flagged placeholder calibration file.')
    hd_writer = io.HERAData(SUM_FILE)
    io.write_cal(OMNICAL_FILE, freqs=hd_writer.freqs, times=hd_writer.times,
                 gains={ant: np.ones((hd_writer.Ntimes, hd_writer.Nfreqs), dtype=np.complex64) for ant in ants},
                 flags={ant: np.ones((len(data.times), len(data.freqs)), dtype=bool) for ant in ants},
                 quality=None, total_qual=None, outdir='', overwrite=True, history=utils.history_string(add_to_history), 
                 x_orientation=hd_writer.x_orientation, telescope_location=hd_writer.telescope_location, lst_array=np.unique(hd_writer.lsts),
                 antenna_positions=np.array([hd_writer.antenna_positions[hd_writer.antenna_numbers == antnum].flatten() for antnum in set(ant[0] for ant in ants)]),
                 antnums2antnames=dict(zip(hd_writer.antenna_numbers, hd_writer.antenna_names)))

### Output empty visibility file if `OMNIVIS_FILE` is not written

In [None]:
if SAVE_RESULTS and SAVE_OMNIVIS_FILE and not os.path.exists(OMNIVIS_FILE):
    print(f'WARNING: No omnivis file produced at {OMNIVIS_FILE}. Creating an empty visibility solution file.')
    hd_writer = io.HERAData(SUM_FILE)
    hd_writer.initialize_uvh5_file(OMNIVIS_FILE, clobber=True)

## Metadata

In [None]:
for repo in ['pyuvdata', 'hera_cal', 'hera_filters', 'hera_qm', 'hera_notebook_templates']:
    exec(f'from {repo} import __version__')
    print(f'{repo}: {__version__}')

In [None]:
print(f'Finished execution in {(time.time() - tstart) / 60:.2f} minutes.')