In [1]:
%load_ext autoreload
%autoreload 2
import sys
import h5py 
import pandas as pd
import numpy as np
import torch 
sys.path.insert(0, "../examples")
sys.path.insert(0, "data/components/")
from QMmodel import GNN_QM
from MDmodel import GNN_MD
from data.components.transformQM import GNNTransformQM
from data.components.transformMD import GNNTransformMD
from data.processing.inference_QM import main




## Creation H5 file from a ligand pdbid

We want to run inference on a new structure from PDB. It is either possible to provide a already downloaded fileName or to just give the pdbid and it will be downloaded automatically. (If you run the script directly in the terminal just give the keywords in the promt)

In [2]:
class Args:
  pdbid = "vww"
  fileName = None
  datasetOutName = 'inference_for_qm.hdf5'
args=Args()

In [3]:
main(args)

reading vww.sdf


## Prediction of Ionization potential and Hardness by our model

We load the created h5 file and store the elements and coordinates in a dataframe.

In [4]:
qmh5_file = "inference_for_qm.hdf5"
qm_H5File = h5py.File(qmh5_file)

In [5]:
column_names = ["x", "y", "z", "element"]
atoms = pd.DataFrame(columns = column_names)

prop = qm_H5File["vww"]["atom_properties"]["atom_properties_values"]
atoms["x"] = prop[:,0].astype(np.float32)
atoms["y"] = prop[:,1].astype(np.float32)
atoms["z"] = prop[:,2].astype(np.float32)
        
atoms["element"] = np.array([element for element in qm_H5File['vww']['atom_properties']['atoms_names'][:]])


In [6]:
item = {
    "atoms" : atoms,
    "labels": 0,
    "bonds": None, 
    "id": "vww"
}

transform = GNNTransformQM()
data_item = transform(item)

We run inference using cpu.

In [7]:
model = GNN_QM(data_item.num_features, 64)
cpt = torch.load("../examples/logs/QM_latest/best_weights_rep0.pt", map_location=torch.device('cpu'))["model_state_dict"]
model.load_state_dict(cpt)
model.eval()

GNN_QM(
  (lin0): Linear(in_features=25, out_features=64, bias=True)
  (conv): NNConv(64, 64, aggr=mean, nn=Sequential(
    (0): Linear(in_features=1, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=4096, bias=True)
  ))
  (gru): GRU(64, 64)
  (set2set): Set2Set(64, 128)
  (lin1): Linear(in_features=128, out_features=64, bias=True)
  (lin2): Linear(in_features=64, out_features=2, bias=True)
)

In [8]:
# predict with the model
y_hat = model(data_item)

In [9]:
y_hat

tensor([ 0.0480, -0.0375], grad_fn=<ViewBackward0>)

## Creating H5 file for a protein-ligand complex

In [4]:
from data.processing.pdb_to_h5 import main

In [5]:
class Args:
  pdbid = "11gs"
  fileName = None
  mapPath = "data/processing/Maps/"
  mask = "!@H=" # no Hydrogens, see https://amberhub.chpc.utah.edu/atom-mask-selection-syntax/
  datasetOutName = 'inference_for_qm.hdf5'
args=Args()

In [6]:
main(args)

11gs/11gs.pdb was created. Please always use this file for inspection because the coordinates might get translated during amber file generation and thus might vary from the input pdb file.
traj pytraj.TrajectoryIterator, 1 frames: 
Size: 0.000146 (GB)
<Topology: 6534 atoms, 416 residues, 2 mols, non-PBC>
           


## Prediction of adaptability by our model

In [10]:
mdh5_file = "data/processing/inference_from_pdb.hdf5"
md_H5File = h5py.File(mdh5_file)

column_names = ["x", "y", "z", "element"]
atoms_protein = pd.DataFrame(columns = column_names)



cutoff = md_H5File["11gs"]["molecules_begin_atom_index"][:][-1]

atoms_protein["x"] = md_H5File["11gs"]["atoms_coordinates_ref"][:][:cutoff, 0]
atoms_protein["y"] = md_H5File["11gs"]["atoms_coordinates_ref"][:][:cutoff, 1]
atoms_protein["z"] = md_H5File["11gs"]["atoms_coordinates_ref"][:][:cutoff, 2]

atoms_protein["element"] = md_H5File["11gs"]["atoms_element"][:][:cutoff]  

item = {}
item["scores"] = 0
item["id"] = "11gs"
item["atoms_protein"] = atoms_protein

transform = GNNTransformMD()
data_item = transform(item)



In [11]:
import torch 
model = GNN_MD(data_item.num_features, 64)

cpt = torch.load("../examples/logs/MD_latest/best_weights_rep0.pt", map_location=torch.device('cpu'))["model_state_dict"]

model.load_state_dict(cpt)

model.eval()

GNN_MD(
  (conv1): GCNConv(11, 64)
  (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): GCNConv(64, 128)
  (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): GCNConv(128, 256)
  (bn3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): GCNConv(256, 256)
  (bn4): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv5): GCNConv(256, 512)
  (bn5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=512, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=1, bias=True)
)

In [12]:
model(data_item).shape

torch.Size([1631])

In [13]:
data_item

Data(x=[1631, 11], edge_index=[2, 27710], edge_attr=[27710], y=[0], pos=[1631, 3], ids='11gs')