In [None]:
#!pip3 install --user astropy
#!pip3 install --user kaleido

In [None]:
import numpy as np
import astropy as ap
import pandas as pd
from astropy.io import fits
import scipy.linalg as slg
from scipy.stats import norm, pearsonr
#import scipy.stats
from math import ceil
import csv

import plotly.io as pio
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from plotly.offline import iplot
import kaleido

# import matplotlib.pyplot as plt
# import matplotlib.colors as colors
# from matplotlib.colors import LinearSegmentedColormap
import glob
import os
# These are in Functions
from os.path import join as pj
from os.path import exists # as pj
# from os.path import abspath as absp

from IPython.display import Image
from IPython.display import display

from joblib import Parallel, delayed

import PIL
import pickle

import sys
from collections import namedtuple as nt

In [None]:
os.environ["SPARCFIRE_HOME"] = "/home/portmanm/sparcfire_matt/"

_HOME_DIR = os.path.expanduser("~")
try:
    _SPARCFIRE_DIR = os.environ["SPARCFIRE_HOME"]
    _MODULE_DIR = pj(_SPARCFIRE_DIR, "GalfitModule")
except KeyError:
    if __name__ == "__main__":
        print("SPARCFIRE_HOME is not set. Please run 'setup.bash' inside SpArcFiRe directory if not done so already.")
        print("Checking the current directory for GalfitModule, otherwise quitting.")
            
        _MODULE_DIR = pj(os.getcwd(), "GalfitModule")
        
        if not exists(_MODULE_DIR):
            raise Exception("Could not find GalfitModule!")
    
sys.path.append(_MODULE_DIR)
from Classes.Components import *
from Classes.Containers import *
from Classes.FitsHandlers import *
from Functions.helper_functions import *

all_results_nt      = nt("all_results_nt", ["full_df", "success_df", "not_success_df", "by_eye_success_df", "by_eye_not_success_df"])
combined_results_nt = nt("combined_results_nt", ["bool_df", "full_df", "success_df", "by_eye_success_df"])
mini_sep    = "\n" + 40*"=" + "\n"

In [None]:
# # Defunct
# def check_galfit_chi(gal_name, base_path):
#     # An example line
#     # # Chi^2/nu = 4.661,  Chi^2 = 12025.575,  Ndof = 2580
    
#     #galfit_txt_out = "galfit.01" # in the future galfit.01 may change
#     filename = os.path.join(base_path, gal_name, galfit_txt_out)
#     with open(filename, "r") as f:
#         for line in f:
#             if "Chi" in line:
#                 chi_line = line.strip("# ")
    
#     # This also works but it's quite devious...
#     # chi_line.replace("^", "").replace("/", "_").replace(",  ", "\n").lower()
#     # exec(chi_line)
    
#     out_vals = chi_line.split(",")
#     chi2_nu = float(out_vals[0].strip().split("=")[-1])
#     chi2 = float(out_vals[1].strip().split("=")[-1])
#     ndof = int(out_vals[2].strip().split("=")[-1])
    
#     return chi2_nu, chi2, ndof

In [None]:
def get_total_galaxies(in_dir = "sparcfire-in", out_dir = "sparcfire-out"):   
    all_gnames_in  = find_files(in_dir, "123*", "f")
    all_gnames_out = find_files(out_dir, "123*", "d")
    total_galaxies = min(len(all_gnames_in), len(all_gnames_out))
    if not total_galaxies:
        total_galaxies  = max(len(all_gnames_in), len(all_gnames_out))
        
    return total_galaxies

In [374]:
def grab_header_parameters_for_flux(input_file):
    try:
        with fits.open(input_file) as hdul:
            exptime = hdul[1].header.get("EXPTIME", 1)
            mag_zpt = hdul[2].header.get("MAGZPT", 24.8) # Default for SDSS r band
    except FileNotFoundError as fe:
        print(f"Could not fits.open {input_file}.")
        print("Using exptime and mag zpt defaults from SDSS DR7 r-band.")
        exptime, mag_zpt = 1, 24.8
        
    return exptime, mag_zpt

In [375]:
def load_residual_df(
    out_dir, 
    basename,
    **kwargs
):
    
    method              = kwargs.get("method", "nmr_x_1-p")
    verbose             = kwargs.get("verbose", True)
    residual_cutoff_val = kwargs.get("residual_cutoff_val", 0.007)
    
    pickle_filename = pj(out_dir, basename, sorted(find_files(pj(out_dir, basename), f'{basename}_output_results*.pkl', "f"))[-1])
    
    residual_df  = pd.read_pickle(pickle_filename)
    # temp_df = deepcopy(residual_df)
    # Setting residual columns
    #residual_df["KS_P"] = 1 - residual_df["KS_P"]
    if method == "nmr_x_1-p":
        result_of_method = (1 - residual_df["KS_P"])*residual_df["NMR"]
    elif method == "nmr_neg_log":
        result_of_method = residual_df["NMR"]/-np.log(residual_df["KS_P"] + 1e-10)
    elif method == "W_quality":
        result_of_method = residual_df["KS_P"]/residual_df["W_NMR"]
    else:
        raise Exception(f"Method given: {method} is not a valid method (yet).")
    
    residual_df[method] = result_of_method
    
    # Valid meaning NMR was successfully calculated
    #cols_to_drop = [col for col in residual_df.columns if col.endswith("_sky_2")]
    #valid_spiral_df = residual_df.drop(columns = cols_to_drop).dropna()

    # rename sky_2 to sky_3 for non-spirals to be inline with everything else
    # this would be for potential comparison down the line
    cols_to_merge = [col for col in residual_df.columns if col.endswith("_sky_3") or col.endswith("_sky_4")]
    #_ = [residual_df[col].fillna(residual_df[f"{col[:-1]}2"], inplace = True) for col in cols_to_merge]
    cols_to_drop  = [col for col in residual_df.columns if col.endswith("_sky_2") or col.endswith("_sky_3")]#  + ["KS_STAT"]
    residual_df.drop(columns = cols_to_drop, inplace = True)
    
    if verbose:
        print(f"{len(residual_df)} galaxy models generated.")
        residual_cutoff = residual_df[method] <= residual_cutoff_val
        print(f"{sum(residual_cutoff)} models pass score cutoff.")
    
    # This will obviously have to change if multiple spiral components are introduced
    # but who's crazy enough to do that(???)
    spiral_component_num = int([i for i in residual_df.columns if i.startswith("inner_rad_power")][0][-1])
    
    input_file = "./fake_file.fake_file"
    count      = 0
    while not exists(input_file) and count < 100:
        gname      = residual_df.index[count]
        input_file = pj(out_dir, gname, f"{gname}_galfit_out.fits")
        count += 1
        
    exptime, mag_zpt = grab_header_parameters_for_flux(input_file)
        
    # Formula from GALFIT readme...
    # m_tot = -2.5*log_10*(F_tot/t_exp) + mag_zpt
    # =>
    # F_tot = t_exp*10^[-(m_tot - mag_zpt)/2.5]
    arm_flux   = exptime*10**(
        -(residual_df[f"magnitude_sersic_{spiral_component_num}"] - mag_zpt)/2.5
    )
    other_flux = exptime*10**(
        -(residual_df[f"magnitude_sersic_{spiral_component_num - 1}"] - mag_zpt)/2.5
    )

    for i in range(spiral_component_num - 2, 0, -1):
        other_flux += exptime*10**(
            -(residual_df[f"magnitude_sersic_{i}"] - mag_zpt)/2.5
    )

    residual_df["arm_flux_ratio"] = arm_flux/other_flux
    
    return residual_df.sort_values(by = method)

In [None]:
def load_galaxy_csv(out_dir, basename, pre_post):
    
    field = " pa_alenWtd_avg_domChiralityOnly"
    # {basename}_ uneccessary because different *galfit* runs 
    # should have same sparcfire output
    fname = pj(out_dir, basename, f"{basename}_{pre_post}_galfit_galaxy.csv")
    sparc_output_csv = pd.read_csv(fname, #pj(out_dir, f"pre_galfit_galaxy.csv"),
                                       index_col = "name",
                                       on_bad_lines = "warn",
                                       usecols   = ["name", field], # , " iptSz"],
                                       #na_values = "NaN",
                                       #dtype     = {field : float} #, " iptSz" : str}#, "name" : str}
                                      )#.loc[:, field]
    #sparc_output_csv.index.name = None
    sparc_output_csv[field] = sparc_output_csv[field].astype(float)
    sparc_output_csv.index  = sparc_output_csv.index.map(str)
    #sparc_output_csv[" iptSz"] = sparc_output_csv[" iptSz"].str.extract(r"([0-9]+)").astype(float)

    #sparc_output_csv["pre_sign"] = np.sign(sparc_output_csv[field])
    sparc_output_csv.rename(columns = {field : f"galaxy_{pre_post}_pa"}, inplace = True)
    
    return sparc_output_csv

In [None]:
def load_galaxy_arcs_csv(out_dir, basename, pre_post, **kwargs):
    
    field_pa   = kwargs.get("field_pa"  , "pitch_angle")
    field_alen = kwargs.get("field_alen", "arc_length")
    name_col   = kwargs.get("name_col"  , "gxyName")

    fname = pj(out_dir, basename, f"{basename}_{pre_post}_galfit_galaxy_arcs.csv")
    sparc_output_arcs_csv = pd.read_csv(fname, 
                                       index_col = name_col,
                                       usecols   = [name_col, field_pa, field_alen],
                                       dtype     = {field_pa : float, field_alen : float} #, name_col : str}
                                      )#.loc[:, field]
    #sparc_output_csv.index.name = None
    sparc_output_arcs_csv.index = sparc_output_arcs_csv.index.map(str)

    # Filtering for pure circles and near circles
    sparc_output_arcs_csv = sparc_output_arcs_csv[abs(sparc_output_arcs_csv[field_pa ]) > 1]

    #sparc_output_arcs_csv = pd.concat([sparc_output_arcs_csv, pre_sparc_output_csv], axis = 1)
    #sparc_output_arcs_csv["sign"] = np.sign(sparc_output_arcs_csv[field])

    # Keeps only arms which align with dom chirality only
    # sparc_output_arcs_csv["check"] = [
    #     row["sign"] + pre_sparc_output_csv.loc[i, "pre_sign"] 
    #     if i in pre_sparc_output_csv.index 
    #     else None 
    #     for i, row in sparc_output_arcs_csv.iterrows()
    # ]

    #sparc_output_arcs_csv = sparc_output_arcs_csv[abs(sparc_output_arcs_csv.loc[:, "check"]) == 2].drop(columns = ["sign", "check"])
    sparc_output_arcs_top3 = sparc_output_arcs_csv.groupby(name_col).head(3).reset_index()
    sparc_output_arcs_top3[f"{pre_post}_sign"] = np.sign(sparc_output_arcs_top3.pitch_angle)

    dom_sign = np.sign(sparc_output_arcs_top3.groupby(name_col).sum()[f"{pre_post}_sign"])
    sparc_output_arcs_top3 = sparc_output_arcs_top3.join(dom_sign, rsuffix = "_dom", on = name_col)

    cond = sparc_output_arcs_top3[f"{pre_post}_sign_dom"] == sparc_output_arcs_top3[f"{pre_post}_sign"]
    sparc_output_arcs_top2 = sparc_output_arcs_top3[cond].groupby(name_col).head(2).reset_index().drop(columns = [f"{pre_post}_sign_dom", "index"])

    #pre_sparc_output_top2.rename(columns = {field : "pre_pa"}, inplace = True)
    #pre_sparc_output_csv.dropna(inplace=True)
    return sparc_output_arcs_top2

In [None]:
def prepare_arcs_output(sparc_output_arcs_top2, pre_post, **kwargs):
    
    field_pa   = kwargs.get("field_pa"  , "pitch_angle")
    field_alen = kwargs.get("field_alen", "arc_length")
    name_col   = kwargs.get("name_col"  , "gxyName")
    
    single_arm = sparc_output_arcs_top2[~sparc_output_arcs_top2.duplicated(name_col, keep = False)]
    single_arm.loc[:, field_pa] = 0
    #single_arm.loc[:, "arc_length"]  = 0

    filled_in = pd.concat([sparc_output_arcs_top2, single_arm], ignore_index = True)
    str_fill = [f"{pre_post}_pa1", f"{pre_post}_pa2"] * (len(filled_in) // 2)
    filled_in["temp1"] = str_fill

    str_fill = [f"{pre_post}_alen1", f"{pre_post}_alen2"] * (len(filled_in) // 2)
    filled_in["temp2"] = str_fill

    #filled_in = filled_in.reset_index().drop(columns = ["index"])
    sp_out = filled_in.pivot_table(index = name_col, columns = ["temp1", "temp2"], values = [field_pa, field_alen])

    sp_out = sp_out.droplevel(0, axis = 1).droplevel(0, axis = 1)
    sp_out.columns = [f'{pre_post}_alen1', f'{pre_post}_alen2', f'{pre_post}_pa1', f'{pre_post}_pa2']
    
    return sp_out

In [None]:
def before_after_galfit_comparison(all_sparc_out, pre_sparc_output_csv, post_sparc_output_csv):
    
    before_after_galfit_df = all_sparc_out.copy(deep = True) #.dropna() #full_df.dropna(subset = ["post_pa"])
    #before_after_galfit_df = before_after_galfit_df[np.sign(before_after_galfit_df.loc[:, "pre_pa"]) != np.sign(before_after_galfit_df.loc[:, "post_pa"])]

    before_after_galfit_df["chiral_agreement"] = np.sign(before_after_galfit_df["pre_pa1"]) == np.sign(before_after_galfit_df["post_pa1"])

    before_after_galfit_df["pre_pa1"]  = abs(before_after_galfit_df["pre_pa1"])
    before_after_galfit_df["pre_pa2"]  = abs(before_after_galfit_df["pre_pa2"])
    before_after_galfit_df["post_pa1"] = abs(before_after_galfit_df["post_pa1"])
    before_after_galfit_df["post_pa2"] = abs(before_after_galfit_df["post_pa2"])
    
    #before_after_galfit_df.fillna(90, inplace = True)

    before_after_galfit_df["1-1"] = abs(before_after_galfit_df["pre_pa1"] - before_after_galfit_df["post_pa1"])
    before_after_galfit_df["2-2"] = abs(before_after_galfit_df["pre_pa2"] - before_after_galfit_df["post_pa2"])
    before_after_galfit_df["1-2"] = abs(before_after_galfit_df["pre_pa1"] - before_after_galfit_df["post_pa2"])
    before_after_galfit_df["2-1"] = abs(before_after_galfit_df["pre_pa2"] - before_after_galfit_df["post_pa1"])
    
    min_diff_pa_idx = before_after_galfit_df[["1-1", "2-2", "1-2", "2-1"]].idxmin(axis = 1)#.reset_index(drop = True)
    max_diff_pa_idx = before_after_galfit_df[["1-1", "2-2", "1-2", "2-1"]].idxmax(axis = 1)#.reset_index(drop = True)
    
    min_pre_str  = "pre_pa"  + min_diff_pa_idx.str[0]
    min_post_str = "post_pa" + min_diff_pa_idx.str[-1]
    #max_pre_str  = "pre_pa"  + min_diff_pa_idx.str[0]
    #max_post_str = "post_pa" + min_diff_pa_idx.str[-1]
    
    before_after_galfit_df["min_pre"], before_after_galfit_df["min_post"]  = zip(*[
        (before_after_galfit_df.loc[gname, col_name_pre], before_after_galfit_df.loc[gname, col_name_post])
        if 
            isinstance(col_name_pre, str) and isinstance(col_name_post, str)
        else 
            (None, None)
        for (gname, col_name_pre), (_, col_name_post) in zip(min_pre_str.items(), min_post_str.items())
        
    ])

    #before_after_galfit_df["pa_diff1"], before_after_galfit_df["pa_diff2"] = zip(*before_after_galfit_df["best_diffs"])
    #before_after_galfit_df["max_arm_pa_diff"]    = abs(before_after_galfit_df["max_post"] - before_after_galfit_df["max_pre"])
    before_after_galfit_df["min_arm_pa_diff"] = abs(before_after_galfit_df["min_post"] - before_after_galfit_df["min_pre"])
    before_after_galfit_df["pa_diff_galaxy"]  = abs(post_sparc_output_csv["galaxy_post_pa"] - pre_sparc_output_csv["galaxy_pre_pa"])
    
    before_after_galfit_df["min_pa_diff"]     = before_after_galfit_df[["min_arm_pa_diff", "pa_diff_galaxy"]].min(axis = 1)

    #before_after_galfit_df["alen_ratio"] = post_sparc_output_csv[" iptSz"]*before_after_galfit_df[["pre_alen1", "pre_alen2"]].min(axis = 1)/(pre_sparc_output_csv[" iptSz"]*before_after_galfit_df[["post_alen1", "post_alen2"]].max(axis = 1))
    before_after_galfit_df["alen_ratio"] = before_after_galfit_df[["post_alen1", "post_alen2"]].min(axis = 1)/before_after_galfit_df[["post_alen1", "post_alen2"]].max(axis = 1)

    cols_to_remove = ['1-1', '2-2', '1-2', '2-1'] # 'mean-1122', 'mean-1221', 'min_diff', 'best_diffs']
    before_after_galfit_df = before_after_galfit_df.drop(columns = cols_to_remove) #before_after_galfit_df.columns[9:-4])

    return before_after_galfit_df

In [None]:
def gather_everything(residual_df, before_after_galfit_df, method):
    full_df = residual_df.join(before_after_galfit_df)
    full_df = full_df[full_df.index.notnull()].sort_values(by = method)

    #full_df.dropna(subset = ["pa_diff1", "pa_diff2", "pa_diff_galaxy"], how = "all", inplace = True)
    #full_df.fillna(subset = ["pa_diff1", "pa_diff2", "pa_diff_galaxy"], how = "all", inplace = True)
    #full_df["min_pa_diff"] = full_df[["min_arm_pa_diff", "pa_diff_galaxy"]].min(axis = 1)
    
    return full_df

In [None]:
def determine_success(
    full_df, 
    **kwargs
):
    
    in_dir                = kwargs.get("in_dir", "sparcfire-in") 
    out_dir               = kwargs.get("out_dir","sparcfire-out")
    sparcfire_processed   = kwargs.get("sparcfire_processed", None)
    flip_chiral_agreement = kwargs.get("flip_chiral_agreement", False)
    residual_cutoff_val   = kwargs.get("residual_cutoff_val", 0.007)
    pa_cutoff_val         = kwargs.get("pa_cutoff_val", 10)
    alen_cutoff_val       = kwargs.get("alen_cutoff_val", 0.5)
    verbose               = kwargs.get("verbose", True)
    
    residual_cutoff = full_df["nmr_x_1-p"] <= residual_cutoff_val
    #pa_cutoff = (full_df["pa_diff1"] < 10) | (full_df["pa_diff2"] < 10)
    pa_cutoff   = full_df["min_pa_diff"] < pa_cutoff_val
    alen_cutoff = full_df["alen_ratio"] > alen_cutoff_val #[True]
    sign_cutoff = full_df["chiral_agreement"].astype(bool)
    if flip_chiral_agreement:
        sign_cutoff = ~sign_cutoff

    success_df     = full_df[residual_cutoff & pa_cutoff & alen_cutoff & sign_cutoff].copy()
    not_success_df = full_df[~(residual_cutoff & pa_cutoff & alen_cutoff & sign_cutoff)].copy()
    
    if verbose:
        # print(f"{len(full_df)} processed by sparcfire")
        # print(f"{sum(residual_cutoff)} pass score cutoff")
        print(f"{sum(pa_cutoff)} pass pitch angle cutoff")
        print(f"{sum(alen_cutoff)} pass arm length ratio cutoff")
        print(f"{sum(sign_cutoff)} pass chiral agreement")
        print(f"{len(success_df)} or {100*len(success_df)/len(full_df):.2f}% ({len(success_df)}/{len(full_df)}) succeed by SpArcFiRe+Score")
        if sparcfire_processed is not None:
            sparcfire_processed = full_df.dropna(subset = ["min_arm_pa_diff", "pa_diff_galaxy"], how = "all")
        
        print(f"{total_galaxies - len(sparcfire_processed)}/{total_galaxies} models failed reprocessing by SpArcFiRe")
        
        #print(f"Total success less 24% false positive -- {len(success_df)*.76:.0f}")
        #print(f"Total success less 24% false positive + 24% false negative -- {len(not_success_df)*0.24+len(success_df)*.76:.0f}")
        #print(f"Estimated total success % -- {100*(len(not_success_df)*0.24+len(success_df)*.76)/len(full_df):.0f}%")
    
    # cutoffs = {
    #     "residual_cutoff" : residual_cutoff, 
    #     "pa_cutoff"       : pa_cutoff, 
    #     "alen_cutoff"     : alen_cutoff, 
    #     "sign_cutoff"     : sign_cutoff
    # }
    return success_df, not_success_df # , cutoffs

In [None]:
def extract_by_eye_data(
    out_dir, 
    basename, 
    residual_df, 
    full_df,
    **kwargs
):
    
    sparcfire_processed = kwargs.get("sparcfire_processed", None)
    subset              = kwargs.get("subset", None)
    verbose             = kwargs.get("verbose", True)
    
    with open(f"{pj(out_dir, basename, basename)}_by-eye_success.txt", "r") as f:
        raw_by_eye_success_galaxies = [i.split("_")[0].strip() for i in f.readlines()]

    with open(f"{pj(out_dir, basename, basename)}_by-eye_not_success.txt", "r") as f:
        raw_by_eye_not_success_galaxies = [i.split("_")[0].strip() for i in f.readlines()]
        
    by_eye_success_galaxies = [i for i in raw_by_eye_success_galaxies if i in full_df.index]
    by_eye_not_success_galaxies = [i for i in raw_by_eye_not_success_galaxies if i in full_df.index]
    if sparcfire_processed is not None:
        sparcfire_processed = full_df.dropna(subset = ["min_arm_pa_diff", "pa_diff_galaxy"], how = "all")
    
    if verbose:
        total = len(residual_df)
        if subset:
            total = subset
            print(f"Working on a subset of {total} galaxies")
            
        align = len(f"{len(by_eye_success_galaxies)}/{len(raw_by_eye_success_galaxies)}")
        print(f"Number of *total* by eye successful galaxies")
        print(f"{len(raw_by_eye_success_galaxies):<{align}} => {len(raw_by_eye_success_galaxies)/total*100:.2f}%")
        print(f"Number of by eye successful galaxies that SpArcFiRe *could* process")
        by_eye_processed = [i for i in sparcfire_processed.index if i in raw_by_eye_success_galaxies]
        print(f"{len(by_eye_processed)}/{len(raw_by_eye_success_galaxies)} => {len(by_eye_processed)/len(raw_by_eye_success_galaxies)*100:.2f}%")
        
        print()
        
        align = len(f"{len(by_eye_not_success_galaxies)}/{len(raw_by_eye_not_success_galaxies)}")
        print(f"Number of *total* by eye not successful galaxies")
        print(f"{len(raw_by_eye_not_success_galaxies):<{align}} => {len(raw_by_eye_not_success_galaxies)/total*100:.2f}%")
        
        print(f"Number of by eye not successful galaxies that SpArcFiRe *could* process")
        by_eye_processed = [i for i in sparcfire_processed.index if i in raw_by_eye_not_success_galaxies]
        print(f"{len(by_eye_processed)}/{len(raw_by_eye_not_success_galaxies)} => {len(by_eye_processed)/len(raw_by_eye_not_success_galaxies)*100:.2f}%")
    
    return by_eye_success_galaxies, by_eye_not_success_galaxies

In [None]:
def calculate_false_positive_negative(
    by_eye_success_galaxies, 
    by_eye_not_success_galaxies, 
    success_df, 
    not_success_df, 
    full_df,
    method  = "nmr_x_1-p",
    verbose = True
):
    
    false_positive = set(by_eye_not_success_galaxies).intersection(set(success_df.index))
    false_negative = set(by_eye_success_galaxies).intersection(set(not_success_df.index))

    by_eye_success_df     = full_df.loc[by_eye_success_galaxies].sort_values(by = method)
    by_eye_not_success_df = full_df.loc[by_eye_not_success_galaxies].sort_values(by = method)

    FP_rate = f"{len(false_positive)}/({len(false_positive)} + {len(by_eye_not_success_df)})"
    FN_rate = f"{len(false_negative)}/({len(false_negative)} + {len(by_eye_success_df)})"

    if verbose:
        print(f"False positive rate (by eye) -- {FP_rate} = {100*eval(FP_rate):.2f}%")
        print(f"False negative rate (by eye) -- {FN_rate} = {100*eval(FN_rate):.2f}%")

    #print(f"Total # of galaxies sorted by eye -- {len(raw_by_eye_success_galaxies) + len(raw_by_eye_not_success_galaxies)}")
    return by_eye_success_df, by_eye_not_success_df, FP_rate, FN_rate

In [None]:
def vprint(verbosity, *args, **kwargs):
    if verbosity:
        print(*args, **kwargs)

In [None]:
def load_full_df_with_petromags(
    full_df,
    petromag_df,
    color_band
):
    
    petromag_col = f"petroMag_{color_band}"
        
    full_df   = full_df.join(petromag_df[petromag_col])#.dropna(subset = petromag_col)

    highest_mag_num = int(sorted([x for x in full_df.columns if x.startswith("magnitude_sersic")])[-1][-1])

    for i in range(1, highest_mag_num + 1):
        full_df[f"petromag_{color_band}_diff_{i}"] = full_df[f"magnitude_sersic_{i}"] - full_df[petromag_col]
        
    return full_df

In [None]:
def residual_analysis(
    **kwargs
):
    
    in_dir                = kwargs.get("in_dir", "sparcfire-in")
    out_dir               = kwargs.get("out_dir", "sparcfire-out")
    basename              = kwargs.get("basename", "") 
    method                = kwargs.get("method", "nmr_x_1-p")
    flip_chiral_agreement = kwargs.get("flip_chiral_agreement", False)
    pa_cutoff_val         = kwargs.get("pa_cutoff_val", 10)
    residual_cutoff_val   = kwargs.get("residual_cutoff_val", 0.5)
    alen_cutoff_val       = kwargs.get("alen_cutoff_val", 0.007)
    incl_by_eye           = kwargs.get("incl_by_eye", True)
    by_eye_subset         = kwargs.get("by_eye_subset", None)
    color_band            = kwargs.get("color_band", "r")
    petromag_df           = kwargs.get("petromag_df", pd.DataFrame())
    verbose               = kwargs.get("verbose", False)
    
    vprint(verbose, "Load residual.")
    residual_df = load_residual_df(
        out_dir, 
        basename, 
        method = method, 
        residual_cutoff_val = residual_cutoff_val
    )
    
    # field_pa   = "pitch_angle"
    # field_alen = "arc_length"
    # name_col   = "gxyName"

    vprint(verbose, "Load pre galaxy csv.")
    pre_sparc_output_csv        = load_galaxy_csv(out_dir,      basename, pre_post = "pre")
        
    vprint(verbose, "Load pre galaxy arcs csv.")
    pre_sparc_output_arcs_top2  = load_galaxy_arcs_csv(out_dir, basename, pre_post = "pre")

    vprint(verbose, "Load post galaxy csv.")
    post_sparc_output_csv       = load_galaxy_csv(out_dir,      basename, pre_post = "post")
    
    vprint(verbose, "Load post galaxy arcs csv.")
    post_sparc_output_arcs_top2 = load_galaxy_arcs_csv(out_dir, basename, pre_post = "post")

# ====================================================================================================================

    vprint(verbose, "Prep pre galaxy arcs df")
    pre_sp_out    = prepare_arcs_output(pre_sparc_output_arcs_top2,  pre_post = "pre")
    vprint(verbose, "Prep post galaxy arcs df")
    post_sp_out   = prepare_arcs_output(post_sparc_output_arcs_top2, pre_post = "post")

    vprint(verbose, "And combine")
    all_sparc_out = pd.concat([pre_sp_out, post_sp_out], axis = 1)
    
# ====================================================================================================================

    vprint(verbose, "Compare SpArcFiRe analysis before and after")
    before_after_galfit_df     = before_after_galfit_comparison(
        all_sparc_out, 
        pre_sparc_output_csv, 
        post_sparc_output_csv
    )
    
    vprint(verbose, "Bring everything together")
    full_df             = gather_everything(residual_df, before_after_galfit_df, method)
    
    if not petromag_df.empty:
        vprint(verbose, "Load dataframe with petromag information")
        full_df         = load_full_df_with_petromags(full_df, petromag_df, color_band)
    
    sparcfire_processed = full_df.dropna(subset = ["min_arm_pa_diff", "pa_diff_galaxy"], how = "all")
    
    vprint(verbose, "Determine success")
    success_df, not_success_df = determine_success(
        full_df, 
        in_dir                 = in_dir, 
        out_dir                = out_dir, 
        flip_chiral_agreement  = flip_chiral_agreement,
        sparcfire_processed    = sparcfire_processed,
        pa_cutoff_val          = pa_cutoff_val, 
        residual_cutoff_val    = residual_cutoff_val,
        alen_cutoff_val        = alen_cutoff_val
    )
    
    full_df["success"] = full_df.index.isin(success_df.index)
    print()
    
# ====================================================================================================================
    
    by_eye_success_df     = pd.DataFrame()
    by_eye_not_success_df = pd.DataFrame()
    
    if incl_by_eye:
        vprint(verbose, "Extract by-eye evaluation")
        by_eye_success_galaxies, by_eye_not_success_galaxies = extract_by_eye_data(
            out_dir, 
            basename, 
            residual_df, 
            full_df, 
            subset = by_eye_subset,
            sparcfire_processed = sparcfire_processed
        )
        print()

        # To resolve an occasional processing error...
        by_eye_success_limited     = list(set(by_eye_success_galaxies).intersection(full_df.index))
        by_eye_not_success_limited = list(set(by_eye_not_success_galaxies).intersection(full_df.index))

        vprint(verbose, "Calculate by-eye statistics")
        by_eye_success_df, by_eye_not_success_df, FP_rate, FN_rate = calculate_false_positive_negative(
            by_eye_success_limited, 
            by_eye_not_success_limited, 
            success_df, 
            not_success_df, 
            full_df,
            method = method
        )

        full_df["by_eye_success"] = full_df.index.isin(by_eye_success_df.index)
    
    results_nt = all_results_nt(full_df, success_df, not_success_df, by_eye_success_df, by_eye_not_success_df)
    for df in results_nt:
        df["runname"]  = basename
        
    return results_nt

In [None]:
def combine_multi_run_results(
    method, 
    *args,
    **kwargs
):
    
    df_names      = kwargs.get("df_names", [])
    incl_by_eye   = kwargs.get("incl_by_eye", True)
    by_eye_subset = kwargs.get("by_eye_subset", None)
    verbose       = kwargs.get("verbose", True)
    
    print(f"Joining {len(args)} attempts...")
    primary_full_df = deepcopy(args[0].full_df)
    
    num_dfs = len(args)
    #alt_full_df     = deepcopy(args[1].full_df)
    #alt_full_df.rename(columns = {method : f"1_{method}"}, inplace = True)

    all_full_dfs = [primary_full_df]
    all_methods  = [method]
    all_columns  = []
    
    for i, arg in enumerate(args[1:]):
        alt_method = f"{i}_{method}"
        all_methods.append(alt_method)
        
        all_full_dfs.append(arg.full_df.rename(columns = {method : alt_method}))
        all_columns.append(set(arg.full_df.columns))
        
    shared_columns = list(set(primary_full_df.columns).intersection(*all_columns)) + ["gname"]
    #empty_list = [None]*max([len(df) for df in all_full_dfs])
    #empty_df = pd.DataFrame({col : None for col in shared_columns}) #.set_index("gname")
    
    # BY RESIDUAL
    #temp_bool_df = pd.concat([primary_full_df[method], alt_full_df[f"1_{method}"]], axis = 1)
    combined_bool_df = pd.concat([df[method] for df, method in zip(all_full_dfs, all_methods)], axis = 1)

    #combined_bool_df.drop(index = list(set(primary_full_df.index).difference(set(alt_full_df.index))), inplace = True)
    #temp_bool_df["minima"] = temp_bool_df.idxmin(axis = 1)
    combined_bool_df["minima"] = combined_bool_df.idxmin(axis = 1)
    
    #og_minima  = temp_bool_df.minima == method
    #alt_minima = temp_bool_df.minima == f"1_{method}"
    
    #og_success = temp_bool_df.index.isin(args[0].by_eye_success_df.index)
    #alt_success = temp_bool_df.index.isin(args[1].by_eye_success_df.index)
    #print(sum((og_minima & og_success) | (alt_minima & alt_success)))
        
    # By everything
    eval_str = " | ".join([f"all_full_dfs[{i}].success" for i in range(num_dfs)])
    # success_n | success_m
    combined_bool_df["by_sparcfire_score_success"] = eval(eval_str) #primary_full_df.success | alt_full_df.success # | combined_bool_df.residual_minima_success
    
    minima_conditions  = [combined_bool_df.minima == method for method in all_methods]
    success_conditions = [df.success for df in all_full_dfs]
    all_conditions     = zip(minima_conditions, success_conditions)
    
    list_o_conditions = [cond_set[0] & cond_set[1] for cond_set in all_conditions]
    eval_str = " | ".join([f"list_o_conditions[{i}]" for i in range(num_dfs)])
    # minima -> success_minima
    combined_bool_df["by_sparcfire_and_best_score_success"] = eval(eval_str)
    
    combined_bool_df["best_fit"] = combined_bool_df[combined_bool_df.by_sparcfire_score_success].minima.replace(
        {alt_method : name for alt_method, name in zip(all_methods, df_names)}
    )
    
    if incl_by_eye:
        
        # Use by eye success df to account for by eye subsets (and to shorten the array) rather than info in full_df
        by_eye_success_conditions = [combined_bool_df.index.isin(df.by_eye_success_df.index) for df in args]
        # Flatten
        all_by_eye_success_gnames     = list(set([gname for df in args for gname in df.by_eye_success_df.index]))
        #all_by_eye_not_success_gnames = list(set([gname for df in args for gname in df.by_eye_not_success_df.index]))
        all_by_eye_not_success_gnames = list(set([gname for df in args for gname in df.by_eye_not_success_df.index]))
        #by_sparcfire_success_cond = [combined_bool_df.by_sparcfire_success == method for method in all_methods]

        combined_bool_df["residual_success_by_eye"] = combined_bool_df.index.isin(all_by_eye_success_gnames) & combined_bool_df.by_sparcfire_score_success 
        
        all_conditions = zip(minima_conditions, by_eye_success_conditions)
        list_o_conditions = [cond_set[0] & cond_set[1] for cond_set in all_conditions]
        eval_str = " | ".join([f"list_o_conditions[{i}]" for i in range(num_dfs)])

        # How well does choosing the smallest residual score across all runs work in picking a successful fit
        # when compared with the by eye analysis?
        # minima -> (success_minima & by eye)
        combined_bool_df["residual_minima_success_by_eye"] = eval(eval_str)
        # print(sum((minima_conditions[0] & by_eye_success_conditions[0]) | (minima_conditions[1] & by_eye_success_conditions[1])))
        # print(sum(list_o_conditions[0] | list_o_conditions[1]))
        
        by_sparcfire_success_by_eye = combined_bool_df.index.isin(all_by_eye_success_gnames) & combined_bool_df.by_sparcfire_score_success
            
        # As with and including residual minima, but now include the sparcfire scoring
        # (minima -> [success_minima & by eye]) | ([success_m | success_n] & by eye)
        combined_bool_df["by_minima_or_sparcfire_success_by_eye"]  = by_sparcfire_success_by_eye | combined_bool_df.residual_minima_success_by_eye
        # Comment out this one because it's filtering both individually by eye rather than doing (minima | score) & by eye
        #combined_bool_df["by_minima_and_sparcfire_success_by_eye"] = by_sparcfire_success_by_eye & combined_bool_df.residual_minima_success_by_eye
    
        # TODO: Show % in both
        # by eye success for all labeled by df
        best_fit_str_dict = {m : f"df_{i}" for i, m in enumerate(all_methods)}
        combined_bool_df["best_fit_by_eye"] = None

        for gname, row in combined_bool_df.iterrows():
            best_method = [
                (m, full_df.loc[gname, "by_eye_success"]) 
                for m, full_df in zip(all_methods, all_full_dfs)
                if gname in full_df.index and full_df.loc[gname, "by_eye_success"]
            ]

            if len(best_method) > 1:
                best_method = [(row.minima, None)]

            elif not best_method:
                best_method = [(None, None)]

            if not combined_bool_df.loc[gname, "best_fit_by_eye"]:
                combined_bool_df.loc[gname, "best_fit_by_eye"] = best_fit_str_dict.get(best_method[0][0], None)

        #eval_str = " | ".join([f"all_full_dfs[{i}].by_eye_success" for i, _ in enumerate(all_full_dfs)])
        #combined_bool_df["by_eye_success"] = eval(eval_str) #primary_full_df.by_eye_success | alt_full_df.by_eye_success
        combined_bool_df["by_eye_success"] = False | combined_bool_df.best_fit_by_eye.str.contains("df")
    
    if verbose:
        print(f"Total success by combining SpArcFiRe + score: {sum(combined_bool_df.by_sparcfire_score_success)}/{total_galaxies}")
        print(f"i.e. success_n | success_m | ...")
        print()
        print(f"Total success by combining SpArcFiRe + best score: {sum(combined_bool_df.by_sparcfire_and_best_score_success)}/{total_galaxies}")
        print(f"i.e. minima -> success_minima")
        print(mini_sep)
        if incl_by_eye:
            print("Checking against the by eye determination...")
            _total_galaxies = total_galaxies
            if df_names:
                if len(df_names) != num_dfs: 
                    print("Length of dataframe names supplied should be equal to the number of dataframes supplied.")
                    print("Leaving current convention in the dataframe (df_0, df_1, ..., df_n)")
                else:
                    combined_bool_df["best_fit_by_eye"]   = combined_bool_df.best_fit_by_eye.replace({f"df_{i}" : name for i, name in enumerate(df_names)})
            
            if by_eye_subset:
                _total_galaxies = by_eye_subset
                print("Using a subset of galaxies for the by eye determination...")
                
            print(f"Total success by eye: {sum(combined_bool_df.by_eye_success)}/{_total_galaxies}")
            total_by_eye = sum(combined_bool_df.by_eye_success)
            print()
            print(f"By eye captured by either score: {sum(combined_bool_df.residual_success_by_eye)}/{total_by_eye}")
            print(f"i.e. (success_m | success_n | ...) & by eye")
            print()
            print(f"By eye captured by best score: {sum(combined_bool_df.residual_minima_success_by_eye)}/{total_by_eye}")
            print(f"i.e. minima -> (success_minima & by eye)")
            print()
            print(f"By eye captured by SpArcFiRe or choosing best score between the two runs: {sum(combined_bool_df.by_minima_or_sparcfire_success_by_eye)}/{total_by_eye}")
            print(f"i.e. (minima -> [success_minima & by eye]) | ([success_m | success_n | ...] & by eye)")
            #print(f"By eye captured by SpArcFiRe and choosing best score between the two runs: {sum(combined_bool_df.by_minima_and_sparcfire_success_by_eye)}/{total_by_eye}")
            print(mini_sep)

            bss  = set(combined_bool_df[combined_bool_df.by_sparcfire_score_success].index)
            #bss  = set(combined_bool_df[bss.isin(all_by_eye_success_gnames)].index)
            TP   = all_by_eye_success_gnames
            #TP   = set(combined_bool_df[combined_bool_df["by_eye_success"]].index)
            
            #bsns  = ~combined_bool_df.by_sparcfire_score_success
            bsns = combined_bool_df[~combined_bool_df.by_sparcfire_score_success].index
            #bsns = bsns.index
            #bsns = set(combined_bool_df[bsns.isin(all_by_eye_not_success_gnames)].index)
            
            # Exclude the ones found in the success galaxies because some runs may find success where the others didn't
            TN =  set(all_by_eye_not_success_gnames).difference(set(all_by_eye_success_gnames))
            assert len(TP) + len(TN) == _total_galaxies, f"True positive and true negative don't add up to {_total_galaxies}!"
            #TN   = combined_bool_df[~combined_bool_df["by_eye_success"]].index
            #TN   = set(TN[TN.isin(all_by_eye_not_success_gnames)])

            FP   = bss.intersection(TN)
            FN   = bsns.intersection(TP)

            sparc_positive = bss.intersection(TP)
            sparc_negative = bsns.intersection(TN)
            fraction = len(sparc_positive)/sum(combined_bool_df.by_eye_success)
            
            combined_bool_by_eye_not_success = ~combined_bool_df.by_eye_success
            denom = combined_bool_by_eye_not_success[combined_bool_by_eye_not_success.index.isin(all_by_eye_not_success_gnames)]
            neg_fraction = len(sparc_negative)/sum(denom)
            # FPR = FP/(FP + TN)
            # FNR = FN/(FN + TP)
            # TODO WTF
            print(f"By eye success found by SpArcFiRe + score:  {len(sparc_positive)}/{sum(combined_bool_df.by_eye_success)} = {100*fraction:.2f}%")
            print(f"By eye not success found by SpArcFiRe + score:  {len(sparc_negative)}/{sum(denom)} = {100*neg_fraction:.2f}%")
            
            FP_rate = f"{len(FP)} / ({len(FP)} + {len(TN)})"
            FN_rate = f"{len(FN)} / ({len(FN)} + {len(TP)})"

            print()
            print(f"False positive rate (by eye) -- {FP_rate} = {100*eval(FP_rate):.2f}%")
            print(f"False negative rate (by eye) -- {FN_rate} = {100*eval(FN_rate):.2f}%")
            
            # TODO: GENERATE CONFUSION MATRIX
            #print()
            #print(f"Confusion matrix")
            #print()
            #print()
    
    
    _ = [full_df.rename(columns = {alt_method : method}, inplace = True) for alt_method, full_df in zip(all_methods[1:], all_full_dfs[1:])]
    
    combined_full_df = pd.concat([full_df for full_df in all_full_dfs])
    
    combined_success_df = pd.concat(
                full_df.loc[combined_bool_df[combined_bool_df.best_fit == name].index, :] 
                for name, full_df in zip(df_names, all_full_dfs)
            )
            
    combined_by_eye_success_df = None
    if incl_by_eye:
        # Get index, i.e. galaxy name from choosing the best fit then feed that into the full_dfs in all_full_dfs via loc
        # to grab the row
        combined_by_eye_success_df = pd.concat(
            full_df.loc[combined_bool_df[combined_bool_df.best_fit_by_eye == name].index, :] 
            for name, full_df in zip(df_names, all_full_dfs)
        )
    
    return combined_results_nt(combined_bool_df, combined_full_df, combined_success_df, combined_by_eye_success_df)

In [188]:
def create_ecdf(
    x,
    df, 
    dict_o_kwargs
):

    fig = px.ecdf(
        df,
        x        = x,
        markers  = True, 
        lines    = False, 
        marginal = dict_o_kwargs.get("marginal"),
        ecdfnorm = None,
     ) 

    cutoff_val = dict_o_kwargs.get("cutoff_val")
    if dict_o_kwargs.get("add_vline"):
        fig.add_vline(x = cutoff_val, 
                      row = 1,
                      line_color = "cyan",
                      annotation_text= f"{cutoff_val}", 
                      annotation_position="bottom")

    if dict_o_kwargs.get("add_hline"):
        yval = sum(df.loc[:, x] < cutoff_val)
        fig.add_hline(y = yval, 
                      row = 1,
                      col = 1,
                      line_color = "magenta",
                      annotation_text=f"{yval}",
                      annotation_position="bottom left"
                     )
        
    return fig

In [189]:
def create_scatter(
    x, 
    y, 
    df, 
    dict_o_kwargs
):
    
    color_continuous_scale    = "Agsunset"
    #if dict_o_kwargs.get("color_continuous_midpoint"):
    #    color_continuous_scale = "Portland"
    
    fig = px.scatter(
        df, 
        x = x, 
        y = y,
        color                     = dict_o_kwargs.get("color"),
        color_continuous_scale    = color_continuous_scale,
        range_color               = dict_o_kwargs.get("range_color"),
        #color_continuous_midpoint = dict_o_kwargs.get("color_continuous_midpoint"),
    )
    
    return fig

In [412]:
def create_histogram(
    x, 
    df, 
    dict_o_kwargs, 
):
    
    df[x] +=  dict_o_kwargs.get("hist_offset", 0)
    fig = px.histogram(
        df,
        x                       = x,
        color                   = dict_o_kwargs.get("color"),
        color_discrete_sequence = dict_o_kwargs.get("color_discrete_sequence"),
        histnorm                = dict_o_kwargs.get("histnorm"),
        facet_col               = dict_o_kwargs.get("facet_col"),
        facet_row               = dict_o_kwargs.get("facet_row"),
        nbins                   = dict_o_kwargs.get("nbins", 0),
        marginal                = dict_o_kwargs.get("marginal"),
        #hover_data = {'Galaxy ID': (":c", full_df.index)},
    )
    
    if dict_o_kwargs.get("facet_col") or dict_o_kwargs.get("facet_row"):
        fig.for_each_annotation(lambda a: a.update(text = a.text.split("=")[-1]))

    # if multi:
    #     fig.update_layout(barmode = "overlay")
    #     fig.update_traces(
    #         opacity = 0.75,
    #         marker_line_width = 1,
    #         marker_line_color = "white"
    #     )
    
    return fig

In [334]:
def create_overlay_histogram(
    x,
    dfs,
    dict_o_kwargs
):
    
    histnorm    = dict_o_kwargs.get("histnorm", "probability")
    nbinsx      = dict_o_kwargs.get("nbins", 40)
    
    xmin, xmax  = None, None
    xaxis_range = dict_o_kwargs.get("xaxis_range", (xmin, xmax))
    
    if isinstance(xaxis_range, (tuple, list)) and all(xaxis_range):
        xmin, xmax  = xaxis_range
        
    xbins       = None
    
    # Because xmin could be 0
    if nbinsx and isinstance(xmin, (int, float)) and isinstance(xmax, (int, float)):
        #print(xmin, xmax, nbinsx)
        xbins = dict(
            start = xmin,
            end   = xmax,
            size  = (xmax - xmin)/nbinsx,
        )
    
    colors = deepcopy(px.colors.qualitative.Plotly)
    # Bulge -- Redder
    colors[0] = px.colors.qualitative.Plotly[1]
    # Disk -- Bluer
    colors[1] = px.colors.qualitative.Plotly[0]
    
    num_rows, num_cols = ceil(len(dfs)/3), len(dfs)
    # facet_col = dict_o_kwargs.get("facet_col")
    # if facet_col:
    #     num_cols  = len(df[facet_col].unique())
    #     reversed_cols.pop(facet_col)
        
    # facet_row = dict_o_kwargs.get("facet_row")
    # if facet_row:
    #     num_rows  = len(df[facet_row].unique())
    #     reversed_cols.pop(facet_row)
    
    # colors and to_plot is already reversed so they're fine as-is
    fig = make_subplots(
        num_rows, 
        num_cols,
        subplot_titles = list(dfs.keys()),
        shared_yaxes = True,
    )
    
    dfs = list(dfs.values())
    showlegend = True
    for row in range(num_rows):
        for col in range(num_cols):
            df = dfs[3*row + col]
            
            reversed_cols   = list(reversed(deepcopy(df.columns)))
            reversed_colors = reversed(colors[:len(df.columns)])
    
            for col_name, color in zip(reversed_cols, reversed_colors):
                name     = col_name.split("_")
                name[-1] = str(int(name[-1]) + 1) # since we index at 0
                name     = " ".join(name)
                
                                    
                fig.add_trace(
                    go.Histogram(
                        x                 = df[col_name] + dict_o_kwargs.get("hist_offset", 0),
                        histnorm          = histnorm,
                        name              = name,
                        marker_color      = color,
                        nbinsx            = nbinsx,
                        xbins             = xbins,
                        showlegend        = showlegend,
                        bingroup          = 1
                        #marginal          = dict_o_kwargs.get("marginal"),
                        #hover_data = {'Galaxy ID': (":c", full_df.index)},
                    ),
                    row = row + 1, col = col + 1
                )
            # Turning this off after one go-round
            showlegend = False
        
    # if facet_col or facet_row:
    #     fig.for_each_annotation(lambda a: a.update(text = a.text.split("=")[-1]))
        
    # This applies the yaxis title text to just 'one' of the facet cols, i.e. the first
    fig.update_layout(
        barmode           = "overlay",
        legend_traceorder = "reversed",
        yaxis_title_text  = histnorm,
    )
    
    fig.update_traces(
        opacity           = 0.75,
        marker_line_width = 1,
        marker_line_color = "white",
    )
    
    # This applies x to *all* facet cols
    fig.update_xaxes(
        title_text = x
    )
        
    return fig

In [392]:
def create_plot(
    x, 
    runname, 
    plot_type, 
    df_or_dfs, 
    output_image_dir = "for_paper_images", 
    **kwargs
):
        
    dict_o_kwargs = {
        "y"               : None,
        
        # Need these for histogram binning
        "xaxis_range"     : (None, None),
        "log_x"           : None,
        
        "color"                     : None,
        "color_discrete_sequence"   : None,
        "range_color"               : None,
        #"color_continuous_midpoint" : None,
        
        "hist_offset"     : 0,
        "histnorm"        : "",
        "marginal"        : None,
        "nbins"           : 0,
        
        "facet_col"       : None,
        "facet_row"       : None,
        
        "cutoff_val"      : 0.007,
        "add_vline"       : True,
        "add_hline"       : True,
    }
    
    # Updating with kwargs
    dict_o_kwargs = {key : kwargs.get(key, default) for key, default in dict_o_kwargs.items()}
    
    plt.clf()
    #pio.templates.default = "plotly_white"
    
    if plot_type == "ecdf":
        fig = create_ecdf(x, df_or_dfs, dict_o_kwargs)
    elif plot_type == "scatter":
        fig = create_scatter(x, dict_o_kwargs.get("y"), df_or_dfs, dict_o_kwargs)
    elif plot_type == "histogram":
        fig = create_histogram(x, df_or_dfs, dict_o_kwargs)
        #kwargs["yaxis_title"] = kwargs.get("yaxis_title", kwargs.get("histnorm", "probability"))
    elif plot_type == "overlay_histogram":
        fig = create_overlay_histogram(x, df_or_dfs, dict_o_kwargs)
    else:
        return
    
    if kwargs.get("title"):
        fig.update_layout(
            title_text = kwargs.get("title"), 
            title_x    = kwargs.get("title_x"), 
            title_y    = kwargs.get("title_y")
        )
    
    row, col = None, None
    if plot_type != "overlay_histogram":
        row, col = 1, 1
        
    if kwargs.get("xaxis_title"):
        fig.update_xaxes(title_text = kwargs.get("xaxis_title"), row = row, col = col)
    if kwargs.get("yaxis_title"):
        fig.update_yaxes(title_text = kwargs.get("yaxis_title"), row = row, col = col)
        
    if kwargs.get("log_x"):
        fig.update_xaxes(type = "log", row = row, col = col)
    if kwargs.get("log_y"):
        fig.update_yaxes(type = "log", row = row, col = col)
       
    xaxis_range = kwargs.get("xaxis_range")
    yaxis_range = kwargs.get("yaxis_range")
    if isinstance(xaxis_range, (tuple, list)):
        fig.update_xaxes(range = kwargs.get("xaxis_range"), row = row, col = col)
    if isinstance(yaxis_range, (tuple, list)):
        fig.update_yaxes(range = kwargs.get("yaxis_range"), row = row, col = col)
    
    height           = kwargs.get("height", 800)
    width_multiplier = kwargs.get("width_multiplier", 1.5) #1200
    width            = height*width_multiplier
    
    if kwargs.get("show"):
        # Invert boolean so interactive = True means NOT static plot and vice-versa
        fig.show(
            config = {
                'staticPlot' : not kwargs.get("interactive", False),
                'toImageButtonOptions': {
                    'height': height,
                    'width' : width
                }
            }
        )
    
    if kwargs.get("write"):
        filetype = kwargs.get("filetype", "png")
        fig.write_image(
            f"{output_image_dir}/{plot_type}_{x}_{runname}.{filetype}", 
            height = height, 
            width = width
        )
        
    fig.data   = []
    fig.layout = {}
    
    return fig

In [419]:
def create_all_plots(
    df_container,
    method, 
    basename, 
    output_image_dir,
    **kwargs
):
    
    #color_band   = kwargs.get("color_band", "r")
    incl_by_eye = kwargs.get("incl_by_eye", True)
    show        = kwargs.get("show", False)
    interactive = kwargs.get("interactive", False)
    write       = kwargs.get("write", False)

# ============================================================================================================================================================
# FULL ECDF
# ============================================================================================================================================================

    # Use a cutoff because there tends to be some extremely high values which skew the plot
    plot_df = df_container.full_df[df_container.full_df.loc[:, method] < kwargs.get("score_ecdf_cutoff", 0.015)].copy()
    
    _ = create_plot(
        x                = method,
        runname          = basename,
        plot_type        = "ecdf",
        df_or_dfs        = plot_df,
        output_image_dir = output_image_dir,
        xaxis_title      = method, #"KStest+NMR",
        marginal         = "histogram",
        # title       = f"1000 galaxies: ECDF for KStest+NMR on all models",
        # title_y     = 0.92
        cutoff_val       = kwargs.get("residual_cutoff_val", 0.007),
        show             = show,
        interactive      = interactive,
        write            = write
    )
    
# ============================================================================================================================================================
# ECDF OF BY EYE SCORE
# ============================================================================================================================================================
    
    if incl_by_eye:
        _ = create_plot(
            x                = method,
            runname          = f"{basename}_by-eye",
            plot_type        = "ecdf",
            df_or_dfs        = df_container.by_eye_success_df,
            output_image_dir = output_image_dir,
            xaxis_title      = method, #"KStest+NMR",
            marginal         = "histogram",
            #add_hline        = False,
            # title       = f"1000 galaxies: ECDF for KStest+NMR on by-eye successful model fits",
            # title_y     = 0.92
            show             = show,
            interactive      = interactive,
            write            = write
        )
    
# ============================================================================================================================================================
# SERSIC INDEX HISTOGRAMS
# ============================================================================================================================================================

    # x1   = "sersic_index_sersic_1"
    # x2   = "sersic_index_sersic_2"
    # x3   = "sersic_index_sersic_3"
    
    x    = "n"
    
    cols_to_use = sorted([x for x in df_container.full_df.columns if x.startswith("sersic_index_sersic")])        
    rename      = {col_name : f"sersic_{i}" for i, col_name in enumerate(cols_to_use)}
    
    plot_df_all          = df_container.full_df[cols_to_use].copy().rename(columns = rename)
    plot_df_success      = df_container.success_df[cols_to_use].copy().rename(columns = rename)
    
    to_plot  = {"all models" : plot_df_all, "success" : plot_df_success}
    runname   = f"{basename}_all-vs-success"
    
    if incl_by_eye:
        plot_df_by_eye = df_container.by_eye_success_df[cols_to_use].copy().rename(columns = rename)
        to_plot["by-eye success"] = plot_df_by_eye
        runname   = f"{basename}_all-vs-success-vs-by-eye"
    
    _ = create_plot(
        x                       = x,
        runname                 = runname,
        plot_type               = "overlay_histogram",
        df_or_dfs               = to_plot,
        output_image_dir        = output_image_dir,
        histnorm                = "probability",
        #marginal                = "rug",
        #color                   = "component",
        #color_discrete_sequence = colors,
        nbins                   = kwargs.get("nbins", 40),
        #facet_col               = fcol,
        xaxis_range             = kwargs.get("xaxis_range_sersic_hist", None),
        yaxis_range             = kwargs.get("yaxis_range_sersic_hist", None),
        #reversed_order          = True,
        # title       = f"{runname} galaxies: distribution of magnitudes for by-eye successful models"
        # title_y     = 0.85
        show                    = show,
        interactive             = interactive,
        write                   = write
    )
    
# ============================================================================================================================================================
# MAGNITUDE HISTOGRAMS 
# ============================================================================================================================================================

    # x1   = "magnitude_sersic_1"
    # x2   = "magnitude_sersic_2"
    # x3   = "magnitude_sersic_3"
        
    x    = "m"
        
    cols_to_use = sorted([x for x in df_container.full_df.columns if x.startswith("magnitude_sersic")])      
    rename      = {col_name : f"sersic_{i}" for i, col_name in enumerate(cols_to_use)}
    
    plot_df_all          = df_container.full_df[cols_to_use].copy().rename(columns = rename)
    plot_df_success      = df_container.success_df[cols_to_use].copy().rename(columns = rename)
    
    to_plot  = {"all models" : plot_df_all, "success" : plot_df_success}
    runname  = f"{basename}_all-vs-success"
    
    if incl_by_eye:
        plot_df_by_eye            = df_container.by_eye_success_df[cols_to_use].copy().rename(columns = rename)
        to_plot["by-eye success"] = plot_df_by_eye
        runname                   = f"{basename}_all-vs-success-vs-by-eye"
    
    _ = create_plot(
        x                       = x,
        runname                 = runname,
        plot_type               = "overlay_histogram",
        df_or_dfs               = to_plot,
        output_image_dir        = output_image_dir,
        histnorm                = "probability",
        nbins                   = kwargs.get("mag_nbins"),
        hist_offset             = kwargs.get("mag_hist_offset", 3),
        xaxis_range             = kwargs.get("xaxis_range_mag_hist", None),
        yaxis_range             = kwargs.get("yaxis_range_mag_hist", None),
        # title       = f"{runname} galaxies: distribution of magnitudes for by-eye successful models"
        # title_y     = 0.85
        show                    = show,
        interactive             = interactive,
        write                   = write
    )
    
# ============================================================================================================================================================
# PETROMAG DIFFERENCE HISTOGRAM
# ============================================================================================================================================================

    if "petromag" in " ".join(df_container.full_df.columns):
        x    = "m-petromag"
        fcol = "domain"

        cols_to_use = sorted([x for x in df_container.full_df.columns if x.startswith(f"petromag")])
        rename      = {col_name : f"sersic_{i}" for i, col_name in enumerate(cols_to_use)}
    
        plot_df_all          = df_container.full_df[cols_to_use].copy().rename(columns = rename)
        plot_df_success      = df_container.success_df[cols_to_use].copy().rename(columns = rename)

        to_plot  = {"all models" : plot_df_all, "success" : plot_df_success}
        runname  = f"{basename}_all-vs-success"

        if incl_by_eye:
            plot_df_by_eye            = df_container.by_eye_success_df[cols_to_use].copy().rename(columns = rename)
            to_plot["by-eye success"] = plot_df_by_eye
            runname                   = f"{basename}_all-vs-success-vs-by-eye"

        _ = create_plot(
            x                       = x,
            runname                 = runname,
            plot_type               = "overlay_histogram",
            df_or_dfs               = to_plot,
            output_image_dir        = output_image_dir,
            histnorm                = "probability",
            nbins                   = kwargs.get("mag_nbins"),
            hist_offset             = kwargs.get("mag_hist_offset", 3),
            xaxis_range             = kwargs.get("xaxis_range_petromag_hist", None),
            yaxis_range             = kwargs.get("yaxis_range_petromag_hist", None),
            # title       = f"{runname} galaxies: distribution of magnitudes for by-eye successful models"
            # title_y     = 0.85
            show                    = show,
            interactive             = interactive,
            write                   = write
        )
        
# ============================================================================================================================================================
# SPIRAL ARM FLUX RATIO HISTOGRAM
# ============================================================================================================================================================

    x     = "arm_flux_ratio"
    fcol  = "domain"
    xmin, xmax = kwargs.get("xaxis_range_arm_flux_hist", (0, 5))

    plot_df       = df_container.full_df[df_container.full_df[x] < xmax][x].copy().to_frame()
    plot_df[fcol] = "all models"

    plot_df1       = df_container.success_df[df_container.success_df[x] < xmax][x].copy().to_frame()
    plot_df1[fcol] = "success"
    
    to_concat = [plot_df, plot_df1]
    runname   = f"{basename}_all-vs-success"
    
    if incl_by_eye:
        plot_df2       = df_container.by_eye_success_df[df_container.by_eye_success_df[x] < xmax][x].copy().to_frame()
        plot_df2[fcol] = "by-eye success"
        to_concat.append(plot_df2)
        runname   = f"{basename}_by-eye-vs-success-vs-all"

    plot_df = pd.concat(to_concat, axis = 0)

    _ = create_plot(
        x                       = x,
        runname                 = runname,
        plot_type               = "histogram",
        df_or_dfs               = plot_df,
        output_image_dir        = output_image_dir,
        histnorm                = "probability",
        multi                   = False,
        facet_col               = fcol,
        #nbins                   = 40, #kwargs.get("mag_nbins"),
        #hist_offset             = kwargs.get("mag_hist_offset", 3),
        #xaxis_range             = kwargs.get("xaxis_range_arm_flux_hist", None),
        yaxis_range             = kwargs.get("yaxis_range_arm_flux_hist", None),
        # title       = f"{runname} galaxies: distribution of magnitudes for by-eye successful models"
        # title_y     = 0.85
        show                    = show,
        interactive             = interactive,
        write                   = write
    )
    
# ============================================================================================================================================================
# ALEN HISTOGRAM
# ============================================================================================================================================================

    x    = "alen_ratio"
    fcol = "domain"
    
    plot_df       = df_container.full_df[x].copy().to_frame()
    plot_df[fcol] = "all models"

    plot_df1       = df_container.success_df[x].copy().to_frame()
    plot_df1[fcol] = "success"
    
    to_concat = [plot_df, plot_df1]
    runname   = f"{basename}_all-vs-success"
    
    if incl_by_eye:
        plot_df2       = df_container.by_eye_success_df[x].copy().to_frame()
        plot_df2[fcol] = "by-eye success"
        to_concat.append(plot_df2)
        runname   = f"{basename}_by-eye-vs-success-vs-all"

    plot_df = pd.concat(to_concat, axis = 0)
    
    _ = create_plot(
        x                       = x,
        runname                 = runname,
        plot_type               = "histogram",
        df_or_dfs               = plot_df,
        output_image_dir        = output_image_dir,
        histnorm                = "probability",
        multi                   = False,
        # color                   = "component",
        # color_discrete_sequence = colors,
        facet_col               = fcol,
        xaxis_range             = kwargs.get("xaxis_range_alen_hist"), # [10, 20],
        yaxis_range             = kwargs.get("yaxis_range_alen_hist"), # [0, 0.15],
        # title       = f"{runname} galaxies: distribution of magnitudes for by-eye successful models"
        # title_y     = 0.85
        show                    = show,
        interactive             = interactive,
        write                   = write
    )
    
# ============================================================================================================================================================
# SCATTER OF PITCH ANGLE DIFFERENCES
# ============================================================================================================================================================
    
    x     = "observation"
    y     = "model"
    color = "difference"

    plot_df        = df_container.full_df.copy()
    plot_df[x]     = df_container.full_df.loc[:, "min_pre"]
    plot_df[y]     = df_container.full_df.loc[:, "min_post"]
    plot_df[color] = df_container.full_df.loc[:, "min_arm_pa_diff"]

    _ = create_plot(
        x                = x,
        y                = y,
        runname          = f"{basename}_pa_diff",
        plot_type        = "scatter",
        df_or_dfs        = plot_df,
        output_image_dir = output_image_dir,
        color            = color,
        #xaxis_title      = "
        #color_continuous_midpoint = kwargs.get("pa_cutoff_val"),
        #range_color      = [0, 90],
        width_multiplier = 1,
        # title     = "Pitch angle difference reported by SpArcFiRe, model vs observation"
        # title_y   = 0.85
        show             = show,
        write            = write
    )
    
# ============================================================================================================================================================
# SCATTER OF PITCH ANGLE DIFFERENCES FOR BY EYE SUCCESSFUL
# ============================================================================================================================================================
    if incl_by_eye:
        x     = "observation"
        y     = "model"
        color = "difference"
        
        plot_df        = df_container.by_eye_success_df.copy()
        plot_df[x]     = df_container.by_eye_success_df.loc[:, "min_pre"]
        plot_df[y]     = df_container.by_eye_success_df.loc[:, "min_post"]
        plot_df[color] = df_container.by_eye_success_df.loc[:, "min_arm_pa_diff"]

        _ = create_plot(
            x                = x,
            y                = y,
            runname          = f"{basename}_by-eye_pa_diff",
            plot_type        = "scatter",
            df_or_dfs        = plot_df,
            output_image_dir = output_image_dir,
            color            = color,
            #color_continuous_midpoint = kwargs.get("pa_cutoff_val"),
            width_multiplier = 1,
            # title     = "Pitch angle difference reported by SpArcFiRe, model vs observation"
            # title_y   = 0.85
            show             = show,
            interactive      = interactive,
            write            = write
        )
    
# ============================================================================================================================================================
# ECDF OF PITCH ANGLE DIFFERENCES
# ============================================================================================================================================================

    # _ = create_plot(
    #     x                = "min_pa_diff",
    #     runname          = basename,
    #     plot_type        = "ecdf",
    #     df_or_dfs        = df_container.full_df,
    #     output_image_dir = output_image_dir,
    #     xaxis_title      = "Pitch Angle Difference (deg)",
    #     cutoff_val       = kwargs.get("pa_cutoff_val", 10),
    #     marginal         = "histogram",
    #     # title       = f"ECDF of pitch angle difference reported by SpArcFiRe, model vs observation"
    #     # title_y     = 0.85
    #     show             = show,
    #     interactive      = interactive,
    #     write            = write
    # )

In [None]:
def create_quantiles(
    out_dir, 
    df, 
    method,
    **kwargs
):
    print_latex = kwargs.get("print_latex", True)
    copy_png    = kwargs.get("copy_png", False)
    
    # Just in case
    df.sort_values(by = method, inplace = True)

    # Expect that if there exists more than one runname,
    # then we're working with combined data
    runnames = list(set(df.runname))
    if len(runnames) > 1:
        prefixes = list(set([i.split("_")[0] for i in runnames]))
        if len(prefixes) == 1:
            runname = f"{prefixes[0]}_combined"
        else:
            runname = "combined"
    else:
        runname = runnames[0]
        
    success_dir = pj(out_dir, runname, f'{runname}_galfit_png')
    print_latex_file = pj(out_dir, runname, f"{runname}_for_latex.txt")
    
    if kwargs.get("incl_by_eye", None):
        success_dir = pj(out_dir, runname, f'{runname}_by_eye_galfit_png')
        print_latex_file = pj(out_dir, runname, f"{runname}_by_eye_for_latex.txt")
    
    if not exists(success_dir):
        os.makedirs(success_dir)
    
    quantile           = ["0", "20", "40", "60", "80"]
    quantiled_galaxies = []
    
    print_latex_all = []
    if print_latex:
        
        if exists(print_latex_file):
            print("Deleting old latex output file...")
            os.remove(print_latex_file)
            
        print(f"Writing latex to file {print_latex_file}")
    
    for q in quantile:
        #vprint(print_latex, f"{q} &")
        print_latex_all.append(f"{q} &")
        
        if copy_png:
            quantile_dir = pj(success_dir, f"{runname}_all_quantile", f"quantile_{q}")
            if exists(quantile_dir):
                shutil.rmtree(quantile_dir)
            os.makedirs(quantile_dir)

        interp_df = df[method][df[method] >= df[method].quantile(0.01*float(q), interpolation='lower')]
        for count, (index, value) in enumerate(interp_df.items()):
            #if count < 5:
            #    continue
            if count == 8:
                break

            gname = index
            #print(q, i)
            #vprint(print_latex, f"{initial_str}{gname + '_combined.png'}{end_str}")
            #latex_rname = runname
            copy_rname  = df.loc[index, "runname"]
            
            # Use runname here for combined runs
            temp_str    = f"images/{runname}/{runname}_all_quantile/quantile_"
            initial_str = f"    \includegraphics[height=0.18\\textheight]{{{temp_str}{q}/"
            
            end_str = "} &"
            if count == 7 or count == len(interp_df) - 1:
                end_str = "} \\\\"
                
            print_latex_all.append(f"{initial_str}{gname + '_combined.png'}{end_str}")

            if copy_png:
                png_dir = pj(out_dir, copy_rname, f'{copy_rname}_galfit_png')
                shutil.copy(pj(png_dir, f"{gname}_combined.png"), quantile_dir)

            quantiled_galaxies.append(gname)
                
            #sp(f"cp {pj(out_dir, 'by_eye_success', gname + '_combined.png')} {pj(success_dir, 'all_quantile', 'quantile_' + q)}")
            
    if print_latex:           
        with open(print_latex_file, "w") as plf:
            plf.write("\n".join(print_latex_all))
            plf.write("\n")
            
    if copy_png:
        # Tar it all up!
        sp(f"tar -czvf {pj(out_dir, runname, runname)}_all_quantile.tar.gz -C {success_dir} {runname}_all_quantile")
        
    return quantiled_galaxies

In [None]:
def fprint(input_str, fill_char = "*", fill_len = 100):
    input_str = f" {input_str} "
    print()
    print(f"{input_str:{fill_char}^{fill_len}}")
    print()

In [None]:
def load_petromags(
    gzoo_file,
    color_band
):
    
    if not exists(gzoo_file):
        return full_df
    
    petromag_col = f"petroMag_{color_band}"
    gname_col = "GZ_dr8objid"
    
    gzoo_data = pd.read_csv(
        gzoo_file, 
        sep = "\t", 
        index_col = gname_col,
        usecols  = [gname_col, petromag_col],
        dtype = {
            gname_col    : str,
            petromag_col : np.float32
        }
    ).fillna("None")
    
    gzoo_data = gzoo_data[~gzoo_data.index.duplicated(keep='first')]
    
    return gzoo_data


In [None]:
def main(
    run_path, 
    *basenames, 
    **kwargs
):
    # Set some path variables and things
    run_path = run_path
    
    if in_notebook():
        run_path = run_path.replace("ics-home", "portmanm")

    in_dir  = kwargs.get("in_dir", pj(run_path, "sparcfire-in"))
    out_dir = kwargs.get("tmp_dir", pj(run_path, "sparcfire-out"))
    tmp_dir = kwargs.get("out_dir", pj(run_path, "sparcfire-tmp"))

    output_image_dir = kwargs.get("output_image_dir", pj(run_path, "for_paper_images"))
    if not exists(output_image_dir):
        os.makedirs(output_image_dir)
        
    method          = kwargs.get("method", "nmr_x_1-p")
    nmr             = "norm_masked_residual"
    color_band      = kwargs.get("color_band", "r")
    
    global total_galaxies
    total_galaxies = get_total_galaxies(in_dir = in_dir, out_dir = out_dir)
    
    petromag_df = pd.DataFrame()
    # For speed
    if not kwargs.get("do_not_load_petromags"):
        print("Loading petromag info...")
        gzoo_file   = pj("/home", "portmanm", "kelly_stuff", "Kelly-29k.tsv")
        gzoo_file   = kwargs.get("gzoo_file", gzoo_file)

        petromag_df = load_petromags(gzoo_file, color_band)
    
    # FUNCTIONS OPTIONS
    incl_by_eye   = kwargs.get("incl_by_eye", False)
    by_eye_subset = kwargs.get("by_eye_subset", False)
    write         = kwargs.get("write", False)
    show          = kwargs.get("show", False)
    interactive   = kwargs.get("interactive", False)
    print_latex   = kwargs.get("print_latex", True)
    copy_png      = kwargs.get("copy_png", True)
    
    # Getting ready
    all_results  = {}
    plot_options = kwargs.get("plot_options", {bname : {} for bname in basenames})
    
    # LOOPING THROUGH NAMES GIVEN FOR ANALYSIS
    for basename in basenames:
        # RESIDUAL ANALYSIS
        fprint(f"PERFORMING RESIDUAL ANALYSIS FOR {basename}")
        analysis_results  = residual_analysis(
            in_dir              = in_dir, 
            out_dir             = out_dir, 
            basename            = basename,
            method              = method,
            incl_by_eye         = incl_by_eye,
            by_eye_subset       = by_eye_subset,
            color_band          = color_band,
            petromag_df         = petromag_df,
            pa_cutoff_val       = kwargs.get("pa_cutoff_val", 10),
            residual_cutoff_val = kwargs.get("residual_cutoff_val", 0.5),
            alen_cutoff_val     = kwargs.get("alen_cutoff_val", 0.007)
        )
        
        # Collating
        all_results[basename] = analysis_results
        
        if write or show:
            # OUTPUTTING PLOTS
            fprint("CREATING PLOTS")
            _ = create_all_plots(
                analysis_results, 
                method, 
                basename, 
                output_image_dir, 
                incl_by_eye = incl_by_eye,
                write       = write,
                show        = show,
                interactive = interactive,
                pa_cutoff_val       = kwargs.get("pa_cutoff_val", 10),
                residual_cutoff_val = kwargs.get("residual_cutoff_val", 0.5),
                #alen_cutoff_val     = kwargs.get("alen_cutoff_val", 0.007)
                **plot_options[basename]
                #xaxis_range_mag_hist = [10, 20],
                #yaxis_range_mag_hist = [0, 0.15]
            )

        if print_latex or copy_png:
            fprint("QUANTILING IMAGES FROM RESULTS")
            
            if incl_by_eye:
                print("... by eye")
                quantile_df = analysis_results.by_eye_success_df
                galaxy_set_q = create_quantiles(
                    out_dir, 
                    #basename, 
                    quantile_df,
                    method,
                    **kwargs
                    # print_latex = print_latex, 
                    # copy_png = copy_png
                )
            else:
                quantile_df = analysis_results.success_df

            fprint("QUANTILING IMAGES FROM RESULTS")
            galaxy_set_q = create_quantiles(
                out_dir, 
                #basename, 
                quantile_df,
                method,
                **kwargs
                # print_latex = print_latex, 
                # copy_png = copy_png
            )
        
        # Unfortunately have to do this after and have the user generate the pngs from here
        # in order to rerun create_quantiles
        if kwargs.get("prep_for_quantile", False):
            fprint("JUST KIDDING, EXTRACTING QUANTILED MODELS TO BE CONVERTED TO PNG")
            
            to_untar = ' '.join([f"./{gname}_galfit_out.fits" for gname in galaxy_set_q])
            tar_file = f"{pj(out_dir, basename, basename)}_galfits.tar.gz"
            sp(f"tar -xzvf {tar_file} --occurrence {to_untar}")

            _ = [shutil.move(f"{gname}_galfit_out.fits", f"{pj(out_dir, basename, basename)}_galfits")
                 for gname in galaxy_set_q
                ]
            
            print(f"Please generate the pngs corresponding with the fits in the {pj(out_dir, basename, basename)}_galfits directory.")
            print("You may then proceed to run the 'create_quantiles' function again with copy_png set to True.")
    
    if len(basenames) > 1:
        fprint("COMBINING RESULTS FROM ALL RUNS FED IN")
        combined = combine_multi_run_results(
            method,
            *all_results.values(),
            df_names      = basenames,
            incl_by_eye   = incl_by_eye,
            by_eye_subset = by_eye_subset
        )
        
        prefixes = list(set([i.split("_")[0] for i in basenames]))
        if len(prefixes) == 1:
            new_basename = f"{prefixes[0]}_combined"
        else:
            new_basename = "combined"
        
        all_results[new_basename] = combined
        
        if write or show:
            try:
                # OUTPUTTING PLOTS
                fprint("CREATING PLOTS")
                _ = create_all_plots(
                    combined, 
                    method, 
                    new_basename, 
                    output_image_dir, 
                    incl_by_eye = incl_by_eye,
                    write       = write,
                    show        = show,
                    interactive = interactive,
                    pa_cutoff_val       = kwargs.get("pa_cutoff_val", 10),
                    residual_cutoff_val = kwargs.get("residual_cutoff_val", 0.007),
                    **plot_options[new_basename]
                    #xaxis_range_mag_hist = [10, 20],
                    #yaxis_range_mag_hist = [0, 0.15]
                )
            except KeyError as ke:
                print(f"Were plot options specified with the correct combined basename, {new_basename}?")
                print("Proceeding without plot options.")
                _ = create_all_plots(
                    combined, 
                    method, 
                    new_basename, 
                    output_image_dir, 
                    incl_by_eye = incl_by_eye,
                    write       = write,
                    show        = show,
                    interactive = interactive,
                    pa_cutoff_val       = kwargs.get("pa_cutoff_val", 10),
                    residual_cutoff_val = kwargs.get("residual_cutoff_val", 0.007),
                    #**plot_options[new_basename]
                    #xaxis_range_mag_hist = [10, 20],
                    #yaxis_range_mag_hist = [0, 0.15]
                )

        if print_latex or copy_png:
            fprint("QUANTILING IMAGES FROM RESULTS")
            if incl_by_eye:
                print("... by eye")
                # Do it twice to have a by eye sample and a regular sample
                quantile_df = combined.by_eye_success_df
                galaxy_set_q = create_quantiles(
                    out_dir, 
                    #basename, 
                    quantile_df,
                    method,
                    **kwargs
                    # print_latex = print_latex, 
                    # copy_png = copy_png
                )
                
            else:
                quantile_df = combined.success_df

            
            galaxy_set_q = create_quantiles(
                out_dir, 
                #basename, 
                quantile_df,
                method,
                **kwargs
                # print_latex = print_latex, 
                # copy_png = copy_png
            )
        
        # Unfortunately have to do this after and have the user generate the pngs from here
        # in order to rerun create_quantiles
        if kwargs.get("prep_for_quantile", False):
            fprint("JUST KIDDING, EXTRACTING QUANTILED MODELS TO BE CONVERTED TO PNG")
            
            to_untar = ' '.join([f"./{gname}_galfit_out.fits" for gname in galaxy_set_q])
            tar_file = f"{pj(out_dir, new_basename, new_basename)}_galfits.tar.gz"
            sp(f"tar -xzvf {tar_file} --occurrence {to_untar}")

            _ = [shutil.move(f"{gname}_galfit_out.fits", f"{pj(out_dir, new_basename, new_basename)}_galfits")
                 for gname in galaxy_set_q
                ]
            
            print(f"Please generate the pngs corresponding with the fits in the {pj(out_dir, basename, basename)}_galfits directory.")
            print("You may then proceed to run the 'create_quantiles' function again with copy_png set to True.")
        
    fprint("DONE!!!")
    
    # combined_bool_df only if applicable
    # {basename : namedtuple (fields below), "combined_bool_df" : combined_bool_df}
    # full_df, success_df, not_success_df, by_eye_success_df, by_eye_not_success_df
    return all_results
    

In [None]:
if __name__ == "__main__":
    #pd.options.mode.chained_assignment = 'raise'
    galaxy_set_14_results = main(
        "testing_python_control", 
        "14_NC2", 
        #"14_NC3",
        pa_cutoff_val       = 10,
        alen_cutoff_val     = 0.5,
        residual_cutoff_val = 0.007,
        color_band  = "r",
        incl_by_eye = True,
        show        = True,
        write       = False,
        copy_png    = False,
        print_latex = False,
        interactive = False
    )

In [None]:
if __name__ == "__main__":
    
    plot_options = {
        "1000_NC2" : {
            "xaxis_range_sersic_hist" : [0, 8],
            "yaxis_range_sersic_hist" : [0, 0.18],
            "xaxis_range_mag_hist"    : [13, 23], #[10, 20], 
            "yaxis_range_mag_hist"    : [0, 0.18],
            "xaxis_range_petromag_hist"    : [-5, 9], 
            "yaxis_range_petromag_hist"    : [0, 0.34],
            "mag_nbins"                    : 40,
        },
        "1000_NC3" : {
            "xaxis_range_sersic_hist" : [0, 8],
            "yaxis_range_sersic_hist" : [0, 0.26],
            "xaxis_range_mag_hist"    : [13, 25], #[10, 22], 
            "yaxis_range_mag_hist"    : [0, 0.22],
            "xaxis_range_petromag_hist"    : [-5, 9], #[-8, 6], 
            "yaxis_range_petromag_hist"    : [0, 0.34],
            "mag_nbins"                    : 40,
        },
        "1000_combined" : {
            "xaxis_range_sersic_hist" : [0, 8],
            "yaxis_range_sersic_hist" : [0, 0.26],
            "xaxis_range_mag_hist"    : [13, 25], #[10, 22], 
            "yaxis_range_mag_hist"    : [0, 0.24],
            "xaxis_range_petromag_hist"    : [-5, 9], 
            "yaxis_range_petromag_hist"    : [0, 0.31],
            "mag_nbins"                    : 40,
        },
    }
    
    galaxy_set_no_elps_1000_results = main(
        "run13_for_paper", 
        "1000_NC2", 
        #"1000_NC3",
        pa_cutoff_val       = 7,
        alen_cutoff_val     = 0.62,
        residual_cutoff_val = 0.007,
        incl_by_eye = True,
        write       = False,
        show        = True,
        copy_png    = False,
        print_latex = False,
        interactive = False,
        plot_options = plot_options
    )
    

In [None]:
if __name__ == "__main__":
    
    plot_options = {
        "29k_NC2" : {
            "xaxis_range_sersic_hist" : [0, 8],
            "yaxis_range_sersic_hist" : [0, 0.22],
            "xaxis_range_mag_hist" : [12, 23], 
            "yaxis_range_mag_hist" : [0, 0.13],
            "xaxis_range_petromag_hist"    : [-3, 5], 
            "yaxis_range_petromag_hist"    : [0, 0.23],
            "mag_nbins"                    : 40,
        },
        "29k_NC3" : {
            "xaxis_range_sersic_hist" : [0, 8],
            "yaxis_range_sersic_hist" : [0, 0.35],
            "xaxis_range_mag_hist"    : [12, 23], 
            "yaxis_range_mag_hist"    : [0, 0.15],
            "xaxis_range_petromag_hist"    : [-3, 5], 
            "yaxis_range_petromag_hist"    : [0, 0.18],
            "mag_nbins"                    : 40,
        },
        "29k_combined" : {
            "xaxis_range_sersic_hist" : [0, 8],
            "yaxis_range_sersic_hist" : [0, 0.31],
            "xaxis_range_mag_hist"    : [12, 23], 
            "yaxis_range_mag_hist"    : [0, 0.15],
            "xaxis_range_petromag_hist"    : [-3, 5], 
            "yaxis_range_petromag_hist"    : [0, 0.19],
            "mag_nbins"                    : 40,
        },
    }
    
    galaxy_set_29k_results = main(
        "29k_galaxies", 
        "29k_NC2", 
        "29k_NC3",
        plot_options  = plot_options,
        pa_cutoff_val       = 9,
        alen_cutoff_val     = 0.45, #0.52,
        residual_cutoff_val = 0.009,
        write         = True,
        show          = False,
        incl_by_eye   = True,
        by_eye_subset = 1000,
        copy_png      = False,
        print_latex   = False,
        interactive   = False
    )

In [None]:
def generate_images_old(input_df, png_dir:str, variable_name:str, custom_range = None):
    images_out = []
    
    if not custom_range:
        custom_range = range(0, len(input_df), 50) 
        
    count = 0
    for index_num in custom_range:
        g_variable = input_df.iloc[index_num]
        gname = g_variable.name
        variable_value = g_variable[variable_name]#.norm_masked_residual

        height = 500
        width = 500
        size = (height, width)
        #out_str = galaxy_info.name.replace("galfit_out.fits", "combined.png").strip()
        out_str = f"{gname}_combined.png"
        #print(out_str)
        
        
        images_out.append(Image(filename = pj(png_dir, out_str), width=width, height=height))
            
        print(f"{gname}, sorted #: {index_num}")
        print(f"{variable_name} = {variable_value:.6f}")
        #print(f"Dim: {galaxy_info['image_size']}x{galaxy_info['image_size']}")
        print()
        
    return images_out

In [None]:
images_to_disp = generate_images_old(full_df, pj(out_dir, "galfit_png"), "diff") #, range(0,len(full_df)))

In [None]:
display(*images_to_disp)

In [None]:
def generate_images(input_df, png_dir:str, cutoff_val = 0.01, variable_name = "norm_masked_residual", custom_range = None):
    images_below_cutoff = []
    images_above_cutoff = []
    
    if not custom_range:
        custom_range = range(0, len(input_df), 50) 
    count = 0
    for index_num in custom_range:
        g_variable = input_df.iloc[index_num]
        gname = g_variable.name
        variable_value = g_variable[variable_name]#.norm_masked_residual

        # iloc returns a series, name returns the name of the row

        
        # print(f"chi^2/nu = {galaxy_info['chi^2_nu']:.2f}")
        # print(f"chi^2 = {galaxy_info['chi^2']:.2f}")
        #print(f"Norm GALFIT residual = {norm_galfit_residual:.4f}")


        # galfit_cmap = grayscale_cmap('RdBu')
        # residual_plot = plt.imshow(np.flipud(masked_residual[:,:])) #, norm=colors.LogNorm())
        # residual_plot.set_cmap('Greys')
        # residual_plot.set_cmap(galfit_cmap)
        # cbar = plt.colorbar()

        #plt.imshow(residual_plot)
        #imgplot = plt.imshow(arr[:, :, 0])
        height = 500
        width = 500
        size = (height, width)
        #out_str = galaxy_info.name.replace("galfit_out.fits", "combined.png").strip()
        out_str = f"{gname}_combined.png"
        #print(out_str)
        
        if variable_value < cutoff_val:
            images_below_cutoff.append(Image(filename = pj(png_dir, out_str), width=width, height=height))
            #images_below_cutoff.append(PIL.Image.open(pj(png_dir, out_str)).resize(size))
        else:
            count += 1
            if count == 1:
                print("="*80)
            images_above_cutoff.append(Image(filename = pj(png_dir, out_str), width=width, height=height))
            #images_above_cutoff.append(PIL.Image.open(pj(png_dir, out_str)).resize(size))

            
        print(f"{gname}, sorted #: {index_num}")
        print(f"{variable_name} = {variable_value:.6f}")
        #print(f"Dim: {galaxy_info['image_size']}x{galaxy_info['image_size']}")
        print()
        
    return images_below_cutoff, images_above_cutoff

In [None]:
png_dir = os.path.join(run_path, out_dir, "galfit_png")
#below, above = generate_images(residual_df, png_dir, cutoff_val = 0.013342, variable_name = analysis_var, custom_range = range(700, len(residual_df), 10) )
below, above = generate_images(residual_df, png_dir, cutoff_val = cutoff_val, variable_name = analysis_var, custom_range = range(800, len(residual_df), 10) )

In [None]:
display(*below)

In [None]:
display(*above)

In [None]:
# good_fit = "1237671262278582530"
# bad_fit = "1237668366388756890"

# good_fit_obj = OutputFits(pj(out_dir, good_fit, f"{good_fit}_galfit_out.fits"))
# bad_fit_obj = OutputFits(pj(out_dir, bad_fit, f"{bad_fit}_galfit_out.fits"))
# good_residual = good_fit_obj.residual.data

# scipy.stats.probplot(good_residual.flatten(), plot = plt)

In [None]:
# bad_residual = bad_fit_obj.residual.data
# scipy.stats.probplot(bad_residual.flatten(), plot = plt)

In [None]:
# Thanks to https://jakevdp.github.io/PythonDataScienceHandbook/04.07-customizing-colorbars.html
def grayscale_cmap(cmap):
    """Return a grayscale version of the given colormap"""
    cmap = plt.cm.get_cmap(cmap)
    colors = cmap(np.arange(cmap.N))
    
    # convert RGBA to perceived grayscale luminance
    # cf. http://alienryderflex.com/hsp.html
    RGB_weight = [0.299, 0.587, 0.114]
    luminance = np.sqrt(np.dot(colors[:, :3] ** 2, RGB_weight))
    colors[:, :3] = luminance[:, np.newaxis]
        
    return LinearSegmentedColormap.from_list(cmap.name + "_gray", colors, cmap.N)