In [1]:
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import sys, os
sys.path.insert(0, "/home/storage/hans/jax_reco")

import logging
logging.getLogger("jax").setLevel(logging.ERROR)

In [2]:
from lib.sim_data_i3 import I3SimHandlerFtr
from lib.geo import center_track_pos_and_time_based_on_data

from lib.network import get_network_eval_v_fn
#from lib.network import get_network_matmul_eval_v_fn as get_network_eval_v_fn

from dom_track_eval import get_eval_network_doms_and_track
from time_sampler import sample_times_clean



In [3]:
key = jax.random.PRNGKey(2)

dtype = jnp.float64

In [4]:
eval_network_v = get_network_eval_v_fn(bpath='/home/storage/hans/jax_reco/data/network', dtype=dtype)
eval_network_doms_and_track = get_eval_network_doms_and_track(eval_network_v, dtype=dtype)

In [5]:
bp = '/home/storage2/hans/i3files/21217'
sim_handler = I3SimHandlerFtr(os.path.join(bp, 'meta_ds_21217_from_35000_to_53530.ftr'),
                              os.path.join(bp, 'pulses_ds_21217_from_35000_to_53530.ftr'),
                              '/home/storage/hans/jax_reco/data/icecube/detector_geometry.csv')

# Get a simulated muon event.
event_index = 5
meta, pulses = sim_handler.get_event_data(event_index)
print(f"muon energy: {meta['muon_energy_at_detector']/1.e3:.1f} TeV")

# Get dom locations, first hit times, and total charges (for each dom).
event_data = sim_handler.get_per_dom_summary_from_sim_data(meta, pulses)

print("n_doms", len(event_data))

muon energy: 3.4 TeV
n_doms 43


In [6]:
# Let's generate some new first hit times following our triple pandel model.
# (avoid problems with time smearing for now -> to be implemented: gaussian convoluted triple pandel.)

track_pos = jnp.array([meta['muon_pos_x'], meta['muon_pos_y'], meta['muon_pos_z']])
track_time = meta['muon_time']
track_zenith = meta['muon_zenith']
track_azimuth = meta['muon_azimuth']
track_src = jnp.array([track_zenith, track_azimuth])

print("old track vertex:", track_pos)
centered_track_pos, centered_track_time = center_track_pos_and_time_based_on_data(event_data, track_pos, track_time, track_src)
print("new track vertex:", centered_track_pos)

key, subkey = jax.random.split(key)
first_times = sample_times_clean(eval_network_doms_and_track, event_data, track_pos, track_src, track_time, subkey)

old track vertex: [ 143.40123598  106.94081331 -815.20509865]
new track vertex: [ -25.42355091  403.14288541 -154.63778486]


In [7]:
# Create some n_photons from qtot (by rounding up).
n_photons = np.round(event_data['charge'].to_numpy()+0.5)

# Combine into single data tensor for fitting.
fake_event_data = jnp.column_stack([jnp.array(event_data[['x', 'y', 'z']].to_numpy()), 
                                    jnp.array(first_times), 
                                    jnp.array(n_photons)])
print(fake_event_data.shape)

# Send to GPU.
fake_event_data.devices()
centered_track_pos.devices()
centered_track_time.devices()
track_src.devices()

(43, 5)


{cuda(id=0)}

In [8]:
# padding

fake_event_data_padded = jnp.pad(fake_event_data, ((0, 50-fake_event_data.shape[0]), (0 ,0)))

In [9]:
fake_event_data_padded.shape

(50, 5)

In [10]:
print(fake_event_data_padded[43:])

[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]


In [11]:
from likelihood_const_vertex_vectorized import get_neg_mpe_llh_const_vertex_v2_padded
neg_llh_padded = get_neg_mpe_llh_const_vertex_v2_padded(eval_network_doms_and_track)

In [12]:
print(neg_llh_padded(track_src, centered_track_pos, centered_track_time, fake_event_data_padded))

470.3698367909118


In [16]:
from jax.scipy import optimize

@jax.jit
def minimize(x0, track_vertex, track_time, event_data):
    return optimize.minimize(neg_llh_padded, x0, args=(track_vertex, track_time, event_data), method="BFGS", tol=1.e-5)

result = minimize(track_src, centered_track_pos, centered_track_time, fake_event_data_padded)

In [17]:
print(result)

OptimizeResults(x=Array([2.66344555, 5.23024581], dtype=float64), success=Array(False, dtype=bool), status=Array(3, dtype=int64, weak_type=True), fun=Array(469.41627225, dtype=float64), jac=Array([ 7.04646980e-05, -9.19381577e-06], dtype=float64), hess_inv=Array([[1.62390920e-06, 1.40989147e-06],
       [1.40989147e-06, 1.23676182e-05]], dtype=float64), nfev=Array(61, dtype=int64, weak_type=True), njev=Array(61, dtype=int64, weak_type=True), nit=Array(12, dtype=int64, weak_type=True))


In [15]:
%timeit -n5 minimize(track_src, centered_track_pos, centered_track_time, fake_event_data_padded).block_until_ready()

218 ms ± 1.41 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)


In [27]:
from likelihood_const_vertex_vectorized import get_neg_mpe_llh_const_vertex_v2
neg_llh = get_neg_mpe_llh_const_vertex_v2(eval_network_doms_and_track)

In [28]:
print(neg_llh(track_src, centered_track_pos, centered_track_time, fake_event_data))

470.3698367909118


In [29]:
from jax.scipy import optimize

@jax.jit
def minimize(x0, track_vertex, track_time, event_data):
    return optimize.minimize(neg_llh, x0, args=(track_vertex, track_time, event_data), method="BFGS", tol=1.e-5).x

result = minimize(track_src, centered_track_pos, centered_track_time, fake_event_data)


In [30]:
print(result)

[2.66344555 5.23024581]


In [31]:
%timeit -n5 minimize(track_src, centered_track_pos, centered_track_time, fake_event_data).block_until_ready()

209 ms ± 1.92 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)


In [None]:
print(result)

In [None]:
print(track_src)

In [None]:
#from likelihood_const_vertex import get_neg_mpe_llh_const_vertex
#neg_llh = get_neg_mpe_llh_const_vertex(eval_network_doms_and_track, fake_event_data, centered_track_pos, centered_track_time)

In [None]:
#from jax.scipy import optimize
#
#@jax.jit
#def minimize(x0):
#    return optimize.minimize(neg_llh, x0, method="BFGS", tol=1.e-5)

#result = minimize(track_src)

In [None]:
print(fake_event_data.shape)

In [None]:
minimize_v = jax.jit(jax.vmap(minimize, (0, 0, 0, 0), 0))

In [None]:
n_ev = 1000

centered_track_pos_v = jnp.repeat(centered_track_pos.reshape(1,3), n_ev, axis=0)
print(centered_track_pos_v.shape)

centered_track_time_v = jnp.repeat(centered_track_time, n_ev)
print(centered_track_time_v.shape)

track_src_v = jnp.repeat(track_src.reshape(1,2), n_ev, axis=0)
print(track_src_v.shape)

fake_event_data_v = jnp.repeat(fake_event_data.reshape(1, fake_event_data.shape[0], fake_event_data.shape[1]), n_ev, axis=0)
print(fake_event_data_v.shape)

In [None]:
result_v = minimize_v(track_src_v, centered_track_pos_v, centered_track_time_v, fake_event_data_v)

In [None]:
print(result_v)

In [None]:
%timeit -n5 minimize_v(track_src_v, centered_track_pos_v, centered_track_time_v, fake_event_data_v).block_until_ready()

In [None]:
# 2.17s for 1000 event reconstructions

In [19]:
dom_pos = fake_event_data_padded[:, :3]
print(dom_pos[40:])
logits, av, bv, geo_time = eval_network_doms_and_track(dom_pos, centered_track_pos, track_src)

[[-101.06  490.22   -9.14]
 [-101.06  490.22 -196.37]
 [  22.11  509.5  -295.31]
 [   0.      0.      0.  ]
 [   0.      0.      0.  ]
 [   0.      0.      0.  ]
 [   0.      0.      0.  ]
 [   0.      0.      0.  ]
 [   0.      0.      0.  ]
 [   0.      0.      0.  ]]


In [21]:
print(logits[40:])

[[ 0.93037825  0.97019836  1.01150626]
 [ 0.98164109  0.7572359  -0.27043317]
 [ 1.18283628  0.53470128  0.00219727]
 [ 1.95618561 -0.06914704  1.62119527]
 [ 1.95618561 -0.06914704  1.62119527]
 [ 1.95618561 -0.06914704  1.62119527]
 [ 1.95618561 -0.06914704  1.62119527]
 [ 1.95618561 -0.06914704  1.62119527]
 [ 1.95618561 -0.06914704  1.62119527]
 [ 1.95618561 -0.06914704  1.62119527]]


In [22]:
print(av[40:])

[[ 3.62614324  5.8580399   1.83043423]
 [ 6.57806378  8.45045441  5.09057448]
 [ 5.45248613  6.65534714  4.45239311]
 [20.08752932 18.82070978 20.69526838]
 [20.08752932 18.82070978 20.69526838]
 [20.08752932 18.82070978 20.69526838]
 [20.08752932 18.82070978 20.69526838]
 [20.08752932 18.82070978 20.69526838]
 [20.08752932 18.82070978 20.69526838]
 [20.08752932 18.82070978 20.69526838]]


In [23]:
print(bv[40:])

[[0.0246426  0.09617788 0.00340015]
 [0.00896593 0.02044258 0.00379655]
 [0.00590327 0.01235341 0.00296858]
 [0.00511689 0.00469212 0.00561909]
 [0.00511689 0.00469212 0.00561909]
 [0.00511689 0.00469212 0.00561909]
 [0.00511689 0.00469212 0.00561909]
 [0.00511689 0.00469212 0.00561909]
 [0.00511689 0.00469212 0.00561909]
 [0.00511689 0.00469212 0.00561909]]


In [24]:
print(geo_time[40:])

[ 722.27043866  402.79418785  143.97842871 1156.0242326  1156.0242326
 1156.0242326  1156.0242326  1156.0242326  1156.0242326  1156.0242326 ]


In [25]:
from scipy.special import softmax
softmax([ 1.95618561, -0.06914704,  1.62119527])

array([0.54133207, 0.07142874, 0.38723919])

In [92]:

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions


def get_neg_mpe_llh_const_vertex_v2_padded(eval_network_doms_and_track_fn, eps=jnp.float64(1.e-20), dtype=jnp.float64):

    def neg_mpe_llh_direction(track_direction, track_vertex, track_time, event_data):
        """
        track_direction: (zenith, azimuth) in radians
        track_vertex: (x, y, z)
        track_time: t (this time defines the fit vertex)
        event_data: 2D array (n_doms X 5) where columns are x,y,z of dom location, and t for first hit time, and estimated number of photon hits from Qtot.
        """

        dom_pos = event_data[:, :3]
        first_hit_times = event_data[:, 3]
        n_photons = event_data[:, 4]
        
        logits, av, bv, geo_time = eval_network_doms_and_track_fn(dom_pos, track_vertex, track_direction)
        delay_time = first_hit_times - (geo_time + track_time) 

        #logits = jnp.where(dom_pos != 0, logits, 1.0)
        #av = jnp.where(dom_pos != 0, av, 9.0)
        #bv = jnp.where(dom_pos != 0, bv, 2.0)
        delay_time = jnp.where(n_photons > 0, delay_time, 5.0)

        gm = tfd.MixtureSameFamily(
                  mixture_distribution=tfd.Categorical(
                      logits=logits
                      ),
                  components_distribution=tfd.Gamma(
                    concentration=av,
                    rate=bv,
                    force_probs_to_zero_outside_support=True
                      )
                )

        prob = jnp.where(n_photons > 0, n_photons * gm.prob(delay_time) * (1-gm.cdf(delay_time))**(n_photons-1) + eps, 1.)
        #llh = jnp.sum(jnp.where(n_photons > 0., jnp.log(prob), 0.0))
        llh = jnp.sum(jnp.log(prob))

        return -2*llh

    return neg_mpe_llh_direction

In [93]:
neg_llh_padded = get_neg_mpe_llh_const_vertex_v2_padded(eval_network_doms_and_track)

neg_llh_padded_grad = jax.grad(neg_llh_padded)

In [94]:
print(neg_llh_padded(track_src, centered_track_pos, centered_track_time, fake_event_data_padded))

470.3698367909118


In [95]:
print(neg_llh_padded_grad(track_src, centered_track_pos, centered_track_time, fake_event_data_padded))

[1102.27517602  -96.70991096]


In [96]:
fake_event_data_padded[40:, :3]

Array([[-101.06,  490.22,   -9.14],
       [-101.06,  490.22, -196.37],
       [  22.11,  509.5 , -295.31],
       [   0.  ,    0.  ,    0.  ],
       [   0.  ,    0.  ,    0.  ],
       [   0.  ,    0.  ,    0.  ],
       [   0.  ,    0.  ,    0.  ],
       [   0.  ,    0.  ,    0.  ],
       [   0.  ,    0.  ,    0.  ],
       [   0.  ,    0.  ,    0.  ]], dtype=float64)