In [1]:
import pandas as pd
import sys, os
import numpy as np

sys.path.insert(0, "/home/storage/hans/jax_reco_new")

In [2]:
from lib.simdata_i3 import I3SimHandler

In [3]:
# Get an IceCube event.

event_index = 0
bp = '/home/storage2/hans/i3files/21217'
sim_handler = I3SimHandler(os.path.join(bp, 'meta_ds_21217_from_35000_to_53530.ftr'),
                              os.path.join(bp, 'pulses_ds_21217_from_35000_to_53530.ftr'),
                              '/home/storage/hans/jax_reco/data/icecube/detector_geometry.csv')
meta, pulses = sim_handler.get_event_data(event_index)
# Get dom locations, first hit times, and total charges (for each dom).
event_data = sim_handler.get_per_dom_summary_from_sim_data(meta, pulses)

In [32]:
n_pulses = 1
event_data_all = sim_handler.get_per_dom_summary_extended_from_index(0, n_pulses=n_pulses)

In [33]:
event_data.shape

(29, 6)

In [34]:
from lib.cgamma import c_multi_gamma_prob_v
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

In [35]:
# TriplePandelSPE/JAX stuff
from lib.geo import center_track_pos_and_time_based_on_data
from lib.network import get_network_eval_v_fn
from dom_track_eval import get_eval_network_doms_and_track

In [36]:
def get_neg_c_triple_gamma_llh(eval_network_doms_and_track_fn, n_pulses):
    """
    here would be a smart docstring
    """

    #@jax.jit
    def neg_c_triple_gamma_llh(track_direction,
                               track_vertex,
                               track_time,
                               event_data):

        # Constant parameters.
        sigma = 3.0 # width of gaussian convolution
        X_safe = 20.0 # when to stop evaluating negative time residuals in units of sigma
        delta = 0.1 # how to combine the three regions that combine approximate and exact evaluation of hyp1f1 (required for convolutions). Small values are faster. Large values are more accurate.


        dom_pos = event_data[:, :3]
        hit_times = event_data[:, 3:3+n_pulses]
        logits, av, bv, geo_time = eval_network_doms_and_track_fn(dom_pos, track_vertex, track_direction)

        geo_time = geo_time.reshape((geo_time.shape[0], 1))
        mix_probs = jax.nn.softmax(logits)
        delay_time = hit_times - (geo_time + track_time)

        # Floor on negative time residuals.
        # Effectively a floor on the pdf.
        # Todo: think about noise.
        safe_delay_time = jnp.where(delay_time > -X_safe * sigma, delay_time, -X_safe * sigma)

        # re-arrange so that dims are (n_doms, n_pulses, n_mixture_components)
        safe_delay_time = jnp.expand_dims(safe_delay_time, 2)
        
        mix_probs = jnp.expand_dims(mix_probs, 1)
        av = jnp.expand_dims(av, 1)
        bv = jnp.expand_dims(bv, 1)

        y = c_multi_gamma_prob_v(safe_delay_time,
                                 mix_probs,
                                 av,
                                 bv,
                                 sigma,
                                 delta)
        
        return -2.0 * jnp.sum(jnp.log(y))

    return neg_c_triple_gamma_llh

In [37]:
def get_neg_c_triple_gamma_llh_padded(eval_network_doms_and_track_fn, n_pulses):
    """
    here would be a smart docstring
    """

    #@jax.jit
    def neg_c_triple_gamma_llh(track_direction,
                               track_vertex,
                               track_time,
                               event_data):

        # Constant parameters.
        sigma = 3.0 # width of gaussian convolution
        X_safe = 20.0 # when to stop evaluating negative time residuals in units of sigma
        delta = 0.1 # how to combine the three regions that combine approximate and exact evaluation of hyp1f1 (required for convolutions). Small values are faster. Large values are more accurate.


        dom_pos = event_data[:, :3]
        hit_times = event_data[:, 3:3+n_pulses]
        # treat padded values in time dimension
        hit_charges = event_data[:, 3+n_pulses:]
        idx_padded_q = hit_charges != 0.0
        logits, av, bv, geo_time = eval_network_doms_and_track_fn(dom_pos, track_vertex, track_direction)

        # treat padded values for doms dimension
        idx_padded = event_data[:, 0] != 0.0
        idx_padded_s = idx_padded.reshape((idx_padded.shape[0], 1))
        # replace padded values with some computable outputs that don't lead to nan.
        logits = jnp.where(idx_padded_s, logits, jnp.ones((1, 3)))
        av = jnp.where(idx_padded_s, av, jnp.ones((1, 3))+3.0)
        bv = jnp.where(idx_padded_s, bv, jnp.ones((1, 3))*1.e-3)

        mix_probs = jax.nn.softmax(logits)

        # now prepare for broadcasting over time axis
        geo_time = jnp.expand_dims(geo_time, 1)
        delay_time = hit_times - (geo_time + track_time)

        # Floor on negative time residuals.
        # Effectively a floor on the pdf.
        # Todo: think about noise.
        safe_delay_time = jnp.where(delay_time > -X_safe * sigma, delay_time, -X_safe * sigma)

        # re-arrange so that dims are (n_doms, n_pulses, n_mixture_components)
        safe_delay_time = jnp.expand_dims(safe_delay_time, 2)
        
        mix_probs = jnp.expand_dims(mix_probs, 1)
        av = jnp.expand_dims(av, 1)
        bv = jnp.expand_dims(bv, 1)

        y = jnp.where(idx_padded_q, 
                      jnp.log(c_multi_gamma_prob_v(safe_delay_time,
                                             mix_probs,
                                             av,
                                             bv,
                                             sigma,
                                             delta)),
                      0.0)
                      
        
        return -2.0 * jnp.sum(y)

    return neg_c_triple_gamma_llh

In [38]:
# Get network and eval logic.
eval_network_v = get_network_eval_v_fn(bpath='/home/storage/hans/jax_reco/data/network')
eval_network_doms_and_track = get_eval_network_doms_and_track(eval_network_v)

In [39]:
neg_llh = get_neg_c_triple_gamma_llh_padded(eval_network_doms_and_track, n_pulses)

In [40]:
# Make MCTruth seed.
track_pos = jnp.array([meta['muon_pos_x'], meta['muon_pos_y'], meta['muon_pos_z']])
track_time = meta['muon_time']
track_zenith = meta['muon_zenith']
track_azimuth = meta['muon_azimuth']
track_src = jnp.array([track_zenith, track_azimuth])

print("original seed vertex:", track_pos)
centered_track_pos, centered_track_time = center_track_pos_and_time_based_on_data(event_data, track_pos, track_time, track_src)
print("shifted seed vertex:", centered_track_pos)

original seed vertex: [-1277.51128861 -1390.39564543 -1675.98024553]
shifted seed vertex: [ -53.74394146  162.12452256 -233.73599134]


In [41]:
neg_llh(track_src, centered_track_pos, centered_track_time, event_data_all)

(1, 1)
(1, 3)


Array(347.80341404, dtype=float64)