In [1]:
%reload_ext autoreload
%autoreload 2

import os
import sys

from astropy.cosmology import Planck18
import py21cmfast as p21c

sys.path.append("..")
from dm21cm.dm_params import DMParams
from dm21cm.evolve import evolve

WDIR = os.environ['DM21CM_DIR']



In [2]:
def run():
    os.environ['DM21CM_DATA_DIR'] = '/n/holyscratch01/iaifi_lab/yitians/dm21cm/DM21cm/data/tf/zf01/data'
    evolve(
        run_name = f'ct_128_256Mpc_xray_noLX_nopop2_test',
        z_start = 45.,
        z_end = 35.,
        zplusone_step_factor = 1.01,
        dm_params = DMParams(
            mode='decay',
            primary='phot_delta',
            m_DM=1e8, # [eV]
            lifetime=1e50, # [s]
        ),
        enable_elec = False,
        tf_version = 'zf01',
        
        p21c_initial_conditions = p21c.initial_conditions(
            user_params = p21c.UserParams(
                HII_DIM = 128,
                BOX_LEN = 128*2, # [conformal Mpc]
                N_THREADS = 32,
            ),
            cosmo_params = p21c.CosmoParams(
                OMm = Planck18.Om0,
                OMb = Planck18.Ob0,
                POWER_INDEX = Planck18.meta['n'],
                SIGMA_8 = Planck18.meta['sigma8'],
                #SIGMA_8 = 1e-6,
                hlittle = Planck18.h,
            ),
            random_seed = 54321,
            write = True,
        ),
        
        rerun_DH = False,
        clear_cache = True,
        use_tqdm = True,
        #debug_flags = ['uniform_xray'], # homogeneous injection
        #debug_flags = ['xraycheck', 'xc-noatten'], # our xray noatten to compare with 21cmfast
        debug_flags = ['xraycheck'], # our xray ST compare with DH
        #debug_flags = ['xraycheck', 'xc-bath', 'xc-force-bath'], # our xray ST forced to bath compare with DH
        debug_astro_params = p21c.AstroParams(L_X = 0.), # log10 value
        use_DH_init = True,
        custom_YHe = 0.245, # 0.245
        debug_turn_off_pop2ion = True,
        debug_copy_dh_init = f"{WDIR}/outputs/dh/xc_xrayST_soln.p",
        track_Tk_xe = True,
        #use_21totf=f"{WDIR}/outputs/stdout/xc_nopop2_noHe_nosp_noatten_esf.out",
        #debug_even_split_f = True,
        #tf_on_device = False,
        debug_skip_dm_injection = True,
    )

In [3]:
import cProfile

In [4]:
cProfile.run('run()', 'run_stats')

INFO:root:Using 21cmFAST version 0.1.dev1586+g60df221.d20231025
INFO:root:Cache dir: /n/holyscratch01/iaifi_lab/yitians/21cmFAST-cache/ct_128_256Mpc_xray_noLX_nopop2_test
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA Interpreter
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
INFO:root:TransferFunctionWrapper: Loaded photon transfer functions.
INFO:root:Copied dh_init_soln.p from /n/home07/yitians/dm21cm/DM21cm/outputs/dh/xc_xrayST_soln.p
INFO:root:DarkHistoryWrapper: Found existing DarkHistory initial conditions.
100%|██████████| 24/24 [03:51<00:00,  9.64s/it]

xraycheck: 6.5065 +/- 3.2637 s
21cmFAST: 3.2093 +/- 0.1937 s
prep_next: 0.0603 +/- 0.0270 s





In [5]:
import pstats
from pstats import SortKey


## jaxfft

In [12]:
p = pstats.Stats('run_stats_jaxfft')

In [13]:
p.sort_stats(SortKey.CUMULATIVE).print_stats(20)

Sun Oct 29 15:35:52 2023    run_stats_jaxfft

         7592848 function calls (7563713 primitive calls) in 267.474 seconds

   Ordered by: cumulative time
   List reduced from 2684 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      6/1    0.000    0.000  277.889  277.889 {built-in method builtins.exec}
        1    0.000    0.000  277.889  277.889 <string>:1(<module>)
        1    0.018    0.018  277.889  277.889 /tmp/ipykernel_1355924/2603915052.py:1(run)
        1    3.248    3.248  254.214  254.214 /n/home07/yitians/dm21cm/DM21cm/benchmarking/../dm21cm/evolve.py:39(evolve)
      276    0.017    0.000  118.433    0.429 /n/home07/yitians/.conda/envs/dm21cm/lib/python3.11/site-packages/scipy/interpolate/_rgi.py:274(__call__)
       25    0.198    0.008   77.898    3.116 /n/home07/yitians/dm21cm/DM21cm/benchmarking/../dm21cm/evolve.py:639(p21c_step)
      276   68.160    0.247   72.489    0.263 /n/home07/yitians/.conda/envs/dm21c

<pstats.Stats at 0x7f7c5511b990>

In [11]:
p.sort_stats(SortKey.TIME).print_stats(20)

Sun Oct 29 15:35:52 2023    run_stats_jaxfft

         7592848 function calls (7563713 primitive calls) in 267.474 seconds

   Ordered by: internal time
   List reduced from 2684 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      276   68.160    0.247   72.489    0.263 /n/home07/yitians/.conda/envs/dm21cm/lib/python3.11/site-packages/scipy/interpolate/_rgi.py:391(_evaluate_linear)
      776   51.361    0.066   51.361    0.066 {method 'read' of 'h5py._selector.Reader' objects}
       25   33.776    1.351   33.776    1.351 {built-in method py21cmfast.c_21cmfast.ComputeTsBox}
      276   18.459    0.067   26.330    0.095 {scipy.interpolate._rgi_cython.find_indices}
        1   10.898   10.898   10.898   10.898 {built-in method py21cmfast.c_21cmfast.ComputeInitialConditions}
       25   10.380    0.415   10.380    0.415 {built-in method py21cmfast.c_21cmfast.ComputePerturbField}
     7652    8.653    0.001    8.653    0.001 {method 

<pstats.Stats at 0x7f7c5645a110>

## scipy

In [None]:
p.sort_stats(SortKey.CUMULATIVE).print_stats(20)

Sun Oct 29 15:29:30 2023    run_stats_scipyfft

         3354756 function calls (3345503 primitive calls) in 269.990 seconds

   Ordered by: cumulative time
   List reduced from 1200 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000  278.887  278.887 {built-in method builtins.exec}
        1    0.000    0.000  278.887  278.887 <string>:1(<module>)
        1    0.012    0.012  278.887  278.887 /tmp/ipykernel_1352011/2603915052.py:1(run)
        1    3.313    3.313  278.338  278.338 /n/home07/yitians/dm21cm/DM21cm/benchmarking/../dm21cm/evolve.py:39(evolve)
      276    0.016    0.000  117.066    0.424 /n/home07/yitians/.conda/envs/dm21cm/lib/python3.11/site-packages/scipy/interpolate/_rgi.py:274(__call__)
       25    0.192    0.008   77.887    3.115 /n/home07/yitians/dm21cm/DM21cm/benchmarking/../dm21cm/evolve.py:639(p21c_step)
      276   69.327    0.251   73.928    0.268 /n/home07/yitians/.conda/envs/dm2

<pstats.Stats at 0x7f7cb9009c50>

In [None]:
p.sort_stats(SortKey.TIME).print_stats(20)

Sun Oct 29 15:29:30 2023    run_stats_scipyfft

         3354756 function calls (3345503 primitive calls) in 269.990 seconds

   Ordered by: internal time
   List reduced from 1200 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      276   69.327    0.251   73.928    0.268 /n/home07/yitians/.conda/envs/dm21cm/lib/python3.11/site-packages/scipy/interpolate/_rgi.py:391(_evaluate_linear)
       25   33.833    1.353   33.833    1.353 {built-in method py21cmfast.c_21cmfast.ComputeTsBox}
      776   31.555    0.041   31.555    0.041 {method 'read' of 'h5py._selector.Reader' objects}
      276   18.321    0.066   31.412    0.114 /n/home07/yitians/dm21cm/DM21cm/benchmarking/../dm21cm/data_cacher.py:80(smooth_box)
      276   16.873    0.061   23.957    0.087 {scipy.interpolate._rgi_cython.find_indices}
      900   11.498    0.013   11.498    0.013 {built-in method numpy.fft._pocketfft_internal.execute}
        1   10.870   10.870   10.870

<pstats.Stats at 0x7f7cb9009c50>