# Utilizing ProteinEmbedding and Dataset formation

The _ProteinEmbedding_ object can store protein sequences in combination with embedding and vector representations,
which can then be manipulated by implementing standard operations from the __skbio__ library. 

To demonstrate this process, we will read in a list of protein sequences, then attempt to create a dataset from our 
embeddings to be streamed in with the standard skbio.read.

In [1]:
# Necessary imports
from skbio.embedding import ProteinEmbedding
from skbio.sequence import Protein
from tqdm import tqdm
import argparse
import skbio
import re

## Loading Embeddings

This function will take the inputted protein sequences and feed it through an embedding model (prot-t5), 
outputting the generated embeddings.

In [2]:
def load_protein_t5_embedding(sequence, model_name, tokenizer_name):
    import torch
    from transformers import T5Tokenizer, T5EncoderModel
    # (In case we want to use ONNX model)
    # from optimum.onnxruntime import ORTModel
    
    tokenizer = T5Tokenizer.from_pretrained(tokenizer_name)
    model = T5EncoderModel.from_pretrained(model_name)

    # convert sequence to formatted list of strings
    seq_list = []
    seq_list.append(sequence)
    seqs = [" ".join(list(re.sub(r"[UZOB]", "X", str(seq)))) for seq in seq_list]
    
    # tokenize sequences and pad up to the longest sequence in the batch
    ids = tokenizer.batch_encode_plus(seqs, add_special_tokens=True, padding="longest")
    input_ids = torch.tensor(ids['input_ids'])
    attention_mask = torch.tensor(ids['attention_mask'])

    # generate embeddings
    with torch.no_grad():
        embedding_repr = model(input_ids=input_ids,attention_mask=attention_mask)
        
    return ProteinEmbedding(embedding_repr.last_hidden_state.detach().cpu().numpy()[0][:-1], sequence)

### Create embeddings

In [3]:
# parse arguments
parser = argparse.ArgumentParser()

# Modify the default values of the arguments to match the desired values
parser.add_argument("--n_sequences", type=int, default=20)
parser.add_argument("--model_name", type=str, default="Rostlab/prot_t5_xl_uniref50")
parser.add_argument("--tokenizer_name", type=str, default="Rostlab/prot_t5_xl_uniref50")
args = parser.parse_args("")

#need str like object
sequence_list = skbio.read('test.fa', format='fasta', constructor=Protein)
#sequence_list = list(skbio.read('test.fa', format='fasta', constructor=Protein))[:args.n_sequences]
#print(sequence_list[0].sequence)
loaded_proteins = lambda sequence: load_protein_t5_embedding(sequence, args.model_name, args.tokenizer_name)


embed_list = (loaded_proteins(x) for x in sequence_list)


## Write to file

Here, we can write our loaded embeddings to a .h5 file using skbio.write. To verify that the embeddings were stored correctly, 

In [5]:
#embed_list = map(loaded_proteins, sequence_list)
# Issue: Generator should include iterables
skbio.write(embed_list, format='embed', into="bagel.h5")

#test if the file was written correctly and output
read_embed = iter(skbio.read("bagel.h5", format='embed' ,constructor=ProteinEmbedding))
for item in read_embed:
    print(item.embedding)

UnboundLocalError: cannot access local variable 'idptr_fh' where it is not associated with a value