# Protein Language Models Part 1
### Learning objectives: 
- Load a pre-trained pLM
- Investigate the internal representation of tokens

Optional
- Fine-Tune a pLM

------------------------------------------------
## 1. Investigate protein representation in the pLM ProtBERT
(Adapted from [DeepChem Tutorials](https://github.com/deepchem/deepchem/blob/master/examples/tutorials/ProteinLM_Tutorial0.ipynb), check the original notebook for more information.)

In [None]:
# libraries
from transformers import BertForMaskedLM, BertTokenizer, pipeline
import torch.nn.functional as F
import seaborn as sns

### Proteins of interest to be investigated
Hemoglobin is the protein responsible for transporting oxygen from the lungs to all the cells of our body via red blood cells. Hemoglobin is a great protein to interrogate the behaviors of protein language models as it is highly conserved in certain regions across species, and also slightly variable in other places. What would we expect the distribution over amino acids to look like if we mask out a highly conserved region? What about a highly diverse region? Let's find out.

Hemoglobin Sequence Homology across closely related mammals (from [Ali et. al](https://www.nature.com/articles/s41598-019-50619-w)):

In [None]:
hemoglobin_beta = {
'human':
"MVHLTPEEKSAVTALWGKVNVDEVGGEALGRLLVVYPWTQRFFESFGDLSTPDAVMGNPKVKAHGKKVLGAFSDGLAHLDNLKGTFATLSELHCDKLHVDPENFRLLGNVLVCVLAHHFGKEFTPPVQAAYQKVVAGVANALAHKYH",
'chimpanzee':
"MVHLTPEEKSAVTALWGKVNVDEVGGEALGRLLVVYPWTORFFESFGDLSTPDAVMGNPKVKAHGKKVLGAFSDGLAHLDNLKGTFATLSELHCDKLHVDPENFRLLGNVLVCVLAHHFGKEFTPPVQAAYQKVVAGVANALAHKYH",
'camel':
"MVHLSGDEKNAVHGLWSKVKVDEVGGEALGRLLVVYPWTRRFFESFGDLSTADAVMNNPKVKAHGSKVLNSFGDGLNHLDNLKGTYAKLSELHCDKLHVDPENFRLLGNVLVVVLARHFGKEFTPDKQAAYQKVVAGVANALAHRYH",
'rabbit':
"MVHLSSEEKSAVTALWGKVNVEEVGGEALGRLLVVYPWTQRFFESFGDLSSANAVMNNPKVKAHGKKVLAAFSEGLSHLDNLKGTFAKLSELHCDKLHVDPENFRLLGNVLVIVLSHHFGKEFTPQVQAAYQKVVAGVANALAHKYH",
'pig':
"MVHLSAEEKEAVLGLWGKVNVDEVGGEALGRLLVVYPWTQRFFESFGDLSNADAVMGNPKVKAHGKKVLQSFSDGLKHLDNLKGTFAKLSELHCDQLHVDPENFRLLGNVIVVVLARRLGHDFNPNVQAAFQKVVAGVANALAHKYH",
'horse':
"*VQLSGEEKAAVLALWDKVNEEEVGGEALGRLLVVYPWTQRFFDSFGDLSNPGAVMGNPKVKAHGKKVLHSFGEGVHHLDNLKGTFAALSELHCDKLHVDPENFRLLGNVLVVVLARHFGKDFTPELQASYQKVVAGVANALAHKYH",
'bovine':
"M**LTAEEKAAVTAFWGKVKVDEVGGEALGRLLVVYPWTQRFFESFGDLSTADAVMNNPKVKAHGKKVLDSFSNGMKHLDDLKGTFAALSELHCDKLHVDPENFKLLGNVLVVVLARNFGKEFTPVLQADFQKVVAGVANALAHRYH",
'sheep':
"M**LTAEEKAAVTGFWGKVKVDEVGAEALGRLLVVYPWTQRFFEHFGDLSNADAVMNNPKVKAHGKKVLDSFSNGMKHLDDLKGTFAQLSELHCDKLHVDPENFRLLGNVLVVVLARHHGNEFTPVLQADFQKVVAGVANALAHKYH"
}

### Load the model
[ProtBERT](https://arxiv.org/abs/2007.06225) is a protein language model based on the BERT model.
Load ProtBERT, use the [pre-trained Uniref100 Model](https://huggingface.co/Rostlab/prot_bert). Also load the tokeniser.

In [None]:
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )
model = BertForMaskedLM.from_pretrained("Rostlab/prot_bert", weights_only=True)
model

### See how the model recovers masked positions in Hemoglobin

In [None]:
# mask the F8 Histidine of Hemoglobin B Subunit
human_heme = list(hemoglobin_beta['human'])
human_heme[92] = "[MASK]"
masked_heme = ' '.join(human_heme)
print(masked_heme)

In [None]:
# tokenise the sequence and pass it through the model
tokenized_sequence = tokenizer(masked_heme, return_tensors='pt')
tokenized_sequence

In [None]:
# tokenise the sequence and pass it through the model
model_outs = model(**tokenized_sequence)
model_outs

In [None]:
# transform the logits
logits = model_outs.logits.squeeze()[1:-1] # Ignore SOS and EOS special tokens
print(logits.shape)
softmaxed = F.softmax(logits, dim=1).detach().numpy() # Softmax to normalize the logits to sum to 1

In [None]:
# decode the Logits Using Greedy Decoding (Max Probability at Each Timestep)
decoded_outputs = tokenizer.batch_decode(softmaxed.argmax(axis=1))
decoded_sequence = ''.join(decoded_outputs)
print(decoded_sequence)
print(f'The filled-in masked sequence is: {decoded_sequence[92]}')

**Sanity Check:** Looks like the pLM ProtBERT was able to recapitulate the correct amino acid at that position. But how confident was the model? Let's visualize the distribution at that position and see what other amino acids the  model was choosing between.


In [None]:
# visualise the Token Distribution at the F8 Histidine
import matplotlib.pyplot as plt

plt.bar(tokenizer.get_vocab().keys(), softmaxed[92])
plt.ylabel('Normalized Probability')
plt.xlabel('Model Vocabulary')
plt.title('Target Distribution at the F8 Histidine')
plt.xticks(rotation='vertical')
plt.show()

In [None]:
# visualise the Logits Map Across All Positions
plt.figure(figsize=(10,16))
sns.heatmap(softmaxed, xticklabels=tokenizer.get_vocab())
plt.show()

In [None]:
# look at a Low Confidence Region

plt.bar(tokenizer.get_vocab().keys(), softmaxed[87])
plt.ylabel('Normalized Probability')
plt.xlabel('Model Vocabulary')
plt.title('Target Distribution at Position 87')
plt.xticks(rotation='vertical')
plt.show()

In [None]:
for animal in hemoglobin_beta:
    print(f'{animal} has residue {hemoglobin_beta[animal][87]} at position 87')

------------------------------------------------
## Optional: To run the second part we need Python <3.12 and the newest DeepChem version. Skip if setting up the env is too tidious.
## 2. Fine-Tune ProtBERT for water solubility
(Adapted from [DeepChem Tutorials](https://github.com/deepchem/deepchem/blob/master/examples/tutorials/Introduction_to_ProtBERT.ipynb))

Use the [DeepLoc](https://academic.oup.com/bioinformatics/article/33/21/3387/3931857?login=false) dataset for fine-tuning.

In [None]:
# libraries
import deepchem as dc
import pandas as pd
from deepchem.models.torch_models import ProtBERT
import torch.nn as nn
import matplotlib.pyplot as plt
import os
import shutil
from urllib.request import urlopen
from rich.progress import Progress, TransferSpeedColumn

In [None]:
## datasets for fine-tuning
URL_test = "https://deepchemdata.s3.us-west-1.amazonaws.com/datasets/DeepLoc_test.csv"
URL_train = "https://deepchemdata.s3.us-west-1.amazonaws.com/datasets/DeepLoc_train.csv"
out_test = './datasets/DeepLoc_test.csv'
out_train = "./datasets/DeepLoc_train.csv" 
os.makedirs('datasets', exist_ok=True)

print(URL_test)
with Progress(*Progress.get_default_columns(), TransferSpeedColumn()) as progress:
    with urlopen(URL_test) as res:
        length = int(res.headers["Content-Length"])
        with progress.wrap_file(res, total=length) as src, open(out_test, "wb") as dest:
            shutil.copyfileobj(src, dest)
print(URL_train)
with Progress(*Progress.get_default_columns(), TransferSpeedColumn()) as progress:
    with urlopen(URL_train) as res:
        length = int(res.headers["Content-Length"])
        with progress.wrap_file(res, total=length) as src, open(out_train, "wb") as dest:
            shutil.copyfileobj(src, dest)

In [None]:
# For demo purpose we choose a subset of the orginal data
train_df = pd.read_csv(out_train)
string_lengths = train_df["protein"].apply(len)
filtered_train_df = train_df[string_lengths < 200].sample(5000)
filtered_train_df.to_csv("./datasets/DeepLoc_train_5000.csv",index=False)


test_df = pd.read_csv(out_test)
string_lengths = test_df["protein"].apply(len)
filtered_test_df = test_df[string_lengths < 200].sample(1000)
filtered_test_df.to_csv("./datasets/DeepLoc_test_1000.csv",index=False)

filtered_train_df.head()

In [None]:
# featurise the 
featurizer = dc.feat.DummyFeaturizer()
tasks = ["water soluble"]
loader = dc.data.CSVLoader(tasks=tasks,
                            feature_field="protein",
                            featurizer=featurizer)
deeploc_train_dataset = loader.create_dataset("./datasets/DeepLoc_train_5000.csv")
deeploc_test_dataset  = loader.create_dataset("./datasets/DeepLoc_test_1000.csv")

### Load the model
Load [ProtBERT](https://academic.oup.com/bioinformatics/article/38/8/2102/6502274?login=false), use the [pre-trained Uniref100 Model](https://huggingface.co/Rostlab/prot_bert). 

In [None]:
# dir for finetuning
finetune_model_dir = "finetuning/"

# Network for custom classfication task
custom_network = nn.Sequential(nn.Linear(1024, 512),
                               nn.ReLU(), nn.Linear(512, 256),
                               nn.ReLU(), nn.Linear(256, 2)) 

# ProtBERT model that can be used for fine-tuning for a downstream task
ProtBERTmodel_for_classification = ProtBERT(task='classification',
                                            model_path="Rostlab/prot_bert",
                                            n_tasks=1,
                                            cls_name="custom",
                                            classifier_net=custom_network,
                                            n_classes=2,
                                            model_dir=finetune_model_dir,
                                            batch_size=32,
                                            learning_rate=1e-5,
                                            log_frequency = 5) 

### Fine-tune the loaded model

In [None]:
# Freeze underlying ProtBERT and only train the classfier head
for param in ProtBERTmodel_for_classification.model.bert.parameters():
    param.requires_grad = False

# track the loss
all_losses = []
loss = ProtBERTmodel_for_classification.fit(deeploc_train_dataset, nb_epoch=1,all_losses = all_losses)

# Plot training loss
batches = list(range(5, 5 * (len(all_losses) + 1), 5))
plt.plot(batches, all_losses, linestyle='-', color='b')
plt.title('Training Loss over Batches')
plt.xlabel('Training Step')
plt.ylabel('Training Loss')
plt.grid(True)
plt.show()

### Evaluate the model
Use the deepchem metrics (e.g. accuracy score) to evaluate your final model

In [None]:
classification_metric = dc.metrics.Metric(dc.metrics.accuracy_score)
eval_score = ProtBERTmodel_for_classification.evaluate(deeploc_test_dataset, [classification_metric],n_classes=2)
eval_score