In [None]:
import copy
import time, importlib
import h5py
import numpy as np
import os
import pandas as pd
import scipy, scipy.ndimage
import tqdm
import unyt

In [None]:
import kalepy as kale

In [None]:
import matplotlib
import matplotlib.colors
import matplotlib.gridspec as gridspec
import matplotlib.patheffects as patheffects

In [None]:
import palettable

In [None]:
import linefinder.analyze_data.worldlines as a_worldlines
import linefinder.analyze_data.plot_worldlines as p_worldlines
import linefinder.config as l_config

In [None]:
import galaxy_dive.plot_data.plotting as plotting

In [None]:
import coolingFunction

In [None]:
import trove
import verdict
from py2tex import py2tex

In [None]:
import helpers

In [None]:
%matplotlib inline
matplotlib.style.use( '~/repos/clean-bold/clean-bold.mplstyle' )

# Parameters

## Manual

In [None]:
pm = dict(
    snum = 600,
    ahf_index = 600,
)

In [None]:
global_variations = [
    '',
    'track_all_thin_disk_stars',
    'track_all_recent_stars',
]

In [None]:
labels = [
    'main sample',
    r'all stars with $j_z/j_c(E) > 0.8$',
    r'all stars with age $<1$ Gyr',
    r'all stars with $j_z/j_c(E) > 0.9$',
]

In [None]:
cmap = palettable.cartocolors.qualitative.Pastel_10.mpl_colors
colors = [ cmap[0], cmap[1], cmap[3], cmap[5] ]

# Load Data

In [None]:
pms = []
ws = []
for gv in global_variations:
    pm = copy.deepcopy( pm )
    if gv != '':
        pm['global_variation'] = gv
    pm_i = trove.link_params_to_config(
        '/home1/03057/zhafen/papers/Hot-Accretion-in-FIRE/analysis/hot_accretion.trove',
        script_id = 'nb.8',
        **pm
    )
    ws_i = a_worldlines.Worldlines(
        tag = pm_i['tag'],
        data_dir = pm_i['data_dirs']['jug.4'],
        halo_data_dir = pm_i['halo_data_dir'],
        ahf_index = pm_i['ahf_index'],
    )
    try:
        ws_i.n_particles
    except FileNotFoundError:
        continue
    pms.append( pm_i )
    ws.append( ws_i )

# Calculate Accretion and cooling time

## Times

In [None]:
taccs = []
tcoolss = []
is_hot_accs = []
particle_indss = []
n_accreted = []
for i, pm in enumerate( pms ):
    
    print(  'Global variation: {}'.format( global_variations[i] ) )

    tacc_inds = ws[i].calc_tacc_inds(
        lookback_time_max = pm['lookback_time_max'],
        choose_first = pm['choose_first'],
    )

    tcools_inds = ws[i].calc_tcools_inds(
        lookback_time_max = pm['lookback_time_max'],
        choose_first = pm['choose_first'],
        B = pm['logTcools'],
    )
    
    is_hot_accretion = ws[i].calc_is_hot_accretion(
        lookback_time_max = pm['lookback_time_max'],
        choose_first = pm['choose_first'],
        B = pm['logTcools'],
    )

    is_accreted = tacc_inds != l_config.INT_FILL_VALUE
    valid = is_accreted & ( tcools_inds != l_config.INT_FILL_VALUE )
    tacc_inds = tacc_inds[valid]
    tcools_inds = tcools_inds[valid]
    particle_inds = np.arange( ws[i].n_particles )[valid]

    tacc = ws[i].get_data( 'tacc' )[particle_inds]
    tcools = ws[i].get_data( 'tcools' )[particle_inds]
    
    taccs.append( tacc )
    tcoolss.append( tcools )
    is_hot_accs.append( is_hot_accretion )
    particle_indss.append( particle_inds )
    n_accreted.append( is_accreted.sum() )

### Special Case
For the "all thin disk stars" case we also want the times for the most eccentric of stars.

In [None]:
if False:
    pm = pms[1]
    w = ws[1]
    particle_inds = particle_indss[1]

    # Get total angular momentum vector
    base_processed_data_dir = pm['config_parser'].get( 'DEFAULT', 'processed_data_dir' )
    tot_momentum_fp = os.path.join( base_processed_data_dir, 'tot_momentums.hdf5' )
    tot_ang_momentum = verdict.Dict.from_hdf5( tot_momentum_fp )[pm['variation']]['snum{:03d}'.format( 600 )]
    tot_ang_momentum_normed = tot_ang_momentum / np.linalg.norm( tot_ang_momentum )

    # Calculate superthin stars
    w.total_ang_momentum = tot_ang_momentum_normed
    c = w.get_data( 'Jz/Jcirc' )
    is_superthin = c[:,0][particle_inds] > 0.9

    # Store
    taccs.append( taccs[1][is_superthin] )
    tcoolss.append( tcoolss[1][is_superthin] )
    pms.append( pm )

## Time distributions

### General

In [None]:
t_z0 = ws[0].get_data( 'time' )[0]
t_zi = ws[0].get_data( 'lookback_time' )[-1]
t_min = np.nanmin([ np.nanpercentile( np.hstack( taccs ), 1. ), np.nanpercentile( np.hstack( tcoolss ), 1. ) ])
t_range = [ t_min, t_z0 ]
t_bins = np.arange( t_range[0], t_range[1], 0.065 )
t_centers = 0.5 * ( t_bins[:-1] + t_bins[1:] )

In [None]:
pdfs = {
    'tacc': [],
    'tcools': [],
}
cdfs = {
    'tacc': [],
    'tcools': [],
}
kdes = {
    'tacc': [],
    'tcools': [],    
}
for i, pm in enumerate( tqdm.tqdm( pms ) ):
    keys = [ 'tacc', 'tcools' ]
    for j, tchange in enumerate([ taccs, tcoolss ]):
    
        pdf, bins = np.histogram(
            tchange[i],
            bins = t_bins,
            density = True,
        )
        pdfs[keys[j]].append( pdf )

        cdf = np.cumsum( pdf )
        cdf /= cdf[-1]
        cdfs[keys[j]].append( cdf )
        
        points, kde = kale.density(
            tchange[i],
            points = t_centers,
            reflect = [ 0, t_z0 ],
            probability = True,
        )
        kdes[keys[j]].append( kde )

# Plot time distributions

## Settings

In [None]:
x_annots = {
    'tacc': [ 13.05, 11, 12.6, 12. ],
    'tcools': [ 12.9, 9.2, 12, 12. ],
}

## Plot

In [None]:
for k, t_key in enumerate([ 'tacc', 'tcools' ]):

    fig = plt.figure( figsize=(7, 5.25/2) )
    ax = plt.gca()

    for i, pm in enumerate( pms ):
        ax.plot(
            t_centers,
            kdes[t_key][i],
            c = colors[i],
            zorder = -i,
        )

        x_annot = x_annots[t_key][i]
        if x_annot is None:
            x_annot = scipy.interpolate.interp1d( cdfs[t_key][i], t_centers )( 0.5 )
        y_annot = scipy.interpolate.interp1d( t_centers, kdes[t_key][i] )( x_annot )
        text = ax.annotate(
            text = labels[i],
            xy = ( x_annot, y_annot ),
            xytext = ( -5, 5 ),
            textcoords = 'offset points',
            ha = 'right',
            va = 'bottom',
            fontsize = 20,
            c = colors[i],
        )
        text.set_path_effects([ patheffects.Stroke(linewidth=4, foreground='white'), patheffects.Normal() ])

    ax.set_xticks( np.arange( 0, t_z0, 1. ) )
    ax.set_yticks( np.arange( 0, 5, 0.5 ) )

    ax.set_xlim( 6, t_centers[-1] )
    ax.set_ylim( 0, np.nanmax( kdes[t_key] )*1.05 )
    
    ax.axvline(
        t_z0 - 1.,
        c = pm['background_linecolor'],
        linewidth = 1,
        zorder = -100,
    )

    x_label = (
        helpers.get_tchange_label( pm, central_indices='{}_inds'.format( t_key ) ) +
        ' [Gyr]'
    )
    ax.set_xlabel( x_label, fontsize=22 )
    ax.set_ylabel( 'PDF', fontsize=22 )

    plotting.save_fig(
        out_dir = os.path.join( pms[0]['figure_dir'], 'selected_to_all_comparison' ),
        save_file = '{}_{}.pdf'.format( t_key, pm['variation'] ),
        fig = fig,
    )

# Store General Properties

Number of tracked particles, number of those that reach T>Tcools prior to accretion, number of those that are hot accretion, and the fraction that is hot.

In [None]:
for i, is_hot_accretion in enumerate( is_hot_accs ):
    
    if i > 2:
        continue

    data_fp = os.path.join( pms[i]['processed_data_dir'], 'summary.hdf5' )
    print( 'Updating data at {}'.format( data_fp ) )
    data = verdict.Dict.from_hdf5( data_fp, create_nonexistent=True )

    # Store quantities
    n_hot = is_hot_accretion.sum()
    keys = [ 'n_tracked', 'n_valid', 'n_hot', 'n_accreted', 'f_hot' ]
    quantities = [ ws[i].n_particles, tcoolss[i].size, n_hot, n_accreted[i], n_hot / n_accreted[i] ]
    for j, key in enumerate( keys ):
        print( '    {}: {:.5g}'.format( key, quantities[j] ) )
        if key not in data:
            data[key] = { pms[i]['variation']: quantities[j] }
        else:
            data[key][pms[i]['variation']] = quantities[j]

    data.to_hdf5( data_fp, handle_jagged_arrs='row datasets' )