Convert Ananke coordinate to FIRE coordinate

In [6]:
import os
from os.path import join
import h5py

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd

%matplotlib inline

plt.style.use('../matplotlib_style/latex_plot_style.mplstyle')

In [7]:
all_keys = (
    'source_id', 'parentid', 'ra', 'dec', 'l', 'b', 'parallax', 'pmra', 'pmdec',
    'radial_velocity', 'age', 'feh', 'px_true', 'py_true', 'pz_true', 
    'vx_true', 'vy_true', 'vz_true', 'ra_true', 'dec_true', 'l_true', 'b_true',
    'parallax_true', 'pmra_true', 'pmdec_true', 'radial_velocity_true'
)

In [8]:
# Subsample dataset
lsr = 0
gal = 'm12f'
batch_size = 1000000

sim_dir = '/scratch/05328/tg846280/FIRE_Public_Simulations/'
sim_file = f'ananke/{gal}/lsr-{lsr}/lsr-{lsr}-rslice-0.{gal}-res7100-md-sliced-gcat-dr2.hdf5'
out_dir = f'/scratch/05328/tg846280/FIRE_Public_Simulations/ananke_subsamples/{gal}'
out_file = f'lsr-{lsr}-rslice-0.{gal}-res7100-md-sliced-gcat-dr2.hdf5'

os.makedirs(out_dir, exist_ok=True)

In [9]:
# Read in accretion table
table = pd.read_csv(f'accretion_history/stars_accretion_history_{gal}_res7100_v2.csv')
id_stars = table['id_stars'].values
id_stars = np.unique(id_stars)

In [10]:
n_select_total = 0
with h5py.File(join(out_dir, out_file), 'w') as out_f:
    with h5py.File(join(sim_dir, sim_file), 'r') as in_f:
        n_total = in_f['parallax_over_error'].len()    # max number of accepted data
        num_batches = (n_total - 1) // batch_size + 1
        
        for i in range(0, n_total, batch_size):
            
            print(i//batch_size, num_batches)
            
            start = i 
            stop = start + batch_size

            # apply some cut to the dataset
            # get parallax over error and apply a parallax cut
            poe = in_f['parallax_over_error'][start: stop]
            select = poe > 10  # equivalent to dp / p  < 0.1

            # only consider stars with
            parent_id = in_f['parentid'][start: stop]
            select = select & np.isin(parent_id, id_stars)
            
            n_select = np.sum(select)
            n_select_total += n_select
            
            # in case none of the data pass the cut
            if n_select == 0:
                continue

            for key in all_keys:
                vals = in_f[key][start: stop][select]

                # if dataset does not exist
                if out_f.get(key) is None:
                    dset = out_f.create_dataset(
                        key, data=vals, maxshape=(n_total, *vals.shape[1:]))
                else:
                    # resize output dataset and add values
                    dset = out_f[key]
                    dset.resize(dset.shape[0] + n_select, axis=0)
                    dset[-n_select:] = vals

0 431
1 431
2 431
3 431
4 431
5 431
6 431
7 431
8 431
9 431
10 431
11 431
12 431
13 431
14 431
15 431
16 431
17 431
18 431
19 431
20 431
21 431
22 431
23 431
24 431
25 431
26 431
27 431
28 431
29 431
30 431
31 431
32 431
33 431
34 431
35 431
36 431
37 431
38 431
39 431
40 431
41 431
42 431
43 431
44 431
45 431
46 431
47 431
48 431
49 431
50 431
51 431
52 431
53 431
54 431
55 431
56 431
57 431
58 431
59 431
60 431
61 431
62 431
63 431
64 431
65 431
66 431
67 431
68 431
69 431
70 431
71 431
72 431
73 431
74 431
75 431
76 431
77 431
78 431
79 431
80 431
81 431
82 431
83 431
84 431
85 431
86 431
87 431
88 431
89 431
90 431
91 431
92 431
93 431
94 431
95 431
96 431
97 431
98 431
99 431
100 431
101 431
102 431
103 431
104 431
105 431
106 431
107 431
108 431
109 431
110 431
111 431
112 431
113 431
114 431
115 431
116 431
117 431
118 431
119 431
120 431
121 431
122 431
123 431
124 431
125 431
126 431
127 431
128 431
129 431
130 431
131 431
132 431
133 431
134 431
135 431
136 431
137 431
138 43