<a href="https://colab.research.google.com/github/agemagician/ProtTrans/blob/master/Prediction/ProtBert_BFD_Predict_MS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<h3>Stability Score prediction using ProtTrans model</h3>

**1. Load necessry libraries including huggingface transformers**

In [None]:
!pip install -q transformers

In [None]:
!pip install tape-proteins

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
import re
from typing import Union, List, Tuple, Sequence, Dict, Any, Optional
from copy import copy
from pathlib import Path
import pickle as pkl
import logging
import random
import tqdm
import numpy as np
import pandas as pd
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
from sklearn.metrics import r2_score, mean_squared_error
from scipy.stats import spearmanr
import torch
import torch.nn as nn
import lmdb
from tape import ProteinBertModel, TAPETokenizer
from transformers import T5Tokenizer, T5EncoderModel
import torch.nn.functional as F
from torch.utils.data import Dataset
from scipy.spatial.distance import pdist, squareform

**2. Load the sequence classification pipeline and load it into the GPU if avilabile**

In [None]:
#run with gpu
pipeline = TextClassificationPipeline(
    model=AutoModelForSequenceClassification.from_pretrained("Rostlab/prot_bert_bfd_membrane"),
    tokenizer=AutoTokenizer.from_pretrained("Rostlab/prot_bert_bfd_membrane"),
    device=0
)

**3. Load sequences and labels from TAPE**

In [None]:
logger = logging.getLogger(__name__)

class LMDBDataset(Dataset):
    """Creates a dataset from an lmdb file.
    Args:
        data_file (Union[str, Path]): Path to lmdb file.
        in_memory (bool, optional): Whether to load the full dataset into memory.
            Default: False.
    """

    def __init__(self,
                 data_file: Union[str, Path],
                 in_memory: bool = False):

        data_file = Path(data_file)
        if not data_file.exists():
            raise FileNotFoundError(data_file)

        env = lmdb.open(str(data_file), max_readers=1, readonly=True,
                        lock=False, readahead=False, meminit=False)

        with env.begin(write=False) as txn:
            num_examples = pkl.loads(txn.get(b'num_examples'))

        if in_memory:
            cache = [None] * num_examples
            self._cache = cache

        self._env = env
        self._in_memory = in_memory
        self._num_examples = num_examples

    def __len__(self) -> int:
        return self._num_examples

    def __getitem__(self, index: int):
        if not 0 <= index < self._num_examples:
            raise IndexError(index)

        if self._in_memory and self._cache[index] is not None:
            item = self._cache[index]
        else:
            with self._env.begin(write=False) as txn:
                item = pkl.loads(txn.get(str(index).encode()))
                if 'id' not in item:
                    item['id'] = str(index)
                if self._in_memory:
                    self._cache[index] = item
        return item

In [None]:
train_data = LMDBDataset('stability_train.lmdb')
test_data = LMDBDataset('stability_test.lmdb') 
print(train_data.__getitem__(900))

{'id': b'EEHEE_rd1_0307.pdb_hp', 'primary': 'GSWKGYATANKKQAPTEEYLKDNAEQSGVDYAEKTKGKLEVDK', 'protein_length': 43, 'topology': b'EEHEE', 'parent': b'EEHEE_rd1_0307', 'stability_score': array([-1.], dtype=float32)}


In [None]:
train_sequences=[]
train_stability_scores=[]
for i in range(len(train_data)):
    train_sequences.append(train_data.__getitem__(i)['primary'])
    train_stability_scores.append(train_data.__getitem__(i)['stability_score'][0])
print(train_sequences)
print(train_stability_scores)

['GSQEVNSGTQTYKNASPEEAERIARKAGATTWTEKGNKWEIRI', 'GSTTIEEAQNKKYQAEPRSWTKAGRTIGGKNWETEVNRAEASI', 'GSRETKKITTVGARGEATAEQAATEEGPKNNSRISNYKEQWWI', 'GSYELEVGNYRYRADDPEQLKEEAKKVGARDVQTDGNNFKVRG', 'GSYKGDLLYENREADNVYKATKRGRDPFGERGKEDVQNVEVQA', 'GSYAKDKEGLYDAGYGTRRPEVADRGNEEKVKLNVNEQQVDRF', 'GSWRVHFRGETYTADTEDDAKQLAKDAGARRIESSNGEVRVEL', 'GSIRLEVRGHTQSASNRKDATRAVTDGWGKDVEEYDLEARAEF', 'GSEELTHNWDEEFAGVRQTGDSATRYGVAKAVKERDLASRDRI', 'GSYTIEENGEKYTFRTRDEAEEWARRQGAQTVETRGTELRSRA', 'GSKEAYETITQRTARSEDNGEEWFRERAARQLETRGYTVTREG', 'GSVEEETEEWIAQARSNTDRQERLEKAYRGTGYRTFERRTATG', 'GSVELEDNGRRVEATSTQEARDRAKKEGATTWTESGTRIEVRG', 'GSAEVKTTVRDGTGKDEERGTNRIETRWARTGSERAEQVSAEL', 'GSQEERTESSVTEKRTRTGTAIDNEAEKWLVRTVGGAERDARG', 'GSWEVEVQGKRYEASTEDEAEEWARREGATEIRTDGNRIEVRK', 'GSIEAYIRVRERSVRQEREGEEAGTKDGARNWEETAEKWDVET', 'GSEATTKREENRAAWWVGERVVYEEIGDRETEIAKDQRSEEGR', 'GSQTVKDGTKTIDTDNADETAEKLARKYNGTYEKRGDEVEVRW', 'GSDIYVKERNGVATRLTEKDATETDRDDTAEKQEGKYVWTNKG', 'GSVRVEINGQEYDADTTEEAKRWAKEQGARKIQTKGTKLEVHK', 'GSGEITVQARQ

In [None]:
test_sequences=[]
test_stability_scores=[]
for i in range(len(test_data)):
    test_sequences.append(test_data.__getitem__(i)['primary'])
    test_stability_scores.append(test_data.__getitem__(i)['stability_score'][0])
print(test_sequences)
print(test_stability_scores)

['TTIKVNGQEYTVPLSPEQAAKAAKKRWPDYEVQIHGNTVKVTR', 'TTIKVNGQEYTVPLSPEQDAKAAKKRWPDYEVQIHGNTVKVTR', 'TTIKVNGQEYTVPLSPEQEAKAAKKRWPDYEVQIHGNTVKVTR', 'TTIKVNGQEYTVPLSPEQFAKAAKKRWPDYEVQIHGNTVKVTR', 'TTIKVNGQEYTVPLSPEQGAKAAKKRWPDYEVQIHGNTVKVTR', 'TTIKVNGQEYTVPLSPEQHAKAAKKRWPDYEVQIHGNTVKVTR', 'TTIKVNGQEYTVPLSPEQIAKAAKKRWPDYEVQIHGNTVKVTR', 'TTIKVNGQEYTVPLSPEQKAKAAKKRWPDYEVQIHGNTVKVTR', 'TTIKVNGQEYTVPLSPEQLAKAAKKRWPDYEVQIHGNTVKVTR', 'TTIKVNGQEYTVPLSPEQMAKAAKKRWPDYEVQIHGNTVKVTR', 'TTIKVNGQEYTVPLSPEQNAKAAKKRWPDYEVQIHGNTVKVTR', 'TTIKVNGQEYTVPLSPEQPAKAAKKRWPDYEVQIHGNTVKVTR', 'TTIKVNGQEYTVPLSPEQQAKAAKKRWPDYEVQIHGNTVKVTR', 'TTIKVNGQEYTVPLSPEQRAKAAKKRWPDYEVQIHGNTVKVTR', 'TTIKVNGQEYTVPLSPEQSAKAAKKRWPDYEVQIHGNTVKVTR', 'TTIKVNGQEYTVPLSPEQTAKAAKKRWPDYEVQIHGNTVKVTR', 'TTIKVNGQEYTVPLSPEQVAKAAKKRWPDYEVQIHGNTVKVTR', 'TTIKVNGQEYTVPLSPEQWAKAAKKRWPDYEVQIHGNTVKVTR', 'TTIKVNGQEYTVPLSPEQYAKAAKKRWPDYEVQIHGNTVKVTR', 'TTIKVNGQEYTVPLSPEQADKAAKKRWPDYEVQIHGNTVKVTR', 'TTIKVNGQEYTVPLSPEQAEKAAKKRWPDYEVQIHGNTVKVTR', 'TTIKVNGQEYT

**4. Load pre-trained model and make necessary modifications**

In [None]:
# Load the pre-trained model and tokenizer
model_name = "Rostlab/prot_bert_bfd_membrane"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2, ignore_mismatched_sizes=True)
model.classifier = nn.Linear(model.classifier.in_features, 1)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Define your training data and labels
train_texts = train_sequences  # a list of strings containing your training texts
train_labels = train_stability_scores  # a list of integers or labels corresponding to your training texts

# Define your test set and labels
test_texts = test_sequences  # a list of strings containing your test texts
test_labels = test_stability_scores  # a list of integers or labels corresponding to your test texts

train_labels = torch.tensor(train_labels).unsqueeze(-1)
test_labels = torch.tensor(test_labels).unsqueeze(-1)

# Convert your training and test texts to input tensors using the tokenizer
train_encodings = tokenizer(train_texts, truncation=True, padding=True)
test_encodings = tokenizer(test_texts, truncation=True, padding=True)
train_input_ids = torch.tensor(train_encodings['input_ids']).to(torch.int64)
test_input_ids = torch.tensor(test_encodings['input_ids']).to(torch.int64)
train_attention_mask = torch.tensor(train_encodings['attention_mask']).to(torch.int64)
test_attention_mask = torch.tensor(test_encodings['attention_mask']).to(torch.int64)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


**5. Train fine-tuned ProtTrans model (modified for stability score prediction)** 

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
loss_fn = nn.MSELoss()
train_dataset = torch.utils.data.TensorDataset(train_input_ids, train_attention_mask, train_labels)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

model.train()
for epoch in range(1):
    epoch_loss = 0.0
    progress_bar = tqdm.tqdm(train_dataloader, desc=f"Epoch {epoch + 1}", leave=False)
    for batch in progress_bar:
        optimizer.zero_grad()
        input_ids, attention_mask, labels = batch
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels.float())
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        progress_bar.set_postfix({"batch_loss": loss.item()})
    avg_epoch_loss = epoch_loss / len(train_dataloader)
    print(f"Epoch {epoch + 1} average loss: {avg_epoch_loss:.4f}")
    
torch.save(model.state_dict(), "model_weights.pth")



Epoch 1 average loss: 0.4767


**6. Use trained model to make predictions and compute accuracy metrics** 

In [None]:
# Use the trained model to make predictions on the test set
with torch.no_grad():
    predictions = model(test_input_ids, attention_mask=test_attention_mask).logits.detach().cpu().numpy().flatten()

# Calculate and print R2 score
r2_score = r2_score(test_labels, predictions)

# Calculate and print RMSE
mse = mean_squared_error(test_labels, predictions)
rmse = np.sqrt(mse)

# Calculate and print Spearman's rho
rho, _ = spearmanr(test_labels, predictions)


In [None]:
print("R2 score:", r2_score)
print("RMSE:", rmse)
print("Spearman's rho:", rho)

R2 score: -0.35
RMSE: 0.44
Spearman's rho: 0.71
