<h1 align="center"><a href="https://github.com/sborquez/gerumo/">*</a> GeRUMo - Gamma-ray Events Reconstructor with Uncertain models</h1>

<h2 align="center">Train Model</h2>

<center>
<img src="https://upload.wikimedia.org/wikipedia/commons/2/2f/Cta_concept.jpg" width="30%" alt="icon"></img>
</center>





## Setup

The first step is to sync this notebook with Google Drive, and change directory to gerumo repository.

#### Colab Setup

In [None]:
!pip install -q ctaplot

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd "/content/drive/My Drive/Projects/gerumo"
!ls

#### Local Setup

In [None]:
%cd D:\sebas\Google Drive\Projects\gerumo

In [1]:
%cd /home/bapanes/Research-Now/gerumo

/home/bapanes/Research-Now/gerumo


## Load Gerumo and extra modules.

In [2]:
from gerumo import *

import logging
import time
from os import path

import numpy as np
from tensorflow import keras
from tensorflow.keras.utils import plot_model

import json
from tqdm.notebook import tqdm

## Load Configuration file

In [3]:
#config_file = "/content/drive/My Drive/Projects/gerumo/train/config/umonna_sst_colab.json"
#config_file = "/content/drive/My Drive/Projects/gerumo/train/config/umonna_mst_colab.json"
#config_file = "/content/drive/My Drive/Projects/gerumo/train/config/umonna_lst_colab.json"

#config_file = "D:/sebas/Google Drive/Projects/gerumo/train/config/local/alt_az/bmo_mst_local.json"
config_file = "/home/bapanes/Research-Now/gerumo/train/config/local/alt_az/bmo_mst_local_bapanes.json"

In [4]:
print(f"Loading config from: {config_file}")
with open(config_file) as cfg_file:
    config = json.load(cfg_file)

# Model
model_name = config["model_name"]
model_constructor = MODELS[config["model_constructor"]]
model_extra_params = config["model_extra_params"]

# Dataset Parameters
version = config["version"]
output_folder = config["output_folder"]
replace_folder_train = config["replace_folder_train"]
replace_folder_validation = config["replace_folder_validation"]
train_events_csv    = config["train_events_csv"]
train_telescope_csv = config["train_telescope_csv"]
validation_events_csv    = config["validation_events_csv"]
validation_telescope_csv = config["validation_telescope_csv"]

# Input and Target Parameters 
telescope = config["telescope"]
min_observations = config["min_observations"]
input_image_mode = config["input_image_mode"]
input_image_mask = config["input_image_mask"]
input_features = config["input_features"]
targets = config["targets"]
target_mode = config["target_mode"]
target_shapes = config["target_shapes"]
target_domains = config["target_domains"]
#
if config["model_constructor"] == 'umonna':
    target_resolutions = get_resolution(targets, target_domains, target_shapes)

    # Prepare Generator target_mode_config 
    target_mode_config = {
        "target_shapes":      tuple([target_shapes[target]      for target in targets]),
        "target_domains":     tuple([target_domains[target]     for target in targets]),
        "target_resolutions": tuple([target_resolutions[target] for target in targets])
    }
    if target_mode == "probability_map":
        target_sigmas = config["target_sigmas"]
        target_mode_config["target_sigmas"] = tuple([target_sigmas[target] for target in targets])
else:
    target_mode_config = {
        "target_domains":     tuple([target_domains[target]     for target in targets]),
        "target_shapes":      tuple([np.inf                     for target in targets]),
        "target_resolutions": tuple([np.inf                     for target in targets])
    }
    target_resolutions = tuple([np.inf      for target in targets])

# Training Parameters
batch_size = config["batch_size"]
epochs = config["epochs"]
loss = config["loss"]
optimizer = config["optimizer"]["name"].lower()
learning_rate = config["optimizer"]["learning_rate"]
optimizer_parameters = config["optimizer"]["extra_parameters"]
optimizer_parameters = {} if optimizer_parameters is None else optimizer_parameters
save_checkpoints = config["save_checkpoints"]

# Debug
save_plot = config["save_plot"]
plot_only = config["plot_only"]
summary = config["summary"]

Loading config from: /home/bapanes/Research-Now/gerumo/train/config/local/alt_az/bmo_mst_local_bapanes.json


In [5]:
import pprint
pprint.pprint(config, width=1)

{'batch_size': 64,
 'epochs': 50,
 'input_features': ['x',
                    'y'],
 'input_image_mask': True,
 'input_image_mode': 'simple-shift',
 'loss': 'mse',
 'min_observations': 1,
 'model_constructor': 'bmo_unit',
 'model_extra_params': {'dropout_rate': 0.25,
                        'latent_variables': 600},
 'model_name': 'BMO_UNIT_MST',
 'optimizer': {'extra_parameters': {'momentum': 0.01,
                                    'nesterov': True},
               'learning_rate': 0.001,
               'name': 'sgd'},
 'output_folder': '/home/bapanes/Research-Now/local/ml-valpo-local/umonna/dataset/ML1/train_output_models/',
 'plot_only': False,
 'replace_folder_test': '...',
 'replace_folder_train': '/home/bapanes/Research-Now/local/ml-valpo-local/umonna/dataset/ML1/raw_data_generation_list/',
 'replace_folder_validation': '/home/bapanes/Research-Now/local/ml-valpo-local/umonna/dataset/ML1/raw_data_generation_list/',
 'save_checkpoints': True,
 'save_plot': False,
 'summary': Fal

## Load Dataset

In [6]:
import pandas as pd

In [7]:
# Prepare datasets
train_dataset      = load_dataset(train_events_csv, train_telescope_csv, replace_folder_train)
validation_dataset = load_dataset(validation_events_csv, validation_telescope_csv, replace_folder_validation)

train_dataset = aggregate_dataset(train_dataset, az=True, log10_mc_energy=True)
train_dataset = filter_dataset(train_dataset, telescope, min_observations, target_domains)

validation_dataset = aggregate_dataset(validation_dataset, az=True, log10_mc_energy=True, hdf5_file=True)
validation_dataset = filter_dataset(validation_dataset, telescope, min_observations, target_domains)

# Preprocessing pipes
preprocess_input_pipes = []
preprocess_output_pipes = []

In [8]:
train_dataset.groupby("type")[["alt", "az","log10_mc_energy"]].describe()

Unnamed: 0_level_0,alt,alt,alt,alt,alt,alt,alt,alt,az,az,az,az,az,log10_mc_energy,log10_mc_energy,log10_mc_energy,log10_mc_energy,log10_mc_energy,log10_mc_energy,log10_mc_energy,log10_mc_energy
Unnamed: 0_level_1,count,mean,std,min,25%,50%,75%,max,count,mean,...,75%,max,count,mean,std,min,25%,50%,75%,max
type,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
MST_FlashCam,3841.0,1.217563,0.041101,1.069269,1.185915,1.217665,1.24741,1.3529,3841.0,-0.008639,...,0.086528,0.499082,3841.0,-0.135418,0.772045,-1.989464,-0.747702,-0.274761,0.339106,2.160553


In [9]:
validation_dataset.groupby("type")[["alt", "az","log10_mc_energy"]].describe()

Unnamed: 0_level_0,alt,alt,alt,alt,alt,alt,alt,alt,az,az,az,az,az,log10_mc_energy,log10_mc_energy,log10_mc_energy,log10_mc_energy,log10_mc_energy,log10_mc_energy,log10_mc_energy,log10_mc_energy
Unnamed: 0_level_1,count,mean,std,min,25%,50%,75%,max,count,mean,...,75%,max,count,mean,std,min,25%,50%,75%,max
type,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
MST_FlashCam,366.0,1.222962,0.044939,1.101235,1.191892,1.222966,1.257264,1.334297,366.0,-0.004328,...,0.100594,0.367815,366.0,-0.318261,0.654453,-1.711332,-0.92434,-0.360779,0.20917,2.157028


## Data Generators

In [10]:
# Generators
train_generator = AssemblerUnitGenerator(train_dataset, batch_size, 
                                        input_image_mode=input_image_mode, 
                                        input_image_mask=input_image_mask, 
                                        input_features=input_features,
                                        targets=targets,
                                        target_mode=target_mode, 
                                        target_mode_config=target_mode_config,
                                        preprocess_input_pipes=preprocess_input_pipes,
                                        preprocess_output_pipes=preprocess_output_pipes,
                                        version=version
                                        )
validation_generator = AssemblerUnitGenerator(validation_dataset, batch_size//2, 
                                                input_image_mode=input_image_mode,
                                                input_image_mask=input_image_mask, 
                                                input_features=input_features,
                                                targets=targets,
                                                target_mode=target_mode, 
                                                target_mode_config=target_mode_config,
                                                preprocess_input_pipes=preprocess_input_pipes,
                                                preprocess_output_pipes=preprocess_output_pipes,
                                                version=version
                                            )

### Generator sample

In [11]:
i = 4
batch_i = train_generator[i]
print(f"Batch size: {batch_i[0][0].shape[0]}")
print()
print(f"Input mode: {input_image_mode}")
print(f"Input image shape: {batch_i[0][0].shape[1:]}")
print(f"Input features: {input_features}")
print(f"Input features shape: {batch_i[0][1].shape[1:]}")
print()
print(f"Target mode: {target_mode}")
print(f"Target shape: {batch_i[1].shape[1:]}")

img_sample =batch_i[0][0][0]
feature_sample = batch_i[0][1][0]
target_sample = batch_i[1][0]

#show_input_sample(img_sample, input_image_mode, feature_sample, make_simple=False)

#show_target_sample(target_sample, targets, target_mode, target_domains)

Batch size: 64

Input mode: simple-shift
Input image shape: (2, 84, 29, 3)
Input features: ['x', 'y']
Input features shape: (2,)

Target mode: lineal
Target shape: (2,)


In [12]:
# Debug: Check if train dataset is loadable
for i in tqdm(range(len(train_generator))):
  _ = train_generator[i]

# Debug: Check if train dataset is loadable
for i in tqdm(range(len(validation_generator))):
  _ = validation_generator[i]

HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))




## Compile model 

In [13]:
output_folder

'/home/bapanes/Research-Now/local/ml-valpo-local/umonna/dataset/ML1/train_output_models/'

In [14]:
model_name

'BMO_UNIT_MST'

In [15]:
telescope

'MST_FlashCam'

In [16]:
loss

'mse'

In [17]:
# CallBacks
callbacks = []
save_checkpoints = True

if save_checkpoints:
#if False:
    # Checkpoint parameters
    checkpoint_filepath = "%s_%s_%s_e{epoch:03d}_{val_loss:.4f}.h5"%(model_name, telescope, loss)
    checkpoint_filepath = path.join(output_folder, checkpoint_filepath)
    callbacks.append(
        keras.callbacks.ModelCheckpoint(checkpoint_filepath, monitor='val_loss', 
              verbose=1, save_weights_only=False, mode='min', save_best_only=True))

# Train
input_img_shape = INPUT_SHAPE[f"{input_image_mode}-mask" if input_image_mask else input_image_mode][telescope]
input_features_shape = (len(input_features),)
target_shapes = target_mode_config["target_shapes"]
model = model_constructor(telescope, input_image_mode, input_image_mask, 
                input_img_shape, input_features_shape,
                targets, target_mode, 
                target_shapes=target_shapes, 
                **model_extra_params)

Cause: mangled names are not yet supported
Cause: mangled names are not yet supported


### Summary model

In [18]:
model.summary()

Model: "BMO_Unit_MST_FlashCam"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
image_input (InputLayer)        [(None, 2, 84, 29, 3 0                                            
__________________________________________________________________________________________________
encoder_hex_conv_layer (HexConv (None, 41, 13, 32)   896         image_input[0][0]                
__________________________________________________________________________________________________
encoder_conv_layer_1_a (Conv2D) (None, 41, 13, 32)   25632       encoder_hex_conv_layer[0][0]     
__________________________________________________________________________________________________
encoder_ReLU_1_a (Activation)   (None, 41, 13, 32)   0           encoder_conv_layer_1_a[0][0]     
______________________________________________________________________________

In [19]:
plot_model(model, to_file="./model.png", show_shapes=True)

Failed to import pydot. You must install pydot and graphviz for `pydotprint` to work.


### Loss Function

In [None]:
## Loss function
loss = loss if loss.split('_')[-1] == 'loss' else f'{loss}_loss'
if loss == "crossentropy":
    loss_ = LOSS[loss](dimensions=len(targets))
elif loss == "distance":
    loss_ = mean_distance_loss(target_shapes)
else:
    loss_ = LOSS[loss]()

### Optimizer

In [None]:
optimizer_ = OPTIMIZERS[optimizer](
    learning_rate=learning_rate,
    **optimizer_parameters
)

## Fit Model

In [None]:
model.compile(
    optimizer=optimizer_,
    loss=loss_
)

In [None]:
start_time = time.time()
history = model.fit(
    train_generator,
    epochs = epochs,
    verbose = 1,
    validation_data = validation_generator,
    validation_steps = len(validation_generator),
    callbacks = callbacks,
    #use_multiprocessing = True,
    #workers = 4,
    max_queue_size = 75
)
training_time = (time.time() - start_time)/60.0

## Show some predictions (DEPRECATED)

In [None]:
'''
import matplotlib.pyplot as plt

# Validate
s = 0
for j in range(0,50,2):
  s += 1
  batch_0 = validation_generator[j]
  i=np.random.randint(len(batch_0[1]))
  #prediction = model.predict(batch_0[0])[i]
  prediction = BMO.bayesian_estimation(model, batch_0[0], 100, 0)[0][i]
  target = batch_0[1][i]
  input_img_e = batch_0[0][0][i][0][:,:,0]
  input_img_t = batch_0[0][0][i][0][:,:,1]
  input_img_m = batch_0[0][0][i][0][:,:,2]

  # TODO: Add plots to viz module
  fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(8, 4))
  fig.suptitle(f'validation sample {s}')
  ax0.imshow(input_img_e)
  ax0.set_title(f"Input charge")
  ax1.imshow(input_img_t)
  ax1.set_title(f"Input peak")
  ax2.imshow(input_img_m)
  ax2.set_title(f"Input mask")
  plt.show()

  _, (ax2, ax3)  = plt.subplots(1, 2, figsize=(8, 4))
  ax2.imshow(target, cmap="jet")
  ax2.set_title(f"Target")

  #ax3.imshow(prediction, cmap="jet")
  #ax3.set_title("Prediction")
  show_pdf_2d(prediction, prediction_point, targets, target_domains, targets_values=None, axis=ax3)
  plt.show()
  ''';