In [1]:
%load_ext autoreload
%autoreload 2
import sys
import h5py 
import pandas as pd
import numpy as np
import torch 
sys.path.insert(0, "../examples")
sys.path.insert(0, "misato_dataset/components/")
from QMmodel import GNN_QM
from MDmodel import GNN_MD
from misato_dataset.components.transformQM import GNNTransformQM
from misato_dataset.components.transformMD import GNNTransformMD
from misato_dataset.processing.inference_QM import main




## Creation H5 file from a ligand pdbid

We want to run inference on a new structure from PDB. It is either possible to provide a already downloaded fileName or to just give the pdbid and it will be downloaded automatically. (If you run the script directly in the terminal just give the keywords in the promt)

In [2]:
class Args:
  pdbid = "vww"
  fileName = None
  datasetOutName = 'inference_for_qm.hdf5'
args=Args()

In [3]:
main(args)

reading vww.sdf


## Prediction of Ionization potential and Hardness by our model

We load the created h5 file and store the elements and coordinates in a dataframe.

In [None]:
qmh5_file = "inference_for_qm.hdf5"
qm_H5File = h5py.File(qmh5_file)

In [None]:
column_names = ["x", "y", "z", "element"]
atoms = pd.DataFrame(columns = column_names)

prop = qm_H5File["vww"]["atom_properties"]["atom_properties_values"]
atoms["x"] = prop[:,0].astype(np.float32)
atoms["y"] = prop[:,1].astype(np.float32)
atoms["z"] = prop[:,2].astype(np.float32)
        
atoms["element"] = np.array([element for element in qm_H5File['vww']['atom_properties']['atoms_names'][:]])


In [None]:
item = {
    "atoms" : atoms,
    "labels": 0,
    "bonds": None, 
    "id": "vww"
}

transform = GNNTransformQM()
data_item = transform(item)

We run inference using cpu.

In [None]:
model = GNN_QM(data_item.num_features, 64)
cpt = torch.load("../examples/logs/QM_latest/best_weights_rep0.pt", map_location=torch.device('cpu'))["model_state_dict"]
model.load_state_dict(cpt)
model.eval()

GNN_QM(
  (lin0): Linear(in_features=25, out_features=64, bias=True)
  (conv): NNConv(64, 64, aggr=mean, nn=Sequential(
    (0): Linear(in_features=1, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=4096, bias=True)
  ))
  (gru): GRU(64, 64)
  (set2set): Set2Set(64, 128)
  (lin1): Linear(in_features=128, out_features=64, bias=True)
  (lin2): Linear(in_features=64, out_features=2, bias=True)
)

In [None]:
# predict with the model
y_hat = model(data_item)

In [None]:
y_hat

tensor([ 0.0480, -0.0375], grad_fn=<ViewBackward0>)

## Creating H5 file for a protein-ligand complex

Similar to the ligand case we download a pdb file, convert it to amber format and store it in an h5 file. For this step you need to have installed ambertools so you might have to switch the conda env.

In [None]:
from data.processing.pdb_to_h5 import main

In [None]:
mdh5_file = "inference_for_md.hdf5"

In [None]:
class Args:
  pdbid = "11GS"
  fileName = None
  mapPath = "data/processing/Maps/"
  mask = "!@H=" # no Hydrogens, see https://amberhub.chpc.utah.edu/atom-mask-selection-syntax/
  datasetOutName = mdh5_file
args=Args()

In [None]:
main(args)

11GS/11GS.pdb was created. Please always use this file for inspection because the coordinates might get translated during amber file generation and thus might vary from the input pdb file.
The following trajectory was created: pytraj.TrajectoryIterator, 1 frames: 
Size: 0.000146 (GB)
<Topology: 6534 atoms, 416 residues, 2 mols, non-PBC>
           
molecule begin atom index [0, 1631, 3262] [1631, 1631]


In [None]:
class Args:
  pdbid = "11GS"
  fileName = "11GS.pdb"
  mapPath = "data/processing/Maps/"
  mask = "" # all atoms, see https://amberhub.chpc.utah.edu/atom-mask-selection-syntax/
  datasetOutName = 'all_atoms_11GS.hdf5'
args=Args()

In [None]:
main(args)

11GS/11GS.pdb was created. Please always use this file for inspection because the coordinates might get translated during amber file generation and thus might vary from the input pdb file.
The following trajectory was created: pytraj.TrajectoryIterator, 1 frames: 
Size: 0.000146 (GB)
<Topology: 6534 atoms, 416 residues, 2 mols, non-PBC>
           
molecule begin atom index [0, 3267, 6534] [3267, 3267]


## Prediction of adaptability by our model

In [None]:
# switch to misato env if not running from container
mdh5_file = "inference_for_md.hdf5"
md_H5File = h5py.File(mdh5_file)

column_names = ["x", "y", "z", "element"]
atoms_protein = pd.DataFrame(columns = column_names)
cutoff = md_H5File["11GS"]["molecules_begin_atom_index"][:][-1] # cutoff defines protein atoms

atoms_protein["x"] = md_H5File["11GS"]["atoms_coordinates_ref"][:][:cutoff, 0]
atoms_protein["y"] = md_H5File["11GS"]["atoms_coordinates_ref"][:][:cutoff, 1]
atoms_protein["z"] = md_H5File["11GS"]["atoms_coordinates_ref"][:][:cutoff, 2]

atoms_protein["element"] = md_H5File["11GS"]["atoms_element"][:][:cutoff]  

item = {}
item["scores"] = 0
item["id"] = "11GS"
item["atoms_protein"] = atoms_protein

transform = GNNTransformMD()
data_item = transform(item)



In [None]:
 md_H5File["11GS"]["molecules_begin_atom_index"][:]

array([   0, 1631, 3262])

In [None]:
import torch 
model = GNN_MD(data_item.num_features, 64)

cpt = torch.load("../examples/logs/MD_latest/best_weights_rep0.pt", map_location=torch.device('cpu'))["model_state_dict"]

model.load_state_dict(cpt)

model.eval()

GNN_MD(
  (conv1): GCNConv(11, 64)
  (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): GCNConv(64, 128)
  (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): GCNConv(128, 256)
  (bn3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): GCNConv(256, 256)
  (bn4): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv5): GCNConv(256, 512)
  (bn5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=512, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=1, bias=True)
)

In [None]:
model(data_item).shape

torch.Size([3262])

In [None]:
adaptability = model(data_item)


In [None]:
adaptability = adaptability.detach().numpy()

In [None]:
# this step might be necessesary in case you have to change the kernel to ambertools env
import pickle
pickle.dump(adaptability.detach().numpy(), open('inference_adaptability.pickle', 'wb'))

## Visualization of Adaptability

In [None]:
# switch to ambertools env if not running from container
import nglview as nv
import pytraj as pt
import os
import h5py
import pickle
import numpy as np



In [None]:
adaptability = pickle.load(open('inference_adaptability.pickle', 'rb'))

In [None]:
def get_index_conversion(all_atom_file):
    atoms_element= all_atom_file['atoms_element'][:]
    atoms_coordinates_ref = all_atom_file['atoms_coordinates_ref'][:]
    index_conversion = {}
    noh_indices = np.where(atoms_element[:]!=1)[0] # change if not hydrogen
    #h_indices = np.where(atoms_element[:]=1)[0]
    equivalent_noh_index = 0
    for all_atom_index in range(np.shape(atoms_coordinates_ref)[0]):
        if all_atom_index in noh_indices:
            index_conversion[equivalent_noh_index]=all_atom_index
            equivalent_noh_index +=1
    return index_conversion


def show_ada_spheres(traj, ada_indices, residue_indices, prediction, color, radiusFactor):
    for i in range(len(ada_indices)):
        pred_mask = '@'+str(residue_indices[i]+1)
        x,y,z = traj[pred_mask].xyz[:,:,:][0][0]
        view.shape.add_sphere([x, y, z], color, prediction[ada_indices[i]]/radiusFactor)

def add_opacity_to_spheres(num_spheres, opacity):
    for i in range(num_spheres):
        view.update_representation(component=view.n_components-i, opacity=opacity)
        
def convert_indices(indices, index_conversion):
    values = []
    for index in indices:
        values.append(index_conversion[index])
    return values

We need to load the h5 file with hydrogens and the h5 file with the hydrogens stripped (noh) after processing so that we assign the correct atom indices for the pdb file that we want to visualize. 

In [None]:
f_inference = h5py.File('inference_for_md.hdf5', 'r')
f_all_atom = h5py.File('all_atoms_11GS.hdf5', 'r')

In [None]:
f_inference['11GS'].keys()

<KeysViewHDF5 ['atoms_coordinates_ref', 'atoms_element', 'atoms_number', 'atoms_residue', 'atoms_type', 'molecules_begin_atom_index']>

In [None]:
f_inference["11GS"]['atoms_element']

<HDF5 dataset "atoms_element": shape (3262,), type "<i8">

In [None]:
f_all_atom["11GS"].keys()

<KeysViewHDF5 ['atoms_coordinates_ref', 'atoms_element', 'atoms_number', 'atoms_residue', 'atoms_type', 'molecules_begin_atom_index']>

In [None]:
index_conversion = get_index_conversion(f_all_atom["11GS"])

In [None]:
inverse_index_conversion= {value:key for key,value in index_conversion.items()}

In [None]:
struct = '11GS'
traj = pt.load(struct+'/'+struct+'.pdb')
view = nv.show_pytraj(traj)

In [None]:
view

NGLWidget()

In [None]:
residue_indices1 = list(traj.top.atom_indices(':1@C=,N=,O=,S='))
residue_indices2 = list(traj.top.atom_indices(':327@C=,N=,O=,S='))

residue_indices = residue_indices1+residue_indices2
converted_indices = convert_indices(residue_indices, inverse_index_conversion)

In [None]:
show_ada_spheres(traj, converted_indices, residue_indices, adaptability, (1,0,0), 1.5)

In [None]:
view.render_image(trim=True, factor=12)

Image(value=b'', width='99%')

In [None]:
view.download_image()

In [None]:
add_opacity_to_spheres(view.n_components, 0.5)

In [None]:
view.n_components

17