# bindPredict21

In [1]:
import sys
BASE_DIR = '..'
BINDPREDICT_DIR = '/afs/csail.mit.edu/u/s/samsl/Work/Applications/bindPredict'

sys.path.append(BINDPREDICT_DIR)
sys.path.append(BASE_DIR)

In [2]:
from config import FileSetter, FileManager
from bindEmbed21DL import BindEmbed21DL
from pathlib import Path

In [3]:
import pandas as pd
df = pd.read_csv(f'{BASE_DIR}/dataset/BindingDB/train.csv')

In [4]:
seqDict = {i:s for i,s in enumerate(sorted(df['Target Sequence'].unique()))}

In [5]:
with open('tmp.fasta','w+') as f:
    for k,v in seqDict.items():
        f.write(f">seq{k}\n{v}\n")

In [6]:
prediction_folder = f'{BASE_DIR}/saved_embeddings'
model_prefix = f'{BINDPREDICT_DIR}/trained_models/checkpoint'

In [7]:
query_fasta = 'tmp.fasta'
query_sequences = FileManager.read_fasta(query_fasta)
query_ids = list(query_sequences.keys())
ri = False  # Whether to write RI or Probabilities

In [32]:
import torch
from architectures import CNN2Layers
from src.prot_feats import ProtT5_XL_Uniref50_f
from functools import lru_cache

In [13]:
p5tf = ProtT5_XL_Uniref50_f(pool=False)

Some weights of the model checkpoint at Rostlab/prot_t5_xl_uniref50 were not used when initializing T5EncoderModel: ['decoder.embed_tokens.weight', 'decoder.block.0.layer.0.SelfAttention.q.weight', 'decoder.block.0.layer.0.SelfAttention.k.weight', 'decoder.block.0.layer.0.SelfAttention.v.weight', 'decoder.block.0.layer.0.SelfAttention.o.weight', 'decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight', 'decoder.block.0.layer.0.layer_norm.weight', 'decoder.block.0.layer.1.EncDecAttention.q.weight', 'decoder.block.0.layer.1.EncDecAttention.k.weight', 'decoder.block.0.layer.1.EncDecAttention.v.weight', 'decoder.block.0.layer.1.EncDecAttention.o.weight', 'decoder.block.0.layer.1.layer_norm.weight', 'decoder.block.0.layer.2.DenseReluDense.wi.weight', 'decoder.block.0.layer.2.DenseReluDense.wo.weight', 'decoder.block.0.layer.2.layer_norm.weight', 'decoder.block.1.layer.0.SelfAttention.q.weight', 'decoder.block.1.layer.0.SelfAttention.k.weight', 'decoder.block.1.layer.0.SelfAtte

In [18]:
md = CNN2Layers(1024,128,5,1,2,0)
md.load_state_dict(torch.load(f"{model_prefix}5.pt", map_location='cuda:1')['state_dict'])
md = md.cuda()
md = md.eval()
cnn_first = md.conv1[:2]

In [28]:
seq_2_bindPredict_embedding('MVMAGCCCGGGGGMV')

tensor([ 1.1825e-01,  3.0741e-02,  9.0513e-02, -4.9664e-02, -8.9960e-02,
        -1.1921e-01, -4.2788e-02,  1.5489e-01,  8.5960e-02,  1.5032e-01,
         1.6159e-02, -1.9241e-02,  1.9506e-01,  8.6426e-02,  1.7518e-04,
         1.6553e-01,  2.0697e-01, -8.3605e-02, -7.1277e-02, -1.2898e-01,
         4.5713e-02, -3.0861e-02, -1.7620e-01,  1.5022e-01,  9.7889e-02,
         1.3669e-01,  9.4064e-02, -3.5105e-02, -5.2145e-02,  1.4234e-01,
         5.9036e-02, -1.6029e-02, -7.7992e-02,  1.2072e-02,  1.9149e-02,
         1.8343e-01,  4.7672e-02, -1.6139e-01,  1.5365e-01,  1.0299e-01,
         1.8208e-02,  3.9731e-02, -8.4341e-04, -8.1135e-02,  6.1327e-02,
         8.5733e-02,  5.9068e-02, -1.8384e-01,  3.0972e-02,  1.2526e-01,
         7.7584e-02,  1.9154e-01,  6.9722e-02,  3.1139e-02,  1.4358e-02,
        -3.6775e-02, -1.6053e-02,  2.1056e-01,  8.8420e-02, -5.4827e-02,
        -1.0726e-01,  8.5581e-03,  2.9816e-02, -5.3433e-02,  1.0815e-01,
        -6.2001e-02,  3.0967e-02,  1.2867e-01, -1.2

In [42]:
class BindPredict21_f:
    def __init__(self, pool=False):
        BINDPREDICT_DIR = '/afs/csail.mit.edu/u/s/samsl/Work/Applications/bindPredict'
        model_prefix = f'{BINDPREDICT_DIR}/trained_models/checkpoint'
        sys.path.append(BINDPREDICT_DIR)
        from architectures import CNN2Layers
        
        self.use_cuda = True
        self.pool = pool
        self._size = 128
        self._max_len = 1024
        self.precomputed = False
        
        self._p5tf = ProtT5_XL_Uniref50_f(pool=False)
        self._md = CNN2Layers(1024,128,5,1,2,0)
        self._md.load_state_dict(torch.load(f"{model_prefix}5.pt", map_location='cuda:1')['state_dict'])
        self._md = self._md.cuda()
        self._md = self._md.eval()
        self._cnn_first = self._md.conv1[:2]
        self._embed = self._seq_2_bindPredict_embedding
        
    def _seq_2_bindPredict_embedding(self,seq,use_cuda=True):
        with torch.set_grad_enabled(False):
            protbert_e = self._p5tf(seq)
            bindpredict_e = self._cnn_first(protbert_e.view(1,1024,-1))
            return bindpredict_e.mean(axis=2).squeeze()

    def precompute(self, seqs, to_disk_path=True, from_disk=True):
        print("--- precomputing BindPredict21 protein featurizer ---")
        assert not self.precomputed
        precompute_path = f"{to_disk_path}_BindPredict21_f_PROTEINS{'_STACKED' if not self.pool else ''}.pk"
        if from_disk and os.path.exists(precompute_path):
            print("--- loading from disk ---")
            self.prot_embs = pk.load(open(precompute_path,"rb"))
        else:
            self.prot_embs = {}
            for sq in tqdm(seqs):
                if sq in self.prot_embs:
                    continue
                self.prot_embs[sq] = self._transform(sq)

            if to_disk_path is not None and not os.path.exists(precompute_path):
                print(f'--- saving protein embeddings to {precompute_path} ---')
                pk.dump(self.prot_embs, open(precompute_path,"wb+"))
        self.precomputed = True

    @lru_cache(maxsize=5000)
    def _transform(self, seq):
        if len(seq) > self._max_len:
            seq = seq[:self._max_len]

        with torch.no_grad():
            lm_emb = self._embed(seq, use_cuda=self.use_cuda)
            if self.pool:
                return lm_emb.squeeze().mean(axis=0)
            else:
                return lm_emb.squeeze()

    def __call__(self, seq):
        if self.precomputed:
            return self.prot_embs[seq]
        else:
            return self._transform(seq)

In [43]:
bp_f = BindPredict21_f()

Some weights of the model checkpoint at Rostlab/prot_t5_xl_uniref50 were not used when initializing T5EncoderModel: ['decoder.embed_tokens.weight', 'decoder.block.0.layer.0.SelfAttention.q.weight', 'decoder.block.0.layer.0.SelfAttention.k.weight', 'decoder.block.0.layer.0.SelfAttention.v.weight', 'decoder.block.0.layer.0.SelfAttention.o.weight', 'decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight', 'decoder.block.0.layer.0.layer_norm.weight', 'decoder.block.0.layer.1.EncDecAttention.q.weight', 'decoder.block.0.layer.1.EncDecAttention.k.weight', 'decoder.block.0.layer.1.EncDecAttention.v.weight', 'decoder.block.0.layer.1.EncDecAttention.o.weight', 'decoder.block.0.layer.1.layer_norm.weight', 'decoder.block.0.layer.2.DenseReluDense.wi.weight', 'decoder.block.0.layer.2.DenseReluDense.wo.weight', 'decoder.block.0.layer.2.layer_norm.weight', 'decoder.block.1.layer.0.SelfAttention.q.weight', 'decoder.block.1.layer.0.SelfAttention.k.weight', 'decoder.block.1.layer.0.SelfAtte

In [44]:
bp_f('MVMAGCCCGGGGGMV')

tensor([ 1.1825e-01,  3.0741e-02,  9.0513e-02, -4.9664e-02, -8.9960e-02,
        -1.1921e-01, -4.2788e-02,  1.5489e-01,  8.5960e-02,  1.5032e-01,
         1.6159e-02, -1.9241e-02,  1.9506e-01,  8.6426e-02,  1.7518e-04,
         1.6553e-01,  2.0697e-01, -8.3605e-02, -7.1277e-02, -1.2898e-01,
         4.5713e-02, -3.0861e-02, -1.7620e-01,  1.5022e-01,  9.7889e-02,
         1.3669e-01,  9.4064e-02, -3.5105e-02, -5.2145e-02,  1.4234e-01,
         5.9036e-02, -1.6029e-02, -7.7992e-02,  1.2072e-02,  1.9149e-02,
         1.8343e-01,  4.7672e-02, -1.6139e-01,  1.5365e-01,  1.0299e-01,
         1.8208e-02,  3.9731e-02, -8.4341e-04, -8.1135e-02,  6.1327e-02,
         8.5733e-02,  5.9068e-02, -1.8384e-01,  3.0972e-02,  1.2526e-01,
         7.7584e-02,  1.9154e-01,  6.9722e-02,  3.1139e-02,  1.4358e-02,
        -3.6775e-02, -1.6053e-02,  2.1056e-01,  8.8420e-02, -5.4827e-02,
        -1.0726e-01,  8.5581e-03,  2.9816e-02, -5.3433e-02,  1.0815e-01,
        -6.2001e-02,  3.0967e-02,  1.2867e-01, -1.2

# NetSurf-P3.0

In [128]:
import sys
BASE_DIR = '..'
NETSURF_DIR = '/afs/csail.mit.edu/u/s/samsl/Work/Applications/bindPredict'

sys.path.append(NETSURF_DIR)
sys.path.append(BASE_DIR)

In [52]:
import biolib
nsp3 = biolib.load('DTU/NetSurfP_3')

2022-06-22 15:25:06,737 | INFO : Loaded project DTU/NetSurfP-3:0.0.2
