In [1]:
import os 
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=14'
os.environ["CUDA_VISIBLE_DEVICES"] = ""

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)

import numpy as np
import scipy
import matplotlib.pyplot as plt


In [3]:
jax.devices("cpu")

[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7),
 CpuDevice(id=8),
 CpuDevice(id=9),
 CpuDevice(id=10),
 CpuDevice(id=11),
 CpuDevice(id=12),
 CpuDevice(id=13)]

In [4]:
import sys
sys.path.insert(0, "/home/storage/hans/jax_reco_new")

from lib.network import get_network_eval_v_fn
from lib.geo import center_track_pos_and_time_based_on_data
from lib.simdata_i3 import I3SimHandler
from dom_track_eval import get_eval_network_doms_and_track

In [5]:
dtype = jnp.float32
eval_network_v = get_network_eval_v_fn(bpath='/home/storage/hans/jax_reco_new/data/network', dtype=dtype)

In [6]:
eval_network_doms_and_track = get_eval_network_doms_and_track(eval_network_v, dtype=dtype)
eval_network_doms_and_track = jax.jit(eval_network_doms_and_track)

In [7]:
event_index = 2

# Get an IceCube event.
bp = '/home/storage2/hans/i3files/21217'
sim_handler = I3SimHandler(os.path.join(bp, 'meta_ds_21217_from_35000_to_53530.ftr'),
                              os.path.join(bp, 'pulses_ds_21217_from_35000_to_53530.ftr'),
                              '/home/storage/hans/jax_reco/data/icecube/detector_geometry.csv')

meta, pulses = sim_handler.get_event_data(event_index)
print(f"muon energy: {meta['muon_energy_at_detector']/1.e3:.1f} TeV")

# Get dom locations, first hit times, and total charges (for each dom).
event_data = sim_handler.get_per_dom_summary_from_sim_data(meta, pulses)

print("n_doms", len(event_data))
fitting_event_data = jnp.array(event_data[['x', 'y', 'z', 'time', 'charge']].to_numpy())

muon energy: 4.7 TeV
n_doms 102


In [8]:
# Make MCTruth seed.
track_pos = jnp.array([meta['muon_pos_x'], meta['muon_pos_y'], meta['muon_pos_z']])
track_time = meta['muon_time']
track_zenith = meta['muon_zenith']
track_azimuth = meta['muon_azimuth']
track_src = jnp.array([track_zenith, track_azimuth])

print("original seed vertex:", track_pos)
centered_track_pos, centered_track_time = center_track_pos_and_time_based_on_data(event_data, 
                                                                                  track_pos, 
                                                                                  track_time, 
                                                                                  track_src)
print("shifted seed vertex:", centered_track_pos)

original seed vertex: [ -777.15166078 -1656.22843231 -1472.45624098]
shifted seed vertex: [-164.52122541  320.02418746 -330.17880541]


In [9]:
dom_pos = fitting_event_data[:, :3]
first_hit_times = fitting_event_data[:, 3]

logits, av, bv, geo_time = eval_network_doms_and_track(dom_pos, centered_track_pos, track_src)
mix_probs = jax.nn.softmax(logits)

delay_time = first_hit_times - (geo_time + centered_track_time)

In [10]:
def func():
    logits, av, bv, geo_time = eval_network_doms_and_track(dom_pos, centered_track_pos, track_src)
    return logits


In [11]:
print(func())

[[ 1.30989981  0.41508666  0.78997564]
 [ 1.08043253  0.63937545  0.78115976]
 [ 1.07323694  0.63215667  0.79428816]
 [ 1.10383415  0.68516821  0.71266603]
 [ 0.81204313  0.68980891  0.76411843]
 [ 0.82605052  0.79009211  0.6531195 ]
 [ 0.80182356  0.51668692  0.83724576]
 [ 1.01524615  0.8235727   0.98089534]
 [ 1.01920176 -0.04484212  0.74508882]
 [ 0.9859823   0.32891491  0.78397214]
 [ 1.14493823  0.74641228  0.65447974]
 [ 0.93164921  1.02077746  0.51619405]
 [ 0.9251709   1.37723207  0.20653372]
 [ 0.92823952  1.51335335  0.0456607 ]
 [ 1.05364358  0.03645617  0.67123985]
 [ 1.05058408  0.065938    0.71161854]
 [ 1.04519749  0.12827635  0.75358391]
 [ 1.05302393  0.2323103   0.79469204]
 [ 0.89867944  0.51733977  0.90101922]
 [ 1.22307038  0.20317671  0.6797384 ]
 [ 1.3793906   0.05905032  0.81567776]
 [ 0.98459589  0.28478992  0.77510321]
 [ 0.83352816  0.35420558  0.42862645]
 [ 0.77976996  0.38025376  0.4027923 ]
 [ 0.71972871  0.39994583  0.39455181]
 [ 0.66483164  0.4212001 

In [12]:
%timeit func().block_until_ready()

580 μs ± 25 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
