In [1]:
import torch 
import esm
import json 
from tqdm import tqdm 
import numpy as np
from scipy.special import softmax
from utils import *
import random 
import pickle
import matplotlib.pyplot as plt

In [2]:
with open('../data/selected_protein.json', 'r') as file:
    selected_protein = json.load(file)

In [8]:
os.environ["CUDA_VISIBLE_DEVICES"] = '1'

In [9]:
model, alphabet = esm.pretrained.esm2_t36_3B_UR50D()

In [10]:
# put model on GPU if available
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model = model.eval()
#model.args.token_dropout = False

esm_alphabet = "".join(alphabet.all_toks[4:24])+"-"

In [11]:
def get_logits(seq):
  x,ln = alphabet.get_batch_converter()([("seq",seq)])[-1],len(seq)
  with torch.no_grad():
    f = lambda x: model(x)["logits"][0,1:(ln+1),4:24].cpu().numpy()
    x = x.to(device)
    logits = f(x)
    return logits

def get_masked_logits(seq, p=None, get_pll=False):
  x,ln = alphabet.get_batch_converter()([(None,seq)])[-1],len(seq)
  if p is None: p = ln
  with torch.no_grad():
    def f(x):
      fx = model(x)["logits"][:,1:(ln+1),4:24]
      return fx

    logits = np.zeros((ln,20))
    for n in range(0,ln,p):
      m = min(n+p,ln)
      x_h = torch.tile(torch.clone(x),[m-n,1])
      for i in range(m-n):
        x_h[i,n+i+1] = alphabet.mask_idx
      fx_h = f(x_h.to(device))
      for i in range(m-n):
        logits[n+i] = fx_h[i,n+i].cpu().numpy()
  if get_pll:
    logits = np.log(softmax(logits,-1))
    x = x.cpu().numpy()[0]
    x = x[1:(ln+1)] - 4
    return sum([logits[n,i] for n,i in enumerate(x)])
  else:
    return logits

In [12]:
def get_categorical_jacobian(seq):
  # ∂in/∂out
  x,ln = alphabet.get_batch_converter()([("seq",seq)])[-1],len(seq)
  with torch.no_grad():
    f = lambda x: model(x)["logits"][...,1:(ln+1),4:24].cpu().numpy()
    fx = f(x.to(device))[0]
    x = torch.tile(x,[20,1]).to(device)
    fx_h = np.zeros((ln,20,ln,20))
    for n in range(ln): # for each position
      x_h = torch.clone(x)
      x_h[:,n+1] = torch.arange(4,24) # mutate to all 20 aa
      fx_h[n] = f(x_h)
    return fx-fx_h

In [13]:
def get_data(pdb): 
    file_path = 'msa/' + pdb + '.a3m'
    
    headers, seqs = parse_fasta(file_path, a3m = True)
    msa = mk_msa(seqs, alphabet=esm_alphabet)
    seq = seqs[0]
    
    tmp = inv_cov_jax(msa)
    print(tmp['apc'].shape)
    with open('../data/inv_cov_msa/' + pdb + '_inv_cov_msa.pkl', 'wb') as f:
        pickle.dump(tmp, f)
    
    with open('../data/msa_contact/' + pdb + '_msa_contact.pkl', 'wb') as f:
        pickle.dump(tmp["apc"], f)
    
    # jacobian of the model
    jac = get_categorical_jacobian(seq)
    # center & symmetrize
    for i in range(4): jac -= jac.mean(i,keepdims=True)
    jac = (jac + jac.transpose(2,3,0,1))/2
    print(jac.shape)
    
    with open('../results/esm2_jac/' + pdb + '_esm2_jac.pkl', 'wb') as f:
        pickle.dump(jac, f)
    
    jac_contacts = get_contacts(jac)
    with open('../results/esm2_jac_contact/' + pdb + '_esm2_jac_contact.pkl', 'wb') as f:
        pickle.dump(jac_contacts, f)

In [None]:
for pdb in tqdm(selected_protein):
    get_data(pdb)

  0%|          | 0/237 [00:00<?, ?it/s]

4M2MA


No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
