# QM9 Modeling

QM9 is a dataset of 134,000 molecules consisting of 9 heavy atoms drawn from the elements C, H, O, N, F{cite}`ramakrishnan2014quantum`. The features are the xyz coordinates ($\mathbf{X}$) and elements ($\vec{e}$) of the molecule. The coordinates are determined from B3LYP/6-31G(2df,p) level DFT geometry optimization. There are multiple labels (see table below), but we'll be interested specifically in the energy of formation (Enthalpy at 298.15 K). The goal in this chapter is to see how we can model this data.


QM9 is one of the most popular dataset for machine learning and deep learning since it came out in 2014. The first papers could achieve about 10 kcal/mol on this regression problem and now are down to ~1 kcal/mol and lower. Any model on this dataset must be translation, rotation, and permutation invariant. 

## Label Description

|Index | Name | Units | Description|
 |:-----|-------|-------|-----------:|
  |0  |index  |   -            |Consecutive, 1-based integer identifier of molecule|
  |1  |A      |   GHz          |Rotational constant A|
  |2  |B      |   GHz          |Rotational constant B|
  |3  |C      |   GHz          |Rotational constant C|
  |4  |mu     |   Debye        |Dipole moment|
  |5  |alpha  |   Bohr^3       |Isotropic polarizability|
  |6  |homo   |   Hartree      |Energy of Highest occupied molecular orbital (HOMO)|
  |7  |lumo   |   Hartree      |Energy of Lowest occupied molecular orbital (LUMO)|
 |8 | gap   |    Hartree     | Gap, difference between LUMO and HOMO|
 |9 | r2    |    Bohr^2      | Electronic spatial extent|
 |10 | zpve  |    Hartree     | Zero point vibrational energy|
 |11 | U0    |    Hartree     | Internal energy at 0 K|
 |12 | U     |    Hartree     | Internal energy at 298.15 K|
 |13 | H     |    Hartree     | Enthalpy at 298.15 K|
 |14 | G     |    Hartree     | Free energy at 298.15 K|
 |15 | Cv    |    cal/(mol K) | Heat capacity at 298.15 K|


## Data

I have written some helper code in the `fetch_qm9.py` file. It downloads the data and converts into a format easily used in Python. The data returned from this function is broken into the features $\mathbf{X}$ and $\vec{e}$. $\mathbf{X}$ is an $N\times4$ matrix of atom positions + partial charge of the atom. $\vec{e}$ is vector of atomic numbers for each atom in the molecule. Remember to slice the specific label you want from the label vector.

## Running This Notebook


Click the &nbsp;<i aria-label="Launch interactive content" class="fas fa-rocket"></i>&nbsp; above to launch this page as an interactive Google Colab. See details below on installing packages, either on your own environment or on Google Colab

````{tip} My title
:class: dropdown
To install packages, execute this code in a new cell

```
!pip install matplotlib numpy pandas seaborn tensorflow jax jaxlib
```

````

In [1]:
import tensorflow as tf
import numpy as np
np.random.seed(0)
from fetch_qm9 import fetch_qm9, get_qm9

Let's load the data. This step will take a few minutes as it is downloaded and processed. 

In [2]:
qm9_records = fetch_qm9()
data = get_qm9(qm9_records)

Found existing record file, delete if you want to re-fetch


`data` is an iterable containing the data for the 133k molecules. Let's examine the first one.

In [3]:
for d in data:
    print(d)
    break

((<tf.Tensor: shape=(5,), dtype=int64, numpy=array([6, 1, 1, 1, 1])>, <tf.Tensor: shape=(5, 4), dtype=float32, numpy=
array([[-1.2698136e-02,  1.0858041e+00,  8.0009960e-03, -5.3568900e-01],
       [ 2.1504159e-03, -6.0313176e-03,  1.9761203e-03,  1.3392100e-01],
       [ 1.0117308e+00,  1.4637512e+00,  2.7657481e-04,  1.3392200e-01],
       [-5.4081506e-01,  1.4475266e+00, -8.7664372e-01,  1.3392299e-01],
       [-5.2381361e-01,  1.4379326e+00,  9.0639728e-01,  1.3392299e-01]],
      dtype=float32)>), <tf.Tensor: shape=(16,), dtype=float32, numpy=
array([ 1.0000000e+00,  1.5771181e+02,  1.5770998e+02,  1.5770699e+02,
        0.0000000e+00,  1.3210000e+01, -3.8769999e-01,  1.1710000e-01,
        5.0480002e-01,  3.5364101e+01,  4.4748999e-02, -4.0478931e+01,
       -4.0476063e+01, -4.0475117e+01, -4.0498596e+01,  6.4689999e+00],
      dtype=float32)>)


These are Tensorflow Tensors. They can be converted to numpy arrays via `x.numpy()`. The first item is the element vector `6,1,1,1,1`. Do you recognize the elements? It's C, H, H, H, H. The positions come next. Note that the there is an extra column containing the atom partial charges, which we will not use as a feature. Finally, the last tensor is the label vector. 

Now we will do some processing of the data to get into a more usable format. Let's convert to numpy arrays, remove the partial charges, and convert the elements into one-hot vectors.

In [4]:
def convert_record(d):
    # break up record
    (e, x), y = d
    # 
    e = e.numpy()
    x = x.numpy()
    r = x[:, :3]    
    # use nearest power of 2 (16)
    ohc = np.zeros((len(e), 16))
    ohc[np.arange(len(e)), e - 1] = 1    
    return (ohc, r), y.numpy()[13]

for d in data:
    (e,x), y = convert_record(d)
    print('Element one hots\n', e)
    print('Coordinates\n', x)
    print('Label:', y)
    break

Element one hots
 [[0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
Coordinates
 [[-1.2698136e-02  1.0858041e+00  8.0009960e-03]
 [ 2.1504159e-03 -6.0313176e-03  1.9761203e-03]
 [ 1.0117308e+00  1.4637512e+00  2.7657481e-04]
 [-5.4081506e-01  1.4475266e+00 -8.7664372e-01]
 [-5.2381361e-01  1.4379326e+00  9.0639728e-01]]
Label: -40.475117


## Example GNN Model

We now can work with this data to build a model. Let's build a simple model that can model energy and obeys the invariances required of the problem. We will use a graph neural network (GNN) because it obeys permutation invariance. We will create a *graph* from the coordinates/element vector by joining all atoms to all other atoms and using their inverse pairwise distance as the edge weight. The choice of pairwise distance gives us translation and rotation invariance. The choice of inverse distance means that atoms which are far away naturally have low edge weights.

I will now define our model using the Battaglia equations {cite}`battaglia2018relational`. As opposed to most examples we've seen in class, I will use the graph level feature vector $\vec{u}$ which will ultimately be our estimate of energy. The edge update will only consider the sender and the edge weight with trainable parameters:

\begin{equation}
      \vec{e}^{'}_k = \phi^e\left( \vec{e}_k, \vec{v}_{rk}, \vec{v}_{sk}, \vec{u}\right) = \sigma\left(\vec{v}_{sk}\vec{w}_ee_k\right)
\end{equation}

where the input edge $e_k$ will be a single number (inverse pairwise distance). We will use a sum aggregation for edges (not shown). The node update will be 

\begin{equation}
   \vec{v}^{'}_i = \phi^v\left( \bar{e}^{'}_i, \vec{v}_i, \vec{u}\right) = \sigma\left(\mathbf{W}_v \bar{e}^{'}_i\right) + \vec{v}_i
\end{equation}

The global node aggregation will also be a sum. Finally we have our graph feature vector update:

\begin{equation}
    \vec{u}^{'} = \phi^u\left( \bar{e}^{'},\bar{v}^{'}, \vec{u}\right) = \sigma\left(\mathbf{W}_u\bar{v}^{'}\right) + \vec{u}
\end{equation}


To compute the final energy, we'll use our regression equation:

\begin{equation}
    \hat{E} = \vec{w}\cdot \vec{u} + b
\end{equation}

One final detail is that we will pass on $\vec{u}$ and the nodes, but we will keep the edges the same at each GNN layer. Remember this is an example model: there are many changes that could be made to the above. Also, it is not kernel learning which is the favorite for this domain. Let's implement it though and see if it works. 

### JAX Model Implementation

In [5]:
import jax.numpy as jnp
import jax.experimental.optimizers as optimizers
import jax
import warnings
warnings.filterwarnings('ignore')

def x2e(x):
    '''convert xyz coordinates to inverse pairwise distance'''    
    r2 = jnp.sum((x - x[:, jnp.newaxis, :])**2, axis=-1)
    e = jnp.where(r2 != 0, 1 / r2, 0.)
    return e

def gnn_layer(nodes, edges, features, we, wv, wu):
    '''Implementation of the GNN'''
    # make nodes be N x N so we can just multiply directly
    ek = jax.nn.relu(
        jnp.repeat(nodes[jnp.newaxis,...], nodes.shape[0], axis=0) @ we * edges[...,jnp.newaxis])
    ebar = jnp.sum(ek, axis=1)
    new_nodes = jax.nn.relu(ebar @ wv) + nodes
    global_node_features = jnp.sum(new_nodes, axis=0)
    new_features = jax.nn.relu(global_node_features  @ wu) + features    
    return new_nodes, edges, new_features


We have implemented the code to convert coordinates into inverse pairwise distance and the GNN equations above. Let's test them out.

In [67]:
graph_feature_len = 8
node_feature_len = 16
msg_feature_len = 16

# make our weights
def init_weights(g, n, m):
    we = np.random.normal(size=(n, m), scale=1e-1)
    wv = np.random.normal(size=(m, n), scale=1e-1)
    wu = np.random.normal(size=(n, g), scale=1e-1)
    return we, wv, wu

# make a graph
nodes = e
edges = x2e(x)
features = jnp.zeros(graph_feature_len)

# eval
out = gnn_layer(nodes, edges, features, *init_weights(graph_feature_len, node_feature_len, msg_feature_len))
print('input feautres', features)
print('output features', out[2])

input feautres [0. 0. 0. 0. 0. 0. 0. 0.]
output features [0.73143464 0.         0.63617605 1.041506   2.0111923  1.0279881
 0.         0.7413224 ]


Great! Our model can update the graph features. Now we need to convert this into callable and loss. We'll stack two GNN layers.

In [72]:
# get weights for both layers
w1 = init_weights(graph_feature_len, node_feature_len, msg_feature_len)
w2 = init_weights(graph_feature_len, node_feature_len, msg_feature_len)
w3 = np.random.normal(size=(graph_feature_len))
b = -325. # starting guess

@jax.jit
def model(nodes, coords, w1, w2, w3, b):
    f0 = jnp.zeros(graph_feature_len)
    e0 = x2e(coords)
    n0 = nodes
    n,e,f = gnn_layer(n0, e0, f0, *w1)
    n,e,f = gnn_layer(n, e, f, *w2)
    yhat = f @ w3 + b
    return yhat

def loss(nodes, coords, y, w1, w2, w3, b):
    return (model(nodes, coords, w1, w2, w3, b) - y)**2
loss_grad = jax.grad(loss, (3, 4, 5, 6))

We're going to do something a little different. Since we cannot batch our graphs (all different numbers of atoms), we'll compute a few gradients to build a mini-batch before updating the parameters. Why do we want to do this? Batches can improve our gradient updates and are not just a computational convenience.

In [73]:
# we'll just train on 5,000 and use 1,000 for test
test_set = data.take(1000)
valid_set = data.skip(1000).take(100)
train_set = data.skip(1100).take(500).shuffle(500)

epochs = 16
batch_size = 32
eta = 1e-2
val_loss = [0. for _ in range(epochs)]
for epoch in range(epochs):
    bi = 0
    grad_est = None
    for d in train_set:         
        # do training step
        # but do not update
        # until have enough points
        (e,x), y = convert_record(d)
        if grad_est is None:
            grad_est = loss_grad(e, x, y, w1, w2, w3, b)
        else:
            grad_est += loss_grad(e, x, y, w1, w2, w3, b)
        bi += 1
        if bi == batch_size:
            # have enough to update            
            # update regression weights
            w3 -= eta * grad_est[2]  / batch_size
            b -= eta * grad_est[3]  / batch_size
            # update GNN weights            
            for i,w in [(0, w1), (1, w2)]:
                for j, param in enumerate(w):
                    param -= eta * grad_est[i][j] / batch_size
            # reset tracking of batch index
            bi = 0            
            grad_est = None            
    # compute validation loss    
    for v in valid_set:
        (e,x), y = convert_record(v)
        # convert SE to RMSE
        val_loss[epoch] += jnp.sqrt(loss(e, x, y, w1, w2, w3, b) / 1000)
    print('epoch:', epoch, 'loss: {:.2f}'.format(val_loss[epoch]))

epoch: 0 loss: 58.11
epoch: 1 loss: 59.59
epoch: 2 loss: 60.31
epoch: 3 loss: 61.78
epoch: 4 loss: 62.12
epoch: 5 loss: 62.49
epoch: 6 loss: 62.30
epoch: 7 loss: 61.85
epoch: 8 loss: 61.12
epoch: 9 loss: 61.99
epoch: 10 loss: 60.74
epoch: 11 loss: 60.81
epoch: 12 loss: 61.73
epoch: 13 loss: 60.56
epoch: 14 loss: 60.90
epoch: 15 loss: 60.66


This is a large dataset and we're under training, but hopefully you get the principles of this process! Finally, we'll examine our parity plot. 

In [49]:
import matplotlib.pyplot as plt

In [53]:
import seaborn as sns
import matplotlib as mpl
sns.set_context('notebook')
sns.set_style('dark',  {'xtick.bottom':True, 'ytick.left':True, 'xtick.color': '#666666', 'ytick.color': '#666666',
                        'axes.edgecolor': '#666666', 'axes.linewidth':     0.8 , 'figure.dpi': 300})
color_cycle = ['#1BBC9B', '#F06060', '#5C4B51', '#F3B562', '#6e5687']
mpl.rcParams['axes.prop_cycle'] = mpl.cycler(color=color_cycle) 

In [None]:
ys = []
yhats = []
for v in valid_set:
    (e,x), y = convert_record(v)
    ys.append(y)
    # convert SE to RMSE
    yhats.append(model(e, x, w1, w2, w3, b))

    
plt.plot(ys, ys, '-')
plt.plot(ys, yhats, 'o')
plt.xlabel('Energy')
plt.ylabel('Predicted Energy')
plt.show()

The clusters are molecule types/sizes. You can see we're starting to get the correct trend within the clusters, but a lot of work needs to be done to move some of them. Additional learning required!

## Cited References

```{bibliography} references.bib
```