### Intro

In this notebook, we will go through the actual modelling (the most exciting part!) of the light curves with the AttnLNP (or the latte model). There are many parameters that can be tuned, which will be defined before running the notebook

### Defining Important Hyperparameters

In [None]:
#The suffix to attach to the output file
file_name_output = 'LC'
#The batch size for the input data
BATCH_SIZE = 8

#Model hyperparameters
encoding_size = 128 #Encoder MLP layer size
latent_size = 128 #Latent dimension size

attention_type = 'scaledot' #Can also use multihead, but scaledot works better
cross_attention = True #Whether to include cross-attention in the deterministic path
self_attention = True #Whether to include self-attention in both paths

lstm_layers = 0 #The number of LSTM layers to use for pre-encoding
lstm_size = 32 #The size of the LSTM layer

use_scheduler = False # Whether to use a learning rate scheduler (has not been found to be effective)
replace_lstm_with_gru = False # Whether to use a GRU instead of an LSTM
bidirectional = False #Whether to use bidirectional LSTM/GRU layers
lstm_agg = False #Whether to aggregate the latent space representations via an LSTM instead of mean pooling
augment = True #Whether to augment the input data by randomly adding or subtracting the error on the fly
activation = 'relu' #Can also make it 'leaky' for LeakyReLu but ReLu seems to work better
lr = 1e-3 #The learning rate for the ADAM optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #defining the device for testing, it can be CPU or CUDA

num_epochs = 1000 #The number of epochs to train
num_runs = 1 #Amount of runs to train the model for. This runs it for another num_epochs unless it stops early
early_stopping_limit = 500 #The number of epochs to see improvement before 
validation_epochs = 10 #Number of epochs before checking the validation data

beta_tf = 0.1 #The factor of the TF in the loss
transfer_function_length = 0 #The length of the transfer function of the data. If no transfer function, it should be 0
tf_folder = None #Link to the folder of the transfer functions
#tf_folder = 'Transfer_Functions/'
#tf_folder = 'Transfer_Functions/band_name'


param_df_path = None
#param_df_path = 'Parameters.csv' #If there is no parameters dataframe, it should be None

if param_df_path is not None:
    param_df = pd.read_csv(param_df_path)
    param_beta = 0.1
    param_columns=['Log_Mass','Inclination','Log_Tau','z','Eddington_Ratio','SFinf'] #Change to the names of your columns
    param_length = len(param_columns)
else:
    param_df = None
    param_beta = 0
    param_columns = []
    param_length = 0

### Importing Relevant Packages

In [None]:
import QNPy_Latte.SPLITTING_AND_TRAINING as st #Importing SPLITTING_AND_TRAINING module from the package
from QNPy_Latte.SPLITTING_AND_TRAINING import * #Importing all packages from SPLITTING_AND_TRAINING module

### Creating the Output Folders for saving

In [None]:
os.makedirs(f'output_{file_name_output}',exist_ok=True)
DATA_PATH_TRAIN = f"./dataset_{file_name}/train" #path to train folder
DATA_PATH_VAL = f"./dataset_{file_name}/val" #path to val folder

MODEL_PATH = f"./output_{file_name_output}/model_{file_name_output}.pth" #path for saving model

### Initializing model, optimizer and metrics

In [None]:
model, optimizer, scheduler, criterion, mseMetric, maeMetric = st.create_model_and_optimizer(device,encoding_size,latent_size,\
                                                                                             attention=cross_attention,self_attention=self_attention,\
                                                                                             use_scheduler=use_scheduler,lstm_layers = lstm_layers,\
                                                                                             lstm_size=lstm_size,transfer_function_length=transfer_function_length,\
                                                                                             parameters_length = param_length,classes = 0,\
                                                                                             replace_lstm_with_gru=replace_lstm_with_gru,\
                                                                                             bidirectional = bidirectional,activation=activation,\
                                                                                             lr=lr,attention_type=attention_type)


### Training Model

Associated with the AttnLNP model, there are many losses. They are defined for both the train and the validation data:

History_Loss - The overall loss that is used to train the model

History_mse - The Mean Squared Error

History_mae - The Mean Absolute Error

Epoch_Counter_Loss - The epoch counter for the history loss

Epoch_Counter_mse - The epoch counter for the MSE loss

Epoch_Counter_mae - The epoch counter for the MAE loss

History_Loss_Reconstruction - The reconstruction loss (Comes from the Gaussian LogProbLoss)

History_Loss_TF - The transfer function loss (Gaussian LogProbLoss)

History_Loss_param - The loss associated with the parameters (Gaussian LogProbLoss)

History_KL_loss - The KL divergence between the posterior and prior latent distributions to keep the sampling coherent

In [None]:
#Returning the trained model and all the losses associated with it
history_loss_train,history_loss_val,history_mse_train,history_mse_val,history_mae_train,history_mae_val,\
epoch_counter_train_loss,epoch_counter_train_mse,epoch_counter_train_mae,epoch_counter_val_loss,epoch_counter_val_mse,epoch_counter_val_mae,\
history_loss_reconstruction_train,history_loss_reconstruction_val,history_loss_tf_train,history_loss_tf_val,history_loss_param_train,history_loss_param_val,\
history_loss_classes_train,history_loss_classes_val,history_kl_loss_train,history_kl_loss_val = st.train_model(
    model, criterion, optimizer, scheduler, num_runs, num_epochs, early_stopping_limit, mseMetric, maeMetric, device,DATA_PATH_TRAIN,DATA_PATH_VAL,BATCH_SIZE,\
    beta_classifier = 0,beta_tf=beta_tf,beta_param= param_beta,tf_dir=tf_folder,param_df=param_df,param_columns=param_columns,augment = augment,validation_epochs = validation_epochs)

### Saving

In [None]:
save=st.save_model(model, MODEL_PATH)#saving the trained model

In [None]:
# Define the file names for saving the lists for all histories
file_names = [f"output_{file_name_output}/history_loss_train.csv", f"output_{file_name_output}/history_loss_val.csv", f"output_{file_name_output}/history_mse_train.csv", f"output_{file_name_output}/history_mse_val.csv",
            f"output_{file_name_output}/history_mae_train.csv", f"output_{file_name_output}/history_mae_val.csv", f"output_{file_name_output}/epoch_counter_train_loss.csv",
            f"output_{file_name_output}/epoch_counter_train_mse.csv", f"output_{file_name_output}/epoch_counter_train_mae.csv", f"output_{file_name_output}/epoch_counter_val_loss.csv",
            f"output_{file_name_output}/epoch_counter_val_mse.csv", f"output_{file_name_output}/epoch_counter_val_mae.csv",f"output_{file_name_output}/history_loss_reconstruction_train.csv", f"output_{file_name_output}/history_loss_reconstruction_val.csv",\
            f"output_{file_name_output}/history_loss_tf_train.csv", f"output_{file_name_output}/history_loss_tf_val.csv",f"output_{file_name_output}/history_loss_param_train.csv", f"output_{file_name_output}/history_loss_param_val.csv",\
            f"output_{file_name_output}/history_loss_classes_train.csv", f"output_{file_name_output}/history_loss_classes_val.csv",f"output_{file_name_output}/history_kl_loss_train.csv", f"output_{file_name_output}/history_kl_loss_val.csv",]

# Define the lists
lists = [history_loss_train, history_loss_val, history_mse_train, history_mse_val, history_mae_train,
        history_mae_val, epoch_counter_train_loss, epoch_counter_train_mse, epoch_counter_train_mae,
        epoch_counter_val_loss, epoch_counter_val_mse, epoch_counter_val_mae,history_loss_reconstruction_train,
        history_loss_reconstruction_val,history_loss_tf_train,history_loss_tf_val,history_loss_param_train,
        history_loss_param_val,history_loss_classes_train,history_loss_classes_val,history_kl_loss_train,history_kl_loss_val]

#running the function for saving all lists with histories
save_list= st.save_lists_to_csv(file_names, lists)