# Deep Learning on Graphs with Message Passing Neural Networks

At this point in the Straight Dope, we've seen a wide variety of different types of data fed as input to our models.

We started with linear regression models and MLPs, which take simple, 1-dimensional vectors of real numbers as input.  Then we met CNNs, which take images represented as 3-dimensional tensors as input.  Next we saw how RNNs can take sequence data, like time-series or natural-language sentences, or really anything we can represent as a sequence of tensors, as input.  And we even saw how to consume tree-structured data, like a parse tree of a natural-language sentence, using a Tree LSTM.

In this chapter we'll see how to build models to handle yet another type of data: graph-structured data. We'll learn how to build Message Passing Neural Networks (MPNNs), which are a class of deep model that can take arbitrary graphs as input.

**Wait, "graphs"?**

When I say "graph", I mean that word the way a mathematician means it.  [Wikipedia explains the concept well][1], if you're not familiar.  Going forward I'll assume we're familiar with graph-lingo like "directed edge" and "adjacency matrix", so take a gander at that link if you need to.

**So what exactly does "taking graphs as input" mean?**

Good question!  Reading papers or blogs about this topic can be confusing, since (at least) two distinct learning scenarios both go by the name "learning on graphs":
1. *We're trying to learn a model whose inputs are arbitrary graphs.*  Our dataset consists of (graph, label) pairs.  E.g. predicting the pharmacological activity of a molecule based on how its atoms are connected.
    
2. *We're trying to learn a model whose inputs are vertices in some graph.*  Our dataset is one big graph whose vertices are datapoints with edges between them, some labeled, some unlabeled.  E.g. predicting the impact factor of an article given a bag-of-words representation of the article and edges connecting it to its references.

In this chapter, we're focusing only on scenario 1, but MPNNs can be used for scenario 2 as well.

**Aren't sequences and trees just special types of graphs?  We already know models that handle those. (RNNs and Tree-RNNs.)**

Yes they are!  In fact you can (and people do) even think of images as graphs where each pixel is a vertex with edges to all its adjacent pixels.  But MPNNs can operate on *any* type of graph: directed or not, cyclic or not, etc.  Be careful though: MPNNs likely won't perform as well on sequences, trees, or images as models designed specifically for these data types will.

**But can't you basically represent anything as a graph if you try hard enough?**

Yeah, that's partly why graphs are ubiquitous in math and computer science.  They're a super general concept.  

This generality should make us veeeery suspicious that deep learning on graphs won't work as consistently well as, say, deep learning on real-world images does.  If it did, we could use deep networks to reason about nearly anything, and that would smell like a free lunch.

But MPNNs are still worth learning about.  They're the best tool we have at the moment for understanding graph-structured data, and they're a hot area of research.

[1]:https://en.wikipedia.org/wiki/Graph_(discrete_mathematics)

## Message Passing Neural Networks

Message Passing Neural Networks were introduced in [this paper](https://arxiv.org/pdf/1704.01212.pdf).  MPNNs are actually a family of models rather than a specific implementation, like how RNNs are a general model family, one implementation of which is an LSTM.  We'll first go over the general MPNN idea and then build a specific implementation.

### The Setup

We've got a dataset of `(graph, label)` pairs.  In each graph, each vertex $v$ has associated features $x_v$, and each edge has features $e_{vw}$.  For simplicity of explanation we'll assume each graph is undirected, but once you understand MPNNs it's easy to see how to extend them to directed graphs or multigraphs.

### The Model

The goal of an MPNN is to take in a `graph` and output the correct `label`.  They do this by the following procedure:
1. Initialize a "hidden state" $h_v^0$ for each vertex $v$ in the graph as a function of the vertex's features: $$h_v^0 = \text{init_hidden}(x_v).$$
2. For each round $t$ out of $T$ total rounds:
    3. Each vertex $v$ receives a "message" $m_v^{t+1}$, which is the sum of messages passed by $v$'s neighbors as functions of their current hidden states and the edge features: $$m_v^{t+1} = \sum_{w \in \text{neighbors of }v} M_t(h_v^t, h_w^t, e_{vw}).$$
    4. Each vertex $v$ updates its hidden state as a function of the message it received: $$h_v^{t+1} = U_t(h_v^t, m_v^{t+1}).$$
5. The output is computed as the "readout" function of all the hidden states: $$\hat{y} = R_t(\{h_v^T \vert v \text{ is in the graph} \}).$$

Here's an base class for any type of MPNN that encapsulates this procedure:

In [None]:
import mxnet as mx
from mxnet import nd, autograd, gluon
import sklearn.metrics as metrics
import numpy as np
import scipy as sp
import math
np.random.seed(1)
mx.random.seed(1)
ctx = mx.cpu() # NOTE: CHANGE THIS TO .gpu() IF YOU HAVE A GPU - MUCH TIME WILL BE SAVED!

class MPNN(gluon.Block):
    '''
    General base class for all varieties of Message Passing Neural Network.
    '''
    def __init__(self, n_msg_pass_iters, *args, **kwargs):
        super(MPNN, self).__init__(**kwargs)
        self.n_msg_pass_iters = n_msg_pass_iters
    
    def init_hidden_states_and_edges(self, graph):
        # Performs "init_hidden" from above and prepares adjacency information from the graph
        # (This function is here so the model can be flexible about what format the graph is given to us in.)
        raise NotImplementedError()
    
    def compute_messages(self, hidden_states, edges, t):
        # Computes M_t from above and sums the messages
        raise NotImplementedError()
    
    def update_hidden_states(self, hidden_states, messages, t):
        # Performs U_t from above
        raise NotImplementedError()
    
    def readout(self, hidden_states):
        # Performs R_t from above
        raise NotImplementedError()
        
    def forward(self, graph):
        hidden_states, edges = self.init_hidden_states_and_edges(graph)
        for t in range(self.n_msg_pass_iters):
            messages = self.compute_messages(hidden_states, edges, t)
            hidden_states = self.update_hidden_states(hidden_states, messages, t)
            
        return self.readout(hidden_states)

Different flavors of MPNN use different functions for $\text{init_hidden}$, $M_t$, $U_t$, and $R_t$, and more often than not these functions are simpler than the fully general versions described above.  For example, in the GGSNN version of MPNN we'll discuss below, $M_t$ is the same function for each $t$, and it doesn't depend on the neighboring vertex's hidden state or any edge features.

## Gated Graph Sequence Neural Networks

Now that we've got the MPNN framework down, let's grab some real data and implement a particular type of MPNN, called a Gated Graph Sequence Neural Network (GGSNN), to learn on it.

### An actual dataset

As a demonstration task, we'll use the [Tox21 dataset][1].  The objective of this dataset is to take in the [chemical structure of a molecule][2], represented as an undirected graph with atoms as vertices and bonds as edges, and predict the toxicity of the molecule.  In particular, we'll try to predict whether a molecule might [activate a particular cellular response to pollutants in your body][3].

To access the Tox21 dataset we use the DeepChem package, which you'll have to install either from its [website][4] or with this code:

[1]:https://tripod.nih.gov/tox21/challenge/
[2]:https://en.wikipedia.org/wiki/Structural_formula
[3]:https://pubchem.ncbi.nlm.nih.gov/bioassay/743122#section=Top
[4]:https://deepchem.io/

In [None]:
# WARNING: You may have to change this command to suit the python environment you're working in, e.g. if you're not using conda
# You also may need to restart your jupyter notebook and/or kernel
!conda install -y -c deepchem -c rdkit -c conda-forge -c omnia deepchem=2.0.0
!conda update numpy

Now we'll load the data and convert it to graph format.  If you're not fluent in chemistry, don't worry about the details of the following preprocessing.  We're just transforming the data from a molecular format into the format we're used to seeing from above.

What we'll end up with is a dataset of `(graph, label)` tuples where each `label` is a binary label (toxic or not), and each `graph` is an undirected graph represented as a vector of features for each vertex and an adjacency matrix.

In [None]:
import deepchem as dc
from deepchem.feat.mol_graphs import ConvMol

tox21_tasks, tox21_datasets, transformers = dc.molnet.load_tox21(featurizer='GraphConv')
train_mols, valid_mols, test_mols = tox21_datasets

def molecules_to_dataset(molecules):
    dataset = []
    for ind, (mols, targets, _, _) in enumerate(molecules.iterbatches(1, deterministic=True)):
        mol = mols[0]
        target = targets[0]
        adj_list = mol.get_adjacency_list()
        if any(adj_list):
            d = {}
            # Grabbing "NR-AhR" endocrine system results, which is the 3rd assay in targets
            d['label'] = target[2]
            # Grabbing features for each atom, including the element type and some other chemical information
            d['vertex_features'] = mol.get_atom_features()
            # Grabbing the connectivity and converting it to sparse array
            a = sp.sparse.dok_matrix((len(adj_list), len(adj_list)), dtype='float32')
            for i in range(len(adj_list)):
                for j in adj_list[i]:
                    a[i,j] = 1
            d['adj_mat'] = a.tocsr()
            dataset.append(d)
    return dataset

train_dataset = molecules_to_dataset(train_mols)
valid_dataset = molecules_to_dataset(valid_mols)
test_dataset = molecules_to_dataset(test_mols)

### The GGSNN model

Now we'll implement a Gated Graph Sequence Neural Network, introduced in [this paper][1], on this dataset.

A GGSNN is an MPNN with the following customizations:
1. Hidden states $h_v^0$ for each vertex are initialized with a single-layer MLP.
2. The messages passed to $v$ by its neighbors are a simple matrix multiplication of each neighbor's hidden state: $$m_v^{t+1} = \sum_{w \in \text{neighbors of }v} W_{\texttt{msg_fxn}}h_w^t.$$
3. Each vertex $v$ updates its hidden state to be the output of a [GRU cell](http://gluon.mxnet.io/chapter05_recurrent-neural-networks/gru-scratch.html) (a type of RNN cell) whose hidden state is the vertex's hidden state and whose input is the message the vertex received: $$h_v^{t+1} = \text{GRU}(m_v^{t+1}, h_v^t).$$
4. The "readout" function is this funny little beast: $$\hat{y} = \text{softmax}\left(f_{\text{out}}\left(\sum_{v} \sigma\left(f_1([h_v^T, h_v^0])\right) \odot f_2(h_v^T)\right)\right),$$ where the $f$s are MLPs, $\sigma$ is the sigmoid function, and $\odot$ is elementwise multiplication.  This acts like a sort of attention mechanism that depends on how much each vertex's hidden state changed during message passing.

Here's an implementation of GGSNN that fills out the details of the MPNN base class from above:

> *A key implementation note about what follows:* You'll notice below that the GGSNN is coded as though it takes in a single graph, rather than a minibatch of graphs as you might expect.  This is intentional.  We want to reserve the 0th/batch dimension of the tensors in our implementation to index over the vertices of the graph.  This makes the implementation more elegant, since MXNet operations are built to handle inputs that vary in size along the 0th dimension, and the number of vertices in each graph is usually different.

> But of course, we DO want to process minibatches of data.  To do this, combine a minibatch of graphs into a single, large, disconnected graph, do all the message passing on this graph (no messages will get passed between minibatch elements, because their graphs are disconnected), and use the `batch_sizes` list to produce separate outputs for each graph in the minibatch in the `readout` step.

[1]:https://arxiv.org/pdf/1511.05493.pdf

In [None]:
class GGSNN(MPNN):
    '''
    GGSNN model for operating on the Tox21 dataset
    '''
    def __init__(self, vertex_feature_size, hidden_size, output_size, **kwargs):
        super(GGSNN, self).__init__(**kwargs)
        
        # Initializing model components
        with self.name_scope():
            self.vertex_init = gluon.nn.Dense(hidden_size, in_units=vertex_feature_size)
            self.message_fxn = gluon.nn.Dense(hidden_size, in_units=hidden_size, use_bias=False)
            self.gru = gluon.rnn.GRUCell(hidden_size, input_size=hidden_size)
            self.readout_1 = gluon.nn.Sequential()
            with self.readout_1.name_scope():
                self.readout_1.add(gluon.nn.Dense(hidden_size*2, activation='tanh'))
                self.readout_1.add(gluon.nn.Dense(hidden_size))
            self.readout_2 = gluon.nn.Sequential()
            with self.readout_2.name_scope():
                self.readout_2.add(gluon.nn.Dense(hidden_size, activation='tanh'))
                self.readout_2.add(gluon.nn.Dense(hidden_size))
            self.readout_final = gluon.nn.Dense(output_size, in_units=hidden_size)
                
    def init_hidden_states_and_edges(self, graph):
        # vertex_features are a (num_vertices x num_features) NDarray 
        # edges is a (num_vertices x num_vertices) sparse NDarray
        # batch_sizes is a list of the sizes of the graphs in the batch that were combined into the graph
        vertex_features, edges, batch_sizes = graph
        init_hidden_states = nd.tanh(self.vertex_init(vertex_features))
        # Saving these for use in the readout function later - not every MPNN requires this, but GGSNNs do
        self.init_hidden_states = init_hidden_states.copy()
        self.batch_sizes = batch_sizes
        return init_hidden_states, edges
    
    def compute_messages(self, hidden_states, edges, t):
        passed_msgs = self.message_fxn(hidden_states)
        summed_msgs = nd.sparse.dot(edges, passed_msgs)
        return summed_msgs
    
    def update_hidden_states(self, hidden_states, messages, t):
        hidden_states, _ = self.gru(messages, [hidden_states])
        return hidden_states
    
    def readout(self, hidden_states):
        readout_in_1 = nd.concat(hidden_states, self.init_hidden_states, dim=1)
        readout_hid_1 = nd.sigmoid(self.readout_1(readout_in_1))
        readout_hid_2 = self.readout_2(hidden_states)
        readout_hid = readout_hid_1 * readout_hid_2
        readout_attention = []
        i = j = 0
        while self.batch_sizes:
            i = j
            j += self.batch_sizes.pop(0)
            readout_attention.append(nd.sum(readout_hid[i:j], axis=0, keepdims=True))
        readout_attention = nd.concat(*readout_attention, dim=0)
        return self.readout_final(readout_attention)

### Let's train!

We'll create a new GGSNN instance and initialize our model parameters, loss function, and optimizer as usual:

In [None]:
model = GGSNN(vertex_feature_size=75, hidden_size=100, output_size=2, n_msg_pass_iters=6)
model.collect_params().initialize(mx.init.Normal(sigma=.01), ctx=ctx)
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': .0002})

Then we'll add a few helper functions to keep the training loop code clean:

In [None]:
def batchify_graphs(graphs):
    '''
    Args:
        batch: List of graphs in {vertex_feature, adjacency_matrix, label} format
        
    Returns:
        The combination of the input graphs into a big disconnected graph
        The labels of each of the input graphs
    '''
    vertex_features = np.concatenate([g['vertex_features'] for g in graphs])
    vertex_features = nd.array(vertex_features, dtype='float32', ctx=ctx)
    adj_mat = sp.sparse.block_diag([g['adj_mat'] for g in graphs]).tocsr()
    adj_mat = nd.sparse.csr_matrix((adj_mat.data, adj_mat.indices, adj_mat.indptr), dtype='float32', ctx=ctx)
    batch_sizes = [g['vertex_features'].shape[0] for g in graphs]
    labels = nd.array([g['label'] for g in graphs], dtype='float32', ctx=ctx)
    return (vertex_features, adj_mat, batch_sizes), labels

In [None]:
def evaluate_accuracy(dataset, model, n_batch):
    '''
    Measures the accuracy of the model on the provided dataset, in batches
    '''
    acc = mx.metric.Accuracy()
    for i in range(0, math.ceil(len(dataset)/n_batch)):
        data = dataset[n_batch*i:n_batch*(i+1)]
        graph, label = batchify_graphs(data)
        output = model(graph)
        predictions = nd.argmax(output, axis=1)
        acc.update(preds=predictions, labels=label)
    return acc.get()[1]

In [None]:
def evaluate_roc_score(dataset, model, n_batch):
    '''
    Measures the area under the ROC curve of the model on the provided dataset, in batches
    '''
    pos_probs = []
    labels = []
    for i in range(0, math.ceil(len(dataset)/n_batch)):
        data = dataset[n_batch*i:n_batch*(i+1)]
        graph, label = batchify_graphs(data)
        output = model(graph)
        pos_probs.append(nd.softmax(output)[:,1])
        labels.append(label)
    labels = nd.concat(*labels, dim=0).asnumpy()
    pos_probs = nd.concat(*pos_probs, dim=0).asnumpy()
    return metrics.roc_auc_score(labels, pos_probs)

Notice that the class balance in the dataset is heavily skewed toward the "not toxic" label:

In [None]:
labels = np.array([i['label'] for i in train_dataset])
print('Percentage of "not toxic" labels in training data = {}'.format(sum(labels == 0)/len(labels)))

This is why in the training loop below we're measuring the [ROC AUC](https://en.wikipedia.org/wiki/Receiver_operating_characteristic), in addition to just the accuracy.

Now let's train!

In [None]:
n_epochs = 30
n_batch = 128

for e in range(n_epochs):
    cumulative_loss = 0
    for i in range(0, math.ceil(len(train_dataset)/n_batch)):
        data = train_dataset[n_batch*i:n_batch*(i+1)]
        graph, label = batchify_graphs(data)
        with autograd.record():
            output = model(graph)
            loss = softmax_cross_entropy(output, label)
        loss.backward()
        trainer.step(len(data))
        cumulative_loss += nd.sum(loss).asscalar()
    
    valid_accuracy = evaluate_accuracy(valid_dataset, model, n_batch)
    train_accuracy = evaluate_accuracy(train_dataset, model, n_batch)
    valid_roc = evaluate_roc_score(valid_dataset, model, n_batch)
    train_roc = evaluate_roc_score(train_dataset, model, n_batch)
    print('Epoch {}. Loss: {}, \n\tTrain_acc {}, Valid_acc {}\n\tTrain_roc_auc {}, Valid_roc_auc {}'.format(
            e, cumulative_loss/len(train_dataset), train_accuracy, valid_accuracy, train_roc, valid_roc))
    
print('Test Accuracy: {}'.format(evaluate_accuracy(test_dataset, model, n_batch)))

Alright!  Unsurprisingly, given the class imbalance, our accuracy didn't improve much; but our ROC score got much better, in line with the current state of the art on this dataset: see the physiology section [here](http://moleculenet.ai/latest-results).

Now go forth and invent your own types of MPNNs!