In [1]:
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)

import numpy as np
import scipy
import matplotlib.pyplot as plt

import sys, os

In [2]:
sys.path.insert(0, "/home/storage/hans/jax_reco_new")
from lib.cgamma import c_gamma_prob, c_gamma_sf, c_multi_gamma_prob, c_multi_gamma_sf
from lib.plotting import adjust_plot_1d
from lib.network import get_network_eval_fn
from lib.geo import get_xyz_from_zenith_azimuth
from lib.trafos import transform_network_outputs, transform_network_inputs

In [3]:
eval_network = get_network_eval_fn(bpath='/home/storage/hans/jax_reco_new/data/network')
c_multi_gamma_prob_v1d_x = jax.vmap(c_multi_gamma_prob, (0, None, None, None, None, None), 0)
c_multi_gamma_sf_v1d_x = jax.vmap(c_multi_gamma_sf, (0, None, None, None, None), 0)

In [4]:
dist = 30
z = -500
rho = 0.0
zenith = 1.57
azimuth = 3.9264083

x = jnp.array([dist, rho, z, zenith, azimuth])
x_prime = transform_network_inputs(x)
y = eval_network(x_prime)
logits, gamma_a, gamma_b = transform_network_outputs(y)
mix_probs = jax.nn.softmax(logits)

In [5]:
xvals = np.linspace(-20, 6000, 30000)
probs = c_multi_gamma_prob_v1d_x(xvals, mix_probs, gamma_a, gamma_b, 3.0, 0.1)

KeyboardInterrupt: 

In [None]:
fig, ax = plt.subplots()
plt.plot(xvals, probs)
plot_args = {'xlim':[-20, 1500],
                     'ylim':[0.0, 1.2 * np.amax(probs)],
                     'xlabel':'delay time [ns]',
                     'ylabel':'pdf'}
    
adjust_plot_1d(fig, ax, plot_args=plot_args)
plt.tight_layout()
plt.show()

In [None]:
probs = c_multi_gamma_prob_v1d_x(xvals, mix_probs, gamma_a, gamma_b, 3.0, 0.01)
sfs = c_multi_gamma_sf_v1d_x(xvals, mix_probs, gamma_a, gamma_b, 3.0)

fig, ax = plt.subplots()
plt.plot(xvals, probs,linestyle='dashed', color='black', zorder=100, label='Triple Pandel')

y_max = np.amax(probs)
n_photons = jnp.array(np.linspace(1, 50, 10))
for n_p in n_photons:
    ys = n_p * probs * sfs**(n_p-1)
    plt.plot(xvals, ys, label=f'TP MPE (N={n_p:.0f})')
    y_max = np.amax([y_max, np.amax(ys)])
plot_args = {'xlim':[-10, 200],
                     'ylim':[0.0, 1.2*y_max],
                     'xlabel':'delay time [ns]',
                     'ylabel':'pdf'}
    
adjust_plot_1d(fig, ax, plot_args=plot_args)
plt.title(f"distance from track {dist:.1f}m")
plt.tight_layout()
plt.show()

In [None]:
from scipy.integrate import quad

In [None]:
@jax.jit
def conv_mpe_pdf(x, mix_probs, gamma_a, gamma_b, sigma, n_photons):
    p = c_multi_gamma_prob(x, mix_probs, gamma_a, gamma_b, sigma)
    sf = c_multi_gamma_sf(x, mix_probs, gamma_a, gamma_b, sigma)
    return n_photons * p * jnp.power(sf, (n_photons-1))

def norm_mpe_pdf(dist, mix_probs, gamma_a, gamma_b, sigma, n_photons):
    f = lambda x: conv_mpe_pdf(x, mix_probs, gamma_a, gamma_b, sigma, n_photons)
    return  quad(f, -15, dist*30, epsabs=1.e-4)[0]

def mean_conv_mpe_pdf(dist, mix_probs, gamma_a, gamma_b, sigma, n_photons):
    f = lambda x: x * conv_mpe_pdf(x, mix_probs, gamma_a, gamma_b, sigma, n_photons)
    return quad(f, -15, dist*30, epsabs=1.e-4)[0]

def second_moment_conv_mpe_pdf(dist, mix_probs, gamma_a, gamma_b, sigma, n_photons):
    f = lambda x: x**2 * conv_mpe_pdf(x, mix_probs, gamma_a, gamma_b, sigma, n_photons)
    return quad(f, -15, dist*30, epsabs=1.e-4)[0]

In [None]:
print(dist)
n_photons = 400
norm = norm_mpe_pdf(dist, mix_probs, gamma_a, gamma_b, 3.0, n_photons)
print(norm)
mean = mean_conv_mpe_pdf(dist, mix_probs, gamma_a, gamma_b, 3.0, n_photons)
print(mean)
second_moment = second_moment_conv_mpe_pdf(dist, mix_probs, gamma_a, gamma_b, 3.0, n_photons)
err = np.sqrt(second_moment - mean**2)
print(err)

In [None]:
from matplotlib.pyplot import cm

fig, ax = plt.subplots()
xvals = np.linspace(-20, 50, 1000)

y_max = 0 
n_ps = [n_photons]
color = cm.rainbow(np.linspace(0, 1, len(n_ps)))

probs = c_multi_gamma_prob_v1d_x(xvals, mix_probs, gamma_a, gamma_b, 3.0, 0.01)
sfs = c_multi_gamma_sf_v1d_x(xvals, mix_probs, gamma_a, gamma_b, 3.0)

for i, n_p in enumerate(n_ps):
    ys = n_p * probs * sfs**(n_p-1)
    plt.plot(xvals, ys, color=color[i])
    y_max = np.amax([y_max, np.amax(ys)])
plot_args = {'xlim':[-10, 20],
                     'ylim':[0.0, 1.2*y_max],
                     'xlabel':'delay time [ns]',
                     'ylabel':'pdf'}

ax.axvline(mean)
ax.axvspan(mean-err, mean+err, alpha=0.2)
adjust_plot_1d(fig, ax, plot_args=plot_args)
plt.title(f"distance from track {dist:.1f}m")
plt.tight_layout()
plt.show()

In [None]:
def find_charge_max(distance, min_charge, max_charge):
    dist = distance
    z = -500
    rho = 0.0
    zenith = 1.57
    azimuth = 3.9264083
    
    x = jnp.array([dist, rho, z, zenith, azimuth])
    x_prime = transform_network_inputs(x)
    y = eval_network(x_prime)
    logits, gamma_a, gamma_b = transform_network_outputs(y)
    mix_probs = jax.nn.softmax(logits)

    mid_charge = (min_charge + max_charge)//2

    sigma = 3.0
    mean1 = mean_conv_mpe_pdf(dist, mix_probs, gamma_a, gamma_b, sigma, mid_charge)
    mean2 = mean_conv_mpe_pdf(dist, mix_probs, gamma_a, gamma_b, sigma, mid_charge-1)
    #print(mid_charge, mean1, mean2)
    
    if mean1 < 0.0:
        if mean2 > 0.0:
            return mid_charge
        else:
            charge = find_charge_max(distance, min_charge, mid_charge-1)
    else:
        charge = find_charge_max(distance, mid_charge+1, max_charge)

    return charge

In [None]:
distances = np.linspace(1.0, 30.0, 30)
uppers = np.exp(0.25*distances) + 4

cuts = [2.0]
for i, dist in enumerate(distances):
    print(f"distance: {dist}") 
    thresh = find_charge_max(dist, cuts[-1], uppers[i])
    print(thresh)
    cuts.append(thresh)

In [None]:
xvals = np.linspace(0.0, 30.0, 31)
plt.plot(xvals, cuts, "rx")
plt.plot(xvals, np.exp(0.21*xvals) / (1 + np.exp(-0.04*xvals))+1.5)
plt.yscale('log')
plt.ylim([1.0, 1.e3])

In [None]:
xvals = np.linspace(0.0, 30.0, 21)
yvals = np.exp(0.25*xvals) + 4

In [None]:
plt.plot(xvals, yvals)
plt.yscale('log')

In [None]:
print(yvals)

In [None]:
def charge_clip_fn(x):
    return np.exp(0.21*x) / (1 + np.exp(-0.04*x))+1.5

In [None]:
print(charge_clip_fn(1.0))