In [None]:
import tensorflow as tf
import numpy as np
import os
import sys

os.environ["CUDA_VISIBLE_DEVICES"]="0, 1, 2, 3"

from freedom.utils.dataset import Data, DataGenerator
from freedom.utils.i3cols_dataloader import load_hits, load_charges
from freedom.neural_nets.hitnet import get_hitnet
from freedom.neural_nets.chargenet import get_chargenet

In [None]:
labels = ['x', 'y', 'z', 'time', 'azimuth','zenith', 'cascade_energy', 'track_energy']
optimizer = tf.keras.optimizers.Adam(1e-4)

strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    hitnet = get_hitnet(labels)
    hitnet.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
nGPUs = strategy.num_replicas_in_sync

### DataGenerator

In [None]:
train_d = ['/tf/localscratch/weldert/120000_i3cols_train/', '/tf/localscratch/weldert/140000_i3cols_train/'] #
valid_d = ['/tf/localscratch/weldert/120000_i3cols_valid/', '/tf/localscratch/weldert/140000_i3cols_valid/'] #

training_generator = DataGenerator(load_hits, train_d, labels, batch_size=4096*nGPUs)
validation_generator = DataGenerator(load_hits, valid_d, labels, batch_size=4096*nGPUs)

In [None]:
hist = hitnet.fit(training_generator, validation_data=validation_generator, epochs=5, verbose=1)

### DataSet

In [None]:
data = Data(dirs=['/tf/localscratch/weldert/120000_i3cols_train', '/tf/localscratch/weldert/140000_i3cols_train']) #

train_data, test_data = data.get_chargenet_data(train_batch_size=4096*nGPUs, test_batch_size=4096*nGPUs)

In [None]:
hist = hitnet.fit(train_data, validation_data=test_data, epochs=10, verbose=1)