In [1]:
%reload_ext autoreload
%autoreload 2

import sys

import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

sys.path.append("..")
from likelihoods.npll_jax import log_like_np
from templates.variable_templates import NFWTemplate
from utils import create_mask as cm
from utils.psf import KingPSF
from utils.psf_correction import PSFCorrection

# JAX

In [2]:
data = jnp.load(f"../data/fermi_data_573w/fermi_data_128/fermidata_counts.npy").astype(jnp.int32)
mask_ps = jnp.load(f"../data/mask_3fgl_0p8deg.npy")
mask = cm.make_mask_total(nside=128, band_mask=True, band_mask_range=2, mask_ring=True, inner=0, outer=25, custom_mask=mask_ps)
data_in = data[~mask]
jnp.save("../data/tmp/data_in.npy", data_in)
pt = jnp.load("../data/fermi_data_573w/fermi_data_128/template_Opi.npy")[~mask]
jnp.save("../data/tmp/pt.npy", pt)
nfw_template = NFWTemplate(nside=128)
npt = nfw_template.get_NFW2_template(gamma=1.2)[~mask]
npt = npt[None, :]
jnp.save("../data/tmp/npt.npy", npt)
k_max = int(jnp.max(data_in))
npixROI = int(jnp.sum(~mask))

# psf
kp = KingPSF()
pc_inst = PSFCorrection(delay_compute=True, num_f_bins=15, nside=128)
pc_inst.psf_r_func = lambda r: kp.psf_fermi_r(r)
pc_inst.sample_psf_max = 10.0 * kp.spe * (kp.score + kp.stail) / 2.0
pc_inst.psf_samples = 10000
pc_inst.psf_tag = "Fermi_PSF_2GeV2_nside128"
pc_inst.make_or_load_psf_corr()
f_arr = pc_inst.f_ary
df_rho_div_f_arr = pc_inst.df_rho_div_f_ary

Loading the psf correction from: ../data/psf_dir/Fermi_PSF_2GeV2_nside128.npy


In [3]:
def jax_run(n):
    keys = [jax.random.PRNGKey(0)]
    for _ in range(n):
        keys = jax.random.split(keys[0], 7)
        Sps = jax.random.uniform(keys[1], shape=(1,), minval=1.0, maxval=2.0)
        n1 = jax.random.uniform(keys[2], shape=(1,), minval=4.0, maxval=6.0)
        n2 = jax.random.uniform(keys[3], shape=(1,), minval=0.5, maxval=1.99)
        n3 = jax.random.uniform(keys[4], shape=(1,), minval=-6.0, maxval=-5.0)
        sb1 = jax.random.uniform(keys[5], shape=(1,), minval=5.0, maxval=10.0)
        lambda_s = jax.random.uniform(keys[6], shape=(1,), minval=0.1, maxval=0.95)
        theta = jnp.array([[Sps, n1, n2, n3, sb1, lambda_s*sb1]])
        log_like_np(theta, pt, npt, data_in, f_arr, df_rho_div_f_arr, k_max, npixROI).block_until_ready()

In [4]:
%time jax_run(5000)

CPU times: user 19 s, sys: 3.92 s, total: 23 s
Wall time: 25.4 s


In [5]:
jax.devices()

[cuda(id=0)]

In [6]:
! nvidia-smi

Tue Jun 18 01:56:06 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          On  |   00000000:31:00.0 Off |                   On |
| N/A   24C    P0             51W /  400W |   15253MiB /  40960MiB |     N/A      Default |
|                                         |                        |              Enabled |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

  pid, fd = os.forkpty()


+-----------------------------------------------------------------------------------------+


# NPTF

In [1]:
%reload_ext autoreload
%autoreload 2

import sys

import numpy as np
np.float = np.float64

sys.path.append("/n/home07/yitians/fermi/NPTFit/NPTFit")
sys.path.append("/n/home07/yitians/fermi/NPTFit")
from NPTFit import npll

In [2]:
data_in = np.load("../data/tmp/data_in.npy")
pt = np.load("../data/tmp/pt.npy")
npt = np.load("../data/tmp/npt.npy")
k_max = int(np.max(data_in))
npixROI = len(data_in)
f_arr, df_rho_div_f_arr = np.load("../data/psf_dir/Fermi_PSF_2GeV2_nside128.npy")

In [4]:
def nptf_run(n):
    for _ in range(n):
        Sps = np.random.uniform(1.0, 2.0)
        n1 = np.random.uniform(4.0, 6.0)
        n2 = np.random.uniform(0.5, 1.99)
        n3 = np.random.uniform(-6.0, -5.0)
        sb1 = np.random.uniform(5.0, 10.0)
        lambda_s = np.random.uniform(0.1, 0.95)
        theta = np.array([[Sps, n1, n2, n3, sb1, lambda_s*sb1]])
        npll.log_like(pt, theta, f_arr, df_rho_div_f_arr, npt, data_in)

In [6]:
%time nptf_run(500)

CPU times: user 17.4 s, sys: 1.35 s, total: 18.8 s
Wall time: 18.8 s


In [7]:
! lscpu

Architecture:        x86_64
CPU op-mode(s):      32-bit, 64-bit
Byte Order:          Little Endian
CPU(s):              64
On-line CPU(s) list: 0-63
Thread(s) per core:  1
Core(s) per socket:  32
Socket(s):           2
NUMA node(s):        2
Vendor ID:           GenuineIntel
CPU family:          6
Model:               106
Model name:          Intel(R) Xeon(R) Platinum 8358 CPU @ 2.60GHz
Stepping:            6
CPU MHz:             3400.000
CPU max MHz:         3400.0000
CPU min MHz:         800.0000
BogoMIPS:            5200.00
Virtualization:      VT-x
L1d cache:           48K
L1i cache:           32K
L2 cache:            1280K
L3 cache:            49152K
NUMA node0 CPU(s):   0-31
NUMA node1 CPU(s):   32-63
Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 

In [9]:
! cat /proc/meminfo

MemTotal:       527460592 kB
MemFree:        439802064 kB
MemAvailable:   515752928 kB
Buffers:            5392 kB
Cached:         76174200 kB
SwapCached:            0 kB
Active:         53588728 kB
Inactive:       26749928 kB
Active(anon):       3920 kB
Inactive(anon):  4228756 kB
Active(file):   53584808 kB
Inactive(file): 22521172 kB
Unevictable:       65104 kB
Mlocked:           65104 kB
SwapTotal:             0 kB
SwapFree:              0 kB
Dirty:                 0 kB
Writeback:            12 kB
AnonPages:       4174964 kB
Mapped:           507100 kB
Shmem:             54628 kB
KReclaimable:    3265276 kB
Slab:            4971488 kB
SReclaimable:    3265276 kB
SUnreclaim:      1706212 kB
KernelStack:       21968 kB
PageTables:        87728 kB
NFS_Unstable:          0 kB
Bounce:                0 kB
WritebackTmp:          0 kB
CommitLimit:    263730296 kB
Committed_AS:    3143888 kB
VmallocTotal:   13743895347199 kB
VmallocUsed:      377996 kB
VmallocChunk:          0 kB
Percpu:   

# JAX CPU

In [2]:
jax.devices()

[CpuDevice(id=0)]

In [3]:
data = jnp.load(f"../data/fermi_data_573w/fermi_data_128/fermidata_counts.npy").astype(jnp.int32)
mask_ps = jnp.load(f"../data/mask_3fgl_0p8deg.npy")
mask = cm.make_mask_total(nside=128, band_mask=True, band_mask_range=2, mask_ring=True, inner=0, outer=25, custom_mask=mask_ps)
data_in = data[~mask]
jnp.save("../data/tmp/data_in.npy", data_in)
pt = jnp.load("../data/fermi_data_573w/fermi_data_128/template_Opi.npy")[~mask]
jnp.save("../data/tmp/pt.npy", pt)
nfw_template = NFWTemplate(nside=128)
npt = nfw_template.get_NFW2_template(gamma=1.2)[~mask]
npt = npt[None, :]
jnp.save("../data/tmp/npt.npy", npt)
k_max = int(jnp.max(data_in))
npixROI = int(jnp.sum(~mask))

# psf
kp = KingPSF()
pc_inst = PSFCorrection(delay_compute=True, num_f_bins=15, nside=128)
pc_inst.psf_r_func = lambda r: kp.psf_fermi_r(r)
pc_inst.sample_psf_max = 10.0 * kp.spe * (kp.score + kp.stail) / 2.0
pc_inst.psf_samples = 10000
pc_inst.psf_tag = "Fermi_PSF_2GeV2_nside128"
pc_inst.make_or_load_psf_corr()
f_arr = pc_inst.f_ary
df_rho_div_f_arr = pc_inst.df_rho_div_f_ary

Loading the psf correction from: ../data/psf_dir/Fermi_PSF_2GeV2_nside128.npy


In [4]:
def jax_run(n):
    keys = [jax.random.PRNGKey(0)]
    for _ in range(n):
        keys = jax.random.split(keys[0], 7)
        Sps = jax.random.uniform(keys[1], shape=(1,), minval=1.0, maxval=2.0)
        n1 = jax.random.uniform(keys[2], shape=(1,), minval=4.0, maxval=6.0)
        n2 = jax.random.uniform(keys[3], shape=(1,), minval=0.5, maxval=1.99)
        n3 = jax.random.uniform(keys[4], shape=(1,), minval=-6.0, maxval=-5.0)
        sb1 = jax.random.uniform(keys[5], shape=(1,), minval=5.0, maxval=10.0)
        lambda_s = jax.random.uniform(keys[6], shape=(1,), minval=0.1, maxval=0.95)
        theta = jnp.array([[Sps, n1, n2, n3, sb1, lambda_s*sb1]])
        log_like_np(theta, pt, npt, data_in, f_arr, df_rho_div_f_arr, k_max, npixROI).block_until_ready()

In [5]:
%time jax_run(100)

CPU times: user 16.2 s, sys: 244 ms, total: 16.5 s
Wall time: 13.2 s


In [6]:
! lscpu

Architecture:        x86_64
CPU op-mode(s):      32-bit, 64-bit
Byte Order:          Little Endian
CPU(s):              112
On-line CPU(s) list: 0-111
Thread(s) per core:  1
Core(s) per socket:  56
Socket(s):           2
NUMA node(s):        2
Vendor ID:           GenuineIntel
CPU family:          6
Model:               143
Model name:          Intel(R) Xeon(R) Platinum 8480CL
Stepping:            7
CPU MHz:             3800.000
CPU max MHz:         3800.0000
CPU min MHz:         800.0000
BogoMIPS:            4000.00
L1d cache:           48K
L1i cache:           32K
L2 cache:            2048K
L3 cache:            107520K
NUMA node0 CPU(s):   0-55
NUMA node1 CPU(s):   56-111
Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 ds_cpl smx est tm2 

  pid, fd = os.forkpty()


In [7]:
! cat /proc/meminfo

MemTotal:       1056005468 kB
MemFree:        768496272 kB
MemAvailable:   1011436476 kB
Buffers:          105312 kB
Cached:         213106144 kB
SwapCached:            0 kB
Active:         84617888 kB
Inactive:       133548260 kB
Active(anon):     158116 kB
Inactive(anon):  5167096 kB
Active(file):   84459772 kB
Inactive(file): 128381164 kB
Unevictable:       89596 kB
Mlocked:           89596 kB
SwapTotal:             0 kB
SwapFree:              0 kB
Dirty:              1452 kB
Writeback:             0 kB
AnonPages:       5017188 kB
Mapped:          1180648 kB
Shmem:            353648 kB
KReclaimable:   35928132 kB
Slab:           65858696 kB
SReclaimable:   35928132 kB
SUnreclaim:     29930564 kB
KernelStack:       36064 kB
PageTables:        65632 kB
NFS_Unstable:          0 kB
Bounce:                0 kB
WritebackTmp:          0 kB
CommitLimit:    528002732 kB
Committed_AS:    7225472 kB
VmallocTotal:   13743895347199 kB
VmallocUsed:     1107632 kB
VmallocChunk:          0 kB
Percp

In [4]:
def jax_run(n):
    rng_key = jax.random.PRNGKey(0)
    keys = jax.random.split(rng_key, n)

    Sps = jax.random.uniform(keys[0], shape=(n,), minval=1e-5, maxval=2.0)
    n1 = jax.random.uniform(keys[1], shape=(n,), minval=4.0, maxval=6.0)
    n2 = jax.random.uniform(keys[2], shape=(n,), minval=0.5, maxval=1.99)
    n3 = jax.random.uniform(keys[3], shape=(n,), minval=-6.0, maxval=-5.0)
    sb1 = jax.random.uniform(keys[4], shape=(n,), minval=5.0, maxval=40.0)
    lambda_s = jax.random.uniform(keys[5], shape=(n,), minval=0.1, maxval=0.95)

    theta = jnp.stack([Sps, n1, n2, n3, sb1, lambda_s * sb1], axis=1)[:, None, :]

    ll_total = jax.vmap(lambda theta_single: log_like_np(
        theta_single, pt, npt, data_in, f_arr, df_rho_div_f_arr, k_max, npixROI).sum())(theta)

In [6]:
%time jax_run(1000)

CPU times: user 2min 33s, sys: 2.01 s, total: 2min 35s
Wall time: 1min 20s


Array(-inf, dtype=float64)

# NPTF (CPU cluster for comparison)

In [1]:
%reload_ext autoreload
%autoreload 2

import sys

import numpy as np
np.float = np.float64

sys.path.append("/n/home07/yitians/fermi/NPTFit/NPTFit")
sys.path.append("/n/home07/yitians/fermi/NPTFit")
from NPTFit import npll

In [2]:
data_in = np.load("../data/tmp/data_in.npy")
pt = np.load("../data/tmp/pt.npy")
npt = np.load("../data/tmp/npt.npy")
k_max = int(np.max(data_in))
npixROI = len(data_in)
f_arr, df_rho_div_f_arr = np.load("../data/psf_dir/Fermi_PSF_2GeV2_nside128.npy")

In [3]:
def nptf_run(n):
    for _ in range(n):
        Sps = np.random.uniform(1.0, 2.0)
        n1 = np.random.uniform(4.0, 6.0)
        n2 = np.random.uniform(0.5, 1.99)
        n3 = np.random.uniform(-6.0, -5.0)
        sb1 = np.random.uniform(5.0, 10.0)
        lambda_s = np.random.uniform(0.1, 0.95)
        theta = np.array([[Sps, n1, n2, n3, sb1, lambda_s*sb1]])
        npll.log_like(pt, theta, f_arr, df_rho_div_f_arr, npt, data_in)

In [5]:
%time nptf_run(100)

CPU times: user 3.22 s, sys: 272 ms, total: 3.49 s
Wall time: 3.5 s


In [6]:
! lscpu

Architecture:        x86_64
CPU op-mode(s):      32-bit, 64-bit
Byte Order:          Little Endian
CPU(s):              112
On-line CPU(s) list: 0-111
Thread(s) per core:  1
Core(s) per socket:  56
Socket(s):           2
NUMA node(s):        2
Vendor ID:           GenuineIntel
CPU family:          6
Model:               143
Model name:          Intel(R) Xeon(R) Platinum 8480CL
Stepping:            7
CPU MHz:             3800.000
CPU max MHz:         3800.0000
CPU min MHz:         800.0000
BogoMIPS:            4000.00
L1d cache:           48K
L1i cache:           32K
L2 cache:            2048K
L3 cache:            107520K
NUMA node0 CPU(s):   0-55
NUMA node1 CPU(s):   56-111
Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 ds_cpl smx est tm2 

In [7]:
! cat /proc/meminfo

MemTotal:       1056005468 kB
MemFree:        765468108 kB
MemAvailable:   1011379640 kB
Buffers:          105312 kB
Cached:         216077372 kB
SwapCached:            0 kB
Active:         84682728 kB
Inactive:       136282512 kB
Active(anon):     158188 kB
Inactive(anon):  4994924 kB
Active(file):   84524540 kB
Inactive(file): 131287588 kB
Unevictable:       89588 kB
Mlocked:           89588 kB
SwapTotal:             0 kB
SwapFree:              0 kB
Dirty:                 0 kB
Writeback:           156 kB
AnonPages:       4846024 kB
Mapped:          1002152 kB
Shmem:            353640 kB
KReclaimable:   35928268 kB
Slab:           66090588 kB
SReclaimable:   35928268 kB
SUnreclaim:     30162320 kB
KernelStack:       34336 kB
PageTables:        66440 kB
NFS_Unstable:          0 kB
Bounce:                0 kB
WritebackTmp:          0 kB
CommitLimit:    528002732 kB
Committed_AS:    5963960 kB
VmallocTotal:   13743895347199 kB
VmallocUsed:     1105868 kB
VmallocChunk:          0 kB
Percp