# 写你自己的 GNN 模块

通常，模型并不是仅仅通过简单堆叠现有 GNN 模块实现的

In [1]:
import os
os.environ['DGLBACKEND'] = 'pytorch'

import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F

# 信息传递

尽管可以使用 DGL 内建函数 `dgl.nn.SAGEConv` 实现 GraphSAGE，本教程中也可以自行搭建  

关于下面代码中 `g.update_all` 用法，其收集并平均相邻特征：
 - message 函数 `fn.copy_u('h', 'm')` 复制名为 `h` 的节点特征作为名为 `m` 的 *message* 并发送给相邻节点；
 - reduce 函数 `fn.mean('m', 'h_N')` 将所有收到的名为 `m` 的 messages 平均，后保存为一个新的名为 `h_N` 的节点特征

In [2]:
class SAGEConv(nn.Module):
    '''Graph convolution module used by the GraphSAGE model.
    
    Parameters
    ----------
    in_feat: int
        Input feature size
    out_feat: int
        Output feature size
    '''
    def __init__(self, in_feat, out_feat):
        super(SAGEConv, self).__init__()
        self.linear = nn.Linear(in_feat * 2, out_feat)
    
    def forward(self, g, h):
        '''Forward computation
        
        Parameters
        ----------
        g: Graph
            The input graph.
        h: Tensor
            The input node feature
        '''
        with g.local_scope():
            g.ndata['h'] = h
            g.update_all(
                message_func = fn.copy_u('h', 'm'),
                reduce_func = fn.mean('m', 'h_N'),
            )
            h_N = g.ndata['h_N']
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)

可以堆叠自己的 GraphSAGE 卷积层，从而形成多层的 GraphSAGE 网络

In [3]:
class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats)
        self.conv2 = SAGEConv(h_feats, num_classes)
    
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

## 训练

In [4]:
import dgl.data

dataset = dgl.data.CoraGraphDataset()
g = dataset[0]
g

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


Graph(num_nodes=2708, num_edges=10556,
      ndata_schemes={'feat': Scheme(shape=(1433,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool), 'train_mask': Scheme(shape=(), dtype=torch.bool)}
      edata_schemes={})

In [5]:
def train(g, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    all_logits = []
    best_val_acc = 0
    best_test_acc = 0
    
    features = g.ndata['feat']
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']
    
    for e in range(200):
        # forward
        logits = model(g, features)
        
        # compute prediction
        pred = logits.argmax(1)
        
        # compute loss
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])
        
        # compute accuracy
        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
        val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()
        
        # update best accuracy
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc
            
        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        all_logits.append(logits.detach())
        
        if e % 5 == 0:
            print('In epoch {}, loss {:.3f}, val acc {:.3f} (best {:.3f}), test acc {:.3f} (best {:.3f})'\
                 .format(e, loss, val_acc, best_val_acc, test_acc, best_test_acc))

In [6]:
model = Model(g.ndata['feat'].shape[1], 16, dataset.num_classes)
train(g, model)

In epoch 0, loss 1.952, val acc 0.122 (best 0.122), test acc 0.130 (best 0.130)
In epoch 5, loss 1.881, val acc 0.236 (best 0.236), test acc 0.243 (best 0.243)
In epoch 10, loss 1.742, val acc 0.356 (best 0.424), test acc 0.352 (best 0.409)
In epoch 15, loss 1.534, val acc 0.436 (best 0.436), test acc 0.431 (best 0.431)
In epoch 20, loss 1.269, val acc 0.558 (best 0.558), test acc 0.540 (best 0.540)
In epoch 25, loss 0.974, val acc 0.634 (best 0.634), test acc 0.623 (best 0.623)
In epoch 30, loss 0.691, val acc 0.674 (best 0.674), test acc 0.678 (best 0.678)
In epoch 35, loss 0.453, val acc 0.710 (best 0.710), test acc 0.711 (best 0.711)
In epoch 40, loss 0.280, val acc 0.740 (best 0.740), test acc 0.753 (best 0.753)
In epoch 45, loss 0.168, val acc 0.738 (best 0.744), test acc 0.749 (best 0.754)
In epoch 50, loss 0.102, val acc 0.748 (best 0.752), test acc 0.752 (best 0.751)
In epoch 55, loss 0.064, val acc 0.746 (best 0.752), test acc 0.760 (best 0.751)
In epoch 60, loss 0.043, val a

## 笔记

 - model 是通过 Model 定义的，三个输入参数为输入、隐含、输出的特征数量
 - train 函数中，logits 值为当前模型的输出值，是由 model(g, features) 获得的（其实这两个参数 g 和 features 均为常量，logits 在每轮训练中不同的原因是调整了 model 中各神经元的权重）
 - logits 的 shape 为（节点数 x 类别数），即每个节点都有 7 个相应的计算值

# 更多的自定义

In [8]:
class WeightedSAGEConv(nn.Module):
    '''Graph convolution module used by the GraphSAGE model with edge weights.
    
    Parameters
    ----------
    in_feat: int
        Input feature size
    out_feat: int
        Output feature size
    '''
    
    def __init__(self, in_feat, out_feat):
        super(WeightedSAGEConv, self).__init__()
        self.linear = nn.Linear(in_feat * 2, out_feat)
        
    def forward(self, g, h, w):
        '''Forward computation
        
        Parameters
        ----------
        g: Graph
            The input graph.
        h: Tensor
            The input node feature.
        w: Tensor
            The edge weight.
        '''
        with g.local_scope():
            g.ndata['h'] = h
            g.edata['w'] = w
            g.update_all(
                message_func = fn.u_mul_e('h', 'w', 'm'),
                reduce_func = fn.mean('m', 'h_N'),
            )
            h_N = g.ndata['h_N']
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)

由于数据集中的图并没有边的权重，我们手动分配所有边权重为 1

In [9]:
class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = WeightedSAGEConv(in_feats, h_feats)
        self.conv2 = WeightedSAGEConv(h_feats, num_classes)
    
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat, torch.ones(g.num_edges(), 1))
        h = F.relu(h)
        h = self.conv2(g, h, torch.ones(g.num_edges(), 1))
        return h

In [10]:
model = Model(g.ndata['feat'].shape[1], 16, dataset.num_classes)
train(g, model)

In epoch 0, loss 1.954, val acc 0.114 (best 0.114), test acc 0.103 (best 0.103)
In epoch 5, loss 1.870, val acc 0.232 (best 0.232), test acc 0.236 (best 0.236)
In epoch 10, loss 1.718, val acc 0.514 (best 0.588), test acc 0.536 (best 0.577)
In epoch 15, loss 1.493, val acc 0.522 (best 0.588), test acc 0.518 (best 0.577)
In epoch 20, loss 1.201, val acc 0.560 (best 0.588), test acc 0.573 (best 0.577)
In epoch 25, loss 0.880, val acc 0.618 (best 0.618), test acc 0.627 (best 0.627)
In epoch 30, loss 0.584, val acc 0.658 (best 0.658), test acc 0.673 (best 0.673)
In epoch 35, loss 0.357, val acc 0.714 (best 0.714), test acc 0.719 (best 0.719)
In epoch 40, loss 0.209, val acc 0.736 (best 0.736), test acc 0.739 (best 0.739)
In epoch 45, loss 0.122, val acc 0.738 (best 0.738), test acc 0.741 (best 0.741)
In epoch 50, loss 0.074, val acc 0.742 (best 0.742), test acc 0.743 (best 0.743)
In epoch 55, loss 0.048, val acc 0.740 (best 0.742), test acc 0.746 (best 0.743)
In epoch 60, loss 0.033, val a