In [1]:
import sys

import h5py 

import pandas as pd
import numpy as np

import torch 

sys.path.insert(0, "../examples")

from QMmodel import GNN_QM
from MDmodel import GNN_MD

from data.components.transformQM import GNNTransformQM
from data.components.transformMD import GNNTransformMD

from data.processing.inference_QM import main

## Creation H5 file from a pdbid

In [13]:
class Args:
  pdbid = "vww"
  fileName = None
  mapPath = "Maps/"

args=Args()

In [14]:
main(args)

reading vww.sdf
{'H': 1, 'Be': 4, 'B': 5, 'C': 6, 'N': 7, 'O': 8, 'F': 9, 'Na': 11, 'Mg': 12, 'Si': 14, 'P': 15, 'S': 16, 'Cl': 17, 'K': 19, 'Ca': 20, 'V': 23, 'Fe': 26, 'Co': 27, 'Cu': 29, 'Zn': 30, 'As': 33, 'Se': 34, 'Br': 35, 'Ru': 44, 'Rh': 45, 'Sb': 51, 'Te': 52, 'I': 53, 'Re': 75, 'Os': 76, 'Ir': 77, 'Pt': 78}


## Prediction of Ionization potential and Hardness by our model

In [2]:
qmh5_file = "inference_for_qm.hdf5"
qm_H5File = h5py.File(qmh5_file)

In [3]:
column_names = ["x", "y", "z", "element"]
atoms = pd.DataFrame(columns = column_names)

prop = qm_H5File["vww"]["atom_properties"]["atom_properties_values"]
atoms["x"] = prop[:,0].astype(np.float32)
atoms["y"] = prop[:,1].astype(np.float32)
atoms["z"] = prop[:,2].astype(np.float32)
        
atoms["element"] = np.array([element for element in qm_H5File['vww']['atom_properties']['atoms_names'][:]])


In [4]:
item = {
    "atoms" : atoms,
    "labels": 0,
    "bonds": None, 
    "id": "vww"
}

transform = GNNTransformQM()
data_item = transform(item)

In [6]:
model = GNN_QM(data_item.num_features, 64)

cpt = torch.load("../examples/logs/qm_test/best_weights_rep0.pt", map_location=torch.device('cpu'))["model_state_dict"]

model.load_state_dict(cpt)

model.eval()

GNN_QM(
  (lin0): Linear(in_features=25, out_features=64, bias=True)
  (conv): NNConv(64, 64, aggr=mean, nn=Sequential(
    (0): Linear(in_features=1, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=4096, bias=True)
  ))
  (gru): GRU(64, 64)
  (set2set): Set2Set(64, 128)
  (lin1): Linear(in_features=128, out_features=64, bias=True)
  (lin2): Linear(in_features=64, out_features=2, bias=True)
)

In [7]:
# predict with the model
y_hat = model(data_item)

In [8]:
y_hat

tensor([6.6293, 1.9390], grad_fn=<ViewBackward0>)

## Prediction of softness and hardness by our model

In [9]:
mdh5_file = "data/processing/inference_from_pdb.hdf5"
md_H5File = h5py.File(mdh5_file)

column_names = ["x", "y", "z", "element"]
atoms_protein = pd.DataFrame(columns = column_names)



cutoff = md_H5File["11gs"]["molecules_begin_atom_index"][:][-1]

atoms_protein["x"] = md_H5File["11gs"]["atoms_coordinates_ref"][:][:cutoff, 0]
atoms_protein["y"] = md_H5File["11gs"]["atoms_coordinates_ref"][:][:cutoff, 1]
atoms_protein["z"] = md_H5File["11gs"]["atoms_coordinates_ref"][:][:cutoff, 2]

atoms_protein["element"] = md_H5File["11gs"]["atoms_element"][:][:cutoff]  

item = {}
item["scores"] = 0
item["id"] = "11gs"
item["atoms_protein"] = atoms_protein

transform = GNNTransformMD()
data_item = transform(item)



In [10]:
import torch 
model = GNN_MD(data_item.num_features, 64)

# cpt = torch.load("/p/project/hai_drug_qm/MiSaTo-dataset/examples/logs/qm_test/best_weights_rep0.pt", map_location=torch.device('cpu'))["model_state_dict"]

# model.load_state_dict(cpt)

model.eval()

GNN_MD(
  (conv1): GCNConv(11, 64)
  (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): GCNConv(64, 128)
  (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): GCNConv(128, 256)
  (bn3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): GCNConv(256, 256)
  (bn4): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv5): GCNConv(256, 512)
  (bn5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=512, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=1, bias=True)
)

In [11]:
model(data_item).shape

torch.Size([1631])

In [12]:
data_item

Data(x=[1631, 11], edge_index=[2, 27710], edge_attr=[27710], y=[0], pos=[1631, 3], ids='11gs')