In [None]:
import copy
import itertools
import numpy as np
import glob
import h5py
import scipy
import sys
import tqdm
import verdict
import os
import unyt

In [None]:
import matplotlib as mpl
from matplotlib.lines import Line2D
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1 import make_axes_locatable 
import palettable

In [None]:
import kalepy as kale

In [None]:
import galaxy_dive.analyze_data.halo_data as halo_data
import galaxy_dive.plot_data.plotting as plotting

In [None]:
import trove
import trove.config_parser

# Load Data

In [None]:
config_fp = '/home1/03057/zhafen/papers/Hot-Accretion-in-FIRE/analysis/hot_accretion.trove'
cp = trove.config_parser.ConfigParser( config_fp )
pm = trove.link_params_to_config(
    config_fp,
    variation = 'm12i_md',
)

In [None]:
data_filepath = os.path.join( pm['processed_data_dir'], 'summary.hdf5' )
data = verdict.Dict.from_hdf5( data_filepath )

In [None]:
h_param = .702

# Extract Quantities

In [None]:
mvir = verdict.Dict({})
mstar = verdict.Dict({})
vc = verdict.Dict({})
rstar = verdict.Dict({})
for key in tqdm.tqdm( cp.variations ):
    
    # Load from summary data if available
    all_data_types_in_summary_data = True
    for data_type in [ 'Mvir', 'Mstar', 'Vc', 'Rstar0.5' ]:
        if data_type not in data:
            data[data_type] = {}
        all_data_types_in_summary_data = all_data_types_in_summary_data & ( key in data[data_type] )
    if all_data_types_in_summary_data:
        mvir[key] = data['Mvir'][key]
        mstar[key] = data['Mstar'][key]
        vc[key] = data['Vc'][key]
        rstar[key] = data['Rstar0.5'][key]
        continue
    
    pm_i = trove.link_params_to_config(
        '/home1/03057/zhafen/papers/Hot-Accretion-in-FIRE/analysis/hot_accretion.trove',
        variation = key,
    )
    
    h_data = halo_data.HaloData(
        data_dir = pm_i['halo_data_dir'],
        mt_kwargs = { 'tag': 'smooth' },
    )
    
    mvir[key] =  h_data.get_mt_data( 'Mvir', snums=[600,] )[0] / h_param
    mstar[key] = h_data.get_mt_data( 'M_star', snums=[600,] )[0] / h_param
    vc[key] = h_data.get_mt_data( 'Vmax', snums=[600,] )[0]
    rstar[key] = h_data.get_mt_data( 'Rstar0.5', snums=[600,] )[0] / h_param
    
    data['Mvir'][key] = mvir[key]
    data['Mstar'][key] = mstar[key]
    data['Vc'][key] = vc[key]
    data['Rstar0.5'][key] = rstar[key]
    
data.to_hdf5( data_filepath )

In [None]:
# tcool/tff
missing = []
tcool_tff = verdict.Dict({})
for key in cp.variations:
    
    try:
        # Get the file
        res = cp.get( key, 'subpath' ).split( '_' )[-1][3:]
        fp = os.path.join( pm['processed_data_dir'], 'tcool_tff', 't_cool_to_t_ff_{}_{}.npz'.format( key, res ) )
        tcool_tff_file = np.load( fp )
    except FileNotFoundError:
        missing.append( key )
        
        continue
        
    tcool_tff[key] = 10.**tcool_tff_file['log_t_cool_to_t_ff_smooth'][0]

In [None]:
# Thin disk fraction change (doesn't actually make sense for gas)
delta_thin_disk = (
    data['circularity']['thin_disk_fractions'].inner_item( -1 ) - 
    data['circularity']['thin_disk_fractions'].inner_item( 0 )
)

In [None]:
# Width change
sigma_cosphi = data['cosphi']['84th_percentile'] - data['cosphi']['16th_percentile']
delta_sigma_cosphi = sigma_cosphi.inner_item( -1 ) - sigma_cosphi.inner_item( 0 )
negative_delta_sigma_cosphi = -1. * delta_sigma_cosphi
sigma_cosphi_ratio = sigma_cosphi.inner_item( 0 ) / sigma_cosphi.inner_item( -1 )
std_cosphi_ratio = data['cosphi']['std'].inner_item( 0 ) / data['cosphi']['std'].inner_item( -1 )

In [None]:
# PDF value changes
delta_pdf = data['cosphi']['pdf(cos theta=0)'].inner_item( -1 ) - data['cosphi']['pdf(cos theta=0)'].inner_item( 0 )
ratio_pdf = data['cosphi']['pdf(cos theta=0)'].inner_item( -1 ) / data['cosphi']['pdf(cos theta=0)'].inner_item( 0 )

In [None]:
# Spherical harmonics
delta_q20 = data['cosphi']['q20'].inner_item( -1 ) - data['cosphi']['q20'].inner_item( 0 )
delta_q33 = data['cosphi']['q33'].inner_item( -1 ) - data['cosphi']['q33'].inner_item( 0 )

In [None]:
# Disk fraction
disk_frac = verdict.Dict({})
for key in cp.variations:
    disk_fracs = []
    for i, pdf in enumerate( data['cosphi']['pdf'][key] ):
        
        in_disk = np.abs( data['cosphi']['points'][key] ) < pm['disk_costheta']
        disk_fracs.append( pdf[in_disk].sum()/pdf.sum() )
    disk_frac[key] = disk_fracs
delta_disk_frac = disk_frac.inner_item( -1 ) - disk_frac.inner_item( 0 )

In [None]:
# Disk fraction (aligned fraction) for thin disk stars, for use as a reference point
aligned_frac = verdict.Dict({})
aligned_frac_recent = verdict.Dict({})
thin_disk_aligned_missing = []
for key in cp.variations:
    aligned = np.abs( data['cosphi_stars']['centers'] ) < pm['disk_costheta']
    
    try:
        pdf = data['cosphi_stars']['thin_disk'][key]
        aligned_frac[key] = pdf[aligned].sum()/pdf.sum()

        pdf = data['cosphi_stars']['thin_disk_recent'][key]
        aligned_frac_recent[key] = pdf[aligned].sum()/pdf.sum()
    except KeyError:
        thin_disk_aligned_missing.append( key )

In [None]:
# Median change
med_cosphi = verdict.Dict({})
abs_med_cosphi = verdict.Dict({})
for key in cp.variations:
    medians = []
    abs_medians = []
    for i, pdf in enumerate( data['cosphi']['pdf'][key] ):

        cdf = np.cumsum( pdf )
        cdf /= cdf[-1]
        medians.append( scipy.interpolate.interp1d( cdf, data['cosphi']['points'][key] )( 0.5 ) )
        
        is_above = data['cosphi']['points'][key] > 0.
        as_above = pdf[is_above]
        so_below = pdf[data['cosphi']['points'][key] < 0.][::-1]
        as_above += so_below
        cdf_abs = np.cumsum( as_above )
        cdf_abs /= cdf_abs[-1]
        abs_medians.append( scipy.interpolate.interp1d( cdf_abs, data['cosphi']['points'][key][is_above] )( 0.5 ) )
        
    med_cosphi[key] = np.array( medians )
    abs_med_cosphi[key] = np.array( abs_medians )

In [None]:
# Delta of median
delta_med_cosphi = med_cosphi.inner_item( 0 ) - med_cosphi.inner_item( -1 )
delta_abs_med_cosphi = abs_med_cosphi.inner_item( 0 ) - abs_med_cosphi.inner_item( -1 )

In [None]:
# Quiet acc fraction (not a rigorous definition)
quiet_frac_all = data['quiet_acc_fraction']
quiet_frac = {}
quiet_frac_strict = {}
for key in cp.variations:
    i_t_cut = np.argmax(np.isclose( quiet_frac_all[key]['t_cuts'], pm['t_cut'] ) )
    j_c_cut = np.argmax(np.isclose( quiet_frac_all[key]['c_cuts'], pm['c_cut'] ) )
    k_r_cut = np.argmax(np.isclose( quiet_frac_all[key]['r_cuts'], pm['r_cut'] ) )
    quiet_frac[key] = quiet_frac_all[key]['fraction'][i_t_cut,j_c_cut]
    quiet_frac_strict[key] = quiet_frac_all[key]['fraction_strict'][i_t_cut,j_c_cut,k_r_cut]
quiet_frac = verdict.Dict( quiet_frac )

In [None]:
# Fraction above 40 kpc
r = 40.
cdfs = data['R1e5K']['pdf'].cumsum()
cdfs /= cdfs.inner_item( -1 )
f_above_r = verdict.Dict({})
for key, cdf in cdfs.items():
    f_above_r[key] = 1. - scipy.interpolate.interp1d( data['R1e5K']['points'][key], cdf )( r )

In [None]:
values = {
    'delta_thin_disk': delta_thin_disk,
    'median_R1e5K': data['R1e5K']['median'],
    'median_R1e5K_rgal': data['R1e5K_rgal']['median'],
    'thin_disk_frac': data['thin_disk_fraction'],
    'thin_disk_frac_recent': data['thin_disk_fraction_recent'],
    'thin_disk_frac_tracked': data['f_thin_tracked_z0'],
    'mvir': mvir,
    'mstar': mstar,
    'rstar': rstar,
    'tcool_tff': tcool_tff,
    'negative_delta_sigma_cosphi': negative_delta_sigma_cosphi,
    'sigma_cosphi_ratio': sigma_cosphi_ratio,
    'std_cosphi_ratio': std_cosphi_ratio,
    'delta_pdfcosphi': delta_pdf,
    'ratio_pdfcosphi': ratio_pdf,
    'disk_frac': disk_frac,
    'delta_disk_frac': delta_disk_frac,
    'delta_q20': delta_q20,
    'delta_q33': delta_q33,
    'delta_med_cosphi': delta_med_cosphi,
    'delta_abs_med_cosphi': delta_abs_med_cosphi,
    'quiet_frac': quiet_frac,
    'quiet_frac_strict': quiet_frac_strict,
    'f_above_r': f_above_r,
}

# Plot

## Settings

In [None]:
labels = {
    'delta_thin_disk': r'$\Delta f_{\rm thin}$',
    'median_R1e5K': r'median $R_{T=10^5\, {\rm K}}$ (kpc)',
    'median_R1e5K_rgal': r'median $R_{T=10^5\, {\rm K}}$ / $R_{\rm gal}$',
    'thin_disk_frac': r'stellar thin disk fraction',
    'thin_disk_frac_recent': r'$f_{\rm thin\,disk}$($z=0$, age $<1$ Gyr)',
    'thin_disk_frac_tracked': r'$f_{\rm thin}(\star,z=0,$ tracked)',
    'mvir': r'$M_{\rm vir}$ $[M_\odot]$',
    'mstar': r'$M_{\star}$ $[M_\odot]$',
    'rstar': r'$R_{\star, 0.5}$ [kpc]',
    'tcool_tff': r'$t_{\rm cool}^{(s)}$ / $t_{\rm ff}$ at $0.1 R_{\rm vir}$', 
    'sigma_cosphi': r'$\sigma( \cos\theta )$',
    'negative_delta_sigma_cosphi': r'$\sigma_{\cos\theta,\,{\rm hot}}$ - $\sigma_{\cos\theta,\,{\rm cool}}$',
    'sigma_cosphi_ratio': r'$\sigma_{\cos\theta,\,{\rm hot}}$ / $\sigma_{\cos\theta,\,{\rm cool}}$',
    'std_cosphi_ratio': r'${\rm STD}_{\cos\theta,\,{\rm hot}}$ / ${\rm STD}_{\cos\theta,\,{\rm cool}}$',
    'delta_pdfcosphi': r'$( dM_{\rm after\,cooling} - dM_{\rm before\,cooling} )\mid_{\rm galaxy\,plane}$',
    'ratio_pdfcosphi': r'${\rm PDF}(\cos\theta=0)_{\rm cool}$ / ${\rm PDF}(\cos\theta=0)_{\rm hot}$',
    'pdfcosphi_0': r'${\rm PDF}(\cos\theta=0)$',
#     'disk_frac': r'$M(\mid z/R \mid < ' + '{:.3g}'.format( pm['disk_costheta'] ) + r')/M$',
#     'delta_disk_frac': r'$\Delta M(\mid z/R \mid < ' + '{:.3g}'.format( pm['disk_costheta'] ) + r')/M$',
#     'disk_frac': 'aligned mass fraction ($\mid z/R \mid < ' + '{:.3g}'.format( pm['disk_costheta'] ) + r')$',
#     'delta_disk_frac': 'change in aligned mass fraction',
    'disk_frac': 'Aligned Accretion',
    'delta_disk_frac': r'$\Delta$(Aligned Accretion)',
    'delta_q20': r'$\Delta q_{20}$',
    'delta_q33': r'$\Delta q_{33}$',
    'med_cosphi': r'$\cos\theta_{50}$',
    'delta_med_cosphi': r'$\Delta \cos\theta_{50}$',
    'abs_med_cosphi': r'$\mid \cos\theta_{50}\mid$',
    'delta_abs_med_cosphi': r'$\mid \Delta \cos\theta_{50}\mid$',
    'quiet_frac': r'CCF fraction, no R cut',
    'quiet_frac_strict': r'CCF fraction',
}

In [None]:
logscale = [ 'mvir', 'mstar', 'tcool_tff' ]

In [None]:
fractions = [ 'thin_disk_frac', 'thin_disk_frac_recent', 'thin_disk_frac_tracked', 'quiet_frac', 'quiet_frac_strict', ]

In [None]:
arrowprops = {
    'arrowstyle': '-',
}

In [None]:
custom_annot_args = {
    ( 'delta_pdfcosphi', 'thin_disk_frac_recent' ): {
        'm12f_md': {
            'ha': 'right',
            'va': 'bottom',
            'xytext': ( -3, 3 ),
        },
        'm12i_cr': {
            'ha': 'right',
            'va': 'center',
            'xytext': ( -8, 0 ),
        },
        'm12i_md': {
            'ha': 'center',
            'va': 'bottom',
            'xytext': ( 3, 5 ),
        },
        'm12i': {
            'ha': 'left',
            'va': 'top',
            'xytext': ( 3, -3 ),
        },
#         'm11e_md': {
#             'xytext': ( 50, 3 ),
#             'arrowprops': arrowprops,
#         },
#         'm11c': {
#             'xytext': ( 50, 3 ),
#             'arrowprops': arrowprops,
#         },
#         'm11q_md': {
#             'xytext': ( 85, 0 ),
#             'arrowprops': arrowprops,
#             'ha': 'left',
#             'va': 'bottom',
#         },
#         'm11a': {
#             'xytext': ( 0, 65 ),
#             'ha': 'left',
#             'va': 'bottom',
#             'arrowprops': arrowprops,
#         },
#         'm11i_md': {
#             'xytext': ( 0, 40 ),
#             'ha': 'left',
#             'va': 'bottom',
#             'arrowprops': arrowprops,
#         },
#         'm11d_md': {
#             'xytext': ( 20, 25 ),
#             'ha': 'left',
#             'va': 'bottom',
#             'arrowprops': arrowprops,
#         },
        'm11a': None,
        'm11c': None,
        'm11d_md': None,
        'm11e_md': None,
        'm11i_md': None,
        'm11q_md': None,
    },
    ( 'thin_disk_frac_recent', 'delta_disk_frac' ): {
        'm12i': {
            'ha': 'right',
            'va': 'bottom',
            'xytext': ( -3, 3 ),
        },
        'm11a': None,
        'm11c': None,
        'm11d_md': None,
        'm11e_md': None,
        'm11i_md': None,
        'm11q_md': None,
    },
}

In [None]:
markers = {
    'md': 'o',
    'cr': '^',
#     'mhdcv': 'P',
    '': 's',
}

In [None]:
marker_labels = {
    'md': 'Hydro+',
    'cr': 'CR+',
#     'mhdcv': 'MHD+',
    '': 'no metal diffusion',
}

In [None]:
custom_lims = {
    'median_R1e5K_rgal': [ 0, 4 ],
    'mvir': [ 3e10, 2e12 ],
    'mstar': [ 7e7, 2e11 ],
    'tcool_tff': [ 0.08, 20 ]
}

In [None]:
def get_lim( vs, is_log, v_key, scale_upper=1., scale_lower=1. ):
    
    if v_key in custom_lims:
        return custom_lims[v_key]
    
    if v_key in fractions:
        return [ 0, 1 ]
    
    min_v = np.nanmin( vs.array() )
    if not is_log:
        min_v = min( 0, min_v )
        
    max_v = np.nanmax( vs.array() ) * scale_upper
    min_v *= scale_lower
                      
    return min_v, max_v

## One Value per Axis

In [None]:
# Automatic
value_keys = list( values.keys() )
combinations = itertools.combinations( value_keys, 3 )

In [None]:
# Manual
combinations = [
    ( 'delta_pdfcosphi', 'thin_disk_frac_recent', 'mvir' ),
    ( 'thin_disk_frac_recent', 'delta_pdfcosphi', 'mvir' ),
    ( 'delta_disk_frac', 'thin_disk_frac_recent', 'mvir' ),
    ( 'thin_disk_frac_recent', 'delta_disk_frac', 'mvir' ),
    ( 'quiet_frac_strict', 'thin_disk_frac_recent', 'mvir' ),
    ( 'quiet_frac_strict', 'delta_pdfcosphi', 'mvir' ),
    ( 'mvir', 'rstar', 'thin_disk_frac_recent' ),
]

In [None]:
cmap = palettable.matplotlib.Viridis_20.mpl_colormap

In [None]:
for i, (x_key, y_key, z_key) in enumerate( combinations ):
    
    xs = values[x_key]
    ys = values[y_key]
    zs = values[z_key]
    
    x_log = x_key in logscale
    y_log = y_key in logscale
    z_log = z_key in logscale
        
    x_lims = get_lim( xs, x_log, x_key, 1.2, 0.95 )
    y_lims = get_lim( ys, y_log, y_key, 1.1, 0.95 )
    z_lims = get_lim( zs, z_log, z_key, )

    fig = plt.figure( figsize=(8,8), facecolor='w' )
    ax = plt.gca()

    for key in cp.variations :
        
        try:
            x = xs[key]
            y = ys[key]
            z = zs[key]
        except KeyError:
            continue
        
        if np.isnan( x ) or np.isnan( y ) or np.isnan( z ):
            continue
        elif not key in cp.variations:
            continue
        
        if '_' in key:
            sim_name, physics = key.split( '_' )
        else:
            sim_name = key
            physics = ''

        if not z_log:
            c_value = ( z - z_lims[0] ) / ( z_lims[1] - z_lims[0] )
        else:
            c_value = ( np.log10( z ) - np.log10( z_lims[0] ) ) / ( np.log10( z_lims[1] ) - np.log10( z_lims[0] ) )
        c = cmap( c_value )
            
        s = ax.scatter(
            x,
            y,
            s = 100,
            color = c,
            marker = markers[physics]
        )

        # Annotate simulation names
        annot_args = {
            'textcoords': 'offset points',
            'fontsize': 22,
            'va': 'bottom',
            'ha': 'left',
            'xytext': ( 3, 3 ),
        }
        # Custom annotation arguments
        if ( x_key, y_key ) in custom_annot_args:
            if key in custom_annot_args[(x_key,y_key)]:
                custom_args = custom_annot_args[(x_key,y_key)][key]
                if custom_args is not None:
                    annot_args.update( custom_args )
                else:
                    annot_args = None
        if annot_args is not None:
            ax.annotate(
                text = sim_name,
                xy = ( x, y ),
                **annot_args
            )
        # Special cases
        if ( x_key, y_key ) == ( 'delta_pdfcosphi', 'thin_disk_frac_recent' ):
            if key == 'm11e_md':
                ax.annotate(
                    text = 'm11a, m11c, m11d\nm11e, m11i, m11q',
                    xy = ( x+0.05, y ),
                    xytext = ( 30, 5 ),
                    textcoords = 'offset points',
                    fontsize = 22,
                    va = 'center',
                    arrowprops = {
                        'arrowstyle': '-[',
                        'lw': 1.5,
                    },
                )
        elif ( x_key, y_key ) == ( 'thin_disk_frac_recent', 'delta_disk_frac' ):
            if key == 'm11e_md':
                ax.annotate(
                    text = 'm11e, m11i, m11q',
                    xy = ( x+0.01, y+0.01 ),
                    xytext = ( -10, 35 ),
                    textcoords = 'offset points',
                    fontsize = 22,
                    ha = 'left',
                    va = 'center',
                    arrowprops = {
                        'arrowstyle': '-[',
                        'lw': 1.5,
                    },
                )
            elif key == 'm11d_md':
                ax.annotate(
                    text = 'm11a, m11c, m11d',
                    xy = ( x+0.03, y ),
                    xytext = ( 20, 0 ),
                    textcoords = 'offset points',
                    fontsize = 22,
                    va = 'center',
                    arrowprops = {
                        'arrowstyle': '-[',
                        'lw': 1.5,
                    },
                )
        
    ax.set_xlabel( labels[x_key], fontsize=22 )
    ax.set_ylabel( labels[y_key], fontsize=22 )
    
    if x_log:
        ax.set_xscale( 'log' )
    if y_log:
        ax.set_yscale( 'log' )
    
    ax.set_xlim( x_lims )
    ax.set_ylim( y_lims )
    
    if x_key in fractions and y_key in fractions:
        ax.set_aspect( 'equal' )
        ax.plot(
            [ 0, 1 ],
            [ 0, 1 ],
            color = '.2',
            linewidth = 1,
        )
    
    # Legend
    legend_elements = [
        Line2D([0], [0], marker=markers[_], color='w', label=marker_labels[_], markerfacecolor='k', markersize=15)
        for _ in markers.keys()
    ]
    ax.legend(
        handles=legend_elements,
        prop = {'size': 16 },
    )
    
    # Colorbar
    if not z_log:
        norm_class = mpl.colors.Normalize
    else:
        norm_class = mpl.colors.LogNorm
    norm = norm_class(vmin=z_lims[0], vmax=z_lims[1])
    divider = make_axes_locatable( ax )
    cax = divider.append_axes( "right", pad=0.05, size='5%' )
    cbar = mpl.colorbar.ColorbarBase( cax, cmap=cmap, norm=norm, )
    
    # Colorbar label
    cax.annotate(
        text = labels[z_key],
        xy = (1,1),
        xytext = ( 0, 5 ),
        xycoords = 'axes fraction',
        textcoords = 'offset points',
        fontsize = 22,
        ha = 'right',
        va = 'bottom',
    )

## Multi-Panel

In [None]:
multipanel_combinations = [
    ( 'mvir', 'delta_disk_frac', 'thin_disk_frac_recent', ),
    ( 'mstar', 'delta_disk_frac', 'thin_disk_frac_recent', ),
    ( 'tcool_tff', 'delta_disk_frac', 'thin_disk_frac_recent', ),
]

In [None]:
cmap = palettable.matplotlib.Plasma_20.mpl_colormap

In [None]:
n_cols = len( multipanel_combinations )

fig = plt.figure( figsize=(6*n_cols,5), facecolor='w' )
main_ax = plt.gca()

gs = gridspec.GridSpec( 1, n_cols )
gs.update( wspace=0.0001 )

invalid = []
for i, (x_key, y_key, z_key) in enumerate( multipanel_combinations ):
    
    panel_invalid = []
    
    ax = plt.subplot(gs[0,i])
    
    xs = values[x_key]
    ys = values[y_key]
    zs = values[z_key]
    
    x_log = x_key in logscale
    y_log = y_key in logscale
    z_log = z_key in logscale
        
    x_lims = get_lim( xs, x_log, x_key, 1.2, )
    y_lims = get_lim( ys, y_log, y_key, 1.1, 2 )
    z_lims = get_lim( zs, z_log, z_key, )

    for key in cp.variations :
        
        try:
            x = xs[key]
            y = ys[key]
            z = zs[key]
        except KeyError:
            panel_invalid.append( key )
            continue
        
        if np.isnan( x ) or np.isnan( y ) or np.isnan( z ):
            continue
            
        elif not key in cp.variations:
            continue
        
        if '_' in key:
            sim_name, physics = key.split( '_' )
        else:
            sim_name = key
            physics = ''

        if not z_log:
            c_value = ( z - z_lims[0] ) / ( z_lims[1] - z_lims[0] )
        else:
            c_value = ( np.log10( z ) - np.log10( z_lims[0] ) ) / ( np.log10( z_lims[1] ) - np.log10( z_lims[0] ) )
        c = cmap( c_value )
            
        s = ax.scatter(
            x,
            y,
            s = 200,
            color = c,
            marker = markers[physics]
        )

#         # Annotate simulation names
#         annot_args = {
#             'textcoords': 'offset points',
#             'fontsize': 22,
#             'va': 'bottom',
#             'ha': 'left',
#             'xytext': ( 3, 3 ),
#         }
#         # Custom annotation arguments
#         if ( x_key, y_key ) in custom_annot_args:
#             if key in custom_annot_args[(x_key,y_key)]:
#                 custom_args = custom_annot_args[(x_key,y_key)][key]
#                 if custom_args is not None:
#                     annot_args.update( custom_args )
#                 else:
#                     annot_args = None
#         if annot_args is not None:
#             ax.annotate(
#                 s = sim_name,
#                 xy = ( x, y ),
#                 **annot_args
#             )
#         # Special cases
#         if ( x_key, y_key ) == ( 'delta_pdfcosphi', 'thin_disk_frac_recent' ):
#             if key == 'm11e_md':
#                 ax.annotate(
#                     s = 'm11a, m11c, m11d\nm11e, m11i, m11q',
#                     xy = ( x+0.05, y ),
#                     xytext = ( 30, 5 ),
#                     textcoords = 'offset points',
#                     fontsize = 22,
#                     va = 'center',
#                     arrowprops = {
#                         'arrowstyle': '-[',
#                         'lw': 1.5,
#                     },
#                 )
        
    ax.set_xlabel( labels[x_key], fontsize=22 )
    if ax.is_first_col():
        ax.set_ylabel( labels[y_key], fontsize=22 )
        
    # Zero line
    if y_key == 'delta_disk_frac':
        ax.axhline(
            0,
            color = pm['background_linecolor'],
            linewidth = 1,
            zorder = -100,
        )

    if x_log:
        ax.set_xscale( 'log' )
    if y_log:
        ax.set_yscale( 'log' )

    ax.set_xlim( x_lims )
    ax.set_ylim( y_lims )

    if x_key in fractions and y_key in fractions:
        ax.set_aspect( 'equal' )
        ax.plot(
            [ 0, 1 ],
            [ 0, 1 ],
            color = '.2',
            linewidth = 1,
        )
        
    if not ax.is_first_col():
        ax.tick_params( left=False, labelleft=False )
    
    if ax.is_first_col():
        # Legend
        legend_elements = [
            Line2D([0], [0], marker=markers[_], color='w', label=marker_labels[_], markerfacecolor='k', markersize=15)
            for _ in markers.keys()
        ]
        ax.legend(
            handles=legend_elements,
            prop = {'size': 16 },
        )
    
    if ax.is_last_col():
        # Colorbar
        if not z_log:
            norm_class = mpl.colors.Normalize
        else:
            norm_class = mpl.colors.LogNorm
        norm = norm_class(vmin=z_lims[0], vmax=z_lims[1])
        divider = make_axes_locatable( ax )
        cax = divider.append_axes( "right", pad=0.05, size='5%' )
        cbar = mpl.colorbar.ColorbarBase( cax, cmap=cmap, norm=norm, )

        # Colorbar label
        cax.annotate(
            text = 'color: '+ labels[z_key],
            xy = (1,1),
            xytext = ( 0, 5 ),
            xycoords = 'axes fraction',
            textcoords = 'offset points',
            fontsize = 22,
            ha = 'right',
            va = 'bottom',
        )
        
    tick_values = ax.get_xticks()
    new_ticklabels = []
    modified = False
    for i, tick_label in enumerate( ax.get_xticklabels() ):
        replace_values = [ 0.01, 0.1, 1., 10., 100. ]
        replace_strs = [ '0.01', '0.1', '1', '10', '100' ]
        for j, replace_val in enumerate( replace_values ):
            if np.isclose( tick_values[i], replace_val ):
                tick_label.set_text( replace_strs[j] )
                modified = True
        new_ticklabels.append( tick_label )
    if modified:
        ax.set_xticklabels( new_ticklabels )
        
    invalid.append( panel_invalid )
    
plotting.save_fig(
    out_dir = os.path.join( pm['figure_dir'], 'prevalence' ),
    save_file = 'aligned_fraction_vs_galaxy_props.pdf',
    fig = fig,
)

In [None]:
invalid

## Connected Before-After

In [None]:
multi_values = {
    'sigma_cosphi': sigma_cosphi,
    'med_cosphi': med_cosphi,
    'abs_med_cosphi': abs_med_cosphi,
    'pdfcosphi_0': data['cosphi']['pdf(cos theta=0)'],
    'disk_frac': disk_frac,
}

In [None]:
rd_bu_cmap = palettable.scientific.diverging.Berlin_3_r.mpl_colors

In [None]:
annotate = False

In [None]:
for x_key in multi_values.keys():

    fig = plt.figure( figsize=(8,8), facecolor='w' )
    ax = plt.gca()

    y_key = 'thin_disk_frac_recent'

    xs = multi_values[x_key]
    xs_pre = xs.inner_item( 0 )
    xs_post = xs.inner_item( -1 )
    ys = values[y_key]

    x_log = x_key in logscale
    y_log = y_key in logscale
    z_log = z_key in logscale

    x_lims = get_lim( xs, x_log, x_key, 1.1 )
    y_lims = get_lim( ys, y_log, y_key, 1.1 )

    for key in cp.variations:

        try:
            x_pre = xs_pre[key]
            y = ys[key]
        except KeyError:
            continue

        if np.isnan( x_pre ):
            continue

        if '_' in key:
            sim_name, physics = key.split( '_' )
        else:
            sim_name = key
            physics = ''

        # Pre
        s = ax.scatter(
            x_pre,
            y,
            s = 100,
            color = rd_bu_cmap[0],
            marker = markers[physics]
        )
        s = ax.scatter(
            xs_post[key],
            y,
            s = 100,
            color = rd_bu_cmap[-1],
            marker = markers[physics]
        )

        if annotate:
            ax.annotate(
                text = sim_name,
                xy = ( x_pre, y ),
                xytext = ( 5, 5 ),
                textcoords = 'offset points',
                fontsize = 22,
                ha = 'left',
                va = 'bottom',
            )

        ax.arrow(
            x_pre, y,
            xs_post[key] - x_pre, 0,
    #         linewidth = 1.5,
            color = 'k',
            zorder = -100,
            width = 0.002,
            head_length = 0.06,
            head_width = 0.02,
            length_includes_head = True,
        )

    if x_key == 'sigma_cosphi':
        # Reference lines
        sigma_cosphi_sphere = 0.84 * 2 - 0.16 * 2
        ax.axvline(
            sigma_cosphi_sphere,
            color = '.5',
            linestyle = '-',
            linewidth = 1,
        )
        ax.annotate(
            text = 'spherical\ndistribution',
            xy = ( sigma_cosphi_sphere, 0.6 ),
            xycoords = 'data',
            xytext = ( -5, -5 ),
            textcoords = 'offset points',
            ha = 'right',
            va = 'top',
            fontsize = 22,
            color = '.5',
        )

    ax.set_xlabel( labels[x_key], fontsize=22 )
    ax.set_ylabel( labels[y_key], fontsize=22 )

    if x_log:
        ax.set_xscale( 'log' )
    if y_log:
        ax.set_yscale( 'log' )

    ax.set_xlim( x_lims )
    ax.set_ylim( y_lims )

    # Legend
    legend_labels = [ 'Hot', 'Cool' ]
    legend_elements = [
    Line2D([0], [0], marker='o', color='k', label=legend_labels[_], markerfacecolor=rd_bu_cmap[_], markersize=15)
    for _ in [ 0, -1 ]
    ]
    ax.legend(
    handles=legend_elements,
    prop = {'size': 16 },
    )

## Combined

In [None]:
multi_values = {
    'disk_frac': disk_frac,
}

In [None]:
t_t1e5_centers = list( data['cosphi']['t_t1e5_centers'].values() )[0]

In [None]:
rd_bu_colors = palettable.scientific.diverging.Berlin_3_r.mpl_colors
rd_bu_colors_base = getattr( palettable.scientific.diverging, 'Roma_{}'.format( len( t_t1e5_centers ) ) )
rd_bu_colors = rd_bu_colors_base.mpl_colors
rd_bu_cmap = rd_bu_colors_base.mpl_colormap

In [None]:
cmap = palettable.matplotlib.Viridis_20.mpl_colormap

In [None]:
annotate = False

In [None]:
sims_subset = []

In [None]:
n_cols = 2
n_rows = 1
scale = 1.5

fig = plt.figure( figsize=(6*n_cols*scale,5*scale), facecolor='w' )
main_ax = plt.gca()

gs = gridspec.GridSpec( 1, n_cols )
#     gs.update( wspace=0.0001 )

x_key = 'thin_disk_frac_recent'
z_key = 'mvir'
y_key = 'disk_frac'

xs = values[x_key]
zs = values[z_key]
ys = values[y_key]
y_multis = multi_values[y_key]
ys_pre = ys.inner_item( 0 )
ys_post = ys.inner_item( -1 )
ys_delta = ys_post - ys_pre

x_log = x_key in logscale
y_log = y_key in logscale
z_log = z_key in logscale

x_lims = get_lim( xs, x_log, x_key, 1.1 )
y_lims = get_lim( ys, y_log, y_key, 1.1 )

####################################################
# Left panel
####################################################

ax = plt.subplot( gs[0,0] )

for key in cp.variations:
    
    if key not in sims_subset and len( sims_subset ) > 0:
        continue

    try:
        y_pre = ys_pre[key]
        x = xs[key]
    except KeyError:
        continue

    if np.isnan( y_pre ):
        continue

    if '_' in key:
        sim_name, physics = key.split( '_' )
    else:
        sim_name = key
        physics = ''

    # Pre
    s = ax.scatter(
        x,
        y_pre,
        s = 300,
        color = rd_bu_colors[0],
        marker = markers[physics]
    )
    # Post
    s = ax.scatter(
        x,
        ys_post[key],
        s = 300,
        color = rd_bu_colors[-1],
        marker = markers[physics]
    )
    # Change along the way
    y_multi_vals = y_multis[key]
    for j, val in enumerate( y_multi_vals ):
        
        # Ticks
        tick_length = 0.015
        linewidth = 5
        if np.isclose( t_t1e5_centers[j], 0.0 ):
            linewidth = 7
            tick_length *= 2
            ax.plot(
                [ x - tick_length / 2, x + tick_length / 2],
                [ val, ] * 2,
                linewidth = linewidth,
                color = rd_bu_colors[j],
                zorder = -10,
            )
            
            # Markers
#             ax.scatter(
#                 [ x, ] ,
#                 [ val, ] * 2,
#                 color = rd_bu_colors[j],
#                 s = 300,
#                 marker = markers[physics]
#             )

        # Line itself
        if j != 0:
            start_point = y_multi_vals[j-1] + 0.5 * ( val - y_multi_vals[j-1] )
        else:
            start_point = val
        if j != len( y_multi_vals ) - 1:
            end_point = y_multi_vals[j+1] - 0.5 * ( y_multi_vals[j+1] - val )
        else:
            end_point = val
        ax.plot(
            [ x, ] * 2,
            [ start_point, end_point ],
            linewidth = 3,
            color = rd_bu_colors[j],
            zorder = -20,
        )
#     s = ax.scatter(
#         [ x, ] * len( y_multi_vals ),
#         y_multi_vals,
#         s = 300,
#         color = rd_bu_colors,
#         marker = markers[physics],
#     )

    if annotate:
        ax.annotate(
            text = sim_name,
            xy = ( x, y_pre ),
            xytext = ( 5, 5 ),
            textcoords = 'offset points',
            fontsize = 22,
            ha = 'left',
            va = 'bottom',
        )
    ax.plot(
        [ x, x ],
        [ y_pre, ys_post[key] ],
        linewidth = 1,
        color = 'k',
        zorder = -100,
    )

if y_key == 'sigma_cosphi':
    # Reference lines
    sigma_cosphi_sphere = 0.84 * 2 - 0.16 * 2
    ax.axvline(
        sigma_cosphi_sphere,
        color = pm['background_linecolor'],
        linestyle = '-',
        linewidth = 1,
    )
    ax.annotate(
        text = 'spherical\ndistribution',
        xy = ( sigma_cosphi_sphere, 0.6 ),
        xycoords = 'data',
        xytext = ( -5, -5 ),
        textcoords = 'offset points',
        ha = 'right',
        va = 'top',
        fontsize = 22,
        color = '.5',
    )
elif y_key == 'disk_frac':
    ax.axhline(
        pm['disk_costheta'],
        color = pm['background_linecolor'],
        linewidth = 1,
        zorder = -100,
    )
    ax.annotate(
        text = 'isotropic distribution value',
        xy = ( 0.5, pm['disk_costheta'] ),
        xycoords = 'data',
        xytext = ( 5, -15 ),
        textcoords = 'offset points',
        ha = 'right',
        va = 'top',
        fontsize = 18,
        color = pm['background_linecolor'],
        arrowprops = { 'arrowstyle': '-', 'color': '.5', },
    )
    
#     # Typical value for thin disk stars
#     ax.axhline(
#         aligned_frac['m12i_md'],
#         color = pm['background_linecolor'],
#         linewidth = 1,
#         zorder = -100,
#     )
#     ax.annotate(
#         text = 'value for thin disk stars',
#         xy = ( 0.5, aligned_frac['m12i_md'] ),
#         xycoords = 'data',
#         xytext = ( 5, -15 ),
#         textcoords = 'offset points',
#         ha = 'right',
#         va = 'top',
#         fontsize = 18,
#         color = pm['background_linecolor'],
#         arrowprops = { 'arrowstyle': '-', 'color': '.5', },
#     )


ax.set_xlabel( labels[x_key], fontsize=22 )
ax.set_ylabel( labels[y_key], fontsize=22 )

# ax.annotate(
#     text = 'Aligned mass fraction',
#     xy = ( 0, 1 ),
#     xycoords = 'axes fraction',
#     xytext = ( 5, 5 ),
#     textcoords = 'offset points',
#     fontsize = 20,
#     ha = 'left',
#     va = 'bottom',
# )

if x_log:
    ax.set_xscale( 'log' )
if y_log:
    ax.set_yscale( 'log' )

ax.set_xlim( x_lims )
ax.set_ylim( y_lims )

# Legend
rd_bu_colors_legend = palettable.scientific.diverging.Roma_3.mpl_colors
# legend_labels = [ r'$t = t_{T = 10^5\,{\rm K}} - 150$ Myr', r'$t = t_{T = 10^5\,{\rm K}}$', r'$t = t_{T = 10^5\,{\rm K}} + 150$ Myr' ]
legend_labels = [ r'150 Myrs before $T = 10^5\,{\rm K}$', r'At $T = 10^5\,{\rm K}$', r'150 Myrs after $T = 10^5\,{\rm K}$' ]
legend_elements = [
Line2D([0], [0], marker='o', color='k', label=legend_labels[_], markerfacecolor=rd_bu_colors_legend[_], markersize=15)
for _ in [ 0, 1, 2 ]
]
ax.legend(
handles=legend_elements,
prop = {'size': 16 },
)

####################################################
# Right panel
####################################################

y_key = 'delta_' + y_key

x_lims = get_lim( xs, x_log, x_key, 1.2, )
y_lims = get_lim( ys_delta, y_log, y_key, 1.1, 2 )
z_lims = get_lim( zs, z_log, z_key, )

ax = plt.subplot(gs[0,1])

for key in cp.variations :

    try:
        x = xs[key]
        y = ys_delta[key]
        z = zs[key]
    except KeyError:
        continue

    if np.isnan( x ) or np.isnan( y ) or np.isnan( z ):
        continue
    elif not key in cp.variations:
        continue

    if '_' in key:
        sim_name, physics = key.split( '_' )
    else:
        sim_name = key
        physics = ''

    if not z_log:
        c_value = ( z - z_lims[0] ) / ( z_lims[1] - z_lims[0] )
    else:
        c_value = ( np.log10( z ) - np.log10( z_lims[0] ) ) / ( np.log10( z_lims[1] ) - np.log10( z_lims[0] ) )
    c = cmap( c_value )

    s = ax.scatter(
        x,
        y,
        s = 300,
        color = c,
        marker = markers[physics]
    )
    
# Zero line
if y_key == 'delta_disk_frac':
    ax.axhline(
        0,
        color = pm['background_linecolor'],
        linewidth = 1,
        zorder = -100,
    )

ax.set_xlabel( labels[x_key], fontsize=22 )
ax.set_ylabel( labels[y_key], fontsize=22 )

# ax.annotate(
#     text = 'Change in aligned mass fraction',
#     xy = ( 0, 1 ),
#     xycoords = 'axes fraction',
#     xytext = ( 5, 5 ),
#     textcoords = 'offset points',
#     fontsize = 20,
#     ha = 'left',
#     va = 'bottom',
# )

if x_log:
    ax.set_xscale( 'log' )
if y_log:
    ax.set_yscale( 'log' )

ax.set_xlim( x_lims )
ax.set_ylim( y_lims )

if x_key in fractions and y_key in fractions:
    ax.set_aspect( 'equal' )
    ax.plot(
        [ 0, 1 ],
        [ 0, 1 ],
        color = pm['lighter_background_linecolor'],
        linewidth = 1,
    )

# Legend
legend_elements = [
    Line2D([0], [0], marker=markers[_], color='w', label=marker_labels[_], markerfacecolor='k', markersize=15)
    for _ in markers.keys()
]
ax.legend(
    handles=legend_elements,
    prop = {'size': 16 },
)

# Colorbar
if not z_log:
    norm_class = mpl.colors.Normalize
else:
    norm_class = mpl.colors.LogNorm
norm = norm_class(vmin=z_lims[0], vmax=z_lims[1])
divider = make_axes_locatable( ax )
cax = divider.append_axes( "right", pad=0.05, size='5%' )
cbar = mpl.colorbar.ColorbarBase( cax, cmap=cmap, norm=norm, )

# Colorbar label
cax.annotate(
    text = labels[z_key],
    xy = (1,1),
    xytext = ( 0, 5 ),
    xycoords = 'axes fraction',
    textcoords = 'offset points',
    fontsize = 22,
    ha = 'right',
    va = 'bottom',
)

plotting.save_fig(
    out_dir = os.path.join( pm['figure_dir'], 'prevalence' ),
    save_file = 'aligned_fraction.pdf',
    fig = fig,
)