In [21]:
#!/usr/bin/env python

from iminuit import Minuit
import optimistix as optx
import lineax as lx

import sys, os
sys.path.insert(0, "/home/storage/hans/jax_reco")
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)
from jax.scipy import optimize

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# TriplePandelSPE/JAX stuff
from lib.simdata_i3 import I3SimHandlerFtr
from lib.geo import center_track_pos_and_time_based_on_data
from lib.network import get_network_eval_v_fn
from dom_track_eval import get_eval_network_doms_and_track
from likelihood_spe import get_neg_c_triple_gamma_llh
from likelihood_spe import get_llh_and_grad_fs_for_iminuit_migrad

from palettable.cubehelix import Cubehelix
cx =Cubehelix.make(start=0.3, rotation=-0.5, n=16, reverse=False, gamma=1.0,
                           max_light=1.0,max_sat=0.5, min_sat=1.4).get_mpl_colormap()

# Number of scan points on 1D
n_eval = 50 # making it a 100x100 grid

# Scan range (truth +/- dzen, +/- dazi)
dzen = 0.03 # rad
dazi = 0.03 # rad

In [2]:
# Event Index.
event_index = 3

# Get network and eval logic.
eval_network_v = get_network_eval_v_fn(bpath='/home/storage/hans/jax_reco/data/network')
eval_network_doms_and_track = get_eval_network_doms_and_track(eval_network_v)

# Get an IceCube event.
bp = '/home/storage2/hans/i3files/21217'
sim_handler = I3SimHandlerFtr(os.path.join(bp, 'meta_ds_21217_from_35000_to_53530.ftr'),
                              os.path.join(bp, 'pulses_ds_21217_from_35000_to_53530.ftr'),
                              '/home/storage/hans/jax_reco/data/icecube/detector_geometry.csv')

meta, pulses = sim_handler.get_event_data(event_index)
print(f"muon energy: {meta['muon_energy_at_detector']/1.e3:.1f} TeV")

# Get dom locations, first hit times, and total charges (for each dom).
event_data = sim_handler.get_per_dom_summary_from_sim_data(meta, pulses)

print("n_doms", len(event_data))

# Make MCTruth seed.
track_pos = jnp.array([meta['muon_pos_x'], meta['muon_pos_y'], meta['muon_pos_z']])
track_time = meta['muon_time']
track_zenith = meta['muon_zenith']
track_azimuth = meta['muon_azimuth']
track_src = jnp.array([track_zenith, track_azimuth])

print("original seed vertex:", track_pos)
centered_track_pos, centered_track_time = center_track_pos_and_time_based_on_data(event_data, track_pos, track_time, track_src)
print("shifted seed vertex:", centered_track_pos)

# Create some n_photons from qtot (by rounding up).
n_photons = np.round(event_data['charge'].to_numpy()+0.5)

muon energy: 2.2 TeV
n_doms 58
original seed vertex: [ 1180.18566012 -1499.16735802  -782.32800156]
shifted seed vertex: [-38.65768538 207.2454018  145.80801123]


In [3]:
# Combine into single data tensor for fitting.
fitting_event_data = jnp.array(event_data[['x', 'y', 'z', 'time']].to_numpy())

obj_fn, obj_grad = get_llh_and_grad_fs_for_iminuit_migrad(eval_network_doms_and_track)

# put the thing below into a for loop if you want to reconstruct many events (without jit-recompiling everything)
f_prime = lambda x: obj_fn(x, centered_track_time, fitting_event_data)
grad_prime = lambda x: obj_grad(x, centered_track_time, fitting_event_data)

x0 = jnp.concatenate([track_src, centered_track_pos])
m = Minuit(f_prime, x0, grad=grad_prime)
m.errordef = Minuit.LIKELIHOOD
m.limits = ((0.0, np.pi), (0.0, 2.0 * np.pi), (-500.0, 500.0),  (-500.0, 500.0),  (-500.0, 500.0))
m.strategy = 0
m.migrad()

print("... solution found.")
print(f"-2*logl={m.fval:.3f}")
print(f"zenith={m.values[0]:.3f}rad")
print(f"azimuth={m.values[1]:.3f}rad")
print(f"x={m.values[2]:.3f}m")
print(f"y={m.values[3]:.3f}m")
print(f"z={m.values[4]:.3f}m")
print(f"at fix time t={centered_track_time:.3f}ns")

... solution found.
-2*logl=737.833
zenith=1.990rad
azimuth=5.346rad
x=-38.957m
y=209.023m
z=145.781m
at fix time t=12375.863ns


In [56]:
# Setup likelihood
neg_llh = get_neg_c_triple_gamma_llh(eval_network_doms_and_track)

scale = 20.0
@jax.jit
def neg_llh_5D(x, args):
        return neg_llh(x[:2]/scale, x[2:]*scale, centered_track_time, fitting_event_data)

In [57]:
bounds = {'lower': jnp.array([0.0*scale, 0.0*scale, -700.0/scale, -700.0/scale, -700.0/scale]),
'upper': jnp.array([jnp.pi*scale, 2.0*jnp.pi*scale, 700.0/scale, 700.0/scale, 700.0/scale])}

newton = optx.Newton(rtol=1e-6, atol=1e-3)
neg_llh_grad = jax.jit(jax.grad(neg_llh_5D))

x0 = jnp.concatenate([track_src*scale, centered_track_pos/scale])
result = optx.root_find(neg_llh_grad, newton, x0,  options=bounds)

In [58]:
new_result = jnp.concatenate([result.value[:2]/scale, result.value[2:]*scale])
print(new_result)
print(neg_llh_5D(result.value, None))

[  1.99009918   5.3461513  -38.95568903 209.02091478 145.78149786]
737.8325328986425


In [59]:
%timeit optx.root_find(neg_llh_grad, newton, x0,  options=bounds)

19.7 ms ± 96.5 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [75]:
bounds = {'lower': jnp.array([-700.0/scale, -700.0/scale, -700.0/scale]),
'upper': jnp.array([700.0/scale, 700.0/scale, 700.0/scale])}
x0 = jnp.array(centered_track_pos/scale)

newton = optx.Newton(rtol=1e-6, atol=1e-3, linear_solver=lx.SVD())

@jax.jit
def neg_llh_3D(x, track_dir):
    return neg_llh(track_dir, x*scale, centered_track_time, fitting_event_data)

neg_llh_grad_3D = jax.jit(jax.grad(neg_llh_3D))

def run_3D(track_dir):
    values = optx.root_find(neg_llh_grad_3D, newton, x0, args=track_dir, options=bounds, max_steps=1000).value
    return neg_llh_3D(values, track_dir)

run_3D_v = jax.jit(jax.vmap(run_3D, 0, 0))

In [76]:
run_3D(track_src)

Array(740.68839203, dtype=float64)

In [77]:
%timeit run_3D(track_src)

46.2 ms ± 54.6 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [78]:
import time

time1 = time.time()
# Number of scan points on 1D
n_eval = 10 # making it a 30x30 grid

# Scan range (truth +/- dzen, +/- dazi)
dzen = 0.03 # rad
dazi = 0.03 # rad

zenith = jnp.linspace(track_src[0]-dzen, track_src[0]+dazi, n_eval)
azimuth = jnp.linspace(track_src[1]-dzen, track_src[1]+dazi, n_eval)
X, Y = jnp.meshgrid(zenith, azimuth)
init_dirs = jnp.column_stack([X.flatten(), Y.flatten()])

logls = run_3D_v(init_dirs)

logls = logls.reshape(X.shape)
time2 = time.time()

print(f"elapsed: {time2-time1}s")

jax.pure_callback failed
Traceback (most recent call last):
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/jax/_src/callback.py", line 79, in pure_callback_impl
    return tree_util.tree_map(np.asarray, callback(*args))
                                          ^^^^^^^^^^^^^^^
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/jax/_src/callback.py", line 64, in __call__
    return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/equinox/_errors.py", line 70, in raises
    raise EqxRuntimeError(msgs[_index.item()])
equinox._errors.EqxRuntimeError: The maximum number of steps was reached in the nonlinear solver. The problem may not be solveable (e.g., a root-find on a function that has no roots), or you may need to incre

XlaRuntimeError: INTERNAL: CustomCall failed: CpuCallback error: Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/ipykernel_launcher.py", line 18, in <module>
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/traitlets/config/application.py", line 1075, in launch_instance
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 739, in start
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 205, in start
  File "/home/hans/.pyenv/versions/3.11.5/lib/python3.11/asyncio/base_events.py", line 607, in run_forever
  File "/home/hans/.pyenv/versions/3.11.5/lib/python3.11/asyncio/base_events.py", line 1922, in _run_once
  File "/home/hans/.pyenv/versions/3.11.5/lib/python3.11/asyncio/events.py", line 80, in _run
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 534, in process_one
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 362, in execute_request
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 778, in execute_request
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 449, in do_execute
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3075, in run_cell
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3130, in _run_cell
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
  File "/tmp/ipykernel_20568/2394509584.py", line 16, in <module>
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/jax/_src/pjit.py", line 304, in cache_miss
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/jax/_src/pjit.py", line 181, in _python_pjit_helper
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/jax/_src/core.py", line 2789, in bind
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/jax/_src/core.py", line 391, in bind_with_trace
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/jax/_src/core.py", line 879, in process_primitive
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/jax/_src/pjit.py", line 1525, in _pjit_call_impl
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/jax/_src/pjit.py", line 1508, in call_impl_cache_miss
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/jax/_src/pjit.py", line 1462, in _pjit_call_impl_python
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/jax/_src/profiler.py", line 335, in wrapper
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 1177, in __call__
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py", line 2483, in _wrapped_callback
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/jax/_src/callback.py", line 221, in _callback
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/jax/_src/callback.py", line 82, in pure_callback_impl
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/jax/_src/callback.py", line 64, in __call__
  File "/home/hans/.local/share/virtualenvs/py3_jax_latest-mr9UFGRS/lib/python3.11/site-packages/equinox/_errors.py", line 70, in raises
EqxRuntimeError: The maximum number of steps was reached in the nonlinear solver. The problem may not be solveable (e.g., a root-find on a function that has no roots), or you may need to increase `max_steps`.