In [1]:
import jax.numpy as jnp
import jax
from jax.scipy import optimize
jax.config.update("jax_enable_x64", True)
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

import numpy as np
import scipy
import matplotlib.pyplot as plt

import sys, os



In [2]:
sys.path.insert(0, "/home/storage/hans/jax_reco")
from lib.cgamma import c_gamma_prob

In [3]:
from lib.plotting_tools import adjust_plot_1d
from lib.network import get_network_eval_fn
from lib.geo import get_xyz_from_zenith_azimuth
from lib.trafos import transform_network_outputs, transform_network_inputs

In [4]:
eval_network = get_network_eval_fn(bpath='/home/storage/hans/jax_reco/data/network')

In [5]:
# Set choice of parameters.

dist = 10
z = -210
rho = 0.0
zenith = np.pi/2
azimuth = 0.0

x = jnp.array([dist, rho, z, zenith, azimuth])
x_prime = transform_network_inputs(x)
y = eval_network(x_prime)
logits, gamma_a, gamma_b = transform_network_outputs(y)

In [7]:
print(c_gamma_prob(0.0, gamma_a, gamma_b))

[0.01023211 0.05535401 0.00169289]


In [10]:
n_ev = 10
gamma_a_v = jnp.repeat(gamma_a.reshape((1,3)), n_ev, axis=0)
gamma_b_v = jnp.repeat(gamma_b.reshape((1,3)), n_ev, axis=0)

In [12]:
print(c_gamma_prob(jnp.linspace(0.0, 10.0, n_ev).reshape((10, 1)), gamma_a_v, gamma_b_v))

[[0.01023211 0.05535401 0.00169289]
 [0.01383071 0.07589159 0.00217874]
 [0.01732342 0.09372509 0.00260039]
 [0.02034257 0.10471538 0.0029184 ]
 [0.02267333 0.10635969 0.0031257 ]
 [0.02427747 0.09874732 0.00324088]
 [0.02524454 0.08430729 0.00329317]
 [0.0257174  0.06662201 0.0033094 ]
 [0.02583476 0.04906419 0.00330805]
 [0.02570489 0.03391365 0.00329938]]


In [31]:
from dom_track_eval import get_eval_network_doms_and_track
from time_sampler import sample_times_clean
from lib.sim_data_i3 import I3SimHandlerFtr
from lib.geo import center_track_pos_and_time_based_on_data
from lib.network import get_network_eval_v_fn

# Get network and eval logic.
eval_network_v = get_network_eval_v_fn(bpath='/home/storage/hans/jax_reco/data/network')
eval_network_doms_and_track = get_eval_network_doms_and_track(eval_network_v)

In [32]:
# Get an IceCube event.
event_index = 0
key = jax.random.PRNGKey(event_index)

bp = '/home/storage2/hans/i3files/21217'
sim_handler = I3SimHandlerFtr(os.path.join(bp, 'meta_ds_21217_from_35000_to_53530.ftr'),
                              os.path.join(bp, 'pulses_ds_21217_from_35000_to_53530.ftr'),
                              '/home/storage/hans/jax_reco/data/icecube/detector_geometry.csv')

meta, pulses = sim_handler.get_event_data(event_index)
print(f"muon energy: {meta['muon_energy_at_detector']/1.e3:.1f} TeV")

# Get dom locations, first hit times, and total charges (for each dom).
event_data = sim_handler.get_per_dom_summary_from_sim_data(meta, pulses)

muon energy: 2.1 TeV


In [33]:
# Let's generate some new first hit times following our triple pandel model.
# (avoid problems with time smearing for now -> to be implemented: gaussian convoluted triple pandel.)
track_pos = jnp.array([meta['muon_pos_x'], meta['muon_pos_y'], meta['muon_pos_z']])
track_time = meta['muon_time']
track_zenith = meta['muon_zenith']
track_azimuth = meta['muon_azimuth']
track_src = jnp.array([track_zenith, track_azimuth])

print("old track vertex:", track_pos)
centered_track_pos, centered_track_time = center_track_pos_and_time_based_on_data(event_data, track_pos, track_time, track_src)
print("new track vertex:", centered_track_pos)

key, subkey = jax.random.split(key)
first_times = sample_times_clean(eval_network_doms_and_track, event_data, track_pos, track_src, track_time, subkey)


old track vertex: [-1277.51128861 -1390.39564543 -1675.98024553]
new track vertex: [ -53.74394146  162.12452256 -233.73599134]


In [35]:
n_photons = np.round(event_data['charge'].to_numpy()+0.5)
fake_event_data = jnp.column_stack([
                                        jnp.array(event_data[['x', 'y', 'z']].to_numpy()),
                                        jnp.array(first_times),
                                        jnp.array(n_photons)
                                   ])

In [42]:
logits, av, bv, geo_time = eval_network_doms_and_track(fake_event_data[:,:3], 
                                                          centered_track_pos, 
                                                          track_src)

first_hit_times = fake_event_data[:, 3]
delay_time = first_hit_times - (geo_time + centered_track_time)
print(delay_time.shape)

(29,)


In [76]:
# now test triple pandel with convolution

c_gamma_probs = c_gamma_prob(delay_time.reshape(delay_time.shape[0], 1), av, bv)
print(c_gamma_probs.shape)

(29, 3)


In [77]:
print(c_gamma_probs)

[[6.25946769e-04 2.61379252e-08 7.78013696e-04]
 [2.03995672e-03 1.13777375e-02 4.67251774e-04]
 [2.75403582e-03 1.20880631e-02 5.16731684e-04]
 [2.15274726e-03 5.04793260e-04 6.39291571e-04]
 [3.60529834e-03 2.40529777e-02 1.21162800e-03]
 [6.60763833e-03 7.06583491e-03 1.45530919e-03]
 [2.64308705e-03 1.99627758e-05 1.07430285e-03]
 [1.04124954e-03 4.27634905e-03 2.50331667e-04]
 [4.09051876e-04 2.72004325e-03 1.30159442e-04]
 [6.55117270e-04 4.71247893e-03 2.19989731e-04]
 [4.52236530e-03 2.69873404e-02 1.64063070e-03]
 [1.83825025e-02 8.47790634e-02 2.21541717e-03]
 [2.44095178e-02 1.16694012e-01 2.00141696e-03]
 [8.56929425e-03 4.60298337e-02 1.80495396e-03]
 [2.26763300e-04 9.89077526e-16 1.21013072e-03]
 [3.48303443e-04 5.79982163e-09 8.08234637e-04]
 [1.65240942e-03 3.51729667e-06 1.10745912e-03]
 [1.67570975e-07 1.02344299e-26 6.84466157e-04]
 [4.55137966e-03 2.80682177e-05 1.37584337e-03]
 [2.51985227e-03 6.19725052e-04 9.56105590e-04]
 [1.29305602e-02 6.29070159e-02 4.716109

In [67]:
%timeit c_gamma_prob(delay_time.reshape(delay_time.shape[0], 1), av, bv)

8.85 ms ± 2.91 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [78]:
c_gamma_prob_v = jax.jit(jax.vmap(c_gamma_prob, (0, 0, 0, None, None), 0))

In [79]:
print(c_gamma_prob_v(delay_time, av, bv, 3.0, 10.0))

[[6.25946769e-04 2.61379252e-08 7.78013696e-04]
 [2.03995672e-03 1.13777375e-02 4.67251774e-04]
 [2.75403582e-03 1.20880631e-02 5.16731684e-04]
 [2.15274726e-03 5.04793260e-04 6.39291571e-04]
 [3.60529834e-03 2.40529777e-02 1.21162800e-03]
 [6.60763833e-03 7.06583491e-03 1.45530919e-03]
 [2.64308705e-03 1.99627758e-05 1.07430285e-03]
 [1.04124954e-03 4.27634905e-03 2.50331667e-04]
 [4.09051876e-04 2.72004325e-03 1.30159442e-04]
 [6.55117270e-04 4.71247893e-03 2.19989731e-04]
 [4.52236530e-03 2.69873404e-02 1.64063070e-03]
 [1.83825025e-02 8.47790634e-02 2.21541717e-03]
 [2.44095178e-02 1.16694012e-01 2.00141696e-03]
 [8.56929425e-03 4.60298337e-02 1.80495396e-03]
 [2.26763300e-04 9.89077526e-16 1.21013072e-03]
 [3.48303443e-04 5.79982163e-09 8.08234637e-04]
 [1.65240942e-03 3.51729667e-06 1.10745912e-03]
 [1.67570975e-07 1.02344299e-26 6.84466157e-04]
 [4.55137966e-03 2.80682177e-05 1.37584337e-03]
 [2.51985227e-03 6.19725052e-04 9.56105590e-04]
 [1.29305602e-02 6.29070159e-02 4.716109

In [65]:
print(c_gamma_prob_v(delay_time, av, bv, 3.0, 10.0).shape)

(29, 3)


In [66]:
%timeit c_gamma_prob_v(delay_time, av, bv, 3.0, 10.0)

8.87 ms ± 27.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [50]:
print(c_gamma_probs.shape)
print(logits.shape)
mix_probs = jax.nn.softmax(logits)

(29, 3)
(29, 3)


In [51]:
mix_probs.shape

(29, 3)

In [57]:
-2*jnp.sum(jnp.log(jnp.sum(mix_probs * c_gamma_probs, axis=-1)))

Array(341.65560723, dtype=float64)

In [91]:
@jax.jit
def c_triple_gamma_prob(x, mix_probs, a, b, sigma, delta):
    g_probs = c_gamma_prob_v(x, a, b, sigma, delta)
    return jnp.sum(mix_probs * g_probs, axis=-1)

In [92]:
print(c_triple_gamma_prob(delay_time, mix_probs, av, av, 3.0, 10.0))

[0.00000000e+000 9.13964284e-050 1.13991962e-053 2.13464238e-295
 1.03262062e-010 4.31926657e-030 2.82099431e-160 6.69006205e-289
 2.00292430e-161 3.59819301e-060 6.28935202e-003 1.48178214e-002
 1.29997095e-001 1.03520036e-001 5.86031222e-112 0.00000000e+000
 3.47631286e-278 0.00000000e+000 1.80710313e-096 0.00000000e+000
 1.23918927e-001 3.89946597e-061 3.81039898e-131 0.00000000e+000
 0.00000000e+000 2.22951279e-121 1.97464698e-078 2.88827316e-037
 5.44533549e-286]


In [75]:
print(mix_probs)

[[0.38102161 0.31853812 0.30044027]
 [0.34213766 0.33981617 0.31804617]
 [0.33457538 0.35614995 0.30927467]
 [0.43881224 0.31768618 0.24350158]
 [0.36701398 0.41494335 0.21804267]
 [0.36831803 0.39822807 0.2334539 ]
 [0.41111853 0.37520557 0.2136759 ]
 [0.37175665 0.309885   0.31835835]
 [0.36486899 0.32116195 0.31396906]
 [0.41046241 0.30773863 0.28179896]
 [0.3747033  0.4456408  0.1796559 ]
 [0.35821777 0.55163966 0.09014257]
 [0.27028351 0.67446183 0.05525466]
 [0.36600794 0.52633749 0.10765458]
 [0.35452917 0.44880993 0.1966609 ]
 [0.40700891 0.32196844 0.27102265]
 [0.41735839 0.32990915 0.25273246]
 [0.36529366 0.330786   0.30392034]
 [0.33694331 0.34171623 0.32134045]
 [0.51864232 0.4342015  0.04715617]
 [0.34063271 0.39377259 0.26559469]
 [0.37062639 0.31540087 0.31397274]
 [0.5351368  0.32890758 0.13595561]
 [0.41190759 0.31267058 0.27542183]
 [0.46683289 0.33053374 0.20263337]
 [0.36866516 0.32207533 0.30925951]
 [0.40448399 0.37253151 0.2229845 ]
 [0.43392179 0.37685668 0.18

In [87]:
jnp.sum(c_gamma_probs * mix_probs, axis=-1)

Array([0.00047225, 0.00471289, 0.00538641, 0.00126069, 0.011568  ,
       0.00558727, 0.00132366, 0.00179196, 0.00106369, 0.00178111,
       0.01401595, 0.05355214, 0.08541373, 0.02755797, 0.00031838,
       0.00036081, 0.0009707 , 0.00020808, 0.00198526, 0.00162107,
       0.0304282 , 0.00363625, 0.00389488, 0.00035686, 0.00100218,
       0.00320803, 0.00405583, 0.00410189, 0.00135668], dtype=float64)

In [88]:
c_gamma_probs = c_gamma_prob_v(delay_time, av, bv, 3.0, 10.0)

In [90]:
jnp.sum(c_gamma_probs * mix_probs, axis=-1)

Array([0.00047225, 0.00471289, 0.00538641, 0.00126069, 0.011568  ,
       0.00558727, 0.00132366, 0.00179196, 0.00106369, 0.00178111,
       0.01401595, 0.05355214, 0.08541373, 0.02755797, 0.00031838,
       0.00036081, 0.0009707 , 0.00020808, 0.00198526, 0.00162107,
       0.0304282 , 0.00363625, 0.00389488, 0.00035686, 0.00100218,
       0.00320803, 0.00405583, 0.00410189, 0.00135668], dtype=float64)

In [103]:
@jax.jit
def test(x, m, a, b, sigma, delta):
    return jnp.sum(m * c_gamma_prob(x, a, b, sigma, delta), axis=-1)

test_v = jax.jit(jax.vmap(test, (0, 0, 0, 0, None, None), 0))

In [104]:
test(delay_time[0], mix_probs[0], av[0], bv[0], 3.0, 10.0)

Array(0.00047225, dtype=float64)

In [108]:
print(test_v(delay_time, mix_probs, av, bv, 3.0, 10.0))

[0.00047225 0.00471289 0.00538641 0.00126069 0.011568   0.00558727
 0.00132366 0.00179196 0.00106369 0.00178111 0.01401595 0.05355214
 0.08541373 0.02755797 0.00031838 0.00036081 0.0009707  0.00020808
 0.00198526 0.00162107 0.0304282  0.00363625 0.00389488 0.00035686
 0.00100218 0.00320803 0.00405583 0.00410189 0.00135668]


In [109]:
%timeit test_v(delay_time, mix_probs, av, bv, 3.0, 10.0)

8.87 ms ± 32.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
