In [2]:
%load_ext autoreload
%autoreload 2

import os
import sys
import h5py
import pickle

sys.path.append('..')

import numpy as np
import jax.numpy as jnp
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl

from analysis_utils import envs

%matplotlib inline
plt.style.use('/mnt/home/tnguyen/default.mplstyle')

In [3]:
def calc_all_props(
    catalog, halo_idx, remove_center=True, sort_mass=True,
    min_subhalo_mass=1,
):
    """ Convenient function to calculate all relevant properties of 
    a halo and its subhalos. """
    prop = {}

    # get halo properties
    prop['halo_pos'] = catalog['GroupPos'][halo_idx] / 1e3
    prop['halo_vel'] = catalog['GroupVel'][halo_idx]
    prop['halo_mvir'] = catalog['GroupMass'][halo_idx] * 1e10
    prop['halo_mstar'] = catalog['GroupMassType'][halo_idx, 4] * 1e10

    # get all subhalo properties
    prop['subhalo_pos'] = catalog['SubhaloPos'][halo_idx] / 1e3
    prop['subhalo_vel'] = catalog['SubhaloVel'][halo_idx]
    prop['subhalo_mvir'] = catalog['SubhaloMass'][halo_idx] * 1e10
    prop['subhalo_mstar'] = catalog['SubhaloMassType'][halo_idx][:, 4] * 1e10
    prop['subhalo_vmax_tilde'] = (
        catalog['SubhaloVmax'][halo_idx] / catalog['SubhaloVmaxRad'][halo_idx]) 
    prop['subhalo_vmax_tilde'] = prop['subhalo_vmax_tilde'].reshape(-1, 1)

    # get SOLBOL parameters
    prop['wdm_mass'] = catalog['sobol_params'][halo_idx, 0]
    prop['sn1'] = catalog['sobol_params'][halo_idx, 1]
    prop['sn2'] = catalog['sobol_params'][halo_idx, 2]
    prop['agn1'] = catalog['sobol_params'][halo_idx, 3]

    # apply some transformations and preprocessing    
    # convert all halo mass to log10
    prop['halo_mvir'] = np.log10(prop['halo_mvir'])
    prop['halo_mstar'] = np.log10(prop['halo_mstar'])
    prop['subhalo_mvir'] = np.log10(prop['subhalo_mvir'])
    prop['subhalo_mstar'] = np.log10(prop['subhalo_mstar'])
    
    # center the halo and subhalo positions and velocities
    prop['subhalo_pos'] -= prop['halo_pos']
    prop['subhalo_vel'] -= prop['halo_vel']

    # remove the center subhalo
    if remove_center:
        prop['subhalo_pos'] = prop['subhalo_pos'][1:]
        prop['subhalo_vel'] = prop['subhalo_vel'][1:]
        prop['subhalo_mvir'] = prop['subhalo_mvir'][1:]
        prop['subhalo_mstar'] = prop['subhalo_mstar'][1:]

    # sort the subhalos by mass
    if sort_mass:
        idx = np.argsort(prop['subhalo_mvir'])[::-1]
        for p in prop.keys():
            prop[p] = prop[p][idx] if 'subhalo' in p else prop[p]

    # apply a minimum subhalo mass cut
    idx = prop['subhalo_mvir'] > np.log10(min_subhalo_mass)
    for p in prop.keys():
        prop[p] = prop[p][idx] if 'subhalo' in p else prop[p]

    # reshape
    prop['subhalo_mvir'] = prop['subhalo_mvir'].reshape(-1, 1)
    prop['subhalo_mstar'] = prop['subhalo_mstar'].reshape(-1, 1)

    return prop

def pad_and_create_mask(features, max_len=None):
    """ Pad and create Transformer mask. """
    if max_len is None:
        max_len = max([f.shape[0] for f in features])

    # create mask (batch_size, max_len)
    # note that jax mask is 1 for valid entries and 0 for invalid entries
    # this is the opposite of the pytorch mask
    mask = np.ones((len(features), max_len))
    for i, f in enumerate(features):
        mask[i, f.shape[0]:] = 0

    # zero pad features
    padded_features = np.zeros((len(features), max_len, features[0].shape[1]))
    for i, f in enumerate(features):
        padded_features[i, :f.shape[0]] = f
    
    return padded_features, mask

In [24]:
# Preprocessing configuration
# ---------------------------
# input / output paths
raw_dset_path = envs.DEFAULT_RAW_DATASET_DIR / 'mw_zooms-wdm.pkl'
out_dset_root = envs.DEFAULT_DATASET_DIR / 'mw_zooms-wdm-dmprop'
out_dset_name = "nmax50-vmaxtilde"

# halo selection and properties
num_max = 50   # maximum number of subhalos to include
pad_features = False
remove_center = True
sort_mass = True
shuffle = True
min_subhalo_mass = 1e8
feat_parameters = ['subhalo_pos', 'subhalo_vel', 'subhalo_mvir', 'subhalo_vmax_tilde']
cond_parameters = ['halo_mvir', 'wdm_mass', 'sn1', 'sn2', 'agn1']

In [25]:
# read in the FoF catalog 
with open(raw_dset_path, 'rb') as f:
    catalog = pickle.load(f)
num_total = len(catalog['box_num'])

features = []
conditions = []
for halo_idx in range(num_total):
    # get all halo and subhalo properties 
    prop = calc_all_props(
        catalog, halo_idx, 
        remove_center=remove_center, sort_mass=sort_mass,
        min_subhalo_mass=min_subhalo_mass
    )
    # skip if there are not enough subhalos
    if (prop['subhalo_pos'].shape[0] < num_max) & (not pad_features):
        continue

    # get all features
    feat = np.concatenate([prop[p] for p in feat_parameters], axis=1)
    feat = feat[:num_max]
    features.append(feat)

    # get all conditions
    cond = np.stack([prop[p] for p in cond_parameters])
    conditions.append(cond)

features, mask = pad_and_create_mask(features)
conditions = np.stack(conditions)

print(f"Number of halos: {len(features)}")

Number of halos: 227


  prop['subhalo_mstar'] = np.log10(prop['subhalo_mstar'])


In [26]:
if shuffle:
    idx = np.arange(len(features))
    np.random.shuffle(idx)
    features = features[idx]
    conditions = conditions[idx]

# save the dataset
os.makedirs(out_dset_root, exist_ok=True)

out_dset_path = os.path.join(out_dset_root, out_dset_name + '_feat.npy')
jnp.save(out_dset_path, features)
out_dset_path = os.path.join(out_dset_root, out_dset_name + '_mask.npy')
jnp.save(out_dset_path, mask)

# save the conditioning as CSV instead 
out_dset_path = os.path.join(out_dset_root, out_dset_name + '_cond.csv')
table = pd.DataFrame(
    conditions.reshape(-1, len(cond_parameters)), columns=cond_parameters)
table.to_csv(out_dset_path, index=False)