# Predicting Peptide Collisional Crossection from timsTOF data

In this lab, we will build a Transformer model to predict the measured collisional cross section (CCS) of a peptide from its sequence and charge state, using subsets of the original training and test data from [Meier et al](https://www.nature.com/articles/s41467-021-21352-8).
To accomplish this task, we'll create a Transformer encoder for peptide sequences and charge states, and add a feed forward neural network to predict the CCS of each peptide.

This lab makes use of Depthcharge, a Python package that Wout and Will have written to model mass spectrometry data with neural networks. Depthcharge provides nice building blocks for us to use within the PyTorch deep learning framework to build these models.

**Before proceeding with this notebook, make sure to switch a GPU runtime on Google Colab.** To do this, click `Runtime` -> `Change runtime type`, and select `GPU` under `Hardware accelerator`. If you have a box called `GPU type` we recommend selecting the `T4` GPU to run this notebook previously.

## Setup

The follow code installs the additional dependencies we'll need: Depthcharge, PyTorch Lightning, and Tensorboard. 
It also downloads the data that we'll be using, directly from the code repository from Meier et al, [here](https://github.com/theislab/DeepCollisionalCrossSection).
In the end, we are left with our data in the working directory, `combined_sm.csv`

In [None]:
%%capture
%%bash
pip install lightning tensorboard git+https://github.com/wfondrie/depthcharge.git
wget -nc https://github.com/theislab/DeepCollisionalCrossSection/raw/master/data/combined_sm.csv.tar.gz
tar -xzvf combined_sm.csv.tar.gz

## Import the libraries we'll need
To work with our data, we'll need a handful of standard data science tools (NumPy, Pandas, etc.).
For model building, we'll use PyTorch Lightning to wrap our model from Depthcharge, making it easy to train.

From Depthcharge, we'll use the following classes:
- `PeptideDataset` - This is a PyTorch Dataset that is designed to hold peptide sequences, their charge states, and additional metadata.
- `FeedForward` - This is a utility PyTorch Module for quickly creating feed forward neural networks.
- `PeptideTokenizer` - This class defines the amino acid alphabet, including modifications, that are valid tokens for our model. 
  It also tells Depthcharge how to convert a peptide sequence into tokens and back. 
  First-class support for ProForma formatted peptide sequences is included out-of-the-box.
- `PeptideTransformerEncoder` - This is a PyTorch Module that embeds the peptide and its residues using a Transformer architecture.

After importing these libraries, the following code also sets a plotting theme and a random seed for reproducibility.

In [None]:
import lightning.pytorch as pl
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
from depthcharge.data import PeptideDataset
from depthcharge.feedforward import FeedForward
from depthcharge.tokenizers import PeptideTokenizer
from depthcharge.transformers import PeptideTransformerEncoder
from lightning.pytorch.callbacks import Callback
from sklearn.preprocessing import StandardScaler

# Set our plotting theme:
sns.set_style("ticks")

# Set random seeds
pl.seed_everything(42, workers=True)

## Parse the data
With our library loaded, we can now parse the CSV file from Meier et al, and sample only 10% of the peptides from the dataset for our lab. 
The peptide sequences are provided in a MaxQuant format, which we convert to be ProForma compliant.

We then try and split the data in to training, validation, and test splits, matching the test data to that described in the paper;
The paper states that the ProteomeTools subset was used as a test set, which are denoted by using the `PT` column. 

In [None]:
# Load the data
data = (
    pd.read_csv("combined_sm.csv", index_col=0)  # Read the data
    .sample(frac=0.1)  # Use only 10% of the data
    .reset_index()  # Renumber rows
    .rename(columns={"Modified sequence": "Seq"})  # Make the column shorter
)

# Convert sequences to be ProForma-compliant.
data["Seq"] = (
    data["Seq"]
    .str.replace("_(ac)", "[Acetyl]-", regex=False)
    .str.replace("M(ox)", "M[Oxidation]", regex=False)
    .str.replace("_", "", regex=False)
)

# Verify we've accounted for all modifica†ions:
assert not data["Seq"].str.contains("(", regex=False).sum()

# Split the data
# The test data contains all of the ProteomeTools sequences.
test_df = data.loc[data["PT"], :]
data_df = data.loc[~data["PT"], :]

# Use 20% of the training set for validation.
n_train = int(0.8 * len(data_df))
train_df = data_df.iloc[:n_train, :].copy()
validation_df = data_df.iloc[n_train:, :].copy()

# Print the number in each set:
print("The training set contains", len(train_df), "peptides")
print("The validation set contains", len(validation_df), "peptides")
print("The test set contains", len(test_df), "peptides")
print("\nThis is what the training set looks like:")
train_df.head()

## Create a tokenizer
Now that we know the peptides that we want to consider, we need to create a tokenizer that accounts for all of the amino acids and modifications that may be present. The tokenizer will split the peptide strings into the amino acid residues and modifications that comprise them. Fortunately, the `PeptideTokenizer` class has a `from_proforma()` method that allows us to extract the amino acids and modifications from a collection of peptides.

In [None]:
# Create the tokenizer:
tokenizer = PeptideTokenizer.from_proforma(validation_df["Seq"])

# See the amino acid tokens:
pd.DataFrame(tokenizer.residues.items(), columns=["Token", "Mass"])

It looks like our tokenizer has captured all of the residues we expect.

## Preparing Datasets
When modeling data using PyTorch, we typically need to pur our data into a PyTorch `Dataset` and serve it to your model using a PyTorch `DataLoader`. 
Here, we use Depthcharge's `PeptideDataset` class, which handles transforming the peptide strings for modeling. 
Because this dataset is small from a memory perspective, we go ahead and load it onto the GPU as well, to increase our training speed.

We also transform the measured CCS using standard scaling, making it an easier value for the model to learn.

In [None]:
scaler = StandardScaler()

# Create a PeptideDataset holding the training data:
train_dataset = PeptideDataset(
    tokenizer,
    train_df["Seq"].to_numpy(),
    torch.tensor(train_df["Charge"].to_numpy()),
    torch.tensor(scaler.fit_transform(train_df[["CCS"]]).flatten()),
)

# Create a PeptideDatset containing the validation data:
validation_dataset = PeptideDataset(
    tokenizer,
    validation_df["Seq"].to_numpy(),
    torch.tensor(validation_df["Charge"].to_numpy()),
    torch.tensor(scaler.transform(validation_df[["CCS"]]).flatten()),
)

# Create a PeptideDataset containing the test data:
test_dataset = PeptideDataset(
    tokenizer,
    test_df["Seq"].to_numpy(),
    torch.tensor(test_df["Charge"].to_numpy()),
)

# Transfer all of the data to the GPU.
# This data is relatively small so it can all live there.
# Many datasets won't all fit on the GPU at once though.
for dset in (train_dataset, validation_dataset, test_dataset):
    tensors = []
    for data in dset.tensors:
        tensors.append(data.to("cuda"))

    dset.tensors = tuple(tensors)

Using our datsets, we create the PyTorch DataLoaders that we'll need:

In [None]:
# Create data loaders to feed data to the model:
train_loader = train_dataset.loader(batch_size=128, shuffle=True)
validation_loader = validation_dataset.loader(batch_size=1024)
test_loader = test_dataset.loader(batch_size=1024)

## Create a model

Time to create a deep learning model using PyTorch Lightning and Depthcharge! 
Our model consists of a `PeptideTransformerEncoder` module to embed peptides and a `FeedForward` module to predict CCS from the latent representation. 
With PyTorch Lightning, we need to specify the modules that comprise our model, define the optimizer(s) we will use to train it, and tell Lightning how to run the model when training, validating, and predicting.

For this task, we're trying to minimize the mean squared error (MSE) loss function:
$$ L = \frac{1}{n}\sum^{n}_{i=1}(Y_i - \hat{Y}_i)^2$$

Where $n$ is the number of peptides, $Y$ is the measured CCS, and $\hat{Y}_i$ is the predicted CCS.

In [None]:
class CCSPredictor(pl.LightningModule):
    """A Transformer model for CCS prediction."""

    def __init__(self, tokenizer, d_model, n_layers):
        """Initialize the CCSPredictor."""
        super().__init__()
        self.peptide_encoder = PeptideTransformerEncoder(
            n_tokens=tokenizer,
            d_model=d_model,
            n_layers=n_layers,
        )
        self.ccs_head = FeedForward(d_model, 1, 3)

    def step(self, batch, batch_idx):
        """A training/validation/inference step."""
        seqs, charges, *ccs = batch
        try:
            embedded, _ = self.peptide_encoder(seqs, charges)
        except IndexError as err:
            print(batch)
            raise err

        pred = self.ccs_head(embedded[:, 0, :]).flatten()
        if ccs:
            ccs = ccs[0].type_as(pred)
            loss = torch.nn.functional.mse_loss(pred, ccs)
        else:
            loss = None

        return pred, loss

    def training_step(self, batch, batch_idx):
        """A training step."""
        _, loss = self.step(batch, batch_idx)
        self.log(
            "train_loss", loss, on_step=False, on_epoch=True, prog_bar=True
        )
        return loss

    def validation_step(self, batch, batch_idx):
        """A validation step."""
        _, loss = self.step(batch, batch_idx)
        self.log(
            "validation_loss",
            loss,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
        )
        return loss

    def predict_step(self, batch, batch_idx):
        """An inference step."""
        pred, _ = self.step(batch, batch_idx)
        return pred

    def configure_optimizers(self):
        """Configure the optimizer for training."""
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
        return optimizer

## Prepare the model

With our model defined and our data loaders ready to go, its almost time to fit the model to the data.
The PyTorch Lightning `Trainer` will handle a lot of the training for us. 
We enable an early stopping criterium here, so that the trainer will stop once the MSE on our validation dataset stops improving. 
This model should take ~5 min to train.

First we need to create a model and a way to track the loss as the model learns:

In [None]:
# Create a model.
# If you have time, try changing d_model and n_layers.
model = CCSPredictor(tokenizer, d_model=32, n_layers=4)


# Create a way to log our losses.
class LossLogger(Callback):
    """A helper class to log our loss function."""

    def __init__(self):
        """Initialize the LossLogger."""
        self.history = []

    def on_train_epoch_end(self, trainer, pl_module):
        """Add the loss to the history."""
        self.history.append(dict(trainer.callback_metrics))


# Create our logger:
logger = LossLogger()

# Create the model trainer.
# If you have time, try changing max_epochs
trainer = pl.Trainer(callbacks=[logger], max_epochs=30)

## Fit the model

Now we're finally ready to fit the model!

In [None]:
trainer.fit(
    model=model,
    train_dataloaders=train_loader,
    val_dataloaders=validation_loader,
)

Congratulations! You just trained a deep learning model. 

Let's take a look at the loss curves from training now:

In [None]:
losses = pd.DataFrame(logger.history)

plt.figure()
plt.plot(losses["train_loss"], label="Training MSE")
plt.plot(losses["validation_loss"], label="Validation MSE")
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("MSE Loss")
plt.show()

## Predict on the Validation dataset

We now want to see how we've done, aside from just looking at the MSE. 
To get the predicted CCS for every peptide in our validation set, we use the `predict()` method for the trainer on our validation data loader.
We then visualize our results using a hexbin plot, which in this case si like a scatterplot + heatmap all in one.

In [None]:
pred = trainer.predict(model, validation_loader)
validation_df = validation_df.copy()
validation_df["pred"] = scaler.inverse_transform(
    torch.cat(pred).detach().cpu().numpy()[:, None]
).flatten()

plt.viridis()
plt.figure()
plt.hexbin(
    validation_df["CCS"],
    validation_df["pred"],
    mincnt=1,
    gridsize=200,
    bins="log",
)
plt.axis("equal")
plt.xlabel("Measured CCS")
plt.ylabel("Predicted CCS")
plt.show()

This looks pretty good to me. 
If we want to perform further tweaks and optimizations, we should turn back and do them now. 
If not, we're ready to get our predictions for the test set, after which we should cease trying to optimize our model.

## Predict on the Test dataset

Like with our validation data, we use the `predict()` method to get the predicted CCS for each of our test dataset peptides. When you're sure that your ready to proceed to the test set, go ahead and delete the first line in the cell below, then run it!

In [None]:
raise RuntimeError("Are you sure your ready to run this?")

trainer = pl.Trainer()
pred = trainer.predict(model, test_loader)

test_df = test_df.copy()
test_df["pred"] = scaler.inverse_transform(
    torch.cat(pred).detach().cpu().numpy()[:, None]
).flatten()

plt.viridis()
plt.figure()
plt.hexbin(
    test_df["CCS"],
    test_df["pred"],
    mincnt=1,
    gridsize=200,
    bins="log",
)
plt.axis("equal")
plt.xlabel("Measured CCS")
plt.ylabel("Predicted CCS")
plt.show()

Nice! This looks great. 