# Graph Attention Network

- The molecules for this notebook were taken from the [QM09 dataset](https://www.nature.com/articles/sdata201422) by R. Ramakrishnan1, P. O. Dral, M. Rupp, and O. A. von Lilienfeld *Sci. Data*, **1**, 140022 (2014).

- Description of the Graph Attention Network taken from https://openreview.net/forum?id=rJXMpikCZ.

<img src="https://miro.medium.com/v2/resize:fit:720/format:webp/1*3D844_twutCaunYMPuo-Sw.png" alt="Illustration of the attention mechanism" aling="right" style="width: 500px;float: right;"/>

## Overview

In this notebook we will learn how to implement a [Graph Attention Network](https://arxiv.org/abs/1710.10903), as reported by Veličković, Cucurull, Casanova, Romero, Liò, and Bengio. This algorithm uses the graph structure that represents an isolated molecule or the unit cell for a given material.

To understand the architecture, consider a graph of $n$ nodes, specified as a set of node features, $(\vec{h}_1, \vec{h}_2, \dots, \vec{h}_n)$, and an adjacency matrix $\bf A$, such that ${\bf A}_{ij} = 1$ if $i$ and $j$ are connected, and 0 otherwise. A graph convolutional layer then computes a set of new node features, $(\vec{h}'_1, \vec{h}'_2, \dots, \vec{h}'_n)$, based on the input features as well as the graph structure. In order to achieve a higher-level representation, every graph convolutional layer starts from a shared node-wise feature transformation, specified by a weight matrix ${\bf W}$. This transforms the feature vectors into $\vec{g}_i = {\bf W}\vec{h}_i$. After this, the vectors $\vec{g}_i$ typically are recombined in some way at each node.

In general, we can define a graph convolutional operator as an aggregation of features across neighborhoods, defining $\mathcal{N}_i$ as the neighborhood of node $i$ that mostly consists of all first-order neighbors of $i$, including $i$ itself. We can define the output features of node $i$ as
$$
\vec{h}'_i = \sigma\left(\sum_{j\in\mathcal{N}_i}\alpha_{ij}\vec{g}_j\right)\, ,
$$
where $\sigma$ is an activation function, and $\alpha_{ij}$ specifies the weighting factor (importance) of node $j$’s features to node $i$.

We can instead let $\alpha_{ij}$ be implicitly defined, employing self-attention over the node features. Notice that self-attention has previously been shown to be self-sufficient for state-of-the-art-level results on machine translation, as demonstrated by the [Transformer architecture](https://arxiv.org/abs/1706.03762).

Generally, we let $\alpha_{ij}$ be computed as a byproduct of an attention mechanism, $a \in \mathbb{R}^N \times \mathbb{R}^N \rightarrow \mathbb{R}$, which computes unnormalised coefficients $e_{ij}$ across pairs of nodes $i, j$, based on their features
$$
e_{ij} = a(\vec{h}_i, \vec{h}_j)\, .
$$
We inject the graph structure by only allowing node $i$ to attend over nodes in its neighborhood, $j \in \mathcal{N}_i$. These coefficients then are normalised using the [Softmax](https://pytorch.org/docs/stable/generated/torch.nn.Softmax.html) function, in order to be comparable across different neighborhoods
$$
\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k\in\mathcal{N}_i}\exp(e_{ik})}\, .
$$
Overall, the framework is agnostic to the choice of attention mechanism $a$. The parameters of the mechanism are trained jointly with the rest of the network in an end-to-end fashion.

To stabilise the learning process of self-attention, multi-head attention might be beneficial. Namely, the operations of the layer are independently replicated $K$ times, each replica with different parameters, and outputs are aggregated feature-wise by concatenation or addition.
$$
\vec{h}'_i = {\LARGE \vert}_{k=1}^K \sigma\left(\sum_{j\in\mathcal{N}_i}\alpha_{ij}^k{\bf W}^k\vec{h}_j\right)\, ,
$$
where $\alpha_{ij}^k$ are the attention coefficients derived by the $k$-th replica, and ${\bf W}^k$ is the weight matrix specifying the linear transformation of the $k$-th replica.

# Libraries

In [None]:
import torch

import numpy   as np
import pandas  as pd
import seaborn as sns

import matplotlib.pyplot as plt

from matplotlib             import cm
from scipy                  import stats
from pathlib                import Path
from torch_geometric.loader import DataLoader
#
### Import local libraries
#
from model import GraphAttentionNetwork
from model import batches, passdata

plt.rc('xtick', labelsize=18) 
plt.rc('ytick', labelsize=18)

## 1. Prepare data

Fisrt we need to load our dataset. We will use the `data.pth` that includes the adjacency matrix and node features for all molecules. Our training data will be these two descriptors and our target data will be the HOMO-LUMO gap for each molecule. The data is shuffled, hence we can directly divide our set into 60 % for training, 20 % testing, and the remaining 20 % for validation.

In [None]:
#
### Define data path
#
imhere  = Path.cwd()
#
### Load dataset
#
datapth = torch.load(imhere/'data.pth')
#
#datapth = datapth[:20_000] # Reduce as needed from size = 132_364
#
### Divide into 60% training, 20% testing, and 20% validating
#
limit   = 40*len(datapth)//100

print(f'train = { len(datapth[:-limit]) }, '
      f'test = { len(datapth[-limit:-limit//2]) }, '
      f'validate = { len(datapth[-limit//2:]) }')

## 2. Settings and hyperparameters

Our optimization algorithm is the ADAptive Moment estimation, [Adam](https://arxiv.org/pdf/1412.6980.pdf), that is based on stochastic gradient descent. We will need to define the **learning rate** and the **weight decay**. The learning rate is a hyperparameter that controls how much we are adjusting the weights of our network with respect to the loss gradient, whereas the weight decay is a regularization term that penalizes large weights

The number of **epochs** is the number of times the learning algorithm will work through the entire training dataset. One epoch means that each sample in the training dataset has had an opportunity to update the internal model parameters.

Because we are interested in learning the regression for a continuous variable, we will use the Mean Squared Error **loss function**.
$$
\mathrm{MSE} = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2
$$

where $N$ is the number of samples in the training set, $y_i$ is the reference value, and $\hat{y}_i$ is predicted value.

The training, testing, and validation data may be used in a loop function,

~~~
for batch in training: print(batch.shape)
~~~

each loop will automatically pass a sample to the neural network.

In [None]:
#
### Training parameters
#
learnig_rate = 1e-5
weight_decay = 1e-5

epochs       = 1
test_epoch   = 2
#
### Define neural network
#
network = GraphAttentionNetwork(hidden_nf=4, output_nf=1, attention_nf=1, reduce='cat', drop=0.0)
#
### Optimizer and Loss
#
optimizer = torch.optim.Adam(params=network.parameters(), lr=learnig_rate, weight_decay=weight_decay)
criterion = torch.nn.MSELoss(reduction='sum')
#
### Training and testing data
#
training   = DataLoader(datapth[:-limit], shuffle=True)
testing    = DataLoader(datapth[-limit:-limit//2], shuffle=False)
validating = DataLoader(datapth[-limit//2:], shuffle=False)

## 3. Training

We can now train our neural network for the total `epochs` we selected and testing it every `test_epoch` epochs.

Passing training data to the networks consists in five steps:

1. Set the gradients to zero, `optimizer.zero_grad()`.

2. Pass batch to the network, `output = network(batch)`.

3. Compute the loss, `loss = criterion(output, y)`.

4. Perform backward pass, `loss.backward()`.

5. Perform the optimization step, `optimizer.step()`.

Keep in mind that during testing you **DO NOT** want to update the gradients in your neural network. Otherwise you will leak testing information and your model will also learn from the testing set. To prevent this from happening, you need to use the `torch.no_grad()` context manager.  This will prevent the gradient from being updated. 

In [None]:
for epoch in range(epochs):

    # your code here for the training set

    print(f'{epoch+1},train,{loss:.4f}')

    if (epoch+1)%test_epoch == 0:
        # your code here for the testing set

        print(f'{epoch+1},test,{loss:.4f}')

## 5. Test the model

Use the validation set to compare the reference and predicted HOMO-LUMO gap.