# Utilizing ProteinEmbedding and Dataset formation

The _ProteinEmbedding_ object can store protein sequences in combination with embedding and vector representations,
which can then be manipulated by implementing standard operations from the __skbio__ library. 

To demonstrate this process, we will read in a list of protein sequences, then attempt to create a dataset from our 
embeddings to be streamed in with the standard skbio.read.

In [1]:
# Necessary imports
from skbio.embedding import ProteinEmbedding
from skbio.sequence import Protein
from tqdm import tqdm
import argparse
import skbio
import re

## Loading Embeddings

This function will take the inputted protein sequences and feed it through an embedding model (prot-t5), 
outputting the generated embeddings.

In [2]:
def load_protein_t5_embedding(sequence, model_name, tokenizer_name):
    import torch
    from transformers import T5Tokenizer, T5EncoderModel
    # (In case we want to use ONNX model)
    # from optimum.onnxruntime import ORTModel
    
    tokenizer = T5Tokenizer.from_pretrained(tokenizer_name)
    model = T5EncoderModel.from_pretrained(model_name)

    # convert sequence to formatted list of strings
    seq_list = []
    seq_list.append(sequence)
    seqs = [" ".join(list(re.sub(r"[UZOB]", "X", str(seq)))) for seq in seq_list]
    
    # tokenize sequences and pad up to the longest sequence in the batch
    ids = tokenizer.batch_encode_plus(seqs, add_special_tokens=True, padding="longest")
    input_ids = torch.tensor(ids['input_ids'])
    attention_mask = torch.tensor(ids['attention_mask'])

    # generate embeddings
    with torch.no_grad():
        embedding_repr = model(input_ids=input_ids,attention_mask=attention_mask)
        
    return ProteinEmbedding(embedding_repr.last_hidden_state.detach().cpu().numpy()[0][:-1], sequence)

### Create embeddings

In [3]:
# parse arguments
parser = argparse.ArgumentParser()

# Modify the default values of the arguments to match the desired values
parser.add_argument("--n_sequences", type=int, default=20)
parser.add_argument("--model_name", type=str, default="Rostlab/prot_t5_xl_uniref50")
parser.add_argument("--tokenizer_name", type=str, default="Rostlab/prot_t5_xl_uniref50")
args = parser.parse_args("")

#need str like object
sequence_list = skbio.read('test.fa', format='fasta', constructor=Protein)
#sequence_list = list(skbio.read('test.fa', format='fasta', constructor=Protein))[:args.n_sequences]
#print(sequence_list[0].sequence)
loaded_proteins = lambda sequence: load_protein_t5_embedding(sequence, args.model_name, args.tokenizer_name)


embed_list = (loaded_proteins(x) for x in sequence_list)


## Write to file

Here, we can write our loaded embeddings to a .h5 file using skbio.write. To verify that the embeddings were stored correctly, 

In [4]:
#embed_list = map(loaded_proteins, sequence_list)
# Issue: Generator should include iterables
skbio.write(embed_list, format='embed', into="bagel.h5")

#test if the file was written correctly and output
read_embed = iter(skbio.read("bagel.h5", format='embed' ,constructor=ProteinEmbedding))
for item in read_embed:
    print(item.embedding)

  from .autonotebook import tqdm as notebook_tqdm
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


[[ 0.05566912  0.05675846 -0.16483116 ...  0.05404978  0.01345288
   0.03236827]
 [ 0.14818218  0.12554237 -0.2660721  ...  0.1882809  -0.05386918
   0.09356216]
 [ 0.14772612  0.10723014 -0.21425001 ...  0.05648247  0.2655668
  -0.05122016]
 ...
 [ 0.12038261  0.15324979  0.13185279 ...  0.22860621 -0.15126939
  -0.09792379]
 [-0.02801923  0.04539812 -0.05980184 ...  0.07925158 -0.08836559
  -0.17490435]
 [ 0.09879396  0.21208242  0.15040213 ...  0.20142882 -0.02289281
  -0.08747597]]
[[-0.08464682 -0.12554063 -0.16223185 ...  0.21731049  0.5165856
  -0.03173729]
 [-0.09767824 -0.04928781 -0.1939477  ...  0.17679492  0.1372093
  -0.18649372]
 [-0.01484734 -0.09123899 -0.21456687 ...  0.14749224 -0.2751867
  -0.2023742 ]
 ...
 [-0.19169022  0.01500933  0.2335486  ...  0.28378057 -0.20650242
  -0.32116067]
 [ 0.11082172  0.00871864 -0.1059447  ...  0.097903   -0.28522074
  -0.00503905]
 [ 0.0339542  -0.21907638  0.08403985 ...  0.02584727 -0.15630914
  -0.45342624]]
