# Protein Language Models Part 1
### Learning objectives: 
- Load a pre-trained pLM
- Investigate the internal representation of tokens
- Fine-Tune a pLM

## 0. Background
### 0.1 Introduction
This DeepChem tutorial is designed to serve as an introductory primer on protein language models, a powerful and versatile method of processing protein sequence information inspired by methods from the natural language space. Over the past decade, natural language processing has shown the strength of using learned representations to encapsulate the semantic meaning of text data. Notable models like word2vec [[1]](https://arxiv.org/abs/1301.3781) and GloVe [[2]](https://aclanthology.org/D14-1162/) proved that self-supervised pre-training on large, unlabeled corpora effectively creates robust feature embeddings that maintain similarity and analogy in language. However, these models were limited in utility by their context-free embeddings. The advent of context-aware models, starting with BERT [[3]](https://aclanthology.org/N19-1423/), led to numerous sequence models applicable beyond language domains. In biology, self-supervised pre-training on protein language models has achieved state-of-the-art performance in various tasks by deriving context-aware amino acid embeddings that can be finteuned to capture information on structure [[4]](https://www.biorxiv.org/content/10.1101/2020.12.15.422761v1) and function [[5]](https://www.biorxiv.org/content/early/2023/08/24/2023.08.23.554486.full.pdf) of proteins.

This tutorial aims to provide an overview of the concepts and intuition of protein language models that are needed to work with them and understand their input/outputs, strengths, failure modes. We skip over the detailed breakdown of their architecture, but invite the community to add content as they see fit in the form of a pull request to build upon this.

**Disclaimer**: For brevity sake, we make some assumptions with familiarity to the multi-layered perceptron, neural networks, and learning by gradient descent. Additionally we assume some fluency with probability theory on matters such as discrete vs. continuous distributions, likelihood, and conditional distributions. We provide links on non-obvious topics and concepts to external sources wherever necessary to bring the audience a vetted and beginner friendly source to start learning on the more complicated topics. Follow along for a high-level overview into the reason that protein language models have been so successful across a broad range of tasks.

### 0.2 What is a language model?
Under the hood, all language models are nothing more than probability distributions over tokens, or discrete sub-sequences. In natural language, a very intuitive set of tokens are the words of a language, or perhaps even the characters. Both have their own pros and cons. For simplicity, let's work with words as tokens here, though this changes for proteins. Since the learned distribution is over discrete units (words), this distribution is a [categorical distribution](https://en.wikipedia.org/wiki/Categorical_distribution), not a [continuous](https://en.wikipedia.org/wiki/Probability_distribution) one. To make this more concrete, take for example a common language model that you have likely interacted with many times in your life: text auto-complete. Text auto-complete is a conditional language model that takes the previous words you have written and then computes the [conditional probability](http://www.stat.yale.edu/Courses/1997-98/101/condprob.htm) over all the words in its vocabulary and returns the highest probability words based on the context. If you'd like a very intuitive and fine-grained explanation for both the form and function of language models, the 3Blue1Brown [walk-through](https://www.3blue1brown.com/lessons/gpt) is a great resource that breaks down the basics of the architecture, the flow of information, and the process of training a specific LLM (GPT). In this section we skip over the architecture of the language models and instead leave them as a black-box, focusing more on the how language models learn from sequences to better motivate their use in the protein domain.

A simple way to visualize what a language model is doing in the background is to think of the language model as updating and indexing a huge square matrix of [transition probabilities](https://en.wikipedia.org/wiki/Stochastic_matrix) of size $D x D$, where $D$ is the vocabulary size of the model. Here vocabulary size refers to the number of unique words or sub-words that make up the state space of the categorical distribution. So a model that only knows the words ['a', 'boy', 'cute', 'is', 'student', 'the' 'walking'] has a vocabulary size of 7. If we start off with an untrained model that is randomly initialized, we can use a [uniform](https://www.investopedia.com/terms/u/uniform-distribution.asp) initialization we would get a transition matrix that looks something like this where we introduce a special word to designate the end of sequence (EOS):

|         | a    | boy  | cute | is   |student| the     | walking | EOS |
|---------|------|------|------|------|-------|---------|---------|-----|
| a       | 0.125| 0.125| 0.125| 0.125| 0.125 | 0.125   | 0.125   |0.125|
| boy     | 0.125| 0.125| 0.125| 0.125| 0.125 | 0.125   | 0.125   |0.125|
| cute    | 0.125| 0.125| 0.125| 0.125| 0.125 | 0.125   | 0.125   |0.125|
| is      | 0.125| 0.125| 0.125| 0.125| 0.125 | 0.125   | 0.125   |0.125|
| student | 0.125| 0.125| 0.125| 0.125| 0.125 | 0.125   | 0.125   |0.125|
| the     | 0.125| 0.125| 0.125| 0.125| 0.125 | 0.125   | 0.125   |0.125|
| walking | 0.125| 0.125| 0.125| 0.125| 0.125 | 0.125   | 0.125   |0.125|


However, if we look at some of the transition probabilities, we can immediately see that the model is not very good. For example, the probability of the word 'a' coming after 'a' should be close to 0. Same goes for the word 'the' coming after 'a'. It's pretty clear that we need some way of training this model so that we can get some realistic transition probabilities. 

### 0.3 Methods for learning language
#### 0.3.1 Causal learning
The first language models were trained on the principle of [causal language modeling](https://huggingface.co/docs/transformers/en/tasks/language_modeling), where the model is tasked with next word prediction during each training step.

$$\text{The quick brown fox jumped over the lazy -----.}$$
$$ P(x_t|x_{<t}) = ? $$

After enough rounds of this training protocol the model learns a much more plausible distribution over the words - something that looks like the following:

|         | a    | boy  | cute | is   |student |  the  | walking| EOS  |
|---------|------|------|------|------|---------|-------|--------|-----|
| a       | 0.0  | 0.5  | 0.1  | 0.05 | 0.25    | 0.05  |0.05    | 0.0 |
| boy     | 0.15 | 0.0  | 0.1  | 0.4  | 0.05    | 0.15  |.05     | 0.1 |
| cute    | 0.05 | 0.2  | 0.0  | 0.1  | 0.25    | 0.1   |0.0     | 0.3 |
| is      | 0.2  | 0.0  | 0.3  | 0.0  | 0.0     | 0.2   |0.3     | 0.0 |
| student | 0.15 | 0.05 | 0.1  | 0.5  | 0.0     | 0.05  |.05     | 0.1 |
| the     | 0.0  | 0.5  | .2   | 0.0  | 0.25    | 0.0   |0.05    | 0.0 |
| walking | 0.1   | .0   | .05  | .2   | .05    | 0.35 |0.0      | 0.3 |
| EOS     | 0.0  | 0.0   | 0.0 | 0.0   | 0.0    | 0.0   | 0.0    | 1.0 |

Here we can see that the model has learned that the words above are not typically repeated twice in a row. It assigns subject words ['boy', 'student'] after the word 'the' with higher probability than the verbs ['is', 'walking']. If we start at 'the' and sample the most likely words at each transition we can generate the following sentence as a path through the model: 'the' -> 'boy' -> 'is' -> 'walking' -> 'EOS'. This mode of sampling a word at every time step and then conditioning on the previously sampled words is known as auto-regressive generation.


### 0.3.2 The power of neural networks
A key point here that motivates the use of increasingly complicated neural networks for language modeling tasks is that with our illustrative example, we have the transition probability matrix and each step we sample from it like a [markov chain](https://en.wikipedia.org/wiki/Markov_chain). However, there are longer range dependencies in language that are not captured simply by conditioning on the previous word. So why not construct matrices that map transition probabilities between pairs of words or triples, or even more? Beyond issues of computational feasibility, this model would require that all possible n-grams would have been seen at least once during training, which greatly limits the models learning and usability. Neural networks have emerged as a great way of using large contexts and generating a neural representation of the sequence before indexing a loose transition matrix that maps between the neural representation and all the words in the vocabulary. For this reason, we often see in the language model space the distribution over the vocabulary from the last layer of a neural network as a one-dimensional probability distribution rather than a probability matrix. Keeping this in mind, we can see how these language models accommodate another method of learning language that draws from the context before **AND** after a word.

### 0.3.3 Masked language modelling
Causal language modeling has a key drawback in that sometimes the necessary context to make sense of a word in a sentence comes after the word and not before. Masked language modeling is like causal modeling, but makes use of the fact that context may come before and after the word of interest.

$$\text{The quick brown [MASK] jumped over the lazy dog.}$$
$$ P(x_t|x_{! t}) = ??$$

This approach is what underlies the powerful BERT [[3]](https://arxiv.org/pdf/1810.04805) language model, where they used a masking rate of about 15\% of the words. Amazingly, this approach has been tried on sequences other than language and has been shown to be a robust model for learning the syntax and semantics of sequential data of various modalities including time series data, videos, and yes even proteins!

### 0.4 How do protein language models (pLMs) work?
Inspired by the success of [large language models (LLMs)](https://magazine.sebastianraschka.com/p/understanding-large-language-models) in a broad variety of natural language tasks, protein language models represent a powerful new approach in understanding the syntax and semantics of protein sequences [[6]](https://arxiv.org/abs/2007.06225). These models are trained on using the masked language modeling objective to mask out portions of the sequence and infer what amino acids belong across billions of protein sequences, learning to identify patterns and relationships within the sequences that are crucial for their structure and function. This step is called pre-training and it imbues the language model with a general understanding of the structural dependencies within the language, in this case proteins.

An optional second training step known as [fine-tuning](https://medium.com/@bijit211987/the-evolution-of-language-models-pre-training-fine-tuning-and-in-context-learning-b63d4c161e49) can be applied on a pre-trained protein language model, to further train it on a specific task with protein sequence examples annotated with labels. In practice, starting from the pretrained weights has shown to have better performance than starting from randomly initialized weights as the model simply learns how to use strong representations of the inputs (learned during pretraining) instead of jointly learning the representation AND how to use it.  PLMs finetuned on the mappings between specific protein families or functional classes can significantly enhance predictive power compared to non-pretrained models, and can be applied in a number of different use cases, such as predicting binding sites or the effects of mutations.


One of the most compelling benefits of PLMs is their ability to capture coevolutionary relationships within and across protein sequences [[7]](https://www.biorxiv.org/content/10.1101/2024.01.30.577970v1.full.pdf). In the same way that words in a sentence co-occur to convey coherent meaning, amino acid residues in a protein sequence co-evolve to maintain the protein's structural integrity and functionality. PLMs capture these coevolutionary patterns, allowing for the prediction of how changes in one part of a protein may affect other parts. Thus, from a design perspective, the directed evolution task is an area where PLMs offer substantial advantages. In a directed evolution experiment, a naturally occurring protein can be mutated according to any arbitrary heuristic and is then checked if a desired function has improved. Since PLMs capture intra-sequence conditional distributions, this process can be vastly streamlined by masking portions of the protein we wish to 'mutate' and sampling from the distribution of what amino acids are strong candidates to occur given the rest of the sequence. PLMs thus have the potential to significantly reduce experimental burden by identify promising candidates a higher hit rate.

#### 0.4.1 Reconciling Sequence and Structure

Some protein language models combine in their training input amino acid sequence data with structural data, such as 3D coordinates of atoms in the protein. The goal is to explicitly incorporate structural information information, aiming to enhance the representation and ultimately prediction of unseen protein structure and functions. This is in contrast to sequence-only models that implicitly model structure which is more closely conserved across proteins via homology.

Models like ESM-1b [[4]](https://www.biorxiv.org/content/10.1101/2020.12.15.422761v1.full.pdf) and ESM-2 [[8]](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v3) are examples of sequence-only pLMs that do not explicitly incorporate 3D structural information. These sequence-based pLMs have demonstrated impressive performance on a variety of protein function prediction tasks by learning patterns from large protein sequence datasets.
However, the lack of structural information can limit the generalization capabilities of sequence-only PLMs. This is true especially for applications heavily dependent on protein structure such as contact prediction. Moreoever, the inclusion of structural information helps overcome the distributional biases that exist in the training datasets of sequences.

Structure-aware pLMs like S-PLM[[9]](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC10441326/) and ESM-Fold [[8]](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v3) are trained on both sequence and structural information, and in turn generate protein representations that encode both sequence and structural information. These models use various methods such as multi-view contrastive learning to align the sequence and structure representations in a shared latent space (S-PLM). The structural awareness enables them to achieve comparable or superior performance to specialized structure-based methods or sequence-based pLMs, particularly for applications that heavily rely on protein structure.

Interestingly, the recently released ESM-3 [[10]](https://www.evolutionaryscale.ai/blog/esm3-release) pLM reasons over sequence, structure, and function, meaning that for each protein, its sequence, structure, and function are extracted, tokenized, and partially masked during pre-training.

![image.png](https://www.biorxiv.org/content/biorxiv/early/2024/05/13/2023.08.06.552203/F1.large.jpg?width=800&height=600&carousel=1)
*The framework of S-PLM and lightweight tuning strategies for downstream supervised learning. a, The framework of S-PLM: During pretraining, the model inputs both the amino acid sequences and contact maps derived from protein structures simultaneously. After pretraining, the ESM-Adapter that generates the AA-level embeddings before the projector layer is used for downstream tasks. The entire ESM-Adapter model can be fully frozen or learnable through lightweight tuning. b, Architecture of the ESM-Adapter. c, Adapter tunning for supervised downstream tasks. d, LoRA tuning for supervised downstream tasks is implemented. Adapted from [[9]](https://www.biorxiv.org/content/10.1101/2023.08.06.552203v3).*


## 0.5 MSA-aware vs non-MSA-aware protein language models
Multiple Sequence Alignment (MSA) is a method used to align three or more biological sequences (protein or nucleic acid) to identify regions of similarity that may indicate functional, structural, or evolutionary relationships. MSAs are a cornerstone in bioinformatics for tasks such as phylogenetic analysis, structure prediction, and function annotation.

In the context of pLMs, MSA provides evolutionary context to the representations of protein sequences. PLMs can be MSA-aware and non-MSA-aware:
#### MSA-aware models

MSA-aware models, such as the MSA Transformer [[11]](https://proceedings.mlr.press/v139/rao21a.html), Evoformer (used in AlphaFold) [[12]](https://www.nature.com/articles/s41586-021-03819-2) and ESM-MSA [[11]](https://proceedings.mlr.press/v139/rao21a.html), are trained on datasets that include MSAs as input to incorporate evolutionary information and relationships between sequences to learn richer representations. They align multiple homologous sequences to capture conserved and variable regions. The rationale is that conserved regions often indicate functionally or structurally important parts of the protein, while variable regions can provide insights into evolutionary divergence and adaptation.

MSA-aware models can provide deeper insights into protein function and structure due to the evolutionary context. However, they are computationally intensive and require high-quality MSAs, which may not be available for all protein families.

#### Non-MSA-aware models
Non-MSA-aware models, such as ESMFold (ESM-2)[[8]](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v3), ProtBERT [[6]](https://arxiv.org/abs/2007.06225) and TAPE, treat each protein sequence independently and do not explicitly incorporate evolutionary information from MSAs. They are trained on large datasets of individual protein sequences, learning patterns and representations directly from the sequence data.

While they can generalize well to diverse sequences and are computationally efficient, they may miss out on the evolutionary context that can be crucial for certain tasks.

<img src="https://cdn.prod.website-files.com/621e95f9ac30687a56e4297e/64a8d21628b03e0f9f71a4fc_V2_1677884143661_5f927638-5fb9-40ac-9301-8f25c3bcf649.png" alt="image.png" width="800">

"*Multiple Sequence Alignment". BioRender.*


#### Benefits and challenges of MSA-aware models:
- Evolutionary insight: MSAs provide evolutionary information, highlighting conserved residues that are often critical for protein function and structure.
- Improved predictions: By incorporating evolutionary context, MSA-aware models can improve performance on tasks such as secondary structure prediction, contact prediction, and function annotation.
- Functional and structural understanding: MSAs help in identifying functionally important regions and understanding the structural constraints of proteins.
- Computational complexity: Generating and processing MSAs is computationally expensive and time-consuming.
- Data availability: High-quality MSAs are not available for all protein families, especially those with few known homologs.
- Model complexity: MSA-aware models are more complex and require sophisticated architectures to effectively utilize the evolutionary information.

Other considerations:
- The performance benchmark of both MSA-aware and not MSA-aware for predicting the 3D structure of proteins, as well as their function and other properties is currently an active topic of research.
- Interestingly, MSA-free models have reported ability to efficiently generate sufficiently accurate MSAs that can be used as input for the MSA-aware models.

Let's see how this works hands-on!
------------------------------------------------
## 1. Investigate protein representation in the pLM ProtBERT
(Adapted from [DeepChem Tutorials](https://github.com/deepchem/deepchem/blob/master/examples/tutorials/ProteinLM_Tutorial0.ipynb), check the original notebook for more information.)

In [None]:
# libraries
from transformers import BertForMaskedLM, BertTokenizer, pipeline
import torch.nn.functional as F
import seaborn as sns

### Proteins of interest to be investigated
![hemoglobin.png](https://www.researchgate.net/profile/Lakna-Panawala/publication/313841668/figure/fig1/AS:463461898559488@1487509335507/Structure-of-Hemoglobin.png)

Image Source: *Adapted from "Représentation simplifiée de l'hémoglobine et de l'hème". Wikimedia Commons.*


Hemoglobin is the protein responsible for transporting oxygen from the lungs to all the cells of our body via red blood cells. Hemoglobin is a great protein to interrogate the behaviors of protein language models as it is highly conserved in certain regions across species, and also slightly variable in other places. What would we expect the distribution over amino acids to look like if we mask out a highly conserved region? What about a highly diverse region? Let's find out.


**Hemoglobin Sequence Homology across closely related mammals** (from [[13]](https://www.nature.com/articles/s41598-019-50619-w)):


In [None]:
hemoglobin_beta = {
'human':
"MVHLTPEEKSAVTALWGKVNVDEVGGEALGRLLVVYPWTQRFFESFGDLSTPDAVMGNPKVKAHGKKVLGAFSDGLAHLDNLKGTFATLSELHCDKLHVDPENFRLLGNVLVCVLAHHFGKEFTPPVQAAYQKVVAGVANALAHKYH",
'chimpanzee':
"MVHLTPEEKSAVTALWGKVNVDEVGGEALGRLLVVYPWTORFFESFGDLSTPDAVMGNPKVKAHGKKVLGAFSDGLAHLDNLKGTFATLSELHCDKLHVDPENFRLLGNVLVCVLAHHFGKEFTPPVQAAYQKVVAGVANALAHKYH",
'camel':
"MVHLSGDEKNAVHGLWSKVKVDEVGGEALGRLLVVYPWTRRFFESFGDLSTADAVMNNPKVKAHGSKVLNSFGDGLNHLDNLKGTYAKLSELHCDKLHVDPENFRLLGNVLVVVLARHFGKEFTPDKQAAYQKVVAGVANALAHRYH",
'rabbit':
"MVHLSSEEKSAVTALWGKVNVEEVGGEALGRLLVVYPWTQRFFESFGDLSSANAVMNNPKVKAHGKKVLAAFSEGLSHLDNLKGTFAKLSELHCDKLHVDPENFRLLGNVLVIVLSHHFGKEFTPQVQAAYQKVVAGVANALAHKYH",
'pig':
"MVHLSAEEKEAVLGLWGKVNVDEVGGEALGRLLVVYPWTQRFFESFGDLSNADAVMGNPKVKAHGKKVLQSFSDGLKHLDNLKGTFAKLSELHCDQLHVDPENFRLLGNVIVVVLARRLGHDFNPNVQAAFQKVVAGVANALAHKYH",
'horse':
"*VQLSGEEKAAVLALWDKVNEEEVGGEALGRLLVVYPWTQRFFDSFGDLSNPGAVMGNPKVKAHGKKVLHSFGEGVHHLDNLKGTFAALSELHCDKLHVDPENFRLLGNVLVVVLARHFGKDFTPELQASYQKVVAGVANALAHKYH",
'bovine':
"M**LTAEEKAAVTAFWGKVKVDEVGGEALGRLLVVYPWTQRFFESFGDLSTADAVMNNPKVKAHGKKVLDSFSNGMKHLDDLKGTFAALSELHCDKLHVDPENFKLLGNVLVVVLARNFGKEFTPVLQADFQKVVAGVANALAHRYH",
'sheep':
"M**LTAEEKAAVTGFWGKVKVDEVGAEALGRLLVVYPWTQRFFEHFGDLSNADAVMNNPKVKAHGKKVLDSFSNGMKHLDDLKGTFAQLSELHCDKLHVDPENFRLLGNVLVVVLARHHGNEFTPVLQADFQKVVAGVANALAHKYH"
}

As we can see there is a great degree of overlap between the hemoglobin $\beta$ subunits across the animal kingdom. The part of the hemoglobin sequence that is essential to the function of carrying oxygen is the part that binds to the heme group. This is handled by a single amino acid, namely the Histidine (H) near position 92 on the beta chain, in the middle of the underlined subsequences above. Unsurprsingly, given its functional importance, the amino acid (H) at position is unchanged across all species. Can a language model recapitulate this?

### Load the model
[ProtBERT](https://arxiv.org/abs/2007.06225) is a protein language model based on the BERT model.
Load ProtBERT, use the [pre-trained Uniref100 Model](https://huggingface.co/Rostlab/prot_bert). Also load the tokeniser.

In [None]:
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )
model = BertForMaskedLM.from_pretrained("Rostlab/prot_bert", weights_only=True)
model

### **ProtBERT learned representation of Hemoglobin 𝛽**

ProtBERT [[6]](https://arxiv.org/abs/2007.06225) is a BERT style protein language model that was trained via masked amino acid modeling on Uniref100 [[14]](https://doi.org/10.1093/bioinformatics/btm098), a dataset consisting of 217 million protein sequences, and 88B amino acids. The Uniref database contains deduplicated protein sequences from UniProt where they are clustered together, and thus deduplicated, given the threhold of sequence identity between species. Uniref100 takes 100% sequence identity, while Uniref90 does 90% and Uniref50% has a cutoff of 50%. As such, ProtBERT was trained on the largest of these databases. Lets load up ProtBERT and see what it looks like.

### See how the model recovers masked positions in Hemoglobin

In [None]:
# mask the F8 Histidine of Hemoglobin B Subunit
human_heme = list(hemoglobin_beta['human'])
human_heme[92] = "[MASK]"
masked_heme = ' '.join(human_heme)
print(masked_heme)

In [None]:
# tokenise the sequence and pass it through the model
tokenized_sequence = tokenizer(masked_heme, return_tensors='pt')
tokenized_sequence

In [None]:
# tokenise the sequence and pass it through the model
model_outs = model(**tokenized_sequence)
model_outs

In [None]:
# transform the logits
logits = model_outs.logits.squeeze()[1:-1] # Ignore SOS and EOS special tokens
print(logits.shape)
softmaxed = F.softmax(logits, dim=1).detach().numpy() # Softmax to normalize the logits to sum to 1

In [None]:
# decode the Logits Using Greedy Decoding (Max Probability at Each Timestep)
decoded_outputs = tokenizer.batch_decode(softmaxed.argmax(axis=1))
decoded_sequence = ''.join(decoded_outputs)
print(decoded_sequence)
print(f'The filled-in masked sequence is: {decoded_sequence[92]}')

**Sanity Check:** Looks like the pLM ProtBERT was able to recapitulate the correct amino acid at that position. But how confident was the model? Let's visualize the distribution at that position and see what other amino acids the  model was choosing between.


In [None]:
# visualise the Token Distribution at the F8 Histidine
import matplotlib.pyplot as plt

plt.bar(tokenizer.get_vocab().keys(), softmaxed[92])
plt.ylabel('Normalized Probability')
plt.xlabel('Model Vocabulary')
plt.title('Target Distribution at the F8 Histidine')
plt.xticks(rotation='vertical')
plt.show()

In [None]:
# visualise the Logits Map Across All Positions
plt.figure(figsize=(10,16))
sns.heatmap(softmaxed, xticklabels=tokenizer.get_vocab())
plt.show()

In [None]:
# look at a Low Confidence Region

plt.bar(tokenizer.get_vocab().keys(), softmaxed[87])
plt.ylabel('Normalized Probability')
plt.xlabel('Model Vocabulary')
plt.title('Target Distribution at Position 87')
plt.xticks(rotation='vertical')
plt.show()

In [None]:
for animal in hemoglobin_beta:
    print(f'{animal} has residue {hemoglobin_beta[animal][87]} at position 87')

------------------------------------------------
## Optional: To run the second part we need Python <3.12 and the newest DeepChem version. Skip if setting up the env is too tidious.
## 2. Fine-Tune ProtBERT for water solubility
(Adapted from [DeepChem Tutorials](https://github.com/deepchem/deepchem/blob/master/examples/tutorials/Introduction_to_ProtBERT.ipynb))

Use the [DeepLoc](https://academic.oup.com/bioinformatics/article/33/21/3387/3931857?login=false) dataset for fine-tuning.

### Understanding ProtBERT

ProtBERT is a specialized variant of the BERT (Bidirectional Encoder Representations from Transformers) model, specifically designed for processing protein sequences. Developed by researchers at the Rostlab, ProtBERT leverages the transformative capabilities of BERT to encode the complex and nuanced features present in amino acid sequences.

#### Key Features of ProtBERT:

1. **BERT Architecture Adaptation:** ProtBERT adapts the original BERT architecture to the unique characteristics of protein sequences. It consists of transformer layers that capture both local and global dependencies in the sequence, making it suitable for tasks ranging from masked language modeling (MLM) to sequence classification.

2. **Tokenization and Embedding:** Similar to how BERT tokenizes words in natural language, ProtBERT tokenizes amino acids in protein sequences. It uses a specialized tokenizer trained on large protein sequence databases, enabling it to generate embeddings that capture the semantic meaning and context of amino acids.

3. **Pretraining and Fine-tuning:** ProtBERT supports pretraining on large-scale protein datasets such as UniRef100 and BFD (Baker's finite difference), which helps it learn representations that generalize well across diverse protein sequences. These pretrained models can then be fine-tuned for specific tasks like protein classification (e.g., predicting membrane proteins or subcellular localization). The authors first pretrain on protein sequences with lengths less than 512, then on sequences less than 1024, and finally on sequences up to 40,000.
4. **Task-specific Adaptation:** Depending on the task, ProtBERT can be adapted with different classifier heads. For instance, it can be configured for single-label or multi-label classification tasks, allowing researchers to tailor it to specific biological questions.


### Loading ProtBERT

ProtBERT comes pretrained with models specifically trained on the Uniref100 and BFD datasets. These pretrained models are available for both Masked Language Modeling (MLM) and Sequence Classification tasks. This section covers how to load ProtBERT in different modes and provides details about the pretrained datasets available.

#### Pretrained Models:

1. **Uniref100 Model:**
   - **Description:** The Uniref100 model is pretrained on the Uniref100 [1] dataset.
   - **Usage:** Initialize ProtBERT with `model_path = 'Rostlab/prot_bert'` to load the Uniref100 pretrained model.

2. **BFD Model:**
   - **Description:** The BFD model is pretrained on the BFD(Big Fantastic Database) dataset [2][3].
   - **Usage:** Initialize ProtBERT with `model_path = 'Rostlab/prot_bert_bfd'` to load the BFD pretrained model.

#### Supported Modes:

1. **Masked Language Modeling (MLM):**
   - **Description:** ProtBERT learns to predict masked amino acids in protein sequences, facilitating a deeper understanding of amino acid relationships and sequence contexts.
   - **Usage:** Initialize ProtBERT with `task='mlm'` and specify either `model_path = 'Rostlab/prot_bert'` or `model_path = 'Rostlab/prot_bert_bfd'` for MLM tasks.

2. **Sequence Classification:**
   - **Description:** Enables classification tasks such as predicting membrane proteins, subcellular localization, or custom classifications using a user-defined classifier head.
   - **Usage:** Set `task='classification'` to utilize ProtBERT for sequence classification. Specify the `cls_name` parameter as 'LogReg', 'FFN', or 'custom' to use Logistic Regression, a simple 1-layer FFN, or a custom classifier network, respectively.
     - **Custom Task:** Set `cls_name='custom'` and provide a custom classifier head using the `classifier_net` argument. This allows users to apply a custom classifier head on top of the pretrained ProtBERT model.

References:

[1] Suzek, Baris E., et al. "UniRef: comprehensive and non-redundant UniProt reference clusters." Bioinformatics 23.10 (2007): 1282-1288.

[2] Steinegger, Martin, Milot Mirdita, and Johannes Söding. "Protein-level assembly increases protein sequence recovery from metagenomic samples manyfold." Nature methods 16.7 (2019): 603-606.

[3] Steinegger, Martin, and Johannes Söding. "Clustering huge protein sequence sets in linear time." Nature communications 9.1 (2018): 2542.

In [None]:
# libraries
import deepchem as dc
import pandas as pd
from deepchem.models.torch_models import ProtBERT
import torch.nn as nn
import matplotlib.pyplot as plt
import os
import shutil
from urllib.request import urlopen
from rich.progress import Progress, TransferSpeedColumn

In [None]:
## datasets for fine-tuning
URL_test = "https://deepchemdata.s3.us-west-1.amazonaws.com/datasets/DeepLoc_test.csv"
URL_train = "https://deepchemdata.s3.us-west-1.amazonaws.com/datasets/DeepLoc_train.csv"
out_test = './datasets/DeepLoc_test.csv'
out_train = "./datasets/DeepLoc_train.csv" 
os.makedirs('datasets', exist_ok=True)

print(URL_test)
with Progress(*Progress.get_default_columns(), TransferSpeedColumn()) as progress:
    with urlopen(URL_test) as res:
        length = int(res.headers["Content-Length"])
        with progress.wrap_file(res, total=length) as src, open(out_test, "wb") as dest:
            shutil.copyfileobj(src, dest)
print(URL_train)
with Progress(*Progress.get_default_columns(), TransferSpeedColumn()) as progress:
    with urlopen(URL_train) as res:
        length = int(res.headers["Content-Length"])
        with progress.wrap_file(res, total=length) as src, open(out_train, "wb") as dest:
            shutil.copyfileobj(src, dest)

In [None]:
# For demo purpose we choose a subset of the orginal data
train_df = pd.read_csv(out_train)
string_lengths = train_df["protein"].apply(len)
filtered_train_df = train_df[string_lengths < 200].sample(5000)
filtered_train_df.to_csv("./datasets/DeepLoc_train_5000.csv",index=False)


test_df = pd.read_csv(out_test)
string_lengths = test_df["protein"].apply(len)
filtered_test_df = test_df[string_lengths < 200].sample(1000)
filtered_test_df.to_csv("./datasets/DeepLoc_test_1000.csv",index=False)

filtered_train_df.head()

In [None]:
# featurise the 
featurizer = dc.feat.DummyFeaturizer()
tasks = ["water soluble"]
loader = dc.data.CSVLoader(tasks=tasks,
                            feature_field="protein",
                            featurizer=featurizer)
deeploc_train_dataset = loader.create_dataset("./datasets/DeepLoc_train_5000.csv")
deeploc_test_dataset  = loader.create_dataset("./datasets/DeepLoc_test_1000.csv")

### Load the model
Load [ProtBERT](https://academic.oup.com/bioinformatics/article/38/8/2102/6502274?login=false), use the [pre-trained Uniref100 Model](https://huggingface.co/Rostlab/prot_bert). 

In [None]:
# dir for finetuning
finetune_model_dir = "finetuning/"

# Network for custom classfication task
custom_network = nn.Sequential(nn.Linear(1024, 512),
                               nn.ReLU(), nn.Linear(512, 256),
                               nn.ReLU(), nn.Linear(256, 2)) 

# ProtBERT model that can be used for fine-tuning for a downstream task
ProtBERTmodel_for_classification = ProtBERT(task='classification',
                                            model_path="Rostlab/prot_bert",
                                            n_tasks=1,
                                            cls_name="custom",
                                            classifier_net=custom_network,
                                            n_classes=2,
                                            model_dir=finetune_model_dir,
                                            batch_size=32,
                                            learning_rate=1e-5,
                                            log_frequency = 5) 

### Fine-tune the loaded model

In [None]:
# Freeze underlying ProtBERT and only train the classfier head
for param in ProtBERTmodel_for_classification.model.bert.parameters():
    param.requires_grad = False

# track the loss
all_losses = []
loss = ProtBERTmodel_for_classification.fit(deeploc_train_dataset, nb_epoch=1,all_losses = all_losses)

# Plot training loss
batches = list(range(5, 5 * (len(all_losses) + 1), 5))
plt.plot(batches, all_losses, linestyle='-', color='b')
plt.title('Training Loss over Batches')
plt.xlabel('Training Step')
plt.ylabel('Training Loss')
plt.grid(True)
plt.show()

### Evaluate the model
Use the deepchem metrics (e.g. accuracy score) to evaluate your final model

In [None]:
classification_metric = dc.metrics.Metric(dc.metrics.accuracy_score)
eval_score = ProtBERTmodel_for_classification.evaluate(deeploc_test_dataset, [classification_metric],n_classes=2)
eval_score