In [1]:
%reload_ext autoreload
%autoreload 2

import os, sys
if os.environ['USER'] == 'yitians' and 'submit' in os.uname().nodename:
    os.environ['DM21CM_DATA_DIR'] = '/data/submit/yitians/dm21cm/DM21cm'
    os.environ['DH_DIR'] = '/work/submit/yitians/darkhistory/DarkHistory'
sys.path.append('..')
sys.path.append(os.environ['DH_DIR'])

from tqdm import tqdm
import time
import pickle

import numpy as np
import jax
import jax.numpy as jnp

import py21cmfast as p21c
from py21cmfast import plotting, cache_tools
print(f'Using 21cmFAST version {p21c.__version__}')

%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import cm
mpl.rc_file('../matplotlibrc')



Using 21cmFAST version 0.1.dev1579+g6b1da6d.d20230421


In [2]:
import logging
logger = logging.getLogger('21cmFAST')
logger.setLevel(logging.INFO)
logging.getLogger('py21cmfast._utils').setLevel(logging.CRITICAL+1)
logging.getLogger('py21cmfast.wrapper').setLevel(logging.CRITICAL+1)

In [3]:
from dm21cm.dm_params import DMParams
from dm21cm.data_loader import load_data
import dm21cm.physics as phys

from darkhistory.spec.spectrum import Spectrum # use branch numpy issue

/n/holyscratch01/iaifi_lab/yitians/darkhistory/DHdata_v1_1


In [4]:
jax.devices()

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0)]

## 0. Global config

In [5]:
! lscpu | grep "CPU(s)"

CPU(s):                64
On-line CPU(s) list:   0-63
NUMA node0 CPU(s):     0-31
NUMA node1 CPU(s):     32-63


In [6]:
N_THREADS = 32

In [7]:
def get_z_arr(z_start=None, z_end=20):
    
    if z_start is None:
        z_start = p21c.global_params.Z_HEAT_MAX
    z_arr = [z_end]
    while np.max(z_arr) < z_start:
        z_prev = (1 + np.max(z_arr)) * p21c.global_params.ZPRIME_STEP_FACTOR - 1
        z_arr.append(z_prev)
    return np.array(z_arr[::-1][1:])

In [8]:
# check cached runs
CACHE_DIR_BASE = os.environ['P21C_CACHE_DIR']
os.listdir(CACHE_DIR_BASE)

['phph_dhinitPhTkX',
 'phph',
 'phph_dhinitPhTk',
 'phph_xray',
 'phph_dhinitPhTkX_new',
 'base']

## 1. Initialization

In [15]:
# config
run_name = 'base'
struct_boost_model = 'erfc 1e-3' # default: erfc 1e-3
run_mode = '' # '' or 'inj'

dhinit_list = [] # 'phot', 'T_k', 'x_e'
enable_phot_bath = False

debug_uniform_injection = False

reload = False
use_tqdm = True
save_slices = True

abscs = pickle.load(open('../data/abscissas/abscs_230408x.p', 'rb'))

p21c.config['direc'] = f'{CACHE_DIR_BASE}/{run_name}'
os.makedirs(p21c.config['direc'], exist_ok=True)

In [None]:
cache_tools.clear_cache()

## 2. Loop

In [16]:
# redshift
p21c.global_params.ZPRIME_STEP_FACTOR = 1.05
p21c.global_params.Z_HEAT_MAX = 44.
z_arr = get_z_arr(z_end=6.)
#p21c.global_params.CLUMPING_FACTOR = 1.

# dark matter
dm_params = DMParams(mode='swave', primary='phot_delta', m_DM=1e10, sigmav=1e-23)

if len(dhinit_list) > 0:
    
    dhinit_fn = f"{p21c.config['direc']}/dhinit_soln.p"
    
    if os.path.exists(dhinit_fn):
        dhinit_soln = pickle.load(open(dhinit_fn, 'rb'))
    else:
        logger.info('Running DarkHistory to generate initial conditions.')
        
        import main
        dhinit_soln = main.evolve(
            DM_process=dm_params.mode, mDM=dm_params.m_DM,
            sigmav=dm_params.sigmav, primary=dm_params.primary,
            struct_boost=phys.struct_boost_func(model=struct_boost_model),
            start_rs=3000, end_rs=z_arr[0], coarsen_factor=12, verbose=1
        )
        pickle.dump(dhinit_soln, open(dhinit_fn, 'wb'))

In [17]:
# physics
cosmo_params = dict(OMm=0.32, OMb=0.049, POWER_INDEX=0.96, SIGMA_8=0.83, hlittle=0.67)

initial_conditions = p21c.initial_conditions(
    user_params = p21c.UserParams(
        HII_DIM = 50, # [1]
        BOX_LEN = 50, # [p-Mpc]
        N_THREADS = N_THREADS
    ),
    cosmo_params = p21c.CosmoParams(**cosmo_params),
    random_seed=54321, write=True
)

# box
box_dim = initial_conditions.user_params.HII_DIM

# recording
records = []
input_time_tot = 0.
p21c_time_tot = 0.
if use_tqdm:
    pbar = tqdm(total=len(z_arr), position=0)
    
if save_slices:
    saved_slices = []
    i_slice = int(box_dim/2)

78it [05:51,  4.50s/it]                                                                      | 0/39 [00:00<?, ?it/s]


In [25]:
# xray
from dm21cm.xray.field_smoother import WindowedData

xray_eng_range = (1e2, 1e4)
i_xray_fm = np.searchsorted(abscs['photE'], xray_eng_range[0])
i_xray_to = np.searchsorted(abscs['photE'], xray_eng_range[1])

# tmp remove
os.remove(p21c.config['direc']+'/xray_brightness.h5')

xray_brightness_boxes = WindowedData(
    data_path=p21c.config['direc']+'/xray_brightness.h5',
    N=box_dim, L=initial_conditions.user_params.BOX_LEN
)

In [18]:
## LOOP
for i_z in range(len(z_arr)):
#for i_z in range(2):

    z = z_arr[i_z]
    
    input_timer = time.time()
    
    if i_z == 0:
        spin_temp = None
        input_heating = input_ionization = input_jalpha = None
        
        ## load tfs
        phot_phot_tf = load_data('phot_phot', reload=reload)
        #elec_phot_tf = load_data('elec_phot', reload=reload)
        phot_dep_tf = load_data('phot_dep', reload=reload)
        #elec_dep_tf = load_data('elec_dep', reload=reload)
        
        ## fix DM in spec
        phot_phot_tf.set_fixed_in_spec(dm_params.inj_phot_spec.N)
        #elec_phot_tf.set_fixed_in_spec(dm_params.inj_elec_spec.N)
        phot_dep_tf.set_fixed_in_spec(dm_params.inj_phot_spec.N)
        #elec_dep_tf.set_fixed_in_spec(dm_params.inj_elec_spec.N)
        
        ## photon bath
        if 'phot' in dhinit_list:
            phot_bath_next = dhinit_soln['highengphot'][-1]
        else:
            phot_bath_next = Spectrum(
                abscs['photE'], np.zeros_like(abscs['photE']), spec_type='N', rs=1+z_arr[0]
            ) # [N per Bavg]
    
    else: ## input from second step
        
        input_heating = p21c.input_heating(redshift=z, init_boxes=initial_conditions, write=False)
        input_ionization = p21c.input_ionization(redshift=z, init_boxes=initial_conditions, write=False)
        input_jalpha = p21c.input_jalpha(redshift=z, init_boxes=initial_conditions, write=False)
        
        if i_z == 1 and 'T_k' in dhinit_list:
            T_k_21cmfast = np.mean(spin_temp.Tk_box)
            T_k_DH = dhinit_soln['Tm'][-1] / phys.kB
            input_heating.input_heating += (T_k_DH - T_k_21cmfast)
            
        if i_z == 1 and 'x_e' in dhinit_list:
            x_e_21cmfast = np.mean(1 - ionized_box.xH_box)
            x_e_DH = dhinit_soln['x'][-1, 0] # last step, HI
            input_ionization.input_ionization += (x_e_DH - x_e_21cmfast)

        if i_z == 1 and run_mode == '':
            logger.warning('Not injecting anything in this run!')

        elif run_mode == 'inj':
            
            #========== common ==========
            delta_box = jnp.asarray(perturbed_field.density)
            x_e_box = jnp.asarray(1 - ionized_box.xH_box)
            B_per_Bavg = 1 + delta_box
            
            if debug_uniform_injection:
                delta_box = jnp.full_like(delta_box, jnp.mean(delta_box))
                x_e_box = jnp.full_like(x_e_box, jnp.mean(x_e_box))
                B_per_Bavg = 1 + delta_box
            
            #========== photon bath ==========
            # redshift from previous step
            if enable_phot_bath:
                phot_bath_next.redshift(1+z)
                phot_bath = phot_bath_next
            
            #========== photon bath -> photon ==========
            if enable_phot_bath:
                out_phot_N = phot_phot_tf(
                    rs = 1 + z,
                    in_spec = phot_bath.N,
                    nBs_s = (1+delta_box).ravel(),
                    x_s = x_e_box.ravel(),
                    sum_result = True,
                    out_of_bounds_action = 'clip',
                ) / (box_dim ** 3) # [N / Bavg]

                phot_bath_next = Spectrum(
                    phot_phot_tf.abscs['out'],
                    out_phot_N,
                    spec_type='N', rs=1+z
                ) # [N / Bavg]
            
            #========== photon bath -> deposition ==========
            if enable_phot_bath:
                phot_bath_dep_box = phot_dep_tf(
                    rs = 1 + z,
                    in_spec = phot_bath.N,
                    nBs_s = (1+delta_box).ravel(),
                    x_s = x_e_box.ravel(),
                    sum_result = False,
                    out_of_bounds_action = 'clip',
                ).reshape(box_dim, box_dim, box_dim, len(abscs['dep_c'])) # [eV / Bavg]
            else:
                phot_bath_dep_box = np.zeros((box_dim, box_dim, box_dim, len(abscs['dep_c']))) # [eV / Bavg]
            
            #========== DM: common ==========
            z_prev = z_arr[i_z-1]
            dt = phys.dt_between_z(z_prev, z) # [s]
            if dm_params.mode == 'swave':
                struct_boost = phys.struct_boost_func(model=struct_boost_model)(1+z)
            else:
                struct_boost = 1
            
            rho_DM_box = (1 + delta_box) * phys.rho_DM * (1+z)**3 # [eV cm^-3]
            n_Bavg = phys.n_B * (1+z)**3 # [Bavg cm^-3]
            inj_per_Bavg_box = phys.inj_rate(rho_DM_box, dm_params) * dt * struct_boost / n_Bavg # [inj/Bavg]
            
            #========== DM -> photon ==========
            if enable_phot_bath:
                out_phot_N_phot = phot_phot_tf(
                    rs = 1 + z,
                    in_spec = dm_params.inj_phot_spec.N,
                    nBs_s = (1+delta_box).ravel(),
                    x_s = x_e_box.ravel(),
                    sum_result = True,
                    sum_weight = inj_per_Bavg_box.ravel(),
                    out_of_bounds_action = 'clip',
                ) / (box_dim ** 3) # [N / Bavg]

                # out_phot_N_elec = elec_phot_tf(
                #     rs = 1 + z,
                #     in_spec = dm_params.inj_elec_spec.N,
                #     nBs_s = (1+delta_box).ravel(),
                #     x_s = x_e_box.ravel(),
                #     sum_result = True,
                #     sum_weight = inj_per_Bavg_box.ravel(),
                #     out_of_bounds_action = 'clip',
                # ) / (box_dim ** 3) # [N / Bavg]

                phot_bath_next += Spectrum(
                    phot_phot_tf.abscs['out'],
                    out_phot_N_phot,# + out_phot_N_elec,
                    spec_type='N', rs=1+z
                ) # [N / Bavg]
            
            #========== DM -> deposition ==========
            DM_phot_dep_box = phot_dep_tf(
                rs = 1 + z,
                in_spec = dm_params.inj_phot_spec.N,
                nBs_s = (1+delta_box).ravel(),
                x_s = x_e_box.ravel(),
                sum_result = False,
                out_of_bounds_action = 'clip',
            ).reshape(box_dim, box_dim, box_dim, len(abscs['dep_c'])) * inj_per_Bavg_box[...,None] # [eV / Bavg]

            # DM_elec_dep_box = elec_dep_tf(
            #     rs = 1 + z,
            #     in_spec = dm_params.inj_elec_spec.N,
            #     nBs_s = (1+delta_box).ravel(),
            #     x_s = x_e_box.ravel(),
            #     sum_result = False,
            #     out_of_bounds_action = 'clip',
            # ).reshape(box_dim, box_dim, box_dim, len(abscs['dep_c'])) * inj_per_Bavg_box[...,None] # [eV / Bavg]
            
            dep_box = phot_bath_dep_box + DM_phot_dep_box# + DM_elec_dep_box # [eV / Bavg]
            # last dimension: ('H ion', 'He ion', 'exc', 'heat', 'cont', 'xray')
            
            #========== xray ==========
            # remember to remove xray photon from bath
            # xray_brightness_boxes.set_field(dep_box[..., 5], z)
            # for i_shell_late in np.arange(2, i_z):
            #     i_shell_early = i_shell_late - 1
            #     z_early = z_arr[i_shell_early]
            #     z_late = z_arr[i_shell_late]
            #     xray_box_at_z = xray_brightness_boxes.get_smoothed_shell(
            #         phys.conformal_t_between(1+z, 1+z_early),
            #         phys.conformal_t_between(1+z, 1+z_late),
            #         z
            #     )

            #========== deposite energy in input_boxes ==========
            input_heating.input_heating += np.array(
                2 / (3*phys.kB*(1+x_e_box)) * dep_box[...,3] / B_per_Bavg
            ) # [K/Bavg] / [B/Bavg] = [K/B]
            input_ionization.input_ionization += np.array(
                (dep_box[...,0] + dep_box[...,1]) / phys.rydberg / B_per_Bavg
            ) # [1/Bavg] / [B/Bavg] = [1/B]
            
            n_lya = dep_box[...,2] * n_Bavg / phys.lya_eng # [lya cm^-3]
            dnu_lya = (phys.rydberg - phys.lya_eng) / (2*np.pi*phys.hbar) # [Hz^-1]
            J_lya = n_lya * phys.c / (4*np.pi) / dnu_lya # [lya cm^-2 s^-1 sr^-1 Hz^-1]
            input_jalpha.input_jalpha += np.array( J_lya )
            
            #========== record ==========
            dE_inj_per_Bavg = dm_params.eng_per_inj * np.mean(inj_per_Bavg_box) # [eV per Bavg]
            dE_inj_per_Bavg_unclustered = dE_inj_per_Bavg / struct_boost
            record_inj = {
                'dE_inj_per_B' : dE_inj_per_Bavg,
                'f_heat' : np.mean(dep_box[...,3]) / dE_inj_per_Bavg_unclustered,
                'f_ion'  : np.mean(dep_box[...,0] + dep_box[...,1]) / dE_inj_per_Bavg_unclustered,
                'f_exc'  : np.mean(dep_box[...,2]) / dE_inj_per_Bavg_unclustered,
            }
            
    input_time_tot += time.time() - input_timer

    #========== step in 21cmFAST ==========
    p21c_timer = time.time()
    perturbed_field = p21c.perturb_field( # perturbed_field controls the redshift
        redshift=z,
        init_boxes=initial_conditions
    )
    spin_temp = p21c.spin_temperature(
        perturbed_field=perturbed_field,
        previous_spin_temp=spin_temp,
        input_heating_box=input_heating,
        input_ionization_box=input_ionization,
        input_jalpha_box=input_jalpha,
        write=True
    )
    ionized_box = p21c.ionize_box(
        spin_temp=spin_temp
    )
    brightness_temp = p21c.brightness_temperature(
        ionized_box=ionized_box,
        perturbed_field=perturbed_field,
        spin_temp=spin_temp
    )
    coeval = p21c.Coeval(
        redshift = z,
        initial_conditions = initial_conditions,
        perturbed_field = perturbed_field,
        ionized_box = ionized_box,
        brightness_temp = brightness_temp,
        ts_box = spin_temp,
    )
    p21c_time_tot += time.time() - p21c_timer

    #========== save results ==========
    if i_z > 0:
        record = {
            'z'   : z,
            'T_s' : np.mean(spin_temp.Ts_box), # [mK]
            'T_b' : np.mean(brightness_temp.brightness_temp), # [K]
            'T_k' : np.mean(spin_temp.Tk_box), # [K]
            'x_e' : np.mean(1 - ionized_box.xH_box), # [1]
        }
        if run_mode == 'inj':
            record.update(record_inj)
        records.append(record)
        
    if save_slices:
        saved_slices.append({
            'z'   : z,
            'T_s' : spin_temp.Ts_box[i_slice], # [mK]
            'T_b' : brightness_temp.brightness_temp[i_slice], # [K]
            'T_k' : spin_temp.Tk_box[i_slice], # [K]
            'x_e' : 1 - ionized_box.xH_box[i_slice], # [1]
            'delta' : perturbed_field.density[i_slice], # [1]
        })
        
    if use_tqdm:
        pbar.update()
        
#========== end of loop ==========
    
print(f'input used {input_time_tot:.4f} s')
print(f'p21c used {p21c_time_tot:.4f} s')

arr_records = {k: np.array([r[k] for r in records]) for k in records[0].keys()}
np.save(f'../data/run_info/{run_name}_records', arr_records)
if save_slices:
    np.save(f'../data/run_info/{run_name}_slices', saved_slices)

100%|███████████████████████████████████████████████████████████████████████████████| 39/39 [00:05<00:00, 15.98it/s]

input used 0.7853 s
p21c used 2.5925 s


100%|███████████████████████████████████████████████████████████████████████████████| 39/39 [00:18<00:00, 15.98it/s]

## 3. Lightcone

In [29]:
import powerbox
# import importlib

# importlib.reload(powerbox)

In [39]:
lightcone = p21c.run_lightcone(
    redshift = z_arr[-1],
    user_params  = spin_temp.user_params,
    cosmo_params = spin_temp.cosmo_params,
    astro_params = spin_temp.astro_params,
    flag_options = spin_temp.flag_options,
)

In [None]:
fig, ax = plt.subplots(figsize = (15, 10))
plotting.lightcone_sliceplot(lightcone, fig=fig, ax=ax)
ax.set(aspect=10)

In [40]:
def compute_power(
   box,
   length,
   n_psbins,
   log_bins=True,
   ignore_kperp_zero=True,
   ignore_kpar_zero=False,
   ignore_k_zero=False,
):
    # Determine the weighting function required from ignoring k's.
    k_weights = np.ones(box.shape, int)
    n0 = k_weights.shape[0]
    n1 = k_weights.shape[-1]

    if ignore_kperp_zero:
        k_weights[n0 // 2, n0 // 2, :] = 0
    if ignore_kpar_zero:
        k_weights[:, :, n1 // 2] = 0
    if ignore_k_zero:
        k_weights[n0 // 2, n0 // 2, n1 // 2] = 0

    res = powerbox.tools.get_power(
        box,
        boxlength=length,
        bins=n_psbins,
        bin_ave=False,
        get_variance=False,
        log_bins=log_bins,
        k_weights=k_weights,
    )

    res = list(res)
    k = res[1]
    if log_bins:
        k = np.exp((np.log(k[1:]) + np.log(k[:-1])) / 2)
    else:
        k = (k[1:] + k[:-1]) / 2

    res[1] = k
    return res

def powerspectra(brightness_temp, n_psbins=50, nchunks=20, min_k=0.1, max_k=1.0, logk=True):
    data = []
    chunk_indices = list(range(0,brightness_temp.n_slices,round(brightness_temp.n_slices / nchunks),))    
    
    if len(chunk_indices) > nchunks:
        chunk_indices = chunk_indices[:-1]
    chunk_indices.append(brightness_temp.n_slices)

    for i in range(nchunks):
        start = chunk_indices[i]
        end = chunk_indices[i + 1]
        chunklen = (end - start) * brightness_temp.cell_size

        power, k = compute_power(
            brightness_temp.brightness_temp[:, :, start:end],
            (BOX_LEN, BOX_LEN, chunklen),
            n_psbins,
            log_bins=logk,
        )
        data.append({"k": k, "delta": power * k ** 3 / (2 * np.pi ** 2)})
    return data

In [41]:
BOX_LEN = 50
HII_DIM = 50

k_fundamental = 2*np.pi / BOX_LEN
k_max = k_fundamental * HII_DIM
Nk = np.floor(HII_DIM/1).astype(int)

In [42]:
out = powerspectra(lightcone, min_k=k_fundamental, max_k=k_max)

In [43]:
pickle.dump(out, open(f'../data/run_info/{run_name}_ps', 'wb'))