In [None]:
import coolingFunction

In [None]:
import copy
import numpy as np
import h5py
import scipy
import sys
import verdict
import os
import unyt

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.patheffects as path_effects
import matplotlib.cm as cm
import matplotlib.colors as plt_colors
import matplotlib.gridspec as gridspec
import matplotlib.transforms
import palettable

In [None]:
import linefinder.analyze_data.worldlines as a_worldlines
import linefinder.analyze_data.worldline_set as worldline_set
import linefinder.analyze_data.plot_worldlines as p_worldlines
import linefinder.utils.presentation_constants as p_constants

In [None]:
import galaxy_dive.analyze_data.ahf as analyze_ahf
import galaxy_dive.plot_data.ahf as plot_ahf
import galaxy_dive.analyze_data.particle_data as particle_data
import galaxy_dive.plot_data.generic_plotter as generic_plotter
import galaxy_dive.plot_data.plotting as plotting
import galaxy_dive.utils.data_operations as data_operations
import galaxy_dive.utils.executable_helpers as exec_helpers

In [None]:
import linefinder.utils.file_management as file_management
import linefinder.config as config

In [None]:
import trove

# Load Data

In [None]:
pm = dict(
    snum = 600,
    tables_dir = '/work/03057/zhafen/CoolingTables/',
    study_duplicates = False,
    ahf_index = 600,
)

In [None]:
pm = trove.link_params_to_config(
    '/home1/03057/zhafen/papers/Hot-Accretion-in-FIRE/analysis/hot_accretion.trove',
    **pm
)

In [None]:
# Used so often it's nice to not enclose
snum = pm['snum']
ind = pm['ahf_index'] - snum

In [None]:
w = a_worldlines.Worldlines(
    tag = pm['tag'],
    data_dir = pm['data_dir'],
    halo_data_dir = pm['halo_data_dir'],
    ahf_index = pm['ahf_index'],
)

In [None]:
w.retrieve_halo_data()

In [None]:
m_plot_label  = r'$M_{\rm h} = 10^{' + '{:.02g}'.format( np.log10( w.m_vir[snum] ) )
m_plot_label += '} M_\odot$'
plot_label = m_plot_label + ', z={:.02}'.format( w.redshift[snum] )
print( plot_label )

In [None]:
classification_list = copy.copy( p_constants.CLASSIFICATIONS_CGM_FATE )

In [None]:
w_plotter = p_worldlines.WorldlinesPlotter( w, label=plot_label )

# Data Pre-Processing

## Calculate $\theta$
Also called $\phi$...

In [None]:
tot_momentum_fp = os.path.join( pm['processed_data_dir'], 'tot_momentums.hdf5' )
tot_ang_momentum = verdict.Dict.from_hdf5( tot_momentum_fp )[pm['variation']]['snum{:03d}'.format( snum )]

In [None]:
w.calc_abs_phi( normal_vector=tot_ang_momentum )

## Calculate mass deposition

In [None]:
delta_m = w.get_data( 'M' )[:,:-1] - w.get_data( 'M' )[:,1:]
deposited_m = np.ma.masked_array( delta_m, delta_m<0 ).sum( axis=1 ).data

# Analysis

In [None]:
# Setup axes
t_window = 1.
t = w.get_data( 'time' )
x_range = [ t[ind] - t_window, t[ind] ]

In [None]:
t_snaps = t[( t > x_range[0] ) & ( t < x_range[1] )][::-1]

In [None]:
n_snaps = t_snaps.size

In [None]:
dt = t_snaps[1:] - t_snaps[:-1]

In [None]:
t_bins = np.zeros( ( t_snaps.size + 1, ) )
t_bins[1:-1] = t_snaps[:-1] + dt / 2.
t_bins[0] = t_snaps[0] - dt[0] / 2.
t_bins[-1] = t_snaps[-1] + dt[-1] / 2.

In [None]:
   
w.data_masker.clear_masks()
w.data_masker.mask_data( 'PType', data_value=0 )

# Median and interval stats
logT = np.log10( w.get_selected_data( 'T', compress=False ) )#[:,ind:ind+n_snaps+1]
R = w.get_selected_data( 'R', compress=False )#[:,ind:ind+n_snaps+1]
L = w.get_selected_data( 'Lmag', compress=False )#[:,ind:ind+n_snaps+1]
M = w.get_selected_data( 'M', compress=False )#[:,ind:ind+n_snaps+1]

logT_med = np.nanmedian( logT, axis=0 )
R_med = np.nanmedian( R, axis=0 )

logT_low = np.nanpercentile( logT, 16, axis=0 )
logT_high = np.nanpercentile( logT, 84, axis=0 )

R_low = np.nanpercentile( R, 16, axis=0 )
R_high = np.nanpercentile( R, 84, axis=0 )

inds = []
for logT_arr in logT:

    ind_ = -1
    for i in range( logT_arr.size ):
        if logT_arr[i] > 5.:
            ind_ = i
            break

    inds.append( ind_ )

inds = np.array( inds )

R_at_Tcool = np.array( [ R[i, ind] for i, ind in enumerate( inds ) ] )
M_at_Tcool = np.array( [ M[i, ind] for i, ind in enumerate( inds ) ] )
L_at_Tcool = np.array( [ L[i, ind] for i, ind in enumerate( inds ) ] )

t_at_Tcool = np.array( [ t[ind] for ind in inds ] )


## Accretion Tracks and $R_{\rm 10^5K}$ Distribution

In [None]:
dt_before = -2.
dt_after = 0.5
color_dt = 0.2
n_particles = 10
x_lim = np.array( [ 0, 105 ] )
y_lim = np.array( [ 5e3, 3e6 ] )
y2_lim = np.array( [ 1, 1e2 ] )

In [None]:
%matplotlib inline

# Load sim data
r_vir = w.r_vir[snum]

w.data_masker.clear_masks()

# Only include particles that have never left the main galaxy
w.data_masker.mask_data( 'n_out', -1, 1 )

np.random.seed( 2 )

fig = plt.figure( figsize=(12, 11), facecolor='w' )
ax = plt.gca()

gs = gridspec.GridSpec(7, 1)
gs.update( hspace=0.001 )

ax1 = plt.subplot(gs[:2,0])

r_for_hist = copy.copy( R_at_Tcool )
r_for_hist[r_for_hist>x_lim[1]] = x_lim[1]*.99

n, bins, patches = ax1.hist(
    R_at_Tcool,
    bins = np.linspace( 0., 2.*r_vir, 256 ),
    color = '0.5',
    density = True,
)

print( 'Median R_at_Tcool = {:.3g} Rvir'.format( np.nanmedian( R_at_Tcool ) / r_vir ) )

# Don't do the below, because it's so far out that it's not visible
# Create plot for volume filling distribution
#     n_rand = int( 1e5 )
#     data_coords = np.random.uniform( -r_vir, r_vir, (3, n_rand ) )
#     data_r = np.sqrt( ( data_coords ** 2. ).sum( axis=0 ) )
#     filtered_data_r = data_r[data_r<r_vir]
#     ax1.hist(
#         filtered_data_r,
#         bins = bins,
#         color = 'k',
#         density = True,
#         histtype = 'step',
#         linewidth = 3,
#         linestyle = '--',
#     )

# ax.axvline(
#     np.median( R_at_Tcool ),
#     color = 'k',
# )

# ax1.axvline(
#     w.r_gal[0],
#     color = 'k',
#     linestyle = '--',
#     linewidth = 3,
# )

ax1.set_xlim( x_lim )

ax1.set_xlabel( r'$R_{T=10^5{\rm K}}$ (kpc)', fontsize=22, labelpad=10 )
ax1.xaxis.set_label_position( 'top' )
#     ax1.set_ylabel( 'count', fontsize=22, )

ax1.tick_params( axis='x', top=True, labeltop=True, bottom=False, labelbottom=False )
ax1.tick_params( axis='y', left=False, labelleft=False, )

ax1.annotate(
    s='radius at which accreted gas cools\n{}'.format( pm['variation'] ),
    xy=(1,1),
    xycoords='axes fraction',
    xytext=(-10,-10),
    textcoords='offset points',
    ha = 'right',
    va = 'top',
    fontsize = 22,
)


### FLOW PLOT ####

ax2 = plt.subplot(gs[2:,0])
#     ax3 = plt.subplot(gs[6:,0])
ax3 = ax

# Choose particles
particle_inds = np.random.choice( np.arange( w.n_particles ), size=n_particles, replace=False )

# Get the time at the phase
t_min_t_cool = ( t[:,np.newaxis] - t[inds] ).transpose()
w.data['t_rel_t1e5'] = t_min_t_cool

# Get positions in r-T space and color
valid_value = ( t_min_t_cool < dt_after ) & ( t_min_t_cool > dt_before )
valid_value = valid_value & ( w.get_data( 'n_out' ) == 0 )
valid_value_inds = valid_value[particle_inds]
r_vecs_all = w.get_data( 'R' )[particle_inds]
T_vecs_all = w.get_data( 'T' )[particle_inds]
K_vecs_all = w.get_data( 'entropy' )[particle_inds]

# Plot quivers for each particles
for k, particle_ind in enumerate( particle_inds ):

    # Format for quiver
    r_vecs = r_vecs_all[k][valid_value_inds[k]]
    T_vecs = np.log10( T_vecs_all[k][valid_value_inds[k]] )
    K_vecs = np.log10( K_vecs_all[k][valid_value_inds[k]] )
    x = r_vecs[1:]
    y = T_vecs[1:]
    y2 = K_vecs[1:]
    dx = r_vecs[:-1] - r_vecs[1:]
    dy = T_vecs[:-1] - T_vecs[1:]
    dy2 = K_vecs[:-1] - K_vecs[1:]
    angles_deg = np.arctan2( dy, dx ) * 180. / np.pi
    C = t_min_t_cool[particle_inds][k,valid_value_inds[k]][1:] * 1e3

    # Plot quiver
    quiver = ax2.quiver(
        x, y,
        dx, dy,
        C,
        angles = 'xy',
        units = 'y',
        scale = 10,
        minshaft = 2,
        headwidth = 2,
        headlength = 3.5,
#             color = 'red',
        cmap = palettable.scientific.diverging.Berlin_5_r.mpl_colormap,
        norm = plt.Normalize( -color_dt*1e3, color_dt*1e3 ),
    )
    plotting.add_colorbar(
        fig,
#             ax2,
        quiver,
        ax_location = [0.905, 0.125, 0.03, 0.6],
#             method = 'ax',
    )

#     w_plotter.plot_streamlines(
#         ax = ax2,
#         x_key = 'R',
#         y_key = 'logT',
#         start_ind = ind,
#         end_ind = 'time_based',
#         t_end = 5.,
#         sample_inds = particle_inds,
#         sample_selected_interval = False,
# #         x_data_kwargs = { 'smooth_data' : True, 'smoothing_window_length' : 7 },
# #         y_data_kwargs = { 'smooth_data' : True, 'smoothing_window_length' : 7 },
#         color = 'black',
#         fade_color = 'black',
#         min_fade_linewidth = 0.5,
#     #     fade_streamlines = False,
#     #     line_features = gas_to_star_line_features,
#         linewidth = 1,
#         x_label = 'R (kpc)',
#         y_label = 'logT (K)',
#     #     y_scale = 'log',
#         x_range = [ 0, 100. ],
#         y_range = [ 3.5, 6.9 ],
#     #     y_floor = 10.**3.8,
#     )

# Plot reference entropy lines
r = w.get_data( 'R' )
at_border = ( r < 1.1 * x_lim[1] ) & ( r > 0.9 * x_lim[1] ) & valid_value
k_at_border = w.get_data( 'entropy' )[at_border]
med_k_at_border = np.nanmedian( k_at_border )
a_vals = [ 0, 1, 2 ]
r_arr = np.linspace( x_lim[0], x_lim[1], 256 )
k_arrs = [ med_k_at_border * ( r_arr / x_lim[1] )**a for a in a_vals ]
for m, k_arr in enumerate( k_arrs ):
    ax3.plot(
        r_arr,
        np.log10( k_arr ),
        color = '0.25',
        linewidth = 1.5,
#             linestyle = '--',
    )
    ax3.annotate(
        s = r'$\propto r^{' + str( a_vals[m] ) + r'}$',
        xy = ( r_arr[100], np.log10( k_arr[100] ) ),
        xycoords = 'data',
        xytext = ( 0, 0 ),
        textcoords = 'offset points',
        ha = 'right',
        va = 'bottom',
        fontsize = 24,
        color = '0.25',
    )

# # Time
# ax.plot(
#     R_med,
#     logT_med,
#     linewidth = 3,
#     color = 'b',
# )
# ax.fill_between(
#     R_med,
#     logT_low,
#     logT_high,
#     color = 'b',
#     alpha = 0.25,
# )

for ax_k in [ ax2, ax3 ]:
    ax_k.annotate(
        s='accretion tracks',
        xy=(1,0),
        xycoords='axes fraction',
        xytext=(-10,10),
        textcoords='offset points',
        ha = 'right',
        va = 'bottom',
        fontsize = 22,
    )

t_label = ax2.annotate(
    s = r'$t - t_{T=10^5 {\rm K}}$ (Myr)',
    xy = ( 1, 0 ),
    xycoords = 'axes fraction',
    xytext = ( 20, -30 ),
    textcoords = 'offset points',
    ha = 'center',
    va = 'top',
    fontsize = 24,
)

# 0.1 Rvir line
for ax_k in [ ax1, ax2, ax3 ]:
    ax_k.axvline(
        0.05 * r_vir,
        color = 'k',
        linestyle = '--',
        linewidth = 3,
    )
    ax_k.axvline(
        r_vir,
        color = 'k',
        linestyle = '--',
        linewidth = 3,
    )
#     ax_k.axvline(
#         w.r_gal[ind],
#         color = 'k',
#         linestyle = '--',
#         linewidth = 3,
#     )
    if ax_k.is_first_row():
        trans = matplotlib.transforms.blended_transform_factory( ax_k.transData, ax_k.transAxes )
        ax_k.annotate(
            s = r'$0.05 R_{\rm vir}$',
            xy = ( 0.05 * r_vir, 1.0 ),
            xycoords = trans,
            xytext = ( 6, -10 ),
            textcoords = 'offset points',
            ha = 'left',
            va = 'top',
            fontsize = 24,
        )
#         ax_k.annotate(
#             s = r'$ R_{\rm gal}$',
#             xy = ( w.r_gal[ind], 1.0 ),
#             xycoords = trans,
#             xytext = ( -6, -10 ),
#             textcoords = 'offset points',
#             ha = 'right',
#             va = 'top',
#             fontsize = 24,
#         )

# Remove bottom ticks on middle axis
#     ax2.tick_params( axis='x', bottom=False, labelbottom=False )

ax2.set_xlim( x_lim )
ax2.set_ylim( np.log10( y_lim ) )
ax2.set_ylabel( '$\log$T (K)', fontsize=22 )
ax3.set_xlim( x_lim )
ax3.set_ylim( np.log10( y2_lim) )
ax2.set_xlabel( 'R (kpc)', fontsize=22 )
ax3.set_ylabel( r'$\log$K (Kev cm$^2$)', fontsize=22 )

plotting.save_fig(
    out_dir = os.path.join( pm['figure_dir'], 'tracks' ),
    save_file = 'tracks_{}.pdf'.format( pm['variation'] ),
    fig = fig,
)


## Angular Distribution

In [None]:
all_t_phase_centers = {
    True: np.array([ -0.15, -0.06, -0.03, 0., 0.03, 0.06, 0.15, ]),
    False: np.arange( -0.100, 0.100 + 0.0000001, 0.1),
} 

In [None]:
binned_props = {}
all_dists = {}
props = [ 'Phi', 'Rx', 'Ry', 'Rz', 'T' ]
for key in props:
    all_dists[key] = {}
    binned_props[key] = {}
        
# Get the time at the phase
t_tphase = ( t[:,np.newaxis] - t[inds] ).transpose()
t_tphase_flat = t_tphase.flatten()

# Get phi (also known as theta)
w.data_masker.clear_masks()
w.data_masker.mask_data( 'PType', data_value=0 )
for key in props:

    prop = w.get_selected_data( key, compress=False )

    # Format data
    prop_flat = prop.flatten()

    # Get distributions
    prop_dists = []
    binned_prop = []
    dt = all_t_phase_centers[key=='Phi'][1] - all_t_phase_centers[key=='Phi'][0]
    for i, center in enumerate( all_t_phase_centers[key=='Phi'] ):
        bin_low = center - dt / 2.
        bin_high = center + dt / 2.
        in_bin = ( t_tphase_flat > bin_low ) & ( t_tphase_flat < bin_high )
        prop_arr = prop_flat[in_bin]
        hist, bin_edges = np.histogram(
            np.cos( prop_arr * np.pi / 180 ),
            bins = np.linspace( -1., 1., 64 ),
            density = True,
        )
        binned_prop.append( prop_arr )
        prop_dists.append( hist )

    all_dists[key] = prop_dists
    binned_props[key] = binned_prop

In [None]:
bin_centers = bin_edges[:-1] + 0.5 * ( bin_edges[1] - bin_edges[0] )

In [None]:
t_tphase_centers = all_t_phase_centers[True]
labeled_is = np.arange( len( t_tphase_centers ) )

In [None]:
fig = plt.figure( figsize=(10, 4.5 ), facecolor='w' )
ax = plt.gca()

z_max = t_tphase_centers.max()
z_min = t_tphase_centers.min()
        
phi_dists = all_dists['Phi']

for i, phi_dist in enumerate( phi_dists ):

    z_width = z_max - z_min
    color_value = ( t_tphase_centers[i] - z_min )/z_width
    color = palettable.scientific.diverging.Roma_3.mpl_colormap( color_value )

    if i in labeled_is:
        if np.isclose( t_tphase_centers[i], 0. ):
            t_tphase_centers[i] = 0
#             label = (
#                 r'$t - t_{T=10^5 {\rm K}}$ =' +
#                 ' {:.3g}'.format( t_tphase_centers[i]*1e3 ) +
#                 r' Myr'
#             )
        label = (
            '{:.3g}'.format( t_tphase_centers[i]*1e3 ) +
            r' Myr'
        )
    else:
        label = None

    line = ax.plot(
        bin_centers,
        phi_dist, #/ (np.pi / 180. / 2. * np.sin( bin_centers * np.pi/180. ) ),
        linewidth = 5,
        color = color,
        label = label,
#             zorder = 10 - i,
    )

ax.tick_params(
    axis = 'x',
    top = True,
    labeltop = ax.is_first_row(),
    bottom = ax.is_last_row(),
    labelbottom = ax.is_last_row(),
)

ax.axhline(
    0.5,
    color = '.2',
    linestyle = '-',
    linewidth = 2,
)
ax.axvline(
    0,
    color = '.2',
    linestyle = '-',
    linewidth = 2,
)

# Sim name label
ax.annotate(
    s = pm['variation'],
    xy = ( 0, 1 ),
    xycoords = 'axes fraction',
    xytext = ( 20, -20 ),
    textcoords = 'offset points',
    ha = 'left',
    va = 'top',
    fontsize = 26,
)

# Spherical line label
ax.annotate(
    s = 'spherical\ndistribution',
    xy = ( -1, 0.5 ),
    xycoords = 'data',
    xytext = ( 10, 10 ),
    textcoords = 'offset points',
    ha = 'left',
    va = 'bottom',
    fontsize = 22,
)

t_label = ax.annotate(
    s = r'$t - t_{T=10^5 {\rm K}}$',
    xy = ( 1, 0.875 ),
    xycoords = 'axes fraction',
    xytext = ( -25, 0 ),
    textcoords = 'offset points',
    ha = 'right',
    va = 'bottom',
    fontsize = 24,
)
t_label.set_zorder( 1000 )
ax.legend(
    prop={'size': 17},
    loc = 'center right',
)

ax.set_xlim( -1, 1 )
ax.set_ylim( 0, 3.75 )

ax.set_xlabel( r'$\cos\ \theta$', fontsize=22 )
if ax.is_first_row():
    ax.xaxis.set_label_position( 'top' )
ax.set_ylabel( r'PDF$\ (\cos\ \theta$)', fontsize=22 )

plotting.save_fig(
    out_dir = os.path.join( pm['figure_dir'], 'ang_dist_evolution' ),
    save_file = 'theta_vs_t_{}.pdf'.format( pm['variation'] ),
    fig = fig,
)

## Intuition-Building Visualizations

In [None]:
t_tphase_centers = all_t_phase_centers[False]

In [None]:
# Custom cmap
# Deviation around zero colormap (blue--red)
cols = []
for x in np.linspace(0,1, 256):
    rcol = 0.237 - 2.13*x + 26.92*x**2 - 65.5*x**3 + 63.5*x**4 - 22.36*x**5
    gcol = ((0.572 + 1.524*x - 1.811*x**2)/(1 - 0.291*x + 0.1574*x**2))**2
    bcol = 1/(1.579 - 4.03*x + 12.92*x**2 - 31.4*x**3 + 48.6*x**4 - 23.36*x**5)
    cols.append((rcol, gcol, bcol))
cmap = matplotlib.colors.LinearSegmentedColormap.from_list("PaulT_plusmin", cols)

In [None]:
n_cols = t_tphase_centers.size

fig = plt.figure( figsize=(6*n_cols, 8), facecolor='w' )
main_ax = plt.gca()

gs = gridspec.GridSpec(1,n_cols)
gs.update( wspace=0.01 )

axs = []
for i in range( n_cols ):
    
    ax = plt.subplot(gs[0,i])

    bins = np.linspace( -30., 30., 128 )
    hist, x_edges, y_edges = np.histogram2d(
        binned_props['Ry'][i],
        binned_props['Rz'][i],
        bins = bins,
        weights = np.log10( binned_props['T'][i] ),
    )
    norm_hist, x_edges, y_edges = np.histogram2d(
        binned_props['Ry'][i],
        binned_props['Rz'][i],
        bins = bins,
    )
    hist /= norm_hist
#     cmap = matplotlib.cm.RdBu_r
#     cmap = palettable.scientific.diverging.Berlin_3.mpl_colormap

    cmap.set_bad(color='0.8')
    img = ax.imshow(
        np.rot90( hist ),
        cmap = cmap,
        vmin = 3.,
        vmax = 7.,
        extent = [ x_edges[0], x_edges[-1], y_edges[0], y_edges[-1] ],
    )
    
    if not ax.is_first_col():
        ax.tick_params( left=False, labelleft=False )
    
    t_label = ax.annotate(
        s = (
            '{:.3g}'.format( t_tphase_centers[i]*1e3 ) +
            r' Myr'
        ),
        xy = ( 1, 1 ),
        xycoords = 'axes fraction',
        xytext = ( -5, -5 ),
        textcoords = 'offset points',
        ha = 'right',
        va = 'top',
        fontsize = 24,
        color = 'w',
    )
    t_label.set_path_effects([
        path_effects.Stroke(linewidth=4, foreground='black'),
        path_effects.Normal(),
    ])

    ax.set_aspect( 'equal' )
    axs.append( ax )
    
fig.colorbar( img, ax=axs, shrink=0.6 )

plotting.save_fig(
    out_dir = os.path.join( pm['figure_dir'], 'projected' ),
    save_file = 'projected_temp_{}.pdf'.format( pm['variation'] ),
    fig = fig,
)