In [58]:
from rlvs.molecule_world.datasets import DataStore
from rlvs.molecule_world.complex import Complex
from rlvs.molecule_world.helper_functions import *

DataStore.init(crop=False)
complexes = []
ligands = []
for i in range(2):
    protein, ligand = DataStore.DATA[i]
    print(protein.n_atoms, ligand.n_atoms)
    ligands.append(ligand)
    complex = Complex(protein, ligand)
    complexes.append(complex)


2859 48
429 18


In [59]:
from deepchem.feat.mol_graphs import MultiConvMol, ConvMol

def mols_to_inputs(mols):
    multiConvMol = ConvMol.agglomerate_mols(mols)
    n_samples = np.array([len(mols)])
    inputs = [multiConvMol.get_atom_features(), multiConvMol.deg_slice,
                np.array(multiConvMol.membership), n_samples]
    for i in range(1, len(multiConvMol.get_deg_adjacency_lists())):
        inputs.append(multiConvMol.get_deg_adjacency_lists()[i])
    return inputs
    
def get_gc_inputs(inputs):
    atom_features = np.expand_dims(inputs[0], axis=0)
    degree_slice = np.expand_dims(tf.cast(inputs[1], dtype=tf.int32), axis=0)
    membership = np.expand_dims(tf.cast(inputs[2], dtype=tf.int32), axis=0)
    n_samples = np.expand_dims(tf.cast(inputs[3], dtype=tf.int32), axis=0)
    deg_adjs = [np.expand_dims(tf.cast(deg_adj, dtype=tf.int32), axis=0) for deg_adj in inputs[4:]]

    in_layer = atom_features

    gc_in = [in_layer, degree_slice, membership, n_samples] + deg_adjs
    
    return gc_in

In [60]:
from rlvs.network.graph_layer import GraphConv, GraphGather, GraphPool
from tensorflow.keras.layers import Dense, Input, add
from tensorflow.keras.models import Model

import tensorflow as tf


In [61]:
def _create_molecule_network(jj=0):

    features_input = Input(shape=(None, 18,), name=f"critic_Feature_{jj}", batch_size=1) 
    degree_slice_input = Input(shape=(11,2), dtype=tf.int32, name=f"critic_Degree_slice_{jj}", batch_size=1)
    membership = Input(shape=(None,), dtype=tf.int32, name=f'membership_{jj}', batch_size=1)
    n_samples = Input(shape=(1,), dtype=tf.int32, name=f'n_samples_{jj}', batch_size=1)
    deg_adjs_input = [Input(shape=(None,None,), dtype=tf.int32, name=f"critic_deg_adjs_{jj}_{i}", batch_size=1) for i in  range(10)]

    input_states = [features_input, degree_slice_input, membership, n_samples] + deg_adjs_input
    graph_layer = GraphConv(out_channel=64, activation_fn=tf.nn.relu)(input_states)

    graph_pool_in = [graph_layer, degree_slice_input, membership, n_samples] + deg_adjs_input
    graph_pool = GraphPool()(graph_pool_in)
    dense_layer = Dense(128, activation=tf.nn.relu)(graph_pool)

    return input_states, GraphGather(activation_fn=tf.nn.tanh)([dense_layer, membership, n_samples])

In [62]:
ip_1, graph_gather_layer_1 = _create_molecule_network(1)
ip_2, graph_gather_layer_2 = _create_molecule_network(2)
mol1_model = Model(inputs=ip_1, outputs=graph_gather_layer_1)
mol2_model = Model(inputs=ip_2, outputs=graph_gather_layer_2)

combination_layer = add([mol1_model.output, mol2_model.output])
combined_dense_layer = Dense(64, activation=tf.nn.relu)(combination_layer)
conv_model_1 = Model([ip_1, ip_2], combined_dense_layer)

In [63]:
ligands = [c.ligand for c in complexes]
proteins = [c.protein for c in complexes]

protein_batch  = get_gc_inputs(mols_to_inputs(proteins))
ligand_batch  = get_gc_inputs(mols_to_inputs(ligands))

In [64]:
conv_model_1.compile()

In [65]:
conv_model_1([protein_batch, ligand_batch])

<tf.Tensor: shape=(2, 64), dtype=float32, numpy=
array([[0.        , 0.        , 0.        , 0.        , 0.        ,
        1.9273264 , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 3.4572368 ,
        1.6342638 , 1.299733  , 0.25167614, 0.        , 0.        ,
        0.        , 0.        , 0.        , 1.7568893 , 0.02819663,
        0.        , 0.        , 0.8645285 , 0.19175795, 0.        ,
        0.        , 1.3853934 , 0.4931846 , 0.12076679, 0.        ,
        0.        , 0.23465212, 0.        , 0.        , 0.        ,
        2.412134  , 0.        , 0.        , 0.        , 0.00671268,
        0.        , 0.        , 1.9239496 , 0.45157376, 0.        ,
        0.        , 1.2090018 , 0.        , 0.        , 1.2155864 ,
        0.        , 0.        , 0.        , 0.4181779 ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        1.

In [90]:
assert( (atom_features[0, np.argwhere(membership==1)[:,1], :] == ligands[1].atom_features ).all())

  """Entry point for launching an IPython kernel.


False