In [1]:
%reload_ext autoreload
%autoreload 2

import os, sys
sys.path.append('..')

import numpy as np
from tqdm import tqdm
import time

%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import cm
mpl.rc_file('../matplotlibrc')

In [2]:
# 21cmFAST
import py21cmfast as p21c
from py21cmfast import plotting, cache_tools
print(f'Using 21cmFAST version {p21c.__version__}')

import logging
logger = logging.getLogger('21cmFAST')
logger.setLevel(logging.INFO)

  "Your configuration file is out of date. Updating..."


Using 21cmFAST version 0.1.dev1578+g6f96f89.d20230224


In [3]:
from dm21cm.injection import get_input_boxs, DMParams

## 0. Global config

In [4]:
! lscpu | grep "CPU(s)"

CPU(s):                48
On-line CPU(s) list:   0-47
NUMA node0 CPU(s):     0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36,38,40,42,44,46
NUMA node1 CPU(s):     1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31,33,35,37,39,41,43,45,47


In [5]:
N_THREADS = 32

## TMP

In [39]:
from dm21cm.interpolators import BatchInterpolator

In [16]:
! ls /data/submit/yitians/dm21cm/DM21cm/transferfunctions/nBs_test_2/

elec_dep_dlnz4.879E-2_aad.p	 phot_dep_dlnz4.879E-2_aad.p
elec_dep_dlnz4.879E-2_rexo_ad.p  phot_dep_dlnz4.879E-2_renxo_ad.p
elec_tf_dlnz4.879E-2_aad.p	 phot_tf_dlnz4.879E-2_aad.p
elec_tf_dlnz4.879E-2_rexo_ad.p	 phot_tf_dlnz4.879E-2_renxo_ad.p
npy


In [40]:
interp = BatchInterpolator('/data/submit/yitians/dm21cm/DM21cm/transferfunctions/nBs_test_2/elec_dep_dlnz4.879E-2_aad.p')

In [45]:
x_in = np.random.uniform(size=(2000,))

In [47]:
r = interp(rs=8, in_spec=np.ones((500,)), x_s=x_in, out_of_bounds_action='clip')

In [48]:
np.sum(r, axis=0)

DeviceArray([3.5963162e+09, 6.7704112e+07, 8.8451154e+09, 4.5057245e+11,
             6.1683034e+08], dtype=float32)

In [50]:
interp(rs=8, in_spec=np.ones((500,)), x_s=x_in, sum_result=True, out_of_bounds_action='clip')

array([3.59631642e+09, 6.77041200e+07, 8.84511539e+09, 4.50572419e+11,
       6.16830400e+08])

## 1. Run

In [6]:
def get_z_arr(z_start=None, z_end=20):
    
    if z_start is None:
        z_start = p21c.global_params.Z_HEAT_MAX
    z_arr = [z_end]
    while np.max(z_arr) < z_start:
        z_prev = (1 + np.max(z_arr)) * p21c.global_params.ZPRIME_STEP_FACTOR - 1
        z_arr.append(z_prev)
    return np.array(z_arr[::-1][1:])

In [7]:
# check cached runs
CACHE_DIR_BASE = '/scratch/submit/ctp/yitians/21cmFAST-cache'
os.listdir(CACHE_DIR_BASE)

['emf_comp_spf_mh3',
 'emf_comp_dh_mh3',
 'emf_comp_dh_mh6',
 'emf_comp_dh_mh9',
 'emf_comp_base']

In [8]:
# run config
RUN_NAME = 'test'
f_scheme = 'DH'
struct_boost_model = 'erfc 1e-3'
run_mode = 'inj'

p21c.config['direc'] = f'{CACHE_DIR_BASE}/{RUN_NAME}'
os.makedirs(p21c.config['direc'], exist_ok=True)

In [9]:
cache_tools.clear_cache()

2023-03-10 23:43:11,676 | INFO | Removed 0 files from cache.
INFO:21cmFAST:Removed 0 files from cache.


In [10]:
cosmo_params_EMF = dict(OMm=0.32, OMb=0.049, POWER_INDEX=0.96, SIGMA_8=0.83, hlittle=0.67)

# initialize
initial_conditions = p21c.initial_conditions(
    user_params = p21c.UserParams(
        HII_DIM=50, # [1]
        BOX_LEN=50, # [p-Mpc]
        N_THREADS=N_THREADS
    ),
    cosmo_params = p21c.CosmoParams(**cosmo_params_EMF),
    random_seed=54321, write=True
)

# redshift
p21c.global_params.ZPRIME_STEP_FACTOR = 1.05
p21c.global_params.Z_HEAT_MAX = 44.
z_arr = get_z_arr(z_end=6.)
print(z_arr)

# dark matter
dm_params = DMParams(mode='swave', primary='mu', m_DM=1e10, sigmav=1e-26)

  "The USE_INTERPOLATION_TABLES setting has changed in v3.1.2 to be "


[43.69834103 41.5698486  39.54271295 37.61210757 35.77343578 34.02231979
 32.35459028 30.76627646 29.25359663 27.81294917 26.44090397 25.13419426
 23.88970882 22.70448459 21.57569961 20.50066629 19.47682504 18.50173813
 17.57308394 16.68865137 15.84633464 15.04412822 14.28012212 13.55249726
 12.8595212  12.199544   11.57099428 10.97237551 10.40226239  9.85929751
  9.34218811  8.84970296  8.38066948  7.93397094  7.50854375  7.103375
  6.7175      6.35        6.        ]


  coords_data = np.array(json.load(data_file))
  values_data = np.array(json.load(data_file))


In [11]:
logging.getLogger('py21cmfast._utils').setLevel(logging.CRITICAL+1)
logging.getLogger('py21cmfast.wrapper').setLevel(logging.CRITICAL+1)

In [61]:
records = []
i_slice = int(initial_conditions.user_params.HII_DIM/2)
input_time_tot = 0.
p21c_time_tot = 0.

photon_bath_spec = np.zeros_like(abscs['photE'])

for i_z in tqdm(range(len(z_arr))):

    z = z_arr[i_z]
    
    input_timer = time.time()
    if i_z == 0:
        spin_temp = None
        input_heating = input_ionization = input_jalpha = None
    
    else: ## input from second step
        
        input_heating = p21c.input_heating(redshift=z, init_boxes=initial_conditions, write=False)
        input_ionization = p21c.input_ionization(redshift=z, init_boxes=initial_conditions, write=False)
        input_jalpha = p21c.input_jalpha(redshift=z, init_boxes=initial_conditions, write=False)

        if run_mode == '':
            if i_z == 1:
                logger.warning('Not injecting anything in this run!')

        elif run_mode == 'inj':
            z_prev = z_arr[i_z-1]
            
            input_boxs = get_input_boxs(
                delta_box = perturbed_field.density, # [1]
                x_e_box = 1 - ionized_box.xH_box, # [1]
                z_prev = z_arr[i_z-1],
                z = z,
                dm_params = dm_params,
                f_scheme = f_scheme,
                struct_boost_model = struct_boost_model
            )

            input_heating.input_heating += input_boxs['heat']
            input_ionization.input_ionization += input_boxs['ion']
            input_jalpha.input_jalpha += input_boxs['exc']
            
        else:
            raise ValueError(run_mode)
            
    input_time_tot += time.time() - input_timer

    ## step in 21cmFAST
    p21c_timer = time.time()
    perturbed_field = p21c.perturb_field( # perturbed_field controls the redshift
        redshift=z,
        init_boxes=initial_conditions
    )
    spin_temp = p21c.spin_temperature(
        perturbed_field=perturbed_field,
        previous_spin_temp=spin_temp,
        input_heating_box=input_heating,
        input_ionization_box=input_ionization,
        input_jalpha_box=input_jalpha,
        write=True
    )
    ionized_box = p21c.ionize_box(
        spin_temp=spin_temp
    )
    brightness_temp = p21c.brightness_temperature(
        ionized_box=ionized_box,
        perturbed_field=perturbed_field,
        spin_temp=spin_temp
    )
    p21c_time_tot += time.time() - p21c_timer

    ## save results
    if i_z > 0:
        record = {
            'z'   : z,
            'T_s' : np.mean(spin_temp.Ts_box), # [mK]
            'T_b' : np.mean(brightness_temp.brightness_temp), # [K]
            'T_k' : np.mean(spin_temp.Tk_box), # [K]
            'x_e' : np.mean(1 - ionized_box.xH_box), # [1]
        }
        if run_mode == 'inj':
            record_inj = {
                'dE_inj_per_B' : input_boxs['dE_inj_per_B_mean'], # [eV per B]
                'f_heat' : input_boxs['f_heat_mean'],
                'f_ion'  : input_boxs['f_ion_mean'],
                'f_exc'  : input_boxs['f_exc_mean'],
            }
            record.update(record_inj)
        records.append(record)
    
print(f'input used {input_time_tot:.4f} s')
print(f'p21c used {p21c_time_tot:.4f} s')

100%|██████████| 39/39 [01:56<00:00,  2.98s/it]

input used 0.4286 s
p21c used 115.5553 s





In [62]:
arr_records = {k: np.array([r[k] for r in records]) for k in records[0].keys()}
np.save(f'../data/run_info/{RUN_NAME}_records', arr_records)