In [1]:
%reload_ext autoreload
%autoreload 2

import os, sys, shutil, logging, gc
import pickle
from tqdm import tqdm

from jax import config
config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp

import numpy as np
from scipy import interpolate
from astropy.cosmology import Planck18
import astropy.units as u

import py21cmfast as p21c
from py21cmfast import cache_tools

sys.path.append(os.environ['DH_DIR'])
from darkhistory.spec.spectrum import Spectrum
from darkhistory.history.reionization import alphaA_recomb
from darkhistory.history.tla import compton_cooling_rate

WDIR = os.environ['DM21CM_DIR']
sys.path.append(WDIR)
import dm21cm.physics as phys
from dm21cm.dh_wrappers import DarkHistoryWrapper, TransferFunctionWrapper
from dm21cm.utils import load_h5_dict
from dm21cm.data_cacher import Cacher
from dm21cm.spectrum import AttenuatedSpectrum
from dm21cm.interpolators_jax import SFRDInterpolator
from dm21cm.dm_params import DMParams

%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import colormaps as cms
mpl.rc_file('../matplotlibrc')



In [3]:
from dm21cm.evolve import get_z_edges, split_xray, gen_injection_boxes, p21c_step, geom_inds, get_emissivity_bracket, DepTracker, debug_get_21totf_interp

In [4]:
def custom_SFRD(z, delta, r):
    return 1. + delta

In [2]:
z_start = 45.
z_end = 5.
zplusone_step_factor = 1.001
dm_params = DMParams(
    mode='decay',
    primary='phot_delta',
    m_DM=1e8, # [eV]
    lifetime=1e50, # [s]
)
enable_elec = False

p21c_initial_conditions = p21c.initial_conditions(
    user_params = p21c.UserParams(
        HII_DIM = 32,
        BOX_LEN = 32*2, # [conformal Mpc]
        N_THREADS = 32,
        USE_INTERPOLATION_TABLES = True, # for testing
    ),
    cosmo_params = p21c.CosmoParams(
        OMm = Planck18.Om0,
        OMb = Planck18.Ob0,
        POWER_INDEX = Planck18.meta['n'],
        SIGMA_8 = Planck18.meta['sigma8'],
        hlittle = Planck18.h,
    ),
    random_seed = 54321,
    write = True,
)

clear_cache = True
use_tqdm = True
debug_break_after_z = 10.
debug_record_extra = False

# 21cmFAST xray injection
use_21cmfast_xray = False
astro_params_before_step = p21c.AstroParams(L_X = 0.) # log10 value
debug_turn_off_pop2ion = True
debug_xray_Rmax_p21c = 500.

# DM21cm xray injection
debug_flags = ['xc-ots', 'xc-custom-SFRD', 'xc-01attenuation']
debug_unif_delta_dep = True
debug_unif_delta_tf_param = True
st_multiplier = 10.
debug_nodplus1 = True
debug_xray_Rmax_shell = 500.
debug_xray_Rmax_bath = 500.
adaptive_shell = 40

# defaults
tf_on_device = True
debug_depallion = False

In [7]:
#===== data and cache =====
os.environ['DM21CM_DATA_DIR'] = '/n/holyscratch01/iaifi_lab/yitians/dm21cm/DM21cm/data/tf/zf001/data'
data_dir = os.environ['DM21CM_DATA_DIR']
p21c.config['direc'] = "/n/home07/yitians/21cmFAST-cache/test"
gc.collect()

#===== initialize =====
#--- physics parameters ---
EPSILON = 1e-6
p21c.global_params.Z_HEAT_MAX = z_start + EPSILON
p21c.global_params.ZPRIME_STEP_FACTOR = zplusone_step_factor
p21c.global_params.CLUMPING_FACTOR = 1.
if debug_turn_off_pop2ion:
    p21c.global_params.Pop2_ion = 0.
if debug_xray_Rmax_p21c is not None:
    p21c.global_params.R_XLy_MAX = debug_xray_Rmax_p21c

abscs = load_h5_dict(f"{data_dir}/abscissas.h5")
if not np.isclose(np.log(zplusone_step_factor), abscs['dlnz']):
    raise ValueError('zplusone_step_factor and abscs mismatch')
dm_params.set_inj_specs(abscs)

box_dim = p21c_initial_conditions.user_params.HII_DIM
box_len = p21c_initial_conditions.user_params.BOX_LEN
cosmo = Planck18

#--- DarkHistory and transfer functions ---
tf_wrapper = TransferFunctionWrapper(
    box_dim = box_dim,
    abscs = abscs,
    prefix = data_dir,
    enable_elec = enable_elec,
    on_device = tf_on_device,
)

#--- xraycheck ---
delta_cacher = Cacher(
    data_path=f"{p21c.config['direc']}/xraycheck_brightness.h5",
    cosmo=cosmo, N=box_dim, dx=box_len/box_dim,
    shell_Rmax=debug_xray_Rmax_shell,
    Rmax=debug_xray_Rmax_bath,
)
# reconstruct spec cache
#spec_cache = pickle.load(open(f"{p21c.config['direc']}/spec_cache.p", 'rb'))
#delta_cacher.spectrum_cache = spec_cache
#delta_cacher.brightness_cache.z_s = spec_cache.z_s
    
xray_eng_lo = 0.5 * 1000 # [eV]
xray_eng_hi = 10.0 * 1000 # [eV]
xray_i_lo = np.searchsorted(abscs['photE'], xray_eng_lo)
xray_i_hi = np.searchsorted(abscs['photE'], xray_eng_hi)

sfrd_tables = load_h5_dict(f"{data_dir}/sfrd_tables.h5")
z_range = sfrd_tables['z_range']
delta_range = sfrd_tables['delta_range']
r_range = sfrd_tables['r_range']
cond_sfrd_table = sfrd_tables['cond_sfrd_table']
st_sfrd_table = sfrd_tables['st_sfrd_table']
Cond_SFRD_Interpolator = SFRDInterpolator(z_range, delta_range, r_range, cond_sfrd_table) # jax good
ST_SFRD_Interpolator = interpolate.interp1d(z_range, st_sfrd_table * st_multiplier)

#--- redshift stepping ---
z_edges = get_z_edges(z_start, z_end, p21c.global_params.ZPRIME_STEP_FACTOR)

#===== initial steps =====
dh_wrapper = DarkHistoryWrapper(
    dm_params,
    prefix = p21c.config[f'direc'],
)
debug_copy_dh_init = f"{WDIR}/outputs/dh/xc_xrayST_soln.p"

shutil.copy(debug_copy_dh_init, f"{p21c.config['direc']}/dh_init_soln.p")
logging.info(f'Copied dh_init_soln.p from {debug_copy_dh_init}')

# We have to synchronize at the second step because 21cmFAST acts weird in the first step:
# - global_params.TK_at_Z_HEAT_MAX is not set correctly (it is probably set and evolved for a step)
# - global_params.XION_at_Z_HEAT_MAX is not set correctly (it is probably set and evolved for a step)
# - first step ignores any values added to spin_temp.Tk_box and spin_temp.x_e_box
z_match = z_edges[1]
dh_wrapper.evolve(end_rs=(1+z_match)*0.9, rerun=False)
T_k_DH_init, x_e_DH_init, phot_bath_spec = dh_wrapper.get_init_cond(rs=1+z_match)

perturbed_field = p21c.perturb_field(redshift=z_edges[1], init_boxes=p21c_initial_conditions)
spin_temp, ionized_box, brightness_temp = p21c_step(perturbed_field=perturbed_field, spin_temp=None, ionized_box=None, astro_params=astro_params_before_step)
spin_temp.Tk_box += T_k_DH_init - np.mean(spin_temp.Tk_box)
spin_temp.x_e_box += x_e_DH_init - np.mean(spin_temp.x_e_box)
ionized_box.xH_box = 1 - spin_temp.x_e_box

records = []
records_extra = []

#===== main loop =====
#--- trackers ---
i_xraycheck_shell_start = ... # set later
i_xraycheck_bath_start = ... # set later

z_edges = z_edges[1:] # Maybe fix this later
z_range = range(len(z_edges)-1)
if use_tqdm:
    from tqdm import tqdm
    z_range = tqdm(z_range)
print_str = ''
dep_tracker = DepTracker()

INFO:root:TransferFunctionWrapper: Loaded photon transfer functions.
INFO:root:Copied dh_init_soln.p from /n/home07/yitians/dm21cm/DM21cm/outputs/dh/xc_xrayST_soln.p
INFO:root:DarkHistoryWrapper: Found existing DarkHistory initial conditions.


zp = 4.495863e+01 E_tot_ave = 0.000000e+00


  0%|          | 0/2037 [00:00<?, ?it/s]

In [None]:
#--- loop ---
for i_z in z_range:
    z_current = z_edges[i_z]
    z_next = z_edges[i_z+1]
    print(i_z, z_current)

    perturbed_field = p21c.perturb_field(redshift=z_next, init_boxes=p21c_initial_conditions)
    spin_temp, ionized_box, brightness_temp = p21c_step(
        perturbed_field, spin_temp, ionized_box,
        input_heating = None,
        input_ionization = None,
        input_jalpha = None,
        astro_params = astro_params_before_step
    )

    if z_next < 44:
        break

In [13]:
i_z += 1
z_current = z_edges[i_z]
z_next = z_edges[i_z+1]

In [14]:
perturbed_field = p21c.perturb_field(redshift=z_next, init_boxes=p21c_initial_conditions)
spin_temp = p21c.spin_temperature(
    perturbed_field = perturbed_field,
    #previous_spin_temp = spin_temp,
    input_heating_box = None,
    input_ionization_box = None,
    input_jalpha_box = None,
    astro_params = p21c.AstroParams(L_X = 40.),
    #astro_params = astro_params_before_step,
)

zp = 4.495863e+01 E_tot_ave = 0.000000e+00
zp = 4.491272e+01 E_tot_ave = 0.000000e+00
zp = 4.486685e+01 E_tot_ave = 0.000000e+00
zp = 4.482103e+01 E_tot_ave = 0.000000e+00
zp = 4.477525e+01 E_tot_ave = 0.000000e+00
zp = 4.472953e+01 E_tot_ave = 0.000000e+00
zp = 4.468384e+01 E_tot_ave = 0.000000e+00
zp = 4.463820e+01 E_tot_ave = 0.000000e+00
zp = 4.459261e+01 E_tot_ave = 0.000000e+00
zp = 4.454706e+01 E_tot_ave = 0.000000e+00
zp = 4.450156e+01 E_tot_ave = 0.000000e+00
zp = 4.445610e+01 E_tot_ave = 0.000000e+00
zp = 4.441069e+01 E_tot_ave = 0.000000e+00
zp = 4.436533e+01 E_tot_ave = 0.000000e+00
zp = 4.432001e+01 E_tot_ave = 0.000000e+00
zp = 4.427473e+01 E_tot_ave = 0.000000e+00


KeyboardInterrupt: 