In [1]:
import os 
os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false "
                           "intra_op_parallelism_threads=1")
os.environ["CUDA_VISIBLE_DEVICES"] = ""

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)

import numpy as np
import scipy
import matplotlib.pyplot as plt


In [3]:
jax.devices("cpu")

[CpuDevice(id=0)]

In [4]:
import sys
sys.path.insert(0, "/home/storage/hans/jax_reco_new")

from lib.network import get_network_eval_v_fn
from lib.geo import center_track_pos_and_time_based_on_data
from lib.simdata_i3 import I3SimHandler
from dom_track_eval import get_eval_network_doms_and_track

In [5]:
dtype = jnp.float32
eval_network_v = get_network_eval_v_fn(bpath='/home/storage/hans/jax_reco_new/data/network', dtype=dtype)

In [6]:
eval_network_doms_and_track = get_eval_network_doms_and_track(eval_network_v, dtype=dtype)
eval_network_doms_and_track = jax.jit(eval_network_doms_and_track, device=jax.devices("cpu")[0])

In [7]:
event_index = 2

# Get an IceCube event.
bp = '/home/storage2/hans/i3files/21217'
sim_handler = I3SimHandler(os.path.join(bp, 'meta_ds_21217_from_35000_to_53530.ftr'),
                              os.path.join(bp, 'pulses_ds_21217_from_35000_to_53530.ftr'),
                              '/home/storage/hans/jax_reco/data/icecube/detector_geometry.csv')

meta, pulses = sim_handler.get_event_data(event_index)
print(f"muon energy: {meta['muon_energy_at_detector']/1.e3:.1f} TeV")

# Get dom locations, first hit times, and total charges (for each dom).
event_data = sim_handler.get_per_dom_summary_from_sim_data(meta, pulses)

print("n_doms", len(event_data))
fitting_event_data = jnp.array(event_data[['x', 'y', 'z', 'time', 'charge']].to_numpy())

muon energy: 4.7 TeV
n_doms 102


In [8]:
# Make MCTruth seed.
track_pos = jnp.array([meta['muon_pos_x'], meta['muon_pos_y'], meta['muon_pos_z']])
track_time = meta['muon_time']
track_zenith = meta['muon_zenith']
track_azimuth = meta['muon_azimuth']
track_src = jnp.array([track_zenith, track_azimuth])

print("original seed vertex:", track_pos)
centered_track_pos, centered_track_time = center_track_pos_and_time_based_on_data(event_data, 
                                                                                  track_pos, 
                                                                                  track_time, 
                                                                                  track_src)
print("shifted seed vertex:", centered_track_pos)

original seed vertex: [ -777.15166078 -1656.22843231 -1472.45624098]
shifted seed vertex: [-164.52122541  320.02418746 -330.17880541]


In [9]:
dom_pos = fitting_event_data[:, :3]
first_hit_times = fitting_event_data[:, 3]

logits, av, bv, geo_time = eval_network_doms_and_track(dom_pos, centered_track_pos, track_src)
mix_probs = jax.nn.softmax(logits)

delay_time = first_hit_times - (geo_time + centered_track_time)

In [10]:
def func():
    logits, av, bv, geo_time = eval_network_doms_and_track(dom_pos, centered_track_pos, track_src)
    return logits

In [11]:
print(func())

[[ 1.3099004   0.41508648  0.7899757 ]
 [ 1.08043289  0.63937539  0.78115928]
 [ 1.0732367   0.63215673  0.79428816]
 [ 1.10383368  0.6851685   0.71266651]
 [ 0.81204283  0.68980879  0.76411849]
 [ 0.82604986  0.79009175  0.65311944]
 [ 0.80182284  0.51668698  0.8372454 ]
 [ 1.01524544  0.82357275  0.98089504]
 [ 1.01920187 -0.04484183  0.74508947]
 [ 0.98598206  0.32891527  0.78397167]
 [ 1.14493775  0.74641258  0.65447938]
 [ 0.93164986  1.02077699  0.51619339]
 [ 0.92517084  1.37723231  0.20653334]
 [ 0.92823893  1.51335359  0.04566085]
 [ 1.05364347  0.03645623  0.67124033]
 [ 1.05058408  0.065938    0.71161842]
 [ 1.04519749  0.12827605  0.75358486]
 [ 1.05302382  0.23230958  0.7946915 ]
 [ 0.8986789   0.51733941  0.90101951]
 [ 1.22307026  0.20317674  0.67973828]
 [ 1.37939024  0.05905008  0.81567758]
 [ 0.9845953   0.28479028  0.77510369]
 [ 0.83352786  0.35420507  0.42862621]
 [ 0.77977115  0.3802537   0.40279236]
 [ 0.7197293   0.39994609  0.39455202]
 [ 0.66483128  0.42120051

In [12]:
%timeit func().block_until_ready()

962 μs ± 11 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [17]:
n_repeat = 14
dom_pos_v = jnp.repeat(jnp.expand_dims(dom_pos, axis=0), 14, axis=0) 
centered_track_pos_v = jnp.repeat(jnp.expand_dims(centered_track_pos, axis=0), 14, axis=0)
track_src_v = jnp.repeat(jnp.expand_dims(track_src, axis=0), 14, axis=0) 

In [18]:
eval_dom_track_pmap = jax.pmap(eval_network_doms_and_track, in_axes=(0,0,0), out_axes=0,
                              devices=jax.devices("cpu"))

In [19]:
eval_dom_track_pmap(dom_pos_v, centered_track_pos_v, track_src_v)

(Array([[[1.30989981, 0.41508666, 0.78997564],
         [1.08043253, 0.63937545, 0.78115976],
         [1.07323694, 0.63215667, 0.79428816],
         ...,
         [0.91866064, 0.44539762, 0.78513825],
         [0.87827969, 0.29374149, 0.79243755],
         [1.29587126, 0.34756616, 0.90942353]],
 
        [[1.30989981, 0.41508666, 0.78997564],
         [1.08043253, 0.63937545, 0.78115976],
         [1.07323694, 0.63215667, 0.79428816],
         ...,
         [0.91866064, 0.44539762, 0.78513825],
         [0.87827969, 0.29374149, 0.79243755],
         [1.29587126, 0.34756616, 0.90942353]],
 
        [[1.30989981, 0.41508666, 0.78997564],
         [1.08043253, 0.63937545, 0.78115976],
         [1.07323694, 0.63215667, 0.79428816],
         ...,
         [0.91866064, 0.44539762, 0.78513825],
         [0.87827969, 0.29374149, 0.79243755],
         [1.29587126, 0.34756616, 0.90942353]],
 
        ...,
 
        [[1.30989981, 0.41508666, 0.78997564],
         [1.08043253, 0.63937545, 0.78115

In [21]:
%timeit eval_dom_track_pmap(dom_pos_v, centered_track_pos_v, track_src_v)[0].block_until_ready()

2.92 ms ± 34.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [22]:
dom_pos_vv = jnp.concatenate(dom_pos_v)

In [23]:
eval_network_doms_and_track(dom_pos_vv, centered_track_pos, track_src)

(Array([[1.30990016, 0.41508648, 0.789976  ],
        [1.08043265, 0.63937551, 0.78115976],
        [1.07323718, 0.63215637, 0.79428852],
        ...,
        [0.9186604 , 0.44539785, 0.78513801],
        [0.87827879, 0.29374108, 0.79243803],
        [1.29587114, 0.34756547, 0.9094227 ]], dtype=float64),
 Array([[6.32997656, 8.30815315, 5.03821182],
        [5.56648064, 7.972013  , 3.75736237],
        [5.43217659, 7.76834679, 3.66270065],
        ...,
        [2.55771756, 4.39169788, 1.76048601],
        [2.85465074, 4.88382244, 1.99021387],
        [4.2364254 , 6.30274868, 3.17609906]], dtype=float64),
 Array([[0.00464397, 0.01059342, 0.00219386],
        [0.00594419, 0.01647322, 0.00214449],
        [0.00602127, 0.01681723, 0.00212893],
        ...,
        [0.01115927, 0.06523703, 0.00248541],
        [0.00952063, 0.05056033, 0.0023965 ],
        [0.00633129, 0.02247833, 0.00245935]], dtype=float64),
 Array([-1314.38215149, -1194.21617964, -1245.8626102 , ...,
          742.0971644

In [25]:
%timeit eval_network_doms_and_track(dom_pos_vv, centered_track_pos, track_src)[0].block_until_ready()

7.45 ms ± 144 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
