# Peptide Detectability (Training and Fine-tuning) 

This notebook is prepared to be run in Google [Colaboratory](https://colab.research.google.com/). In order to train the model faster, please change the runtime of Colab to use Hardware Accelerator, either GPU or TPU.

This notebook provides a concise walkthrough of the process for reading a dataset, training, and fine-tuning a model for peptide detectability prediction. 

The dataset used in this example is derived from:

- **ProteomTools Dataset**: Includes data from the PRIDE repository with the following identifiers: `PXD004732`, `PXD010595`, and `PXD021013`.
- **MAssIVE Dataset**: Deposited in the ProteomeXchange Consortium via the MAssIVE partner repository with the identifier `PXD024364`.
  

#### Installing the DLOmix Package

If you have not installed the DLOmix package yet, you need to do so before running the code. 

You can install the DLOmix package using pip.

In [None]:
# uncomment the following line to install the DLOmix package in the current environment using pip

#!python -m pip install dlomix>0.1.3

#### Importing Required Libraries

Before running the code, ensure you import all the necessary libraries. These imports are essential for accessing the functionalities needed for data processing, model training, and evaluation.

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
import dlomix
import sys
import os

In [None]:
dlomix.__version__

## 1. Load Data for Training

You can import the `DetectabilityDataset` class and create an instance to manage data for training, validation, and testing. This instance handles TensorFlow dataset objects and simplifies configuring and controlling how your data is preprocessed and split.

For the paramters of the dataset class, please refer to the DLOmix documentation: https://dlomix.readthedocs.io/en/main/dlomix.data.html#


**Note**: If class labels are provided, the following encoding scheme should be used:
- **Non-Flyer**: 0
- **Weak Flyer**: 1
- **Intermediate Flyer**: 2
- **Strong Flyer**: 3

In [None]:
from dlomix.data import DetectabilityDataset

In [None]:
from dlomix.constants import CLASSES_LABELS, alphabet, aa_to_int_dict

In [None]:
CLASSES_LABELS, len(alphabet), aa_to_int_dict

In [None]:
max_pep_length = 40
BATCH_SIZE = 128 
            
# The Class handles all the inner details, we have to provide the column names and the alphabet for encoding
# If the data is already split with a specific logic (which is generally recommended) -> val_data_source and test_data_source are available as well

hf_data = "Wilhelmlab/detectability-proteometools"
detectability_data = DetectabilityDataset(data_source=hf_data,
                                          data_format='hub',
                                          max_seq_len=max_pep_length,
                                          label_column="Classes",
                                          sequence_column="Sequences",
                                          dataset_columns_to_keep=None,
                                          batch_size=BATCH_SIZE,
                                          with_termini=False,
                                          alphabet=aa_to_int_dict)

In [None]:
# This is the dataset with train, val, and test splits  
# You can see the column names under each split (the columns starting with _ are internal, but can also be used to look up original sequences for example "_parsed_sequence")
detectability_data

In [None]:
# Accessing elements in the dataset is done by specificing the split name and then the column name
# Example here for one sequence after encoding & padding comapred to the original sequence

detectability_data["train"]["Sequences"][0], "".join(detectability_data["train"]["_parsed_sequence"][0])

## 2. Model

We can now create the model. The model architecture is an encoder-decoder with an attention mechanism, that is based on Bidirectional Recurrent Neural Network (BRNN) with Gated Recurrent Units (GRU). Both the Encoder and Decoder consists of a single layer, with the Decoder also including a Dense layer. The model has the default working arguments.

In [None]:
from dlomix.models import DetectabilityModel

In [None]:
total_num_classes = len(CLASSES_LABELS)
input_dimension = len(alphabet)
num_cells = 64

model = DetectabilityModel(num_units = num_cells, num_clases = total_num_classes)

## 3. Training and saving the model

You can train the model using the standard Keras approach. The training parameters provided here are those initially configured for the detectability model. However, you have the flexibility to modify these parameters to suit your specific needs.

#### Compile the Model

Compile the model with the selected settings. You can use built-in TensorFlow options or define and pass custom settings for the optimizer, loss function, and metrics. The default configurations match those used in the original study, but you can modify these settings according to your preferences.

Early stopping is also configured with the original settings, but the parameters can be adjusted based on user preferences. Early stopping monitors a performance metric (e.g., validation loss) and halts training when no improvement is observed for a specified number of epochs. This feature helps prevent overfitting and ensures efficient training.

In [None]:
callback = tf.keras.callbacks.EarlyStopping(monitor = 'val_loss', 
                                            mode = 'min', 
                                            verbose = 1, 
                                            patience = 5)


model_save_path = 'output/weights/new_base_model/base_model_weights_detectability'

model_checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path,
                                                      monitor='val_sparse_categorical_accuracy',
                                                      mode='max',
                                                      verbose=1,
                                                      save_best_only=True, 
                                                      save_weights_only=True)

model.compile(optimizer='adam',
              loss='SparseCategoricalCrossentropy', 
              metrics='sparse_categorical_accuracy')

We save the results of the training process to enable a detailed examination of the metrics and losses at a later stage. We define the number of epochs for training and supply the training and validation data previously generated. This approach allows us to effectively monitor the model’s performance and make any necessary adjustments.

In [None]:
# Access to the tensorflow datasets is done by referencing the tensor_train_data or tensor_val_data

history = model.fit(detectability_data.tensor_train_data,
                    validation_data = detectability_data.tensor_val_data,
                    epochs = 50, 
                    callbacks=[callback, model_checkpoint])

## 4. Testing and Reporting


We use the test dataset to assess our model's performance, which is only applicable if labels are available. The `DetectabilityReport` class allows us to compute various metrics, generate reports, and create plots for a comprehensive evaluation of the model.

Note: The reporting module is currently under development, so some features may be unstable or subject to change.

In the next cell, set the path to the model weights. By default, it points to the newly trained base model. If using different weights, update the path accordingly.

In [None]:
# edit the path to save the trained model
model_save_path = 'output/weights/new_base_model/base_model_weights_detectability'

In [None]:
## Loading best model weights 

model.load_weights(model_save_path)

#### Generate Predictions on Test Data Using `model.predict`

To obtain predictions for your test data, use the Keras `model.predict` method. Simply pass your test dataset to this method, and it will return the model's predictions.

In [None]:
predictions = model.predict(detectability_data.tensor_test_data)

In [None]:
predictions.shape

To generate reports and calculate evaluation metrics against predictions, we obtain the targets and the data for the specific dataset split. This can be achieved using the `DetectabilityDataset` class directly.

In [None]:
# access val dataset and get the Classes column
test_targets = detectability_data["test"]["Classes"]


# if needed, the decoded version of the classes can be retrieved by looking up the class names
test_targets_decoded = [CLASSES_LABELS[x] for x in test_targets]


test_targets[0:5], test_targets_decoded[0:5]

In [None]:
# The dataframe needed for the report

test_data_df = pd.DataFrame(
    {
        "Sequences": detectability_data["test"]["_parsed_sequence"], # get the raw parsed sequences
        "Classes": test_targets, # get the test targets from above
#         "Proteins": detectability_data["test"]["Proteins"] # get the Proteins column from the dataset object (if the dataset has "Proteins" column)
    }
)

test_data_df.Sequences = test_data_df.Sequences.apply(lambda x: "".join(x)) # join the sequences since they are a list of string amino acids.
test_data_df.head(5)

In [None]:
from dlomix.reports.DetectabilityReport import DetectabilityReport, predictions_report
WANDB_REPORT_API_DISABLE_MESSAGE=True

#### Generate a Report Using the `DetectabilityReport` Class

The `DetectabilityReport` class provides a comprehensive way to evaluate your model by generating detailed reports and visualizations. The outputs include:

1. **A PDF Report**: This includes evaluation metrics and plots.
2. **A CSV File**: Contains the model’s predictions.
3. **Independent Image Files**: Visualizations are saved as separate image files.

To generate a report, provide the following parameters to the `DetectabilityReport` class:

- **targets**: The true labels for the dataset, which are used to assess the model’s performance.
- **predictions**: The model’s output predictions for the dataset, which will be compared against the true labels.
- **input_data_df**: The DataFrame containing the input data used for generating predictions.
- **output_path**: The directory path where the generated reports, images, and CSV file will be saved.
- **history**: The training history object (e.g., containing metrics from training) if available. Set this to `None` if not applicable, such as when the report is generated for predictions without training.
- **rank_by_prot**: A boolean indicating whether to rank peptides based on their associated proteins (`True` or `False`). Defaults to `False`.
- **threshold**: The classification threshold used to adjust the decision boundary for predictions. By default, this is set to `None`, meaning no specific threshold is applied.
- **name_of_dataset**: The name of the dataset used for generating predictions, which will be included in the report to provide context.
- **name_of_model**: The name of the model used to generate the predictions, which will be specified in the report for reference.

Note: The reporting module is currently under development, so some features may be unstable or subject to change.

In [None]:
# Since the detectabiliy report expects the true labels in one-hot encoded format, we expand them here.

num_classes = np.max(test_targets) + 1
test_targets_one_hot = np.eye(num_classes)[test_targets]

In [None]:
report = DetectabilityReport(targets = test_targets_one_hot, 
                             predictions = predictions, 
                             input_data_df = test_data_df,
                             output_path = "./output/report_on_ProteomeTools",
                             history = history, 
                             rank_by_prot = False,
                             threshold = None,
                             name_of_dataset = 'ProteomeTools',
                             name_of_model = 'Base model (new)')

#### Predictions report

In [None]:
results_df = report.detectability_report_table
results_df.head(5)

#### Generating Evaluation Plots with `DetectabilityReport`

The `DetectabilityReport` class enables you to generate a range of plots to visualize and evaluate model performance. It offers a comprehensive suite of visualizations to help you interpret the results of your model's predictions. Here’s how to use it:

##### Training and Validation Metrics

These plots show the training and validation metrics over epochs. The first plot displays the loss, and the second shows the categorical accuracy. Both plots are generated from the `history` object recorded during the model training process.

In [None]:
report.plot_keras_metric("loss")

In [None]:
report.plot_keras_metric("sparse_categorical_accuracy")

##### ROC curve (Binary)

In [None]:
report.plot_roc_curve_binary()

##### Confusion matrix (Binary)

In [None]:
report.plot_confusion_matrix_binary()

##### ROC curve (Multi-class)

In [None]:
report.plot_roc_curve()

##### Confusion matrix (Multi-class)

In [None]:
report.plot_confusion_matrix_multiclass()

#### Heatmap of Average Error Between Actual and Predicted Classes

In [None]:
report.plot_heatmap_prediction_prob_error()

We can also produce a complete report with all the relevant plots in one PDF file by calling the `generate_report` function.

In [None]:
report.generate_report()

### Example: Defining a Classification Threshold

In the following example, a specific classification threshold is defined to adjust the decision boundary for the model's predictions. By setting a threshold, you can control the sensitivity of the model, influencing how it categorizes the output into different classes.

In [None]:
report_using_threshold = DetectabilityReport(test_targets_one_hot, 
                                             predictions, 
                                             test_data_df, 
                                             output_path = "./output/report_on_ProteomeTools_with_threshold", 
                                             history = history, 
                                             rank_by_prot = False,
                                             threshold = 0.5,                              
                                             name_of_dataset = 'ProteomeTools',
                                             name_of_model = 'Base model (new) with threshold')

#### Predictions report 

In [None]:
report_using_threshold.detectability_report_table.head(5)

Generating a complete PDF report using the `generate_report` function.

In [None]:
report_using_threshold.generate_report()

## 5. Load data for fine tuning

For fine-tuning, the process mirrors the steps used during training. Simply create a `DetectabilityDataset` object with the fine-tuning data (refer to **Section 1: Load Data for Training**).

In [None]:
max_pep_length = 40
BATCH_SIZE = 128 
            
# The Class handles all the inner details, we have to provide the column names and the alphabet for encoding
# If the data is already split with a specific logic (which is generally recommended) -> val_data_source and test_data_source are available as well

hf_data = "Wilhelmlab/detectability-sinitcyn"
fine_tune_data = DetectabilityDataset(data_source=hf_data,
                                      data_format='hub',
                                      max_seq_len=max_pep_length,
                                      label_column="Classes",
                                      sequence_column="Sequences",
                                      dataset_columns_to_keep=['Proteins'],
                                      batch_size=BATCH_SIZE,
                                      with_termini=False,
                                      alphabet=aa_to_int_dict)

In [None]:
fine_tune_data

## 6. Fine tuning the model

In the next cell, we create the model and load its weights for fine-tuning. By default, the path is set to the weights of the most recently trained base model. To use different weights, update the path to point to your desired model's weights.

In [None]:
# define again if not in environment from training run
load_model_path = model_save_path #'output/weights/new_base_model/base_model_weights_detectability'

fine_tuned_model = DetectabilityModel(num_units = num_cells,  
                                      num_clases = total_num_classes)

fine_tuned_model.load_weights(load_model_path)

#### Compile the Model

Compile the model with the selected settings. You can use built-in TensorFlow options or define and pass custom settings for the optimizer, loss function, and metrics. The default configurations match those used in the original study, but you can modify these settings according to your preferences.Early stopping is also configured with the original settings, but the parameters can be adjusted based on user preferences.

In [None]:
# compile the model  with the optimizer and the metrics we want to use.
callback_FT = tf.keras.callbacks.EarlyStopping(monitor = 'val_loss', 
                                               mode = 'min', 
                                               verbose = 1, 
                                               patience = 5)


model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability'

model_checkpoint_FT = tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path_FT,
                                                         monitor='val_sparse_categorical_accuracy', 
                                                         mode='max',
                                                         verbose=1,
                                                         save_best_only=True, 
                                                         save_weights_only=True)

fine_tuned_model.compile(optimizer='adam',
                         loss='SparseCategoricalCrossentropy', 
                         metrics='sparse_categorical_accuracy') 

We store the result of training so that we can explore the metrics and the losses later. We specify the number of epochs for training and pass the training and validation data as previously described.

In [None]:
history_fine_tuned = fine_tuned_model.fit(fine_tune_data.tensor_train_data,
                                          validation_data=fine_tune_data.tensor_val_data,
                                          epochs=50, 
                                          callbacks=[callback_FT, model_checkpoint_FT])

## 7. Testing and Reporting (Fine-Tuned Model)

In the following cell, we load the best model weights obtained from fine-tuning. By default, the path points to the most recently fine-tuned model from the previous cell. Update the path if you wish to load different weights.

In [None]:
## Loading best model weights 

model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability'

fine_tuned_model.load_weights(model_save_path_FT)

Generating predictions on the test data using the fine-tuned model with `model.predict`.

In [None]:
predictions_FT = fine_tuned_model.predict(fine_tune_data.tensor_test_data)

In [None]:
predictions_FT.shape

To generate reports and calculate evaluation metrics against predictions, we obtain the targets and the data for the specific dataset split. This can be achieved using the DetectabilityDataset class directly.

In [None]:
# access val dataset and get the Classes column
test_targets_FT = fine_tune_data["test"]["Classes"]


# if needed, the decoded version of the classes can be retrieved by looking up the class names
test_targets_decoded_FT = [CLASSES_LABELS[x] for x in test_targets_FT]


test_targets_FT[0:5], test_targets_decoded_FT[0:5]

In [None]:
# The dataframe needed for the report

test_data_df_FT = pd.DataFrame(
    {
        "Sequences": fine_tune_data["test"]["_parsed_sequence"], # get the raw parsed sequences
        "Classes": test_targets_FT, # get the test targets from above
        "Proteins": fine_tune_data["test"]["Proteins"] # get the Proteins column from the dataset object
    }
)

test_data_df_FT.Sequences = test_data_df_FT.Sequences.apply(lambda x: "".join(x)) # join the sequences since they are a list of string amino acids.
test_data_df_FT.head(5)

Creating a report object with the test targets, predictions, and history to generate metrics and plots for the fine-tuned model. For more details, refer to Section 4: Testing and Reporting, which provides a detailed description of the same process for the initial or base model.

In [None]:
# Since the detectabiliy report expects the true labels in one-hot encoded format, we expand them here. 

num_classes = np.max(test_targets_FT) + 1
test_targets_FT_one_hot = np.eye(num_classes)[test_targets_FT]
test_targets_FT_one_hot.shape, len(test_targets_FT)

In [None]:
report_FT = DetectabilityReport(test_targets_FT_one_hot, 
                                predictions_FT, 
                                test_data_df_FT, 
                                output_path = './output/report_on_Sinitcyn (Fine-tuned model)', 
                                history = history_fine_tuned, 
                                rank_by_prot = True,
                                threshold = None,                              
                                name_of_dataset = 'Sinitcyn test dataset',
                                name_of_model = 'Fine tuned model (new)')

#### Predictions report (Fine-tuned model)

In [None]:
results_df_FT = report_FT.detectability_report_table
results_df_FT

Generating a complete PDF report using the `generate_report` function.

In [None]:
report_FT.generate_report()

#### Generating the Evaluation Plots for the Fine-Tuned Model

##### Training and Validation Metrics

These plots show the training and validation metrics over epochs. The first plot displays the loss, and the second shows the categorical accuracy.

In [None]:
report_FT.plot_keras_metric("loss")

In [None]:
report_FT.plot_keras_metric("sparse_categorical_accuracy")

##### ROC curve (Binary)

In [None]:
report_FT.plot_roc_curve_binary()

##### Confusion matrix (Binary)

In [None]:
report_FT.plot_confusion_matrix_binary()

##### ROC curve (Multi-class)

In [None]:
report_FT.plot_roc_curve()

##### Confusion matrix (Multi-class)

In [None]:
report_FT.plot_confusion_matrix_multiclass()

#### Heatmap of Average Error Between Actual and Predicted Classes

In [None]:
report_FT.plot_heatmap_prediction_prob_error()