In [None]:
import sys
if '/tf/localscratch/weldert/freeDOM/' not in sys.path:
    sys.path.append('/tf/localscratch/weldert/freeDOM/')
    #sys.path.append('/localscratch/weldert/freeDOM/')

import os
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3" #

import numpy as np
import matplotlib.pyplot as plt
#from matplotlib.colors import LogNorm
from mpl_toolkits.mplot3d import Axes3D
from scipy import stats
from freedom.toy_model import toy_model
import tensorflow as tf
import tensorflow.keras.backend as K
import tensorflow_addons as tfa
from sklearn.model_selection import train_test_split
#import dragoman as dm
import pickle
from types import SimpleNamespace
%load_ext autoreload
%autoreload 2

In [None]:
plt.rcParams['xtick.labelsize'] = 14
plt.rcParams['ytick.labelsize'] = 14 
plt.rcParams['axes.labelsize'] = 16
plt.rcParams['axes.titlesize'] = 16
plt.rcParams['legend.fontsize'] = 14

In [None]:
toy_experiment = toy_model.toy_experiment(detector_xs=np.linspace(-5, 5, 11), t_std=1)

In [None]:
example_x_src = 2.1
example_b_src = 1.1
example_N_src = 2.3

In [None]:
# generate one test event
test_event = toy_experiment.generate_event(x_src=example_x_src, b=example_b_src, N_src=example_N_src)

In [None]:
# Grid scan

x = np.linspace(-5, 5, 100)
y = np.linspace(-2, 2, 100)
x, y = np.meshgrid(x, y)

g = {}
g['dom_hit_term'] = np.empty(x.shape)
g['dom_charge_terms'] = np.empty(x.shape)
g['total_charge_hit_terms'] = np.empty(x.shape)
g['total_charge_terms'] = np.empty(x.shape)

for idx in np.ndindex(x.shape):
    hypo_x =  x[idx]
    hypo_b =  y[idx]
    hypo_t = 0
    hypo_N_src = example_N_src
    g['dom_hit_term'][idx] = -toy_experiment.dom_hit_term(test_event[1], hypo_x, hypo_b, 0)
    g['dom_charge_terms'][idx] = -toy_experiment.dom_charge_term(test_event[0], hypo_x, hypo_b, hypo_N_src)
    g['total_charge_hit_terms'][idx] = -toy_experiment.total_charge_hit_term(test_event[1], hypo_x, hypo_b, hypo_t, hypo_N_src)
    g['total_charge_terms'][idx] = -toy_experiment.total_charge_term(test_event[0], hypo_x, hypo_b, hypo_N_src)

In [None]:
g['dom_llh'] = g['dom_hit_term'] + g['dom_charge_terms']
g['total_charge_llh'] = g['total_charge_hit_terms'] + g['total_charge_terms']
g['dom_llh'] -= np.min(g['dom_llh'])
g['total_charge_llh'] -= np.min(g['total_charge_llh'])

In [None]:
def plot_diff(a, b, axes, title_a='a', title_b='b', vmax=None, txt=0, **kwargs):
    m=axes[0].pcolormesh(x, y, a, cmap='Spectral', vmax=vmax, label=r'$\Delta LLH$', **kwargs)
    plt.colorbar(m, ax=axes[0])
    axes[0].set_title(title_a)
    m=axes[1].pcolormesh(x, y, b, cmap='Spectral', vmax=vmax, label=r'$\Delta LLH$', **kwargs)
    plt.colorbar(m, ax=axes[1])
    axes[1].set_title(title_b)
    diff = a - b
    vlim = min(np.max(np.abs(diff)), vmax)
    m=axes[2].pcolormesh(x, y, diff, cmap='RdBu', vmin=-vlim, vmax=vlim, label=r'$\Delta LLH$', **kwargs)
    plt.colorbar(m, ax=axes[2])
    axes[2].set_title('diff')
    if txt == 1:
        p = np.unravel_index(np.argmax(np.abs(diff), axis=None), a.shape)
        axes[2].text(-3.5, 1.5, r'Max abs diff = %.1f (%.2f)'%(np.max(np.abs(diff)), (a[p]-b[p])/a[p]), size=15)
        #plt.scatter(x[p], y[p], color='red', marker='x')
    elif txt > 1:
        vs = np.abs(diff[np.abs(diff) <= txt])
        axes[2].text(-3, 1.5, r'Mean diff = %.1f'%(np.sum(vs)/len(vs)), size=15)

In [None]:
def plot_truth(axes, x, y):
    if not isinstance(axes, np.ndarray):
        axes = np.array([axes])
    for ax in axes.flatten():
        ax.plot([x], [y], marker='$T$', markersize=10, color='white')

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(20,17))
plt.subplots_adjust(wspace=0.3, hspace=0.3)

plot_diff(g['dom_hit_term'], g['total_charge_hit_terms'], axes=ax[0], title_a='per DOM hit', title_b='total hit', vmax=200)
plot_diff(g['dom_charge_terms'], g['total_charge_terms'], axes=ax[1], title_a='per DOM charge', title_b='total charge', vmax=200)
plot_diff(g['dom_llh'], g['total_charge_llh'], axes=ax[2], title_a='per DOM llh', title_b='total llh', vmax=200)

plot_truth(ax, example_x_src, example_b_src)

# Train NNs

In [None]:
#!rm events.pkl

In [None]:
#%%time
fname = 'events.pkl'
if os.path.isfile(fname):
    with open(fname, 'rb') as file:
        events = pickle.load(file)
    
    #with open('events_close.pkl', 'rb') as file:
    #    events2 = pickle.load(file) 
    
else:
    # generate some MC (it's very slow right now....about 15min for 1e5, but I don't mind)
    events = toy_experiment.generate_events(int(1e5), N_lims=(0, 20))
    with open(fname, 'wb') as file:
        pickle.dump(events, file, protocol=pickle.HIGHEST_PROTOCOL) 
    
    #events2 = toy_experiment.generate_events(int(1e5), N_lims=(0, 20), blims=(-0.5,0.5))
    #with open('events_close.pkl', 'wb') as file:
    #    pickle.dump(events2, file, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
mc, truth = events
#mc2, truth2 = events2
#mc = np.append(mc, mc2, axis=0)
#truth = np.append(truth, truth2, axis=0)

In [None]:
hitnet = SimpleNamespace()
chargenet = SimpleNamespace()

data generator and activations

In [None]:
class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self, x, t, batch_size=4096, shuffle='event', weights=False):
        
        self.batch_size = int(batch_size/2) # half true labels half false labels
        self.data = x
        self.params = t
        if shuffle == 'event':
            self.shuffled_params = np.roll(t, len(toy_experiment.detector_xs), axis=0)
        elif shuffle == 'DOM':
            self.shuffled_params = np.empty_like(self.params)
            for DOM_index in range(11):
                mask = self.data[:, 2] == DOM_index
                self.shuffled_params[mask] = np.random.permutation(self.params[mask])
        
        self.indexes = np.arange(len(self.data))
        self.shuffle = shuffle
        self.weights = weights
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.data) / self.batch_size))

    def __getitem__(self, index):
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Generate data
        X, y, w = self.__data_generation(indexes) #

        return X, y, w

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        np.random.shuffle(self.indexes) # mix between batches
        if self.shuffle == 'event':
            self.shuffled_params = np.roll(self.shuffled_params, len(toy_experiment.detector_xs), axis=0)

    def __data_generation(self, indexes_temp):
        'Generates data containing batch_size samples'
        # Generate data similar to Data.get_dataset()
        x = np.take(self.data, indexes_temp, axis=0)
        t = np.take(self.params, indexes_temp, axis=0)
        if self.shuffle == 'event' or self.shuffle == 'DOM':
            t_shuffle = np.take(self.shuffled_params, indexes_temp, axis=0)

        d_true_labels = np.ones((self.batch_size, 1), dtype=x.dtype)
        d_false_labels = np.zeros((self.batch_size, 1), dtype=x.dtype)
        
        d_X = np.append(x, x, axis=0)
        if self.shuffle == 'event' or self.shuffle == 'DOM':
            d_T = np.append(t, t_shuffle, axis=0)
        else:
            d_T = np.append(t, np.random.permutation(t), axis=0)
        d_labels = np.append(d_true_labels, d_false_labels)
        
        d_X, d_T, d_labels = self.unison_shuffled_copies(d_X, d_T, d_labels)
        
        #weights = np.where((d_X[:,0]+1) * (np.sqrt(np.square(d_T[:,0]-d_X[:,1])+np.square(d_T[:,1]))+0.05) < 0.5, 1000, 1)
        #R = np.sqrt(np.square(d_T[:,0]-d_X[:,1])+np.square(d_T[:,1]))
        if self.weights:
            weights = d_T[:,2] #np.clip(d_T[:,2], 0, 2) #d_X[:,0]/100+1
        else:
            weights = np.ones(len(d_T[:,2]))

        return [d_X, d_T], d_labels, weights
    
    def unison_shuffled_copies(self, a, b, c):
        'Shuffles arrays in the same way'
        assert len(a) == len(b) == len(c)
        p = np.random.permutation(len(a))
        
        return a[p], b[p], c[p]

In [None]:
def x2(x):
    return tf.where(x >= 0, x+tf.math.pow(x, 2), 0) #*tf.math.exp(x/30) , 0.1*x
    
class combi_activation(tf.keras.layers.Layer):
    def __init__(self):
        super(combi_activation, self).__init__()

    def build(self, input_shape):
        
        self.a = self.add_weight(
            shape=(1, input_shape[-1]),
            initializer=tf.keras.initializers.RandomUniform(0, 1),
            trainable=True,
            name='a',
            constraint=lambda x: tf.clip_by_value(x, 0, 3)
        )
        
        self.b = self.add_weight(
            shape=(1, input_shape[-1]),
            initializer='ones',
            trainable=True,
            name='b',
            constraint=lambda x: tf.clip_by_value(x, 0, 3)
        )

        self.c = self.add_weight(
            shape=(1,), #input_shape[-1]
            initializer='zeros',
            trainable=True,
            name='c',
            constraint=lambda x: tf.clip_by_value(x, 0, 0.2)
        )
        
    def call(self, inputs):
        pos = self.a*inputs + self.b*tf.math.pow(inputs, 2) 
        neg = self.c*inputs
        return tf.where(inputs >= 0, pos, neg)


class ParametricSoftExp(tf.keras.layers.Layer):
    def __init__(self, init_alpha=0.0, init_stdv=0.01, **kwargs):
        """
        Soft Exponential activation function with trainable alpha.
        We initialize alpha from a random uniform distribution.
        Layer can be used as an advanced layer that learns/changes during the optimization process.
        See: https://arxiv.org/pdf/1602.01321.pdf by Godfrey and Gashler
        Soft Exponential f(α, x):
           α == 0:  x
           α  > 0:  (exp(αx)-1) / α + α
           α  < 0:  -ln(1-α(x + α)) / α
        :param init_alpha:
        :param init_stdv:
        :param kwargs:
        """
        super(ParametricSoftExp, self).__init__(**kwargs)
        self.init_mean = init_alpha
        self.init_stdv = init_stdv
        self.atol = tf.constant(1e-08)
        
    def build(self, temp):
        # Initialize alpha
        alpha_init = tf.random_normal_initializer(mean=self.init_mean, stddev=self.init_stdv)
        self.alpha_actv = tf.Variable(initial_value=alpha_init(shape=(1,), dtype='float32'), trainable=True)
        
    def call_lt0(self, x):
        return -(tf.math.log(1 - self.alpha_actv * (x + self.alpha_actv))) / (self.alpha_actv)
    def call_gt0(self, x):
        return (tf.math.exp(self.alpha_actv * x) - 1) / self.alpha_actv + self.alpha_actv
    def call(self, x):
        x = 4*(x-tf.reduce_min(x))/(tf.reduce_max(x)-tf.reduce_min(x)) - 2 #tf.clip_by_value(x, -5, 5)
        # Check for equal-ness first, based on a certain tolerance.
        cond_equal = tf.less_equal(tf.abs(self.alpha_actv), self.atol)
        # Otherwise go to greater or lower than zero
        res = K.switch(cond_equal, x, K.switch(self.alpha_actv > 0, self.call_gt0(x), self.call_lt0(x)))
        #res = tf.clip_by_value(res, -1000, 1000)
        return res
    
    def get_config(self):
        config = {'alpha_init': float(self.init_mean)}
        base_config = super().get_config()
        return dict(list(base_config.items()) + list(config.items()))

Prepare Data for NN

In [None]:
def get_total_charge_infos(doms, truth):
    c = np.sum(doms[:,0])
    n = np.count_nonzero(doms[:,0])
    #if n == 0:
    #    r2c = 10
    #else:
    #    r2c = np.min(np.square(truth[0] - doms[:,1][doms[:,0]!=0]) + np.square(truth[1]))
    #if n == 11:
    #    r2n = 10
    #else:
    #    r2n = np.min(np.square(truth[0] - doms[:,1][doms[:,0]==0]) + np.square(truth[1]))
    return [c, n]

In [None]:
Typ = 'dom'  #'total'

In [None]:
chargenet.x = []
hitnet.x = []
n_hits_per_event = []
failures = []
for i, item in enumerate(mc):
    if np.sum(item[0][:,0]) == 0:
        failures.append(i)
        continue
    chargenet.x.append(item[0]) #get_total_charge_infos(item[0], truth[i])
    hitnet.x.append(item[1])
    n_hits_per_event.append(item[1].shape[0])
truth = np.delete(truth, failures, axis=0)

In [None]:
chargenet.x = np.concatenate(chargenet.x) #np.array(chargenet.x)
hitnet.x = np.concatenate(hitnet.x)
n_hits_per_event = np.array(n_hits_per_event)

In [None]:
chargenet.t = np.repeat(truth, len(toy_experiment.detector_xs), axis=0) #truth
hitnet.t = np.repeat(truth, n_hits_per_event, axis=0)

In [None]:
assert chargenet.x.shape[0] == chargenet.t.shape[0]
assert hitnet.x.shape == hitnet.t.shape

## charge Net

In [None]:
#chargenet.x_train, chargenet.x_test, chargenet.t_train, chargenet.t_test = train_test_split(chargenet.x, chargenet.t, test_size=0.2, random_state=42)

# some nasty gymnastics to get the NN inputs for the grid scan
chargenet.tt = np.vstack([x.flatten(), y.flatten(), np.ones(np.prod(x.shape)) * example_N_src]).T
if Typ == 'dom':
    chargenet.tts = np.repeat(chargenet.tt, len(toy_experiment.detector_xs), axis=0)
    
    chargenet.xxs = np.repeat(test_event[0][np.newaxis,:, :], np.prod(x.shape), axis=0)
    chargenet.xxs = chargenet.xxs.reshape(-1, 3)
    
elif Typ == 'total':
    chargenet.tts = chargenet.tt
    
    chargenet.xxs = np.repeat(get_total_charge_infos(test_event[0], [example_x_src, example_b_src]), np.prod(x.shape), axis=0)
    chargenet.xxs = chargenet.xxs.reshape(-1, 2)

prepare NN

In [None]:
#r = np.log(np.sqrt((chargenet.t[:,0]-chargenet.x[:,1])**2 + chargenet.t[:,1]**2))
#np.mean(r), np.std(r)
np.min(chargenet.t[:,2]), np.max(chargenet.t[:,2])-np.min(chargenet.t[:,2])

In [None]:
if Typ == 'dom':
    class charge_trafo(tf.keras.layers.Layer):

        def call(self, charges, theta):
            r2 = tf.math.square(theta[:,0] - charges[:,1]) + tf.math.square(theta[:,1])
            r = tf.math.log(r2) #tf.math.sqrt()
            #d = (charges[:,0])/(r2) #+0.05**2
            #(charges[:,2]-5.0)/3.16,
            out = tf.stack([
                    #charges[:,0],
                    (charges[:,0]-5.8)/62.1, 
                    #charges[:,1],
                    (charges[:,1])/3.16,     
                    #r2,
                    (r-2.22)/1.46,           
                    #theta[:,0],
                    (theta[:,0])/2.89,       
                    #theta[:,1],
                    (theta[:,1])/1.15,       
                    #theta[:,2]
                    (theta[:,2]-9.98)/5.78
                    ],
                    axis=1
                    ) 
            return out

elif Typ == 'total':
    class charge_trafo(tf.keras.layers.Layer):

        def call(self, charges, theta):
            out = tf.stack([
                     charges[:,0],
                     #(charges[:,0]-63.9)/205.7,
                     #charges[:,1],
                     #(charges[:,1]-6.)/2.36,
                     theta[:,0],
                     #(theta[:,0])/2.89,
                     theta[:,1],
                     #(theta[:,1])/1.15,
                     theta[:,2]
                     #(theta[:,2]-9.98)/5.78
                    ],
                    axis=1
                    ) 
            return out

chargenet.trafo = charge_trafo

In [None]:
def get_charge_model(activation='relu'):
    x_input = tf.keras.Input(shape=(chargenet.x.shape[1],)) #_train
    t_input = tf.keras.Input(shape=(chargenet.t.shape[1],)) #_train

    inp = chargenet.trafo()(x_input, t_input)

    h = tf.keras.layers.Dense(32, activation=activation)(inp)
    #h = tf.keras.layers.Activation(x2)(h)
    #h = combi_activation()(h)
    h = tf.keras.layers.Dropout(0.01)(h)

    h = tf.keras.layers.Dense(64, activation=activation)(h)
    #h = combi_activation()(h)
    h = tf.keras.layers.Dropout(0.01)(h)

    h = tf.keras.layers.Dense(128, activation=None)(h)
    #h = tf.keras.layers.Activation(x2)(h)
    h = combi_activation()(h)
    #h = mini_layer(h, 128)
    #h = residual_layer(h, 128, activation)
    h = tf.keras.layers.Dropout(0.01)(h)

    h = tf.keras.layers.Dense(64, activation=activation)(h)
    #h = combi_activation()(h)
    h = tf.keras.layers.Dropout(0.01)(h)

    h = tf.keras.layers.Dense(32, activation=None)(h)
    #h = tf.keras.layers.Activation(x2)(h)
    h = combi_activation()(h)
    #h = mini_layer(h, 32)
    #h = residual_layer(h, 32, activation)
    h = tf.keras.layers.Dropout(0.01)(h)
    
    #h = tf.keras.layers.Dense(16, activation=activation)(h)
    #h = tf.keras.layers.Dropout(0.01)(h)

    outputs = tf.keras.layers.Dense(1, activation='sigmoid')(h)

    model = tf.keras.Model(inputs=[x_input, t_input], outputs=outputs)
    
    return model

#model = get_charge_model()
#model.summary()

In [None]:
strategy = tf.distribute.MirroredStrategy()
nGPUs = strategy.num_replicas_in_sync

with strategy.scope():
    cmodel = get_charge_model(activation=tfa.activations.mish) #
    optimizer = tf.keras.optimizers.Adam(1e-3)
    #radam = tfa.optimizers.RectifiedAdam(lr=1e-3)
    #optimizer = tfa.optimizers.Lookahead(radam)
    bce = tf.keras.losses.BinaryCrossentropy()
    cmodel.compile(loss=bce, optimizer=optimizer, metrics=['accuracy']) #, run_eagerly=True

#chargenet.d_train = get_dataset(chargenet.x_train, chargenet.t_train)
#chargenet.d_test = get_dataset(chargenet.x_test, chargenet.t_test, test=True)
d_train = DataGenerator(chargenet.x[int(0.2*len(chargenet.x)):], chargenet.t[int(0.2*len(chargenet.x)):], batch_size=4096*nGPUs, weights=True)
d_test = DataGenerator(chargenet.x[:int(0.2*len(chargenet.x))], chargenet.t[:int(0.2*len(chargenet.x))], batch_size=4096*nGPUs, weights=True)

hist = cmodel.fit(d_train, epochs=15, verbose=1, validation_data=d_test)

In [None]:
plt.plot(cmodel.history.history['loss'])
plt.plot(cmodel.history.history['val_loss'])
plt.gca().set_yscale('log')
#plt.ylim(1, 10)

In [None]:
chargenet.llh = cmodel
chargenet.llh.layers[-1].activation = tf.keras.activations.linear
chargenet.llh.compile()

In [None]:
chargenet.llhs = chargenet.llh.predict([chargenet.xxs, chargenet.tts])

if Typ == 'dom':
    g['charge_llh'] = -np.sum(chargenet.llhs.reshape(-1, len(toy_experiment.detector_xs)), axis=1).reshape(x.shape)
elif Typ == 'total':
    g['charge_llh'] = -chargenet.llhs.reshape(x.shape)
g['charge_llh'] -= np.min(g['charge_llh'])

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(20,17))
plt.subplots_adjust(wspace=0.2, hspace=0.2)

typ = Typ+'_charge_terms'
plot_diff(g[typ]-np.min(g[typ]), g['charge_llh'], title_a='Analytic', title_b='NN', vmax=20, axes=ax[0]) #, txt=20
plot_truth(ax, example_x_src, example_b_src)
plot_diff(g[typ]-np.min(g[typ]), g['charge_llh'], title_a='Analytic', title_b='NN', vmax=2, axes=ax[1]) #, txt=2
plot_truth(ax, example_x_src, example_b_src)
plot_diff(g[typ]-np.min(g[typ]), g['charge_llh'], title_a='Analytic', title_b='NN', vmax=200, axes=ax[2], txt=1)
plot_truth(ax, example_x_src, example_b_src)

#plt.savefig('../../plots/toy_model/xy_old_close', bbox_inches='tight')

In [None]:
fig = plt.figure(figsize=(20,17))
ax3D = fig.add_subplot(111, projection='3d')
LLH = g[typ]-np.min(g[typ])

ax3D.plot_surface(X=x, Y=y, Z=g['charge_llh']) #, cmap='RdBu'
#ax3D.plot_surface(X=x, Y=y, Z=LLH)
#ax3D.plot_surface(X=x, Y=y, Z=LLH-g['charge_llh'], cmap='RdBu', vmin=-2, vmax=2)
#ax3D.plot_surface(X=x, Y=y, Z= np.divide(LLH-g['charge_llh'], LLH, out=np.zeros_like(LLH), where=LLH>0.1), cmap='RdBu', vmin=-1, vmax=1)
ax3D.set_zlim(0, 820)
#ax3D.set_xlim(-1, 1)
#ax3D.set_ylim(-1, 1)

#plt.savefig('../../plots/toy_model/xy_old_3D_close', bbox_inches='tight')

In [None]:
tt = np.vstack([np.ones(100)*example_x_src, np.ones(100)*example_b_src, np.linspace(0.01, 10, 100)]).T
if Typ == 'dom':
    tts = np.repeat(tt, len(toy_experiment.detector_xs), axis=0)
    
    xxs = np.repeat(test_event[0][np.newaxis,:, :], 100, axis=0)
    xxs = xxs.reshape(-1, 3)

llhs = chargenet.llh.predict([xxs, tts])

if Typ == 'dom':
    C = -np.sum(llhs.reshape(-1, len(toy_experiment.detector_xs)), axis=1)
C -= np.min(C)

T = []
for i in np.linspace(0.01, 10, 100):
    T.append(-toy_experiment.dom_charge_term(test_event[0], example_x_src, example_b_src, i))
T -= np.min(T)

In [None]:
plt.plot(np.linspace(0.01, 10, 100), C)
#plt.plot(np.linspace(0.01, 10, 100), T)
#plt.yscale('log')

## DOM hit Net

In [None]:
#hitnet.x_train, hitnet.x_test, hitnet.t_train, hitnet.t_test = train_test_split(hitnet.x, hitnet.t, test_size=0.2, random_state=42)

# some nasty gymnastics to get the NN inputs for the grid scan
hitnet.tt = np.vstack([x.flatten(), y.flatten(), np.ones(np.prod(x.shape)) * example_N_src]).T
hitnet.tts = np.repeat(hitnet.tt, test_event[1].shape[0], axis=0)

hitnet.xxs = np.repeat(test_event[1][np.newaxis,:, :], np.prod(x.shape), axis=0)
hitnet.xxs = hitnet.xxs.reshape(-1, 3)

In [None]:
r = np.log(np.sqrt((hitnet.t[:,0]-hitnet.x[:,1])**2 + hitnet.t[:,1]**2))
np.mean(r), np.std(r)

In [None]:
class hit_trafo(tf.keras.layers.Layer):

    def call(self, hits, theta):
        c = 0.3
        r2 = tf.math.square(theta[:,0] - hits[:,1]) + tf.math.square(theta[:,1])
        r = tf.math.sqrt(r2)
        
        delta_t = hits[:,0] - r/c 
        
        out = tf.stack([
                 (hits[:,0]-2.6)/3.64,
                 (hits[:,1])/2.93,
                 (r-0.77)/1.05,
                 delta_t,
                 (theta[:,0])/2.87,
                 (theta[:,1])/0.622,
                 (theta[:,2]-13.3)/4.69
                ],
                axis=1
                )    
        return out
hitnet.trafo = hit_trafo

In [None]:
def get_hit_model(activation='relu'):
    x_input = tf.keras.Input(shape=(hitnet.x.shape[1],))
    t_input = tf.keras.Input(shape=(hitnet.t.shape[1],))

    h = hitnet.trafo()(x_input, t_input)

    h = tf.keras.layers.Dense(32, activation=activation)(h)
    h = tf.keras.layers.Dropout(0.01)(h)

    h = tf.keras.layers.Dense(64, activation=activation)(h)
    h = tf.keras.layers.Dropout(0.01)(h)

    h = tf.keras.layers.Dense(128, activation=activation)(h)
    #h = tf.keras.layers.Activation(x2)(h)
    #h = combi_activation()(h)
    #h = ParametricSoftExp()(h)
    h = tf.keras.layers.Dropout(0.01)(h)

    h = tf.keras.layers.Dense(64, activation=activation)(h)
    h = tf.keras.layers.Dropout(0.01)(h)

    h = tf.keras.layers.Dense(32, activation=activation)(h)
    h = tf.keras.layers.Dropout(0.01)(h)

    outputs = tf.keras.layers.Dense(1, activation='sigmoid')(h)

    model = tf.keras.Model(inputs=[x_input, t_input], outputs=outputs)
    #hitnet.model.summary()
    
    return model

In [None]:
strategy = tf.distribute.MirroredStrategy()
nGPUs = strategy.num_replicas_in_sync

with strategy.scope():
    hmodel = get_hit_model(activation=tfa.activations.mish)
    optimizer = tf.keras.optimizers.Adam(1e-3)
    #radam = tfa.optimizers.RectifiedAdam(lr=0.001)
    #optimizer = tfa.optimizers.Lookahead(radam)
    hmodel.compile(loss='binary_crossentropy', optimizer=optimizer)

#chargenet.d_train = get_dataset(chargenet.x_train, chargenet.t_train)
#chargenet.d_test = get_dataset(chargenet.x_test, chargenet.t_test, test=True)
d_train = DataGenerator(hitnet.x[:int(0.8*len(hitnet.x))], hitnet.t[:int(0.8*len(hitnet.x))], batch_size=4096*nGPUs, shuffle='free')
d_test = DataGenerator(hitnet.x[int(0.8*len(hitnet.x)):], hitnet.t[int(0.8*len(hitnet.x)):], batch_size=4096*nGPUs, shuffle='free')

hist = hmodel.fit(d_train, epochs=4, verbose=1, validation_data=d_test)

In [None]:
plt.plot(hmodel.history.history['loss'])
plt.plot(hmodel.history.history['val_loss'])
plt.gca().set_yscale('log')

In [None]:
hitnet.llh = hmodel
hitnet.llh.layers[-1].activation = tf.keras.activations.linear
hitnet.llh.compile()

In [None]:
hitnet.llhs = hitnet.llh.predict([hitnet.xxs, hitnet.tts])

g['hit_llh'] = -np.sum(hitnet.llhs.reshape(-1, test_event[1].shape[0]), axis=1).reshape(x.shape)
g['hit_llh'] -= np.min(g['hit_llh'])

In [None]:
Typ = 'total_charge_hit_terms' #'dom_hit_term'

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(20,17))
plt.subplots_adjust(wspace=0.2, hspace=0.2)

typ = Typ
plot_diff(g[typ]-np.min(g[typ]), g['hit_llh'], title_a='Analytic', title_b='NN', vmax=20, axes=ax[0]) #, txt=20
plot_truth(ax, example_x_src, example_b_src)
plot_diff(g[typ]-np.min(g[typ]), g['hit_llh'], title_a='Analytic', title_b='NN', vmax=2, axes=ax[1]) #, txt=2
plot_truth(ax, example_x_src, example_b_src)
plot_diff(g[typ]-np.min(g[typ]), g['hit_llh'], title_a='Analytic', title_b='NN', vmax=200, axes=ax[2], txt=1)
plot_truth(ax, example_x_src, example_b_src)

#plt.savefig('../../plots/toy_model/xy_old_hit', bbox_inches='tight')

In [None]:
fig = plt.figure(figsize=(20,17))
ax3D = fig.add_subplot(111, projection='3d')
LLH = g['dom_hit_term']-np.min(g['dom_hit_term'])

ax3D.plot_surface(X=x, Y=y, Z=g['hit_llh'], cmap='RdBu')
#ax3D.plot_surface(X=x, Y=y, Z=LLH)
#ax3D.plot_surface(X=x, Y=y, Z=LLH-g['hit_llh'], cmap='RdBu', vmin=-2, vmax=2)
#ax3D.plot_surface(X=x, Y=y, Z= np.divide(LLH-g['hit_llh'], LLH, out=np.zeros_like(LLH), where=LLH!=0), cmap='RdBu', vmin=-1, vmax=1)
ax3D.set_zlim(0, np.max(LLH))
#plt.savefig('../../plots/toy_model/hit_LLH_NN_3D_zoomOut', bbox_inches='tight')

In [None]:
l = 8
print(hmodel.weights[l].name)
plt.hist(hmodel.weights[l].values[0].numpy().flatten())

## All together

In [None]:
g['llh'] = g['charge_llh'] + g['hit_llh']
g['llh'] -= np.min(g['llh'])

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(20,17))
plt.subplots_adjust(wspace=0.2, hspace=0.2)

plot_diff(g['dom_hit_term']-np.min(g['dom_hit_term']), g['hit_llh'], title_a='Hit Analytic', title_b='Hit NN', vmax=20, axes=ax[0], txt=1)
plot_truth(ax, example_x_src, example_b_src)
plot_diff(g['dom_charge_terms']-np.min(g['dom_charge_terms']), g['charge_llh'], title_a='Charge Analytic', title_b='Charge NN', vmax=20, axes=ax[1], txt=1)
plot_truth(ax, example_x_src, example_b_src)
plot_diff(g['dom_llh']-np.min(g['dom_llh']), g['llh'], title_a='Analytic', title_b='NN', vmax=20, axes=ax[2], txt=1)
plot_truth(ax, example_x_src, example_b_src)

#plt.savefig('../../plots/toy_model/event3', bbox_inches='tight')