In [1]:
%reload_ext autoreload
%autoreload 2

import os, sys
if os.environ['USER'] == 'yitians' and 'submit' in os.uname().nodename:
    os.environ['DM21CM_DATA_DIR'] = '/data/submit/yitians/dm21cm/DM21cm'
    os.environ['DH_DIR'] = '/work/submit/yitians/darkhistory/DarkHistory'
sys.path.append('..')
sys.path.append(os.environ['DH_DIR'])

from tqdm import tqdm
import time
import pickle

import numpy as np
import jax
import jax.numpy as jnp

import py21cmfast as p21c
from py21cmfast import plotting, cache_tools
print(f'Using 21cmFAST version {p21c.__version__}')

%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import cm
mpl.rc_file('../matplotlibrc')

  "Your configuration file is out of date. Updating..."


Using 21cmFAST version 0.1.dev1578+g6f96f89.d20230224


In [2]:
import logging
logger = logging.getLogger('21cmFAST')
logger.setLevel(logging.INFO)
logging.getLogger('py21cmfast._utils').setLevel(logging.CRITICAL+1)
logging.getLogger('py21cmfast.wrapper').setLevel(logging.CRITICAL+1)

In [3]:
from dm21cm.dm_params import DMParams
from dm21cm.data_loader import load_data
import dm21cm.physics as phys

In [4]:
from darkhistory.spec.spectrum import Spectrum

In [5]:
jax.devices()



[CpuDevice(id=0)]

## 0. Global config

In [6]:
! lscpu | grep "CPU(s)"

CPU(s):                48
On-line CPU(s) list:   0-47
NUMA node0 CPU(s):     0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36,38,40,42,44,46
NUMA node1 CPU(s):     1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31,33,35,37,39,41,43,45,47


In [7]:
N_THREADS = 32

In [8]:
def get_z_arr(z_start=None, z_end=20):
    
    if z_start is None:
        z_start = p21c.global_params.Z_HEAT_MAX
    z_arr = [z_end]
    while np.max(z_arr) < z_start:
        z_prev = (1 + np.max(z_arr)) * p21c.global_params.ZPRIME_STEP_FACTOR - 1
        z_arr.append(z_prev)
    return np.array(z_arr[::-1][1:])

In [17]:
# check cached runs
CACHE_DIR_BASE = '/scratch/submit/ctp/yitians/21cmFAST-cache'
os.listdir(CACHE_DIR_BASE)

['test',
 'phph_dh_mh3',
 'phph_phbath_dhinit_Tk_dhinit',
 'emf_comp_spf_mh3',
 'base',
 'testdhph',
 'turnon_phph',
 'phph_phbath_dhinit_Tk_dhinit_uniform_C1',
 'phot_bath_no_prop',
 'phph_phbath_dhinit_Tk_dhinit_uniform',
 'phph_dh_phbath_mh3_dhinit',
 'phph_dh_phbath_mh3',
 'emf_comp_dh_mh3',
 'phph_spf_mh3',
 'emf_comp_dh_mh6',
 'emf_comp_dh_mh9',
 'emf_comp_base',
 'phph_dh_phbath_mh3_nophot']

## 1. Initialization

In [13]:
# config
RUN_NAME = 'phph_phbath_dhinit_Tk_dhinit_uniform_C1'
struct_boost_model = 'erfc 1e-3' # default: erfc 1e-3
run_mode = 'inj'

dh_init_phot_bath = True
dh_init_T_k = True
enable_phot_bath = True

debug_uniform_injection = True

reload = False
use_tqdm = True
save_slices = True

abscs = pickle.load(open('../data/abscissas/abscs_nBs_test_2.p', 'rb'))

p21c.config['direc'] = f'{CACHE_DIR_BASE}/{RUN_NAME}'
os.makedirs(p21c.config['direc'], exist_ok=True)

In [14]:
cache_tools.clear_cache()

2023-04-14 09:47:12,716 | INFO | Removed 0 files from cache.
INFO:21cmFAST:Removed 0 files from cache.


## 2. Loop

In [15]:
# redshift
p21c.global_params.ZPRIME_STEP_FACTOR = 1.05
p21c.global_params.Z_HEAT_MAX = 44.
z_arr = get_z_arr(z_end=6.)

p21c.global_params.CLUMPING_FACTOR = 1.

# dark matter
dm_params = DMParams(mode='swave', primary='phot_delta', m_DM=1e10, sigmav=1e-23)

if dh_init_phot_bath or dh_init_T_k:
    
    dhinit_fn = f"{p21c.config['direc']}/dhinit_soln.p"
    
    if os.path.exists(dhinit_fn):
        dhinit_soln = pickle.load(open(dhinit_fn, 'rb'))
    else:
        logger.info('Running DarkHistory to generate initial conditions.')
        
        import main
        dhinit_soln = main.evolve(
            DM_process=dm_params.mode, mDM=dm_params.m_DM,
            sigmav=dm_params.sigmav, primary=dm_params.primary,
            struct_boost=phys.struct_boost_func(model=struct_boost_model),
            start_rs=3000, end_rs=z_arr[0], coarsen_factor=12, verbose=1
        )
        pickle.dump(dhinit_soln, open(dhinit_fn, 'wb'))

In [16]:
# physics
cosmo_params_EMF = dict(OMm=0.32, OMb=0.049, POWER_INDEX=0.96, SIGMA_8=0.83, hlittle=0.67)

initial_conditions = p21c.initial_conditions(
    user_params = p21c.UserParams(
        HII_DIM = 50, # [1]
        BOX_LEN = 50, # [p-Mpc]
        N_THREADS = N_THREADS
    ),
    cosmo_params = p21c.CosmoParams(**cosmo_params_EMF),
    random_seed=54321, write=True
)

# box
box_dim = initial_conditions.user_params.HII_DIM

# recording
records = []
input_time_tot = 0.
p21c_time_tot = 0.
if use_tqdm:
    pbar = tqdm(total=len(z_arr))
    
if save_slices:
    saved_slices = []
    i_slice = int(box_dim/2)

## LOOP
for i_z in range(len(z_arr)):

    z = z_arr[i_z]
    
    input_timer = time.time()
    
    if i_z == 0:
        spin_temp = None
        input_heating = input_ionization = input_jalpha = None
        
        ## load tfs
        phot_phot_tf = load_data('phot_phot', reload=reload)
        elec_phot_tf = load_data('elec_phot', reload=reload)
        phot_dep_tf = load_data('phot_dep', reload=reload)
        elec_dep_tf = load_data('elec_dep', reload=reload)
        
        ## fix DM in spec
        phot_phot_tf.set_fixed_in_spec(dm_params.inj_phot_spec.N)
        elec_phot_tf.set_fixed_in_spec(dm_params.inj_elec_spec.N)
        phot_dep_tf.set_fixed_in_spec(dm_params.inj_phot_spec.N)
        elec_dep_tf.set_fixed_in_spec(dm_params.inj_elec_spec.N)
        
        ## photon bath
        if dh_init_phot_bath:
            phot_bath_next = dhinit_soln['highengphot'][-1]
        else:
            phot_bath_next = Spectrum(
                abscs['photE'], np.zeros_like(abscs['photE']), spec_type='N', rs=1+z_arr[0]
            ) # [N per Bavg]
    
    else: ## input from second step
        
        input_heating = p21c.input_heating(redshift=z, init_boxes=initial_conditions, write=False)
        input_ionization = p21c.input_ionization(redshift=z, init_boxes=initial_conditions, write=False)
        input_jalpha = p21c.input_jalpha(redshift=z, init_boxes=initial_conditions, write=False)
        
        if i_z == 1 and dh_init_T_k:
            T_k_21cmfast = np.mean(spin_temp.Tk_box)
            T_k_DH = dhinit_soln['Tm'][-1] / phys.kB
            input_heating.input_heating += (T_k_DH - T_k_21cmfast)

        if run_mode == '':
            if i_z == 1:
                logger.warning('Not injecting anything in this run!')

        elif run_mode == 'inj':
            
            ## common
            delta_box = jnp.asarray(perturbed_field.density)
            x_e_box = jnp.asarray(1 - ionized_box.xH_box)
            B_per_Bavg = 1 + delta_box
            
            if debug_uniform_injection:
                delta_box = jnp.full_like(delta_box, jnp.mean(delta_box))
                x_e_box = jnp.full_like(x_e_box, jnp.mean(x_e_box))
                B_per_Bavg = 1 + delta_box
            
            ## photon bath: redshift from previous step
            if enable_phot_bath:
                phot_bath_next.redshift(1+z)
                phot_bath = phot_bath_next
            
            ## photon bath -> photon
            if enable_phot_bath:
                out_phot_N = phot_phot_tf(
                    rs = 1 + z,
                    in_spec = phot_bath.N,
                    nBs_s = (1+delta_box).ravel(),
                    x_s = x_e_box.ravel(),
                    sum_result = True,
                    out_of_bounds_action = 'clip',
                ) / (box_dim ** 3) # [N / Bavg]

                phot_bath_next = Spectrum(
                    phot_phot_tf.abscs['out'],
                    out_phot_N,
                    spec_type='N', rs=1+z
                ) # [N / Bavg]
            
            ## photon bath -> deposition
            if enable_phot_bath:
                phot_bath_dep_box = phot_dep_tf(
                    rs = 1 + z,
                    in_spec = phot_bath.N,
                    nBs_s = (1+delta_box).ravel(),
                    x_s = x_e_box.ravel(),
                    sum_result = False,
                    out_of_bounds_action = 'clip',
                ).reshape(box_dim, box_dim, box_dim, len(abscs['dep_c'])) # [eV / Bavg]
            else:
                phot_bath_dep_box = np.zeros((box_dim, box_dim, box_dim, len(abscs['dep_c']))) # [eV / Bavg]
            
            ## DM: common
            z_prev = z_arr[i_z-1]
            dt = phys.dt_between_z(z_prev, z) # [s]
            if dm_params.mode == 'swave':
                struct_boost = phys.struct_boost_func(model=struct_boost_model)(1+z)
            else:
                struct_boost = 1
            
            rho_DM_box = (1 + delta_box) * phys.rho_DM * (1+z)**3 # [eV cm^-3]
            n_Bavg = phys.n_B * (1+z)**3 # [Bavg cm^-3]
            inj_per_Bavg_box = phys.inj_rate(rho_DM_box, dm_params) * dt * struct_boost / n_Bavg # [inj/Bavg]
            
            ## DM -> photon
            if enable_phot_bath:
                out_phot_N_phot = phot_phot_tf(
                    rs = 1 + z,
                    in_spec = dm_params.inj_phot_spec.N,
                    nBs_s = (1+delta_box).ravel(),
                    x_s = x_e_box.ravel(),
                    sum_result = True,
                    sum_weight = inj_per_Bavg_box.ravel(),
                    out_of_bounds_action = 'clip',
                ) / (box_dim ** 3) # [N / Bavg]

                out_phot_N_elec = elec_phot_tf(
                    rs = 1 + z,
                    in_spec = dm_params.inj_elec_spec.N,
                    nBs_s = (1+delta_box).ravel(),
                    x_s = x_e_box.ravel(),
                    sum_result = True,
                    sum_weight = inj_per_Bavg_box.ravel(),
                    out_of_bounds_action = 'clip',
                ) / (box_dim ** 3) # [N / Bavg]

                phot_bath_next += Spectrum(
                    phot_phot_tf.abscs['out'],
                    out_phot_N_phot + out_phot_N_elec,
                    spec_type='N', rs=1+z
                ) # [N / Bavg]
            
            ## DM -> deposition
            DM_phot_dep_box = phot_dep_tf(
                rs = 1 + z,
                in_spec = dm_params.inj_phot_spec.N,
                nBs_s = (1+delta_box).ravel(),
                x_s = x_e_box.ravel(),
                sum_result = False,
                out_of_bounds_action = 'clip',
            ).reshape(box_dim, box_dim, box_dim, len(abscs['dep_c'])) * inj_per_Bavg_box[...,None] # [eV / Bavg]

            DM_elec_dep_box = elec_dep_tf(
                rs = 1 + z,
                in_spec = dm_params.inj_elec_spec.N,
                nBs_s = (1+delta_box).ravel(),
                x_s = x_e_box.ravel(),
                sum_result = False,
                out_of_bounds_action = 'clip',
            ).reshape(box_dim, box_dim, box_dim, len(abscs['dep_c'])) * inj_per_Bavg_box[...,None] # [eV / Bavg]
            
            ## deposite energy
            dep_box = phot_bath_dep_box + DM_phot_dep_box + DM_elec_dep_box # [eV / Bavg]
            # last dimension: ('H ion', 'He ion', 'exc', 'heat', 'cont')

            input_heating.input_heating += np.array(
                2 / (3*phys.kB*(1+x_e_box)) * dep_box[...,3] / B_per_Bavg
            ) # [K/Bavg] / [B/Bavg] = [K/B]
            input_ionization.input_ionization += np.array(
                (dep_box[...,0] + dep_box[...,1]) / phys.rydberg / B_per_Bavg
            ) # [1/Bavg] / [B/Bavg] = [1/B]
            
            n_lya = dep_box[...,2] * n_Bavg / phys.lya_eng # [lya cm^-3]
            dnu_lya = (phys.rydberg - phys.lya_eng) / (2*np.pi*phys.hbar) # [Hz^-1]
            J_lya = n_lya * phys.c / (4*np.pi) / dnu_lya # [lya cm^-2 s^-1 sr^-1 Hz^-1]
            input_jalpha.input_jalpha += np.array( J_lya )
            
            ## record
            dE_inj_per_Bavg = dm_params.eng_per_inj * np.mean(inj_per_Bavg_box) # [eV per Bavg]
            dE_inj_per_Bavg_unclustered = dE_inj_per_Bavg / struct_boost
            record_inj = {
                'dE_inj_per_B' : dE_inj_per_Bavg,
                'f_heat' : np.mean(dep_box[...,3]) / dE_inj_per_Bavg_unclustered,
                'f_ion'  : np.mean(dep_box[...,0] + dep_box[...,1]) / dE_inj_per_Bavg_unclustered,
                'f_exc'  : np.mean(dep_box[...,2]) / dE_inj_per_Bavg_unclustered,
            }
            
        else:
            raise ValueError(run_mode)
            
    input_time_tot += time.time() - input_timer

    ## step in 21cmFAST
    p21c_timer = time.time()
    perturbed_field = p21c.perturb_field( # perturbed_field controls the redshift
        redshift=z,
        init_boxes=initial_conditions
    )
    spin_temp = p21c.spin_temperature(
        perturbed_field=perturbed_field,
        previous_spin_temp=spin_temp,
        input_heating_box=input_heating,
        input_ionization_box=input_ionization,
        input_jalpha_box=input_jalpha,
        write=True
    )
    ionized_box = p21c.ionize_box(
        spin_temp=spin_temp
    )
    brightness_temp = p21c.brightness_temperature(
        ionized_box=ionized_box,
        perturbed_field=perturbed_field,
        spin_temp=spin_temp
    )
    p21c_time_tot += time.time() - p21c_timer

    ## save results
    if i_z > 0:
        record = {
            'z'   : z,
            'T_s' : np.mean(spin_temp.Ts_box), # [mK]
            'T_b' : np.mean(brightness_temp.brightness_temp), # [K]
            'T_k' : np.mean(spin_temp.Tk_box), # [K]
            'x_e' : np.mean(1 - ionized_box.xH_box), # [1]
        }
        if run_mode == 'inj':
            record.update(record_inj)
        records.append(record)
        
    if save_slices:
        saved_slices.append({
            'z'   : z,
            'T_s' : spin_temp.Ts_box[i_slice], # [mK]
            'T_b' : brightness_temp.brightness_temp[i_slice], # [K]
            'T_k' : spin_temp.Tk_box[i_slice], # [K]
            'x_e' : 1 - ionized_box.xH_box[i_slice], # [1]
            'delta' : perturbed_field.density[i_slice], # [1]
        })
        
    if use_tqdm:
        pbar.update()
    
print(f'input used {input_time_tot:.4f} s')
print(f'p21c used {p21c_time_tot:.4f} s')

arr_records = {k: np.array([r[k] for r in records]) for k in records[0].keys()}
np.save(f'../data/run_info/{RUN_NAME}_records', arr_records)
if save_slices:
    np.save(f'../data/run_info/{RUN_NAME}_slices', saved_slices)

  "The USE_INTERPOLATION_TABLES setting has changed in v3.1.2 to be "

  0%|          | 0/39 [00:00<?, ?it/s][A
  3%|▎         | 1/39 [00:07<05:01,  7.93s/it][A
  5%|▌         | 2/39 [00:14<04:28,  7.24s/it][A
  8%|▊         | 3/39 [00:23<04:45,  7.92s/it][A
 10%|█         | 4/39 [00:30<04:29,  7.69s/it][A
 13%|█▎        | 5/39 [00:37<04:06,  7.24s/it][A
 15%|█▌        | 6/39 [00:43<03:52,  7.06s/it][A
 18%|█▊        | 7/39 [00:49<03:31,  6.60s/it][A
 21%|██        | 8/39 [00:55<03:22,  6.53s/it][A
 23%|██▎       | 9/39 [01:01<03:03,  6.11s/it][A
 26%|██▌       | 10/39 [01:06<02:52,  5.93s/it][A
 28%|██▊       | 11/39 [01:12<02:41,  5.76s/it][A
 31%|███       | 12/39 [01:17<02:34,  5.72s/it][A
 33%|███▎      | 13/39 [01:23<02:27,  5.67s/it][A
 36%|███▌      | 14/39 [01:28<02:20,  5.63s/it][A
 38%|███▊      | 15/39 [01:36<02:29,  6.23s/it][A
 41%|████      | 16/39 [01:44<02:36,  6.79s/it][A
 44%|████▎     | 17/39 [01:50<02:26,  6.65s/it][A
 46%|████▌     | 18/39 [02:00

input used 175.5371 s
p21c used 106.4929 s


## 2. DarkHistory comparison

### 2.1 DM

In [21]:
soln = main.evolve(
    DM_process=dm_params.mode, mDM=dm_params.m_DM,
    sigmav=dm_params.sigmav, primary=dm_params.primary,
    struct_boost=phys.struct_boost_func(model=struct_boost_model),
    start_rs=3000, end_rs=6, coarsen_factor=12, verbose=1
)
pickle.dump(soln, open('../data/run_info/rundh_phph_mh3_soln.p', 'wb'))

Loading time: 0.023 s


  0%|          | 0/518 [00:00<?, ?it/s]

Initialization time: 0.008 s


100%|██████████| 518/518 [04:26<00:00,  1.95it/s]

Main loop time: 266.007 s





In [23]:
import darkhistory.physics as dhphys

In [28]:
z_arr = get_z_arr(z_end=6.)
dt = phys.dt_between_z(z_arr[-2], z_arr[-1])

In [31]:
dhphys.inj_rate('swave', 1+z_arr[-1], mDM=1e10, sigmav=1e-23) * dt * dhphys.struct_boost_func(model='erfc')(1+z_arr[-1]) / (dhphys.nB * (1+z_arr[-1])**3) # [eV per Bavg]

22.664122047774903

In [34]:
phys.inj_rate(phys.rho_DM * (1+z_arr[-1])**3, dm_params) * dm_params.eng_per_inj * dt * phys.struct_boost_func(model='erfc 1e-3')(1+z_arr[-1]) / (phys.n_B * (1+z_arr[-1])**3)

22.66412204777491

In [35]:
phys.struct_boost_func(model='erfc 1e-3')(1+z_arr[-1])

5029.444766101746

In [22]:



arr_records = {
    'z' : soln['rs'] - 1,
    'x_e' : soln['x'][:,0],
    'T_k' : soln['Tm']/phys.kB,
    'f_heat' : soln['f']['low']['heat'] + soln['f']['high']['heat'],
    'f_ion' : soln['f']['low']['H ion']  + soln['f']['high']['H ion'] + \
              soln['f']['low']['He ion'] + soln['f']['high']['He ion'],
    'f_exc' : soln['f']['low']['exc'] + soln['f']['high']['exc'],
}

z_low, z_high = 6, 44
z_arr = arr_records['z'][::-1]
i_low = np.searchsorted(z_arr, z_low)
i_high = np.searchsorted(z_arr, z_high)

for k in arr_records.keys():
    arr_records[k] = arr_records[k][::-1][i_low:i_high]
    
np.save(f'../data/run_info/rundh_phph_mh3_records', arr_records)