In [2]:
from Bio.Blast import NCBIWWW
from Bio.Blast import NCBIXML
from Bio import SeqIO, Entrez
from io import StringIO

In [18]:
# parse the blastp output file

sequences = []

# Parse the results from the files
with open("./proteins/keratin/keratin-blastp.xml") as in_handle:
    blast_records = NCBIXML.parse(in_handle)
    for blast_record in blast_records:
        for alignment in blast_record.alignments:
            for hsp in alignment.hsps:
                if hsp.expect < 0.0001:
                    print("****Alignment****")
                    print("sequence:", alignment.title)
                    print("length:", alignment.length)
                    print("e value:", hsp.expect)
                    print(hsp.query[0:75] + "...")
                    print(hsp.match[0:75] + "...")
                    print(hsp.sbjct[0:75] + "...")
        

****Alignment****
sequence: gb|EAW58243.1| hCG1997648 [Homo sapiens] >gb|KAI4066085.1| keratin 81 [Homo sapiens] >dbj|BAI45776.1| keratin 81, partial [synthetic construct] >emb|CAA73943.1| keratin [Homo sapiens]
length: 505
e value: 0.0
MTCGSGFGGRAFSCISACGPRPGRCCITAAPYRGISCYRGLTGGFGSHSVCGGFRAGSCGRSFGYRSGGVCGPSP...
MTCGSGFGGRAFSCISACGPRPGRCCITAAPYRGISCYRGLTGGFGSHSVCGGFRAGSCGRSFGYRSGGVCGPSP...
MTCGSGFGGRAFSCISACGPRPGRCCITAAPYRGISCYRGLTGGFGSHSVCGGFRAGSCGRSFGYRSGGVCGPSP...
****Alignment****
sequence: ref|XP_018893534.2| keratin, type II cuticular Hb1 [Gorilla gorilla gorilla]
length: 505
e value: 0.0
MTCGSGFGGRAFSCISACGPRPGRCCITAAPYRGISCYRGLTGGFGSHSVCGGFRAGSCGRSFGYRSGGVCGPSP...
MTCGSGFGGRAFSCISACGPRPGRCCITAAPYRGISCYRGLTGGFGSHSVCGGFRAGSCGRSFGYRSGGVCGPSP...
MTCGSGFGGRAFSCISACGPRPGRCCITAAPYRGISCYRGLTGGFGSHSVCGGFRAGSCGRSFGYRSGGVCGPSP...
****Alignment****
sequence: ref|XP_024112725.1| keratin, type II cuticular Hb1 [Pongo abelii] >gb|PNJ25503.1| KRT86 isoform 1 [Pongo abelii]
length: 505
e valu

# ESM to extract the embeddings of the sequences

The Meta Fundamental AI Research Protein Team (FAIR) already made available a series of pre-trained models for working with biological data. The ESM-2 language model is able to extract representations directly relying solely on the single sequence and produce accurate structure prediction. Recently, facebook research made available the [ESM metagenomic atlas](https://esmatlas.com/) an open database made of 617 million predicted metagenomic protein structures.

Actually, from the sequences we aim to extract the embeddings using ESM-2.

In [186]:
# !pip install git+https://github.com/facebookresearch/esm.git

Collecting git+https://github.com/facebookresearch/esm.git
  Cloning https://github.com/facebookresearch/esm.git to /tmp/pip-req-build-uftfbug8
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/esm.git /tmp/pip-req-build-uftfbug8
  Resolved https://github.com/facebookresearch/esm.git to commit c9c7d4f0fec964ce10c3e11dccec6c16edaa5144
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hBuilding wheels for collected packages: fair-esm
  Building wheel for fair-esm (pyproject.toml) ... [?25ldone
[?25h  Created wheel for fair-esm: filename=fair_esm-2.0.1-py3-none-any.whl size=105311 sha256=1c64850705d1296536ad3369af9a41e2829834f8693ba90ed481eb94841cf435
  Stored in directory: /tmp/pip-ephem-wheel-cache-o4de2_du/wheels/f3/b2/ec/4db0b108f6367c7563f99b2445e1137d486003fb2f9bfd2f53
Successfully built fair-esm
Installing collected packa

In [1]:
import esm
import torch
import numpy as np
import glob
import pickle

In [2]:
# Load ESM-2 model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() #esm2_t36_3B_UR50D()#esm2_t48_15B_UR50D()
batch_converter = alphabet.get_batch_converter()
# half precision float16
model = model.half()
model.eval()  # disables dropout for deterministic results

ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

In [3]:
# allocate the model on the GPU
model = model.cuda()

In [4]:
# load the sequences and cast them
files = glob.glob("./proteins/*/sequences/sequences_pck.pkl")

# list with all the sequences
sequences = []

for file in files:
    # load the sequence
    with open(file, 'rb') as f:
        seq = pickle.load(f)

    # cast all the sequences as a string 
    for key in seq.keys():
        seq[key] = str(seq[key])

    sequences.append(seq)

In [17]:
print('number of protein: ',len(sequences))
print()
[print(f'number of sequences protein {idx+1}: ',len(s)) for idx, s in enumerate(sequences)]
print()
print('total number of sequences: ',sum([len(s) for s in sequences]))

number of protein:  15

number of sequences protein 1:  500
number of sequences protein 2:  500
number of sequences protein 3:  500
number of sequences protein 4:  500
number of sequences protein 5:  500
number of sequences protein 6:  498
number of sequences protein 7:  489
number of sequences protein 8:  500
number of sequences protein 9:  500
number of sequences protein 10:  500
number of sequences protein 11:  500
number of sequences protein 12:  500
number of sequences protein 13:  500
number of sequences protein 14:  500
number of sequences protein 15:  500

total number of sequences:  7487


In [4]:
import gc

def get_batch(sequences_token, idx):
    """ Function to get the batch of sequences

    Returns:
        list: list with all the sequences
    """
    if idx == 5:
        yield [sequences_token[m:M]
                    for m, M in zip(np.linspace(0, 488, 60)[:59].astype(int),np.linspace(0, 488, 60)[1:].astype(int))]
        
    elif idx == 6:
        yield [sequences_token[m:M]
                    for m, M in zip(np.linspace(0, 498, 60)[:59].astype(int),np.linspace(0, 498, 60)[1:].astype(int))]
    elif idx == 7:
        yield [sequences_token[m:M]
                    for m, M in zip(np.linspace(0, 489, 60)[:59].astype(int),np.linspace(0, 489, 60)[1:].astype(int))]
        
    else:
        yield [sequences_token[m:M]
                    for m, M in zip(np.linspace(0, 500, 80)[:79].astype(int),np.linspace(0, 500, 80)[1:].astype(int))]


def get_esm_encoding(batch_tokens, model, device):
    # allocate the token on the GPU
    batch_tokens = batch_tokens.to(device)

    # Extract per-residue representations
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[33], return_contacts=False)

    yield results["representations"][33].type(torch.float16).cpu()

def extract_embedding(sequences_token, model, idx):
    """ Function to extract the embedding from the sequences

    Args:
        sequences_token (list): list with all the sequences
        model (torch.nn.Module): model to use to extract the embedding
        alphabet (esm.Alphabet): alphabet of the model
        batch_converter (esm.pretrained.esm1_t6_43M_UR50S): converter of the model

    Returns:
        list: list with all the embedding
    """

    # list with all the extracted representations
    token_representations = []

    for batch_tokens in list(get_batch(sequences_token, idx))[0]:

        # extract the embedding
        token_representations.append(list(get_esm_encoding(batch_tokens, model, device))[0])

        # free the memory
        gc.collect()
        del batch_tokens

    return token_representations


from tqdm import tqdm

# load the sequences and cast them
files = glob.glob("./proteins/*/sequences/sequences_pck.pkl")

# list with all the sequences
sequences = []

for file in files:
    # load the sequence
    with open(file, 'rb') as f:
        seq = pickle.load(f)

    s_seq = dict()
    # cast all the sequences as a string 
    for key in seq.keys():
        # I have to check and filter out the 
        # sequences with J otherwise the model
        # get an error
        if 'J' in str(seq[key]):
            break
        else:
            # store the sequence
            s_seq[key] = str(seq[key])

    sequences.append(s_seq)

del seq

print('number of protein: ',len(sequences))
print()
[print(f'protein {idx+1}: ',len(s)) for idx,s in enumerate(sequences)]
print()

sequences = sequences[8:11]

# check is cuda available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# list with all the sequences tokenized
sequences_token = []

for idx in tqdm(range(len(sequences))):
    # extract data
    data = list(sequences[idx].items())

    # convert the data in a batch
    batch_labels, batch_strs, batch_tokens = batch_converter(data)

    # check the length of the batch
    batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

    sequences_token.append(extract_embedding(batch_tokens, model,idx+8))



number of protein:  15

protein 1:  500
protein 2:  500
protein 3:  500
protein 4:  500
protein 5:  488
protein 6:  498
protein 7:  489
protein 8:  500
protein 9:  500
protein 10:  500
protein 11:  500
protein 12:  500
protein 13:  500
protein 14:  500
protein 15:  500



100%|██████████| 3/3 [05:31<00:00, 110.50s/it]


In [58]:
# print('number of protein: ',len(sequences_token))
# print()
# open('esm_embedding_0_3.pkl', 'wb').write(pickle.dumps(sequences_token))


number of protein:  4



4000073055

In [81]:
# print('number of protein: ',len(sequences_token))
# print()
# open('esm_embedding_4_7.pkl', 'wb').write(pickle.dumps(sequences_token))

number of protein:  4



1413264558

In [5]:
print('number of protein: ',len(sequences_token))
print()
open('esm_embedding_8_10.pkl', 'wb').write(pickle.dumps(sequences_token))

number of protein:  3



3299913281

In [6]:
import gc

def get_batch(sequences_token, idx):
    """ Function to get the batch of sequences

    Returns:
        list: list with all the sequences
    """
    if idx == 5:
        yield [sequences_token[m:M]
                    for m, M in zip(np.linspace(0, 488, 60)[:59].astype(int),np.linspace(0, 488, 60)[1:].astype(int))]
        
    elif idx == 6:
        yield [sequences_token[m:M]
                    for m, M in zip(np.linspace(0, 498, 60)[:59].astype(int),np.linspace(0, 498, 60)[1:].astype(int))]
    elif idx == 7:
        yield [sequences_token[m:M]
                    for m, M in zip(np.linspace(0, 489, 60)[:59].astype(int),np.linspace(0, 489, 60)[1:].astype(int))]
        
    else:
        yield [sequences_token[m:M]
                    for m, M in zip(np.linspace(0, 500, 80)[:79].astype(int),np.linspace(0, 500, 80)[1:].astype(int))]


def get_esm_encoding(batch_tokens, model, device):
    # allocate the token on the GPU
    batch_tokens = batch_tokens.to(device)

    # Extract per-residue representations
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[33], return_contacts=False)

    yield results["representations"][33].type(torch.float16).cpu()

def extract_embedding(sequences_token, model, idx):
    """ Function to extract the embedding from the sequences

    Args:
        sequences_token (list): list with all the sequences
        model (torch.nn.Module): model to use to extract the embedding
        alphabet (esm.Alphabet): alphabet of the model
        batch_converter (esm.pretrained.esm1_t6_43M_UR50S): converter of the model

    Returns:
        list: list with all the embedding
    """

    # list with all the extracted representations
    token_representations = []

    for batch_tokens in list(get_batch(sequences_token, idx))[0]:

        # extract the embedding
        token_representations.append(list(get_esm_encoding(batch_tokens, model, device))[0])

        # free the memory
        gc.collect()
        del batch_tokens

    return token_representations


from tqdm import tqdm

# load the sequences and cast them
files = glob.glob("./proteins/*/sequences/sequences_pck.pkl")

# list with all the sequences
sequences = []

for file in files:
    # load the sequence
    with open(file, 'rb') as f:
        seq = pickle.load(f)

    s_seq = dict()
    # cast all the sequences as a string 
    for key in seq.keys():
        # I have to check and filter out the 
        # sequences with J otherwise the model
        # get an error
        if 'J' in str(seq[key]):
            break
        else:
            # store the sequence
            s_seq[key] = str(seq[key])

    sequences.append(s_seq)

del seq

print('number of protein: ',len(sequences))
print()
[print(f'protein {idx+1}: ',len(s)) for idx,s in enumerate(sequences)]
print()

sequences = sequences[11:]

# check is cuda available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# list with all the sequences tokenized
sequences_token = []

for idx in tqdm(range(len(sequences))):
    # extract data
    data = list(sequences[idx].items())

    # convert the data in a batch
    batch_labels, batch_strs, batch_tokens = batch_converter(data)

    # check the length of the batch
    batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

    sequences_token.append(extract_embedding(batch_tokens, model,idx+8))



number of protein:  15

protein 1:  500
protein 2:  500
protein 3:  500
protein 4:  500
protein 5:  488
protein 6:  498
protein 7:  489
protein 8:  500
protein 9:  500
protein 10:  500
protein 11:  500
protein 12:  500
protein 13:  500
protein 14:  500
protein 15:  500



100%|██████████| 4/4 [04:17<00:00, 64.26s/it]


In [7]:
print('number of protein: ',len(sequences_token))
print()
open('esm_embedding_11_14.pkl', 'wb').write(pickle.dumps(sequences_token))

number of protein:  4



3467617775

In [None]:
from esm.data import ESMStructuralSplitDataset
from esm.pretrained import load_model_and_alphabet


# load the model
model, alphabet = load_model_and_alphabet('esm1_t34_670M_UR50S')

# load the dataset
data = ESMStructuralSplitDataset

The dataframe has 773,846,840 records, and the file size is around 16GB. This dataframe has 10 columns:

+ id is the MGnify ID

+ ptm is the predicted TM score

+ plddt is the predicted average lddt

+ num_conf is the number of residues with plddt > 0.7

+ len is the total residues in the protein

+ is_fragment indicates whether the protein sequence is identified as a fragment in the MGnify90 sequence database.

+ sequenceChecksum is the CRC64 hash of the sequence. Can be used for cheaper lookups.

+ esmfold_version is the version of ESMFold, matching the model accessible as esm.pretrained.esmfold_v{0,1}

+ atlas_version is the Atlas version where this structure first appeared. Note: some of the predictions appearing for the first time in v0 are also part of Atlas v2023_02.

+ sequence_dbs is the metagenomic source databases where this structure is part of, as MGnify90_2022_05, comma-separated if it exists in more than one release, as MGnify90_2022_05,MGnify90_2023_02

source:https://github.com/facebookresearch/esm/tree/main/scripts/atlas

In [2]:
import pandas as pd

# df = pd.read_parquet('/home/rickbook/document/applied-nlp/metadata-rc2.parquet')

{'AAP36614.1': Seq('MLRAAARFGPRLGRRLLSAAATQAVPAPNQQPEVFCNQIFINNEWHDAVSRKTF...KNS'),
 'NP_000681.2': Seq('MLRAAARFGPRLGRRLLSAAATQAVPAPNQQPEVFCNQIFINNEWHDAVSRKTF...KNS'),
 'BAD97093.1': Seq('MLRAAARFGPRLGRRLLSAAATQAVPAPNQQPEVFCNQIFINNEWHDAVSRKTF...KNS'),
 'XP_054298134.1': Seq('MLRAAAHFGPRLGRRLLSAAATQAVPAPNQQPEVFCNQIFINNEWHDAVSRKTF...KNS'),
 'AAA51693.1': Seq('MLRAAARFGPRLGRRLLSAAATQAVPAPNQQPEVFCNQIFINNEWHDAVSRKTF...KNS'),
 'AAT41621.1': Seq('MLRAAARFGPRLGRRLLSAAATQAVPAPNQQPEVFCNQIFINNEWHDAVSRKTF...KNS'),
 'NP_001124747.1': Seq('MLRAAARFGPRLGRRLLSAAATQAVPAPNQQPEVFCNQIFINNEWHDAVSRKTF...KNS'),
 'XP_030677272.1': Seq('MLRAAALFGPRLGRRLLSAAATQAVPAPNQQPEVFCNQIFINNEWHDAVSRKTF...KNS'),
 'XP_003832508.1': Seq('MLRAAACFGPRLGRRLLSAAATQAVPAPNQQPEVFCNQIFINNEWHDAVSRKTF...KNS'),
 'CAG33272.1': Seq('MLRAAARFGPRLGRRLLSAAATQAVPAPNQQPEVFCNQIFINNEWHDAVSRKTF...KNS'),
 'XP_016779748.1': Seq('MFRAAACLGPRLGRRLLSAAATQAVPAPNQQPEVFCNQIFINNEWHDAVSRKTF...KNS'),
 'XP_010377208.1': Seq('MLRAAARFGPRLGLRLLSAAATQAVPAPNQQ