In [1]:
import jax.numpy as jnp
import jax
from jax.scipy import optimize
jax.config.update("jax_enable_x64", True)
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

import numpy as np
import scipy
import matplotlib.pyplot as plt

import sys, os



In [2]:
sys.path.insert(0, "/home/storage/hans/jax_reco")
from lib.cgamma import c_multi_gamma_prob, c_multi_gamma_prob_v

In [3]:
from lib.plotting import adjust_plot_1d
from lib.network import get_network_eval_fn
from lib.geo import get_xyz_from_zenith_azimuth
from lib.trafos import transform_network_outputs, transform_network_inputs

In [4]:
eval_network = get_network_eval_fn(bpath='/home/storage/hans/jax_reco/data/network')

In [8]:
from dom_track_eval import get_eval_network_doms_and_track
from time_sampler import sample_times_clean
from lib.simdata_i3 import I3SimHandlerFtr
from lib.geo import center_track_pos_and_time_based_on_data
from lib.network import get_network_eval_v_fn

# Get network and eval logic.
eval_network_v = get_network_eval_v_fn(bpath='/home/storage/hans/jax_reco/data/network')
eval_network_doms_and_track = get_eval_network_doms_and_track(eval_network_v)

In [9]:
# Get an IceCube event.
event_index = 0
key = jax.random.PRNGKey(event_index)

bp = '/home/storage2/hans/i3files/21217'
sim_handler = I3SimHandlerFtr(os.path.join(bp, 'meta_ds_21217_from_35000_to_53530.ftr'),
                              os.path.join(bp, 'pulses_ds_21217_from_35000_to_53530.ftr'),
                              '/home/storage/hans/jax_reco/data/icecube/detector_geometry.csv')

meta, pulses = sim_handler.get_event_data(event_index)
print(f"muon energy: {meta['muon_energy_at_detector']/1.e3:.1f} TeV")

# Get dom locations, first hit times, and total charges (for each dom).
event_data = sim_handler.get_per_dom_summary_from_sim_data(meta, pulses)

muon energy: 2.1 TeV


In [10]:
# Let's generate some new first hit times following our triple pandel model.
# (avoid problems with time smearing for now -> to be implemented: gaussian convoluted triple pandel.)
track_pos = jnp.array([meta['muon_pos_x'], meta['muon_pos_y'], meta['muon_pos_z']])
track_time = meta['muon_time']
track_zenith = meta['muon_zenith']
track_azimuth = meta['muon_azimuth']
track_src = jnp.array([track_zenith, track_azimuth])

print("old track vertex:", track_pos)
centered_track_pos, centered_track_time = center_track_pos_and_time_based_on_data(event_data, track_pos, track_time, track_src)
print("new track vertex:", centered_track_pos)

key, subkey = jax.random.split(key)
first_times = sample_times_clean(eval_network_doms_and_track, event_data, track_pos, track_src, track_time, subkey)

old track vertex: [-1277.51128861 -1390.39564543 -1675.98024553]
new track vertex: [ -53.74394146  162.12452256 -233.73599134]


In [11]:
n_photons = np.round(event_data['charge'].to_numpy()+0.5)
fake_event_data = jnp.column_stack([
                                        jnp.array(event_data[['x', 'y', 'z']].to_numpy()),
                                        jnp.array(first_times),
                                        jnp.array(n_photons)
                                   ])

In [12]:
logits, av, bv, geo_time = eval_network_doms_and_track(fake_event_data[:,:3], 
                                                          centered_track_pos, 
                                                          track_src)

first_hit_times = fake_event_data[:, 3]
delay_time = first_hit_times - (geo_time + centered_track_time)
print(delay_time.shape)

(29,)


In [14]:
mix_probs = jax.nn.softmax(logits)
print(c_multi_gamma_prob_v(delay_time, mix_probs, av, bv, 3.0, 10.0))

[0.00047225 0.00471289 0.00538641 0.00126069 0.011568   0.00558727
 0.00132366 0.00179196 0.00106369 0.00178111 0.01401595 0.05355214
 0.08541373 0.02755797 0.00031838 0.00036081 0.0009707  0.00020808
 0.00198526 0.00162107 0.0304282  0.00363625 0.00389488 0.00035686
 0.00100218 0.00320803 0.00405583 0.00410189 0.00135668]


In [17]:
%timeit c_multi_gamma_prob_v(delay_time, mix_probs, av, bv, 3.0, 10.0)

8.86 ms ± 22.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [18]:
print(c_multi_gamma_prob(delay_time[0], mix_probs[0], av[0], bv[0], 3.0, 10.0))

0.0004722542195556937


In [19]:
%timeit c_multi_gamma_prob(delay_time[0], mix_probs[0], av[0], bv[0], 3.0, 10.0)

1.07 ms ± 23.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
