In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from Bio import SeqIO
from transformers import T5Tokenizer, T5EncoderModel
import torch
import re
from io import StringIO
from Bio import SeqIO
import os
import pandas as pd
import tqdm
from IPython.display import clear_output
import sys
import os
import subprocess

In [2]:
# pLMs
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"

print(device)

cpu


In [4]:
def load_ProstT5():
    global prostt5_tokenizer, prostt5_model
    if "prostt5_tokenizer" not in globals():
        prostt5_tokenizer = T5Tokenizer.from_pretrained('Rostlab/ProstT5_fp16')
    if "prostt5_model" not in globals():
        prostt5_model = T5EncoderModel.from_pretrained("Rostlab/ProstT5_fp16").to(device).eval()



# Load the CNN model once and make it a global variable
class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.classifier = torch.nn.Sequential(
            torch.nn.Conv2d(1024, 32, kernel_size=(7, 1), padding=(3, 0)),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.0),
            torch.nn.Conv2d(32, 20, kernel_size=(7, 1), padding=(3, 0))
        )

    def forward(self, x):
        x = x.permute(0, 2, 1).unsqueeze(dim=-1)
        Yhat = self.classifier(x)
        Yhat = Yhat.squeeze(dim=-1)
        return Yhat

# Function to predict 3Di sequence
def predict_3Di(sequence):
    """
    Predict 3Di sequence from an amino acid sequence.

    Args:
        sequence (str): Amino acid sequence.

    Returns:
        str: Predicted 3Di sequence.get_ProtT5_embeddings(accession, Seq_AA, site, feature_folder)
    """
    global prostt5_model, prostt5_tokenizer, cnn_model

    # Preprocess the sequence
    prefix = "<AA2fold>"
    seq = prefix + ' ' + ' '.join(list(sequence))
    token_encoding = prostt5_tokenizer(seq, return_tensors="pt").to(device)

    # Generate embeddings using the T5 model
    with torch.no_grad():
        embedding_repr = prostt5_model(**token_encoding)
        embedding = embedding_repr.last_hidden_state[:, 1:, :]  # Skip special token
        prediction = cnn_model(embedding)
        prediction = prediction.argmax(dim=1).squeeze().cpu().numpy()

    # Map predictions to 3Di symbols
    ss_mapping = {
        0: "A", 1: "C", 2: "D", 3: "E", 4: "F", 5: "G", 6: "H", 7: "I",
        8: "K", 9: "L", 10: "M", 11: "N", 12: "P", 13: "Q", 14: "R", 15: "S",
        16: "T", 17: "V", 18: "W", 19: "Y"
    }
    predicted_3Di = "".join([ss_mapping[p] for p in prediction])
    return predicted_3Di.lower()

In [5]:
def get_SaProt_embeddings(Seq_AA):
    """
    Get or compute SaProt embeddings for a protein sequence and its structural information.

    Parameters:
    - accession (str): Accession ID of the protein.
    - Seq_AA (str): Amino acid sequence of the protein.
    - site (int): Position of interest in the sequence.
    - feature_folder (str): Path to the folder containing precomputed features.
    - saprot_tokenizer: Tokenizer for SaProt.
    - saprot_model: Model for generating embeddings.
    - device: PyTorch device (e.g., 'cpu' or 'cuda').

    Returns:
    - torch.Tensor: Averaged representation of protein sequence.
    """

    Seq_3Di = predict_3Di(Seq_AA) # Use the provided foldseek code if pdb available
    
    # Combine sequence and structure
    combined_AA_3Di = "".join([a + b for a, b in zip(Seq_AA, Seq_3Di)])
    
    # Tokenize sequence
    inputs = saprot_tokenizer(combined_AA_3Di, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}  # Move inputs to the correct device
    
    # Generate embeddings
    embeddings_per_residue = saprot_model.get_hidden_states(inputs)[0]

    # Compute protein-level representation (mean pooling)
    protein_representation = embeddings_per_residue.mean(dim=0)
    
    return protein_representation

In [6]:
def parse_fasta(file_path):
    data = []
    for record in SeqIO.parse(file_path, "fasta"):
        accession_parts = record.id.split("|")
        accession = accession_parts[1] if len(accession_parts) > 1 else record.id  # Extract second part if exists
        AA_Seq = str(record.seq)
        data.append((accession, AA_Seq))
    return data

In [9]:
from model.saprot.base import SaprotBaseModel
from transformers import EsmTokenizer

def load_SaProt():
    global saprot_model, saprot_tokenizer

    saprot_config = {
        "task": "base",
        "config_path": "model/saprot/SaProt_650M_AF2/", # Note this is the directory path of SaProt, not the ".pt" file
        "load_pretrained": True,
    }
    
    if "saprot_tokenizer" not in globals():
        saprot_tokenizer = EsmTokenizer.from_pretrained(saprot_config["config_path"])
    if "saprot_model" not in globals():
        saprot_model = SaprotBaseModel(**saprot_config)

In [10]:
# Load ProstT5
load_ProstT5()

# Load SaProt
load_SaProt()

# Load CNN model
cnn_model = CNN()
checkpoint_path_3Di_prediction = "AA_to_3Di_prostt5_cnn_model.pt"
state = torch.load(checkpoint_path_3Di_prediction, map_location=device)
cnn_model.load_state_dict(state["state_dict"])
cnn_model = cnn_model.to(device).eval()

In [11]:
SaProt_features = get_SaProt_embeddings("MHWIATRNAVVSFPKWRFFFRSSYRTYSSLKPSSPILLNRRYSEGISCLRDGKSLKRITTASKKVKTSSDVLTDKDLSHLVWWKERLQTCKKPSTLQLIERLMYTNLLGLDPSLRNGSLKDGNLNWEMLQFKSRFPREVLLCRVGEFYEAIGIDACILVEYAGLNPFGGLRSDSIPKAGCPIMNLRQTLDDLTRNGYSVCIVEEVQGPTPARSRKGRFISGHAHPGSPYVYGLVGVDHDLDFPDPMPVVGISRSARGYCMISIFETMKAYSLDDGLTEEALVTKLRTRRCHHLFLHASLRHNASGTCRWGEFGEGGLLWGECSSRNFEWFEGDTLSELLSRVKDVYGLDDEVSFRNVNVPSKNRPRPLHLGTATQIGALPTEGIPCLLKVLLPSTCSGLPSLYVRDLLLNPPAYDIALKIQETCKLMSTVTCSIPEFTCVSSAKLVKLLEQREANYIEFCRIKNVLDDVLHMHRHAELVEILKLLMDPTWVATGLKIDFDTFVNECHWASDTIGEMISLDENESHQNVSKCDNVPNEFFYDMESSWRGRVKGIHIEEEITQVEKSAEALSLAVAEDFHPIISRIKATTASLGGPKGEIAYAREHESVWFKGKRFTPSIWAGTAGEDQIKQLKPALDSKGKKVGEEWFTTPKVEIALVRYHEASENAKARVLELLRELSVKLQTKINVLVFASMLLVISKALFSHACEGRRRKWVFPTLVGFSLDEGAKPLDGASRMKLTGLSPYWFDVSSGTAVHNTVDMQSLFLLTGPNGGGKSSLLRSICAAALLGISGLMVPAESACIPHFDSIMLHMKSYDSPVDGKSSFQVEMSEIRSIVSQATSRSLVLIDEICRGTETAKGTCIAGSVVESLDTSGCLGIVSTHLHGIFSLPLTAKNITYKAMGAENVEGQTKPTWKLTDGVCRESLAFETAKREGVPESVIQRAEALYLSVYAKDASAEVVKPDQIITSSNNDQQIQKPVSSERSLEKDLAKAIVKICGKKMIEPEAIECLSIGARELPPPSTVGSSCVYVMRRPDKRLYIGQTDDLEGRIRAHRAKEGLQGSSFLYLMVQGKSMACQLETLLINQLHEQGYSLANLADGKHRNFGTSSSLSTSDVVSIL")

In [13]:
SaProt_features.shape

torch.Size([1280])

In [23]:
# Define directory and file paths
data_dir = "/home/sp2530/Desktop/DNA-Binding-V2/data/plant/fasta/"

file_paths = [
    "DBP_independent.fasta",
    "DBP.fasta",
    "non_DBP.fasta",
    "non_DBP_independent.fasta"
]

# Parse FASTA files with full paths in a single list comprehension
data = [entry for file_path in file_paths for entry in parse_fasta(data_dir + file_path)]

In [24]:
data[:2]

[('Q5Z807',
  'MSRRQEICRNFQRGSCKYGAQCRYLHASPHQQQQQQQAKPNPFGFGTGSRQQQQPSFGSQFQQQQQQQQKPNPFGFGVQGANAQSRNAPGPAKPFQNKWVRDPSAPTKQTEAVQPPQAQAAHTSCEDPQSCRQQISEDFKNEAPIWKLTCYAHLRNGPCNIKGDISFEELRAKAYEEGKQGHSLQSIVEGERNLQNAKLMEFTNLLNSARPSQTPSFPTMSSFPEVKNNSSFGASQTNGPPVFSSFSQIGAATNIGPGPGTTAPGMPASSPFGHPSSAPLAAPTFGSSQMKFGVSSVFGNQGSGQPFGSFQAPRFPSSKSPASSVQHRDIDRQSQELLNGMVTPPSVMFEESVGNNKNENQDDSIWLKEKWAIGEIPLDEPPQRHVSHVF'),
 ('C0SVV6',
  'MRIPTYDFGSKFSVVQEVMRLQTVKHFLEPVLEPLIRKVVKEEVELALGKHLAGIKWICEKETHPLESRNLQLKFLNNLSLPVFTSARIEGDEGQAIRVGLIDPSTGQIFSSGPASSAKLEVFVVEGDFNSVSDWTDEDIRNNIVREREGKKPLLNGNVFAVLNDGIGVMDEISFTDNSSWTRSRKFRLGVRIVDQFDYVKIREAITESFVVRDHRGELYKKHHPPSLFDEVWRLEKIGKDGAFHRRLNLSNINTVKDFLTHFHLNSSKLRQVLGTGMSSKMWEITLDHARSCVLDSSVHVYQAPGFQKKTAVVFNVVAQVLGLLVDFQYIPAEKLSEIEKAQAEVMVIDALSHLNEVISYDDEVSMMRNVLNAPASQGSVAGIDYSGLSLTSLDGYGFVSSLHNTAECSGKHSDDVDMEVTPHGLYEDYDNLWNCSHILGLEEPQSELQSALDDFMSQKNASVGGKAHSKRWTKLFSVSRWLSVFKYVKLGKI')]

In [25]:
saprot_features_disk = "/home/sp2530/Desktop/DNA-Binding-V2/data/plant/features/saprot"

In [27]:
# Step 1: Create already_extracted_set
already_extracted_set = {
    os.path.splitext(file)[0].replace("_saprot", "")  # extract accession from file name
    for file in os.listdir(saprot_features_disk)
    if file.endswith("_saprot.pt")
}

already_extracted_set

{'C0SVV6',
 'Q5Z807',
 'Q75LX7',
 'Q7XC57',
 'Q84JF0',
 'Q8LFK2',
 'Q9FVV7',
 'Q9FX84'}

In [None]:
# Feature extraction
for accession, Seq_AA in tqdm.tqdm(data):
    get_SaProt_embeddings(accession, Seq_AA, saprot_features_disk)

  embeddings_per_residue = torch.load(feature_file)
100%|████████████████████████████████████▉| 2693/2695 [2:28:48<00:09,  4.71s/it]