# Precursor Charge State Prediction

This notebook presents a short walkthrough the process of reading a dataset and training a model for precursor charge state prediction. The dataset is an example dataset extracted from a ProteomTools dataset generated in the **Chair of Bioanalytics** at the **School of Life Sciences** at the **Technical University of Munich**.

DLOmix is the framework being used and is a custom wrapper on top of Keras/TensorFlow.

In [None]:
# install the DLOmix package in the current environment using pip

!python -m pip install -q dlomix

In [None]:
import dlomix
dlomix.__version__

The available modules in the framework are as follows:

- `constants`: constants to be used in the framework (e.g. Aminoacid alphabet mapping)
- `data`:  classes for representing dataset, wrappers around HuggingFace datasets to process input data and generate tensor datasets
- `eval`: custom evaluation metrics implemented in Keras/TF to work as `metrics` for model training
- `layers`: custom layer implementation required for the different models
- `models`: different model implementations for Retention Time Prediction
- `pipelines`: complete pipelines to run a task (e.g. Retention Time prediction)

**Note**: reports and pipelines are work-in-progress, some funtionalities are not complete.

In [None]:
from dlomix import constants, data, eval, layers, models, pipelines, reports
print([x for x in dir(dlomix) if not x.startswith("_")])

## Required imports

In [None]:
import tensorflow as tf

from dlomix.constants import PTMS_ALPHABET
from dlomix.data import ChargeStateDataset
from dlomix.eval import adjusted_mean_absolute_error
from dlomix.models import ChargeStatePredictor

## 1. Load Data

We can import the dataset class and create an object of type `ChargeStateDataset`. This object wraps around a Hugging Face dataset that can generate TensorFlow Dataset objects or Torch Dataset for training, validation, or testing. This can be controlled by the arguments `val_ratio`, `val_data_source`, and `test_data_source`.

The most important columns of the charge state dataset are:
* "modified_sequence", representing the peptide sequences, modifications are annotated using the UNIMOD encoding.
* "charge_state_dist", representing the relative charge state distribution per peptide. It is to be used together with the model_flavour="relative", which is the default.
* "most_abundant_charge_state", representing the most abundant charge state (as binary vector) per peptide. It is to be used together with the model_flavour="dominant".
* "observed_charge_states", representing all observed charge states (as binary vector) per peptide. It is to be used together with the model_flavour="observed".

In [None]:
DATA_PATH = "Wilhelmlab/prospect-ptms-charge"   # complete PROSPECT dataset prepared for charge state prediction
BATCH_SIZE = 128

In [None]:
d = ChargeStateDataset(
    data_format="hub",
    data_source=DATA_PATH,
    sequence_column="modified_sequence",
    label_column="charge_state_dist",   # use this column for relative charge state distribution
    max_seq_len=30,
    batch_size=BATCH_SIZE,
)

Now we have an CS dataset that can be used directly with standard or custom `Keras` models. This wrapper contains the splits we chose when creating it. In our case, they are training and validation splits. To get the TF Dataset, we call the attributes `.tensor_rain_data` and `.tensor_val_data`.

In [None]:
print("Hugging Face Dataset:", d)

print("Training examples:", len(d["train"]))
print("One training batch looks like:")
for x in d.tensor_train_data:
    print(x)
    break

print("Validation examples:", len(d["val"]))

## 2. Model

We can now create the model. We will use the relative charge state distribution version (set via the parameter `model_flavour="relative"`) of the Prosit-based Precursor Charge State Prediction model `ChargeStatePredictor`. It has default working arguments, but most of the parameters can be customized.

**Note**: Important is to ensure that the padding length used for the dataset object is equal to the sequence length passed to the model.

The three model flavours of `ChargeStatePredictor` are:

1. Dominant Charge State Prediction:
   - Task: Predict the dominant charge state of a given peptide sequence.
   - Model: Uses a deep learning model (RNN-based) inspired by Prosit's architecture to predict the most likely charge state.

2. Observed Charge State Prediction:
   - Task: Predict the observed charge states for a given peptide sequence.
   - Model: Uses a multi-label classification approach to predict all possible charge states.

3. Relative Charge State Prediction:
   - Task: Predict the proportion of each charge state for a given peptide sequence.
   - Model: Uses a regression approach to predict the proportion of each charge state.

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)

In [None]:
model = ChargeStatePredictor(
    num_classes=6, seq_length=30, alphabet=PTMS_ALPHABET, model_flavour="relative"
)

## 3. Training

We can then train the model like a standard Keras model. You can observe the decreasing loss value

In [None]:
model.compile(
    optimizer=optimizer,
    loss="mean_squared_error",
    metrics=[adjusted_mean_absolute_error],
)

In [None]:
history = model.fit(
    d.tensor_train_data, 
    validation_data=d.tensor_val_data,
    epochs=1,  # reduced for demonstration
)

### Train History

In [None]:
import matplotlib.pyplot as plt

In [None]:
def plot_learning_curves(history, title='Learning Curves'):
    history_dict = history.history
    loss = history_dict['loss']
    val_loss = history_dict.get('val_loss', [])
    
    epochs = range(1, len(loss) + 1)
    
    plt.figure(figsize=(8, 5))
    plt.plot(epochs, loss, 'b-', label='Training Loss')
    if val_loss:
        plt.plot(epochs, val_loss, 'r-', label='Validation Loss')
    plt.title(title)
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

In [None]:
plot_learning_curves(history, title='Charge State Distribution Model')

## 3. Testing

The ChargeStateDataset also contains a test dataset to test our model.

**Note**: Currently there is no reporting module available for CS prediction.

In [None]:
test_targets = d["test"]["charge_state_dist"]
test_sequences = d["test"]["modified_sequence"]

In [None]:
predictions = model.predict(test_sequences)
print(test_sequences[:5])
print(test_targets[:5])
print(predictions[:5])
print(predictions.shape, len(test_targets))

## 4. Saving and Loading Models

Models can be saved normally the same Keras models would be saved. It is better to save the weights and the not the model since it makes it easier and more platform-indepdent when loading the model again. The extra step needed is to create a model object and then load the weights.

In [None]:
# save the model weights

save_path = "./output/csd_model"
model.save_weights(save_path)

In [None]:
# models can be later loaded by creating a model object and then loading the weights

trained_model = ChargeStatePredictor(seq_length=30)
trained_model.load_weights(save_path)

In [None]:
# make predictions with the loaded model
loaded_model_predictions = trained_model.predict(test_sequences)