In [None]:
import h5py
import numpy as np
import os
import scipy
import verdict

In [None]:
import galaxy_dive.analyze_data.particle_data as particle_data
import galaxy_dive.plot_data.generic_plotter as generic_plotter
import galaxy_dive.utils.astro as astro_utils

In [None]:
import galaxy_dive.utils.data_operations as data_operations

# Load Galaxy Data

In [None]:
snum = 600
sim_data_dir = '/scratch/03057/zhafen/multiphysics/m12i_res7100_mhdcv_old/output'
data_dir = '/scratch/03057/zhafen/linefinder_data/multiphysics/m12i_res7100_mhdcv/data'
halo_data_dir = '/scratch/03057/zhafen/halo_files/multiphysics/m12i_res7100_mhdcv'

In [None]:
s_data = particle_data.ParticleData(
    sdir = sim_data_dir,
    halo_data_dir = halo_data_dir,
    snum = snum,
    ptype = 4,
    main_halo_id = 0,
)

In [None]:
s_plotter = generic_plotter.GenericPlotter( s_data )

# Rotate

In [None]:
pos = data_operations.align_axes(
    s_data.get_data( 'P' ).transpose(),
    s_data.total_ang_momentum,
)

### Make sure it looks okay

In [None]:
%matplotlib inline
fig = plt.figure( figsize=(6,6), facecolor='w' )
ax = plt.gca()

_ = ax.hist2d(
    pos[:,1],
    pos[:,2],
    bins = [ np.linspace( -10, 10, 128 ), ] * 2,
    norm = matplotlib.colors.LogNorm(),
)

ax.set_aspect( 'equal' )


In [None]:
inside_galaxy = s_data.get_data( 'R' ) < s_data.averaging_frac * s_data.length_scale

In [None]:
fig = plt.figure( figsize=(10,6), facecolor='w' )
ax = plt.gca()

hist, bins, _ = ax.hist(
    pos[:,2][inside_galaxy],
    bins = 2048,
    weights = s_data.get_data( 'M' )[inside_galaxy],
)

In [None]:
centers = bins[:-1] + 0.5 * ( bins[1] - bins[0] )
positive_centers = centers[centers>=0]
positive_hist = hist[centers>=0]

In [None]:
interp_fn = scipy.interpolate.interp1d( positive_hist, positive_centers )

In [None]:
scale_height = interp_fn( hist.max() / np.e )

# Bursty Phase Lookback Time

In [None]:
from abg_python.smooth_utils import find_first_window,boxcar_average

In [None]:
def findBurstyRegime(
    time_edges:np.ndarray,
    SFRs:np.ndarray,
    thresh:float=0.3, ## dex of scatter
    window_size:float=0.3, ## size of window to compute scatter within
    mode:str=None,
    thresh_window:float=1.5): 
    """ Finds the first window where the "bursty condition" is true. The "bursty condition"
        can be defined 3 ways (using the mode parameter). By default, finds when the scatter in 
        log SFR is < the ~scatter in the SFMS (0.3 dex). Scatters larger than this one would not 
        call "constant" w.r.t. the SFMS and scatters smaller than this would occupy a spot on the
        SFMS (not directly on mind you, just like, it wouldn't jump around).
        1. Sigma_300(log(<SFR>_1)) (default,
            thresh=0.3 consistent with defn in
            https://ui.adsabs.harvard.edu/abs/2022arXiv220304321G/abstract
            https://ui.adsabs.harvard.edu/abs/2021MNRAS.501.4812F/abstract
            https://ui.adsabs.harvard.edu/abs/2021ApJ...911...88S/abstract
            https://ui.adsabs.harvard.edu/abs/2020MNRAS.498.3664G/abstract )
        2. Sigma_300(<SFR>_10)/<SFR>_10 < thresh (mode == 'anna', 
            thresh=0.5 consistent with defn in
            https://ui.adsabs.harvard.edu/abs/2021MNRAS.505..889Y/abstract )
        3.  peak(SFR/median) - trough(SFR/median) (mode == 'peaktrough',
            used to check "visual intuition" but works like actual garbage.
            Do not use this. Seriously.)
    Parameters
    ----------
    time_edges : np.ndarray
        SFR histogram edges in Gyr, ideally spaced by 1 Myr (we boxcar average anyway)
    SFRs : np.ndarray
        SFR histogram in msun/year (or whatever units, i'm a docstring not a cop)
    thresh : float, optional
        threshold value that the relative scatter should be below, by default 0.3
    window_size: float, optional
        the window that the relative scatter in log SFR should be computed in, 
        by default 0.3 (300 Myr)
    thresh_window : float, optional
        width of window that the relative scatter must remain below threshold for 
        (to avoid little excursions below counting as the "end" of bursty SFR.
        I see you m12f!!), by default 1.5
    Returns
    -------
    l_window 
        the time corresponding to the left edge of the thresh_window
        that satisfies the threshold condition.
    rel_scatters
        the relative scatters with the same shape as SFRs
    """
                
    adjusted_sfrs = (SFRs + SFRs[SFRs>0].min()/10)

    if mode == 'peaktrough':

        rel_scatters = np.zeros(adjusted_sfrs.size)
        per_ls = np.zeros(adjusted_sfrs.size)
        per_rs = np.zeros(adjusted_sfrs.size)
        medians = np.zeros(adjusted_sfrs.size)
        
        this_window_size = 0.05 #window_size
        window_size_n = int(this_window_size/SFH_dt/2)

        for i in range(adjusted_sfrs.size):
            window = adjusted_sfrs[
                max(0,i-window_size_n):
                min(adjusted_sfrs.size-1,i+window_size_n)]

            median = np.nanmedian(window)
            if np.isnan(median): import pdb; pdb.set_trace()

            per_l,per_r = np.quantile(
                window/median,
                [0.1,0.9])

            rel_scatters[i] = (per_r - per_l)
            #rel_scatters[i] = (per_r / per_l)
            per_ls[i] = per_l
            per_rs[i] = per_r
            medians[i] = median
        xs,rel_scatters = boxcar_average(
            time_edges,
            rel_scatters,
            0.3,
            assign='center')

        ## plot these to show peak-trough band
        #SFH_scatter_per_ls = per_ls #<--- bottom of band
        #SFH_scatter_per_rs = per_rs #<--- top of band
        #SFH_scatter_medians = medians #<--- middle of band
        #SFH_rel_scatters = rel_scatters #<--- scatters you want to stay w/i band

    elif mode == 'anna':
        ## calculate scatter using 10 Myr running average in 
        ##  window_size sized window
        xs,adjusted_sfrs = boxcar_average(
            time_edges,
            adjusted_sfrs,
            0.01)

        xs,boxcar_ys_300 = boxcar_average(
            time_edges,
            adjusted_sfrs,
            0.5,#window_size,
            assign='center')

        xs,boxcar_ys2_300 = boxcar_average(
            time_edges,
            adjusted_sfrs**2,
            0.5,#window_size,
            assign='center')

        ## <std>/<SFR>
        rel_scatters = np.sqrt(boxcar_ys2_300 - boxcar_ys_300**2)/boxcar_ys_300
    else:
        xs,boxcar_ys_300 = boxcar_average(
            time_edges,
            np.log10(adjusted_sfrs),
            window_size,
            assign='center')

        xs,boxcar_ys2_300 = boxcar_average(
            time_edges,
            np.log10(adjusted_sfrs)**2,
            window_size,
            assign='center')

        rel_scatters = np.sqrt(boxcar_ys2_300 - boxcar_ys_300**2)

    ## find the first 300 Myr window that is consistently below the threshold
    #print(thresh, thresh_window,rel_scatters)
    l_window, r_window = find_first_window(
        time_edges,
        rel_scatters,
        lambda x,y: y < thresh,
        thresh_window,
        last=True)

    return l_window, rel_scatters

In [None]:
formation_redshift = 1. / s_data.get_data( 'Age' ) - 1.

In [None]:
formation_time = astro_utils.age_of_universe(
    formation_redshift,
    h = s_data.data_attrs['hubble'],
    omega_matter = s_data.data_attrs['omega_matter']
)

In [None]:
age_at_z0 = astro_utils.age_of_universe(
    0.,
    h = s_data.data_attrs['hubble'],
    omega_matter = s_data.data_attrs['omega_matter']
)

In [None]:
main_galaxy_formation_time = formation_time[inside_galaxy]
main_galaxy_ages = age_at_z0 - main_galaxy_formation_time

In [None]:
fig = plt.figure()
ax = plt.gca()

SFR, time_edges, img = ax.hist(
    main_galaxy_formation_time,
    bins = 256,
)

ax.set_yscale( 'log' )

In [None]:
t_lookback_bursty, rel_scatters = findBurstyRegime(
    time_edges,
    SFR,
)
t_bursty = age_at_z0 - t_lookback_bursty
ax.axvline( t_bursty, color='k' )

ax.set_ylim( 1e3, 3e5 )

fig

In [None]:
t_bursty, t_lookback_bursty

In [None]:
fig = plt.figure()
ax = plt.gca()

ax.plot(
    time_edges[:-1],
    rel_scatters
)

ax.axvline(
    t_bursty,
    color = 'k',
)
ax.axhline(
    0.3,
    color = 'k',
)

# Save

In [None]:
data = verdict.Dict({
    'scale_height (kpc)': scale_height,
    'galaxy_angular_momentum (Msun*kpc*km/s)': s_data.total_ang_momentum
})

In [None]:
data.to_hdf5( './data/galaxy_stats.h5' )