In [None]:
import sys, os
sys.path.insert(0, "/home/storage/hans/jax_reco_new")
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import jax.numpy as jnp
from jax.scipy import optimize
import jax
jax.config.update("jax_enable_x64", True)
import optimistix as optx

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from lib.simdata_i3 import I3SimBatchHandlerTFRecord
from lib.geo import center_track_pos_and_time_based_on_data_batched_v
from lib.experimental_methods import get_clean_pulses_fn_v
from lib.network import get_network_eval_v_fn

from likelihood_mpe_padded_input import get_neg_c_triple_gamma_llh
from lib.geo import get_xyz_from_zenith_azimuth, __c
from dom_track_eval import get_eval_network_doms_and_track2 as get_eval_network_doms_and_track
import time

dtype = jnp.float32
eval_network_v = get_network_eval_v_fn(bpath='/home/storage/hans/jax_reco/data/network',
                                       dtype=dtype)
eval_network_doms_and_track = get_eval_network_doms_and_track(eval_network_v, dtype=dtype)

# Create padded batches (with different seq length).
event_ids = ['1022', '10393', '10644', '10738', '11086', '11232', '13011',
       '13945', '14017', '14230', '15243', '16416', '16443', '1663',
       '1722', '17475', '18846', '19455', '20027', '21113', '21663',
       '22232', '22510', '22617', '23574', '23638', '23862', '24530',
       '24726', '25181', '25596', '25632', '27063', '27188', '27285',
       '28188', '28400', '29040', '29707', '3062', '31920', '31989',
       '32781', '32839', '33119', '33656', '34506', '35349', '37086',
       '37263', '37448', '37786', '37811', '39166', '39962', '40023',
       '41381', '41586', '42566', '42568', '42677', '43153', '43483',
       '4397', '44081', '48309', '48448', '48632', '49067', '50832',
       '51687', '51956', '54374', '55301', '55526', '55533', '56041',
       '5620', '56741', '56774', '57174', '57394', '57723', '59010',
       '59029', '59089', '59099', '59228', '62274', '62512', '63373',
       '65472', '6586', '8', '8604', '8674', '8840', '9410', '9419',
       '9505']

ix = 0
print(event_ids[ix])
tfrecord = f"/home/storage2/hans/i3files/alerts/bfrv2/event_{event_ids[ix]}_N100_from_0_to_10_1st_pulse.tfrecord"



#tfrecord = "/home/storage2/hans/i3files/golden_muons/IC/NominalIce/MuMinus_150e3GeV_Horizontal_CloseToDOMs_Smooth_N100_from_0_to_10_1st_pulse.tfrecord"
#tfrecord = "/home/storage2/hans/i3files/golden_muons/IC/NominalIce/MuMinus_150e3GeV_Horizontal_CloseToDOMs_Stochastic_N100_from_0_to_10_1st_pulse.tfrecord"
# ridiculous pulse: 2063.8000043034554pe
# ridiculous pulse 97pe

#tfrecord = "/home/storage2/hans/i3files/golden_muons/IC/NominalIce/MuMinus_150e3GeV_Horizontal_FarFromDOMs_Stochastic_N100_from_0_to_10_1st_pulse.tfrecord"
#tfrecord = "/home/storage2/hans/i3files/golden_muons/IC/NominalIce/MuMinus_150e3GeV_Horizontal_FarFromDOMs_Smooth_N100_from_0_to_10_1st_pulse.tfrecord"
# ridiculous pulse: 2063.8000043034554pe
# ridiculous pulse 97pe

batch_maker = I3SimBatchHandlerTFRecord(tfrecord, batch_size=100)
batch_iter = batch_maker.get_batch_iterator()

# Until LLH has a noise-term, we need to remove crazy early noise pulses
clean_pulses_fn_v = get_clean_pulses_fn_v(eval_network_doms_and_track)

data, mctruth = batch_iter.next()
data = jnp.array(data.numpy())
mctruth = jnp.array(mctruth.numpy())
data_clean_padded = clean_pulses_fn_v(data, mctruth)

track_times = mctruth[:, 4]
track_positions = mctruth[:, 5:8]
track_srcs = mctruth[:,2:4]

In [None]:
track_times_ = track_times.reshape((len(track_times), 1))

In [None]:
geo = pd.read_csv('/home/storage/hans/jax_reco_new/data/icecube/detector_geometry.csv')
pulses = data_clean_padded[0]
df = pd.DataFrame(data=pulses, columns = ["x", "y", "z", "time", "charge"])

def plot_event(df, geo=None):
    fig = plt.figure(figsize=(8,6))
    ax = plt.subplot(projection='3d')
    ax.set_xlabel('pos.x [m]', fontsize=16, labelpad=-25)
    ax.set_ylabel('pos.y [m]', fontsize=16, labelpad=-25)
    ax.set_zlabel('pos.z [m]', fontsize=16, labelpad=-25)

    idx = df['charge'] > 0
    
    try:
        im = ax.scatter(geo['x'], geo['y'], geo['z'], s=0.5, c='0.7', alpha=0.4)
    except:
        pass
    
    im = ax.scatter(df[idx]['x'], df[idx]['y'], df[idx]['z'], s=np.sqrt(df[idx]['charge']*100), c=df[idx]['time'], 
                    cmap='rainbow_r',  edgecolors='k', zorder=1000)
    ax.tick_params(axis='both', which='both', width=1.5, colors='0.0', labelsize=16)
    cb = plt.colorbar(im, orientation="vertical", pad=0.1)
    cb.set_label(label='time [ns]', size='x-large')
    cb.ax.tick_params(labelsize='x-large')
        
    plt.show()

plot_event(df, geo)

In [None]:
hit_x = np.array(data_clean_padded[..., 0].flatten())
hit_y = np.array(data_clean_padded[..., 1].flatten())
hit_z = np.array(data_clean_padded[..., 2].flatten())
hit_t = np.array((data_clean_padded[..., 3]-track_times_).flatten())
hit_q = np.array(data_clean_padded[..., 4].flatten())

In [None]:
from collections import defaultdict
times_dict = defaultdict(list)
charges_dict = defaultdict(list)

for i in range(len(hit_x)):
    q = hit_q[i]
    if q < 1.e-3:
        continue

    if q > 1000:
        continue
        
    x = hit_x[i]
    y = hit_y[i]
    z = hit_z[i]
   
    t = hit_t[i]
    times_dict[(x,y,z)].append(t)
    charges_dict[(x,y,z)].append(q)

qtot_dict = dict()
for key, qs in charges_dict.items():
    qtot_dict[key]=sum(qs) / len(qs)

In [None]:
max_om = max(qtot_dict, key=charges_dict.get)

In [None]:
print(qtot_dict[max_om])

In [None]:
print(charges_dict[max_om])

In [None]:
print(times_dict[max_om])
print(np.min(times_dict[max_om]))
print(np.max(times_dict[max_om]))
print(np.std(times_dict[max_om]))

In [None]:
plt.hist(times_dict[max_om]-np.median(times_dict[max_om]), bins=np.linspace(-20, 20, 31), density=True)
plt.show()

In [None]:
positions_batch = []
for key in times_dict.keys():
    positions_batch.append(jnp.array(key).reshape((1,3)))

positions_batch = jnp.concatenate(positions_batch, axis=0)

In [None]:
logits, av, bv, geo_time = eval_network_doms_and_track(positions_batch, track_positions[0], track_srcs[0])
mix_probs = jax.nn.softmax(logits)

In [None]:
for i in range(len(positions_batch)):
    pos = tuple(np.array(positions_batch[i]))
    gt = geo_time[i]
    for j in range(len(times_dict[pos])):
        times_dict[pos][j] -= float(gt)

In [None]:
plt.hist(times_dict[max_om], bins=np.linspace(-10, 10, 31), density=True)
plt.xlabel('delay time [ns]')
plt.show()

print(np.std(times_dict[max_om]))

In [None]:
print(times_dict[max_om])

In [None]:
from lib.geo import cherenkov_cylinder_coordinates_w_rho2_v as cherenkov_cylinder_coordinates_w_rho_v
from lib.geo import get_xyz_from_zenith_azimuth

In [None]:
track_dir_xyz = get_xyz_from_zenith_azimuth(track_srcs[0])

geo_time, closest_approach_dist, closest_approach_z, closest_approach_rho = \
            cherenkov_cylinder_coordinates_w_rho_v(positions_batch,
                                         track_positions[0],
                                         track_dir_xyz)

In [None]:
print(closest_approach_dist)

In [None]:
dom_i = None
for i in range(len(positions_batch)):
    if tuple(np.array(positions_batch[i]))==max_om:
        dom_i = i

In [None]:
print(closest_approach_dist[dom_i])
print(closest_approach_z[dom_i])
print(closest_approach_rho[dom_i])

In [None]:
from lib.cgamma import c_multi_gamma_prob, c_multi_gamma_sf
from lib.plotting import adjust_plot_1d

In [None]:
c_multi_gamma_prob_vx = jax.vmap(c_multi_gamma_prob, (0, None, None, None, None, None), 0)

xvals = np.linspace(-10, 3000, 30000)

m = mix_probs[dom_i]
a = av[dom_i]
b = bv[dom_i]
yval = c_multi_gamma_prob_vx(xvals, m, a, b, 2.0, 0.1)
mode = (a[1]-1)/b[1]
fig, ax = plt.subplots()
plt.plot(xvals, yval)
    
plot_args = {'xlim':[-10, np.max([20, 10 * mode])],
                     'ylim':[0.0, 1.2 * np.amax(yval)],
                     'xlabel':'delay time [ns]',
                     'ylabel':'pdf'}

for tx in times_dict[max_om]:
    plt.axvline(tx, alpha=0.2, color='black', lw=0.5)

adjust_plot_1d(fig, ax, plot_args=plot_args)
plt.tight_layout()
plt.show()

In [None]:
c_multi_gamma_sf_vx = jax.vmap(c_multi_gamma_sf, (0, None, None, None, None), 0)

n_p = qtot_dict[tuple(np.array(positions_batch[dom_i]))]
n_p = np.round(n_p+0.5)
print(n_p)
n_p = np.min([3.0, n_p])
print(n_p)

m = mix_probs[dom_i]
a = av[dom_i]
b = bv[dom_i]

probs = c_multi_gamma_prob_vx(xvals, m, a, b, 3.0, 0.1)
sfs = c_multi_gamma_sf_vx(xvals, m, a, b, 3.0)
yval = n_p * probs * sfs**(n_p-1)

mode = (a[1]-1)/b[1]
fig, ax = plt.subplots()
plt.title(f"event {event_ids[ix]}")
plt.plot(xvals, yval)
    
plot_args = {'xlim':[-10, np.max([20, 5 * np.amax(tx)])],
                     'ylim':[0.0, 1.2 * np.amax(yval)],
                     'xlabel':'delay time [ns]',
                     'ylabel':'pdf'}

for tx in times_dict[max_om]:
    plt.axvline(tx, alpha=0.2, color='black', lw=0.5)

adjust_plot_1d(fig, ax, plot_args=plot_args)
plt.tight_layout()
plt.show()

In [None]:
qtots = []
for dom_pos in positions_batch:
    qtots.append(qtot_dict[tuple(np.array(dom_pos))])

fig, ax = plt.subplots()
plt.scatter(closest_approach_dist, qtots)
plot_args = {'xlim':[0.0, 500],
                     'ylim':[0.0, 500],
                     'xlabel':'distance to track [m]',
                     'ylabel':'qtot [p.e.]'}
adjust_plot_1d(fig, ax, plot_args=plot_args)
plt.yscale('log')
plt.ylim([1.0, 500])
plt.show()