In [1]:
import json
import os
import pickle
import numpy as np
import tensorflow as tf
import mpnn
from datetime import datetime
from typing import Dict

tf.keras.backend.clear_session()
np.random.seed(42)

In [2]:
data_fn = "data/molecules/mol_100000_0.json"
with open(data_fn, "r") as fp:
    graph = json.load(fp)
num_node_features = tf.convert_to_tensor(
    [list(node.values()) for node in graph["node_list"].values()]
).shape[1]
num_edge_features = tf.convert_to_tensor(
    [list(graph["edge_list"][key].values()) for key in sorted(graph["edge_list"].keys())]
).shape[1]
dataShape = mpnn.GraphInputShapes(
    node_features=tf.TensorShape((None, num_node_features)),
    edge_features=tf.TensorShape((None, num_edge_features))
)
    
def get_data(molecules_fn):
    """Data generator for molecule JSON files into GraphInput namedtuples."""
    for fn in molecules_fn:
        with open(fn, "r") as fp:
            graph = json.load(fp)
        data = mpnn.GraphInput(
            node_features=tf.convert_to_tensor([list(node.values()) for node in graph["node_list"].values()]),
            edge_features=tf.convert_to_tensor(
                [list(graph["edge_list"][key].values()) for key in sorted(graph["edge_list"].keys())]
            ),
            adjacency_lists=tuple([tf.convert_to_tensor(adj) for adj in graph["adjacency"].values()]),
            edge_to_idx={key: sorted(graph["edge_list"].keys()).index(key) for key in graph["edge_list"]}
        )
        target = tf.constant([graph["molecule_props"]["dipole_moment"]])
        yield data, target

dataShape

GraphInputShapes(node_features=TensorShape([None, 11]), edge_features=TensorShape([None, 5]))

In [3]:
model = mpnn.MessagePassingNet()
model.compile(
    optimizer="adam",
    loss="mse",
    metrics="mean_absolute_percentage_error"
)
model.build(dataShape)

In [4]:
data_dir = "data/molecules/"
molecules_fn = np.array([os.path.join(data_dir, mol) for mol in os.listdir(data_dir)])
np.random.shuffle(molecules_fn)
validation = molecules_fn[:10000].tolist()
train = molecules_fn[10000:].tolist()

In [None]:
loss = model.fit(get_data(train))

In [None]:
metric = model.evaluate(get_data(validation))

In [None]:
with open(
    os.path.join("data", f"{datetime.utcnow().strftime('%Y%m%dT%H')}.pkl"),
    "wb"
) as fp:
    pickle.dump([loss, metric], fp)