In [1]:
import sys, os
from constants import *
import torch, esm, scipy
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from tqdm import tqdm
from Bio.PDB import PDBParser, PPBuilder
import matplotlib.patches as patches

# import random
np.random.seed(1024)
import pickle

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.to(device).eval()

AAs = "ACDEFGHIKLMNPQRSTVWY"

# using ESM dictionary
aa2indx = {tok:idx for idx, tok in enumerate(alphabet.all_toks) if tok in AAs}
id2aa = {i:aa for i,aa in enumerate(list(aa2indx.keys()))}
aa2id_np = {aa:i for i,aa in id2aa.items()}

# VR1 region
VR_start, VR_end = 253,274
VRi = 'NHLYKQISNSTSGGSSNDNAY' #seq_wt[253:274]
assert seq_wt[VR_start:VR_end] == VRi

In [3]:
def get_ESM2_probs_VR1(motif, pep):
    assert pep[2:-2] == motif
    start = seq_wt.index(VR)
    end = seq_wt.index(VR)+len(VR)
    assert seq_wt[start:end] == VR
    
    seq_mu = seq_wt
    seq_mu = seq_mu[:start] + pep + seq_mu[end:]
    
    data = [(f"seq", seq_mu),]
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
    with torch.no_grad():
        results = model(batch_tokens.to(device))
    
    logits = results['logits'][:,1:-1,list(aa2indx.values())].cpu()[0]
    probs = scipy.special.softmax(logits, axis=-1)
    
    # get VR1 score
    ps = probs[VR_start:VR_end]
    assert seq_mu[VR_start:VR_end] == seq_wt[VR_start:VR_end] == VRi
    
    ps = [float(ps[i, aa2id_np[VRi[i]]]) for i in range(len(VRi))]
    p_VR = np.average(ps)
    return float(p_VR)

In [4]:
for p, motif in motifs.items():
    df = pd.read_csv(f'../ESM2_MPNN_score/score_rosetta_ESM2_MPNN-BBonly_{p}.csv')    
    scores = []
    for pep in tqdm(df['VR4_seq'].tolist()):
        scores.append(get_ESM2_probs_VR1(motif, pep))
    
    df['ESM2_mean_VR1'] = scores
    df.to_csv(f'../ESM2_MPNN_score/score_rosetta_ESM2_MPNN-BBonly_{p}.csv', index=False)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 715/715 [02:05<00:00,  5.70it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 577/577 [01:42<00:00,  5.64it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 600/600 [01:46<00:00,  5.61it/s]
