In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3"
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import tensorflow as tf
from sklearn.model_selection import train_test_split
import dama as dm
import pickle
import os
import awkward as ak
import pyarrow.parquet as pq

from freedom.toy_model.toy_model_functions import toy_model
from freedom.toy_model.detectors import get_box_detector
from types import SimpleNamespace
from freedom.toy_model import NNs

%load_ext autoreload
%autoreload 2

In [None]:
params = {'figure.figsize': (7, 7*0.618),
          'legend.fontsize': 14,
          'axes.labelsize': 16,
          'axes.titlesize': 16,
          'xtick.labelsize': 16,
          'ytick.labelsize': 16}
plt.rcParams.update(params)

par_names = ['x', 'y', 'z', 't', r'$\phi^{azimuth}$', r'$\theta^{zenith}$',  r'$E^{deposited}$', 'I']

def plot_truth(axes, truth, idx=(0,1)):
    if not isinstance(axes, np.ndarray):
        axes = np.array([axes])
    for ax in axes.flatten():
        ax.plot([truth[idx[0]]], [truth[idx[1]]], marker='$T$', markersize=10, color='k')

def plot_diff(a, b, axes, title_a='a', title_b='b', vmax=None, limit_diff=False, **kwargs):
    
    levels = stats.chi2(df=2).isf(stats.norm.sf(np.arange(1,6))*2)/2    
    labels = [str(i) + r'$\sigma$' for i in range(1,6)]
    colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red']
    colors = plt.cm.viridis(np.linspace(0, 0.9, 6))
    a.plot_contour(ax=axes[0], levels=levels, labels=labels, colors=colors, label=r'$\Delta LLH$', **kwargs)
    axes[0].set_title(title_a)
    b.plot_contour(ax=axes[1], levels=levels,  labels=labels, colors=colors, label=r'$\Delta LLH$', **kwargs)
    axes[1].set_title(title_b)
    diff = a - b
    if limit_diff:
        diff.plot(ax=axes[2], cmap='RdBu', cbar=True, vmin=-vmax, vmax=vmax, label=r'$\Delta LLH$', **kwargs)
    else:
        diff.plot(ax=axes[2], cmap='RdBu', cbar=True, vmin=-np.max(np.abs(diff)), vmax=np.max(np.abs(diff)), 
                  label=r'$\Delta LLH$', **kwargs) 
    axes[2].set_title(title_a + ' - ' + title_b)

In [None]:
om = np.zeros((2,3))

In [None]:
toy_experiment = toy_model(om)

# Train NNs

In [None]:
events, meta = toy_experiment.generate_events(n=10_000_000, gamma=0, gen_volume="sphere",
                                              e_lim=(1,50), inelast_lim=(0,1), radius=50., t_width=0,
                                              contained=False) #, min_hits=3
truths = NNs.make_truth_array(events)

In [None]:
plt.hist(ak.count(events.photons.t, axis=1).to_numpy(), np.linspace(0,100,101))
plt.yscale('log')

In [None]:
strategy = tf.distribute.MirroredStrategy()
nGPUs = strategy.num_replicas_in_sync

## Hit Net - per dom

In [None]:
x, t = NNs.get_hit_data(events)
x_train, x_test, t_train, t_test = train_test_split(x, t, test_size=0.1, random_state=42)
d_train = NNs.DataGenerator(x_train, t_train, batch_size=4096*nGPUs, time_spread=50) #, shuffle='inDOM'
d_valid = NNs.DataGenerator(x_test, t_test, batch_size=4096*nGPUs, time_spread=50) #, shuffle='inDOM'

In [None]:
optimizer = tf.keras.optimizers.Adam(2e-5)

with strategy.scope():
    hmodel_d = NNs.get_hmodel(x_shape=6, t_shape=8, trafo=NNs.hit_trafo_3D, activation='swish', final_activation='swish')
    hmodel_d.compile(loss='binary_crossentropy', optimizer=optimizer)

In [None]:
hist = hmodel_d.fit(d_train, epochs=60, verbose=1, validation_data=d_valid)

In [None]:
plt.plot(hist.history['loss'])
plt.plot(hist.history['val_loss'])
plt.yscale('log')

In [None]:
inp, lab = d_valid.__getitem__(0)
pred = hmodel_d.predict(inp, batch_size=4096).flatten()
plt.hist(pred[lab==0], 100, histtype='step')
plt.hist(pred[lab==1], 100, histtype='step');
plt.yscale('log')

In [None]:
#hmodel_d.layers[-1].activation = tf.keras.activations.linear
#hmodel_d.compile()

#hmodel_d.save('networks/string_toy_hitnet_dom.h5')
hmodel_d = tf.keras.models.load_model('networks/string_toy_hitnet_dom.h5',
                                     custom_objects={'hit_trafo_3D':NNs.hit_trafo_3D})

## Charge Net - per DOM

In [None]:
x, t = NNs.get_dom_data(events, om)
x_train, x_test, t_train, t_test = train_test_split(x, t, test_size=0.1, random_state=42)

d_train = NNs.DataGenerator(x_train, t_train, batch_size=4096*nGPUs)
d_valid = NNs.DataGenerator(x_test, t_test, batch_size=4096*nGPUs)

In [None]:
optimizer = tf.keras.optimizers.Adam(1e-4)

with strategy.scope():
    dmodel = NNs.get_hmodel(x_shape=4, t_shape=8, trafo=NNs.dom_trafo_3D, activation='swish', final_activation='swish')
    dmodel.compile(loss='binary_crossentropy', optimizer=optimizer)

In [None]:
hist = dmodel.fit(d_train, epochs=25, verbose=1, validation_data=d_valid)

In [None]:
plt.plot(hist.history['loss'])
plt.plot(hist.history['val_loss'])
plt.yscale('log')

In [None]:
inp, lab = d_valid.__getitem__(0)
pred = dmodel.predict(inp, batch_size=4096).flatten()
plt.hist(pred[lab==0], 100, histtype='step')
plt.hist(pred[lab==1], 100, histtype='step');
plt.yscale('log')

In [None]:
#dmodel.layers[-1].activation = tf.keras.activations.linear
#dmodel.compile()

#dmodel.save('networks/string_toy_chargenet_dom.h5')
dmodel = tf.keras.models.load_model('networks/string_toy_chargenet_dom.h5',
                                    custom_objects={'dom_trafo_3D':NNs.dom_trafo_3D})

## LLH - per DOM

In [None]:
def shift_om(xs, ts):
    ts[:, :3] -= xs[:, :3]
    xs[:, :3] -= xs[:, :3]
    return xs, ts

def heart(n, r=5):
    x = np.linspace(-1,1,n)
    a = np.hstack([x.reshape(-1,1), (np.sqrt(np.abs(x)) + np.sqrt(1-x**2)).reshape(-1,1)])
    b = np.hstack([x.reshape(-1,1), (np.sqrt(np.abs(x)) - np.sqrt(1-x**2)).reshape(-1,1)])
    return r*np.append(a, b, axis=0)

def cube(d, n=5, return_dens=False):
    x = np.linspace(0, d*(n-1), n) - 0.5*d*(n-1)
    a, b, c = np.meshgrid(x, x, x)
    out = np.zeros((n**3, 3))
    out[:, 0] = a.flatten()
    out[:, 1] = b.flatten()
    out[:, 2] = c.flatten()
    if return_dens:
        return out, n**3/(d*(n-1))**3
    else:
        return out

In [None]:
#str_pos = np.array([[0,0], [5,5], [6,2.5], [7.5,0], [-2,2.5], [-2,5]]) #heart(20) #
#z_pos = np.linspace(-10,10,6) #np.linspace(-10,10,3) #

#detector = np.append(np.repeat(str_pos, len(z_pos), axis=0), np.tile(z_pos, len(str_pos)).reshape(-1,1), axis=1)
detector = cube(10, 5)
toy_experiment = toy_model(detector)

fig = plt.figure(figsize=(15,9))
ax = fig.add_subplot(projection='3d')
ax.scatter(detector[:,0], detector[:,1], detector[:,2])

In [None]:
truth = np.array([2., 0., 5., 0, 3, np.arccos(0), 2., 0.3])

# generate one test event
test_event = toy_experiment.generate_event(truth)
print(np.sum(test_event[1]))

segments = toy_experiment.model(*truth)
fig = plt.figure(figsize=(20,12))
ax = fig.add_subplot(projection='3d')
ax.set_box_aspect([1,1,1])
ax.scatter(detector[:,0],detector[:,1],detector[:,2],s=10, c='black', marker='x', alpha=0.5)
ax.scatter(segments[:,0],segments[:,1],segments[:,2],s=segments[:,4]/100, c=segments[:,3])
ax.scatter(test_event[0][:, 0], test_event[0][:, 1], test_event[0][:, 2],
           s=30, c=np.log(test_event[0][:,3]), cmap='turbo')
#ax.view_init(45, 45)

#plt.savefig('images/string_pos/test_event_det4.png', bbox_inches='tight')

In [None]:
%%time
# Grid scan

g = dm.GridData(x=np.linspace(-7, 7, 100), y=np.linspace(-7, 7, 100))
#g = dm.GridData(x=np.linspace(1, 50, 100), y=np.linspace(0, 1, 100))
IDX = (0,1)

g['dom_hit_term'] = np.empty(g.shape)
g['dom_charge_terms'] = np.empty(g.shape)

p = np.copy(truth)

for idx in np.ndindex(g.shape):
    p[IDX[0]] = g['x'][idx]
    p[IDX[1]] = g['y'][idx]
    segments = toy_experiment.model(*p)
    g['dom_hit_term'][idx] = toy_experiment.nllh_p_term_dom(segments, test_event[0])
    g['dom_charge_terms'][idx] = toy_experiment.nllh_N_term_dom(segments, test_event[1])
    
g['dom_hit_term'] -= g['dom_hit_term'].min()
g['dom_charge_terms'] -= g['dom_charge_terms'].min()
g['dom_llh'] = g['dom_hit_term'] + g['dom_charge_terms']
g['dom_llh'] -= g['dom_llh'].min()

In [None]:
xxs = np.repeat(test_event[0][np.newaxis, :], np.prod(g.shape), axis=0)
xxs = xxs.reshape(-1, 6)

tts = np.repeat(truth[np.newaxis, :], np.prod(g.shape), axis=0)
tts[:, IDX[0]] = g.get_array('x', flat=True)
tts[:, IDX[1]] = g.get_array('y', flat=True)
tts = np.repeat(tts, len(test_event[0]), axis=0)

#xxs, tts = shift_om(xxs, tts)

llhs = -hmodel_d.predict((xxs, tts), batch_size=4096)
llhs = np.sum(np.nan_to_num(llhs.reshape(-1, len(test_event[0]))), axis=1)

g.hit_llh_dom = llhs.reshape(g.shape)
g.hit_llh_dom -= g.hit_llh_dom.min()

In [None]:
xx = []
ind = test_event[0][:, 5]
for i in range(len(detector)):
    d = np.append(detector[i], np.sum(ind==i))
    xx.append(list(d))
xxs = np.repeat(np.array(xx)[np.newaxis, :], np.prod(g.shape), axis=0)
xxs = xxs.reshape(-1, 4)

tts = np.repeat(truth[np.newaxis, :], np.prod(g.shape), axis=0)
tts[:, IDX[0]] = g.get_array('x', flat=True)
tts[:, IDX[1]] = g.get_array('y', flat=True)
tts = np.repeat(tts, len(detector), axis=0)

#xxs, tts = shift_om(xxs, tts)

llhs = -dmodel.predict((xxs, tts), batch_size=4096)
llhs = np.sum(np.nan_to_num(llhs.reshape(-1, len(detector))), axis=1)

g.charge_llh_dom = llhs.reshape(g.shape)
g.charge_llh_dom -= g.charge_llh_dom.min()

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(20,17))
plt.subplots_adjust(wspace=0.25, hspace=0.25)

plot_diff(g.dom_hit_term, 
          g.hit_llh_dom, 
          title_a='Hit Analytic', title_b='Hit NN', vmax=10, axes=ax[0], limit_diff=True)
plot_truth(ax, truth, IDX)

plot_diff(g.dom_charge_terms, 
          g.charge_llh_dom, 
          title_a='Charge Analytic', title_b='Charge NN', vmax=10, axes=ax[1], limit_diff=True)
plot_truth(ax, truth, IDX)

ana, NN = g.dom_hit_term+g.dom_charge_terms, g.hit_llh_dom+g.charge_llh_dom
plot_diff(ana-ana.min(), 
          NN-NN.min(), 
          title_a='Analytic', title_b='NN', vmax=10, axes=ax[2], limit_diff=True)
plot_truth(ax, truth, IDX)

#plt.savefig('images/string_pos/test_llh_det4.png', bbox_inches='tight')

# Reco

In [None]:
from spherical_opt import spherical_opt
from multiprocessing import Pool, Process

In [None]:
def init_points(hits, n_live_points, bound=bounds, seed=[None]):
    if seed[0] == None:
        avg = np.average(hits[:, :3], axis=0)
        low_lims = np.concatenate([avg-np.array([3,3,3]), np.array([-50,0,0,1,0])])
        hig_lims = np.concatenate([avg+np.array([3,3,3]), np.array([50,2*np.pi,np.pi,30,1])])
    else:
        low_lims = seed - np.array([1, 1, 1, 5, 0.5, 0.3, 3, 3])
        hig_lims = seed + np.array([1, 1, 1, 5, 0.5, 0.3, 3, 3])
    
    uniforms = np.random.uniform(size=(n_live_points, 8))
    initial_points = low_lims + uniforms * (hig_lims - low_lims)
    initial_points = np.clip(initial_points, bounds[:, 0], bounds[:, 1])
    return initial_points

In [None]:
#str_pos = np.array([[0,0], [2,0], [4,0], [-2,0], [-4,0]])
#z_pos = np.linspace(-10,10,6)

#detector = np.append(np.repeat(str_pos, len(z_pos), axis=0), np.tile(z_pos, len(str_pos)).reshape(-1,1), axis=1)
detector = cube(2.5, 5) #20-2.5?
toy_experiment = toy_model(detector)

In [None]:
lim = 5
events, meta = toy_experiment.generate_events(n=5000, gamma=0, gen_volume="box", e_lim=(1,10), inelast_lim=(0,1),
                                              x_lim=(-lim,lim), y_lim=(-lim,lim), z_lim=(-lim,lim), t_width=0,
                                              contained=False, min_hits=4)
truths = NNs.make_truth_array(events)

bounds = np.array([[-lim-2,lim+2], [-lim-2,lim+2], [-lim-2,lim+2], [-300,300], [0,2*np.pi], [0,np.pi], [1,30], [0,1]])

In [None]:
ps = ak.count(events.photons.t, axis=1).to_numpy()
plt.hist(ps, np.linspace(0,np.quantile(ps, 0.99),101))
plt.yscale('log')

### Analytic

In [None]:
def LLH_ana(X, hits, n_obs, form='total', fix=[None], bounds=bounds):
    if fix[0] != None:
        X = np.insert(X, fix[0], fix[1])

    if ~np.alltrue(np.logical_and(bounds[:,0] <= X, X <= bounds[:,1]), axis=-1):
        return 1e9
    
    segments = toy_experiment.model(*X)
    if form == 'dom':
        h_term = toy_experiment.nllh_p_term_dom(segments, hits)
        c_term = toy_experiment.nllh_N_term_dom(segments, n_obs)
    elif form == 'total':
        h_term = toy_experiment.nllh_p_term_tot(segments, hits)
        c_term = toy_experiment.nllh_N_term_tot(segments, n_obs)
    else:
        raise NameError("Formulation must be one of ['total', 'dom'], not "+form)
    
    return c_term + h_term

def fit_event_ana(event):
    hits = np.stack([event.photons[var].to_numpy() for var in ['x', 'y', 'z', 't', 'sensor_id']], axis=1)
    n_obs = event.n_obs.to_numpy() 
    #truth = event?
    
    def eval_LLH(params):
        if params.ndim == 1:
            return LLH_ana(params, hits, n_obs)
        else:
            llhs = []
            for p in params:
                llhs.append(LLH_ana(p, hits, n_obs))
            return np.array(llhs)

    # seeding
    initial_points = init_points(hits, 97) #, seed=truth
    
    # free fit
    fit_res = spherical_opt.spherical_opt(
        func=eval_LLH,
        method="CRS2",
        initial_points=initial_points,
        rand=np.random.default_rng(42),
        spherical_indices=[[4,5]],
        batch_size=12,
    )

    return list(fit_res['x'])

In [None]:
%%time
with Pool(10) as p:
    outs = p.map(fit_event_ana, events)
recos_ana = np.array(outs)

In [None]:
diff_ana = recos_ana - truths

### NNs

In [None]:
import math

from functools import partial
from freedom.llh_service.llh_service import LLHService
from freedom.llh_service.llh_client import LLHClient
from freedom.reco import crs_reco

In [None]:
loc = 'networks/'
service_conf = {
        "poll_timeout": 1,
        "flush_period": 1,
        "n_hypo_params": 8,
        "n_hit_features": 6,
        "n_evt_features": len(detector)*4,
        "batch_size" : {
          "n_hypos": 200,
          "n_observations": 6000, 
        },
        "send_hwm": 10000,
        "recv_hwm": 10000,
        "hitnet_file": loc+'string_toy_hitnet_dom.h5',
        "domnet_file": loc+'string_toy_chargenet_dom.h5',
        "ndoms": len(detector),
        "toy": True
}

In [None]:
n_gpus = 4

base_req = "ipc:///tmp/recotestreq"
base_ctrl = "ipc:///tmp/recotestctrl"

req_addrs = []
ctrl_addrs = []
for i in range(n_gpus):
    req_addrs.append(f'{base_req}{i}')
    ctrl_addrs.append(f'{base_ctrl}{i}')
    
procs = []
for i in range(n_gpus):
    proc = Process(target=crs_reco.start_service, args=(service_conf, ctrl_addrs[i], req_addrs[i], i))
    proc.start()
    procs.append(proc)

In [None]:
def fit_events_nn(events, index, Truths, ctrl_addrs):
    outputs = []

    client = LLHClient(ctrl_addr=ctrl_addrs[index], conf_timeout=60000)
    def Eval_llh(params, hits, n_obs, fix=[None]):
        if fix[0] != None:
            params = np.insert(params, fix[0], fix[1])
            
        if ~np.alltrue(np.logical_and(bounds[:,0] <= params, params <= bounds[:,1]), axis=-1):
            return 1e9

        c_data = np.hstack([detector, n_obs[:, np.newaxis]])
        return client.eval_llh(hits, c_data, params)

    for j, event in enumerate(events):
        hits = np.stack([event.photons[var].to_numpy() for var in ['x', 'y', 'z', 't', 'q', 'sensor_id']], axis=1)
        n_obs = event.n_obs.to_numpy()
        
        def eval_LLH(params):
            if params.ndim == 1:
                return Eval_llh(params, hits, n_obs)
            else:
                o = []
                for p in params:
                    o.append(Eval_llh(p, hits, n_obs))
                return np.array(o)

        # seeding
        initial_points = init_points(hits, 97) #, seed=Truths[j]
        
        #free fit
        fit_res = spherical_opt.spherical_opt(
            func=eval_LLH,
            method="CRS2",
            initial_points=initial_points,
            rand=np.random.default_rng(42),
            spherical_indices=[[4,5]],
            batch_size=12,
        )
        outputs.append(fit_res['x'])
        
    return outputs

In [None]:
events_to_process = len(events)
pool_size = 200
evts_per_proc = int(math.ceil(events_to_process/pool_size))
evt_splits = [events[i*evts_per_proc:(i+1)*evts_per_proc] for i in range(pool_size)] #_red
true_splits = [truths[i*evts_per_proc:(i+1)*evts_per_proc] for i in range(pool_size)]
print(sum(len(l) for l in evt_splits))

gpu_inds = np.arange(pool_size) % n_gpus

fit_events_partial = partial(
        fit_events_nn,
        ctrl_addrs=ctrl_addrs
)

In [None]:
%%time
# reconstruct with a worker pool; one LLH client per worker
with Pool(pool_size) as p:
    outs = p.starmap(fit_events_partial, zip(evt_splits, gpu_inds, true_splits))

all_outs = sum((out for out in outs), [])
all_outs = np.array(all_outs).reshape((events_to_process, 8))
recos_nn = np.array(all_outs)

In [None]:
diff_nn = recos_nn - truths
diff_nn[:, -2] = np.log10(recos_nn[:, -2] / truths[:, -2])
#np.save('recos/string/diff_nn_cub_10_5_sameGen', diff_nn)

In [None]:
# kill all the services
import zmq
for proc, ctrl_addr in zip(procs, ctrl_addrs): 
    with zmq.Context.instance().socket(zmq.REQ) as ctrl_sock:
        ctrl_sock.connect(ctrl_addr)
        ctrl_sock.send_string("die")
        proc.join()

### plot

In [None]:
def bootstrap_iqr(x, n=100, f=0.2):
    iqrs = np.zeros((n,8))
    for i in range(n):
        #inds = np.random.choice(range(len(x)), size=int(f*len(x)), replace=False)
        inds = np.random.choice(range(len(x)), size=len(x))
        iqrs[i] = stats.iqr(x[inds], axis=0)
    return np.mean(iqrs, axis=0), np.std(iqrs, axis=0)

In [None]:
points, diffs, uncs = ['2p5','3','3p5','4p2','5','6','7','8p4','10','12','14','17','20'], [], []

for p in points:
    exec("diff_nn_cub_%s_5 = np.load('recos/string/diff_nn_cub_%s_5.npy')"%(p,p))
    #exec("diffs.append(stats.iqr(diff_nn_cub_%s_5, axis=0))"%(p))
    exec("d, u = bootstrap_iqr(diff_nn_cub_%s_5)"%(p))
    exec("diffs.append(d)")
    exec("uncs.append(u)")

diffs = np.vstack(diffs).T
uncs = np.vstack(uncs).T

xs, dens = np.logspace(np.log10(2.5), np.log10(20), 13), []
for x in xs:
    dens.append(cube(x, 5, True)[1])

In [None]:
colors = ['tab:blue','tab:blue','tab:blue','tab:orange','tab:green','tab:green','tab:red','tab:purple']
pos = [[2.5e-4,0.16], [3.2e-4,0.16], [2e-3,0.13], [3e-4,0.05], [2e-4,0.21], [2e-4,0.43], [1.9e-4,0.59], [3e-4,0.73]]
styles = ['-.','--',':','-','--',':','-','-',]
par_names = ['x', 'y', 'z', 't', r'$\phi^{azimuth}$', r'$\theta^{zenith}$',  r'$E^{deposited}$', 'I']
#par_names = ['x', 'y', 'z', 't', r'$\varphi$', r'$\vartheta$',  r'E', 'I']

fig = plt.figure(figsize=(7, 7*0.618))
ax1 = fig.add_subplot(111)
ax2 = ax1.twiny()

for i in range(8):
    ax1.fill_between(dens, np.min(diffs[i])/(diffs[i]-uncs[i]), np.min(diffs[i])/(diffs[i]+uncs[i]), 
                     color=colors[i], alpha=0.2)
    ax1.plot(dens, np.min(diffs[i])/diffs[i], label=par_names[i], color=colors[i], linestyle=styles[i])
#ax1.legend()
ax1.set_xscale('log')
ax1.set_xlabel('Sensor density')
ax1.set_ylabel('Parameter resolution (IQR) \n normalized to best value   ')
ax1.set_ylim(0,1.05)

for i in range(8):
    ax1.text(pos[i][0], pos[i][1], par_names[i], size=16, color=colors[i])

ax2.set_xlabel('Detector volume')
ax2.plot(125/np.array(dens), 2*np.ones(len(dens)))
ax2.set_xscale('log')
ax2.invert_xaxis()

#plt.savefig('images/string_pos/res_vs_dens.pdf', bbox_inches='tight')

In [None]:
diff_nn_cub_20_5_sameGen = np.load('recos/string/diff_nn_cub_20_5_sameGen.npy')
diff_nn_cub_5_5_sameGen = np.load('recos/string/diff_nn_cub_5_5_sameGen.npy')

print(stats.iqr(diff_nn_cub_20_5_sameGen, axis=0) / stats.iqr(diff_nn_cub_20_5, axis=0))
print('---------------------------')
print(stats.iqr(diff_nn_cub_5_5_sameGen, axis=0) / stats.iqr(diff_nn_cub_5_5, axis=0))