In [1]:
import json
import numpy as np
import os
import pandas as pd
import pysmiles
from tqdm.auto import tqdm

In [2]:
data_dir = "data/dsgdb9nsd"
limit = 1000
output_fn = "data/molecules/mol_{id}.json"

In [3]:
def process_files(data_dir="dsgdb9nsd", limit=None, output_fn="data/mol_{id}.json", verbose=False):
    files = os.listdir(data_dir)
    if limit is not None:
        files = files[:limit]
        name = output_fn.split(".")[0]
        output_fn = f"{name}_{limit}.json"
    for file in files[:limit]:
        with open(os.path.join(data_dir, file), "r") as fp:
            lines = [l.replace("\n", "").split("\t") for l in fp.readlines()]
        na = int(lines[0][0])
        molecule = {
            "na": na,
            "molecule_props": dict(zip(
                ["id", "rotational_a", "rotational_b", "rotational_c", "dipole_moment", "polarizability", "homo_energy", "lumo_energy", "spatial_extent", "internal_energy_0k", "internal_energy_298k", "free_energy", "heat_capacity"],
                [float(e.replace("gdb", "").strip()) for e in lines[1][:-1]]
            )),
            "coordinates": [
                [c[0], float(c[1].replace("*^", "e")), float(c[2].replace("*^", "e")), float(c[3].replace("*^", "e")), float(c[4].replace("*^", "e"))]
                for c in lines[2:(na+2)]
            ],
            "frequencies": lines[na+2],
            "smiles": lines[na+3][0],
            "inchi": lines[na+4]
        }
        mol_graph = pysmiles.read_smiles(molecule["smiles"], explicit_hydrogen=True)
        edges = list(mol_graph.edges(data=True))
        atoms = list(mol_graph.nodes(data=True))
        graph = {}
        coordinates = molecule["coordinates"]
        # Edge features
        graph["edge_sources"] = []
        graph["edge_targets"] = []
        graph["edge_distance"] = []
        graph["edge_order_1"] = []
        graph["edge_order_1_5"] = []
        graph["edge_order_2"] = []
        graph["edge_order_3"] = []
        for edge in edges:  # Two edges representing both message passing directions
            start = np.array(coordinates[edge[0]][1:])
            end = np.array(coordinates[edge[1]][1:])
            distance = np.sqrt(np.sum(np.square(end - start)))
            graph["edge_sources"].extend([edge[0], edge[1]])
            graph["edge_targets"].extend([edge[1], edge[0]])
            graph["edge_distance"].extend([distance, distance])
            edge_order = edge[2]["order"]
            graph["edge_order_1"].extend([int(edge_order == 1), int(edge_order == 1)])
            graph["edge_order_1_5"].extend([int(edge_order == 1.5), int(edge_order == 1.5)])
            graph["edge_order_2"].extend([int(edge_order == 2), int(edge_order == 2)])
            graph["edge_order_3"].extend([int(edge_order == 3), int(edge_order == 3)])
        # Node features
        graph["node_x"] = [float(coord[1]) for coord in molecule["coordinates"]]
        graph["node_y"] = [float(coord[2]) for coord in molecule["coordinates"]]
        graph["node_z"] = [float(coord[3]) for coord in molecule["coordinates"]]
        graph["node_e"] = [float(coord[4]) for coord in molecule["coordinates"]]
        graph["node_element_h"] = [1 if atom[1]["element"] == "H" else 0 for atom in atoms]
        graph["node_element_c"] = [1 if atom[1]["element"] == "C" else 0 for atom in atoms]
        graph["node_element_n"] = [1 if atom[1]["element"] == "N" else 0 for atom in atoms]
        graph["node_element_o"] = [1 if atom[1]["element"] == "O" else 0 for atom in atoms]
        graph["node_element_f"] = [1 if atom[1]["element"] == "F" else 0 for atom in atoms]
        graph["node_charge"] = [atom[1]["charge"] for atom in atoms]
        graph["node_aromatic"] = [1 if atom[1]["aromatic"] else 0 for atom in atoms]
        graph["molecule_props"] = molecule["molecule_props"]
        fn = output_fn.format(id=str(molecule["molecule_props"]["id"]).replace(".", "_"))
        with open(fn, "w") as fp:
            json.dump(graph, fp)
        if verbose:
            print(f"Processed file {file} and saved to {fn}.")

In [4]:
process_files(data_dir=data_dir, limit=limit, output_fn=output_fn)

### Tensorflow Datset

In [5]:
import tensorflow as tf
from typing import NamedTuple

class GraphInput(NamedTuple):
    """Input named tuple for the MessagePassingNet."""
    node_features: tf.Tensor
    edge_features: tf.Tensor
    edge_sources: tf.Tensor
    edge_targets: tf.Tensor

In [6]:
mol_dir = "data/molecules"
filename = os.listdir(mol_dir)[0]
target = "dipole_moment"

graph = json.load(open(os.path.join(mol_dir, filename)))
y = graph["molecule_props"][target]

print(f"Target: {target}={y}")
for key, value in graph.items():
    print(f"- {key:15} length: {len(value)}")

Target: dipole_moment=2.532
- edge_sources    length: 44
- edge_targets    length: 44
- edge_distance   length: 44
- edge_order_1    length: 44
- edge_order_1_5  length: 44
- edge_order_2    length: 44
- edge_order_3    length: 44
- node_x          length: 21
- node_y          length: 21
- node_z          length: 21
- node_e          length: 21
- node_element_h  length: 21
- node_element_c  length: 21
- node_element_n  length: 21
- node_element_o  length: 21
- node_element_f  length: 21
- node_charge     length: 21
- node_aromatic   length: 21
- molecule_props  length: 13
