In [1]:
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import sys, os
sys.path.insert(0, "/home/storage/hans/jax_reco")

import logging
logging.getLogger("jax").setLevel(logging.ERROR)

In [2]:
from lib.sim_data_i3 import I3SimHandlerFtr
from lib.geo import center_track_pos_and_time_based_on_data

from lib.network import get_network_eval_v_fn
#from lib.network import get_network_matmul_eval_v_fn as get_network_eval_v_fn

from dom_track_eval import get_eval_network_doms_and_track
from time_sampler import sample_times_clean



In [3]:
key = jax.random.PRNGKey(2)

dtype = jnp.float64

In [4]:
eval_network_v = get_network_eval_v_fn(bpath='/home/storage/hans/jax_reco/data/network', dtype=dtype)
eval_network_doms_and_track = get_eval_network_doms_and_track(eval_network_v, dtype=dtype)

In [5]:
bp = '/home/storage2/hans/i3files/21217'
sim_handler = I3SimHandlerFtr(os.path.join(bp, 'meta_ds_21217_from_35000_to_53530.ftr'),
                              os.path.join(bp, 'pulses_ds_21217_from_35000_to_53530.ftr'),
                              '/home/storage/hans/jax_reco/data/icecube/detector_geometry.csv')

# Get a simulated muon event.
event_index = 5
meta, pulses = sim_handler.get_event_data(event_index)
print(f"muon energy: {meta['muon_energy_at_detector']/1.e3:.1f} TeV")

# Get dom locations, first hit times, and total charges (for each dom).
event_data = sim_handler.get_per_dom_summary_from_sim_data(meta, pulses)

print("n_doms", len(event_data))

muon energy: 3.4 TeV
n_doms 43


In [6]:
# Let's generate some new first hit times following our triple pandel model.
# (avoid problems with time smearing for now -> to be implemented: gaussian convoluted triple pandel.)

track_pos = jnp.array([meta['muon_pos_x'], meta['muon_pos_y'], meta['muon_pos_z']])
track_time = meta['muon_time']
track_zenith = meta['muon_zenith']
track_azimuth = meta['muon_azimuth']
track_src = jnp.array([track_zenith, track_azimuth])

print("old track vertex:", track_pos)
centered_track_pos, centered_track_time = center_track_pos_and_time_based_on_data(event_data, track_pos, track_time, track_src)
print("new track vertex:", centered_track_pos)

key, subkey = jax.random.split(key)
first_times = sample_times_clean(eval_network_doms_and_track, event_data, track_pos, track_src, track_time, subkey)

old track vertex: [ 143.40123598  106.94081331 -815.20509865]
new track vertex: [ -25.42355091  403.14288541 -154.63778486]


In [7]:
# Create some n_photons from qtot (by rounding up).
n_photons = np.round(event_data['charge'].to_numpy()+0.5)

# Combine into single data tensor for fitting.
fake_event_data = jnp.column_stack([jnp.array(event_data[['x', 'y', 'z']].to_numpy()), 
                                    jnp.array(first_times), 
                                    jnp.array(n_photons)])
print(fake_event_data.shape)

# Send to GPU.
fake_event_data.devices()
centered_track_pos.devices()
centered_track_time.devices()
track_src.devices()

(43, 5)


{cuda(id=0)}

In [8]:
from likelihood_const_vertex_vectorized import get_neg_mpe_llh_const_vertex_v2
neg_llh = get_neg_mpe_llh_const_vertex_v2(eval_network_doms_and_track)

In [16]:
from jax.scipy import optimize

@jax.jit
def minimize(x0, track_vertex, track_time, event_data):
    return optimize.minimize(neg_llh, x0, args=(track_vertex, track_time, event_data), method="BFGS", tol=1.e-5).x

result = minimize(track_src, centered_track_pos, centered_track_time, fake_event_data)

In [39]:
%timeit -n5 minimize(track_src, centered_track_pos, centered_track_time, fake_event_data).block_until_ready()

210 ms ± 1.2 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)


In [17]:
print(result)

[2.66344555 5.23024581]


In [18]:
print(track_src)

[2.66512738 5.23043102]


In [19]:
#from likelihood_const_vertex import get_neg_mpe_llh_const_vertex
#neg_llh = get_neg_mpe_llh_const_vertex(eval_network_doms_and_track, fake_event_data, centered_track_pos, centered_track_time)

In [20]:
#from jax.scipy import optimize
#
#@jax.jit
#def minimize(x0):
#    return optimize.minimize(neg_llh, x0, method="BFGS", tol=1.e-5)

#result = minimize(track_src)

In [21]:
print(fake_event_data.shape)

(43, 5)


In [41]:
minimize_v = jax.jit(jax.vmap(minimize, (0, 0, 0, 0), 0))

In [50]:
n_ev = 1000

centered_track_pos_v = jnp.repeat(centered_track_pos.reshape(1,3), n_ev, axis=0)
print(centered_track_pos_v.shape)

centered_track_time_v = jnp.repeat(centered_track_time, n_ev)
print(centered_track_time_v.shape)

track_src_v = jnp.repeat(track_src.reshape(1,2), n_ev, axis=0)
print(track_src_v.shape)

fake_event_data_v = jnp.repeat(fake_event_data.reshape(1, fake_event_data.shape[0], fake_event_data.shape[1]), n_ev, axis=0)
print(fake_event_data_v.shape)

(1000, 3)
(1000,)
(1000, 2)
(1000, 43, 5)


In [47]:
result_v = minimize_v(track_src_v, centered_track_pos_v, centered_track_time_v, fake_event_data_v)

In [48]:
print(result_v)

[[2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]
 [2.66344555 5.23024581]


In [51]:
%timeit -n5 minimize_v(track_src_v, centered_track_pos_v, centered_track_time_v, fake_event_data_v).block_until_ready()

2.17 s ± 336 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)


In [52]:
# 2.17s for 1000 event reconstructions