In [1]:
import tensorflow as tf
#from tensorflow.keras import layers, models, optimizers

import aux_code
import numpy as np

import os

#import matplotlib.pyplot as plt
import pickle

# Define folders

In [2]:
# Folder containing the dataset
dataset_folder = os.path.join('.','dataset')
# Folder used to save the training history
history_folder = os.path.join('.','history')
aux_code.create_folder(history_folder)
# Folder used to save the optimized MLP weights
model_folder   = os.path.join('.','model')
aux_code.create_folder(model_folder)

# Define MLP

In [3]:
model = aux_code.define_mlp()

2024-04-23 09:43:30.254425: I tensorflow/core/platform/cpu_feature_guard.cc:145] This TensorFlow binary is optimized with Intel(R) MKL-DNN to use the following CPU instructions in performance critical operations:  SSE4.1 SSE4.2
To enable them in non-MKL-DNN operations, rebuild TensorFlow with the appropriate compiler flags.
2024-04-23 09:43:30.254749: I tensorflow/core/common_runtime/process_util.cc:115] Creating new thread pool with default inter op setting: 8. Tune using inter_op_parallelism_threads for best performance.


# Load and normalize dataset

In [4]:
# Training set
training_set     = np.load(os.path.join(dataset_folder, 'training_set.npz'))
x_train, y_train = aux_code.normalize_dataset(training_set['nn_data_train'], 
                                              training_set['stx_flare_loc_train'])

# Validation set
validation_set   = np.load(os.path.join(dataset_folder, 'validation_set.npz'))
x_valid, y_valid = aux_code.normalize_dataset(validation_set['nn_data_valid'], 
                                              validation_set['stx_flare_loc_valid'])

# Train dataset

In [5]:
# Compile the model
model.compile(optimizer='adam', loss='mse')

# Define callbacks
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',  
                                                 factor=0.5,
                                                 patience=10, 
                                                 min_lr=1e-6)  

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(model_folder,'model_weights.h5'), 
                                                         save_best_only=True,
                                                         monitor='val_loss',
                                                         mode='min',
                                                         verbose=1)

early_stopping_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', 
                                                           patience=20,
                                                           mode='min',
                                                           verbose=1)


# Train the model
history = model.fit(x_train, y_train, 
                    epochs=1000, 
                    batch_size=100, 
                    validation_data=(x_valid, y_valid),
                    callbacks=[reduce_lr,checkpoint_callback,early_stopping_callback])

# Save history
with open(os.path.join(history_folder,'training_history.pkl'), 'wb') as file:
    pickle.dump(history.history, file)

Train on 18145 samples, validate on 12824 samples
Epoch 1/1000
Epoch 00001: val_loss improved from inf to 0.00076, saving model to ./model/model_weights.h5
Epoch 2/1000
Epoch 00002: val_loss improved from 0.00076 to 0.00057, saving model to ./model/model_weights.h5
Epoch 3/1000
Epoch 00003: val_loss improved from 0.00057 to 0.00048, saving model to ./model/model_weights.h5
Epoch 4/1000
Epoch 00004: val_loss improved from 0.00048 to 0.00045, saving model to ./model/model_weights.h5
Epoch 5/1000
Epoch 00005: val_loss did not improve from 0.00045
Epoch 6/1000
Epoch 00006: val_loss did not improve from 0.00045
Epoch 7/1000
Epoch 00007: val_loss improved from 0.00045 to 0.00043, saving model to ./model/model_weights.h5
Epoch 8/1000
Epoch 00008: val_loss did not improve from 0.00043
Epoch 9/1000
Epoch 00009: val_loss did not improve from 0.00043
Epoch 10/1000
Epoch 00010: val_loss improved from 0.00043 to 0.00042, saving model to ./model/model_weights.h5
Epoch 11/1000
Epoch 00011: val_loss d