In [None]:
%reload_ext autoreload
%autoreload 2

import os
import sys
sys.path.append("..")
sys.path.append(os.environ['DH_DIR'])

from tqdm import tqdm
import time
import pickle
import h5py
import logging
import warnings

import numpy as np
from scipy import signal, ndimage, stats, interpolate, integrate
from astropy import cosmology, constants, units
from astropy.cosmology import Planck18
import jax
import jax.numpy as jnp
print(jax.devices())

%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import cm
mpl.rc_file('../matplotlibrc')

In [None]:
import py21cmfast as p21c
from py21cmfast import plotting, cache_tools
print(f'Using 21cmFAST version {p21c.__version__}')
logger = logging.getLogger('21cmFAST')
logger.setLevel(logging.INFO)
logging.getLogger('py21cmfast._utils').setLevel(logging.CRITICAL+1)
logging.getLogger('py21cmfast.wrapper').setLevel(logging.CRITICAL+1)

from dm21cm.dm_params import DMParams
from dm21cm.data_loader import load_data
import dm21cm.physics as phys
from XRay_Development.field_smoother import WindowedData # move into dm21cm?

from darkhistory.spec.spectrum import Spectrum # use branch numpy issue
#from dm21cm.spec import Spectrum

## 0. Global config

In [None]:
! lscpu | grep "CPU(s)"

In [None]:
N_THREADS = 32

In [None]:
def get_z_arr(z_start=None, z_end=20):
    
    if z_start is None:
        z_start = p21c.global_params.Z_HEAT_MAX
    z_arr = [z_end]
    while np.max(z_arr) < z_start:
        z_prev = (1 + np.max(z_arr)) * p21c.global_params.ZPRIME_STEP_FACTOR - 1
        z_arr.append(z_prev)
    return np.array(z_arr[::-1][1:])

In [None]:
# check cached runs
CACHE_DIR_BASE = os.environ['P21C_CACHE_DIR']
os.listdir(CACHE_DIR_BASE)

## 1. Initialization

In [None]:
"""Docstring for all the switches.

========== run config ==========
run_name : str
use_tqdm : bool
save_slices : bool
    Whether to save slices of T_b, T_k, ... of run.

========== physics ==========
struct_boost_model : {'erfc 1e-3', 'erfc 1e-6', 'erfc 1e-9'}
run_mode : {'no inj', 'bath', 'xray'}
    'no inj' runs ye olde 21cmFAST, 'bath' treats every photon as if in bath. 'xray' treats xray seperately.
dhinit_list : list
    Which variables to initialize with DarkHistory, can include 'phot', 'T_k', 'x_e'.
    
========== debug ==========
debug_enable_elec : bool
    Electron transfer function not yet implemented.
debug_reload : bool
    Whether to reload DarkHistoory transfer functions.
"""

run_name = 'base'
use_tqdm = True
save_slices = True

struct_boost_model = 'erfc 1e-3'
run_mode = 'bath'
dhinit_list = ['phot', 'T_k', 'x_e']

DEBUG_ENABLE_ELEC = False
DEBUG_RELOAD_TF = False

In [None]:
abscs = pickle.load(open('../data/abscissas/abscs_230408x.p', 'rb'))
photeng = abscs['photE']
eleceng = abscs['elecE']

p21c.config['direc'] = f'{CACHE_DIR_BASE}/{run_name}'
os.makedirs(p21c.config['direc'], exist_ok=True)

In [None]:
cache_tools.clear_cache()

## 2. Loop

In [16]:
# redshift steps
p21c.global_params.ZPRIME_STEP_FACTOR = 1.05
p21c.global_params.Z_HEAT_MAX = 44.
z_edges = get_z_arr(z_end=6.)
z_mids = np.sqrt(z_edges[1:] * z_edges[:-1])
z_dh_stops = z_edges[1]
#p21c.global_params.CLUMPING_FACTOR = 1.

# dark matter
dm_params = DMParams(mode='swave', primary='phot_delta', m_DM=1e10, sigmav=1e-23)

# DarkHistory
if len(dhinit_list) > 0:
    dhinit_fn = f"{p21c.config['direc']}/dhinit_soln.p"
    
    if os.path.exists(dhinit_fn):
        dhinit_soln = pickle.load(open(dhinit_fn, 'rb'))
    else:
        logger.info('Running DarkHistory to generate initial conditions.')
        
        import main
        dhinit_soln = main.evolve(
            DM_process=dm_params.mode, mDM=dm_params.m_DM,
            sigmav=dm_params.sigmav, primary=dm_params.primary,
            struct_boost=phys.struct_boost_func(model=struct_boost_model),
            start_rs=3000, end_rs=(1+z_dh_stops), coarsen_factor=12, verbose=1
        )
        pickle.dump(dhinit_soln, open(dhinit_fn, 'wb'))

In [None]:
# 21cmFAST
initial_conditions = p21c.initial_conditions(
    user_params = p21c.UserParams(
        HII_DIM = 50, # [1]
        BOX_LEN = 50, # [p-Mpc]
        N_THREADS = N_THREADS
    ),
    cosmo_params = p21c.CosmoParams(
        OMm=0.32,
        OMb=0.049,
        POWER_INDEX=0.96,
        SIGMA_8=0.83,
        hlittle=0.67
    ),
    random_seed=54321, write=True
)
box_dim = initial_conditions.user_params.HII_DIM

# recording
records = []
input_time_tot = 0.
p21c_time_tot = 0.
if use_tqdm:
    pbar = tqdm(total=len(z_edges)-1, position=0)
    
if save_slices:
    saved_slices = []
    i_slice = int(box_dim/2)

In [25]:
# xray
ex_lo, ex_hi = 1e2, 1e4 # [eV]
ix_lo = np.searchsorted(photeng, ex_lo) # i of first bin greater than ex_lo, excluded
ix_hi = np.searchsorted(photeng, ex_hi) # i of first bin greater than ex_hi, included

def split_xray(phot_N):
    bath_N = phot_N.copy()
    xray_N = phot_N.copy()
    bath_N[ix_lo:ix_hi] *= 0
    xray_N[:ix_lo] *= 0
    xray_N[ix_hi:] *= 0
    return bath_N, xray_N

# tmp remove
os.remove(p21c.config['direc']+'/xray_brightness.h5')

xray_windowed_data = WindowedData(
    data_path=p21c.config['direc']+'/xray_brightness.h5',
    cosmo=Planck18,
    N=initial_conditions.user_params.HII_DIM,
    dx=initial_conditions.user_params.BOX_LEN / initial_conditions.user_params.HII_DIM,
    cache=True,
)

### TODO
- [ ] Produce phot_prop_tf, phot_scat_tf, and phot_phot_tf = phot_prop_tf + phot_scat_tf.
- [ ] Produce phot_dep_tf with xray bin only including scattered xray.
- [ ] Build attenuator.

In [18]:
for i_z in range(len(z_edges)):
    
    input_timer = time.time()
    
    if i_z == 0:
        
        z = z_edges[i_z]
        # At this step we will arrive at z_edges[0], so z_mid is not defined yet.
        spin_temp = None
        input_heating = input_ionization = input_jalpha = None
        
        #----- load tfs -----
        phot_phot_tf = load_data('phot_phot', reload=DEBUG_RELOAD_TF)
        phot_dep_tf = load_data('phot_dep', reload=DEBUG_RELOAD_TF)
        if DEBUG_ENABLE_ELEC:
            elec_phot_tf = load_data('elec_phot', reload=DEBUG_RELOAD_TF)
            elec_dep_tf = load_data('elec_dep', reload=DEBUG_RELOAD_TF)
        
        #----- initialize DM in_spec -----
        phot_phot_tf.set_fixed_in_spec(dm_params.inj_phot_spec.N)
        phot_dep_tf.set_fixed_in_spec(dm_params.inj_phot_spec.N)
        if DEBUG_ENABLE_ELEC:
            elec_phot_tf.set_fixed_in_spec(dm_params.inj_elec_spec.N)
            elec_dep_tf.set_fixed_in_spec(dm_params.inj_elec_spec.N)
        
        #----- initialize photon bath -----
        # note that photon bath also includes all uniform xray contributions
        if 'phot' in dhinit_list:
            dh_spec = dhinit_soln['highengphot'][-1] # [N per Bavg]
            phot_bath_spec = Spectrum(dh_spec.eng, dh_spec.N, rs=1+z_edges[0], spec_type='N')
        else:
            phot_bath_spec = Spectrum(photeng, np.zeros_like(photeng), rs=1+z_edges[0], spec_type='N') # [N per Bavg]
    
    
    else: # input from second step
        
        z = z_edges[i_z]
        z_mid = z_mids[i_z-1] # At this step we will arrive at z_edges[i], passing through z_mids[i-1].
        
        input_heating = p21c.input_heating(redshift=z, init_boxes=initial_conditions, write=False)
        input_ionization = p21c.input_ionization(redshift=z, init_boxes=initial_conditions, write=False)
        input_jalpha = p21c.input_jalpha(redshift=z, init_boxes=initial_conditions, write=False)
        
        if i_z == 1:
            if 'T_k' in dhinit_list:
                T_k_DH = dhinit_soln['Tm'][-1] / phys.kB # [K]
                spin_temp.Tk_box +=  T_k_DH - np.mean(spin_temp.Tk_box)
                # T_k_21cmfast = np.mean(spin_temp.Tk_box)
                # input_heating.input_heating += (T_k_DH - T_k_21cmfast) # old method of adjusting input boxes
            
            if 'x_e' in dhinit_list:
                x_e_DH = dhinit_soln['x'][-1, 0] # last step, HI
                spin_temp.x_e_box +=  dh_xe_global - np.mean(spin_temp.x_e_box)
                # x_e_21cmfast = np.mean(1 - ionized_box.xH_box)
                # input_ionization.input_ionization += (x_e_DH - x_e_21cmfast) # old method of adjusting input boxes

        if run_mode == 'no inj':
            if i_z == 1:
                logger.warning('Not injecting anything in this run!')

        else:
            #========== calculate some quantities ==========
            z_prev = z_edges[i_z-1]
            dt = phys.dt_between_z(z_prev, z) # [s]
            if dm_params.mode == 'swave':
                struct_boost = phys.struct_boost_func(model=struct_boost_model)(1+z_mid)
            else:
                struct_boost = 1
                
            n_Bavg = phys.n_B * (1+z_mid)**3 # [Bavg cm^-3]
            
            delta_box = jnp.asarray(perturbed_field.density)
            B_per_Bavg = 1 + delta_box
            rho_DM_box = (1 + delta_box) * phys.rho_DM * (1+z_mid)**3 # [eV cm^-3]
            x_e_box = jnp.asarray(1 - ionized_box.xH_box)
            inj_per_Bavg_box = phys.inj_rate(rho_DM_box, dm_params) * dt * struct_boost / n_Bavg # [inj/Bavg]
            
            tf_kwargs = dict(
                rs = 1 + z_mid,
                nBs_s = (1+delta_box).ravel(),
                x_s = x_e_box.ravel(),
                out_of_bounds_action = 'clip',
            )
            
            #========== initialize ==========
            prop_phot_N = np.zeros_like(photeng) # [N / Bavg]
            emit_phot_N = np.zeros_like(photeng) # [N / Bavg]
            dep_box = np.zeros(box_dim, box_dim, box_dim, len(abscs['dep_c']))
            # last dimension: ('H ion', 'He ion', 'exc', 'heat', 'cont', 'xray')
            
            #========== photon bath -> prop emit dep ==========
            prop_phot_N += phot_prop_tf(
                in_spec=phot_bath_spec.N, sum_result=True, **tf_kwargs,
            ) / (box_dim ** 3) # [N / Bavg]
            
            emit_phot_N += phot_scat_tf(
                in_spec=phot_bath_spec.N, sum_result=True, **tf_kwargs,
            ) / (box_dim ** 3) # [N / Bavg]
            
            dep_box += phot_dep_tf(
                in_spec=phot_bath_spec.N, sum_result=False, **tf_kwargs,
            ).reshape(dep_box.shape) # [eV / Bavg]
            
            #========== DM (prompt phot+elec) -> emit dep ==========
            emit_phot_N += phot_phot_tf(
                in_spec=dm_params.inj_phot_spec.N, sum_result=True, sum_weight=inj_per_Bavg_box.ravel(), **tf_kwargs,
            ) / (box_dim ** 3) # [N / Bavg]

            dep_box += phot_dep_tf(
                in_spec=dm_params.inj_phot_spec.N, sum_result=False, **tf_kwargs,
            ).reshape(dep_box.shape) * inj_per_Bavg_box[..., None] # [eV / Bavg]

            if DEBUG_ENABLE_ELEC:
                emit_phot_N += elec_phot_tf(
                    in_spec=dm_params.inj_elec_spec.N, sum_result=True, sum_weight=inj_per_Bavg_box.ravel(), **tf_kwargs,
                ) / (box_dim ** 3) # [N / Bavg]
                
                dep_box += elec_dep_tf(
                    in_spec=dm_params.inj_elec_spec.N, sum_result=False, **tf_kwargs,
                ).reshape(dep_box.shape) * inj_per_Bavg_box[..., None] # [eV / Bavg]
            
            #========== emitted xray (difference) -> emit dep ==========
            if run_mode == 'xray':
                for i_z_shell in range(2, i_z):
                    #-----[EDIT]-----
                    shell_smoothed, shell_spec = xray_windowed_data.get_smoothed_shell(
                        z_receiver,
                        z_evals[lookback_index],
                        z_evals[lookback_index+1],
                    )
                    xray_spec = ... # unit: [photon / Bavg]
                    # [HERE] redshift xray_spec, apply attenuation
                    # xray_mean_eng = xray_spec.toteng() # unit: [eV / Bavg]
                    xray_e_box = ... # = local xray band energy / mean xray band energy
                    #-----[EDIT]-----

                    emit_phot_N += phot_scat_tf(
                        in_spec=xray_spec.N, sum_result=True, sum_weight=xray_e_box.ravel(), **tf_kwargs,
                    ) / (box_dim ** 3) # [N / Bavg]
                    
                    dep_box += phot_dep_tf(
                        in_spec=xray_spec.N, sum_result=False, **tf_kwargs,
                    ).reshape(dep_box.shape) * xray_e_box[..., None] # [eV / Bavg] # CHECK SHAPE HERE
            else:
                pass # do nothing, because xray will be in photon bath already

            #========== update input_boxes ==========
            input_heating.input_heating += np.array(
                2 / (3*phys.kB*(1+x_e_box)) * dep_box[...,3] / B_per_Bavg
            ) # [K/Bavg] / [B/Bavg] = [K/B]
            input_ionization.input_ionization += np.array(
                (dep_box[...,0] + dep_box[...,1]) / phys.rydberg / B_per_Bavg
            ) # [1/Bavg] / [B/Bavg] = [1/B]
            
            n_lya = dep_box[...,2] * n_Bavg / phys.lya_eng # [lya cm^-3]
            dnu_lya = (phys.rydberg - phys.lya_eng) / (2*np.pi*phys.hbar) # [Hz^-1]
            J_lya = n_lya * phys.c / (4*np.pi) / dnu_lya # [lya cm^-2 s^-1 sr^-1 Hz^-1]
            input_jalpha.input_jalpha += np.array(J_lya)
            
            #========== record ==========
            dE_inj_per_Bavg = dm_params.eng_per_inj * np.mean(inj_per_Bavg_box) # [eV per Bavg]
            dE_inj_per_Bavg_unclustered = dE_inj_per_Bavg / struct_boost
            record_inj = {
                'dE_inj_per_B' : dE_inj_per_Bavg,
                'f_heat' : np.mean(dep_box[...,3]) / dE_inj_per_Bavg_unclustered,
                'f_ion'  : np.mean(dep_box[...,0] + dep_box[...,1]) / dE_inj_per_Bavg_unclustered,
                'f_exc'  : np.mean(dep_box[...,2]) / dE_inj_per_Bavg_unclustered,
            }
            
    input_time_tot += time.time() - input_timer

    #========== step in 21cmFAST ==========
    p21c_timer = time.time()
    perturbed_field = p21c.perturb_field( # perturbed_field controls the redshift
        redshift=z,
        init_boxes=initial_conditions
    )
    spin_temp = p21c.spin_temperature(
        perturbed_field=perturbed_field,
        previous_spin_temp=spin_temp,
        input_heating_box=input_heating,
        input_ionization_box=input_ionization,
        input_jalpha_box=input_jalpha,
        write=True
    )
    ionized_box = p21c.ionize_box(
        spin_temp=spin_temp
    )
    brightness_temp = p21c.brightness_temperature(
        ionized_box=ionized_box,
        perturbed_field=perturbed_field,
        spin_temp=spin_temp
    )
    coeval = p21c.Coeval(
        redshift = z,
        initial_conditions = initial_conditions,
        perturbed_field = perturbed_field,
        ionized_box = ionized_box,
        brightness_temp = brightness_temp,
        ts_box = spin_temp,
    )
    p21c_time_tot += time.time() - p21c_timer
    
    #========== prepare for next step ==========
    if run_mode == 'xray':
        emit_bath_N, emit_xray_N = split_xray(emit_phot_N)
        out_phot_N = prop_phot_N + emit_bath_N
        
        emit_xray_spec = Spectrum(photeng, emit_xray_N, rs=1+z, spec_type='N')   
        xray_e_box = dep_box[..., 5] / np.dot(photeng, emit_xray_N)
        # maybe do a fourier transform
        # save these two
        xray_windowed_data.set_field(phot_xray_box, phot_xray_spec, z_evals[zindex])
        xray_windowed_data.global_Tk = np.append(xray_windowed_data.global_Tk, np.mean(spin_temp.Tk_box))
        xray_windowed_data.global_x = np.append(xray_windowed_data.global_x, np.mean(ionized_field.xH_box))
    else:
        out_phot_N = prop_phot_N + emit_phot_N # everything treated as uniform
    
    out_phot_spec = Spectrum(photeng, out_phot_N, rs=1+z, spec_type='N')    
    if z != z_edges[-1]:
        out_phot_spec.redshift(1+z_edges[i_z+1])
        phot_bath_spec = out_phot_spec

    #========== save results ==========
    if i_z > 0:
        record = {
            'z'   : z,
            'T_s' : np.mean(spin_temp.Ts_box), # [mK]
            'T_b' : np.mean(brightness_temp.brightness_temp), # [K]
            'T_k' : np.mean(spin_temp.Tk_box), # [K]
            'x_e' : np.mean(1 - ionized_box.xH_box), # [1]
        }
        if run_mode == 'inj':
            record.update(record_inj)
        records.append(record)
        
    if save_slices:
        saved_slices.append({
            'z'   : z,
            'T_s' : spin_temp.Ts_box[i_slice], # [mK]
            'T_b' : brightness_temp.brightness_temp[i_slice], # [K]
            'T_k' : spin_temp.Tk_box[i_slice], # [K]
            'x_e' : 1 - ionized_box.xH_box[i_slice], # [1]
            'delta' : perturbed_field.density[i_slice], # [1]
        })
        
    if use_tqdm:
        pbar.update()
        
#========== end of loop ==========
    
print(f'input used {input_time_tot:.4f} s')
print(f'p21c used {p21c_time_tot:.4f} s')

arr_records = {k: np.array([r[k] for r in records]) for k in records[0].keys()}
np.save(f'../data/run_info/{run_name}_records', arr_records)
if save_slices:
    np.save(f'../data/run_info/{run_name}_slices', saved_slices)

100%|███████████████████████████████████████████████████████████████████████████████| 39/39 [00:05<00:00, 15.98it/s]

input used 0.7853 s
p21c used 2.5925 s


100%|███████████████████████████████████████████████████████████████████████████████| 39/39 [00:18<00:00, 15.98it/s]

## 3. Lightcone

In [29]:
import powerbox
# import importlib
# importlib.reload(powerbox)

In [39]:
lightcone = p21c.run_lightcone(
    redshift = z_edges[-1],
    user_params  = spin_temp.user_params,
    cosmo_params = spin_temp.cosmo_params,
    astro_params = spin_temp.astro_params,
    flag_options = spin_temp.flag_options,
)

In [None]:
fig, ax = plt.subplots(figsize = (15, 10))
plotting.lightcone_sliceplot(lightcone, fig=fig, ax=ax)
ax.set(aspect=10)

In [40]:
def compute_power(
   box,
   length,
   n_psbins,
   log_bins=True,
   ignore_kperp_zero=True,
   ignore_kpar_zero=False,
   ignore_k_zero=False,
):
    # Determine the weighting function required from ignoring k's.
    k_weights = np.ones(box.shape, int)
    n0 = k_weights.shape[0]
    n1 = k_weights.shape[-1]

    if ignore_kperp_zero:
        k_weights[n0 // 2, n0 // 2, :] = 0
    if ignore_kpar_zero:
        k_weights[:, :, n1 // 2] = 0
    if ignore_k_zero:
        k_weights[n0 // 2, n0 // 2, n1 // 2] = 0

    res = powerbox.tools.get_power(
        box,
        boxlength=length,
        bins=n_psbins,
        bin_ave=False,
        get_variance=False,
        log_bins=log_bins,
        k_weights=k_weights,
    )

    res = list(res)
    k = res[1]
    if log_bins:
        k = np.exp((np.log(k[1:]) + np.log(k[:-1])) / 2)
    else:
        k = (k[1:] + k[:-1]) / 2

    res[1] = k
    return res

def powerspectra(brightness_temp, n_psbins=50, nchunks=20, min_k=0.1, max_k=1.0, logk=True):
    data = []
    chunk_indices = list(range(0,brightness_temp.n_slices,round(brightness_temp.n_slices / nchunks),))    
    
    if len(chunk_indices) > nchunks:
        chunk_indices = chunk_indices[:-1]
    chunk_indices.append(brightness_temp.n_slices)

    for i in range(nchunks):
        start = chunk_indices[i]
        end = chunk_indices[i + 1]
        chunklen = (end - start) * brightness_temp.cell_size

        power, k = compute_power(
            brightness_temp.brightness_temp[:, :, start:end],
            (BOX_LEN, BOX_LEN, chunklen),
            n_psbins,
            log_bins=logk,
        )
        data.append({"k": k, "delta": power * k ** 3 / (2 * np.pi ** 2)})
    return data

In [41]:
BOX_LEN = 50
HII_DIM = 50

k_fundamental = 2*np.pi / BOX_LEN
k_max = k_fundamental * HII_DIM
Nk = np.floor(HII_DIM/1).astype(int)

In [42]:
out = powerspectra(lightcone, min_k=k_fundamental, max_k=k_max)

In [43]:
pickle.dump(out, open(f'../data/run_info/{run_name}_ps', 'wb'))