In [1]:
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)

import numpy as np
import scipy
import matplotlib.pyplot as plt

import sys, os

In [2]:
from tensorflow_probability.substrates import jax as tfp

In [3]:

sys.path.insert(0, "/home/storage/hans/jax_reco_new")
from lib.c_mpe_gamma import c_multi_gamma_mpe_prob_v, c_multi_gamma_mpe_prob
from lib.plotting import adjust_plot_1d
from lib.simdata_i3 import I3SimHandler
from lib.geo import center_track_pos_and_time_based_on_data
from lib.network import get_network_eval_v_fn
from lib.experimental_methods import remove_early_pulses
from likelihood_conv_mpe import get_neg_c_triple_gamma_llh
from dom_track_eval import get_eval_network_doms_and_track as get_eval_network_doms_and_track

In [4]:
from lib.geo import cherenkov_cylinder_coordinates_w_rho_v
from lib.geo import get_xyz_from_zenith_azimuth

In [13]:


dtype = jnp.float32
event_index = 2

# Get network and eval logic.
eval_network_v = get_network_eval_v_fn(bpath='/home/storage/hans/jax_reco/data/network', dtype=dtype)
eval_network_doms_and_track = get_eval_network_doms_and_track(eval_network_v, dtype=dtype)



In [14]:


# Get an IceCube event.
bp = '/home/storage2/hans/i3files/21217'
sim_handler = I3SimHandler(os.path.join(bp, 'meta_ds_21217_from_35000_to_53530.ftr'),
                              os.path.join(bp, 'pulses_ds_21217_from_35000_to_53530.ftr'),
                              '/home/storage/hans/jax_reco/data/icecube/detector_geometry.csv')

meta, pulses = sim_handler.get_event_data(event_index)
print(f"muon energy: {meta['muon_energy_at_detector']/1.e3:.1f} TeV")

# Get dom locations, first hit times, and total charges (for each dom).
event_data = sim_handler.get_per_dom_summary_from_sim_data(meta, pulses)

print("n_doms", len(event_data))



muon energy: 4.7 TeV
n_doms 102


In [15]:
# Make MCTruth seed.
track_pos = jnp.array([meta['muon_pos_x'], meta['muon_pos_y'], meta['muon_pos_z']])
track_time = meta['muon_time']
track_zenith = meta['muon_zenith']
track_azimuth = meta['muon_azimuth']
track_src = jnp.array([track_zenith, track_azimuth])

print("original seed vertex:", track_pos)
centered_track_pos, centered_track_time = center_track_pos_and_time_based_on_data(event_data, 
                                                                                  track_pos, 
                                                                                  track_time, 
                                                                                  track_src)
print("shifted seed vertex:", centered_track_pos)

original seed vertex: [ -777.15166078 -1656.22843231 -1472.45624098]
shifted seed vertex: [-164.52122541  320.02418746 -330.17880541]


In [16]:


fitting_event_data = jnp.array(event_data[['x', 'y', 'z', 'time', 'charge']].to_numpy())
print(fitting_event_data.shape)

fitting_event_data = remove_early_pulses(eval_network_doms_and_track,
                                         fitting_event_data,
                                         centered_track_pos,
                                         track_src,
                                         centered_track_time)
print(fitting_event_data.shape)



(102, 5)
(102, 5)


In [17]:
def get_neg_c_triple_gamma_llh_mpe(eval_network_doms_and_track_fn):
    """
    here would be a smart docstring
    """

    @jax.jit
    def neg_c_triple_gamma_llh(track_direction,
                               track_vertex,
                               track_time,
                               event_data):

        # Constant parameters.
        sigma = 2.0 # width of gaussian convolution
        X_safe = 20.0 # when to stop evaluating negative time residuals in units of sigma
        nmax = 20
        nint = 41

        dom_pos = event_data[:, :3]
        first_hit_times = event_data[:, 3]
        charges = event_data[:, 4]
        n_photons = jnp.round(charges + 0.5)

        logits, av, bv, geo_time = eval_network_doms_and_track_fn(dom_pos, track_vertex, track_direction)
        delay_time = first_hit_times - (geo_time + track_time)

        # Floor on negative time residuals.
        # Effectively a floor on the pdf.
        # Todo: think about noise.
        safe_delay_time = jnp.where(delay_time > -X_safe * sigma, delay_time, -X_safe * sigma)

        probs = c_multi_gamma_mpe_prob_v(safe_delay_time,
                                     logits,
                                     av,
                                     bv,
                                     n_photons,
                                     sigma,
                                     nmax,
                                     nint)

        return -2.0 * jnp.sum(jnp.log(probs))



    return neg_c_triple_gamma_llh

In [18]:


neg_llh_mpe = get_neg_c_triple_gamma_llh_mpe(eval_network_doms_and_track)
neg_llh_mpe_grad = jax.grad(neg_llh_mpe)

for index in range(len(fitting_event_data)):
    llh_val = neg_llh_mpe(track_src, centered_track_pos, centered_track_time, fitting_event_data[index: index+1])
    llh_grad = neg_llh_mpe_grad(track_src, centered_track_pos, centered_track_time, fitting_event_data[index: index+1])
    if np.any(np.isnan(llh_grad)):
        print("DOM", index)
        print(llh_val, llh_grad)



In [19]:
def get_dom_info(track_direction, track_vertex, track_time, event_data):
        
        # Constant parameters.
        sigma = 2.0 # width of gaussian convolution
        X_safe = 20.0 # when to stop evaluating negative time residuals in units of sigma

        dom_pos = event_data[:, :3]
        first_hit_times = event_data[:, 3]
        charges = event_data[:, 4]
        n_photons = jnp.round(charges + 0.5)
        print("n_hits:", n_photons)

        track_dir_xyz = get_xyz_from_zenith_azimuth(track_direction)
        geo_time, closest_approach_dist, closest_approach_z, closest_approach_rho = \
            cherenkov_cylinder_coordinates_w_rho_v(dom_pos, track_vertex, track_dir_xyz)

        print("dist, z, rho=", closest_approach_dist, closest_approach_z, closest_approach_rho)
        
        logits, av, bv, geo_time = eval_network_doms_and_track(dom_pos, track_vertex, track_direction)
        print("gamma_a:", av)
        print("gamma_b:", bv)
        mix_probs = jax.nn.softmax(logits)
        delay_time = first_hit_times - (geo_time + track_time)

        # Floor on negative time residuals.
        # Effectively a floor on the pdf.
        # Todo: think about noise.
        safe_delay_time = jnp.where(delay_time > -X_safe * sigma, delay_time, -X_safe * sigma)
        print("delay time:", safe_delay_time)

        probs = c_multi_gamma_mpe_prob_v(safe_delay_time, 
                                     logits,
                                     av,
                                     bv,
                                     n_photons,
                                     sigma,
                                     20,
                                     41)

        print("pdf:", probs)

        return mix_probs, av, bv, safe_delay_time

In [20]:
index = 4
mix_probs, gamma_a, gamma_b, delay_time = get_dom_info(track_src, centered_track_pos, centered_track_time, fitting_event_data[index: index+1])

n_hits: [1.]
dist, z, rho= [96.18811464] [-549.97664403] [-0.32022561]
gamma_a: [[3.84631443 5.87746286 2.27034545]]
gamma_b: [[0.00949955 0.03401769 0.00209393]]
delay time: [138.70062895]
pdf: [0.00236885]
