In [None]:
import os
import sys
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3"
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from scipy.spatial import distance
from scipy.optimize import curve_fit
import tensorflow as tf
from sklearn.model_selection import train_test_split
import dama as dm
import pickle
import awkward as ak
import pyarrow.parquet as pq

from freedom.toy_model.toy_model_functions import toy_model
from freedom.toy_model.detectors import get_box_detector
from types import SimpleNamespace
#from toy_NN_trafo import build_q_trafo, build_h_trafo
from freedom.toy_model import NNs

%load_ext autoreload
%autoreload 2

In [None]:
plt.rcParams['xtick.labelsize'] = 18
plt.rcParams['ytick.labelsize'] = 18
plt.rcParams['axes.labelsize'] = 18
plt.rcParams['axes.titlesize'] = 16
plt.rcParams['legend.fontsize'] = 18

par_names = ['x', 'y', 'z', 't', 'azi', 'zen', 'E', 'I']

def plot_truth(axes, truth, idx=(0,1)):
    if not isinstance(axes, np.ndarray):
        axes = np.array([axes])
    for ax in axes.flatten():
        ax.plot([truth[idx[0]]], [truth[idx[1]]], marker='$T$', markersize=10, color='k')

def plot_diff(a, b, axes, title_a='a', title_b='b', vmax=None, limit_diff=False, **kwargs):
    
    levels = stats.chi2(df=2).isf(stats.norm.sf(np.arange(1,6))*2)/2    
    labels = [str(i) + r'$\sigma$' for i in range(1,6)]
    colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red']
    colors = plt.cm.viridis(np.linspace(0, 0.9, 6))
    #a.plot(ax=axes[0], cmap='Greys', label=r'$\Delta LLH$', **kwargs)
    a.plot_contour(ax=axes[0], levels=levels, labels=labels, colors=colors, label=r'$\Delta$LLH', **kwargs)
    axes[0].set_title(title_a)
    #b.plot(ax=axes[1], cmap='Greys', label=r'$\Delta LLH$', **kwargs)
    b.plot_contour(ax=axes[1], levels=levels,  labels=labels, colors=colors, label=r'$\Delta$LLH', **kwargs)
    axes[1].set_title(title_b)
    diff = a - b
    if limit_diff:
        diff.plot(ax=axes[2], cmap='RdBu', cbar=True, vmin=-vmax, vmax=vmax, label=r'$\Delta$LLH', **kwargs)
        #diff.plot_contour(ax=axes[2], levels=levels, labels=labels, colors=colors, label=r'$\Delta LLH$', **kwargs)
        #np.clip(-diff, 0, None).plot_contour(ax=axes[2], levels=[0.1,0.2, 0.3], colors=['r']*2)
    else:
        diff.plot(ax=axes[2], cmap='RdBu', cbar=True, vmin=-np.max(np.abs(diff)), vmax=np.max(np.abs(diff)), label=r'$\Delta$LLH', **kwargs) 
        #diff.plot_contour(ax=axes[2], levels=levels, labels=labels, colors=colors, label=r'$\Delta LLH$', **kwargs)
    axes[2].set_title('Difference') #axes[2].set_title(title_a + ' - ' + title_b)
    
def plot_point_dense(x, y):
    xy = np.vstack([x,y])
    z = stats.gaussian_kde(xy)(xy)
    
    idx = z.argsort()

    plt.scatter(x[idx], y[idx], c=z[idx])
    #plt.colorbar()

In [None]:
detector = get_box_detector(x=np.linspace(-5,5,5), y=[0,], z=[0,])

In [None]:
toy_experiment = toy_model(detector)

In [None]:
truth = np.array([1., 1., 0, 0, 0, np.arccos(1), 2., 0.5]) #x, y, z, t, az, zen, energy, inelast

In [None]:
# generate one test event
test_event = toy_experiment.generate_event(truth)
test_event[1]

In [None]:
%%time
# Grid scan

g = dm.GridData(x=np.linspace(-7, 7, 100), y=np.linspace(-2, 2, 100))
#g = dm.GridData(x=np.linspace(1, 7, 100), y=np.linspace(0, 1, 100))
IDX = (0,1)

g['dom_hit_term'] = np.empty(g.shape)
g['dom_charge_terms'] = np.empty(g.shape)
g['total_charge_hit_terms'] = np.empty(g.shape)
g['total_charge_terms'] = np.empty(g.shape)

p = np.copy(truth)

for idx in np.ndindex(g.shape):
    p[IDX[0]] =  g['x'][idx]
    p[IDX[1]] =  g['y'][idx]
    segments = toy_experiment.model(*p)
    g['dom_hit_term'][idx] = toy_experiment.nllh_p_term_dom(segments, test_event[0])
    g['dom_charge_terms'][idx] = toy_experiment.nllh_N_term_dom(segments, test_event[1])
    g['total_charge_hit_terms'][idx] = toy_experiment.nllh_p_term_tot(segments, test_event[0])
    g['total_charge_terms'][idx] = toy_experiment.nllh_N_term_tot(segments, test_event[1])
    
g['dom_hit_term'] -= g['dom_hit_term'].min()
g['dom_charge_terms'] -= g['dom_charge_terms'].min()
g['dom_llh'] = g['dom_hit_term'] + g['dom_charge_terms']
g['total_charge_hit_terms'] -= g['total_charge_hit_terms'].min()
g['total_charge_terms'] -= g['total_charge_terms'].min()
g['total_charge_llh'] = g['total_charge_hit_terms'] + g['total_charge_terms']
g['dom_llh'] -= g['dom_llh'].min()
g['total_charge_llh'] -= g['total_charge_llh'].min()

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(12,12))
plt.subplots_adjust(wspace=0.2, hspace=0.25)

plot_diff(g['dom_hit_term'], g['total_charge_hit_terms'], axes=ax[0], title_a=r'Per-sensor $p_{s}(x_{i,s}|\theta)$ terms', title_b=r'All-sensor $p_{s_{i}}^{tot}(x_{i}|\theta)$ terms', vmax=20, limit_diff=True)
plot_diff(g['dom_charge_terms'], g['total_charge_terms'], axes=ax[1], title_a=r'Per-sensor $P_{s}(N_{s}|\theta)$ terms', title_b=r'All-sensor $P_{tot}(N_{tot}|\theta)$ term', vmax=20, limit_diff=True)
plot_diff(g['dom_llh'], g['total_charge_llh'], axes=ax[2], title_a='Per-sensor LLH', title_b='All-sensor LLH', limit_diff=False)

plot_truth(ax, truth, IDX)
for i, a in enumerate(ax.flatten()): 
    if i in [6,7,8]: 
        a.set_xlabel('x')
    else:
        a.set_xlabel('')
    if i in [0,3,6]: 
        a.set_ylabel('y')
    else:
        a.set_ylabel('')

#plt.savefig('../../plots/thesis/eml_decomp', bbox_inches='tight')

In [None]:
stats.chi2(df=2).isf(stats.norm.sf(np.arange(1,6))*2)/2

# Train NNs

In [None]:
NE, Set = 2_560_000, 0
events, meta = toy_experiment.generate_events(n=NE, gamma=0, gen_volume="box", e_lim=(1,7), inelast_lim=(0,1),
                                              x_lim=(-7,7), y_lim=(-2, 2), z_lim=(0,0), t_width=0, coszen_lim=(1,1),
                                              contained=False, min_hits=3, rand=Set)
#truths = NNs.make_truth_array(events)

In [None]:
plt.hist(ak.count(events.photons.t, axis=1).to_numpy(), np.linspace(0,100,101))
plt.yscale('log')

In [None]:
strategy = tf.distribute.MirroredStrategy()
nGPUs = strategy.num_replicas_in_sync

## Hit Net - total charge

In [None]:
x, t = NNs.get_hit_data(events)
x_train, x_test, t_train, t_test = train_test_split(x, t, test_size=0.1, random_state=42)
d_train = NNs.DataGenerator(x_train, t_train, batch_size=4096*nGPUs, time_spread=0) #10 30
d_valid = NNs.DataGenerator(x_test, t_test, batch_size=4096*nGPUs, time_spread=0) #10 30

#size = sys.getsizeof(d_train.indexes) + sys.getsizeof(d_train.data) + sys.getsizeof(d_train.params)
#size += sys.getsizeof(d_valid.indexes) + sys.getsizeof(d_valid.data) + sys.getsizeof(d_valid.params)
#size * 10e-6

In [None]:
optimizer = tf.keras.optimizers.Adam(1e-4)

with strategy.scope():
    hmodel = NNs.get_hmodel(x_shape=6, t_shape=8, trafo=NNs.hit_trafo, activation='swish', final_activation='swish',
                            nodes=250, n_layer=12) # big:(300,14), small:(200,10)
    hmodel.compile(loss='binary_crossentropy', optimizer=optimizer)
#hmodel.summary()

In [None]:
hist = hmodel.fit(d_train, epochs=15, verbose=1, validation_data=d_valid)

In [None]:
plt.plot(hist.history['loss'])
plt.plot(hist.history['val_loss'])
plt.yscale('log')

In [None]:
inp, lab = d_valid.__getitem__(0)
pred = hmodel.predict(inp, batch_size=4096).flatten()
plt.hist(pred[lab==0], 100, histtype='step')
plt.hist(pred[lab==1], 100, histtype='step');
plt.yscale('log')

In [None]:
xxs = np.repeat(test_event[0][np.newaxis, :], np.prod(g.shape), axis=0)
xxs = xxs.reshape(-1, 6)

tts = np.repeat(truth[np.newaxis, :], np.prod(g.shape), axis=0)
tts[:, IDX[0]] = g.get_array('x', flat=True)
tts[:, IDX[1]] = g.get_array('y', flat=True)
tts = np.repeat(tts, len(test_event[0]), axis=0)

hmodel.layers[-1].activation = tf.keras.activations.linear
hmodel.compile()

llhs = -hmodel.predict((xxs, tts), batch_size=4096)
llhs = np.sum(np.nan_to_num(llhs.reshape(-1, len(test_event[0]))), axis=1)

g.hit_llh_total = llhs.reshape(g.shape)
g.hit_llh_total -= g.hit_llh_total.min()

fig, ax = plt.subplots(1, 3, figsize=(20,5))
plt.subplots_adjust(wspace=0.3)
#plt.suptitle(str(NE)+' events', size=17)

plot_diff(g.total_charge_hit_terms, g.hit_llh_total, title_a='Analytic', title_b='NN', vmax=10, axes=ax, limit_diff=True)
plot_truth(ax, truth, IDX)
for a in ax:
    a.set_xlabel(par_names[IDX[0]])
    a.set_ylabel(par_names[IDX[1]])

#plt.savefig('images/simple_tests/llh_scan_xt'+str(NE)+'.png', bbox_inches='tight')

In [None]:
#hmodel.save('networks/simple_toy/'+str(NE)+'/simple_toy_hitnet_total_'+str(NE)+'_set'+str(Set)+'.h5')
hmodel_t = tf.keras.models.load_model('networks/simple_toy/2560000/simple_toy_hitnet_total_2560000_set4.h5',
                                       custom_objects={'hit_trafo':NNs.hit_trafo})

## Charge Net - total charge

In [None]:
x, t = NNs.get_charge_data(events)
x_train, x_test, t_train, t_test = train_test_split(x, t, test_size=0.1, random_state=42)

d_train = NNs.DataGenerator(x_train, t_train, batch_size=4096*nGPUs)
d_valid = NNs.DataGenerator(x_test, t_test, batch_size=4096*nGPUs)

In [None]:
optimizer = tf.keras.optimizers.Adam(1e-3)

with strategy.scope():
    cmodel = NNs.get_cmodel(x_shape=2, t_shape=8, trafo=NNs.charge_trafo, activation='swish', final_activation='swish',
                            nodes=300, n_layer=15) # big:(170,13), small:(130,9)
    cmodel.compile(loss='binary_crossentropy', optimizer=optimizer)
#cmodel.summary()

In [None]:
hist = cmodel.fit(d_train, epochs=30, verbose=1, validation_data=d_valid)

In [None]:
plt.plot(hist.history['loss'])
plt.plot(hist.history['val_loss'])
plt.yscale('log')

In [None]:
inp, lab = d_valid.__getitem__(0)
pred = cmodel.predict(inp, batch_size=4096).flatten()
plt.hist(pred[lab==0], 100, histtype='step')
plt.hist(pred[lab==1], 100, histtype='step');
plt.yscale('log')

In [None]:
xxs = np.tile([len(test_event[0]), len(np.unique(test_event[0][:,0]))], np.prod(g.shape))
xxs = xxs.reshape(-1, 2)

tts = np.repeat(truth[np.newaxis, :], np.prod(g.shape), axis=0)
tts[:, IDX[0]] = g.get_array('x', flat=True)
tts[:, IDX[1]] = g.get_array('y', flat=True)

cmodel.layers[-1].activation = tf.keras.activations.linear
cmodel.compile()

llhs = np.nan_to_num(-cmodel.predict((xxs, tts), batch_size=4096))

g.charge_llh_total = llhs.reshape(g.shape)
g.charge_llh_total -= g.charge_llh_total.min()

fig, ax = plt.subplots(1, 3, figsize=(20,5))
plt.subplots_adjust(wspace=0.3)

plot_diff(g.total_charge_terms, g.charge_llh_total, title_a='Analytic', title_b='NN', vmax=10, axes=ax, limit_diff=True)
plot_truth(ax, truth, IDX)
#plt.savefig('images/chargeNNtest.png', bbox_inches='tight')

In [None]:
#cmodel.save('networks/simple_toy/'+str(NE)+'/simple_toy_chargenet_total_'+str(NE)+'_set'+str(Set)+'.h5')
cmodel_t = tf.keras.models.load_model('networks/simple_toy/2560000/simple_toy_chargenet_total_2560000_set4.h5',
                                      custom_objects={'charge_trafo':NNs.charge_trafo})

## LLH - total charge

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(20,17))
plt.subplots_adjust(wspace=0.25, hspace=0.25)

plot_diff(g.total_charge_hit_terms, 
          g.hit_llh_total, 
          title_a='Hit Analytic', title_b='Hit NN', vmax=10, axes=ax[0], limit_diff=True)

plot_diff(g.total_charge_terms, 
          g.charge_llh_total, 
          title_a='Charge Analytic', title_b='Charge NN', vmax=10, axes=ax[1], limit_diff=True)

ana, NN = g.total_charge_hit_terms+g.total_charge_terms, g.hit_llh_total+g.charge_llh_total
plot_diff(ana-ana.min(), 
          NN-NN.min(), 
          title_a='Analytic', title_b='NN', vmax=10, axes=ax[2], limit_diff=True)
plot_truth(ax, truth, IDX)

#plt.savefig('images/NNtest_totalC.png', bbox_inches='tight')

## Hit Net - per dom

In [None]:
x, t = NNs.get_hit_data(events)
x_train, x_test, t_train, t_test = train_test_split(x, t, test_size=0.1, random_state=42)
d_train = NNs.DataGenerator(x_train, t_train, batch_size=4096*nGPUs, shuffle='inDOM', time_spread=30)
d_valid = NNs.DataGenerator(x_test, t_test, batch_size=4096*nGPUs, shuffle='inDOM', time_spread=30)

In [None]:
optimizer = tf.keras.optimizers.Adam(2e-5)

with strategy.scope():
    hmodel = NNs.get_hmodel(x_shape=6, t_shape=8, trafo=NNs.hit_trafo, activation='swish', final_activation='swish')
    hmodel.compile(loss='binary_crossentropy', optimizer=optimizer)

In [None]:
hist = hmodel.fit(d_train, epochs=7, verbose=1, validation_data=d_valid)

In [None]:
plt.plot(hist.history['loss'])
plt.plot(hist.history['val_loss'])
plt.yscale('log')

In [None]:
inp, lab = d_valid.__getitem__(0)
pred = hmodel.predict(inp, batch_size=4096).flatten()
plt.hist(pred[lab==0], 100, histtype='step')
plt.hist(pred[lab==1], 100, histtype='step');
plt.yscale('log')

In [None]:
xxs = np.repeat(test_event[0][np.newaxis, :], np.prod(g.shape), axis=0)
xxs = xxs.reshape(-1, 6)

tts = np.repeat(truth[np.newaxis, :], np.prod(g.shape), axis=0)
tts[:, IDX[0]] = g.get_array('x', flat=True)
tts[:, IDX[1]] = g.get_array('y', flat=True)
tts = np.repeat(tts, len(test_event[0]), axis=0)

hmodel.layers[-1].activation = tf.keras.activations.linear
hmodel.compile()

llhs = -hmodel.predict((xxs, tts), batch_size=4096)
llhs = np.sum(np.nan_to_num(llhs.reshape(-1, len(test_event[0]))), axis=1)

g.hit_llh_dom = llhs.reshape(g.shape)
g.hit_llh_dom -= g.hit_llh_dom.min()

fig, ax = plt.subplots(1, 3, figsize=(20,5))
plt.subplots_adjust(wspace=0.3)

plot_diff(g.dom_hit_term, g.hit_llh_dom, title_a='Analytic', title_b='NN', vmax=10, axes=ax, limit_diff=True)
plot_truth(ax, truth, IDX)
#plt.savefig('images/hitNNtest.png', bbox_inches='tight')

In [None]:
#hmodel.save('networks/simple_toy/simple_toy_hitnet_dom.h5')

## Charge Net - per DOM

In [None]:
x, t = NNs.get_dom_data(events, detector)
x_train, x_test, t_train, t_test = train_test_split(x, t, test_size=0.1, random_state=42)

d_train = NNs.DataGenerator(x_train, t_train, batch_size=4096*nGPUs)
d_valid = NNs.DataGenerator(x_test, t_test, batch_size=4096*nGPUs)

In [None]:
optimizer = tf.keras.optimizers.Adam(1e-4)

with strategy.scope():
    dmodel = NNs.get_hmodel(x_shape=4, t_shape=8, trafo=NNs.dom_trafo, activation='swish', final_activation='swish')
    dmodel.compile(loss='binary_crossentropy', optimizer=optimizer)

In [None]:
hist = dmodel.fit(d_train, epochs=60, verbose=1, validation_data=d_valid)

In [None]:
plt.plot(hist.history['loss'])
plt.plot(hist.history['val_loss'])
plt.yscale('log')

In [None]:
inp, lab = d_valid.__getitem__(0)
pred = dmodel.predict(inp, batch_size=4096).flatten()
plt.hist(pred[lab==0], 100, histtype='step')
plt.hist(pred[lab==1], 100, histtype='step');
plt.yscale('log')

In [None]:
xx = []
ind = test_event[0][:, 5]
for i in range(len(detector)):
    d = np.append(detector[i], np.sum(ind==i))
    xx.append(list(d))
xxs = np.repeat(np.array(xx)[np.newaxis, :], np.prod(g.shape), axis=0)
xxs = xxs.reshape(-1, 4)

tts = np.repeat(truth[np.newaxis, :], np.prod(g.shape), axis=0)
tts[:, IDX[0]] = g.get_array('x', flat=True)
tts[:, IDX[1]] = g.get_array('y', flat=True)
tts = np.repeat(tts, len(detector), axis=0)

dmodel.layers[-1].activation = tf.keras.activations.linear
dmodel.compile()

llhs = -dmodel.predict((xxs, tts), batch_size=4096)
llhs = np.sum(np.nan_to_num(llhs.reshape(-1, len(detector))), axis=1)

g.charge_llh_dom = llhs.reshape(g.shape)
g.charge_llh_dom -= g.charge_llh_dom.min()

fig, ax = plt.subplots(1, 3, figsize=(20,5))
plt.subplots_adjust(wspace=0.3)

plot_diff(g.dom_charge_terms, g.charge_llh_dom, title_a='Analytic', title_b='NN', vmax=10, axes=ax, limit_diff=True)
plot_truth(ax, truth, IDX)
#plt.savefig('images/hitNNtest.png', bbox_inches='tight')

In [None]:
#dmodel.save('networks/simple_toy/simple_toy_chargenet_dom.h5')

## LLH - per DOM

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(20,17))
plt.subplots_adjust(wspace=0.25, hspace=0.25)

plot_diff(g.dom_hit_term, 
          g.hit_llh_dom, 
          title_a='Hit Analytic', title_b='Hit NN', vmax=10, axes=ax[0], limit_diff=True)

plot_diff(g.dom_charge_terms, 
          g.charge_llh_dom, 
          title_a='Charge Analytic', title_b='Charge NN', vmax=10, axes=ax[1], limit_diff=True)

ana, NN = g.dom_hit_term+g.dom_charge_terms, g.hit_llh_dom+g.charge_llh_dom
plot_diff(ana-ana.min(), 
          NN-NN.min(), 
          title_a='Analytic', title_b='NN', vmax=10, axes=ax[2], limit_diff=True)
plot_truth(ax, truth, IDX)

#plt.savefig('images/NNtest_perDOM.png', bbox_inches='tight')

## LLH - both

In [None]:
def plot_overlay(a, b, ax, **kwargs):
    levels = stats.chi2(df=2).isf(stats.norm.sf(np.arange(1,6))*2)/2    
    labels = [str(i) + r'$\sigma$' for i in range(1,6)]
    colors = plt.cm.viridis(np.linspace(0, 0.9, 6))
    a.plot_contour(ax=ax, levels=levels, labels=labels, colors=colors, **kwargs)
    b.plot_contour(ax=ax, levels=levels, linestyles=[':']*len(levels), colors=colors, **kwargs)
    ax.plot([], [], label='Analytic', color='Tab:blue')
    ax.plot([], [], label='NN', linestyle=':', color='Tab:blue')
    ax.set_xlabel('x')
    ax.set_ylabel('y')

In [None]:
g['llh_dom'] = g['hit_llh_dom'] + g['charge_llh_dom']
g['llh_total'] = g['hit_llh_total'] + g['charge_llh_total']
g['llh_dom'] -= g['llh_dom'].min()
g['llh_total'] -= g['llh_total'].min()

In [None]:
fig, ax = plt.subplots(2, 3, figsize=(20,12))
plt.subplots_adjust(wspace=0.25, hspace=0.25)

plot_truth(ax, truth, IDX)

ax[0][0].set_title('HitNet', size=20)
plot_overlay(g.dom_hit_term, g.hit_llh_dom, ax[0][0])
ax[0][0].text(-7.3, 1, 'Per Sensor', rotation=90, size=20, ha='center', va='center')

ax[0][1].set_title('ChargeNet', size=20)
plot_overlay(g.dom_charge_terms, g.charge_llh_dom, ax[0][1])

ax[0][2].set_title('Complete LLH', size=20)
plot_overlay(g.dom_llh, g.llh_dom, ax[0][2])
ax[0][2].legend(loc='upper left')

plot_overlay(g.total_charge_hit_terms, g.hit_llh_total, ax[1][0])
ax[1][0].text(-7.3, 1, 'Total Detector', rotation=90, size=20, ha='center', va='center')

plot_overlay(g.total_charge_terms, g.charge_llh_total, ax[1][1])

plot_overlay(g.total_charge_llh, g.llh_total, ax[1][2])

#plt.savefig('images/NNtest_simple.png', bbox_inches='tight')

# Event reweighting

In [None]:
def get_hit_data(truth, vec, bins):
    t = truth.reshape((1,8))
    ts = np.repeat(t, len(hits), axis=0)

    r_t = np.exp(hmodel_t.predict([hits, ts], batch_size=5000)).flatten()
    #r_d = np.exp(hmodel_d.predict([hits, ts], batch_size=5000)).flatten()

    dists = distance.cdist(detector[:,:3], toy_experiment.model(*truth)[:,:3])
    survive = toy_experiment.survival(dists)
    hit_llh_p, hit_llh_ps = [], []
    for b in bins:
        mat = toy_experiment.pandel.pdf(b-dists*4.333, d=dists)
        hit_llh_p.append(np.sum(np.sum(mat, axis=0) * vec))
        hit_llh_ps.append(np.sum(np.sum(mat * survive, axis=0) * vec))
    norm_p = np.sum(np.array(hit_llh_p)) * np.diff(bins)[0]
    norm_ps = np.sum(np.array(hit_llh_ps)) * np.diff(bins)[0]
    
    return r_t, hit_llh_p, hit_llh_ps, norm_p, norm_ps #r_d,

def get_charge_data(truth, exp_bins, exp_bins_fine):
    t = truth.reshape((1,8))
    ts_t = np.repeat(t, len(charges_t), axis=0)
    #ts_d = np.repeat(t, len(charges_d), axis=0)

    r_t = np.exp(cmodel.predict([charges_t, ts_t], batch_size=5000)).flatten()
    #r_d = np.exp(dmodel.predict([charges_d, ts_d], batch_size=5000)).flatten()

    N_exp = toy_experiment.N_exp(toy_experiment.model(*truth))
    dom_c_llh = np.zeros(len(exp_bins))
    for N in N_exp:
        dom_c_llh += stats.poisson.pmf(exp_bins, mu=N)
        
    return r_t, N_exp, dom_c_llh #r_d,

In [None]:
events, meta = toy_experiment.generate_events(n=10_000, gamma=0, gen_volume="box", e_lim=(1,7), inelast_lim=(0,1),
                                              x_lim=(-7,7), y_lim=(-2, 2), z_lim=(0,0), t_width=0, coszen_lim=(1,1),
                                              contained=False, min_hits=3)
truths = NNs.make_truth_array(events)

hitnet

In [None]:
%%time
hits = NNs.get_hit_data(events)[0]
hits[:, 3] += np.random.normal(0, 10, len(hits))

In [None]:
bins = np.linspace(-40,100,100)

r_t, hit_llh_p, hit_llh_ps, norm_p, norm_ps = get_hit_data(np.array([3., 1., 0, 0, 0, np.arccos(1), 5., 0.8]), 
                                                           [4,0.2,0.2,0.2,0.2,0.2],
                                                           bins)

r_t2, hit_llh_p2, hit_llh_ps2, norm_p2, norm_ps2 = get_hit_data(
                                                            np.array([-1., 1., 0, 0, 0, np.arccos(1), 1.5, 0.2]), 
                                                            [0.3,0.2,0.2,0.2,0.2,0.2,0.2],
                                                            bins)

r_t3, hit_llh_p3, hit_llh_ps3, norm_p3, norm_ps3 = get_hit_data(
                                                            np.array([-5., 3., 0, 0, 0, np.arccos(1), 4., 0.9]), 
                                                            [3.6,0.2,0.2],
                                                            bins)

In [None]:
plt.figure(figsize=(7,7*0.618))

#plt.hist(test_hits[:, 3], bins, label='Pulses example event', density=True, histtype='step')
plt.hist(hits[:, 3], bins, label=r'$p(x)$', density=True, histtype='step', color='black')

plt.hist(hits[:, 3], bins, label=r'$p(x) \hat{r}(x,\theta_{1})$', weights=r_t, density=True, histtype='step')
plt.plot(bins+np.diff(bins)[0]/2, np.array(hit_llh_ps)/norm_ps, c='tab:blue', linestyle='--', label=r'$p(x|\theta_{1})$')

plt.hist(hits[:, 3], bins, label=r'$p(x) \hat{r}(x,\theta_{2})$', weights=r_t2, density=True, histtype='step')
plt.plot(bins+np.diff(bins)[0]/2, np.array(hit_llh_ps2)/norm_ps2, c='tab:orange', linestyle='--', label=r'$p(x|\theta_{2})$')

plt.hist(hits[:, 3], bins, label=r'$p(x) \hat{r}(x,\theta_{3})$', weights=r_t3, density=True, histtype='step')
plt.plot(bins+np.diff(bins)[0]/2, np.array(hit_llh_ps3)/norm_ps3, c='tab:green', linestyle='--', label=r'$p(x|\theta_{3})$')

plt.legend(prop={'size':14})
#plt.title('All-sensor')
plt.xlabel(r'$x$ = Hit time')
plt.ylabel('PDF')

#plt.savefig('images/simple_tests/NNtest_rweight_time', bbox_inches='tight')

chargenet

In [None]:
%%time
charges_t = NNs.get_charge_data(events)[0]
#charges_d = NNs.get_dom_data(events, detector)[0]

In [None]:
exp_bins, exp_bins_fine = np.linspace(0,12,13), np.linspace(1,75,75)

r_t, N_exp, dom_c_llh = get_charge_data(np.array([3., 1., 0, 0, 0, np.arccos(1), 5., 0.8]), exp_bins, exp_bins_fine)
r_t2, N_exp2, dom_c_llh2 = get_charge_data(np.array([-1., 1., 0, 0, 0, np.arccos(1), 1.5, 0.2]), exp_bins, exp_bins_fine)
r_t3, N_exp3, dom_c_llh3 = get_charge_data(np.array([-3., 3., 0, 0, 0, np.arccos(1), 4., 0.9]), exp_bins, exp_bins_fine)

In [None]:
plt.figure(figsize=(7,7*0.618))

bins = np.linspace(1,75,75)
#plt.hist(test_charges_t[:, 0], bins, label='Pulses example event', density=True, histtype='step')
plt.hist(charges_t[:, 0], bins, label='p(x)', density=True, histtype='step', color='black')

plt.hist(charges_t[:, 0], bins, label=r'$p(x) \hat{r}(x,\theta_{1})$', weights=r_t, density=True, histtype='step')
plt.plot(exp_bins_fine+np.diff(exp_bins_fine)[0]/2, stats.poisson.pmf(exp_bins_fine, mu=np.sum(N_exp)), 
         c='tab:blue', linestyle='--', label=r'$p(x|\theta_{1})$')

plt.hist(charges_t[:, 0], bins, label=r'$p(x) \hat{r}(x,\theta_{2})$', weights=r_t2, density=True, histtype='step')
plt.plot(exp_bins_fine+np.diff(exp_bins_fine)[0]/2, stats.poisson.pmf(exp_bins_fine, mu=np.sum(N_exp2)), 
         c='tab:orange', linestyle='--', label=r'$p(x|\theta_{2})$')

plt.hist(charges_t[:, 0], bins, label=r'$p(x) \hat{r}(x,\theta_{3})$', weights=r_t3, density=True, histtype='step')
plt.plot(exp_bins_fine+np.diff(exp_bins_fine)[0]/2, stats.poisson.pmf(exp_bins_fine, mu=np.sum(N_exp3)), 
         c='tab:green', linestyle='--', label=r'$p(x|\theta_{3})$')

#plt.title('All-sensor')
plt.legend(prop={'size':14})
plt.xlabel(r'$x$ = Charge (total detector)')
plt.ylabel('PDF')

#plt.savefig('images/simple_tests/NNtest_rweight_charge', bbox_inches='tight')

### llh error

In [None]:
'''
bounds = np.array([[-5,5], [-2,2], [0,0], [-50,50], [0,2*np.pi], [0,0], [0,5], [0,5]])
uniforms = np.random.uniform(size=(50000, 8))
Ps = bounds[:,0] + uniforms * (bounds[:,1] - bounds[:,0])


llhs_ana = []
for p in Ps:
    segments = toy_experiment.model(*p)
    llhs_ana.append(toy_experiment.nllh_p_term_tot(segments, test_event[0]))
    
llhs_ana = np.array(llhs_ana) - np.min(llhs_ana)
'''

xxs = np.repeat(test_event[0][np.newaxis, :], len(Ps), axis=0)
xxs = xxs.reshape(-1, 6)
tts = np.repeat(Ps, len(test_event[0]), axis=0)

llhs_nn = -hmodel.predict((xxs, tts), batch_size=4096)
llhs_nn = np.sum(np.nan_to_num(llhs_nn.reshape(-1, len(test_event[0]))), axis=1)
llhs_nn -= np.min(llhs_nn)

In [None]:
plt.figure(figsize=(15,5))
plt.suptitle(str(NE)+' events', size=17)

plt.subplot(121)
plot_point_dense(llhs_nn-llhs_ana, llhs_ana)
plt.axvline(0, color='black', linestyle='--')
plt.xlabel(r'$\Delta LLH_{NN} - \Delta LLH_{true}$')
plt.ylabel(r'$\Delta LLH_{true}$')
plt.text(np.min(llhs_nn-llhs_ana), 0, 'Mean(x) = %.2f'%(np.mean(llhs_nn-llhs_ana)), size=12)
plt.text(np.min(llhs_nn-llhs_ana), 0.1*np.max(llhs_ana), 'STD(x) = %.2f'%(np.std(llhs_nn-llhs_ana)), size=12)
plt.text(np.min(llhs_nn-llhs_ana), 0.2*np.max(llhs_ana), 'Median(|x|) = %.2f'%(np.median(np.abs(llhs_nn-llhs_ana))), size=12)

plt.subplot(122)
plot_point_dense((llhs_nn-llhs_ana)[llhs_ana<10], llhs_ana[llhs_ana<10])
plt.axvline(0, color='black', linestyle='--')
plt.xlabel(r'$\Delta LLH_{NN} - \Delta LLH_{true}$')
plt.ylabel(r'$\Delta LLH_{true}$')

#plt.savefig('images/simple_tests/dLLH_error_'+str(NE)+'.png', bbox_inches='tight')

In [None]:
plot_point_dense(Ps[:,3], llhs_nn-llhs_ana)

# Reco

In [None]:
from spherical_opt import spherical_opt
from multiprocessing import Pool, Process

In [None]:
bounds = np.array([[-7,7], [-2,2], [-0.1,0.1], [-125,125], [0,2*np.pi], [-0.1,0.1], [1,7], [0,1]])

def trafo(points, d=1):
    if d == 1:
        e = points[:,-2] + points[:,-1]
        elas = points[:,-2]/e
        points[:,-2] = e
        points[:,-1] = elas
    elif d == -1:
        ecscd = points[:,-2] * points[:,-1]
        etrck = points[:,-2] * (1-points[:,-1])
        points[:,-2] = ecscd
        points[:,-1] = etrck
    else:
        raise 
        
    return points

def init_points(hits, n_live_points, bound=bounds, seed=[None]):
    if seed[0] == None:
        avg = np.average(hits[:, :3], axis=0)
        low_lims = np.concatenate([avg-np.array([3,2,0]), np.array([-30,0,0,1,0])])
        hig_lims = np.concatenate([avg+np.array([3,2,0]), np.array([30,2*np.pi,0,7,1])])
    else:
        low_lims = seed - np.array([1, 0.5, 0, 5, 0.5, 0, 3, 0.3])
        hig_lims = seed + np.array([1, 0.5, 0, 5, 0.5, 0, 3, 0.3])
    
    uniforms = np.random.uniform(size=(n_live_points, 8))
    initial_points = low_lims + uniforms * (hig_lims - low_lims)
    initial_points = np.clip(initial_points, bounds[:, 0], bounds[:, 1])
    return initial_points

In [None]:
events, meta = toy_experiment.generate_events(n=10_000, gamma=0, gen_volume="box", e_lim=(1,7), inelast_lim=(0,1),
                                              x_lim=(-7,7), y_lim=(-2, 2), z_lim=(0,0), t_width=0, coszen_lim=(1,1),
                                              contained=False, min_hits=3) #, rand=1
truths = NNs.make_truth_array(events)
#np.save('recos/simple_toy/truths', truths)

### Analytic

In [None]:
def LLH_ana(X, hits, n_obs, form='total', fix=[None], bounds=bounds):
    if fix[0] != None:
        X = np.insert(X, fix[0], fix[1])

    if ~np.alltrue(np.logical_and(bounds[:,0] <= X, X <= bounds[:,1]), axis=-1):
        return 1e9
    
    segments = toy_experiment.model(*X)
    if form == 'dom':
        h_term = toy_experiment.nllh_p_term_dom(segments, hits)
        c_term = toy_experiment.nllh_N_term_dom(segments, n_obs)
    elif form == 'total':
        h_term = toy_experiment.nllh_p_term_tot(segments, hits)
        c_term = toy_experiment.nllh_N_term_tot(segments, n_obs)
    else:
        raise NameError("Formulation must be one of ['total', 'dom'], not "+form)
    
    return c_term + h_term

def fit_event_ana(event):
    hits = np.stack([event.photons[var].to_numpy() for var in ['x', 'y', 'z', 't', 'sensor_id']], axis=1)
    n_obs = event.n_obs.to_numpy() 
    #truth = event?
    
    def eval_LLH(params):
        llhs = []
        for p in params:
            llhs.append(LLH_ana(p, hits, n_obs)) #, fix=[[2, 5], [0, 0]]
        return np.array(llhs)

    # seeding
    initial_points = init_points(hits, 97) #, seed=truth
    
    # free fit
    fit_res = spherical_opt.spherical_opt(
        func=eval_LLH,
        method="CRS2",
        initial_points=initial_points,
        rand=np.random.default_rng(42),            
        spherical_indices=[[4,5]],
        batch_size=12,
    )

    return list(np.delete(fit_res['x'], [2, 4, 5]))

In [None]:
%%time
with Pool(10) as p:
    outs = p.map(fit_event_ana, events)
recos_ana = np.array(outs)

In [None]:
diff_ana = recos_ana - np.delete(truths, [2, 4, 5], axis=1)
#np.save('recos/simple_toy/diff_ana', diff_ana)

### NNs

In [None]:
import math

from functools import partial
from freedom.llh_service.llh_service import LLHService
from freedom.llh_service.llh_client import LLHClient
from freedom.reco import crs_reco

from scipy.optimize import minimize

In [None]:
NE, Set = 2_560_000 , 1

In [None]:
loc = 'networks/simple_toy/'
service_conf = {
        "poll_timeout": 1,
        "flush_period": 1,
        "n_hypo_params": 8,
        "n_hit_features": 6,
        "n_evt_features": 2, #len(detector)*4,
        "batch_size" : {
          "n_hypos": 200,
          "n_observations": 6000, 
        },
        "send_hwm": 10000,
        "recv_hwm": 10000,
        #"hitnet_file": loc+'%s/simple_toy_hitnet_total_%s_set%s.h5'%(NE, NE, Set),
        "chargenet_file": loc+'%s/simple_toy_chargenet_total_%s_set%s.h5'%(NE, NE, Set),
        "hitnet_file": loc+'/simple_toy_hitnet_total_noTsmear.h5',
        #"chargenet_file": loc+'/simple_toy_chargenet_total_large.h5',
        #"hitnet_file": loc+'simple_toy_hitnet_dom.h5',
        #"domnet_file": loc+'simple_toy_chargenet_dom.h5',
        #"ndoms": len(detector),
        "toy": True
}

In [None]:
n_gpus = 4

base_req = "ipc:///tmp/recotestreq"
base_ctrl = "ipc:///tmp/recotestctrl"

req_addrs = []
ctrl_addrs = []
for i in range(n_gpus):
    req_addrs.append(f'{base_req}{i}')
    ctrl_addrs.append(f'{base_ctrl}{i}')
    
procs = []
for i in range(n_gpus):
    proc = Process(target=crs_reco.start_service, args=(service_conf, ctrl_addrs[i], req_addrs[i], i))
    proc.start()
    procs.append(proc)

In [None]:
def fit_events_nn(events, index, Truths, ctrl_addrs):
    outputs = []

    client = LLHClient(ctrl_addr=ctrl_addrs[index], conf_timeout=60000)
    def Eval_llh(params, hits, n_obs, fix=[None]):
        if fix[0] != None:
            params = np.insert(params, fix[0], fix[1])
            
        if ~np.alltrue(np.logical_and(bounds[:,0] <= params, params <= bounds[:,1]), axis=-1):
            return 1e9
        
        c_data = [np.sum(n_obs), np.sum(n_obs > 0)] #total
        #c_data = np.hstack([detector, n_obs[:, np.newaxis]]) #dom
        #ps = np.array(params)
        #ps[-2], ps[-1] = ps[-2] * ps[-1], ps[-2] * (1-ps[-1])
        return client.eval_llh(hits, c_data, params)

    for j, event in enumerate(events):
        hits = np.stack([event.photons[var].to_numpy() for var in ['x', 'y', 'z', 't', 'q', 'sensor_id']], axis=1)
        n_obs = event.n_obs.to_numpy()
        
        def eval_LLH(params):
            if params.ndim == 1:
                return Eval_llh(params, hits, n_obs) #, fix=[[2, 5], [0, 0]]
            else:
                o = []
                for p in params:
                    o.append(Eval_llh(p, hits, n_obs)) #, fix=[[2, 5], [0, 0]]
                return np.array(o)

        # seeding
        initial_points = init_points(hits, 97) #, seed=Truths[j]
        
        #free fit
        fit_res = spherical_opt.spherical_opt(
            func=eval_LLH,
            method="CRS2",
            initial_points=initial_points,
            rand=np.random.default_rng(42),            
            spherical_indices=[[4,5]],
            batch_size=12,
        )
        outputs.append(np.delete(fit_res['x'], [2, 4, 5]))
        
    return outputs

In [None]:
events_to_process = len(events)
pool_size = 200
evts_per_proc = int(math.ceil(events_to_process/pool_size))
evt_splits = [events[i*evts_per_proc:(i+1)*evts_per_proc] for i in range(pool_size)]
true_splits = [truths[i*evts_per_proc:(i+1)*evts_per_proc] for i in range(pool_size)]
print(sum(len(l) for l in evt_splits))

gpu_inds = np.arange(pool_size) % n_gpus

fit_events_partial = partial(
        fit_events_nn,
        ctrl_addrs=ctrl_addrs
)

In [None]:
%%time
# reconstruct with a worker pool; one LLH client per worker
with Pool(pool_size) as p:
    outs = p.starmap(fit_events_partial, zip(evt_splits, gpu_inds, true_splits))

all_outs = sum((out for out in outs), [])
all_outs = np.array(all_outs).reshape((events_to_process, 5))
recos_nn = np.array(all_outs)

In [None]:
diff_nn = recos_nn - np.delete(truths, [2, 4, 5], axis=1)
#np.save('recos/simple_toy/%s/diff_nn_%s_set%s'%(NE, NE, Set), diff_nn)
#np.save('recos/simple_toy/diff_nn_noTsmear', diff_nn)

In [None]:
# kill all the services
import zmq
for proc, ctrl_addr in zip(procs, ctrl_addrs): 
    with zmq.Context.instance().socket(zmq.REQ) as ctrl_sock:
        ctrl_sock.connect(ctrl_addr)
        ctrl_sock.send_string("die")
        proc.join()

### plot

In [None]:
#truths = np.load('recos/simple_toy/truths.npy')
#truths_2 = np.load('recos/simple_toy/truths_2.npy')
diff_ana = np.load('recos/simple_toy/diff_ana.npy')
diff_ana2 = np.load('recos/simple_toy/diff_ana2.npy')
diff_ana_2 = np.load('recos/simple_toy/diff_ana_2.npy')
data_sizes = np.load('recos/simple_toy/data_sizes.npy')

#NEs = np.array([10000, 20000, 40000, 80000, 160000, 320000, 640000, 1280000, 2560000])
NEs = np.array([10000, 40000, 160000, 640000, 2560000])

for NE in NEs:
    exec("diff_nn_%s = np.load('recos/simple_toy/%s/diff_nn_%s.npy')"%(NE, NE, NE))
    exec("diff_nn_%s_big = np.load('recos/simple_toy/%s/diff_nn_%s_big.npy')"%(NE, NE, NE))
    exec("diff_nn_%s_small = np.load('recos/simple_toy/%s/diff_nn_%s_small.npy')"%(NE, NE, NE))

#diff_nn_1280000_2 = np.load('recos/simple_toy/1280000/diff_nn_1280000_2.npy')
#diff_nn_1280000_3 = np.load('recos/simple_toy/1280000/diff_nn_1280000_3.npy')
#diff_nn_1280000_4 = np.load('recos/simple_toy/1280000/diff_nn_1280000_4.npy')
#diff_nn_1280000_5 = np.load('recos/simple_toy/1280000/diff_nn_1280000_5.npy')

for s in [1,2,3,4]:
    for NE in NEs:
        exec("diff_nn_%s_set%s = np.load('recos/simple_toy/%s/diff_nn_%s_set%s.npy')"%(NE, s, NE, NE, s))

In [None]:
fig = plt.figure(figsize=(20, 10))

ranges = [(-2.5,2.5), (-4,4), (-15,20), (-5,5), (-1,1)]
for i, p in enumerate([0,1,3,6,7]):
    bins = np.linspace(ranges[i][0], ranges[i][1], 101)
    
    plt.subplot(2,3,i+1)
    plt.hist(diff_ana[:, i], bins, label='Analytic', histtype='step', linewidth=2)
    #plt.hist(diff_nn_10000[:, i], bins, label='NN_10k', histtype='step', linewidth=2) #, alpha=0.2
    #plt.hist(diff_nn_20000[:, i], bins, label='NN_20k', histtype='step', linewidth=2)
    #plt.hist(diff_nn_40000[:, i], bins, label='NN_40k', histtype='step', linewidth=2)
    #plt.hist(diff_nn_80000[:, i], bins, label='NN_80k', histtype='step', linewidth=2)
    #plt.hist(diff_nn_160000[:, i], bins, label='NN_160k', histtype='step', linewidth=2)
    #plt.hist(diff_nn_320000[:, i], bins, label='NN_320k', histtype='step', linewidth=2)
    #plt.hist(diff_nn_640000[:, i], bins, label='NN_640k', histtype='step', linewidth=2)
    #plt.hist(diff_nn_1280000[:, i], bins, label='NN_1.28m', histtype='step', linewidth=2)
    plt.hist(diff_nn_2560000[:, i], bins, label='NN_2.56m', histtype='step', linewidth=2)
    plt.hist(diff_nn[:, i], bins, label='NN', histtype='step', linewidth=2)
    #plt.hist(recos_ana[:, i], 50, alpha=0.5, label='Analytic')
    #plt.hist(recos_nn[:, i], 50, alpha=0.5, label='NN')
    if i == 1: plt.title('Reco-Truth')
    if i == 2: plt.legend()
    plt.xlabel(par_names[p])
    
#plt.savefig('images/simple_tests/simple_reco_big_best.png', bbox_inches='tight')

In [None]:
fig = plt.figure(figsize=(15, 6))

plt.subplot(121)
iqr_ana = stats.iqr(diff_ana, axis=0)
plt.scatter(range(5), iqr_ana/iqr_ana, label='Analytic')
plt.scatter(range(5), stats.iqr(diff_nn_10000, axis=0)/iqr_ana, label='NN_10k')
plt.scatter(range(5), stats.iqr(diff_nn_20000, axis=0)/iqr_ana, label='NN_20k')
plt.scatter(range(5), stats.iqr(diff_nn_40000, axis=0)/iqr_ana, label='NN_40k')
plt.scatter(range(5), stats.iqr(diff_nn_80000, axis=0)/iqr_ana, label='NN_80k')
plt.scatter(range(5), stats.iqr(diff_nn_160000, axis=0)/iqr_ana, label='NN_160k')
plt.scatter(range(5), stats.iqr(diff_nn_320000, axis=0)/iqr_ana, label='NN_320k')
plt.scatter(range(5), stats.iqr(diff_nn_640000, axis=0)/iqr_ana, label='NN_640k')
plt.scatter(range(5), stats.iqr(diff_nn_1280000, axis=0)/iqr_ana, label='NN_1.28m')
plt.scatter(range(5), stats.iqr(diff_nn_2560000, axis=0)/iqr_ana, label='NN_2.56m')

plt.xticks(range(5), np.array(par_names)[[0,1,3,6,7]])
plt.ylabel('IQR Reco-Truth / Analytic')
plt.ylim(0.93,1.3)

plt.subplot(122)
med_ana = np.median(np.abs(diff_ana), axis=0)
plt.scatter(range(5), med_ana/med_ana, label='Analytic')
plt.scatter(range(5), np.median(np.abs(diff_nn_10000), axis=0)/med_ana, label='NN_10k')
plt.scatter(range(5), np.median(np.abs(diff_nn_20000), axis=0)/med_ana, label='NN_20k')
plt.scatter(range(5), np.median(np.abs(diff_nn_40000), axis=0)/med_ana, label='NN_40k')
plt.scatter(range(5), np.median(np.abs(diff_nn_80000), axis=0)/med_ana, label='NN_80k')
plt.scatter(range(5), np.median(np.abs(diff_nn_160000), axis=0)/med_ana, label='NN_160k')
plt.scatter(range(5), np.median(np.abs(diff_nn_320000), axis=0)/med_ana, label='NN_320k')
plt.scatter(range(5), np.median(np.abs(diff_nn_640000), axis=0)/med_ana, label='NN_640k')
plt.scatter(range(5), np.median(np.abs(diff_nn_1280000), axis=0)/med_ana, label='NN_1.28m')
plt.scatter(range(5), np.median(np.abs(diff_nn_2560000), axis=0)/med_ana, label='NN_2.56m')

plt.xticks(range(5), np.array(par_names)[[0,1,3,6,7]])
plt.legend()
plt.ylabel('Median absolute Reco-Truth / Analytic')
plt.ylim(0.93,1.3)

#plt.savefig('images/simple_tests/simple_reco_iqr_med2.png', bbox_inches='tight')

ks results

In [None]:
ks_ana2, ks_ana_2 = [], []
for i in range(5):
    ks_ana2.append(stats.ks_2samp(diff_ana[:,i], diff_ana2[:,i])[0])
    ks_ana_2.append(stats.ks_2samp(diff_ana[:,i], diff_ana_2[:,i])[0])
    
#ks = []
#for i in range(5):
#    ks.append(stats.ks_2samp(diff_ana[:,i], diff_nn[:,i])[0])

kss, kss_big, kss_small = np.zeros((5,len(NEs),5)), np.zeros((len(NEs),5)), np.zeros((len(NEs),5))
for j, NE in enumerate(NEs):
    exec('diff = diff_nn_'+str(NE))
    for i in range(5):
        kss[0][j][i] = stats.ks_2samp(diff_ana[:,i], diff[:,i])[0]
        
    exec('diff = diff_nn_'+str(NE)+'_big')
    for i in range(5):
        kss_big[j][i] = stats.ks_2samp(diff_ana[:,i], diff[:,i])[0]
        
    exec('diff = diff_nn_'+str(NE)+'_small')
    for i in range(5):
        kss_small[j][i] = stats.ks_2samp(diff_ana[:,i], diff[:,i])[0]
    
for s in [1,2,3,4]:
    for j, NE in enumerate(NEs):
        exec('diff = diff_nn_'+str(NE)+'_set'+str(s))
        for i in range(5):
            kss[s][j][i] = stats.ks_2samp(diff_ana[:,i], diff[:,i])[0]

In [None]:
np.mean(np.std(np.vstack([np.mean(kss_big, axis=1), np.mean(kss[0], axis=1), np.mean(kss_small, axis=1)]), axis=0))

In [None]:
fig = plt.figure(figsize=(7, 8*0.618))
ax1 = fig.add_subplot(111)
ax2 = ax1.twiny()

for i, p in enumerate([0,1,3,6,7]):
    ax1.scatter(NEs, kss[0][:,i], alpha=0.5, label=par_names[p])
ax1.scatter(NEs, np.mean(kss[0], axis=1), label='Average', color='black')

#ax1.scatter(NEs, np.mean(kss_big, axis=1), label='Average big NN', color='black', marker='*')
#ax1.scatter(NEs, np.mean(kss_small, axis=1), label='Average small NN', color='black', marker='s')

ax1.axhline(np.mean(ks_ana2), color='black', linestyle='--', label=r'KS$_{rereco}$')
#ax1.axhline(np.mean(ks_ana_2), color='grey', linestyle=':', alpha=0.7, label='new events (ana)')

ax2.set_xlabel('Data size (GB)')
ax2.scatter(data_sizes[::2]*1e-9, np.mean(kss[0], axis=1), marker='')
ax2.set_xscale('log')

ax1.legend(prop={'size':15})
ax1.set_xscale('log')
ax1.set_xlabel('#Events in training set')
ax1.set_ylabel('KS value')

#plt.savefig('images/simple_tests/simple_reco_ks.png', bbox_inches='tight') #_avg

In [None]:
fig = plt.figure(figsize=(7, 8*0.618))
ax1 = fig.add_subplot(111)
ax2 = ax1.twiny()

ax1.errorbar(NEs, np.mean(np.mean(kss, axis=2), axis=0), np.std(np.mean(kss, axis=2), axis=0), 
             label='Average all sets', color='black')
for i, a in enumerate(np.mean(kss, axis=2)):
    ax1.scatter(NEs, a, label='Average set '+str(i))

ax1.axhline(np.mean(ks_ana2), color='black', linestyle='--', label=r'KS$_{rereco}$')
#ax1.axhline(np.mean(ks_ana_2), color='grey', linestyle=':', alpha=0.7, label='new events (ana)')
#ax1.axhline(np.mean(ks), color='black', linestyle='--', label=r'KS$_{large}$')

ax2.set_xlabel('Data size (GB)')
ax2.scatter(data_sizes[::2]*1e-9, np.mean(kss[0], axis=1), marker='')
ax2.set_xscale('log')

ax1.legend(prop={'size':15}, loc='upper right')
ax1.set_xscale('log')
ax1.set_xlabel('#Events in training set')
ax1.set_ylabel('KS value')
plt.ylim(0, 0.27)

#plt.savefig('images/simple_tests/simple_reco_ks_avg_sets.png', bbox_inches='tight')

In [None]:
def func(x, a, b):
    return a/np.log(b*x) + np.mean(ks_ana2)
def lin(x, a, b):
    return a * x + b

x = NEs
y = np.mean(np.mean(kss, axis=2), axis=0)

fit, cov = curve_fit(func, x, y, sigma=np.std(np.mean(kss, axis=2), axis=0), p0=[0.3, 5e-4])
fit2, cov = curve_fit(lin, x, data_sizes[::2], p0=[4000, 4e5])
ps = np.logspace(3.9,9,100)
fit, fit2

In [None]:
fig = plt.figure(figsize=(7, 8*0.618))
ax1 = fig.add_subplot(111)
ax2 = ax1.twiny()

ax1.errorbar(NEs, np.mean(np.mean(kss, axis=2), axis=0), np.std(np.mean(kss, axis=2), axis=0), 
             label='Average all sets', color='black')
ax1.axhline(np.mean(ks_ana2), color='black', linestyle='--', label=r'KS$_{rereco}$')
ax1.plot(ps, func(ps, fit[0], fit[1]), label=r'$\frac{a}{\log(b\cdot x)}+KS_{rereco}$')

ax2.set_xlabel('Data size (GB)')
ax2.scatter(lin(ps, fit2[0], fit2[1])*1e-9, np.zeros(len(ps)), marker='')
ax2.set_xscale('log')

ax1.legend(prop={'size':15})
ax1.set_xscale('log')
ax1.set_xlabel('#Events in training set')
ax1.set_ylabel('KS value')

#plt.savefig('images/simple_tests/simple_reco_ks_avg_fit.png', bbox_inches='tight')

error per pos

In [None]:
IDX = (6,7)
for NE in NEs:
    exec('diff = diff_nn_'+str(NE))
    
    fig = plt.figure(figsize=(22, 10))
    for i, p in enumerate([0,1,3,6,7]):
        plt.subplot(2,3,i+1)
        idx = np.abs(diff[:,i]).argsort()

        plt.scatter(truths[idx,IDX[0]], truths[idx,IDX[1]], c=np.abs(diff[idx,i]), s=12)
        cbar=plt.colorbar()
        cbar.set_label('Absolute ' + par_names[p] + ' error')
        #plt.scatter(np.linspace(-5,5,5), np.zeros(5), color='red')

        if i==1: plt.title('Resolution depending on true position (%s)'%(NE))
        plt.xlabel(par_names[IDX[0]])
        plt.ylabel(par_names[IDX[1]])

    pname = par_names[IDX[0]]+par_names[IDX[1]]
    #plt.savefig('images/simple_tests/simple_resolutions_'+pname+'_'+str(NE)+'.png', bbox_inches='tight')