In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3"
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import tensorflow as tf
from sklearn.model_selection import train_test_split
import dama as dm
import pickle
import os

from freedom.toy_model.toy_model_functions import toy_model
from freedom.toy_model.detectors import get_spherical_detector
from types import SimpleNamespace

%load_ext autoreload
%autoreload 2

In [None]:
plt.rcParams['xtick.labelsize'] = 14
plt.rcParams['ytick.labelsize'] = 14 
plt.rcParams['axes.labelsize'] = 16
plt.rcParams['axes.titlesize'] = 16
plt.rcParams['legend.fontsize'] = 14

In [None]:
detector = get_spherical_detector(radius=10, subdivisions=4)

In [None]:
toy_experiment = toy_model(detector)

In [None]:
truth = np.array([3., 1., 0, 0, 0, np.arccos(1), 10., 0])

In [None]:
# generate one test event
test_event = toy_experiment.generate_event(truth)

In [None]:
%%time
# Grid scan

g = dm.GridData(x=np.linspace(-5, 5, 100), y=np.linspace(0, 2, 100))

g['dom_hit_term'] = np.empty(g.shape)
g['dom_charge_terms'] = np.empty(g.shape)
g['total_charge_hit_terms'] = np.empty(g.shape)
g['total_charge_terms'] = np.empty(g.shape)

p = np.copy(truth)

for idx in np.ndindex(g.shape):
    p[0] =  g['x'][idx]
    p[1] =  g['y'][idx]
    segments = toy_experiment.model(*p)
    g['dom_hit_term'][idx] = toy_experiment.nllh_p_term_dom(segments, test_event[0])
    g['dom_charge_terms'][idx] = toy_experiment.nllh_N_term_dom(segments, test_event[1])
    g['total_charge_hit_terms'][idx] = toy_experiment.nllh_p_term_tot(segments, test_event[0])
    g['total_charge_terms'][idx] = toy_experiment.nllh_N_term_tot(segments, test_event[1])

In [None]:
g['dom_hit_term'] -= g['dom_hit_term'].min()
g['dom_charge_terms'] -= g['dom_charge_terms'].min()
g['dom_llh'] = g['dom_hit_term'] + g['dom_charge_terms']
g['total_charge_hit_terms'] -= g['total_charge_hit_terms'].min()
g['total_charge_terms'] -= g['total_charge_terms'].min()
g['total_charge_llh'] = g['total_charge_hit_terms'] + g['total_charge_terms']
g['dom_llh'] -= g['dom_llh'].min()
g['total_charge_llh'] -= g['total_charge_llh'].min()

In [None]:
def plot_truth(axes, truth):
    if not isinstance(axes, np.ndarray):
        axes = np.array([axes])
    for ax in axes.flatten():
        ax.plot([truth[0]], [truth[1]], marker='$T$', markersize=10, color='k')

In [None]:
def plot_diff(a, b, axes, title_a='a', title_b='b', vmax=None, limit_diff=False, **kwargs):
    
    levels = stats.chi2(df=2).isf(stats.norm.sf(np.arange(1,6))*2)/2    
    labels = [str(i) + r'$\sigma$' for i in range(1,6)]
    colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red']
    colors = plt.cm.viridis(np.linspace(0, 0.9, 6))
    #a.plot(ax=axes[0], cmap='Greys', label=r'$\Delta LLH$', **kwargs)
    a.plot_contour(ax=axes[0], levels=levels, labels=labels, colors=colors, label=r'$\Delta LLH$', **kwargs)
    axes[0].set_title(title_a)
    #b.plot(ax=axes[1], cmap='Greys', label=r'$\Delta LLH$', **kwargs)
    b.plot_contour(ax=axes[1], levels=levels,  labels=labels, colors=colors, label=r'$\Delta LLH$', **kwargs)
    axes[1].set_title(title_b)
    diff = a - b
    if limit_diff:
        diff.plot(ax=axes[2], cmap='RdBu', cbar=True, vmin=-vmax, vmax=vmax, label=r'$\Delta LLH$', **kwargs)
        #diff.plot_contour(ax=axes[2], levels=levels, labels=labels, colors=colors, label=r'$\Delta LLH$', **kwargs)
        #np.clip(-diff, 0, None).plot_contour(ax=axes[2], levels=[0.1,0.2, 0.3], colors=['r']*2)
    else:
        diff.plot(ax=axes[2], cmap='RdBu', cbar=True, vmin=-np.max(np.abs(diff)), vmax=np.max(np.abs(diff)), label=r'$\Delta LLH$', **kwargs) 
        #diff.plot_contour(ax=axes[2], levels=levels, labels=labels, colors=colors, label=r'$\Delta LLH$', **kwargs)
    axes[2].set_title(title_a + ' - ' + title_b)

In [None]:
#stats.norm.isf(stats.chi2(df=2).sf(g['dom_llh']*2)/2)

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(20,17))
plt.subplots_adjust(wspace=0.3, hspace=0.3)

plot_diff(g['dom_hit_term'], g['total_charge_hit_terms'], axes=ax[0], title_a='per DOM hit', title_b='total hit', vmax=20, limit_diff=True)
plot_diff(g['dom_charge_terms'], g['total_charge_terms'], axes=ax[1], title_a='per DOM charge', title_b='total charge', vmax=20, limit_diff=True)
plot_diff(g['dom_llh'], g['total_charge_llh'], axes=ax[2], title_a='per DOM llh', title_b='total llh', limit_diff=False)

plot_truth(ax, truth)

In [None]:
stats.chi2(df=2).isf(stats.norm.sf(np.arange(1,6))*2)/2

# Train NNs

In [None]:
from freedom.toy_model import NNs
%aimport freedom.toy_model.NNs

In [None]:
events, truths = toy_experiment.generate_event_sphere(n=100_000, e_lim=(1,10), inelast_lim=(0,1),
                                                      t_width=0, contained=True, radius=9)

In [None]:
strategy = tf.distribute.MirroredStrategy()
nGPUs = strategy.num_replicas_in_sync

## Hit Net - total charge

In [None]:
x, t = NNs.get_hit_data(events, truths)
x_train, x_test, t_train, t_test = train_test_split(x, t, test_size=0.1, random_state=42)
d_train = NNs.DataGenerator(x_train, t_train, batch_size=4096*nGPUs, time_spread=50)
d_valid = NNs.DataGenerator(x_test, t_test, batch_size=4096*nGPUs, time_spread=50)

In [None]:
optimizer = tf.keras.optimizers.Adam(1e-4)

with strategy.scope():
    hmodel = NNs.get_hmodel(x_shape=6, t_shape=8, trafo=NNs.hit_trafo_3D, activation='swish', final_activation='swish')
    hmodel.compile(loss='binary_crossentropy', optimizer=optimizer)

In [None]:
hist = hmodel.fit(d_train, epochs=25, verbose=1, validation_data=d_valid)

In [None]:
plt.plot(hist.history['loss'])
plt.plot(hist.history['val_loss'])
plt.yscale('log')

In [None]:
inp, lab = d_valid.__getitem__(0)
pred = hmodel.predict(inp, batch_size=4096).flatten()
plt.hist(pred[lab==0], 100, histtype='step')
plt.hist(pred[lab==1], 100, histtype='step');
plt.yscale('log')

In [None]:
xxs = np.repeat(test_event[0][np.newaxis, :], np.prod(g.shape), axis=0)
xxs = xxs.reshape(-1, 6)

tts = np.repeat(truth[np.newaxis, :], np.prod(g.shape), axis=0)
tts[:, 0] = g.get_array('x', flat=True)
tts[:, 1] = g.get_array('y', flat=True)
tts = np.repeat(tts, len(test_event[0]), axis=0)

hmodel.layers[-1].activation = tf.keras.activations.linear
hmodel.compile()

llhs = -hmodel.predict((xxs, tts), batch_size=4096)

llhs = np.sum(np.nan_to_num(llhs.reshape(-1, len(test_event[0]))), axis=1)

g.hit_llh_total = llhs.reshape(g.shape)

g.hit_llh_total -= g.hit_llh_total.min()

fig, ax = plt.subplots(1, 3, figsize=(20,5))
plt.subplots_adjust(wspace=0.3)

plot_diff(g.total_charge_hit_terms, g.hit_llh_total, title_a='Analytic', title_b='NN', vmax=10, axes=ax, limit_diff=True)
plot_truth(ax, truth)
#plt.savefig('images/hitNNtest.png', bbox_inches='tight')

In [None]:
#hmodel.save('networks/spherical_toy_hitnet_total.h5')

## Charge Net - total charge

In [None]:
x, t = NNs.get_charge_data(events, truths)
x_train, x_test, t_train, t_test = train_test_split(x, t, test_size=0.1, random_state=42)

d_train = NNs.DataGenerator(x_train, t_train, batch_size=4096*nGPUs)
d_valid = NNs.DataGenerator(x_test, t_test, batch_size=4096*nGPUs)

In [None]:
optimizer = tf.keras.optimizers.Adam(1e-3)

with strategy.scope():
    cmodel = NNs.get_cmodel(x_shape=2, t_shape=8, trafo=NNs.charge_trafo_3D, activation='swish', final_activation='swish') #
    cmodel.compile(loss='binary_crossentropy', optimizer=optimizer)

In [None]:
hist = cmodel.fit(d_train, epochs=150, verbose=1, validation_data=d_valid)

In [None]:
plt.plot(hist.history['loss'])
plt.plot(hist.history['val_loss'])
plt.yscale('log')

In [None]:
inp, lab = d_valid.__getitem__(0)
pred = cmodel.predict(inp, batch_size=4096).flatten()
plt.hist(pred[lab==0], 100, histtype='step')
plt.hist(pred[lab==1], 100, histtype='step');
plt.yscale('log')

In [None]:
xxs = np.tile([len(test_event[0]), len(np.unique(test_event[0][:,0]))], np.prod(g.shape))
xxs = xxs.reshape(-1, 2)

tts = np.repeat(truth[np.newaxis, :], np.prod(g.shape), axis=0)
tts[:, 0] = g.get_array('x', flat=True)
tts[:, 1] = g.get_array('y', flat=True)

cmodel.layers[-1].activation = tf.keras.activations.linear
cmodel.compile()

llhs = np.nan_to_num(-cmodel.predict((xxs, tts), batch_size=4096))

g.charge_llh_total = llhs.reshape(g.shape)

g.charge_llh_total -= g.charge_llh_total.min()

fig, ax = plt.subplots(1, 3, figsize=(20,5))
plt.subplots_adjust(wspace=0.3)

plot_diff(g.total_charge_terms, g.charge_llh_total, title_a='Analytic', title_b='NN', vmax=10, axes=ax, limit_diff=True)
plot_truth(ax, truth)
#plt.savefig('images/chargeNNtest.png', bbox_inches='tight')

In [None]:
#cmodel.save('networks/spherical_toy_chargenet_total.h5')

## LLH - total charge

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(20,17))
plt.subplots_adjust(wspace=0.25, hspace=0.25)

plot_diff(g.total_charge_hit_terms, 
          g.hit_llh_total, 
          title_a='Hit Analytic', title_b='Hit NN', vmax=10, axes=ax[0], limit_diff=True)
plot_truth(ax, truth)

plot_diff(g.total_charge_terms, 
          g.charge_llh_total, 
          title_a='Charge Analytic', title_b='Charge NN', vmax=10, axes=ax[1], limit_diff=True)
plot_truth(ax, truth)

ana, NN = g.total_charge_hit_terms+g.total_charge_terms, g.hit_llh_total+g.charge_llh_total
plot_diff(ana-ana.min(), 
          NN-NN.min(), 
          title_a='Analytic', title_b='NN', vmax=10, axes=ax[2], limit_diff=True)
plot_truth(ax, truth)

#plt.savefig('images/NNtest_totalC.png', bbox_inches='tight')

## Hit Net - per dom

In [None]:
x, t = NNs.get_hit_data(events, truths)
x_train, x_test, t_train, t_test = train_test_split(x, t, test_size=0.1, random_state=42)
d_train = NNs.DataGenerator(x_train, t_train, batch_size=4096*nGPUs, shuffle='inDOM', time_spread=50)
d_valid = NNs.DataGenerator(x_test, t_test, batch_size=4096*nGPUs, shuffle='inDOM', time_spread=50)

In [None]:
optimizer = tf.keras.optimizers.Adam(2e-5)

with strategy.scope():
    hmodel = NNs.get_hmodel(x_shape=6, t_shape=8, trafo=NNs.hit_trafo_3D, activation='swish', final_activation='swish')
    hmodel.compile(loss='binary_crossentropy', optimizer=optimizer)

In [None]:
hist = hmodel.fit(d_train, epochs=8, verbose=1, validation_data=d_valid)

In [None]:
plt.plot(hist.history['loss'])
plt.plot(hist.history['val_loss'])
plt.yscale('log')

In [None]:
inp, lab = d_valid.__getitem__(0)
pred = hmodel.predict(inp, batch_size=4096).flatten()
plt.hist(pred[lab==0], 100, histtype='step')
plt.hist(pred[lab==1], 100, histtype='step');
plt.yscale('log')

In [None]:
xxs = np.repeat(test_event[0][np.newaxis, :], np.prod(g.shape), axis=0)
xxs = xxs.reshape(-1, 6)

tts = np.repeat(truth[np.newaxis, :], np.prod(g.shape), axis=0)
tts[:, 0] = g.get_array('x', flat=True)
tts[:, 1] = g.get_array('y', flat=True)
tts = np.repeat(tts, len(test_event[0]), axis=0)

hmodel.layers[-1].activation = tf.keras.activations.linear
hmodel.compile()

llhs = -hmodel.predict((xxs, tts), batch_size=4096)

llhs = np.sum(np.nan_to_num(llhs.reshape(-1, len(test_event[0]))), axis=1)

g.hit_llh_dom = llhs.reshape(g.shape)

g.hit_llh_dom -= g.hit_llh_dom.min()

fig, ax = plt.subplots(1, 3, figsize=(20,5))
plt.subplots_adjust(wspace=0.3)

plot_diff(g.dom_hit_term, g.hit_llh_dom, title_a='Analytic', title_b='NN', vmax=10, axes=ax, limit_diff=True)
plot_truth(ax, truth)
#plt.savefig('images/hitNNtest.png', bbox_inches='tight')

In [None]:
#hmodel.save('networks/spherical_toy_hitnet_dom.h5')

## Charge Net - per DOM

In [None]:
x, t = NNs.get_dom_data(events, truths, detector)
x_train, x_test, t_train, t_test = train_test_split(x, t, test_size=0.1, random_state=42)

d_train = NNs.DataGenerator(x_train, t_train, batch_size=4096*nGPUs)
d_valid = NNs.DataGenerator(x_test, t_test, batch_size=4096*nGPUs)

In [None]:
optimizer = tf.keras.optimizers.Adam(5e-4)

with strategy.scope():
    dmodel = NNs.get_hmodel(x_shape=4, t_shape=8, trafo=NNs.dom_trafo_3D, activation='swish', final_activation='swish')
    dmodel.compile(loss='binary_crossentropy', optimizer=optimizer)

In [None]:
hist = dmodel.fit(d_train, epochs=10, verbose=1, validation_data=d_valid)

In [None]:
plt.plot(hist.history['loss'])
plt.plot(hist.history['val_loss'])
plt.yscale('log')

In [None]:
inp, lab = d_valid.__getitem__(0)
pred = dmodel.predict(inp, batch_size=4096).flatten()
plt.hist(pred[lab==0], 100, histtype='step')
plt.hist(pred[lab==1], 100, histtype='step');
plt.yscale('log')

In [None]:
xx = []
ind = test_event[0][:, 5]
for i in range(len(detector)):
    d = np.append(detector[i], np.sum(ind==i))
    xx.append(list(d))
xxs = np.repeat(np.array(xx)[np.newaxis, :], np.prod(g.shape), axis=0)
xxs = xxs.reshape(-1, 4)

tts = np.repeat(truth[np.newaxis, :], np.prod(g.shape), axis=0)
tts[:, 0] = g.get_array('x', flat=True)
tts[:, 1] = g.get_array('y', flat=True)
tts = np.repeat(tts, len(detector), axis=0)

dmodel.layers[-1].activation = tf.keras.activations.linear
dmodel.compile()

llhs = -dmodel.predict((xxs, tts), batch_size=4096)

llhs = np.sum(np.nan_to_num(llhs.reshape(-1, len(detector))), axis=1)

g.charge_llh_dom = llhs.reshape(g.shape)

g.charge_llh_dom -= g.charge_llh_dom.min()

fig, ax = plt.subplots(1, 3, figsize=(20,5))
plt.subplots_adjust(wspace=0.3)

plot_diff(g.dom_charge_terms, g.charge_llh_dom, title_a='Analytic', title_b='NN', vmax=10, axes=ax, limit_diff=True)
plot_truth(ax, truth)
#plt.savefig('images/hitNNtest.png', bbox_inches='tight')

In [None]:
#dmodel.save('networks/spherical_toy_chargenet_dom.h5')

## LLH - per DOM

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(20,17))
plt.subplots_adjust(wspace=0.25, hspace=0.25)

plot_diff(g.dom_hit_term, 
          g.hit_llh_dom, 
          title_a='Hit Analytic', title_b='Hit NN', vmax=10, axes=ax[0], limit_diff=True)
plot_truth(ax, truth)

plot_diff(g.dom_charge_terms, 
          g.charge_llh_dom, 
          title_a='Charge Analytic', title_b='Charge NN', vmax=10, axes=ax[1], limit_diff=True)
plot_truth(ax, truth)

ana, NN = g.dom_hit_term+g.dom_charge_terms, g.hit_llh_dom+g.charge_llh_dom
plot_diff(ana-ana.min(), 
          NN-NN.min(), 
          title_a='Analytic', title_b='NN', vmax=10, axes=ax[2], limit_diff=True)
plot_truth(ax, truth)

#plt.savefig('images/NNtest_perDOM.png', bbox_inches='tight')

## LLH - both

In [None]:
def plot_overlay(a, b, ax, **kwargs):
    levels = stats.chi2(df=2).isf(stats.norm.sf(np.arange(1,6))*2)/2    
    labels = [str(i) + r'$\sigma$' for i in range(1,6)]
    colors = plt.cm.viridis(np.linspace(0, 0.9, 6))
    a.plot_contour(ax=ax, levels=levels, labels=labels, colors=colors, **kwargs)
    b.plot_contour(ax=ax, levels=levels, linestyles=[':']*len(levels), colors=colors, **kwargs)
    ax.plot([], [], label='Analytic', color='Tab:blue')
    ax.plot([], [], label='NN', linestyle=':', color='Tab:blue')
    ax.set_xlabel('x')
    ax.set_ylabel('y')

In [None]:
g['llh_dom'] = g['hit_llh_dom'] + g['charge_llh_dom']
g['llh_total'] = g['hit_llh_total'] + g['charge_llh_total']
g['llh_dom'] -= g['llh_dom'].min()
g['llh_total'] -= g['llh_total'].min()

In [None]:
fig, ax = plt.subplots(2, 3, figsize=(20,12))
plt.subplots_adjust(wspace=0.25, hspace=0.25)

plot_truth(ax, truth)

ax[0][0].set_title('HitNet', size=20)
plot_overlay(g.dom_hit_term, g.hit_llh_dom, ax[0][0])
ax[0][0].text(-7.3, 1, 'Per Sensor', rotation=90, size=20, ha='center', va='center')

ax[0][1].set_title('ChargeNet', size=20)
plot_overlay(g.dom_charge_terms, g.charge_llh_dom, ax[0][1])

ax[0][2].set_title('Complete LLH', size=20)
plot_overlay(g.dom_llh, g.llh_dom, ax[0][2])
ax[0][2].legend(loc='upper left')

plot_overlay(g.total_charge_hit_terms, g.hit_llh_total, ax[1][0])
ax[1][0].text(-7.3, 1, 'Total Detector', rotation=90, size=20, ha='center', va='center')

plot_overlay(g.total_charge_terms, g.charge_llh_total, ax[1][1])

plot_overlay(g.total_charge_llh, g.llh_total, ax[1][2])

#plt.savefig('images/NNtest_spherical.png', bbox_inches='tight')

# Reco

In [None]:
from spherical_opt import spherical_opt
from multiprocessing import Pool, Process

In [None]:
bounds = np.array([[-9,9], [-9,9], [-9,9], [-200,200], [0,2*np.pi], [0,np.pi], [0,10], [0,10]])

def init_points(hits, n_live_points, bound=bounds, seed=[None]):
    if seed[0] == None:
        avg = np.average(hits[:, :4], axis=0)
        low_lims = np.concatenate([avg-np.array([3,3,3,40]), np.array([0,0,0,0])])
        hig_lims = np.concatenate([avg+np.array([3,3,3,0]), np.array([2*np.pi,np.pi,10,10])])
    else:
        low_lims = seed - np.array([1, 1, 1, 5, 0.5, 0.3, 3, 3])
        hig_lims = seed + np.array([1, 1, 1, 5, 0.5, 0.3, 3, 3])
    
    uniforms = np.random.uniform(size=(n_live_points, 8))
    initial_points = low_lims + uniforms * (hig_lims - low_lims)
    initial_points = np.clip(initial_points, bounds[:, 0], bounds[:, 1])
    return initial_points

In [None]:
events, truths = toy_experiment.generate_event_sphere(n=1000, e_lim=(1,10), inelast_lim=(0,1),
                                                      t_width=0, contained=True, radius=9, N_min=3)

### Analytic

In [None]:
def LLH_ana(X, event, form='total', fix=[None], bounds=bounds):
    if fix[0] != None:
        X = np.insert(X, fix[0], fix[1])
        
    if ~np.alltrue(np.logical_and(bounds[:,0] <= X, X <= bounds[:,1]), axis=-1):
        return 1e9
    
    segments = toy_experiment.model(*X)
    if form == 'dom':
        h_term = toy_experiment.nllh_p_term_dom(segments, event[0])
        c_term = toy_experiment.nllh_N_term_dom(segments, event[1])
    elif form == 'total':
        h_term = toy_experiment.nllh_p_term_tot(segments, event[0])
        c_term = toy_experiment.nllh_N_term_tot(segments, event[1])
    else:
        raise NameError("Formulation must be one of ['total', 'dom'], not "+form)
    
    return c_term + h_term

def fit_event_ana(event):
    event, truth = event
    
    def eval_LLH(params):
        if params.ndim == 1:
            return LLH_ana(params, event)
        else:
            llhs = []
            for p in params:
                llhs.append(LLH_ana(p, event))
            return np.array(llhs)

    # seeding
    initial_points = init_points(event[0], 97) #, seed=truth
    
    # free fit
    fit_res = spherical_opt.spherical_opt(
        func=eval_LLH,
        method="CRS2",
        initial_points=initial_points,
        rand=np.random.default_rng(42),
        spherical_indices=[[4,5]],
        batch_size=12,
    )

    return list(fit_res['x'])

In [None]:
%%time
with Pool(10) as p:
    outs = p.map(fit_event_ana, zip(events, truths))
recos_ana = np.array(outs)

In [None]:
diff_ana = recos_ana - truths

### NNs

In [None]:
import math

from functools import partial
from freedom.llh_service.llh_service import LLHService
from freedom.llh_service.llh_client import LLHClient
from freedom.reco import crs_reco

In [None]:
loc = 'networks/'
service_conf = {
        "poll_timeout": 1,
        "flush_period": 1,
        "n_hypo_params": 8,
        "n_hit_features": 6,
        "n_evt_features": len(detector)*4, #2
        "batch_size" : {
          "n_hypos": 200,
          "n_observations": 6000, 
        },
        "send_hwm": 10000,
        "recv_hwm": 10000,
        #"hitnet_file": loc+'spherical_toy_hitnet_total.h5',
        #"chargenet_file": loc+'spherical_toy_chargenet_total.h5',
        "hitnet_file": loc+'spherical_toy_hitnet_dom.h5',
        "domnet_file": loc+'spherical_toy_chargenet_dom.h5',
        "ndoms": len(detector),
        "toy": True
}

In [None]:
n_gpus = 4

base_req = "ipc:///tmp/recotestreq"
base_ctrl = "ipc:///tmp/recotestctrl"

req_addrs = []
ctrl_addrs = []
for i in range(n_gpus):
    req_addrs.append(f'{base_req}{i}')
    ctrl_addrs.append(f'{base_ctrl}{i}')
    
procs = []
for i in range(n_gpus):
    proc = Process(target=crs_reco.start_service, args=(service_conf, ctrl_addrs[i], req_addrs[i], i))
    proc.start()
    procs.append(proc)

In [None]:
def fit_events_nn(events, index, Truths, ctrl_addrs):
    outputs = []

    client = LLHClient(ctrl_addr=ctrl_addrs[index], conf_timeout=60000)
    def Eval_llh(params, event, fix=[None]):
        if fix[0] != None:
            params = np.insert(params, fix[0], fix[1])
            
        if ~np.alltrue(np.logical_and(bounds[:,0] <= params, params <= bounds[:,1]), axis=-1):
            return 1e9
        
        #c_data = [np.sum(event[1]), np.sum(event[1] > 0)] #total
        c_data, _ = NNs.get_dom_data(np.array([event]), np.ones((1,8)), detector) #dom
        return client.eval_llh(event[0], c_data, params)

    for j, event in enumerate(events):
        def eval_LLH(params):
            if params.ndim == 1:
                return Eval_llh(params, event)
            else:
                o = []
                for p in params:
                    o.append(Eval_llh(p, event))
                return np.array(o)

        # seeding
        initial_points = init_points(event[0], 97) #, seed=Truths[j]
        
        #free fit
        fit_res = spherical_opt.spherical_opt(
            func=eval_LLH,
            method="CRS2",
            initial_points=initial_points,
            rand=np.random.default_rng(42),
            spherical_indices=[[4,5]],
            batch_size=12,
        )
        outputs.append(fit_res['x'])
        
    return outputs

In [None]:
events_to_process = len(events)
pool_size = 200
evts_per_proc = int(math.ceil(events_to_process/pool_size))
evt_splits = [events[i*evts_per_proc:(i+1)*evts_per_proc] for i in range(pool_size)]
true_splits = [truths[i*evts_per_proc:(i+1)*evts_per_proc] for i in range(pool_size)]
print(sum(len(l) for l in evt_splits))

gpu_inds = np.arange(pool_size) % n_gpus

fit_events_partial = partial(
        fit_events_nn,
        ctrl_addrs=ctrl_addrs
)

In [None]:
%%time
# reconstruct with a worker pool; one LLH client per worker
with Pool(pool_size) as p:
    outs = p.starmap(fit_events_partial, zip(evt_splits, gpu_inds, true_splits))

all_outs = sum((out for out in outs), [])
all_outs = np.array(all_outs).reshape((events_to_process, 8))
#recos_nn_t = np.array(all_outs)
recos_nn_d = np.array(all_outs)

In [None]:
#diff_nn_t = recos_nn_t - truths
diff_nn_d = recos_nn_d - truths

In [None]:
# kill all the services
import zmq
for proc, ctrl_addr in zip(procs, ctrl_addrs): 
    with zmq.Context.instance().socket(zmq.REQ) as ctrl_sock:
        ctrl_sock.connect(ctrl_addr)
        ctrl_sock.send_string("die")
        proc.join()

### plot

In [None]:
par_names = ['x', 'y', 'z', 't', 'azi', 'zen', 'Ecscd', 'Etrck']

In [None]:
fig = plt.figure(figsize=(25, 17))

for i in range(8):
    r = bounds[i][1] + 0.5*bounds[i][0]
    bins = np.linspace(-r, r, 50)
    
    plt.subplot(3,3,i+1)
    plt.hist(diff_ana[:, i], bins, alpha=0.5, label='Analytic')
    plt.hist(diff_nn_t[:, i], bins, alpha=0.5, label='NN (total)')
    plt.hist(diff_nn_d[:, i], bins, alpha=0.5, label='NN (dom)')
    #plt.hist(recos_ana[:, i], 50, alpha=0.5, label='Analytic')
    #plt.hist(recos_nn_t[:, i], 50, alpha=0.5, label='NN (total)')
    #plt.hist(recos_nn_d[:, i], 50, alpha=0.5, label='NN (dom)')
    if i == 2: plt.legend()
    if i == 1: plt.title('Spherical Detector, Reco-True')
    plt.xlabel(par_names[i])
    
#plt.savefig('images/spherical_reco.png', bbox_inches='tight')