In [1]:
import tensorflow as tf
from models import MQAModel

from validate_performance_on_xtals import process_strucs, predict_on_xtals
import sys, os
import mdtraj as md
from glob import glob
import numpy as np

2022-05-19 22:30:54.379781: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
2022-05-19 22:30:56.436321: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2022-05-19 22:30:56.437535: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1
2022-05-19 22:30:56.833467: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 0 with properties: 
pciBusID: 0001:00:00.0 name: Tesla V100-PCIE-16GB computeCapability: 7.0
coreClock: 1.38GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s
2022-05-19 22:30:56.834619: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 1 with properties: 
pciBusID: 0002:00:00.0 name: Tesla V100-PCIE-16GB computeCapability: 7.0
coreClock: 1.38GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s
2022-

### Get prediction labels

In [25]:
def make_predictions(strucs, model, nn_path):
    '''
    strucs : list of single frame MDTraj trajectories
    model : MQAModel corresponding to network in nn_path
    nn_path : path to checkpoint files
    '''
    X, S, mask = process_strucs(strucs)
    predictions = predict_on_xtals(model, nn_path, X, S, mask)
    return predictions

In [26]:
pdb_file_path = '../data/ACE2.pdb'
model_path = '../models/1646754348_001'
output_path = '../outputs'

##### Load the input PDB file and check for exceptions in the format

In [28]:
strucs = [md.load(pdb_file_path)]

#### Look for exceptions
# check if PDB file is malformed, return error
# check how many chains.... if more than 1 chain, also send message: only 1 chain supported
# info from non-protein ligands is not incorporated

##### Load model and get predictions

In [50]:
# create a MQA model
DROPOUT_RATE = 0.1
NUM_LAYERS = 4
HIDDEN_DIM = 100

# MQA Model used for selected NN network
model = MQAModel(node_features=(8, 50), edge_features=(1, 32), hidden_dim=(16, HIDDEN_DIM),
                     num_layers=NUM_LAYERS, dropout=DROPOUT_RATE)

predictions = make_predictions(strucs, model, model_path)

np.savetxt(os.path.join(output_path,'predictions.txt'), predictions, fmt='%.4g', delimiter='\n')


CHECKPOINT RESTORED FROM ../models/1646754348_001


In [51]:
predictions.shape

TensorShape([1, 596])

### Generate output PDB file

In [52]:
from Bio.PDB import PDBParser, PDBIO
import warnings

#warnings.filterwarnings("ignore")

residue_types = {"ALA" : "A" , "ARG" : "R" , "ASN" : "N" , "ASP" : "D" , "CYS" : "C" , "CYM" : "C", "GLU" : "E" , 
          "GLN" : "Q" , "GLY" : "G" , "HIS" : "H" , "ILE" : "I" , "LEU" : "L" , "LYS" : "K" , "MET" : "M" , 
          "PHE" : "F" , "PRO" : "P" , "SER" : "S" , "THR" : "T" , "TRP" : "W" , "TYR" : "Y" , "VAL" : "V"}

In [53]:
p = PDBParser()
structure = p.get_structure("ACE2", pdb_file_path)

# get only the first chain as that's the one where we have predictions
for chain in structure.get_chains():
    break

predictions = predictions.numpy()
num_res = 0
for res in chain.get_residues():
    if res.get_resname() in residue_types:
        for atom in res.get_atoms():
            atom.set_bfactor(predictions[0,num_res])
            #print(atom.get_bfactor())
        num_res = num_res+1
    
io = PDBIO()
io.set_structure(structure)
io.save(os.path.join(output_path,"out.pdb"))   #, preserve_atom_numbering = True



In [54]:
print(predictions)

[[0.48711956 0.56252795 0.46445793 0.18135533 0.44590116 0.3502969
  0.06687918 0.30247313 0.33985758 0.19999526 0.30081442 0.17988338
  0.21175835 0.44905275 0.32371986 0.30504623 0.23767522 0.435769
  0.43880817 0.39842495 0.47724032 0.51145244 0.47854042 0.27030686
  0.15423127 0.29797772 0.29912856 0.07981896 0.2522303  0.40138686
  0.20876311 0.19166611 0.23305494 0.32099757 0.10099658 0.21530071
  0.13083641 0.08977848 0.04803484 0.0646581  0.17444883 0.05745269
  0.09031501 0.1251416  0.21056415 0.08988038 0.17424329 0.21288955
  0.28903753 0.32426172 0.52118325 0.51214886 0.39083317 0.26220813
  0.52058667 0.60489964 0.24156858 0.05961951 0.4739507  0.5618186
  0.21233258 0.02833413 0.47707966 0.47276607 0.43228358 0.48979425
  0.5690386  0.46660158 0.1575257  0.34963158 0.33921582 0.3204312
  0.660057   0.6517522  0.27138853 0.48501074 0.6737432  0.61947596
  0.02133246 0.5648836  0.67542696 0.68322057 0.40329605 0.70807534
  0.53097    0.578135   0.5586743  0.6288768  0.65733