In [1]:
%matplotlib inline
from matplotlib import pyplot as plt
from molgym.mpnn.layers import GraphNetwork, Squeeze
from molgym.mpnn.data import make_data_loader
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import InverseTimeDecay
from tensorflow.keras.layers import Input, Lambda, Dense
from tensorflow.keras.models import Model
from tensorflow.keras import callbacks as cb
from scipy.stats import spearmanr, kendalltau
import tensorflow as tf
import numpy as np
import json

In [2]:
train_loader = make_data_loader('train_data.proto', shuffle_buffer=1024)
val_loader = make_data_loader('val_data.proto')
test_loader = make_data_loader('test_data.proto')

In [3]:
with open('atom_types.json') as fp:
    atom_type_count = len(json.load(fp))
with open('bond_types.json') as fp:
    bond_type_count = len(json.load(fp))

In [4]:
def build_fn(atom_features=64, message_steps=8):
    node_graph_indices = Input(shape=(1,), name='node_graph_indices', dtype='int32')
    atom_types = Input(shape=(1,), name='atom', dtype='int32')
    bond_types = Input(shape=(1,), name='bond', dtype='int32')
    connectivity = Input(shape=(2,), name='connectivity', dtype='int32')
    
    # Squeeze the node graph and connectivity matrices
    snode_graph_indices = Squeeze(axis=1)(node_graph_indices)
    satom_types = Squeeze(axis=1)(atom_types)
    sbond_types = Squeeze(axis=1)(bond_types)
    
    output = GraphNetwork(atom_type_count, bond_type_count, atom_features, message_steps,
                          output_layer_sizes=[512, 256, 128],
                          atomic_contribution=False, reduce_function='max',
                          name='mpnn')([satom_types, sbond_types, snode_graph_indices, connectivity])
    
    # Scale the output
    output = Dense(1, activation='linear', name='scale')(output)
    
    return Model(inputs=[node_graph_indices, atom_types, bond_types, connectivity],
                 outputs=output)

In [5]:
model = build_fn(atom_features=256, message_steps=8)

In [7]:
ic50s = np.concatenate([x[1].numpy() for x in iter(train_loader)], axis=0)

In [8]:
model.get_layer('scale').set_weights([np.array([[ic50s.std()]]), np.array([ic50s.mean()])])

In [6]:
model.compile(Adam(InverseTimeDecay(1e-3, 64, 0.5)), 'mean_squared_error', metrics=['mean_absolute_error'])

In [10]:
history = model.fit(train_loader, validation_data=val_loader, epochs=10, verbose=False, 
                   shuffle=False, callbacks=[
                       cb.ModelCheckpoint('best_model.h5', save_best_only=True),
                       cb.EarlyStopping(patience=128, restore_best_weights=True),
                       cb.CSVLogger('train_log.csv'),
                       cb.TerminateOnNaN()
                   ])

  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tenso

In [9]:
model.load_weights('model.h5')
model.save('./saved_models/mpnn_12_4_22')

INFO:tensorflow:Assets written to: ./saved_models/mpnn_12_4_22/assets
