Speed up GNN training
===================
In previous tutorial (GCN.ipynb), we have seen how to implement an end-to-end GCN model for community detection with DGL. 

Despite the fact that training looks fast in previous example, it's only because the graph is small (34 nodes, 190 edges) and the node feature is just a scalar. In reality, we will always be dealing with large graphs which could have more than millions of nodes and edges, each associated large features.

There are two challenges for making computation over graph efficient:
- nodes in graphs have different degrees (power-law distribution usually), and the best we can do is batch reduce functions by in-degree of nodes
- number of edges is usually one order of magnitude larger than nodes, and materialized messages that are stored on edges consume huge amount of memory

Therefore, when training GNN models on large graphs, people always easily get slow training speed or even an out-of-memory error. And this tutorial provides some optimization guideline about how to write GNN model efficiently with DGL.

In [None]:
# A bit of setup, just ignore this cell
import matplotlib.pyplot as plt

# for auto-reloading external modules
%load_ext autoreload
%autoreload 2

%matplotlib inline
plt.rcParams['figure.figsize'] = (8.0, 6.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
plt.rcParams['animation.html'] = 'html5'

Fuse message reduction into one kernel
--------------------------------------------------------------------------------

To address the first challenge, DGL exploits the fact that if the reduce function is a summation, the reduce phase can be replaced with a sparse matrix dense vector multiplication (SPMV) between messages and [incidence matrix](https://en.wikipedia.org/wiki/Incidence_matrix), as illustrated below:
![](https://www.dropbox.com/s/cditivsb50w2i5c/fuse_reduce.png?dl=1)

In the figure, $M_{ij}$ represents message sent from node $i$ to node $j$, and $M_i$ represents aggregated messages that node $i$ received. The incidence matrix is a sparse matrix that encodes connectivity between edges and destination node. Each row in the graph represents a destination node, and each column represents an edge. If the value at location $(i, j)$ has value $1$, then edge $j$'s destination end is node $i$.

By replacing reduce phase with a simple sparse matrix multiplication kernel, DGL avoids the cost to 
- analyze graph structure and assign receive nodes to execution buckets based on node in-degree
- loop over each degree bucket and perform reduce
- merge results of each degree buckets

To enable this optimization, DGL requires users to provide hints about the reduce function. We provide many commonly used builtin reduce functions like `sum`, `max`, which, on one hand, saves users' trouble to define reduce function, and on the other hand, informs DGL what reduce function is doing, so that optimization can be done.

Now we can re-define the GCN model with SPMV reduce optimization, and DGL will automatically generate the sparse matrix for fused execution.

In [None]:
import dgl.function as fn
import torch.nn as nn
import torch.nn.functional as F

# Define the message & reduce function
# NOTE: we ignore the normalization constant c_ij for now.
def gcn_message(edges):
    # messages are the features of the source nodes.
    return {'msg' : edges.src['h']}

# Define the GCN module
class GCN(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(GCN, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
    
    def forward(self, g, inputs):
        # g is the graph and the inputs is the input node features
        # first perform linear transformation
        h = self.linear(inputs)
        # set the node features
        g.ndata['h'] = h
        # trigger message passing
        g.send(g.edges(), gcn_message)
        g.recv(g.nodes(), fn.sum('msg', 'h'))
        # get the result node features
        h = g.ndata.pop('h')
        return h

### Fuse message passing into one kernel
The fusion of reduce into one kernel is already able to significantly reduce execution time. And if the message function is also some known pattern like copying out source node representation (DGL's builtin message function `copy_src`), it can also be fused with reduce function:
![](https://www.dropbox.com/s/ws62v6ukjx968fb/fuse_mp.png?dl=1)

Here the sparse matrix is the transpose of the adjacency matrix which encodes connectivity between nodes in graphs.

In order to trigger the fusion of entire message passing, DGL needs to know both the message function and reduce function. DGL provides many routines that combines basic `send` and `recv` in various ways. They are called **level-2 APIs**. For example, we can use the `send_and_recv` API to trigger both message function and reduce function in one API. Furthermore, since GCN performs computation on the entire graph, we can use `update_all` API in the GCN module so that `edges()` and `nodes()` can be omitted.

In [None]:
# Re-define the GCN module using DGL builtin functions and level-2 APIs.
class GCN_level2(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(GCN_level2, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
    
    def forward(self, g, inputs):
        # g is the graph and the inputs is the input node features
        # first perform linear transformation
        h = self.linear(inputs)
        # set the node features
        g.ndata['h'] = h
        # trigger message passing using `update_all`
        # original codes:
        #   g.send(g.edges(), gcn_message)
        #   g.recv(g.nodes(), gcn_reduce)
        g.update_all(fn.copy_src('h', 'msg'), fn.sum('msg', 'h'))
        # get the result node features
        h = g.ndata.pop('h')
        return h

Now let's define the GCN community detection model using `GCN_level2` and re-train it again on karate club graph from previous example.

In [None]:
# Define a 2-layer GCN model
class Net(nn.Module):
    def __init__(self, in_feats, hidden_size, num_classes):
        super(Net, self).__init__()
        self.gcn1 = GCN_level2(in_feats, hidden_size)
        self.gcn2 = GCN_level2(hidden_size, num_classes)
    
    def forward(self, g, inputs):
        h = self.gcn1(g, inputs)
        h = torch.relu(h)
        h = self.gcn2(g, h)
        return h

import dgl, torch
import torch.nn as nn
import torch.nn.functional as F
import networkx as nx
from tutorial_utils import create_karate_graph, convert_to_bidirectional
G = create_karate_graph()
GG = convert_to_bidirectional(G)
    
inputs = torch.eye(34)  # featureless inputs
labeled_nodes = torch.tensor([0, 33])  # only the instructor and the president nodes are labeled
labels = torch.tensor([0, 1])  # their labels are different
net = Net(34, 5, 2)
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)

all_logits = []
for epoch in range(30):
    logits = net(GG, inputs)
    all_logits.append(logits.detach())
    logp = F.log_softmax(logits, 1)
    # we only compute loss for node 0 and node 33
    loss = F.nll_loss(logp[labeled_nodes], labels)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print('Epoch %d | Loss: %.4f' % (epoch, loss.item()))

In [None]:
# Visualize the node classification using the logits output.
import numpy as np
import matplotlib.animation as animation
from IPython.display import HTML

fig = plt.figure(dpi=150)
fig.clf()
ax = fig.subplots()
nx_G = G.to_networkx()
def draw(i):
    cls1color = '#00FFFF'
    cls2color = '#FF00FF'
    pos = {}
    colors = []
    for v in range(34):
        pos[v] = all_logits[i][v].numpy()
        cls = np.argmax(pos[v])
        colors.append(cls1color if cls else cls2color)
    ax.cla()
    ax.axis('off')
    ax.set_title('Epoch: %d' % i)
    nx.draw(nx_G.to_undirected(), pos, node_color=colors, with_labels=True, node_size=500)

ani = animation.FuncAnimation(fig, draw, frames=len(all_logits), interval=200)
HTML(ani.to_html5_video())

### Summary
Writing GNN models with DGL's builtin message and reduce functions allows DGL to perform optimizations like fusing computation into one kernel.

However, most GNN models are complicated with carefully designed message and reduce function, in which case, DGL's builtin won't be expressive enough. But in principle, users should try to push as much computation into message function and node apply function since they are usually perfectly parallelizable and keep reduce function as simple as possible.