# Training a Next Event Predictor

For discovering a process model, a machine learning model that can predict the next event in a case is required.
In the following, we can either train different versions of the BINet or the Transformer model.
* *Model*: The type of model to train.
* *Event Log*: The event log on which to train the model.
* *Use Event Attributes*: Whether to train the model on predicting the event attributes.
* *Epochs*: The maximum number of epochs for which to train the model for.
* *Batch Size*: The batch size when training.
* *Validation Split*: The percentage of training data to use for calculating the validation loss (used for early stopping).
* *Patience*: The patience for early stopping.
* *Delta*: The minimum delta for early stopping.
* *Layers*: The number of encoding layers of the Transformer model.
* *Heads*: The number of attention heads for each encoding layer of the Transformer.

In [1]:
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))

from ipywidgets import widgets, interact, interact_manual, Layout, Button, Box
from IPython.display import display

from april.fs import MODEL_DIR, DATE_FORMAT, EVENTLOG_DIR
from r2pa.api import routines
from april.alignments.binet import BINet
from april import Dataset

import arrow
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

model_name_widget = widgets.Dropdown(description='Model', options=['BINetV1', 'BINetV2', 'BINetV3', 'Transformer'])
event_log_widget = widgets.Dropdown(description='Event Log')
batch_size_widget = widgets.IntText(description='Batch Size', value=50)
validation_split_widget = widgets.FloatSlider(description='Valid. Split', min=0, max=1, step=0.01, value=0.1)
epochs_widget = widgets.IntText(description='Epochs', value=100)
patience_widget = widgets.IntText(description='Patience', value=5)
delta_widget = widgets.FloatText(description='Delta', value=0.01)
number_layers_widget = widgets.IntSlider(description='Layers', min=1, max=6, value=3, disabled=True)
number_heads_widget = widgets.IntSlider(description='Heads', min=1, max=6, value=3, disabled=True)
event_attributes_widget = widgets.Checkbox(description='Use Event Attributes', value=True)

train_button = widgets.Button(description="Train")
reload_button = widgets.Button(description="Reload Logs")

first_row = widgets.HBox([model_name_widget, event_log_widget, event_attributes_widget])
second_row = widgets.HBox([epochs_widget, batch_size_widget, validation_split_widget])
third_row = widgets.HBox([patience_widget, delta_widget])
fourth_row = widgets.HBox([number_layers_widget, number_heads_widget])
fifth_row = widgets.HBox([train_button, reload_button])

parameter_gui = widgets.Output()
output = widgets.Output()

with parameter_gui:
    display(widgets.VBox([first_row, second_row, third_row, fourth_row, fifth_row]))

def get_all_event_logs():
    files = os.listdir(EVENTLOG_DIR)
    return [file[0:len(file)-8] for file in files if file[-8:] == '.json.gz']

def on_reload(button):
    event_log_widget.options = get_all_event_logs()

def on_value_change(change):
    if change['new'] == 'Transformer':
        number_layers_widget.disabled = False
        number_heads_widget.disabled = False
        number_heads_widget.layout.visibility = 'visible'
        number_layers_widget.layout.visibility = 'visible'
    else:
        number_layers_widget.disabled = True
        number_heads_widget.disabled = True
        number_heads_widget.layout.visibility = 'hidden'
        number_layers_widget.layout.visibility = 'hidden'
    
number_heads_widget.layout.visibility = 'hidden'
number_layers_widget.layout.visibility = 'hidden'
model_name_widget.observe(on_value_change, names='value')

def train_model(button):
    parameters = {'event_log': event_log_widget.value,
                  'epochs': int(epochs_widget.value),
                  'batch_size': int(batch_size_widget.value),
                  'early_stopping_patience': int(patience_widget.value),
                  'early_stopping_delta': float(delta_widget.value),
                  'validation_split': validation_split_widget.value,
                  'use_event_attributes': event_attributes_widget.value,
                  'smoothing_extend': 0.0}

    start_time = arrow.now()

    event_log = parameters['event_log']
    dataset = Dataset(event_log, use_event_attributes=parameters['use_event_attributes'], use_case_attributes=False)
    
    model = model_name_widget.value
    
    if model == 'Transformer':
        parameters['number_layers'] = number_layers_widget.value 
        parameters['number_heads'] = number_heads_widget.value

        output_name = f'{parameters["event_log"]}_{parameters["number_layers"]}TR{parameters["number_heads"]}_{start_time.format(DATE_FORMAT)}'
        output_locations = [output_name]
        
        output.clear_output()
        with output:
            print(f"Upon completion of the training, the Transformer will be stored at {MODEL_DIR / output_name}")
            routines.train_transformer(output_locations=output_locations, event_log=event_log, parameters=parameters)
    else:
        version = int(model[-1])
        (present_activity, present_attribute), combination = routines.get_present_setting(version)
        binet = BINet(dataset, use_event_attributes=parameters['use_event_attributes'], use_case_attributes=False,
                      use_present_activity=present_activity, use_present_attributes=present_attribute)

        output_name = f'{event_log}_{binet.name}{combination}_{start_time.format(DATE_FORMAT)}'
        output_locations = [output_name]
        
        output.clear_output()
        with output:
            print(f"Upon completion of the training, the BINet will be stored as {output_name}")
            routines.train_binet(output_locations=output_locations, event_log=event_log, version=version, parameters=parameters)


train_button.on_click(train_model)
reload_button.on_click(on_reload)
on_reload(None)

display(parameter_gui)
display(output)


Output()

Output()