In [32]:
import numpy as np
import pandas as pd

import torch
from torch.utils.data import Dataset

from embed_structure_model import trans_basic_block, trans_basic_block_Config
from tm_vec_utils import featurize_prottrans, embed_tm_vec, encode, load_database, query

from transformers import T5EncoderModel, T5Tokenizer
import re
import gc

from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns 
from Bio import SeqIO
import time

In [2]:
#Load the ProtTrans model and ProtTrans tokenizer
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False )
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
gc.collect()

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model = model.eval()

Some weights of the model checkpoint at Rostlab/prot_t5_xl_uniref50 were not used when initializing T5EncoderModel: ['decoder.block.17.layer.2.layer_norm.weight', 'decoder.block.6.layer.2.layer_norm.weight', 'decoder.block.20.layer.0.SelfAttention.q.weight', 'decoder.block.1.layer.1.layer_norm.weight', 'decoder.block.2.layer.0.SelfAttention.k.weight', 'decoder.block.2.layer.0.SelfAttention.o.weight', 'decoder.block.10.layer.1.EncDecAttention.v.weight', 'decoder.block.10.layer.0.SelfAttention.v.weight', 'decoder.block.22.layer.1.EncDecAttention.k.weight', 'decoder.block.6.layer.2.DenseReluDense.wi.weight', 'decoder.block.3.layer.1.layer_norm.weight', 'decoder.block.15.layer.0.SelfAttention.o.weight', 'decoder.block.0.layer.1.EncDecAttention.v.weight', 'decoder.block.15.layer.1.EncDecAttention.q.weight', 'decoder.block.21.layer.0.SelfAttention.q.weight', 'decoder.block.20.layer.2.DenseReluDense.wo.weight', 'decoder.block.13.layer.1.layer_norm.weight', 'decoder.block.0.layer.1.layer_norm.

In [3]:
#TM-Vec model paths
tm_vec_model_cpnt = "/mnt/home/thamamsy/public_www/tm_vec_cath_model_large.ckpt"
tm_vec_model_config = "/mnt/home/thamamsy/public_www/tm_vec_cath_model_large_params.json"

#Load the TM-Vec model
tm_vec_model_config = trans_basic_block_Config.from_json(tm_vec_model_config)
model_deep = trans_basic_block.load_from_checkpoint(tm_vec_model_cpnt, config=tm_vec_model_config)
model_deep = model_deep.to(device)
model_deep = model_deep.eval()

In [13]:
sequence_path = '/mnt/home/thamamsy/public_www/cath-domain-seqs-large.fa'
record_ids = []
record_seqs = []
with open(sequence_path) as handle:
    for record in SeqIO.parse(handle, "fasta"):
        record_ids.append(record.id)
        record_seqs.append(str(record.seq))

In [30]:
#Let's encode 10K sequences ("queries") and search a lookup database with 1M sequences
start_time = time.time()

query_seqs = record_seqs[0:10000]
queries = encode(query_seqs, model_deep, model, tokenizer, device)

print("--- %s seconds ---" % (time.time() - start_time))

--- 483.4614632129669 seconds ---


In [34]:
#Now let's load a lookup database- here it consists of 500K sequences
#Load the database that we will query
#Make sure that the query database was encoded using the same model that's being applied to the query (i.e. CATH and CATH database)
start_time = time.time()
lookup_database = load_database("/mnt/home/thamamsy/public_www/embeddings_cath_s100_final.npy")
metadata_for_lookup_database = pd.read_csv("/mnt/home/thamamsy/public_www/embeddings_cath_s100_w_metadata.tsv", sep="\t")

print("--- %s seconds ---" % (time.time() - start_time))

Lookup Database dimensions: 
--- 0.39347171783447266 seconds ---


In [35]:
#Search databases
start_time = time.time()
k = 10 #Return 10 nearest neighbors for every query
D, I = query(lookup_database, queries, k)

print("--- %s seconds ---" % (time.time() - start_time))

--- 29.233018398284912 seconds ---
