In [None]:
import numpy as np
import tensorflow as tf
from freedom.utils.i3cols_dataloader import load_charges
import os

from freedom.neural_nets.transformations import chargenet_trafo

In [None]:
def unison_shuffled_copies(a, b, c):
    assert len(a) == len(b) == len(c)
    p = np.random.permutation(len(a))
    return a[p], b[p], c[p]

In [None]:
class DataGenerator(tf.keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, func, dirs, labels, batch_size=4096):
        'Initialization'
        self.batch_size = batch_size
        for i, dir in enumerate(dirs):
            data, params, self.labels = func(dir=dir, labels=labels)
            if i == 0:
                self.data = data
                self.params = params
            else:
                self.data = np.append(self.data, data, axis=0)
                self.params = np.append(self.params, params, axis=0)
        
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.data) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Generate data
        X, y = self.__data_generation(indexes)

        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.data))
        np.random.shuffle(self.indexes)

    def __data_generation(self, indexes_temp):
        'Generates data containing batch_size samples'
        # Generate data
        x = np.take(self.data, indexes_temp, axis=0)
        p = np.take(self.params, indexes_temp, axis=0)

        d_true_labels = np.ones((self.batch_size, 1), dtype=x.dtype)
        d_false_labels = np.zeros((self.batch_size, 1), dtype=x.dtype)
        d_labels = np.append(d_true_labels, d_false_labels)

        d_X = np.append(x, x, axis=0)
        d_P = np.append(p, np.random.permutation(p), axis=0)
        
        d_X, d_P, d_labels = unison_shuffled_copies(d_X, d_P, d_labels)

        return [d_X, d_P], d_labels

In [None]:
labels = ['x', 'y', 'z', 'time', 'azimuth','zenith', 'cascade_energy', 'track_energy']
train_d = ['/localscratch/weldert/120000_i3cols_train/'] #, '/localscratch/weldert/140000_i3cols_train/'
valid_d = ['/localscratch/weldert/120000_i3cols_valid/'] #, '/localscratch/weldert/140000_i3cols_valid/'

training_generator = DataGenerator(load_charges, train_d, labels, batch_size=2048)
validation_generator = DataGenerator(load_charges, valid_d, labels, batch_size=2048)

In [None]:
charge_input = tf.keras.Input(shape=(2,))
params_input = tf.keras.Input(shape=(len(labels),))

t = chargenet_trafo(labels=labels)

h = t(charge_input, params_input)
h = tf.keras.layers.Dense(32, activation="relu")(h)
h = tf.keras.layers.Dropout(0.001)(h)
h = tf.keras.layers.Dense(64, activation="relu")(h)
h = tf.keras.layers.Dropout(0.001)(h)
h = tf.keras.layers.Dense(128, activation="relu")(h)
h = tf.keras.layers.Dropout(0.001)(h)
h = tf.keras.layers.Dense(256, activation="relu")(h)
h = tf.keras.layers.Dropout(0.001)(h)
h = tf.keras.layers.Dense(512, activation="relu")(h)
h = tf.keras.layers.Dropout(0.001)(h)
h = tf.keras.layers.Dense(256, activation="relu")(h)
h = tf.keras.layers.Dropout(0.001)(h)
h = tf.keras.layers.Dense(128, activation="relu")(h)
h = tf.keras.layers.Dropout(0.001)(h)
h = tf.keras.layers.Dense(64, activation="relu")(h)
h = tf.keras.layers.Dropout(0.001)(h)
h = tf.keras.layers.Dense(32, activation="relu")(h)
h = tf.keras.layers.Dropout(0.001)(h)
outputs = tf.keras.layers.Dense(1, activation='sigmoid')(h)

chargenet = tf.keras.Model(inputs=[charge_input, params_input], outputs=outputs)

In [None]:
optimizer = tf.keras.optimizers.Adam(1e-4)
chargenet.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])

In [None]:
hist = chargenet.fit(training_generator, validation_data=validation_generator, epochs=5, verbose=1)