In [1]:
# import sys
# import gizmo_analysis as gizmo
# import utilities as ut
# import numpy as np
# import matplotlib.pyplot as plt

In [2]:
from scipy.stats import binned_statistic

def bin_helper(x, y, numbins=20, stat='mean'):

    stat, edges, _ = binned_statistic(x, y, statistic=stat, bins=numbins)
    std, edges, _ = binned_statistic(x, y, statistic='std', bins=numbins)
    bins = edges[:-1]
    
    return stat, std, bins

# ut.bin.get_statistics_of_array()

In [3]:
def plot_helper(x, y, label, c='r', stat='mean', ax=None, scatter=False, grid=True):
    if ax==None:
        ax = plt.gca()
    if scatter:
        ax.scatter(x, y, s=1, alpha=0.1)
    stat_y, std, bins = bin_helper(x, y, 14, stat=stat)
    
    ax.plot(bins, stat_y, 'o-', c=c, label=label)
    ax.fill_between(bins, stat_y-std, stat_y+std, color=c, alpha=0.5)
    ax.legend(loc='upper right')
    if grid:
        ax.grid()
    return ax

In [4]:
def ratio_checker(v_circ_form, v_circ_diff, inds, cutoff=0.2):
    
    # indices of stars particles formed on circular orbits
    form_inds = ut.array.get_indices(ps_form[inds][::, 2], [-0.3, 0.3]) # |z| < 300 pc
    v_tot = np.sqrt(np.square(vs_form[inds][::, 0]) + np.square(vs_form[inds][::, 2]))
    form_inds = ut.array.get_indices(v_tot, [0,v_circ_form], form_inds)  # sqrt(vR^2 + vz^2) < 20 km/s

    vz_diff = vs[inds][::, 2]-vs_form[inds][::, 2]
    vR_diff = vs[inds][::, 0]-vs_form[inds][::, 0]
    
    j = ps[inds][::, 0]*vs[inds][::, 1]
    j_form = ps_form[inds][::, 0]*vs_form[inds][::, 1]
    j_diff = (j-j_form)/j_form
    
    # indices of orbits that stay circular: sum in quadrature of change < 10 km/s 
    v_tot_change = np.sqrt(np.square(vz_diff) + np.square(vR_diff))

    non_circular_inds = ut.array.get_indices(v_tot_change, [v_circ_diff, max(v_tot_change)], form_inds) 

    cold_T_inds = ut.array.get_indices(v_tot_change, [0, v_circ_diff], form_inds)
    
    cold_T_cutoff_inds = ut.array.get_indices(abs(j_diff), [cutoff, max(abs(j_diff))], cold_T_inds)
        
    return len(cold_T_cutoff_inds)/len(non_circular_inds), j_diff[cold_T_inds]


In [5]:
def plot_hist_stats(data, ax=None, numstd=1):
    if ax==None:
        ax = plt.gca()
    mean = np.mean(data)
    std = np.std(data)
    c=['r', 'orange']
    ax.axvline(mean, c=c[0], label=f'mean = {mean:.3f}')
    for i in range(numstd):
        ax.axvline(mean-std*(i+1), c=c[i], linestyle='dashed', label=f'{i+1} $\sigma$ = {std*(i+1):.3f}')
        ax.axvline(mean+std*(i+1), linestyle='dashed', c=c[i])
    ax.legend()
    return ax