In [1]:
# from rlvs.molecule_world.datasets import DataStore

import tensorflow as tf
import numpy as np

In [2]:
tf.__version__

'2.2.0'

### Get test data

In [3]:
from rlvs.molecule_world.helper_functions import *
file_ligand = "data/pafnucy_data/complexes/4mrw/4mrw_ligand.mol2"
file_protein = "data/pafnucy_data/complexes/4mrw/4mrw_pocket.mol2"

f = Featurizer()

obmol_ligand = read_to_OB(filename=file_ligand, filetype="mol2")
nodes_ligand, canon_adj_list_ligand = f.get_mol_features(obmol=obmol_ligand, molecule_type=1, bond_verbose=0)
ligand = Molecule(atom_features=nodes_ligand, canon_adj_list=canon_adj_list_ligand)

obmol_protein = read_to_OB(filename=file_protein, filetype="mol2")
nodes_protein, canon_adj_list_protein = f.get_mol_features(obmol=obmol_protein, molecule_type=-1, bond_verbose=0)
protein = Molecule(atom_features=nodes_protein, canon_adj_list=canon_adj_list_protein)

In [4]:
# from rdkit import Chem
# from deepchem.feat import ConvMolFeaturizer

# smiles = ['COC(C)(C)CCCC(C)CC=CC(C)=CC(=O)OC(C)C',
#           'CCOC(=O)CC',
#           'CSc1nc(NC(C)C)nc(NC(C)C)n1',
#           'CC(C#C)N(C)C(=O)Nc1ccc(Cl)cc1',
#           'Cc1cc2ccccc2cc1C']

# mols = [Chem.MolFromSmiles(s) for s in smiles]
# featurizer = ConvMolFeaturizer()
# mols = featurizer.featurize(mols)

In [5]:
mol1 = protein
mol2 = ligand
mol1.get_atom_features().shape, mol2.get_atom_features().shape

((550, 18), (18, 18))

### Prepare tensor inputs for convolution layer

TODO: Feature normalization?

In [6]:
inputs = []
for mol in [mol1, mol2]:
    atom_features = np.expand_dims(mol.get_atom_features(), axis=0)
    degree_slice = np.expand_dims((mol.deg_slice.astype(dtype='int32')), axis=0)
    deg_adjs = [np.expand_dims(deg_adj.astype(dtype='int32'), axis=0) for deg_adj in mol.get_deg_adjacency_lists()[1:]]

    gc_in = [atom_features, degree_slice] + deg_adjs
    
    inputs.append(gc_in)


### Define model (inputs -> protein-ligand combined feature vector)

In [7]:
from rlvs.network.graph_layer import GraphConv, GraphPool, GraphGather
from tensorflow.keras.layers import Input, Dense, add
from tensorflow.keras.models import Model

# first molecule inputs
features_input_1 = Input(shape=(None,mol.n_feat,), batch_size=1)
degree_slice_input_1 = Input(shape=(11,2), dtype=tf.int32, batch_size=1) #degree per atom
deg_adjs_input_1 = []
for i in range(10):
    deg_adjs_input_1.append(Input(shape=(None,None,), dtype=tf.int32, batch_size=1))
ip_1 = [features_input_1, degree_slice_input_1] + deg_adjs_input_1

# first molecule convolution
graph_layer_1 = GraphConv(out_channel=64, activation_fn=tf.nn.relu)(ip_1)
gp_in_1 = [graph_layer_1, degree_slice_input_1] + deg_adjs_input_1

# first molecule pooling (analogous to max pooling)
graph_pool_1 = GraphPool()(gp_in_1)
dense_layer_1 = Dense(128, activation="relu")(graph_pool_1)

# first molecule gather node level features into molecule level features
graph_gather_layer_1 = GraphGather(activation_fn=tf.nn.relu)(dense_layer_1)

mol1_model = Model(inputs=ip_1, outputs=graph_gather_layer_1)

# second molecule inputs
features_input_2 = Input(shape=(None,mol.n_feat,), batch_size=1)
degree_slice_input_2 = Input(shape=(11,2), dtype=tf.int32, batch_size=1)
deg_adjs_input_2 = []
for i in range(10):
    deg_adjs_input_2.append(Input(shape=(None,None,), dtype=tf.int32, batch_size=1))
ip_2 = [features_input_2, degree_slice_input_2] + deg_adjs_input_2

# second molecule convolution
graph_layer_2 = GraphConv(out_channel=64, activation_fn=tf.nn.relu)(ip_2)
gp_in_2 = [graph_layer_2, degree_slice_input_2] + deg_adjs_input_2

# second molecule pooling (analogous to max pooling)
graph_pool_2 = GraphPool()(gp_in_2)
dense_layer_2 = Dense(128, activation="relu")(graph_pool_2)

# second molecule gather node level features into molecule level features
graph_gather_layer_2 = GraphGather(activation_fn=tf.nn.relu)(dense_layer_2)

mol2_model = Model(inputs=ip_2, outputs=graph_gather_layer_2)

# combine molecule1 and molecule2 feature vectors
#TODO: is "add" the correct operation?
combination_layer = add([mol1_model.output, mol2_model.output])
combined_dense_layer = Dense(64, activation="relu")(combination_layer)

# define full model
m = Model([ip_1, ip_2], combined_dense_layer)
m.compile()
m.summary()

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(1, None, 18)]      0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(1, 11, 2)]         0                                            
__________________________________________________________________________________________________
input_3 (InputLayer)            [(1, None, None)]    0                                            
__________________________________________________________________________________________________
input_4 (InputLayer)            [(1, None, None)]    0                                            
____________________________________________________________________________________________

#### Actor Network Test

In [8]:
mol.n_feat

18

In [10]:
from rlvs.network import ActorGNN
from rlvs.molecule_world.env import GraphEnv
env = GraphEnv()
actor = ActorGNN(
            18,
            6,
            0.00005,
            0.001
        )
inputs = env.reset()
actor.actor(inputs)
inputs[0]

[array([[[ 9.5850e+00,  6.3040e+00,  4.8043e+01, ...,  0.0000e+00,
           2.0000e+00, -1.0000e+00],
         [ 1.5418e+01,  1.4069e+01,  4.5875e+01, ...,  1.0000e+00,
           2.7190e-01, -1.0000e+00],
         [ 1.3540e+01,  1.2143e+01,  4.8911e+01, ...,  0.0000e+00,
          -5.6790e-01, -1.0000e+00],
         ...,
         [-2.9130e+00,  1.1315e+01,  4.6582e+01, ...,  0.0000e+00,
          -1.8250e-01, -1.0000e+00],
         [-3.8290e+00,  2.2978e+01,  4.4985e+01, ...,  1.0000e+00,
          -2.7500e-02, -1.0000e+00],
         [-3.4360e+00,  2.1593e+01,  4.4447e+01, ...,  0.0000e+00,
          -5.0000e-03, -1.0000e+00]]]),
 array([[[  0,   1],
         [  1, 240],
         [241,  28],
         [269,  96],
         [365,  64],
         [  0,   0],
         [  0,   0],
         [  0,   0],
         [  0,   0],
         [  0,   0],
         [  0,   0]]], dtype=int32),
 array([[[241],
         [269],
         [242],
         [365],
         [366],
         [366],
         [271],


### Forward run of test data through model 

In [11]:
o = m(inputs)

In [12]:
o

<tf.Tensor: shape=(1, 64), dtype=float32, numpy=
array([[  545.26935,     0.     ,     0.     ,  3791.2192 ,  2360.0024 ,
         9495.732  ,  1209.4473 ,  4813.2646 ,     0.     ,   789.11066,
          684.5331 ,   826.39307,  7852.6924 ,  3336.8691 ,     0.     ,
            0.     ,  5453.1475 ,     0.     ,     0.     ,  2616.227  ,
            0.     ,  4508.199  ,  3329.6023 ,  6843.436  ,     0.     ,
         2942.189  ,     0.     ,     0.     ,   386.92505,     0.     ,
         5114.6855 ,  5607.7563 , 10091.157  ,   925.6692 ,     0.     ,
         3731.1091 ,     0.     ,  8693.792  ,     0.     ,  1674.274  ,
            0.     ,  1856.3945 ,   466.98758,     0.     ,     0.     ,
         2826.5874 ,  6300.3926 ,     0.     ,  1114.1597 ,  4391.5474 ,
            0.     ,  3172.9285 ,     0.     ,     0.     ,     0.     ,
         2555.6826 ,     0.     ,     0.     ,     0.     ,     0.     ,
            0.     ,  1317.4073 ,  2840.6704 ,     0.     ]],
      dtype=f

In [13]:
m.predict(inputs)

array([[  545.26935,     0.     ,     0.     ,  3791.2192 ,  2360.0024 ,
         9495.732  ,  1209.4473 ,  4813.2646 ,     0.     ,   789.11066,
          684.5331 ,   826.39307,  7852.6924 ,  3336.8691 ,     0.     ,
            0.     ,  5453.1475 ,     0.     ,     0.     ,  2616.227  ,
            0.     ,  4508.199  ,  3329.6023 ,  6843.436  ,     0.     ,
         2942.189  ,     0.     ,     0.     ,   386.92505,     0.     ,
         5114.6855 ,  5607.7563 , 10091.157  ,   925.6692 ,     0.     ,
         3731.1091 ,     0.     ,  8693.792  ,     0.     ,  1674.274  ,
            0.     ,  1856.3945 ,   466.98758,     0.     ,     0.     ,
         2826.5874 ,  6300.3926 ,     0.     ,  1114.1597 ,  4391.5474 ,
            0.     ,  3172.9285 ,     0.     ,     0.     ,     0.     ,
         2555.6826 ,     0.     ,     0.     ,     0.     ,     0.     ,
            0.     ,  1317.4073 ,  2840.6704 ,     0.     ]],
      dtype=float32)