# Equivariant Graph Neural Networks

- The molecules for this notebook were taken from the [refined QM09 dataset](https://doi.org/10.1038/s41597-019-0121-7) by H. Kim, J.Y. Park, and S. Choi, *Sci. Data*, **6**, 109 (2019).

<img src="https://ehoogeboom.github.io/publication/egnn/featured_hua4419112e0b0f9c21e721be460820b18_120982_680x500_fill_q90_lanczos_center_2.png" alt="Illustration of rotation equivariance" aling="right" style="width: 500px;float: right;"/>

## Overview

In this notebook we will learn how to implement an [Equivariant Graph Neural Network](https://arxiv.org/abs/2102.09844), as reported by Satorras, Hoogeboom, and Welling. Their work explores three types of equivariance on a set of particles ${\bf x}$, namely,

* Translation equivariance, where translating the input results in an equivalent translation of the output.
* Rotation/Reflection equivariance, where rotating/reflecting the input results in an equivalent rotation/reflection of the ouput.
* Permutation equivariance, where permuting the input results in the same permutation of the output.

As seen previously, we start by considering a graph of $n$ nodes $v_i \in \mathcal{V}$ and edges $e_{ij} \in \mathcal{E}$, as well as node features ${\bf h} = (\vec{h}_1, \vec{h}_2, \dots, \vec{h}_n)$, and the set of coordinates ${\bf x}_i$ associated with each of the graph nodes. An equivariant graph convolutional layer, EGCL, takes as input the set of node embeddings ${\bf h}$, coordinate embeddings ${\bf x}$ and edge information $\mathcal{E} = (e_{ij})$, and outputs a transformation on ${\bf h'}$ and ${\bf x'}$. Concisely, ${\bf h'}, {\bf x'} = \mathrm{EGCL}[{\bf h}, {\bf x}, \mathcal{E}]$. The set of equations associated to an EGCL contains
$$
\begin{aligned}
{\bf m}_{ij} &= \phi_e \left( {\bf h}_{i}, {\bf h}_{j}, \left\| {\bf x}_{i} - {\bf x}_{j} \right\|^{2}, e_{i j} \right)\, , \\
{\bf x'}_{i} &= {\bf x}_{i} + C \sum_{j \neq i} \left( {\bf x}_{i} - {\bf x}_{j} \right) \phi_x \left( {\bf m}_{ij} \right)\, , \\
{\bf m}_{i} &= \sum_{j \neq i} {\bf m}_{ij}\, , \\
{\bf h'}_{i} &= \phi_h \left( {\bf h}_{i}, {\bf m}_{i} \right)\, .
\end{aligned}
$$

Notice that the graph incorporates the relative squared distance between two coordinates $\left\| {\bf x}_{i} - {\bf x}_{j} \right\|^{2}$ into the **edge operation** $\phi_e$. The node embeddings ${\bf h}_i$ and ${\bf h}_j$, and edge attributes $e_{ij}$ also are provided as input to that same edge operation. It is worth noting that the embeddings ${\bf m}_{ij}$ can carry information from the whole graph and not only for a given edge $e_{ij}$.

Next we update the position for each particle ${\bf x}_i$ as a vector field in a radial direction. Here, the position ${\bf x}_i$ is updated by the weighted sum of all the relative differences $({\bf x}_i - {\bf x}_j)_{\forall j}$. The constant $C = 1/(n - 1)$. These weights then are multiplied to the output of the **coordinate operation** $\phi_x$. This operation takes as input the edge embeddings ${\bf m}_{ij}$. The overall operation results in the updated particle positions ${\bf x'}_i$.

These two operations are followed by an aggregation step that combines messages from all $j \neq i$ nodes.

Finally, the **node operation** $\phi_h$ takes as input the node embeddings ${\bf h}_i$ and the aggregated messages ${\bf m}_{i}$ that result in the updated node embeddings ${\bf h'}_i$.

# Libraries

In [None]:
import torch

import numpy   as np
import pandas  as pd
import seaborn as sns

import matplotlib.pyplot as plt

from matplotlib             import cm
from scipy                  import stats
from pathlib                import Path
from torch_geometric.loader import DataLoader
#
### Import local libraries
#
from model import EquivariantGraphNetwork
from model import batches, passdata

plt.rc('xtick', labelsize=18) 
plt.rc('ytick', labelsize=18)

## 1. Prepare data

Fisrt we need to load our dataset. We will use the `data.pth` that includes the nodes, edges, node features, edge features, and coordinates for all molecules. Our training data will be these descriptors and our target data will be the atomization energy for each molecule. The data is shuffled, hence we can directly divide our set into 60 % for training, 20 % testing, and the remaining 20 % for validation.

In [None]:
#
### Define data path
#
imhere  = Path.cwd()
#
### Load dataset
#
datapth = torch.load(imhere/'data.pth')

#datapth = datapth[:20_000] # Reduce as needed from size = 132_723
#
### Divide into 60% training, 20% testing, and 20% validating
#
limit   = 40*len(datapth)//100

print(f'train = { len(datapth[:-limit]) }, '
      f'test = { len(datapth[-limit:-limit//2]) }, '
      f'validate = { len(datapth[-limit//2:]) }')

## 2. Settings and hyperparameters

Our optimization algorithm is the ADAptive Moment estimation, [Adam](https://arxiv.org/pdf/1412.6980.pdf), that is based on stochastic gradient descent. We will need to define the **learning rate** and the **weight decay**. The learning rate is a hyperparameter that controls how much we are adjusting the weights of our network with respect to the loss gradient, whereas the weight decay is a regularization term that penalizes large weights

The number of **epochs** is the number of times the learning algorithm will work through the entire training dataset. One epoch means that each sample in the training dataset has had an opportunity to update the internal model parameters.

Because we are interested in learning the regression for a continuous variable, we will use the Mean Squared Error **loss function**.
$$
\mathrm{MSE} = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2
$$

where $N$ is the number of samples in the training set, $y_i$ is the reference value, and $\hat{y}_i$ is predicted value.

The training, testing, and validation data may be used in a loop function,

~~~
for batch in training: print(batch.shape)
~~~

each loop will automatically pass a sample to the neural network.

In [None]:
#
### Training parameters
#
learnig_rate = 1e-3
weight_decay = 1e-5

epochs       = 1
batch_size   = 10
test_epoch   = 10
#
### Define neural network
#
network = EquivariantGraphNetwork(hidden_nf=4, activation=torch.nn.SiLU(), aggregation='sum')
#
### Optimizer and Loss
#
optimizer = torch.optim.Adam(params=network.parameters(), lr=learnig_rate, weight_decay=weight_decay)
criterion = torch.nn.MSELoss(reduction='sum')
#
### Training and testing data
#
training   = DataLoader(datapth[:-limit],          shuffle=True,  batch_size=batch_size)
testing    = DataLoader(datapth[-limit:-limit//2], shuffle=False, batch_size=batch_size)
validating = DataLoader(datapth[-limit//2:],       shuffle=False, batch_size=batch_size)

## 3. Training

We can now train our neural network for the total `epochs` we selected and testing it every `test_epoch` epochs.

Passing training data to the networks consists in five steps:

1. Set the gradients to zero, `optimizer.zero_grad()`.

2. Pass batch to the network, `output = network(batch)`.

3. Compute the loss, `loss = criterion(output, y)`.

4. Perform backward pass, `loss.backward()`.

5. Perform the optimization step, `optimizer.step()`.

Keep in mind that during testing you **DO NOT** want to update the gradients in your neural network. Otherwise you will leak testing information and your model will also learn from the testing set. To prevent this from happening, you need to use the `torch.no_grad()` context manager.  This will prevent the gradient from being updated. 

In [None]:
for epoch in range(epochs):

    # your code here for the training set

    print(f'{epoch+1},train,{loss:.4f}')

    if (epoch+1)%test_epoch == 0:
        # your code here for the testing set

        print(f'{epoch+1},test,{loss:.4f}')

## 5. Test the model

Use the validation set to compare the reference and predicted atomization energy.