In [1]:
from loadmodules import *
import numpy as np
import scipy as sp

import joblib
from joblib import Parallel, delayed

In [None]:
def expand_indices(counts):
    return (np.arange(counts.sum()) - np.repeat(np.cumsum(counts) - counts, counts)).astype(np.int64)

def mask_equal_to_previous(arr):
    mask = np.ones(len(arr), dtype=bool)
    mask[1:] = arr[1:] != arr[:-1]
    return mask

def B(x):
    return sp.special.erf(x) * 2.*x*np.exp(-x**2)/np.sqrt(np.pi)

In [3]:
path = './compare_comp_time/Au6_lvl5_31a2_eqrhnoevo/output/'

num_snaps = 128
files_per_snap = 8

In [None]:
# Write the information of disrupted clusters to a log file
with open('./output/disrupted_clusters_DF.txt', 'w') as log_file:
    log_file.write("Snapshot, Time, PID, SCID, Initial Mass, Current Mass, FormationTime \n")

for i in range(127, num_snaps):
    print('Loading snapshot', i)
    sf = load_subfind(i, dir=path, hdf5=True, loadonly=['fpos', 'frc2', 'svel', 'flty', 'fnsh', 'slty', 'spos', 'smty', 'ffsh'] )
    s = gadget_readsnap(i, snappath=path, subfind=sf, hdf5=True, loadonlyhalo=0)
    print('Redshift:', s.redshift, ' cosmo time:', s.time)

    if ((s.data['type']==4).sum() > 0):
        if((s.data['incl']>0).sum() > 0):
            print('Found stars with GC in this snapshot')
            s.calc_sf_indizes( sf )
            s.select_halo( sf, use_principal_axis=True, use_cold_gas_spin=False, do_rotation=True, verbose=False )

            Gcosmo = 43.
            starparts = s.data['type']==4
            
            kinetic_energy = np.sum(s.data['vel']**2, axis=1)

            orbital_energy = s.data['pot'] + 0.5 * kinetic_energy
            Jtot = np.sqrt((np.cross( s.data['pos'], (s.data['vel'] ))**2).sum(axis=1))
            
            isort_parts = np.argsort(s.r())
            revert_sort = np.argsort(isort_parts)
            cummass = np.cumsum(s.data['mass'][isort_parts])
            Vc_parts = np.sqrt(Gcosmo*cummass[revert_sort]/s.r())

            # Energy of circular orbits at increasing radii
            Ecirc = 0.5*Vc_parts[isort_parts]**2 + s.data['pot'][isort_parts]
            e_max = np.nanmax(Ecirc[~np.isinf(Ecirc)])
            orbital_energy -= e_max
            Ecirc -= e_max

            mask = mask_equal_to_previous(s.r()[isort_parts][~np.isinf(Ecirc)])

            r_test = np.logspace(-5., np.log10(s.r().max()), 500)
            Ecirc_f = sp.interpolate.PchipInterpolator(s.r()[isort_parts][~np.isinf(Ecirc)][mask], Ecirc[~np.isinf(Ecirc)][mask])
            Vc_f = sp.interpolate.PchipInterpolator(s.r()[isort_parts][~np.isinf(Ecirc)][mask], Vc_parts[isort_parts][~np.isinf(Ecirc)][mask])
            Mr_f = sp.interpolate.PchipInterpolator(s.r()[isort_parts][~np.isinf(Ecirc)][mask], cummass[~np.isinf(Ecirc)][mask])
    
            mask_clusters_initial = (s.data['incl'] > 0)
    
            idx = np.argmin(np.abs(orbital_energy[starparts][mask_clusters_initial,np.newaxis] - Ecirc_f(r_test)), axis=1)
            rc = r_test[idx]
            vc = Vc_f(rc)
            Lzmax = rc*vc
            
            cluster_masses = s.data['mclt'][mask_clusters_initial].flatten()
            init_cluster_masses = s.data['imcl'][mask_clusters_initial].flatten()
            cluster_mlost_sh = s.data['mlsk'][mask_clusters_initial].flatten()
            cluster_mlost_rx = s.data['mlrx'][mask_clusters_initial].flatten()
            not_empty_clusters = (init_cluster_masses > 0.)
            cluster_masses = cluster_masses[not_empty_clusters]
            cluster_mlost_sh = cluster_mlost_sh[not_empty_clusters]
            cluster_mlost_rx = cluster_mlost_rx[not_empty_clusters]
            init_cluster_masses = init_cluster_masses[not_empty_clusters]

            part_id = np.repeat(s.data['id'][starparts], s.data['incl'])
            scs_id = expand_indices(s.data['incl'][mask_clusters_initial])
            
            clusters_formtime = np.repeat(s.data['age'], s.data['incl'])
            clusters_age = s.cosmology_get_lookback_time_from_a(clusters_formtime, is_flat=True) - s.cosmology_get_lookback_time_from_a(s.time, is_flat=True)
            
            # Do the DF timescale estimate for clusters with mass
            mask_mass = (cluster_masses > 0.)
    
            rc_clus = np.repeat(rc, s.data['incl'][mask_clusters_initial])
            M_rc_clus = np.repeat(Mr_f(rc), s.data['incl'][mask_clusters_initial])
            vc_rc_clus = np.repeat(vc, s.data['incl'][mask_clusters_initial])
            sigma_rc_clus = np.zeros_like(rc_clus)

            joblib.dump(s.r(), 'parts_radius.npy')
            joblib.dump(starparts, 'starparts.npy')
            joblib.dump(s.data['age'], 's_data_age.npy')
            joblib.dump(s.data['vel'], 's_data_vel.npy')  # assuming 3D velocity
            joblib.dump(s.data['type'], 's_data_type.npy')

            parts_radius = joblib.load('parts_radius.npy', mmap_mode='r')
            s_data_age = joblib.load('s_data_age.npy', mmap_mode='r')
            s_data_vel = joblib.load('s_data_vel.npy', mmap_mode='r')
            s_data_type = joblib.load('s_data_type.npy', mmap_mode='r')
            starparts = joblib.load('starparts.npy', mmap_mode='r')

            def velocity_dispersion(radius, parts_radius, starparts, s_data_age, s_data_vel, s_data_type):
                within_radius = parts_radius[starparts][s_data_age > 0.] < radius
                if within_radius.sum() >= 48:
                    velocities = np.sqrt(np.sum(s_data_vel[starparts][s_data_age > 0.][within_radius]**2, axis=1))
                else:
                    mask_dm = (s_data_type != 4) * (s_data_type != 0)
                    within_radius = parts_radius[mask_dm] < radius
                    velocities = np.sqrt(np.sum(s_data_vel[mask_dm][within_radius]**2, axis=1))

                return np.std(velocities) if velocities.size else 0.0

            sigma_clus = Parallel(n_jobs=-1)(delayed(velocity_dispersion)(r, parts_radius, starparts, s_data_age, s_data_vel, s_data_type)
                                              for r in rc[s.data['nclt'][mask_clusters_initial]>0])
            sigma_rc_clus[mask_mass] = np.repeat(sigma_clus, s.data['nclt'][mask_clusters_initial][s.data['nclt'][mask_clusters_initial]>0])

            os.remove('parts_radius.npy')
            os.remove('starparts.npy')
            os.remove('s_data_age.npy')
            os.remove('s_data_vel.npy')
            os.remove('s_data_type.npy')

            feps = np.repeat((Jtot[starparts][mask_clusters_initial]/Lzmax)**0.78, s.data['incl'][mask_clusters_initial])
            coulumblog = np.zeros_like(rc_clus)
            coulumblog[mask_mass] = np.log(1. + M_rc_clus[mask_mass]/cluster_masses[mask_mass])
            
            tdf = 2e4 * np.ones_like(rc_clus)
            tdf[mask_mass] = feps[mask_mass]/(2*B(vc_rc_clus[mask_mass]/(np.sqrt(2.)*sigma_rc_clus[mask_mass])))*np.sqrt(2.)*sigma_rc_clus[mask_mass]* \
                            rc_clus[mask_mass]**2./(Gcosmo*cluster_masses[mask_mass]*coulumblog[mask_mass])
            tdf *= s.UnitLength_in_cm/s.UnitVelocity_in_cm_per_s / (1e9*365.25*24*3600)

            mask_disrupted = (tdf < clusters_age)

            # Write the information of disrupted clusters to a log file
            with open('./output/disrupted_clusters_DF.txt', 'a') as log_file:
                for idx in np.where(mask_disrupted)[0]:
                    log_file.write(f"{i}, {s.time}, {part_id[idx]}, {scs_id[idx]}, {init_cluster_masses[idx]}, {cluster_masses[idx]}, {clusters_formtime[idx]}\n")

            # Now go through all subsequent snapshots and remove the disrupted clusters
            if (mask_disrupted.sum() > 0):
                print('Clusters disrupted by dynamical friction {:d}'.format(mask_disrupted.sum()))
                for j in range(i, num_snaps):
                    found = 0
                    k = 0
                    while found < mask_disrupted.sum():
                        h5_file = h5py.File(path + 'snapdir_{:03d}/snapshot_{:03d}.{:d}.hdf5'.format(j,j,k), 'r+')
                        stars = h5_file['PartType4']
                        ids = stars['ParticleIDs'][:]
                        clus_mass = stars['ClusterMass'][:]
                        clus_radius = stars['ClusterRadius'][:]
                        disruption_time = stars['DisruptionTime'][:]
                        mlost_shocks = stars['MassLostShocks'][:]
                        mlost_relax = stars['MassLostRelaxation'][:]
                        nclus = stars['NumberOfClusters'][:]
                        inverse_mask = np.isin(part_id[mask_disrupted], ids)
                        found += inverse_mask.sum()
                        if inverse_mask.sum()>0:
                            print('Found {:d} clusters in snapshot {:d} part {:d}'.format(inverse_mask.sum(), j, k))
                            for cl_idx in range(inverse_mask.sum()):
                                mask_id = np.isin(ids, part_id[mask_disrupted][inverse_mask][cl_idx])
                                clus_mass[mask_id, scs_id[mask_disrupted][inverse_mask][cl_idx]] = 0.0
                                clus_radius[mask_id, scs_id[mask_disrupted][inverse_mask][cl_idx]] = 0.0
                                mlost_shocks[mask_id, scs_id[mask_disrupted][inverse_mask][cl_idx]] = cluster_mlost_sh[mask_disrupted][inverse_mask][cl_idx]
                                mlost_relax[mask_id, scs_id[mask_disrupted][inverse_mask][cl_idx]] = cluster_mlost_rx[mask_disrupted][inverse_mask][cl_idx]
                                nclus[mask_id] = (clus_mass[mask_id] > 0.).sum()
                                disruption_time[mask_id, scs_id[mask_disrupted][inverse_mask][cl_idx]] = s.time
                        stars['ClusterMass'][:] = clus_mass
                        stars['ClusterRadius'][:] = clus_radius
                        stars['DisruptionTime'][:] = disruption_time
                        stars['MassLostShocks'][:] = mlost_shocks
                        stars['MassLostRelaxation'][:] = mlost_relax
                        stars['NumberOfClusters'][:] = nclus
                        h5_file.close()
                        k+=1
            else:
                print('No clusters disrupted by dynamical friction in this snapshot')
        else:
            print('NO STARS WITH GC IN MAIN HALO IN THIS SNAPSHOT')
    else:
        print('No stars in this snapshot')

# Identify SCs that have info that need to be corrected

In [2]:
# disrupted_clusters_DF.txt content example:
# Snapshot:0, Time:1, PID:2, SCID:3, Initial Mass:4, Current Mass:5, FormationTime:6
disrupted_scs = [
[62, 0.344341999449342, 8796098036440, 0, 6.574573490070179e-05, 3.3196440199390054e-05, 0.31109389662742615],
[62, 0.344341999449342, 8796098085888, 0, 5.590692671830766e-05, 3.454939724178985e-05, 0.30888494849205017],
[62, 0.344341999449342, 8796098581042, 0, 0.00021077124984003603, 0.00013148513971827924, 0.3019043505191803],
[62, 0.344341999449342, 8796098774989, 0, 1.1844956134154927e-05, 5.524972038983833e-06, 0.3049369752407074],
[62, 0.344341999449342, 8796090239295, 0, 0.0007212080527096987, 0.00043532243580557406, 0.23272006213665009]
]

In [3]:
simulation = 'Au6_lvl4_cfea_sh50myr10/output/'
num_snaps = 128
files_per_snap = 8

In [4]:
for sc in range(len(disrupted_scs)):
    for j in range(files_per_snap):
        h5_file = h5py.File(simulation + 'snapdir_{:03d}/snapshot_{:03d}.{:d}.hdf5'.format(disrupted_scs[sc][0],disrupted_scs[sc][0],j), 'r')
        stars = h5_file['PartType4']
        ids = stars['ParticleIDs'][:]
        clus_mass = stars['ClusterMass'][:]
        clus_radius = stars['ClusterRadius'][:]
        disruption_time = stars['DisruptionTime'][:]
        mlost_shocks = stars['MassLostShocks'][:]
        mlost_relax = stars['MassLostRelaxation'][:]
        nclus = stars['NumberOfClusters'][:]
        cl_idx = int(disrupted_scs[sc][3])
        inverse_mask = np.isin(ids, disrupted_scs[sc][2])
        if inverse_mask.sum()>0:
            print('Found cluster {:d} in snapshot {:d} part {:d}'.format(sc+1, disrupted_scs[sc][0], j))
            print(clus_mass[inverse_mask, cl_idx],
                  clus_radius[inverse_mask, cl_idx],
                  mlost_shocks[inverse_mask, cl_idx],
                  mlost_relax[inverse_mask, cl_idx],
                  nclus[inverse_mask],
                  disruption_time[inverse_mask, cl_idx])
            snap = num_snaps - 1
            need_to_correct = True
            while need_to_correct:
                for part in range(files_per_snap):
                    h5_file_next = h5py.File(simulation + 'snapdir_{:03d}/snapshot_{:03d}.{:d}.hdf5'.format(snap,snap,part), 'r+')
                    stars_next = h5_file_next['PartType4']
                    ids_next = stars_next['ParticleIDs'][:]
                    clus_mass_next = stars_next['ClusterMass'][:]
                    clus_radius_next = stars_next['ClusterRadius'][:]
                    disruption_time_next = stars_next['DisruptionTime'][:]
                    mlost_shocks_next = stars_next['MassLostShocks'][:]
                    mlost_relax_next = stars_next['MassLostRelaxation'][:]
                    nclus_next = stars_next['NumberOfClusters'][:]
                    inverse_mask_next = np.isin(ids_next, disrupted_scs[sc][2])
                    if inverse_mask_next.sum()>0:
                        print('Found cluster {:d} in snapshot {:d} part {:d}'.format(sc+1, snap, part))
                        if(clus_mass_next[inverse_mask_next, cl_idx] > 0. or disruption_time_next[inverse_mask_next, cl_idx] != disrupted_scs[sc][1]
                           or mlost_shocks_next[inverse_mask_next, cl_idx] != mlost_shocks[inverse_mask, cl_idx]
                           or mlost_relax_next[inverse_mask_next, cl_idx] != mlost_relax[inverse_mask, cl_idx]):
                            print('Need to fix cluster {:d} in snapshot {:d} part {:d}'.format(sc+1, snap, part))
                            print(clus_mass_next[inverse_mask_next, cl_idx],
                                  clus_radius_next[inverse_mask_next, cl_idx],
                                  mlost_shocks_next[inverse_mask_next, cl_idx],
                                  mlost_relax_next[inverse_mask_next, cl_idx],
                                  nclus_next[inverse_mask_next],
                                  disruption_time_next[inverse_mask_next, cl_idx])
                            clus_mass_next[inverse_mask_next, cl_idx] = 0.
                            clus_radius_next[inverse_mask_next, cl_idx] = 0.
                            mlost_shocks_next[inverse_mask_next, cl_idx] = mlost_shocks[inverse_mask, cl_idx]
                            mlost_relax_next[inverse_mask_next, cl_idx] = mlost_relax[inverse_mask, cl_idx]
                            nclus_next[inverse_mask_next] = (clus_mass_next[inverse_mask_next] > 0.).sum()
                            disruption_time_next[inverse_mask_next, cl_idx] = disrupted_scs[sc][1]
                            print(clus_mass_next[inverse_mask_next, cl_idx],
                                  clus_radius_next[inverse_mask_next, cl_idx],
                                  mlost_shocks_next[inverse_mask_next, cl_idx],
                                  mlost_relax_next[inverse_mask_next, cl_idx],
                                  nclus_next[inverse_mask_next],
                                  disruption_time_next[inverse_mask_next, cl_idx])
                            stars_next['ClusterMass'][:] = clus_mass_next
                            stars_next['ClusterRadius'][:] = clus_radius_next
                            stars_next['DisruptionTime'][:] = disruption_time_next
                            stars_next['MassLostShocks'][:] = mlost_shocks_next
                            stars_next['MassLostRelaxation'][:] = mlost_relax_next
                            stars_next['NumberOfClusters'][:] = nclus_next
                            snap -= 1
                        else:
                            print('From this snapshot down the cluster should be correct')
                            need_to_correct = False
                        part = files_per_snap  # break the loop
                    h5_file_next.close()
            j = files_per_snap  # break the loop
        h5_file.close()

Found cluster 1 in snapshot 62 part 0
[0.] [0.] [7.5806724e-06] [6.267579e-07] [0] [0.344342]
Found cluster 1 in snapshot 127 part 0
Need to fix cluster 1 in snapshot 127 part 0
[0.] [0.] [6.1110795e-06] [3.7020984e-06] [0] [0.344342]
[0.] [0.] [7.5806724e-06] [6.267579e-07] [0] [0.344342]
Found cluster 1 in snapshot 126 part 0
From this snapshot down the cluster should be correct
Found cluster 2 in snapshot 62 part 0
[0.] [0.] [1.3543137e-06] [9.462226e-07] [0] [0.344342]
Found cluster 2 in snapshot 127 part 0
Need to fix cluster 2 in snapshot 127 part 0
[0.] [0.] [6.1110795e-06] [3.7020984e-06] [0] [0.344342]
[0.] [0.] [1.3543137e-06] [9.462226e-07] [0] [0.344342]
Found cluster 2 in snapshot 126 part 0
From this snapshot down the cluster should be correct
Found cluster 3 in snapshot 62 part 0
[0.] [0.] [6.277888e-06] [1.3067355e-06] [0] [0.344342]
Found cluster 3 in snapshot 127 part 0
Need to fix cluster 3 in snapshot 127 part 0
[0.] [0.] [6.1110795e-06] [3.7020984e-06] [0] [0.34434