In [12]:
def get_fitter(neg_llh, 
                use_multiple_vertex_seeds=False, 
                prescan_time=False,
                scale=100.0,
                scale_rad=50.0,
                rtol=1e-8, 
                atol=1e-4):
    """Creates a fitter() function that performs a 5D likelihood optimization. 
    Note: The argumenst are used globally
    within functions defined in the body.
    
    Args
    ----
        neg_llh: A valid negative log likelihood function.
        use_multiple_vertex_seeds: If True, the fitter will perform the 
            reconstruction starting from additional vertex seeds.
            This increases robustness against local minima
            but increases the run-time. 
        prescan_time: If True, the fitter will search for the best time 
            for a given vertex seed.
            This can improve convergence at the cost of run-time.
        scale: Re-scales the vertex coordinate units during optimization.
            (maps step size of 1 to value of scale).
        scale_rad: Re-scales the directional coordinate units during optimization.
            (maps step size of 1 to value of scale).
        rtol: relative tolerance of the optimizer (see optimistix.BFGS)
        atol: abolsute tolerance of the optimizer (see optimistix.BFGS)
            
    Returns
    -------
        fitter: A function that performs the likelihood fit
            according with the given properties / behavior.
    """
    
    # Vectorize likelihood along time argument.
    # Reminder: arguments are (direction, vertex, time, data).
    neg_llh_time_v = jax.vmap(neg_llh, (None, None, 0, None), 0)
    
    def get_track_time(track_dir, track_vertex, seed_time, data):
        """Find time that best matches the given vertex seed 
        and track direction. I.e. the time that yields the lowest 
        log likelihood value for the given track parameters.
        
        Args
        ----
            track_dir: jnp.array([zenith, azimuth]) in radians
            track_vertex: jnp.array([x, y, z]) in m
            seed_time: jnp.array(t) in ns
            data: jnp.array(data) with shape 
                (n_sensors, n_features) = (N, 5)
                
        Returns
        -------
            best_time: jnp.array(float)
        """
        dt = 100. # we search 100ns before and after seed_time
        nt = 20 # number of evaluation points
        time = jnp.linspace(seed_time - dt, seed_time + dt, nt)
        llh = neg_llh_time_v(track_dir, track_vertex, time, data)
        
        return time[jnp.argmin(llh, axis=0)]
    
    # Vectorize across vertex dimension. This allows performing
    # this operation for multiple vertex seeds at the 
    # given track direction.
    # Reminder: arguments are (direction, vertex, time, data).
    get_track_time_v = jax.vmap(
        get_track_time, 
        (None, 0, None, None), 
        0
    )

    # Define the likelihood function for the 5D optimization
    def neg_llh_5D(x, args):
        # project back if outside of [0, pi] x [0, 2*pi]
        zenith = x[0] / scale_rad
        azimuth = x[1] / scale_rad
        zenith = jnp.fmod(zenith, 2.0*jnp.pi)
        zenith = jnp.where(zenith < 0, zenith+2.0*jnp.pi, zenith)
        cond = zenith > jnp.pi
        zenith = jnp.where(cond, -1.0*zenith+2.0*jnp.pi, zenith)
        azimuth = jnp.where(cond, azimuth-jnp.pi, azimuth)

        azimuth = jnp.fmod(azimuth, 2.0*jnp.pi)
        azimuth = jnp.where(azimuth < 0, azimuth+2.0*jnp.pi, azimuth)
        projected_dir = jnp.array([zenith, azimuth])
        
        track_time, data = args
        return neg_llh(projected_dir, x[2:]*scale, track_time, data)

    
    def reconstruct_event(track_vertex_seed, track_dir_seed, track_time, data):
        """Performs a single event reconstruction.
        
        Args
        ----
        track_vertex_seed: jnp.array([x, y, z]) in m
        track_dir_seed: jnp.array([zenith, azimuth]) in radians
        track_time: jnp.array(t) in ns
        data: jnp.array(data) with shape 
                (n_sensors, n_features) = (N, 5)
            
        Returns
        -------
            Best-fit Negative loglikelihood value (neglogl) and 
            corresponding coordinates as tuple(jnp.array, jnp.array([zenith, azimuth, x,y,z]).
        """
        solver = optx.BestSoFarMinimiser(optx.BFGS(rtol=rtol, atol=atol, use_inverse=True))
        args = (track_time, data)
        x0 = jnp.concatenate([track_dir_seed*scale_rad, track_vertex_seed/scale])
        sol = optx.minimise(neg_llh_5D, 
                            solver, 
                            x0, 
                            args=args, 
                            throw=False).value
        
        sol_dir = sol[:2] / scale_rad
        sol_pos = sol[2:] * scale
        
        return neg_llh_5D(sol, args), sol_dir, sol_pos
    
    # Vectorize over vertex argument
    reconstruct_event_v = jax.vmap(reconstruct_event, (0, None, None, None), 0)
    
    # Vectorize over vertex and time arguments
    reconstruct_event_vt = jax.vmap(reconstruct_event, (0, None, 0, None), 0)
    
    def run_reconstruction(track_dir_seed, vertex_seed, track_time, data):
        """Wraps a single reconstruction for a given track direction seed.
        Allows reconstructing that event multiple times with different
        track seed values. And provides possibility to adjust the corresponding 
        time constant to provide best starting conditions for the 
        reconstruction.
        
        Args
        ----
        
        Returns
        -------
            
        """
        
        if use_multiple_vertex_seeds:
            # Get additional vertex seeds using cylindrical geometry
            vertex_seeds = get_vertex_seeds(vertex_seed, track_dir_seed)

            if prescan_time:
                # For each vertex seed, we should optimize the track time.
                # i.e. we use the best-matching time for each vertex reconstruction
                seed_times = get_track_time_v(track_dir_seed, vertex_seeds, track_time, data)
                logls, dirs, verts = reconstruct_event_vt(vertex_seeds, track_dir_seed, seed_times, data)
            
            else:
                # Do not perform additional time matching. We reconstruct
                # each vertex seed with a fixed intial track time.
                logls, dirs, verts = reconstruct_event_v(vertex_seeds, track_dir_seed, track_time, data)
                seed_times = jnp.ones(vertex_seeds.shape[0]) * track_time
                
            # The solution is given by the fit with the best likelihood value
            # across all fits performed.
            ix = jnp.argmin(logls)
            return logls[ix], dirs[ix], verts[ix], seed_times[ix]
        
        # We are using only a single vertex seed
        if prescan_time:
            # Update time with best-match for given vertex_seed
            track_time = get_track_time(track_dir_seed, vertex_seeed, track_time, data)
            
        logl, direction, vertex = reconstruct_event(vertex_seed, track_dir_seed, track_time, data)
        return logl, direction, vertex, track_time
    
    # Vectorize directions vertices in grid
    #run_reconstruction_v = jax.vmap(run_reconstruction, (0, None, None, None), 0)
        
    return run_reconstruction

In [13]:
import sys, os
sys.path.insert(0, "/home/storage/hans/jax_reco_gupta_corrections4/")
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

from tensorflow_probability.substrates import jax as tfp

import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)
import optimistix as optx

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

# TriplePandelSPE/JAX stuff
from lib.simdata_i3 import I3SimHandler
from lib.geo import center_track_pos_and_time_based_on_data
from lib.gupta_network_eqx_4comp import get_network_eval_v_fn
from lib.experimental_methods import get_vertex_seeds

from dom_track_eval import get_eval_network_doms_and_track
from likelihood_conv_mpe_logsumexp_gupta import get_neg_c_triple_gamma_llh
from palettable.cubehelix import Cubehelix
cx =Cubehelix.make(start=0.3, rotation=-0.5, n=16, reverse=False, gamma=1.0,
                           max_light=1.0,max_sat=0.5, min_sat=1.4).get_mpl_colormap()

import time

In [21]:
# Event Index.
event_index = 2

# Get network and eval logic.
dtype = jnp.float64

eval_network_v = get_network_eval_v_fn(bpath='/home/storage/hans/photondata/gupta/ftpv1/n96_errscale1_32bit_4comp_update_regularized/cache/new_model_no_penalties_tree_start_epoch_1000.eqx', dtype=dtype, n_hidden=96)
eval_network_doms_and_track = get_eval_network_doms_and_track(eval_network_v, dtype=dtype, gupta=True, n_comp=4)

# Get an IceCube event.
bp = '/home/fast_storage/i3/22645/ftr/'

sim_handler = I3SimHandler(os.path.join(bp, f'meta_ds_22645_from_0_to_1000_10_to_100TeV.ftr'),
                                os.path.join(bp, f'pulses_ds_22645_from_0_to_1000_10_to_100TeV.ftr'),
                                '/home/storage/hans/jax_reco_new/data/icecube/detector_geometry.csv')

meta, pulses = sim_handler.get_event_data(event_index)
print(f"muon energy: {meta['muon_energy_at_detector']/1.e3:.1f} TeV")

# Get dom locations, first hit times, and total charges (for each dom).
event_data = sim_handler.get_per_dom_summary_from_sim_data(meta, pulses)

# Remove early pulses.
sim_handler.replace_early_pulse(event_data, pulses)
print("n_doms", len(event_data))

muon energy: 37.3 TeV
n_doms 77


In [22]:
# Get MCTruth.
true_pos = jnp.array([meta['muon_pos_x'], meta['muon_pos_y'], meta['muon_pos_z']])
true_time = meta['muon_time']
true_zenith = meta['muon_zenith']
true_azimuth = meta['muon_azimuth']
true_src = jnp.array([true_zenith, true_azimuth])
print("true direction:", true_src)

# Use SplineMPE as a seed.
track_pos = jnp.array([meta['spline_mpe_pos_x'], meta['spline_mpe_pos_y'], meta['spline_mpe_pos_z']])
track_time = meta['spline_mpe_time']
track_zenith = meta['spline_mpe_zenith']
track_azimuth = meta['spline_mpe_azimuth']
track_src = jnp.array([track_zenith, track_azimuth])
print("seed direction:", track_src)

print("original seed vertex:", track_pos)
centered_track_pos, centered_track_time = center_track_pos_and_time_based_on_data(event_data, track_pos, track_time, track_src)
print("shifted seed vertex:", centered_track_pos)

true direction: [1.57599182 2.28262353]
seed direction: [1.58173391 2.22337364]
original seed vertex: [ 477.84495319  267.44606706 -413.56920501]
shifted seed vertex: [ 547.41093803  176.4243718  -412.31612685]


In [23]:
fitting_event_data = jnp.array(event_data[['x', 'y', 'z', 'time', 'charge']].to_numpy())
print(fitting_event_data.shape)

# Setup likelihood.
neg_llh = get_neg_c_triple_gamma_llh(eval_network_doms_and_track)

(77, 5)


In [24]:
# Setup fitter.
fit_llh = get_fitter(neg_llh)

In [25]:
fit_llh(track_src, centered_track_pos, centered_track_time, fitting_event_data)

(Array(1051.13267881, dtype=float64),
 Array([1.58097487, 2.22337661], dtype=float64),
 Array([ 546.73316124,  173.68773598, -412.38368277], dtype=float64),
 Array(11054.31841689, dtype=float64))

In [26]:
fit_llh = get_fitter(neg_llh, use_multiple_vertex_seeds=True)
fit_llh(track_src, centered_track_pos, centered_track_time, fitting_event_data)

(Array(1042.50581372, dtype=float64),
 Array([1.58193127, 2.2828955 ], dtype=float64),
 Array([ 582.44798209,  200.18097542, -412.63000858], dtype=float64),
 Array(11054.31841689, dtype=float64))

In [27]:
fit_llh = get_fitter(neg_llh, use_multiple_vertex_seeds=True, prescan_time=True)
fit_llh(track_src, centered_track_pos, centered_track_time, fitting_event_data)

(Array(1042.50581527, dtype=float64),
 Array([1.58193004, 2.28289371], dtype=float64),
 Array([ 599.97320828,  179.87708047, -412.33059148], dtype=float64),
 Array(11143.7921011, dtype=float64))

In [28]:
sol_dir = jnp.array([1.58193127, 2.2828955])
jnp.rad2deg(sol_dir)

Array([ 90.63798525, 130.80027722], dtype=float64)