In [3]:
# Uncomment & run me to install Spektral! 
# !pip install spektral



**The QM9 Dataset**

The QM9 dataset is a benchmark small-molecule dataset comprised of labeled nodes (of which there are five types: H, C, N, O, F, representing their respective atoms on the periodic table, Hydrogen, Carbon, Nitrogen, Oxygen, and Fluorine) and edges which represent chemical bonds. 

Node features represent our atoms' chemical properties and include: 

- Position of the atom in x, y and z dimensions
- The atomic number of the atom, one-hot encoded
- The atomic charge
- The mass difference from the monoisotope 

In [4]:
# Import and read in our dataset
from spektral.datasets.qm9 import QM9

# Instantiate our dataset
dataset = QM9()

# Check out our features:
print(f"Dataset: {dataset}")
print(f"First graph: {dataset[0]}")
print(f"Number of graphs: {len(dataset)}")

Loading QM9 dataset.
Reading SDF


100%|██████████████████████████████████| 133885/133885 [02:34<00:00, 864.14it/s]


Dataset: QM9(n_graphs=133885)
First graph: Graph(n_nodes=5, n_node_features=10, n_edge_features=4, n_labels=19)
Number of graphs: 133885


In [6]:
# Print the first few node labels, the graph labels, and the node features to see how they're represented
# Our graph label is our 19-dimensional label for regression
print(f"Graph labels: {dataset[0].y}")
# Our node labels represent our atoms
print(f"First 10 node labels: {dataset[0].x[0:10]}")

Graph labels: [ 3.29932000e+00  1.02987000e+00  9.70160000e-01  5.52060000e+00
  7.93700000e+01 -2.77600000e-01  3.86000000e-02  3.16200000e-01
  1.29977970e+03  1.71140000e-01 -4.03187169e+02 -4.03178075e+02
 -4.03177131e+02 -4.03221255e+02  3.47780000e+01 -1.95050003e+03
 -1.96256891e+03 -1.97383081e+03 -1.81564082e+03]
First 10 node labels: [[ 0.      1.      0.      0.      0.     -0.2233  1.5468 -0.0164  0.
   0.    ]
 [ 0.      1.      0.      0.      0.      0.0687  0.0474 -0.0079  0.
   0.    ]
 [ 0.      0.      0.      1.      0.     -1.157  -0.6947  0.0414  0.
   0.    ]
 [ 0.      1.      0.      0.      0.      0.9169 -0.434  -1.2068  0.
   0.    ]
 [ 0.      1.      0.      0.      0.      1.4848 -1.775  -0.7332  0.
   0.    ]
 [ 0.      1.      0.      0.      0.      1.9259 -1.4485  0.7173  0.
   0.    ]
 [ 0.      1.      0.      0.      0.      0.8673 -0.4236  1.2363  0.
   0.    ]
 [ 0.      1.      0.      0.      0.      2.0745 -2.625   1.5671  0.
   0.    ]
 [ 0. 

We're going to use a crystal convolutional layer, similar to the approach in [this paper](https://arxiv.org/abs/1710.10324) which also uses QM9 as its benchmark dataset. 

The goal is to predict chemical properties of molecules. We're going to maintain explanations at a high level, for the purpose of this walkthrough, to get a handle on using existing libraries to implement various types of GCNs. 

Spektral is built on top of Tensorflow & Keras functionality, making it extraordinarily intuitive to implement for engineers already working with these tools. Another great alternative for implementation of Graph Neural Networks is [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/index.html), which I've provided an example for [here](https://github.com/sidneyarcidiacono/graph-convolutional-networks/blob/main/notebooks/PROTEINS_Embedding.ipynb) if you're interested in seeing how the implementation differs from Spektral. 

Spektral is, in my opinion, only more convenient for quick prototyping or proof-of-concept work, as it's easier to install with pip without the need to deal with additional dependencies, GPU setup, and so on.

// later on talk about the implications of being able to do this for material science, chemistry, biology, etc: What's important to take away is the implication of being able to ...

In [7]:
# We're going to do a slightly messy train/test split:
import numpy as np

np.random.shuffle(dataset)
split = int(0.8 * len(dataset))
data_train, data_test = dataset[:split], dataset[split:]

In [92]:
# Spektral is built on top of Keras, so we can use the Keras functional API to build a model that first embeds,
# then sums the nodes together (global pooling), then classifies the result with a dense softmax layer

# First, let's import the necessary layers:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Dropout
from spektral.layers import GCNConv, GlobalSumPool
import tensorflow as tf

In [123]:
# Now, we can use model subclassing to define our model:

class MoleculePredictor(Model):
  
    def __init__(self, n_hidden, n_labels):
        super().__init__()
        # Define our GCN layer with our n_hidden layers
        self.graph_conv = GCNConv(n_hidden)
        # Define our dropout layer, initialize dropout freq. to .5 (50%)
        self.dropout = Dropout(0.5)
        # Define our Dense layer, with softmax activation function
        self.dense = Dense(n_labels, 'softmax')
    
    def call(self, inputs, mask=None):
        x, a, y = inputs
        print(f"Shape of X: {x.shape}")

        # Update node features
        x = tf.matmul(x, self.weights)

        return self.propagate(x=x, a=a)

In [130]:
# Instantiate our model for training
model = MoleculePredictor(32, dataset.n_labels)

In [129]:
# Compile model with our optimizer (adam) and loss function
model.compile('adam', 'categorical_crossentropy')

In [126]:
# Here's the trick - we can't just call Keras' fit() method on this model.
# Instead, we have to use Loaders, which Spektral walks us through. Loaders create mini-batches by iterating over the graph
# Since we're using Spektral for an experiment, for our first trial we'll use the recommended loader in the getting started tutorial

# TODO: read up on modes and try other loaders later
from spektral.data import BatchLoader

loader = BatchLoader(data_train, batch_size=32)

In [127]:
# Now we can train! We don't need to specify a batch size, since our loader is basically a generator
# But we do need to specify the steps_per_epoch parameter

model.fit(loader.load(), steps_per_epoch=loader.steps_per_epoch, epochs=10)

Shape of X: (32, 29, 10)


InvalidArgumentError: In[1] ndims must be >= 2: 1 [Op:BatchMatMulV2]