In [1]:
import json
import os
import pickle
import numpy as np
import tensorflow as tf
import mpnn
from datetime import datetime
from time import time
from typing import Dict

tf.keras.backend.clear_session()
np.random.seed(42)

In [2]:
def get_data(molecules_fn, target="dipole_moment"):
    """Data generator for molecule JSON files into GraphInput namedtuples."""
    for fn in molecules_fn:
        with open(fn, "r") as fp:
            graph = json.load(fp)
        edge_sources = tf.constant(graph.pop("edge_sources"))
        edge_targets = tf.constant(graph.pop("edge_targets"))
        node_features = tf.stack(
            [tf.constant(values, dtype=float) for key, values in graph.items() if "node_" in key], axis=1
        )
        edge_features = tf.stack(
            [tf.constant(values, dtype=float) for key, values in graph.items() if "edge_" in key], axis=1
        )
        data = mpnn.GraphInput(
            node_features=node_features,
            edge_features=edge_features,
            edge_sources=edge_sources,
            edge_targets=edge_targets
        )
        y = tf.constant([graph[target]])
        yield data, y

In [3]:
training_dir = "data/training"
training_fn = [os.path.join(training_dir, fn) for fn in os.listdir(training_dir) if not fn.startswith(".")]
test_dir = "data/test"
test_fn = [os.path.join(test_dir, fn) for fn in os.listdir(test_dir) if not fn.startswith(".")]
target = "dipole_moment"
num_node_features=11
num_edge_features=5

training = tf.data.Dataset.from_generator(
    lambda: get_data(training_fn, target),
    output_types=((tf.float32, tf.float32, tf.int32, tf.int32), tf.float32),
    output_shapes=(
        (tf.TensorShape([None, num_node_features]), tf.TensorShape([None, num_edge_features]),
         tf.TensorShape([None]), tf.TensorShape([None])),
        tf.TensorShape([1])
    )
)
test = tf.data.Dataset.from_generator(
    lambda: get_data(test_fn, target),
    output_types=((tf.float32, tf.float32, tf.int32, tf.int32), tf.float32),
    output_shapes=(
        (tf.TensorShape([None, num_node_features]), tf.TensorShape([None, num_edge_features]),
         tf.TensorShape([None]), tf.TensorShape([None])),
        tf.TensorShape([1])
    )
)

In [4]:
model = mpnn.MessagePassingNet()

In [5]:
learning_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
    initial_learning_rate=1.935e-4, decay_steps=900, end_learning_rate=1.84e-4, power=1.0
)
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_schedule)
loss = tf.keras.losses.MeanAbsolutePercentageError()
metrics = [tf.keras.metrics.MeanSquaredError()]

model.compile(
    optimizer=optimizer,
    loss=loss,
    metrics=metrics
)

In [None]:
start = time()
loss = model.fit(training, epochs=1)
time() - start

In [None]:
loss[-1]

In [None]:
start = time()
metric = model.evaluate(test)
time() - start

In [None]:
metric