# Peptide Detectability (Training and fine tuning) 

This notebook is prepared to be run in Google [Colaboratory](https://colab.research.google.com/). In order to train the model faster, please change the runtime of Colab to use Hardware Accelerator, either GPU or TPU.

This notebook provides a concise walkthrough of the process for reading a dataset, training, and fine-tuning a model for peptide detectability prediction. 

The dataset used in this example is derived from:

- **ProteomTools Dataset**: Includes data from the PRIDE repository with the following identifiers: `PXD004732`, `PXD010595`, and `PXD021013`.
- **MAssIVE Dataset**: Deposited in the ProteomeXchange Consortium via the MAssIVE partner repository with the identifier `PXD024364`.

The framework being used is a custom wrapper on top of Keras/TensorFlow. The working name of the package is for now DLOmix -  `dlomix`.

#### Installing the DLOmix Package

If you haven't installed the DLOmix package yet, you need to do so before running the code. 

You can install the DLOmix package using pip.

In [1]:
# # install the DLOmix package in the current environment using pip

# !python -m pip install -q git+https://github.com/wilhelm-lab/dlomix

#### Importing Required Libraries

Before running the code, ensure you import all the necessary libraries. These imports are essential for accessing the functionalities needed for data processing, model training, and evaluation.

In [2]:
import sys

In [3]:
sys.path.append('C:/Users/JZ05DL/OneDrive - Aalborg Universitet/Documents/GitHub/dlomix/src')
sys.path.append('C:/Users/JZ05DL/OneDrive - Aalborg Universitet/Documents/GitHub/dlomix/src/dlomix/data')
sys.path.append('C:/Users/JZ05DL/OneDrive - Aalborg Universitet/Documents/GitHub/dlomix/src/dlomix/models')
sys.path.append('C:/Users/JZ05DL/OneDrive - Aalborg Universitet/Documents/GitHub/dlomix/src/dlomix/reports')

In [4]:
import numpy as np
import pandas as pd
import tensorflow as tf
import dlomix

import os
print([x for x in dir(dlomix)])

['META_DATA', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '__version__']


- `constants`: constants to be used in the framework (e.g. Aminoacid alphabet mapping)
- `data`:  classes for representing dataset, wrappers around TensorFlow Dataset
- `eval`: custom evaluation metrics implemented in Keras/TF to work as `metrics` for model training
- `layers`: custom layer implementation required for the different models
- `models`: different model implementations (e.g. Detectability Prediction)
- `pipelines`: complete pipelines to run a task (e.g. Retention Time prediction)
- `utils`: helper modules

**Note**: reports and pipelines are work-in-progress, some funtionalities are not complete.

## 1. Load Data for Training

You can import the `detectability_dataset` class and create an instance to manage data for training, validation, and testing. This instance handles TensorFlow dataset objects and allows for easy control of dataset splits.

#### Parameters

- **data_source** (`str`, `tuple of two numpy.ndarray`, `numpy.ndarray`): Specifies the data source. This can be a tuple of two arrays (sequences and classes), a single array (sequences), which is useful for test data, or a file path to a CSV file. Defaults to `None`.

- **protein_data** (`str`, `numpy.ndarray`): Specifies the protein data. This can be a single array with protein names or IDs, or the name of the column with protein information in the CSV file if a file path is provided. Defaults to `'proteins'`.

- **sep** (`str`): Separator used in the CSV file if the data source is a CSV file. Defaults to `","`.

- **sequence_col** (`str`): Name of the column containing sequences in the CSV file. Defaults to `"sequences"`.

- **classes_col** (`str`): Name of the column containing classes in the CSV file. Defaults to `"classes"`.

- **split_on_protein** (`bool`): Determines if the dataset should be split based on proteins, ensuring that all peptides from a specific protein are kept together. Requires `protein_data` to be provided. Defaults to `False`.

- **seq_length** (`int`): Maximum sequence length. Sequences longer than this length will be removed, and shorter sequences will be padded. Defaults to `40`.

- **batch_size** (`int`): Batch size for training. Defaults to `32`.

- **val_ratio** (`float`): Fraction of the dataset to be used for validation. Defaults to `0.1` (10%).

- **test_ratio** (`float`): Fraction of the dataset to be used for testing. Defaults to `0.2` (20%).

- **seed** (`int`): Seed for reproducible data splitting. Defaults to `21`.

- **test** (`bool`): Indicates if the dataset is used solely for testing. Defaults to `False`.

- **sample_run** (`bool`): Limits the number of examples for testing and debugging purposes. Defaults to `False`.


**Note**: If class labels are provided, the following encoding scheme should be used:
- **Non-Flyer**: 0
- **Weak Flyer**: 1
- **Intermediate Flyer**: 2
- **Strong Flyer**: 3

In [7]:
from dlomix.data.detectability_dataset import detectability_dataset

ModuleNotFoundError: No module named 'datasets'

In [None]:
TRAIN_DATAPATH = './example_dataset/detectability_data.csv'

max_pep_length = 40
BATCH_SIZE = 128
            
detectability_data = detectability_dataset(data_source = TRAIN_DATAPATH, 
                                           #protein_data = np.array([x.upper() for x in protein_data]),
                                           seq_length = max_pep_length,
                                           split_on_protein = False, 
                                           batch_size = BATCH_SIZE, 
                                           val_ratio = 0.1, 
                                           test_ratio = 0.2, 
                                           test = False)

The `detectability_dataset` class can be directly integrated with both standard and custom `Keras` models. This dataset wrapper provides predefined splits for training, validation, and testing. These subsets can be accessed via the attributes `.train_data`, `.val_data`, and `.test_data`. Each subset's length is represented in batches, where the total number of examples is calculated as `total examples = batch_size x len(subset)`.

If the parameter `test=True` is set, only the test dataset is generated, skipping the training and validation splits.

In [None]:
 "Training examples", BATCH_SIZE * len(detectability_data.train_data)

In [None]:
"Validation examples", BATCH_SIZE * len(detectability_data.val_data)

In [None]:
"Test examples", BATCH_SIZE * len(detectability_data.test_data)

It is also possible to retrieve the different dataset splits as DataFrames by using the `.get_split_dataframe` method and specifying the desired split: `train`, `val`, or `test`.

In [None]:
test_data_df = detectability_data.get_split_dataframe(split = "test")
test_data_df.head(5)

## 2. Model

We can now create the model. The model architecture is an encoder-decoder with an attention mechanism, that is based on Bidirectional Recurrent Neural Network (BRNN) with Gated Recurrent Units (GRU). Both the Encoder and Decoder consists of a single layer, with the Decoder also including a Dense layer. The model has the default working arguments.

In [None]:
from dlomix.models import detetability_model
from dlomix.detectability_model_constants import CLASSES_LABELS, alphabet

In [None]:
total_num_classes = len(CLASSES_LABELS)
input_dimension = len(alphabet)
num_cells = 64

model = detetability_model.detetability_model(num_units = num_cells,
                                              num_clases = total_num_classes)

## 3. Training and saving the model

You can train the model using the standard Keras approach. The training parameters provided here are those initially configured for the detectability model. However, you have the flexibility to modify these parameters to suit your specific needs.

#### Compile the Model

Compile the model with the selected settings. You can use built-in TensorFlow options or define and pass custom settings for the optimizer, loss function, and metrics. The default configurations match those used in the original study, but you can modify these settings according to your preferences.

Early stopping is also configured with the original settings, but the parameters can be adjusted based on user preferences. Early stopping monitors a performance metric (e.g., validation loss) and halts training when no improvement is observed for a specified number of epochs. This feature helps prevent overfitting and ensures efficient training.

In [None]:
callback = tf.keras.callbacks.EarlyStopping(monitor = 'val_loss', 
                                            mode = 'min', 
                                            verbose = 1, 
                                            patience = 5)


model_checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath = 'output/weights/new_base_model/base_model_weights',
                                                      monitor = 'val_categorical_accuracy',
                                                      mode = 'max',
                                                      verbose = 1,
                                                      save_best_only = True, 
                                                      save_weights_only = True)

model.compile(optimizer = 'adam',
              loss = 'CategoricalCrossentropy',
              metrics = 'categorical_accuracy')

We save the results of the training process to enable a detailed examination of the metrics and losses at a later stage. We define the number of epochs for training and supply the training and validation data previously generated. This approach allows us to effectively monitor the model’s performance and make any necessary adjustments.

In [None]:
history = model.fit(detectability_data.train_data,
                    validation_data = detectability_data.val_data,
                    epochs = 50, 
                    callbacks=[callback, model_checkpoint])

## 4. Testing and Reporting


We use the test dataset to assess our model's performance, which is only applicable if labels are available. The `detectability_report` class allows us to compute various metrics, generate reports, and create plots for a comprehensive evaluation of the model.

Note: The reporting module is currently under development, so some features may be unstable or subject to change.

In the next cell, set the path to the model weights. By default, it points to the newly trained base model. If using different weights, update the path accordingly.

In [None]:
## Loading best model's weights 

model.load_weights('output/weights/new_base_model/base_model_weights')

##### Generate Predictions on Test Data Using `model.predict`

To obtain predictions for your test data, use the Keras `model.predict` method. Simply pass your test dataset to this method, and it will return the model's predictions.

In [None]:
# use model.predict from keras directly on the testdata

predictions = model.predict(detectability_data.test_data)

To generate reports and calculate evaluation metrics against predictions, you need to obtain the targets and DataFrames for the specific dataset split. This can be achieved using the `get_split_targets` and `get_split_dataframe` functions from the `detectability_dataset` class.

In [None]:
test_targets = detectability_data.get_split_targets(split = "test")
test_data_df = detectability_data.get_split_dataframe(split = "test")

In [8]:
from dlomix.reports.detectability_report import detectability_report, predictions_report
WANDB_REPORT_API_DISABLE_MESSAGE=True

#### Generate a Report Using the `detectability_report` Class

The `detectability_report` class provides a comprehensive way to evaluate your model by generating detailed reports and visualizations. The outputs include:

1. **A PDF Report**: This includes evaluation metrics and plots.
2. **A CSV File**: Contains the model’s predictions.
3. **Independent Image Files**: Visualizations are saved as separate image files.

To generate a report, provide the following parameters to the `detectability_report` class:

- **targets**: The true labels for the dataset, which are used to assess the model’s performance.
- **predictions**: The model’s output predictions for the dataset, which will be compared against the true labels.
- **input_data_df**: The DataFrame containing the input data used for generating predictions.
- **output_path**: The directory path where the generated reports, images, and CSV file will be saved.
- **history**: The training history object (e.g., containing metrics from training) if available. Set this to `None` if not applicable, such as when the report is generated for predictions without training.
- **rank_by_prot**: A boolean indicating whether to rank peptides based on their associated proteins (`True` or `False`).
- **threshold**: The classification threshold used to adjust the decision boundary for predictions. By default, this is set to `None`, meaning no specific threshold is applied.
- **name_of_dataset**: The name of the dataset used for generating predictions, which will be included in the report to provide context.
- **name_of_model**: The name of the model used to generate the predictions, which will be specified in the report for reference.

Note: The reporting module is currently under development, so some features may be unstable or subject to change.

In [None]:
report = detectability_report(targets = test_targets, 
                              predictions = predictions, 
                              input_data_df = test_data_df,
                              output_path = "./output/report_on_ProteomeTools", 
                              history = history, 
                              rank_by_prot = False,
                              threshold = None,
                              name_of_dataset = 'ProteomeTools',
                              name_of_model = 'Base model (new)')

#### Predictions report

In [None]:
results_df = report.detectability_report_table
results_df

#### Generating Evaluation Plots with `detectability_report`

The `detectability_report` class enables you to generate a range of plots to visualize and evaluate model performance. It offers a comprehensive suite of visualizations to help you interpret the results of your model's predictions. Here’s how to use it:

##### Training and Validation Metrics

These plots show the training and validation metrics over epochs. The first plot displays the loss, and the second shows the categorical accuracy. Both plots are generated from the `history` object recorded during the model training process.

In [None]:
report.plot_keras_metric("loss")

In [None]:
report.plot_keras_metric("categorical_accuracy")

##### ROC curve (Binary)

In [None]:
report.plot_roc_curve_binary()

##### Confusion matrix (Binary)

In [None]:
report.plot_confusion_matrix_binary()

##### ROC curve (Multi-class)

In [None]:
report.plot_roc_curve()

##### Confusion matrix (Multi-class)

In [None]:
report.plot_confusion_matrix_multiclass()

#### Heatmap of Average Error Between Actual and Predicted Classes

In [None]:
report.plot_heatmap_prediction_prob_error()

We can also produce a complete report with all the relevant plots in one PDF file by calling the `generate_report` function.

In [None]:
report.generate_report()

### Example: Defining a Classification Threshold

In the following example, a specific classification threshold is defined to adjust the decision boundary for the model's predictions. By setting a threshold, you can control the sensitivity of the model, influencing how it categorizes the output into different classes.

In [None]:
report_using_threshold = detectability_report(test_targets, 
                                              predictions, 
                                              test_data_df, 
                                              output_path = "./output/report_on_ProteomeTools_with_threshold", 
                                              history = history, 
                                              rank_by_prot = False,
                                              threshold = 0.02,                              
                                              name_of_dataset = 'ProteomeTools',
                                              name_of_model = 'Base model (new) with threshold')

#### Predictions report 

In [None]:
report_using_threshold.detectability_report_table

Generating a complete PDF report using the `generate_report` function.

In [None]:
report_using_threshold.generate_report()

## 5. Load data for fine tuning

For fine-tuning, the process mirrors the steps used during training. Simply create a `detectability_dataset` object with the fine-tuning data (refer to **Section 1: Load Data for Training**).

In [None]:
TRAIN_DATAPATH = './example_dataset/Sinitcyn_train_data.csv'

max_pep_length = 40
BATCH_SIZE = 128
            
fine_tune_data = detectability_dataset(data_source = TRAIN_DATAPATH, 
                                       protein_data = "proteins",
                                       split_on_protein = True, 
                                       seq_length = max_pep_length, 
                                       batch_size = BATCH_SIZE, 
                                       val_ratio = 0.1, 
                                       test_ratio = 0.2, 
                                       test = False)

In [None]:
 "Training examples", BATCH_SIZE * len(fine_tune_data.train_data) 

In [None]:
 "Validation examples", BATCH_SIZE * len(fine_tune_data.val_data)

In [None]:
"Test examples", BATCH_SIZE * len(fine_tune_data.test_data)

## 6. Fine tuning the model

In the next cell, we create the model and load its weights for fine-tuning. By default, the path is set to the weights of the most recently trained base model. To use different weights, update the path to point to your desired model's weights.

In [None]:
save_path = "./output/weights/new_base_model/base_model_weights"

fine_tuned_model = detetability_model.detetability_model(num_units = num_cells,  
                                                         num_clases = total_num_classes)

fine_tuned_model.load_weights(save_path)

#### Compile the Model

Compile the model with the selected settings. You can use built-in TensorFlow options or define and pass custom settings for the optimizer, loss function, and metrics. The default configurations match those used in the original study, but you can modify these settings according to your preferences.Early stopping is also configured with the original settings, but the parameters can be adjusted based on user preferences.

In [None]:
# compile the model  with the optimizer and the metrics we want to use.

callback = tf.keras.callbacks.EarlyStopping(monitor = 'val_loss', 
                                            mode = 'min', 
                                            verbose = 1, 
                                            patience = 5)


model_checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights',
                                                      monitor = 'val_categorical_accuracy',
                                                      mode = 'max',
                                                      verbose = 1,
                                                      save_best_only = True, 
                                                      save_weights_only = True)

fine_tuned_model.compile(optimizer = 'adam',
                         loss = 'CategoricalCrossentropy',
                         metrics = 'categorical_accuracy')

We store the result of training so that we can explore the metrics and the losses later. We specify the number of epochs for training and pass the training and validation data as previously described.

In [None]:
history_fine_tuned = fine_tuned_model.fit(fine_tune_data.train_data,
                                          validation_data = fine_tune_data.val_data,
                                          epochs = 50, 
                                          callbacks=[callback, model_checkpoint])

## 7. Testing and Reporting (Fine-Tuned Model)

In the following cell, we load the best model weights obtained from fine-tuning. By default, the path points to the most recently fine-tuned model from the previous cell. Update the path if you wish to load different weights.

In [None]:
## Loading best model's weights 

fine_tuned_model.load_weights('output/weights/new_fine_tuned_model/fine_tuned_model_weights')

Generating predictions on the test data using the fine-tuned model with `model.predict`.

In [None]:
predictions_FT = fine_tuned_model.predict(fine_tune_data.test_data)

Obtaining the targets and DataFrame for the test dataset using the `get_split_targets` and `get_split_dataframe` functions from the `detectability_dataset` class.

In [None]:
test_targets_FT = fine_tune_data.get_split_targets(split = "test")
test_data_df_FT = fine_tune_data.get_split_dataframe(split = "test")

Creating a report object with the test targets, predictions, and history to generate metrics and plots for the fine-tuned model. For more details, refer to Section 4: Testing and Reporting, which provides a detailed description of the same process for the initial or base model.

In [None]:
report_FT = detectability_report(test_targets_FT, 
                                 predictions_FT, 
                                 test_data_df_FT, 
                                 output_path = "./output/report_on_Sinitcyn (Fine tuned model)", 
                                 history = history_fine_tuned, 
                                 rank_by_prot = False,
                                 threshold = None,                              
                                 name_of_dataset = 'Sinitcyn test subset',
                                 name_of_model = 'Fine tuned model (new)')

#### Predictions report (Fine-tuned model)

In [None]:
results_df_FT = report_FT.detectability_report_table
results_df_FT

Generating a complete PDF report using the `generate_report` function.

In [None]:
report_FT.generate_report()

#### Generating the Evaluation Plots for the Fine-Tuned Model

##### Training and Validation Metrics

These plots show the training and validation metrics over epochs. The first plot displays the loss, and the second shows the categorical accuracy.

In [None]:
report_FT.plot_keras_metric("loss")

In [None]:
report_FT.plot_keras_metric("categorical_accuracy")

##### ROC curve (Binary)

In [None]:
report_FT.plot_roc_curve_binary()

##### Confusion matrix (Binary)

In [None]:
report_FT.plot_confusion_matrix_binary()

##### ROC curve (Multi-class)

In [None]:
report_FT.plot_roc_curve()

##### Confusion matrix (Multi-class)

In [None]:
report_FT.plot_confusion_matrix_multiclass()

#### Heatmap of Average Error Between Actual and Predicted Classes

In [None]:
report_FT.plot_heatmap_prediction_prob_error()