In [1]:
import tensorflow as tf
from models import MQAModel

from validate_performance_on_xtals import process_strucs, predict_on_xtals
import sys, os
import mdtraj as md
from glob import glob
import numpy as np

2023-11-01 02:36:37.763198: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Get prediction labels

In [2]:
def make_predictions(strucs, model, nn_path):
    '''
    strucs : list of single frame MDTraj trajectories
    model : MQAModel corresponding to network in nn_path
    nn_path : path to checkpoint files
    '''
    X, S, mask = process_strucs(strucs)
    predictions = predict_on_xtals(model, nn_path, X, S, mask)
    return predictions

In [3]:
pdb_file_path = '../data/ACE2.pdb'
model_path = '../models/pocketminer.index'
output_path = '../outputs'

##### Load the input PDB file and check for exceptions in the format

In [4]:
strucs = [md.load(pdb_file_path)]

#### Look for exceptions
# check if PDB file is malformed, return error
# check how many chains.... if more than 1 chain, also send message: only 1 chain supported
# info from non-protein ligands is not incorporated

##### Load model and get predictions

In [5]:
# create a MQA model
DROPOUT_RATE = 0.1
NUM_LAYERS = 4
HIDDEN_DIM = 100

# MQA Model used for selected NN network
model = MQAModel(node_features=(8, 50), edge_features=(1, 32), hidden_dim=(16, HIDDEN_DIM),
                     num_layers=NUM_LAYERS, dropout=DROPOUT_RATE)

predictions = make_predictions(strucs, model, model_path)

np.savetxt(os.path.join(output_path,'predictions.txt'), predictions, fmt='%.4g', delimiter='\n')


Instructions for updating:
Restoring a name-based tf.train.Saver checkpoint using the object-based restore API. This mode uses global names to match variables, and so is somewhat fragile. It also adds new restore ops to the graph each time it is called when graph building. Prefer re-encoding training checkpoints in the object-based format: run save() on the object-based saver (the same one this message is coming from) and use that checkpoint in the future.
CHECKPOINT RESTORED FROM ../models/pocketminer.index


In [6]:
predictions.shape

TensorShape([1, 596])

### Generate output PDB file

In [None]:
from Bio.PDB import PDBParser, PDBIO
import warnings

#warnings.filterwarnings("ignore")

residue_types = {"ALA" : "A" , "ARG" : "R" , "ASN" : "N" , "ASP" : "D" , "CYS" : "C" , "CYM" : "C", "GLU" : "E" , 
          "GLN" : "Q" , "GLY" : "G" , "HIS" : "H" , "ILE" : "I" , "LEU" : "L" , "LYS" : "K" , "MET" : "M" , 
          "PHE" : "F" , "PRO" : "P" , "SER" : "S" , "THR" : "T" , "TRP" : "W" , "TYR" : "Y" , "VAL" : "V"}

In [None]:
p = PDBParser()
structure = p.get_structure("ACE2", pdb_file_path)

# get only the first chain as that's the one where we have predictions
for chain in structure.get_chains():
    break

predictions = predictions.numpy()
num_res = 0
for res in chain.get_residues():
    if res.get_resname() in residue_types:
        for atom in res.get_atoms():
            atom.set_bfactor(predictions[0,num_res])
            #print(atom.get_bfactor())
        num_res = num_res+1
    
io = PDBIO()
io.set_structure(structure)
io.save(os.path.join(output_path,"out.pdb"))   #, preserve_atom_numbering = True

In [None]:
print(predictions)