<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Cora-Dataset" data-toc-modified-id="Cora-Dataset-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Cora Dataset</a></span></li><li><span><a href="#使用GraphConv定义GCN" data-toc-modified-id="使用GraphConv定义GCN-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>使用GraphConv定义GCN</a></span></li><li><span><a href="#使用builtin定义GCN" data-toc-modified-id="使用builtin定义GCN-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>使用builtin定义GCN</a></span></li><li><span><a href="#模型训练" data-toc-modified-id="模型训练-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>模型训练</a></span></li></ul></div>

如何使用DGL实现GCN (Graph Convolutional Network)模型，在Cora数据集上完成节点分类任务

## Cora Dataset

In [1]:
import dgl.data.citation_graph as citegrh

Using backend: pytorch


In [2]:
from dgl import DGLGraph

In [3]:
import numpy as np

In [4]:
import torch as th

In [5]:
def load_cora_data():
    data = citegrh.load_cora()
    features = th.FloatTensor(data.features)
    labels = th.LongTensor(data.labels)
    train_mask = th.BoolTensor(data.train_mask)
    val_mask = th.BoolTensor(data.val_mask)
    test_mask = th.BoolTensor(data.test_mask)
    g = DGLGraph(data.graph)
    
    return g, features, labels, train_mask, val_mask, test_mask
    

In [6]:
g, features, labels, train_mask, val_mask, test_mask = load_cora_data()

In [7]:
num_nodes, feature_dims = features.shape
num_nodes, feature_dims

(2708, 1433)

In [33]:
np.unique(labels)

array([0, 1, 2, 3, 4, 5, 6])

In [8]:
num_classes = len(np.unique(labels))
num_classes

7

## 使用GraphConv定义GCN

在DGL中可以使用多种方法定义GCN，包括预定义的GraphConv类，send/recv接口，builtin接口。

In [9]:
from dgl.nn import GraphConv

In [12]:
from torch import nn

In [13]:
import torch.nn.functional as F

In [15]:
class GCNGraphConv(nn.Module):
    def __init__(self, in_feats, hidden_size, num_classes):
        super(GCNGraphConv, self).__init__()
        self._conv1 = GraphConv(in_feats, hidden_size)
        self._conv2 = GraphConv(hidden_size, num_classes)
        
    def forward(self, g, features):
        h = self._conv1(g, features)
        h = F.relu(h)
        h = self._conv2(g, h)
        return h
        

In [16]:
hidden_size = 512
model_graph_conv = GCNGraphConv(feature_dims, hidden_size, num_classes)

## 使用builtin定义GCN

dgl整个框架的设计理念其实是将图神经网络看做是一个信息传递(Message Passing)的过程，每一个节点向邻居节点发送信息，同时也接受来自于其他节点的信息，并且将这些信息聚合然后更新自己的状态。每个节点向其他节点发送信息，在dgl内部对应`message_func`，节点汇聚信息则对应`reduce_func`。`builtin`是dgl内部用于加快速度封住的`message_func`函数和`reduce_func`函数。上一节使用dgl提供的`GraphConv`定义了GCN模型，其实`GraphConv`内部的实现也是基于这个思想实现的。

在GCN中，每一个节点的状态更新使用如下公式定义：
$$h_i^{(l+1)} = \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}\frac{1}{c_{ij}}h_j^{(l)}W^{(l)}) = 
\sigma(b^{(l)} + W^{(l)}\frac{1}{\sqrt{|\mathcal{N}(i)|}}\sum_{j\in\mathcal{N}(i)}\frac{1}{\sqrt{|\mathcal{N}(j)|}}h_j^{(l)})$$

其中，$\mathcal{N}(i)$表示与节点$i$相连的节点个数，$c_{ij}$则等于$\sqrt{|\mathcal{N}(i)|}\sqrt{|\mathcal{N}(j)|}$，起到了归一化的作用。权重$W^{(l)}$使用Glorot均匀分布初始化，偏置$b^{(l)}$初始化为0，$\sigma$是一个激活函数。

In [48]:
from torch import nn

In [50]:
from torch.nn import init

In [51]:
import torch as th

In [13]:
import torch.nn.functional as F

In [58]:
import dgl.function as fn

In [104]:
class GraphConvMsgPass(nn.Module):
    """Define a Graph Convoutional Layer whose weights are initialized 
    with Glorot uniform distribution, bias are initialized to zero,
    and without activation function."""
    def __init__(self, in_feats, out_feats):
        super(GraphConvMsgPass, self).__init__()
        self._in_feats = in_feats  # size of input features
        self._out_feats = out_feats  # sizse of output features
        
        # initialize weights with Glorot uniform distribution
        self._weights = nn.Parameter(th.Tensor(in_feats, out_feats))
        init.xavier_uniform_(self._weights)
        
        # initialize bias to zero
        self._bias = nn.Parameter(th.Tensor(out_feats))
        init.zeros_(self._bias)
    
    def forward(self, graph, features):
        # calculate out degrees for each node -> N(i)
        degs = graph.out_degrees().float().clamp(1)
        # 1/sqrt(deg)
        norm = th.pow(degs, -0.5)
        norm = th.reshape(norm, degs.shape + (1,))
        # h_j * (1 / sqrt(N(j)))
        features = features * norm
        
        graph.srcdata['h'] = features
        graph.update_all(message_func=fn.copy_src(src='h', out='m'),
                         reduce_func=fn.sum(msg='m', out='h'))
        rst = graph.dstdata['h']
        
        # weight
        rst = th.matmul(rst, self._weights)
        
        # normalization -> 1/sqrt(N(i))
        rst = rst * norm
        
        # bias
        rst = rst + self._bias
        
        return rst
         

In [105]:
class GCNGraphMsgPass(nn.Module):
    def __init__(self, in_feats, hidden_size, num_classes):
        super(GCNGraphMsgPass, self).__init__()
        self._conv1 = GraphConvMsgPass(in_feats, hidden_size)
        self._conv2 = GraphConvMsgPass(hidden_size, num_classes)
        
    def forward(self, g, features):
        h = self._conv1(g, features)
        h = F.relu(h)
        h = self._conv2(g, h)
        return h
        

In [106]:
hidden_size = 512
model_graph_msg = GCNGraphMsgPass(feature_dims, hidden_size, num_classes)

## 模型训练

In [17]:
from torch.optim import Adam

In [18]:
import torch as th

In [44]:
def validate(model, g, features, labels, mask):
    logits = model(g, features)
    logits = F.log_softmax(logits, dim=1)
    predict = th.argmax(logits, dim=1)
    acc = th.sum(labels[mask] == predict[mask])

    return acc.item() / len(labels[mask]) * 1.0

In [46]:
def train_model(model, max_epochs=50, val_interval=5):
    optimizer = Adam(model.parameters(), lr=0.01)
    model.train()
    for epoch in range(max_epochs):
        logits = model(g, features)
        logits = F.log_softmax(logits, dim=1)
        loss = F.nll_loss(logits[train_mask], labels[train_mask])
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print(f'+++ Epoch = {epoch}, loss = {loss.item():.4f}')
        
        if (epoch + 1) % val_interval == 0:
            acc = validate(model, g, features, labels, val_mask)
            print(f'+++ accuracy for validation dataset is {acc:.4f}')
            
    test_acc = validate(model, g, features, labels, test_mask)
    print(f'++++ accuracy for test dataset is {test_acc:.4f}')
            
    

In [107]:
train_model(model_graph_conv)

+++ Epoch = 0, loss = 0.0099
+++ Epoch = 1, loss = 0.0137
+++ Epoch = 2, loss = 0.0101
+++ Epoch = 3, loss = 0.0110
+++ Epoch = 4, loss = 0.0121
+++ accuracy for validation dataset is 0.7633
+++ Epoch = 5, loss = 0.0111
+++ Epoch = 6, loss = 0.0100
+++ Epoch = 7, loss = 0.0101
+++ Epoch = 8, loss = 0.0108
+++ Epoch = 9, loss = 0.0110
+++ accuracy for validation dataset is 0.7233
+++ Epoch = 10, loss = 0.0106
+++ Epoch = 11, loss = 0.0100
+++ Epoch = 12, loss = 0.0099
+++ Epoch = 13, loss = 0.0102
+++ Epoch = 14, loss = 0.0105
+++ accuracy for validation dataset is 0.7600
+++ Epoch = 15, loss = 0.0104
+++ Epoch = 16, loss = 0.0101
+++ Epoch = 17, loss = 0.0099
+++ Epoch = 18, loss = 0.0100
+++ Epoch = 19, loss = 0.0101
+++ accuracy for validation dataset is 0.7367
+++ Epoch = 20, loss = 0.0102
+++ Epoch = 21, loss = 0.0101
+++ Epoch = 22, loss = 0.0099
+++ Epoch = 23, loss = 0.0099
+++ Epoch = 24, loss = 0.0100
+++ accuracy for validation dataset is 0.7533
+++ Epoch = 25, loss = 0.0101


In [108]:
train_model(model_graph_msg)

+++ Epoch = 0, loss = 1.9455
+++ Epoch = 1, loss = 1.8626
+++ Epoch = 2, loss = 1.7700
+++ Epoch = 3, loss = 1.6800
+++ Epoch = 4, loss = 1.6054
+++ accuracy for validation dataset is 0.4667
+++ Epoch = 5, loss = 1.5224
+++ Epoch = 6, loss = 1.4174
+++ Epoch = 7, loss = 1.3029
+++ Epoch = 8, loss = 1.1911
+++ Epoch = 9, loss = 1.0832
+++ accuracy for validation dataset is 0.7167
+++ Epoch = 10, loss = 0.9752
+++ Epoch = 11, loss = 0.8670
+++ Epoch = 12, loss = 0.7619
+++ Epoch = 13, loss = 0.6632
+++ Epoch = 14, loss = 0.5725
+++ accuracy for validation dataset is 0.8033
+++ Epoch = 15, loss = 0.4900
+++ Epoch = 16, loss = 0.4162
+++ Epoch = 17, loss = 0.3517
+++ Epoch = 18, loss = 0.2968
+++ Epoch = 19, loss = 0.2509
+++ accuracy for validation dataset is 0.8133
+++ Epoch = 20, loss = 0.2131
+++ Epoch = 21, loss = 0.1824
+++ Epoch = 22, loss = 0.1573
+++ Epoch = 23, loss = 0.1368
+++ Epoch = 24, loss = 0.1199
+++ accuracy for validation dataset is 0.8267
+++ Epoch = 25, loss = 0.1060
