# Utilizing ProteinEmbedding and Dataset formation

The _ProteinEmbedding_ object can store protein sequences in combination with embedding and vector representations,
which can then be manipulated by implementing standard operations from the __skbio__ library. 

To demonstrate this process, we will read in a list of protein sequences, then attempt to create a dataset from our 
embeddings to be streamed in with the standard skbio.read.

In [1]:
# Necessary imports
from skbio.embedding import ProteinEmbedding
from skbio.sequence import Protein
from tqdm import tqdm
import numpy as np
import argparse
import skbio

## Protein Sequences

We will now create a function to generate a list of random proteins (of size args.n_prots) to be embedded.

In [2]:
def generate_proteins(n_prots):
    import numpy as np
    PROTEIN_ALPHABET = "ACDEFGHIKLMNPQRSTVWY"
    np.random.seed(42)
    proteins = []
    for _ in range(n_prots):
        prot = "".join(
            np.random.choice(list(PROTEIN_ALPHABET),
                             size=np.random.randint(20, 100)))
        proteins.append(prot)
    return proteins

## Loading FASTA File

_Bagel.fa_ is a file containing bacteriocin sequences stored in FASTA format. We will parse this
file and use the output to create our __ProteinEmbedding__ object.

## Loading Embeddings

This function will take the inputted protein sequences and feed it through an embedding model (prot-t5), 
outputting the generated embeddings.

In [9]:
import torch
from transformers import T5Tokenizer, T5EncoderModel
from skbio.embedding import ProteinEmbedding

def load_protein_t5_embedding(sequence, model_name, tokenizer_name):
    # (In case we want to use ONNX model)
    # from optimum.onnxruntime import ORTModel
    tokenizer = T5Tokenizer.from_pretrained(tokenizer_name)
    model = T5EncoderModel.from_pretrained(model_name)
    # tokenize sequences and pad up to the longest sequence in the batch
    ids = tokenizer.batch_encode_plus(sequence, add_special_tokens=True, padding="longest")
    input_ids = torch.tensor(ids['input_ids'])
    attention_mask = torch.tensor(ids['attention_mask'])
    # generate embeddings
    with torch.no_grad():
        embedding_repr = model(input_ids=input_ids,attention_mask=attention_mask)
    emb = embedding_repr.last_hidden_state[:, 0, :].squeeze()
    return ProteinEmbedding(emb, sequence)


def to_embeddings(sequences : list, model_name, tokenizer_name):
    # Embed the random/inputted protein sequence(s)
    for sequence in tqdm(sequence_list):
        test_embed = load_protein_t5_embedding(str(sequence), model_name, tokenizer_name)
        #reshape embeddings to fit the skbio format
        yield test_embed

## Passing to file

Finally, we can output the embeddings into a "test.h5" file, which can be utilized further
as will be demonstrated in other scikit-bio tutorials.

In [11]:
model_name = "Rostlab/prot_t5_xl_uniref50"
tokenizer_name = "Rostlab/prot_t5_xl_uniref50"

# Parse bagel.fa
sequence_list = skbio.io.read("bagel.fa", format='fasta')
embed_list = to_embeddings(sequence_list, model_name, tokenizer_name)
skbio.write(embed_list, format='embed', into="bagel.h5")

#test if the file was written correctly and output
read_embed = iter(skbio.read("bagel.h5", format='embed' ,constructor=ProteinEmbedding))
item = next(read_embed)
item

0it [00:00, ?it/s]

torch.Size([21, 1024])


1it [00:07,  7.60s/it]

torch.Size([47, 1024])


2it [00:14,  7.32s/it]

torch.Size([44, 1024])


4it [00:28,  6.91s/it]

torch.Size([67, 1024])
torch.Size([40, 1024])


6it [00:41,  6.80s/it]

torch.Size([75, 1024])
torch.Size([60, 1024])


8it [00:54,  6.65s/it]

torch.Size([60, 1024])
torch.Size([58, 1024])


9it [01:00,  6.54s/it]

torch.Size([87, 1024])


10it [01:07,  6.70s/it]

torch.Size([61, 1024])


12it [01:21,  6.78s/it]

torch.Size([71, 1024])
torch.Size([62, 1024])


14it [01:35,  6.82s/it]

torch.Size([47, 1024])


15it [01:42,  6.88s/it]

torch.Size([79, 1024])


16it [01:49,  6.94s/it]

torch.Size([71, 1024])
torch.Size([71, 1024])


18it [02:02,  6.82s/it]

torch.Size([71, 1024])


19it [02:09,  6.75s/it]

torch.Size([71, 1024])
torch.Size([46, 1024])


21it [02:22,  6.72s/it]

torch.Size([76, 1024])
torch.Size([53, 1024])


23it [02:37,  7.12s/it]

torch.Size([58, 1024])
torch.Size([57, 1024])


25it [02:51,  7.05s/it]

torch.Size([74, 1024])
torch.Size([102, 1024])


27it [03:06,  7.14s/it]

torch.Size([64, 1024])
torch.Size([43, 1024])


29it [03:19,  6.97s/it]

torch.Size([81, 1024])


30it [03:26,  6.84s/it]

torch.Size([64, 1024])
torch.Size([41, 1024])


32it [03:39,  6.82s/it]

torch.Size([66, 1024])


33it [03:46,  6.69s/it]

torch.Size([57, 1024])


34it [03:52,  6.59s/it]

torch.Size([57, 1024])


35it [03:58,  6.51s/it]

torch.Size([62, 1024])
torch.Size([63, 1024])


36it [04:05,  6.56s/it]

torch.Size([30, 1024])


38it [04:18,  6.64s/it]

torch.Size([61, 1024])


39it [04:25,  6.53s/it]

torch.Size([58, 1024])
torch.Size([40, 1024])


IndexError: Index (39) out of range for (0-38)