# Fine-Tuning Protein Language Models

In [None]:
#! pip install transformers[torch] evaluate datasets requests pandas scikit-learn

In this notebook, we're going to do some transfer learning to fine-tune some large, pre-trained protein language models on tasks of interest. If that sentence feels a bit intimidating to you, don't panic - there's [a blog post](https://huggingface.co/blog/deep-learning-with-proteins) that explains the concepts here in much more detail.

The specific model we're going to use is ESM-2, which is the state-of-the-art protein language model at the time of writing (November 2022). The citation for this model is [Lin et al, 2022](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v1).

There are several ESM-2 checkpoints with differing model sizes. Larger models will generally have better accuracy, but they require more GPU memory and will take much longer to train. The available ESM-2 checkpoints (at time of writing) are:

| Checkpoint name | Num layers | Num parameters |
|------------------------------|----|----------|
| `esm2_t48_15B_UR50D`         | 48 | 15B     |
| `esm2_t36_3B_UR50D`          | 36 | 3B      |
| `esm2_t33_650M_UR50D`        | 33 | 650M    |
| `esm2_t30_150M_UR50D`        | 30 | 150M    |
| `esm2_t12_35M_UR50D`         | 12 | 35M     |
| `esm2_t6_8M_UR50D`           | 6  | 8M      |

Note that the larger checkpoints may be very difficult to train without a large cloud GPU like an A100 or H100, and the largest 15B parameter checkpoint will probably be impossible to train on **any** single GPU! Also, note that memory usage for attention during training will scale as `O(batch_size * num_layers * seq_len^2)`, so larger models on long sequences will use quite a lot of memory! We will use the `esm2_t12_35M_UR50D` checkpoint for this notebook, which should train on any Colab instance or modern GPU.

In [None]:
model_checkpoint = "facebook/esm2_t33_650M_UR50D"
#model_checkpoint = 'facebook/esm2_t36_3B_UR50D'

rep_layers = 33

# Data

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch

In [None]:
cols = ['sequence', 'ss_H']
data = data[cols]
data.rename(columns={'ss_H': 'label'}, inplace=True)
data

In [None]:
# Quick check to make sure we got it right
sequences = data['sequence'].to_list()
labels = data['label'].to_list()
assert len(sequences) == len(labels)

## Splitting the data

Since the data we're loading isn't prepared for us as a machine learning dataset, we'll have to split the data into train and test sets ourselves! We can use sklearn's function for that:

In [None]:
from sklearn.model_selection import train_test_split

train_sequences, test_sequences, train_labels, test_labels = train_test_split(sequences, labels, test_size=0.2, shuffle=True)

## Tokenizing the data

All inputs to neural nets must be numerical. The process of converting strings into numerical indices suitable for a neural net is called **tokenization**. For natural language this can be quite complex, as usually the network's vocabulary will not contain every possible word, which means the tokenizer must handle splitting rarer words into pieces, as well as all the complexities of capitalization and unicode characters and so on.

With proteins, however, things are very easy. In protein language models, each amino acid is converted to a single token. Every model on `transformers` comes with an associated `tokenizer` that handles tokenization for it, and protein language models are no different. Let's get our tokenizer!

# transformer models are downloaded here
~/.cache/huggingface/hub/ 
# from fair-esm
~/.cache/torch/hub/checkpoints

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
tokenizer(train_sequences[0], max_length=max_pad, truncation=True, padding='max_length')

input id: 0 = cls, numbers(1-22) according to the amino acid, 1 = pad
attention: 1 = actual aa, 0= padding

This looks good! We can see that our sequence has been converted into `input_ids`, which is the tokenized sequence, and an `attention_mask`. The attention mask handles the case when we have sequences of variable length - in those cases, the shorter sequences are padded with blank "padding" tokens, and the attention mask is padded with 0s to indicate that those tokens should be ignored by the model.

So now, let's tokenize our whole dataset. Note that we don't need to do anything with the labels, as they're already in the format we need.

In [None]:
train_tokenized = tokenizer(train_sequences, max_length=max_pad, truncation=True, padding='max_length')
test_tokenized = tokenizer(test_sequences, max_length=max_pad, truncation=True, padding='max_length')
test_tokenized

## Dataset creation

Now we want to turn this data into a dataset that PyTorch can load samples from. We can use the HuggingFace `Dataset` class for this, although if you prefer you can also use `torch.utils.data.Dataset`, at the cost of some more boilerplate code.

In [None]:
from datasets import Dataset

train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)

train_dataset

This looks good, but we're missing our labels! Let's add those on as an extra column to the datasets.

In [None]:
train_dataset = train_dataset.add_column("labels", train_labels)
test_dataset = test_dataset.add_column("labels", test_labels)
train_dataset

## Model loading

Next, we want to load our model. Make sure to use exactly the same model as you used when loading the tokenizer, or your model might not understand the tokenization scheme you're using!

In [None]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
#from transformers import AutoModel, TrainingArguments, Trainer

#num_labels = max(train_labels + test_labels) + 1  # Add 1 since 0 can be a label
num_labels = 1 # regression
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)

# multiple GPUs
if torch.cuda.is_available():
    model = torch.nn.DataParallel(model)
    model.cuda()

freeze_ = 6 # number of last layers to update, rest will be freeze
for i in range(rep_layers-freeze_):
    for param in model.module.base_model.encoder.layer[i].parameters():
        param.requires_grad = False


#################### here we can freeze layers ####################
# Here's an way to freeze layers for a generic transformer model
# freeze_ = 6 # number of last layers to update, rest will be freeze
# for i in range(rep_layers-freeze_):
#     for param in model.base_model.encoder.layer[i].parameters():
#         param.requires_grad = False


print()
frozen_layers = []
trainable_layers = []

for name, param in model.named_parameters():
    if "layer" in name:  # Only process names containing "layer"
        if not param.requires_grad:
            frozen_layers.append(name)
            #print(name, "False")
        else:
            trainable_layers.append(name)
            #print(name, "True")

print(f"Number of frozen layers: {int(len(frozen_layers)/16)}")
print(f"Number of trainable layers: {int(len(trainable_layers)/16)}")

These warnings are telling us that the model is discarding some weights that it used for language modelling (the `lm_head`) and adding some weights for sequence classification (the `classifier`). This is exactly what we expect when we want to fine-tune a language model on a sequence classification task!

Next, we initialize our `TrainingArguments`. These control the various training hyperparameters, and will be passed to our `Trainer`.

In [None]:
model_name = model_checkpoint.split("/")[-1]
batch_size = 64
num_epochs = 10
lr = 0.00001

args = TrainingArguments(
    f"{model_name}-finetuned-regression",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=lr,  # Consider experimenting with this based on model performance
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_epochs,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",  # More suitable for regression
    push_to_hub=False,
    logging_steps=50,  # Adjusted for more frequent logging, modify as needed
    # Optional settings based on your dataset and compute resources:
    # gradient_accumulation_steps=2,  # Use if effective batch size needs to be larger
    # lr_scheduler_type="linear",  # Linear scheduler can be effective with a warmup phase
    # warmup_steps=50,  # Number of warmup steps, adjust as needed
)


**weight_decay** adds a penalty to the loss function based on the magnitude of the weights in the model. This penalty discourages the model from having large weights, which can lead to overfitting

Next, we define the metric we will use to evaluate our models and write a `compute_metrics` function. We can load this from the `evaluate` library.

In [None]:
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)



# def compute_metrics(eval_pred):
#     predictions, labels = eval_pred
#     #predictions = predictions.squeeze()  
#     predictions = np.argmax(predictions, axis=1)

#     mse = mean_squared_error(labels, predictions)
#     mae = mean_absolute_error(labels, predictions)
#     r2 = r2_score(labels, predictions)

#     return {
#         'mean_squared_error': mse,
#         'mean_absolute_error': mae,
#         'r2_score': r2
#     }


And at last we're ready to initialize our `Trainer`:

In [None]:
from transformers.trainer_callback import EarlyStoppingCallback


trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.001)]
)


**early_stopping_patience** evaluating your model after each epoch, then if the validation loss (or another specified metric) doesn't improve for 3 consecutive evaluations (i.e., epochs), the training will be stopped.

**early_stopping_threshold** is the improvement in the evaluation metric (e.g., a decrease in validation loss) must be at least 0.001 for the evaluation to be considered "better."

We can now finetune our model by just calling the `train` method:

In [None]:
trainer.train()

## Model evaluation

In [None]:
res = pd.read_table('fine_tuning/fine_tunning_esm2/results_finetunning.txt').head(4)
sns.lineplot(x=res['Epoch'], y=res['Training Loss'], label='Train eval loss')
sns.lineplot(x=res['Epoch'], y=res['Validation Loss'], label='Test eval loss')

In [None]:
pg_val = pandas.read_excel('/stor/work/Wilke/luiz/DMS_ML_AMP/data/pg1_muts_validation_set.xlsx', usecols=['ID', 'Sequence', 'MIC MH', '%hemo'])
pg_val.replace('>', '', regex=True, inplace=True)
pg_val['label'] = [1 if x <= 16 else 0 for x in pg_val['MIC MH'].astype(float)]

test_seqs = df_val['Sequence'].to_list()
true_labels  = df_val['label'].to_list()

In [None]:
test_tok = tokenizer(test_seqs)
test_ = Dataset.from_dict(test_tok)
test_ = test_.add_column("labels", true_labels)

In [None]:
trainer.evaluate(test_)

In [None]:
predictions = trainer.predict(test_)
predicted_classes = np.argmax(predictions[0], axis=1)
accuracy = np.mean(predicted_classes == true_labels)
print(f"Accuracy: {accuracy * 100:.2f}%")

In [None]:
from sklearn import metrics
import seaborn as sns
import matplotlib.pyplot as plt

conf_m = metrics.confusion_matrix(true_labels, predicted_classes)

group_names = ['True Neg','False Pos','False Neg','True Pos']

group_counts = ["{0:0.0f}".format(value) for value in conf_m.flatten()]

group_percentages = ["{0:.2%}".format(value) for value in conf_m.flatten()/np.sum(conf_m)]

labels = [f"{v1}\n{v2}\n{v3}" for v1, v2, v3 in zip(group_names,group_counts,group_percentages)]

labels = np.asarray(labels).reshape(2,2)


labelsx = ['Non_active', 'Active']
sns.heatmap(conf_m, annot=labels, fmt='', cmap="YlGnBu", yticklabels=labelsx, xticklabels=labelsx)
plt.title('Confusion Matrix', fontsize=16)
plt.ylabel('Actual label', size=14)
plt.xlabel('Predicted label', size=14)
plt.tight_layout()