Message Passing with DGL
=====================

In previous tutorial (1_Basics.ipynb), we studied the basic usage of DGL like creating and manipulating a DGLGraph. In this tutorial, we will focus on how to perform computation on graph structures following Message Passing paradigm.

In [None]:
# Use MXNet as backend
import os
os.environ['DGLBACKEND'] = 'mxnet'

# A bit of setup, just ignore this cell
import matplotlib.pyplot as plt

# for auto-reloading external modules
%load_ext autoreload
%autoreload 2

%matplotlib inline
plt.rcParams['figure.figsize'] = (8.0, 6.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
plt.rcParams['animation.html'] = 'html5'

In this tutorial, we  still use karate club as example. And we provide a utility function to create the graph:

In [None]:
import dgl
import mxnet.ndarray as nd
import networkx as nx
from tutorial_utils import create_karate_graph
G = create_karate_graph()

![](https://www.dropbox.com/s/uqzor4lqsmbnz8k/karate1.jpg?dl=1)

Message passing on graph
-------------------------------------------------

Many graph neural networks follows the **message passing** computation model ([Gilmer et al, 2017](https://arxiv.org/abs/1704.01212)):
- each node receives and aggregates messages from its neighbors  
$$m_v^{t+1} = \sum\limits_{w\in \mathcal{N}(v)}M_t(h_v^t, h_w^t, e_{vw}^t)$$, where $\mathcal{N}(v)$ is the neighbor set of node $v$.

- each node update its own embedding using aggregated messages
$$h_v^{t+1} = U_t(h_v^t, m_v^{t+1})$$

We will go through the basic mechanism of message passing using a toy task:

Suppose the karate club president (node 33) is sending out an invitation of their annual karate match. The president also asks the club members to broadcast the news to, of course, their friends in the club. We use a scalar to represent whether the member has received the invitation or not (1 for invited, 0 for not invited). Initially, everyone is 0 except node 33.

In [None]:
# We first convert the uni-directional edges to bi-directional so messages can
#   be sent in both direction.
src, dst = G.edges()
G.add_edges(dst, src)
# add self loop for each nodes for convenience
v = G.nodes()
G.add_edges(v, v)
print('We now have %d edges!' % G.number_of_edges())

# init the state
G.ndata['invited'] = nd.zeros((34,))
G.nodes[33].data['invited'] = nd.array([1.])
print(G.ndata['invited'])

We first define the function that computes the messages. In DGL, the message function is an **Edge UDF** that takes in a single argument `edges`. It has three members `src`, `dst`, and `data` for accessing source node features, destination node features, and edge features respectively.

In [None]:
def message_func(edges):
    # The message is simply the 'invited' state of the source nodes.
    return {'msg' : edges.src['invited']}

Next, we define the reduce function which accumulates and consume the messages to update the node features. In DGL, the reduce function is a **Node UDF** that takes in a single argument `nodes`, which has two members `data` and `mailbox`. `data` contains the node features while `mailbox` contains all incoming message features, stacked along the second dimension (hence the `dim=1` argument).

In [None]:
def reduce_func(nodes):
    # The reduce function sets the 'invited' state to be one if the node has already
    #   been invited or any of the received messages contains an invitation (is one).
    #   This can be done using sum and clamp operations as follows.
    accum = nodes.mailbox['msg'].sum(axis=1)  # note that messages are stacked on dim=1
    return {'invited' : accum.clip(a_min=float("-inf"), a_max=1)}

To trigger the message and reduce function, one can use the `send` and `recv` APIs. Following codes send out the messages from node 33:

In [None]:
# The first argument to `G.send` is the edges along which the messages are sent.
# Note that we can use the same syntax used in adding edges to the graph.
# The second argument is the message function we just defined.
G.send((33, G.successors(33)), message_func)

We then call `recv` on the receiver nodes to trigger the reduce function.

In [None]:
G.recv(G.successors(33), reduce_func)

You can print out the `'invited'` status to see the invitation being propagated.

In [None]:
print(G.ndata['invited'])

**What's under the hood?**

The key idea here is to automatically batch the node and edge features so that your UDF can compute message passing on multiple nodes and edges in parallel.

```python
def message_func(edges):
    return {'msg' : edges.src['invited']}
```

The `edges` argument is an `EdgeBatch` object representing a batch of edges. It has three members, `src`, `dst`, `data`. The `edges.src['invited']` returns a tensor of shape `(B,)`, where `B` is the number of edges being triggered.

```python
def reduce_func(nodes):
    accum = nodes.mailbox['msg'].sum(dim=1)
    return {'invited' : accum.clip(a_min=0, a_max=1)}
```

Similarly, for the reduce function, the argument `nodes` is an `NodeBatch` object representing a batch of nodes. It has two members `data` and `mailbox`. The `nodes.mailbox['msg']` returns a tensor of shape `(B, deg)`, where `B` is the number of nodes that have the same in-degree `deg`. The reduce function will be called *many times* for each degree group.

In [None]:
# Exercise: please write code to continue broadcasting the invitation until all members in the graph are invited.

#Hint 1: you can trigger message function on all edges and reduce function on all nodes
#Hint 2: you can get all edges with G.edges() and all nodes with G.nodes()

# >>> YOUR CODE STARTS

num_invited = int(G.ndata['invited'].sum().asscalar())
print("{} members invited".format(num_invited))

while num_invited < 34:
    G.send(G.edges(), message_func)
    G.recv(G.nodes(), reduce_func)
    num_invited = int(G.ndata['invited'].sum().asscalar())
    print("{} members invited".format(num_invited))
    
# <<< YOUR CODE ENDS