In [1]:
# Basic python imports
import sys, os, time, h5py, logging, warnings
import numpy as np
from scipy import interpolate, integrate

# The LambdaCDM cosmology
from astropy.cosmology import Planck18 as cosmo

# Import modified 21cmFAST
import py21cmfast as p21c
from py21cmfast import cache_tools

# Configure environment for use with DarkHistory
os.environ['DH_DIR']='/global/scratch/projects/pc_heptheory/fosterjw/21CM_Project/DarkHistory/'
sys.path.append(os.environ['DH_DIR'])
from darkhistory.spec.spectrum import Spectrum # use branch test_dm21cm

# Import DM21CM code for this project
sys.path.append("..")
import dm21cm.physics as phys
from dm21cm.utils import split_xray, get_z_edges, gen_injection_boxes, p21_step
from dm21cm.dh_wrapper import DMParams, DarkHistoryWrapper
from dm21cm.data_cacher import SpectrumCache, BrightnessCache 

# Plotting
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import colormaps as cms
mpl.rc_file("../matplotlibrc")
mpl.rcParams['text.usetex']=False

# Logging
logging.getLogger().setLevel(logging.INFO)
logging.getLogger('21cmFAST').setLevel(logging.CRITICAL+1)
logging.getLogger('py21cmfast._utils').setLevel(logging.CRITICAL+1)
logging.getLogger('py21cmfast.wrapper').setLevel(logging.CRITICAL+1)
logging.info(f'Using 21cmFAST version {p21c.__version__}')

import powerbox

INFO:root:Using 21cmFAST version 0.1.dev1579+g6b1da6d


# Configure the 21cmFAST

In [2]:
p21c.config['direc'] = '../DebugCache/'
cache_tools.clear_cache()

# The range of times and how we step
z_start = 45.
z_end = 5.
z_step_factor = 1.01

# The size and resolution of our box
HII_DIM = 64 # ['Dimensionless']
BOX_LEN = HII_DIM*2. # ['Mpccm']

In [3]:
p21c.global_params.Z_HEAT_MAX = z_start
p21c.global_params.ZPRIME_STEP_FACTOR = z_step_factor
p21c.global_params.CLUMPING_FACTOR = 1.

p21c_initial_conditions = p21c.initial_conditions(
    user_params = p21c.UserParams(
        HII_DIM = HII_DIM,
        BOX_LEN = BOX_LEN,
        N_THREADS = 32,
    ),
    cosmo_params = p21c.CosmoParams(
        OMm = cosmo.Om(0),
        OMb = cosmo.Ob(0),
        POWER_INDEX =cosmo.meta['n'],
        SIGMA_8 = cosmo.meta['sigma8'],
        hlittle = cosmo.h,
    ),
    random_seed = 54321,
    write = True,
)



# Configure the Physics with Dark History

In [4]:
struct_boost_model = 'erfc 1e-3'
run_mode = 'xray'
dhinit_list = ['phot', 'T_k', 'x_e']
dhtf_version = '230629'

dh_init_path = f"{p21c.config[f'direc']}/dhinit_soln.p"
abscs_path = f'../data/abscissas/abscs_{dhtf_version}.h5'
transfer_prefix = f'../../tf/{dhtf_version}/phot'

# Our energy injection model
dm_params = DMParams(mode = 'swave', primary = 'phot_delta', m_DM = 1e10, abscs_path = abscs_path, sigmav = 1e-23)

# Determine the appropriate boost factor
if dm_params.mode == 'swave':
    struct_boost = lambda rs: phys.struct_boost_func(model=struct_boost_model)(rs)
else:
    struct_boost = lambda rs: 1
    
# The DH Wrapper Class  
dh_wrapper = DarkHistoryWrapper(HII_DIM, dhinit_list, 1.01,
                                dh_init_path, abscs_path, transfer_prefix,
                                enable_elec = False, force_reload_tf = False)

INFO:root:Loaded photon propagation transfer function.
INFO:root:Loaded photon scattering transfer function.
INFO:root:Loaded photon deposition transfer function.


# Set Details of Caching

In [5]:
# Some details regarding our stepping
z_edges = get_z_edges(z_start, z_end, 1.01)

# Where we start looking for annuli
xray_loop_start = 0

ex_lo, ex_hi = 1e2, 1e4 # [eV]
ix_lo = np.searchsorted(dh_wrapper.photeng, ex_lo) # i of first bin greater than ex_lo, excluded
ix_hi = np.searchsorted(dh_wrapper.photeng, ex_hi) # i of first bin greater than ex_hi, included

xray_fn = p21c.config['direc']+'/xray_brightness.h5'
if os.path.isfile(xray_fn):
    os.remove(xray_fn)
    
box_cache = BrightnessCache(data_path = xray_fn, cosmo = cosmo, N = HII_DIM, dx = BOX_LEN / HII_DIM)
spectrum_cache = SpectrumCache()

# Evolution

### Synchronization

In [6]:
# Initial step with no injection
perturbed_field = p21c.perturb_field(redshift=z_edges[0], init_boxes=p21c_initial_conditions)
spin_temp, ionized_box, brightness_temp = p21_step(z_edges[0], perturbed_field, None, None)

# Manual sychronization of 21cmFAST with DarkHistory state
if 'T_k' in dhinit_list:
    T_k_DH = np.interp(1+spin_temp.redshift, dh_wrapper.dhinit_soln['rs'][::-1],
                       dh_wrapper.dhinit_soln['Tm'][::-1] / phys.kB) # [K]
    spin_temp.Tk_box += T_k_DH - np.mean(spin_temp.Tk_box)

if 'x_e' in dhinit_list:
    x_e_DH = np.interp(1+spin_temp.redshift, dh_wrapper.dhinit_soln['rs'][::-1],
                       dh_wrapper.dhinit_soln['x'][::-1, 0]) # HI
    
    spin_temp.x_e_box += x_e_DH - np.mean(spin_temp.x_e_box)
    x_H_DH = 1 - x_e_DH
    ionized_box.xH_box += x_H_DH - np.mean(ionized_box.xH_box)

if 'phot' in dhinit_list:
    logrs_dh_arr = np.log(dh_wrapper.dhinit_soln['rs'])[::-1]
    logrs = np.log(1+spin_temp.redshift)
    i = np.searchsorted(logrs_dh_arr, logrs)
    logrs_left, logrs_right = logrs_dh_arr[i-1:i+1]

    dh_spec_N_arr = np.array([s.N for s in dh_wrapper.dhinit_soln['highengphot']])[::-1]
    dh_spec_left, dh_spec_right = dh_spec_N_arr[i-1:i+1]
    dh_spec = ( dh_spec_left * np.abs(logrs - logrs_right) + \
                dh_spec_right * np.abs(logrs - logrs_left) ) / np.abs(logrs_right - logrs_left)
    phot_bath_spec = Spectrum(dh_wrapper.photeng, dh_spec, rs=1+spin_temp.redshift, spec_type='N')
else:
    phot_bath_spec = Spectrum(dh_wrapper.photeng, np.zeros_like(photeng),
                              rs=1+spin_temp.redshift, spec_type='N') # [N per Bavg]

## Now that we are synchronized, we enter our loop

In [7]:
def get_time_step(i_z):
    current_z=z_edges[i_z]
    next_z = z_edges[i_z+1]
    
    # The cosmic time step size in [s]
    dt= ( cosmo.age(next_z) - cosmo.age(current_z) ).to('s').value
    return current_z, next_z, dt

In [None]:
records = []

for i_z in range(len(z_edges)-1):
    start = time.time()
    
    print('Starting step: ', i_z)
    
    current_z, next_z, dt = get_time_step(i_z)
    nBavg = phys.n_B * (1+current_z)**3 # [Bavg / (physical cm)^3]
    
    print('Currently at: ', current_z)
    print('Advancing to: ', next_z)
    
    # Derived quantities that I need
    delta_plus_one_box = 1+ np.asarray(perturbed_field.density)
    rho_DM_box = delta_plus_one_box * phys.rho_DM * (1+current_z)**3 # [eV/(physical cm)^3]
    x_e_box = np.asarray(1 - ionized_box.xH_box) # check this
    inj_per_Bavg_box = phys.inj_rate(rho_DM_box, dm_params) * dt * struct_boost(1+current_z) / nBavg # [inj/Bavg]
    
    # Kwargs for the transfer function interpolation. I had to put the hacky cap on `x_s`. Look into this
    tf_kwargs = dict(rs = 1 + current_z,
                     nBs_s = delta_plus_one_box.ravel(),
                     x_s = np.minimum(.999, x_e_box).ravel(), 
                     out_of_bounds_action = 'clip'
                    )
    
    # Construct the empty arrays we deposit into
    dh_wrapper.set_empty_arrays()   
    
    # Set the kwargs for use in interpolations
    dh_wrapper.set_tf_kwargs(tf_kwargs)
    
    #############################
    ###   Energy Deposition   ###
    #############################
        
    # Now calculate photon emission and energy deposition from our X-ray annuli
    for i_z_shell in range(xray_loop_start, i_z):
        
        # Load the effective density
        effective_density, is_box_average= box_cache.get_smoothed_shell(current_z,
                                                                                 z_edges[i_z_shell],
                                                                                 z_edges[i_z_shell+1])
        # Load the emission spectrum
        xray_spec = spectrum_cache.spectrum_list[i_z_shell]

        # If we are smoothing on the scale of the box then dump to the global bath spectrum.
        # The deposition will happen later, and we will not revisit this shell.
        if is_box_average:
            phot_bath_spec.N += effective_density[0, 0, 0]*xray_spec.N
            xray_loop_start = max(i_z_shell+1, xray_loop_start)

        else:
            dh_wrapper.photon_injection(xray_spec, bath = False, weight_box = effective_density)

    # Homogeneous bath injection
    dh_wrapper.photon_injection(phot_bath_spec)
    
    # Dark Matter injection
    dh_wrapper.photon_injection(dm_params.inj_phot_spec, bath =False, weight_box=inj_per_Bavg_box, ots = True)
    
    #############################################################
    ###   Generate the input boxes and take a 21cmFAST step   ###
    #############################################################
    
    # Access the propagating photon spectrum, emitted photon spectrum, and deposition box
    prop_phot_N, emit_phot_N, dep_box = dh_wrapper.get_state_arrays()
    
    perturbed_field = p21c.perturb_field(redshift=next_z, init_boxes=p21c_initial_conditions)    
    input_heating, input_ionization, input_jalpha = gen_injection_boxes(next_z, p21c_initial_conditions)
    dh_wrapper.populate_injection_boxes(input_heating, input_ionization, input_jalpha,
                                        x_e_box, delta_plus_one_box, nBavg)
     
    spin_temp, ionized_box, brightness_temp = p21_step(next_z, perturbed_field, spin_temp, ionized_box,
                                                       input_heating, input_ionization, input_jalpha)
    
    ########################################################
    ###   Prepare X-Ray and Bath Spectra for Next Step   ###
    ########################################################
    
    # Advance all cached spectra through this redshift step under the assumption of global
    # average quantities. First we attenuate, then we redshift.        
    dep_tf_at_point = dh_wrapper.phot_dep_tf.point_interp(rs=1+current_z, nBs=1, x=np.mean(x_e_box))
    dep_toteng = np.sum(dep_tf_at_point[:, :4], axis=1)
    
    spectrum_cache.attenuate(1 - dep_toteng/dh_wrapper.photeng)
    spectrum_cache.redshift(next_z)
    
    ################################################################
    ###   Cache X-Ray Emission from this Step and Prepare Bath   ###
    ################################################################
        
    # Split the x-ray spectrum into bath and emission
    emit_bath_N, emit_xray_N = split_xray(emit_phot_N, ix_lo, ix_hi)
    out_phot_N = prop_phot_N + emit_bath_N
    
    # Prepare the bath spectrum for the next step    
    out_phot_spec = Spectrum(dh_wrapper.photeng, out_phot_N, rs=1+current_z, spec_type='N')
    out_phot_spec.redshift(1+next_z)
    phot_bath_spec = out_phot_spec
    
    # Cache the x-ray brightness box
    xray_e_box = dep_box[..., 5] / np.dot(dh_wrapper.photeng, emit_xray_N) # energy / B_avg
    box_cache.cache_box(box=xray_e_box, z=current_z)
    
    # Cache the x-ray spectrum. This needs to be redshifted to the next time but not attenuated.
    xray_spec = Spectrum(dh_wrapper.photeng, emit_xray_N, rs=1+current_z, spec_type='N') # [photon / Bavg]
    xray_spec.redshift(1+next_z)
    spectrum_cache.append(xray_spec, current_z)
    
    #######################################################
    ###   Save some Global Quantities for Convenience   ###
    #######################################################
    
    dE_inj_per_Bavg = dm_params.eng_per_inj * np.mean(inj_per_Bavg_box) # [eV per Bavg]
    dE_inj_per_Bavg_unclustered = dE_inj_per_Bavg / struct_boost(1+current_z)

     
    record_inj = {'dE_inj_per_B' : dE_inj_per_Bavg,
                  'f_ion'  : np.mean(dep_box[...,0] + dep_box[...,1]) / dE_inj_per_Bavg_unclustered,
                  'f_exc'  : np.mean(dep_box[...,2]) / dE_inj_per_Bavg_unclustered,
                  'f_heat' : np.mean(dep_box[...,3]) / dE_inj_per_Bavg_unclustered,
    }
    
    record = {
        'z'   : next_z,
        'T_s' : np.mean(spin_temp.Ts_box), # [mK]
        'T_b' : np.mean(brightness_temp.brightness_temp), # [K]
        'T_k' : np.mean(spin_temp.Tk_box), # [K]
        'x_e' : np.mean(1 - ionized_box.xH_box), # [1]
        'E_phot' : phot_bath_spec.toteng(), # [eV/Bavg]
    }
    if run_mode in ['bath', 'xray']:
        record.update(record_inj)
    records.append(record)
    print(record['T_b'])
    
    arr_records = {k: np.array([r[k] for r in records]) for k in records[0].keys()}
    np.save('./New_Debug', arr_records)
    
    end = time.time()
    print('Time Elapsed:', end-start)
    print()

Starting step:  0
Currently at:  45.0
Advancing to:  44.67846508307498


INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'


-17.926395
Time Elapsed: 19.831497192382812

Starting step:  1
Currently at:  44.67846508307498
Advancing to:  44.22620305254949
44.67846508307498 45.0 44.67846508307498 8.181227271456812 0.0


  W1 = 3*(np.sin(self.kMag*R1) - self.kMag*R1 * np.cos(self.kMag*R1)) /(self.kMag*R1)**3
  W2 = 3*(np.sin(self.kMag*R2) - self.kMag*R2 * np.cos(self.kMag*R2)) /(self.kMag*R2)**3


-17.191023
Time Elapsed: 5.923879384994507

Starting step:  2
Currently at:  44.22620305254949
Advancing to:  43.77841886391038
44.22620305254949 45.0 44.67846508307498 19.836647575195023 11.655420328125528
44.22620305254949 44.67846508307498 44.22620305254949 11.655420328125528 0.0
-16.172785
Time Elapsed: 7.2218239307403564

Starting step:  3
Currently at:  43.77841886391038
Advancing to:  43.335068182089486
43.77841886391038 45.0 44.67846508307498 31.550901784142937 23.369674507546076
43.77841886391038 44.67846508307498 44.22620305254949 23.369674507546076 11.71425421787273
43.77841886391038 44.22620305254949 43.77841886391038 11.71425421787273 0.0


# Below we make Lightcones and Power Spectra

In [None]:
lightcone_quantities = ('brightness_temp','Ts_box','xH_box',"dNrec_box",'z_re_box',
                        'Gamma12_box','J_21_LW_box',"density")


In [None]:
lightcone_fid = p21c.run_lightcone(
        redshift = z_edges[-1],
        init_box = p21c_initial_conditions,
        flag_options = ionized_box.flag_options,
        astro_params = ionized_box.astro_params,
        lightcone_quantities=lightcone_quantities,
        global_quantities=lightcone_quantities,
        random_seed = 54321,
        direc = '../DebugCache/',
)


In [None]:
from py21cmfast import plotting

In [None]:
fig, ax = plt.subplots(figsize = (14, 10))
plotting.lightcone_sliceplot(lightcone_fid, fig =fig,ax = ax) 

In [None]:
def compute_power(
   box,
   length,
   n_psbins,
   log_bins=True,
   ignore_kperp_zero=True,
   ignore_kpar_zero=False,
   ignore_k_zero=False,
):
    # Determine the weighting function required from ignoring k's.
    k_weights = np.ones(box.shape, int)
    n0 = k_weights.shape[0]
    n1 = k_weights.shape[-1]

    if ignore_kperp_zero:
        k_weights[n0 // 2, n0 // 2, :] = 0
    if ignore_kpar_zero:
        k_weights[:, :, n1 // 2] = 0
    if ignore_k_zero:
        k_weights[n0 // 2, n0 // 2, n1 // 2] = 0

    res = powerbox.tools.get_power(
        box,
        boxlength=length,
        bins=n_psbins,
        bin_ave=False,
        get_variance=False,
        log_bins=log_bins,
        k_weights=k_weights,
    )

    res = list(res)
    k = res[1]
    if log_bins:
        k = np.exp((np.log(k[1:]) + np.log(k[:-1])) / 2)
    else:
        k = (k[1:] + k[:-1]) / 2

    res[1] = k
    return res

def powerspectra(brightness_temp, n_psbins=50, nchunks=20, min_k=0.1, max_k=1.0, logk=True):
    data = []
    chunk_indices = list(range(0,brightness_temp.n_slices,round(brightness_temp.n_slices / nchunks),))    
    
    if len(chunk_indices) > nchunks:
        chunk_indices = chunk_indices[:-1]
    chunk_indices.append(brightness_temp.n_slices)

    for i in range(nchunks):
        start = chunk_indices[i]
        end = chunk_indices[i + 1]
        chunklen = (end - start) * brightness_temp.cell_size

        power, k = compute_power(
            brightness_temp.brightness_temp[:, :, start:end],
            (BOX_LEN, BOX_LEN, chunklen),
            n_psbins,
            log_bins=logk,
        )
        data.append({"k": k, "delta": power * k ** 3 / (2 * np.pi ** 2)})
    return data

In [None]:
k_fundamental = 2*np.pi/BOX_LEN
k_max = k_fundamental * HII_DIM
Nk=np.floor(HII_DIM/1).astype(int)

In [None]:
out = powerspectra(lightcone_fid, min_k = k_fundamental, max_k = k_max)


In [None]:
fig, axs = plt.subplots(ncols = 5, nrows = 4, figsize = (30, 16))

for i, item in enumerate(out):
    row_index, col_index = np.unravel_index(i, axs.shape)
    
    ax = axs[row_index, col_index]
    ax.plot(item['k'], item['delta'], color = 'black')
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.text(.6, .1, 'Redshift Chunk ' + str(i), transform=ax.transAxes, fontsize = 18)
    
for i in range(axs.shape[0]-1):
    for j in range(axs.shape[1]):
        axs[i, j].set_xticklabels([])

for j in range(axs.shape[1]):
    axs[-1, j].set_xlabel('k [Mpc$^{-1}$]', fontsize = 22)
for i in range(axs.shape[0]):
    axs[i, 0].set_ylabel('$k^3 P(k)$', fontsize = 22)
    
plt.tight_layout()
plt.show()