<a href="https://colab.research.google.com/github/tymor22/tm-vec/blob/master/google_colabs/Search_use_TM_Vec_search_to_search_for_related_sequences.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Notes:
1. In order to use TM-Vec and DeepBlast, you need to install TM-Vec, DeepBlast, and the huggingface transformers library. 
2. You will also need to download the ProtT5-XL-UniRef50 encoder (large language model that TM-Vec and DeepBlast uses), the trained TM-Vec model, and the trained DeepBlast model. As the ProtT5-XL-UniRef50 encoder is very large (~11.3GB), unless you have the necessary RAM on your GPU (at least more than the model), you may have to use a CPU runtime on Google Colab.
3. This notebook demonstrates how TM-Vec can be used to search for related proteins contained within large protein databases to queries proteins.


<h3>Searching for related protein sequences using a trained TM-Vec model, and then aligning the related sequences using DeepBlast</h3>

**1. Install the relevant libraries including tm-vec, deepblast, the huggingface transformers library, and faiss**

In [None]:
%pip install git+https://github.com/tymor22/tm-vec.git -q gwpy
%pip install git+https://github.com/flatironinstitute/deepblast.git -q gwpy
%pip install -q SentencePiece transformers  -q gwpy
%pip install faiss-cpu  -q gwpy

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m825.8/825.8 KB[0m [31m22.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m60.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m56.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.0/51.0 KB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.4/9.4 MB[0m [31m74.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.4/45.4 KB[0m [31m2.4 MB/s[0m eta [36m0:00:

<b>2. Load the relevant libraries<b>

In [None]:
import torch
from transformers import T5EncoderModel, T5Tokenizer
import re
import gc
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from tm_vec.embed_structure_model import trans_basic_block, trans_basic_block_Config
from tm_vec.tm_vec_utils import featurize_prottrans, embed_tm_vec, encode, load_database, query
import faiss

<b>3. Load the ProtT5-XL-UniRef50 tokenizer and model<b>

In [None]:
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False )
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
gc.collect()

Downloading (…)"spiece.model";:   0%|          | 0.00/238k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/24.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/546 [00:00<?, ?B/s]

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/11.3G [00:00<?, ?B/s]

Some weights of the model checkpoint at Rostlab/prot_t5_xl_uniref50 were not used when initializing T5EncoderModel: ['decoder.block.22.layer.2.DenseReluDense.wi.weight', 'decoder.block.11.layer.1.EncDecAttention.q.weight', 'decoder.block.21.layer.2.DenseReluDense.wi.weight', 'decoder.block.13.layer.0.SelfAttention.v.weight', 'decoder.block.16.layer.0.SelfAttention.q.weight', 'decoder.block.15.layer.1.EncDecAttention.q.weight', 'decoder.block.6.layer.0.layer_norm.weight', 'decoder.block.8.layer.1.EncDecAttention.q.weight', 'decoder.block.13.layer.0.layer_norm.weight', 'decoder.block.0.layer.1.EncDecAttention.q.weight', 'decoder.block.14.layer.2.DenseReluDense.wi.weight', 'decoder.block.15.layer.1.EncDecAttention.v.weight', 'decoder.block.16.layer.1.EncDecAttention.v.weight', 'decoder.block.13.layer.1.EncDecAttention.q.weight', 'decoder.block.19.layer.0.SelfAttention.o.weight', 'decoder.block.22.layer.1.layer_norm.weight', 'decoder.block.4.layer.1.EncDecAttention.k.weight', 'decoder.bloc

586

<b>3. Put the model onto your GPU if it is avilabile, switching the model to inference mode<b>

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
model = model.to(device)
model = model.eval()

cpu


<b>4. Download a trained TM-Vec model, its configuration file, and a trained DeepBlast model<b>

In [None]:
!wget https://users.flatironinstitute.org/thamamsy/public_www/tm_vec_cath_model.ckpt -q gwpy
!wget https://users.flatironinstitute.org/thamamsy/public_www/tm_vec_cath_model_params.json -q gwpy
!wget https://users.flatironinstitute.org/jmorton/public_www/deepblast-public-data/checkpoints/deepblast-lstm4x.pt -q gwpy

<b> 5. Load the trained TM-Vec model<b>

In [None]:
#TM-Vec model paths
tm_vec_model_cpnt = "tm_vec_cath_model.ckpt"
tm_vec_model_config = "tm_vec_cath_model_params.json"

#Load the TM-Vec model
tm_vec_model_config = trans_basic_block_Config.from_json(tm_vec_model_config)
model_deep = trans_basic_block.load_from_checkpoint(tm_vec_model_cpnt, config=tm_vec_model_config)
model_deep = model_deep.to(device)
model_deep = model_deep.eval()

INFO:pytorch_lightning.utilities.migration.utils:Lightning automatically upgraded your loaded checkpoint from v1.5.8 to v1.9.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file tm_vec_cath_model.ckpt`


<b> 7. Load one of our TM-Vec embedding databases and the associated metadata, or use one of your own (i.e. embed your own collection of protein sequences)<b>

In [None]:
!wget https://users.flatironinstitute.org/thamamsy/public_www/embeddings_cath_s100_final.npy -q gwpy
!wget https://users.flatironinstitute.org/thamamsy/public_www/embeddings_cath_s100_w_metadata.tsv -q gwpy

<b> 8. Load or paste some sequences that you would like to query the database with <b>

In [None]:
sequences = ["MKRESHKHAEQARRNRLAVALHELASLIPAEWKQQNVSAAPSKATTVEAACRYIRHLQQNGST","MERPYACPVESCDRRFSQSGSLTRHIRIHTGQ"]

<b> 9. Embed your query sequences using the same TM-Vec model used to make the embeddings database <b> 



In [None]:
queries = encode(sequences, model_deep, model, tokenizer, device)

<b>10. Load and index the lookup database<b>

In [None]:
#Load the database that we will query
#Make sure that the query database was encoded using the same model that's being applied to the query (i.e. CATH and CATH database)
lookup_database = load_database("embeddings_cath_s100_final.npy")
metadata_for_lookup_database = pd.read_csv("embeddings_cath_s100_w_metadata.tsv", sep="\t")

<b>11. Return the k nearest neighbors to query sequences <b>

In [None]:
k = 10
D, I = query(lookup_database, queries, k)

In [None]:
print("TM scores for the nearest neighbors")
D

TM scores for the nearest neighbors


array([[1.0000001 , 0.81756675, 0.8130522 , 0.8094741 , 0.8080609 ,
        0.80032474, 0.7942065 , 0.7830712 , 0.7816005 , 0.78118867],
       [1.        , 0.94502103, 0.92700404, 0.92061406, 0.90921414,
        0.90611005, 0.9052031 , 0.84973043, 0.8436762 , 0.8435321 ]],
      dtype=float32)

In [None]:
#Get metadata for the top neighbor
near_meta = []
for i in range(I.shape[0]):
    meta = metadata_for_lookup_database.iloc[I[i, 0]]
    near_meta.append(meta)

In [None]:
#1st queries nearest neighbors meta data
near_meta[0]

Cath_ID                                 1a0aA00
CATH_full               cath|4_3_0|1a0aA00/0-62
Cath_Domain                             1a0aA00
Class                                         4
Architecture                                 10
Topology                                    280
Homology                                     10
S35_cluster                                  15
S60_cluster                                   1
S95_cluster                                   1
S100_cluster                                  1
S100_count                                    1
Domain_length                                63
Structure_resolution                        2.8
Name: 94, dtype: object