# 在大图上训练GNN模型的方法

We have seen the example of training GNNs on the entire graph.  However, usually our graph is very big: it could contain millions or billions of nodes and edges.  The storage required for the graph would be many times bigger if we consider node and edge features.  If we want to utilize GPUs for faster computation, we would notice that full graph training is often impossible on GPUs because our graph and features cannot fit into a single GPU.  Not to mention that the node representation of intermediate layers are also stored for the sake of backpropagation.

To get over this limit, we employ two methodologies:

1. Stochastic training on graphs.
2. Neighbor sampling on graphs.

## Stochastic Training on Graphs

If you are familiar with deep learning for images/texts/etc., you should know stochastic gradient descent (SGD) very well.  In SGD, you sample a minibatch of examples, compute the loss on those examples only, find the gradients, and update the model parameters.

Stochastic training on graphs resembles SGD on image/text datasets in the sense that one also samples a minibatch of nodes (or pair/tuple of nodes, depending on the task) and compute the loss on those nodes only.  The difference is that the output representation of a small set of nodes may depend on the input features of a substantially larger set of nodes.

### GraphSAGE Recap

In previous session, we have discussed GraphSAGE model.

The output representation $\boldsymbol{y}_u$ of node $u$ from a GraphSAGE layer is simply computed by:

* Aggregating the input features of all neighbors of $u$ by for instance averaging.
* Concatenating the aggregation with the node $u$'s representation itself.
* Passing the concatenation to an MLP.

And we have defined the following GraphSAGEModel. It leveraged dgl's built-in class SAGEConv and can forward on a whole graph.

In [1]:
import dgl
import dgl.function as fn
import dgl.nn as dglnn
import torch
from torch import nn
import torch.nn.functional as F

In [2]:
from dgl.nn.pytorch import conv as dgl_conv

class GraphSAGEModel(nn.Module):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 out_dim,
                 n_layers,
                 activation,
                 dropout,
                 aggregator_type):
        super(GraphSAGEModel, self).__init__()
        self.layers = nn.ModuleList()

        # input layer
        self.layers.append(dgl_conv.SAGEConv(in_feats, n_hidden, aggregator_type,
                                         feat_drop=dropout, activation=activation))
        # hidden layers
        for i in range(n_layers - 1):
            self.layers.append(dgl_conv.SAGEConv(n_hidden, n_hidden, aggregator_type,
                                             feat_drop=dropout, activation=activation))
        # output layer
        self.layers.append(dgl_conv.SAGEConv(n_hidden, out_dim, aggregator_type,
                                         feat_drop=dropout, activation=None))

    def forward(self, g, features):
        h = features
        for layer in self.layers:
            h = layer(g, h)
        return h

### Batching on a Graph

For stochastic training, we want to split training data into small batches and only put necessary information into GPU for each step of training. In case of node classification, we want to split the labeld nodes into batches. Let take a deep look of what information is necessary for a batch of nodes.

For instance, consider the following graph:

![Graph](assets/graph.png)

In [3]:
# A small graph

import networkx as nx

example_graph = nx.Graph(
    [(0, 2), (0, 4), (0, 6), (0, 7), (0, 8), (0, 9), (0, 10),
     (1, 2), (1, 3), (1, 5), (2, 3), (2, 4), (2, 6), (3, 5),
     (3, 8), (4, 7), (8, 9), (8, 11), (9, 10), (9, 11)])
example_graph = dgl.graph(example_graph)
# We also assign features for each node
INPUT_FEATURES = 5
OUTPUT_FEATURES = 6
example_graph.ndata['features'] = torch.randn(12, INPUT_FEATURES)

If we wish to compute the output representation of node 4 and 6 with a GraphSAGE layer, we actually need the input feature of node 4 and 6 themselves, as well as their neighbors (node 7, 0 and 2):

![Graph](assets/graph_1layer_46.png)

We can see that node 7, 0, and 2 will contribute to representation of node 4, while 0 and 2 will contribute to node 6.

### Finding Neighbors of Nodes

DGL provides an API: `dgl.in_subgraph`, that takes in a set of nodes and returns a graph consisting of all edges going to one of the given nodes.  Such a graph can exactly describe the computation dependency above.

In [4]:
sampled_node_batch = torch.LongTensor([4, 6])   # These are the nodes whose outputs are to be computed
sampled_graph = dgl.in_subgraph(example_graph, sampled_node_batch)
print(sampled_graph.all_edges())

(tensor([0, 2, 7, 0, 2]), tensor([4, 4, 4, 6, 6]))


The result above reads that node 0, 2 and 7 connects to node 4, while node 0 and 2 connects to node 6.

#### Sub Graph to Blocks

Nodes in such a sub graph can be seperated into two roles: 
* The input nodes, which only contain the neighbors of those nodes.
* The output nodes, which only contain the nodes whose outputs are to be computed.

In later propagation, logic of information flow is quite different on nodes with different roles.
DGL further provides a bipartie structure *block* to better reflect this feature. A sub graph can be easily converted to a block with function `dgl.to_block`.

In [5]:
sampled_block = dgl.to_block(sampled_graph, sampled_node_batch)

def print_block_info(sampled_block):
    sampled_input_nodes = sampled_block.srcdata[dgl.NID]
    print('Node ID of input nodes in original graph:', sampled_input_nodes)

    sampled_output_nodes = sampled_block.dstdata[dgl.NID]
    print('Node ID of output nodes in original graph:', sampled_output_nodes)

    sampled_block_edges_src, sampled_block_edges_dst = sampled_block.all_edges()
    # We need to map the src and dst node IDs in the blocks to the node IDs in the original graph.
    sampled_block_edges_src_mapped = sampled_input_nodes[sampled_block_edges_src]
    sampled_block_edges_dst_mapped = sampled_output_nodes[sampled_block_edges_dst]
    print('Edge connections:', sampled_block_edges_src_mapped, sampled_block_edges_dst_mapped)
    
print_block_info(sampled_block)

Node ID of input nodes in original graph: tensor([4, 6, 0, 2, 7])
Node ID of output nodes in original graph: tensor([4, 6])
Edge connections: tensor([0, 2, 7, 0, 2]) tensor([4, 4, 4, 6, 6])


We can see that the input nodes also include node 4 and 6, which are the output nodes themselves. And the edge connections are preserved (i.e. they map to the same ones in `sampled_graph`).

#### GraphSAGE Layer on Blocks

The sampled block is ensantially a bipartite graph. We have seen in previous example that DGL's built-in class `SAGConv` works perfectly on whole graph. Does it also function properly on a *Block*? The answer is yes. Acutally all of DGL's neural network layers support working on both homogeneous graphs and bipartite graphs.

In [6]:
import dgl.nn as dglnn
sageconv_module = dglnn.SAGEConv(INPUT_FEATURES, OUTPUT_FEATURES, 'mean', activation=F.relu)

sampled_block_src_features = example_graph.ndata['features'][sampled_block.srcdata[dgl.NID]]
sampled_block_dst_features = example_graph.ndata['features'][sampled_block.dstdata[dgl.NID]]

output_of_sampled_node_batch = sageconv_module(
    sampled_block, (sampled_block_src_features, sampled_block_dst_features))
print(output_of_sampled_node_batch)

tensor([[1.0929, 0.9714, 1.6120, 0.7548, 1.2298, 0.0000],
        [2.7942, 1.0297, 0.4052, 2.9697, 0.0000, 0.0000]],
       grad_fn=<ReluBackward0>)


### Multiple Layers

Now we wish to compute the output of node 4 and 6 from a 2-layer GraphSAGE.  This requires the input features of not only the nodes themselves and their neighbors, but also the neighbors of these neighbors.

![](assets/graph_2layer_46.png)

To compute the 2-layer output of node 4 and 6, we first need to obtain the 1-layer output of node 4 and 6, as well as the neighbors (node 7, 0, and 2).  To obtain the 1-layer output of all these nodes, we again need the input feature of these nodes (node 4, 6, 7, 0, 2) as well as *their* neighbors (node 10, 9, 8, 1, and 3).  This constitutes a reason why `dgl.to_block` also includes the output nodes in the input nodes.

We can see that the generation of computation dependency for multi-layer GNNs is a bottom-up process: we start from the output layer, and grows the node set towards the input layer.

The following code directly returns the list of blocks as the computation dependency generation for multi-layer GNNs.

In [7]:
class FullNeighborBlockSampler(object):
    def __init__(self, g, num_layers):
        self.g = g
        self.num_layers = num_layers
        
    def sample(self, seeds):
        blocks = []
        for i in range(self.num_layers):
            sampled_graph = dgl.in_subgraph(self.g, seeds)
            sampled_block = dgl.to_block(sampled_graph, seeds)
            seeds = sampled_block.srcdata[dgl.NID]
            # Because the computation dependency is generated bottom-up, we prepend the new block instead of
            # appending it.
            blocks.insert(0, sampled_block)
            
        return blocks

In [8]:
block_sampler = FullNeighborBlockSampler(example_graph, 2)
sampled_blocks = block_sampler.sample(sampled_node_batch)

print('Block for first layer')
print('---------------------')
print_block_info(sampled_blocks[0])
print()
print('Block for second layer')
print('----------------------')
print_block_info(sampled_blocks[1])

Block for first layer
---------------------
Node ID of input nodes in original graph: tensor([ 4,  6,  0,  2,  7,  8,  9, 10,  1,  3])
Node ID of output nodes in original graph: tensor([4, 6, 0, 2, 7])
Edge connections: tensor([ 0,  2,  7,  0,  2,  2,  4,  6,  7,  8,  9, 10,  0,  4,  6,  1,  3,  0,
         4]) tensor([4, 4, 4, 6, 6, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 7, 7])

Block for second layer
----------------------
Node ID of input nodes in original graph: tensor([4, 6, 0, 2, 7])
Node ID of output nodes in original graph: tensor([4, 6])
Edge connections: tensor([0, 2, 7, 0, 2]) tensor([4, 4, 4, 6, 6])


The message propagation is instead a top-down process, as opposed to computation dependency generation: we start from the input layer, and computes the representations towards the output layer.

Now we modify our GraphSAGEModel so it can forward on blocks.

In [9]:
class BatchGraphSAGEModel(nn.Module):
    def __init__(self, n_layers, in_feats, out_feats, hidden_feats=None):
        super().__init__()
        self.convs = nn.ModuleList()
        
        if hidden_feats is None:
            hidden_feats = out_feats
        
        if n_layers == 1:
            self.convs.append(dglnn.SAGEConv(in_feats, out_feats, 'mean'))
        else:
            self.convs.append(dglnn.SAGEConv(in_feats, hidden_feats, 'mean', activation=F.relu))
            for i in range(n_layers - 2):
                self.convs.append(dglnn.SAGEConv(hidden_feats, hidden_feats, 'mean', activation=F.relu))
            self.convs.append(dglnn.SAGEConv(hidden_feats, out_feats, 'mean'))
        
    def forward(self, blocks, input_features):
        """
        blocks : List of blocks generated by block sampler.
        input_features : Input features of the first block.
        """
        h = input_features
        for layer, block in zip(self.convs, blocks):
            h = self.propagate(block, h, layer)
        return h
    
    def propagate(self, block, h, layer):
        # Because GraphSAGE requires not only the features of the neighbors, but also the features
        # of the output nodes themselves on the current layer, we need to copy the output node features
        # from the input side to the output side ourselves to make GraphSAGE work correctly.
        # The output nodes of a block are guaranteed to appear the first in the input nodes, so we can
        # conveniently write like this:
        h_dst = h[:block.number_of_dst_nodes()]
        h = layer(block, (h, h_dst))
        return h

In [10]:
model = BatchGraphSAGEModel(2, INPUT_FEATURES, OUTPUT_FEATURES)

# The input nodes for computing 2-layer GraphSAGE output on the given output nodes can be obtained like this:
sampled_input_nodes = sampled_blocks[0].srcdata[dgl.NID]

# Get the input features.
# In real life we want to copy this to GPU.  But in this hands-on tutorial we don't have GPUs.
sampled_input_features = example_graph.ndata['features'][sampled_input_nodes]

output_of_sampled_node_batch = model(sampled_blocks, sampled_input_features)
print(output_of_sampled_node_batch)

tensor([[ 2.3278,  3.2866, -2.6597,  1.5736,  2.4844,  1.8405],
        [ 1.1273,  1.4332, -1.0207,  2.3226,  0.3350, -1.2296]],
       grad_fn=<AddBackward0>)


## Neighborhood Sampling

We may notice in the above example that 2-hop neighbors actually almost covered the entire graph.  In real world graphs whose node degrees often follow a power-law distribution (i.e. there would exist a few "hub" nodes with lots of edges), we indeed often observe that for a small set of output nodes from a multi-layer GNN, the input nodes will still cover a large part of the graph.  The whole purpose of saving GPU memory thus fails again in this setting.

Neighborhood sampling offers a solution by *not* taking all neighbors for every node during computation dependency generation.  Instead, we pick a small subset of neighbors and estimate the aggregation of all neighbors from this subset.  This trick often not only reduces memory consumption, but also improves model generalization.

DGL provides a function `dgl.sampling.sample_neighbors` for uniform sampling a fixed number of neighbors of each node.  One can also change `dgl.sampling.sample_neighbors` to any kind of existing neighborhood sampling algorithm (including your own).

In [11]:
class NeighborSampler(object):
    def __init__(self, g, num_fanouts):
        """
        num_fanouts : list of fanouts on each layer.
        """
        self.g = g
        self.num_fanouts = num_fanouts
        
    def sample(self, seeds):
        seeds = torch.LongTensor(seeds)
        blocks = []
        for fanout in reversed(self.num_fanouts):
            # We simply switch from in_subgraph to sample_neighbors for neighbor sampling.
            sampled_graph = dgl.sampling.sample_neighbors(self.g, seeds, fanout)
            
            sampled_block = dgl.to_block(sampled_graph, seeds)
            seeds = sampled_block.srcdata[dgl.NID]
            # Because the computation dependency is generated bottom-up, we prepend the new block instead of
            # appending it.
            blocks.insert(0, sampled_block)
            
        return blocks

In [12]:
block_sampler = NeighborSampler(example_graph, [2, 2])
sampled_blocks = block_sampler.sample(sampled_node_batch)

print('Block for first layer')
print('---------------------')
print_block_info(sampled_blocks[0])
print()
print('Block for second layer')
print('----------------------')
print_block_info(sampled_blocks[1])

Block for first layer
---------------------
Node ID of input nodes in original graph: tensor([4, 6, 7, 2, 0, 3, 9])
Node ID of output nodes in original graph: tensor([4, 6, 7, 2, 0])
Edge connections: tensor([0, 7, 0, 2, 0, 4, 3, 6, 2, 9]) tensor([4, 4, 6, 6, 7, 7, 2, 2, 0, 0])

Block for second layer
----------------------
Node ID of input nodes in original graph: tensor([4, 6, 7, 2, 0])
Node ID of output nodes in original graph: tensor([4, 6])
Edge connections: tensor([7, 2, 0, 2]) tensor([4, 4, 6, 6])


We can see that each output node now has at most 2 neighbors.

Code for message passing on blocks generated with neighborhood sampling does not change at all.

In [13]:
sagenet = BatchGraphSAGEModel(2, INPUT_FEATURES, OUTPUT_FEATURES)

# The input nodes for computing 2-layer GraphSAGE output on the given output nodes can be obtained like this:
sampled_input_nodes = sampled_blocks[0].srcdata[dgl.NID]

# Get the input features.
# In real life we want to copy this to GPU.  But in this hands-on tutorial we don't have GPUs.
sampled_input_features = example_graph.ndata['features'][sampled_input_nodes]

output_of_sampled_node_batch = sagenet(sampled_blocks, sampled_input_features)
print(output_of_sampled_node_batch)

tensor([[-2.0956,  1.7314, -1.4967, -0.1753,  4.9156, -3.3693],
        [-0.5386,  2.7805, -0.1162,  0.6231,  3.9517, -2.7087]],
       grad_fn=<AddBackward0>)


### Inference with Models Trained with Neighbor Sampling

Recall that modules such as Dropout or batch normalization have different formulations in training and inference.  The reason was that we do not wish to introduce any randomness during inference or model deployment.  Similarly, we do not want to sample any of the neighbors during inference; aggregation should be performed on all neighbors without sampling to eliminate randomness.  However, directly using the multi-layer `FullNeighborBlockSampler` would still cost a lot of memory even during inference, due to the large number of input nodes being covered.

The solution to this is to compute representations of all nodes on one intermediate layer at a time.  To be more specific, for a multi-layer GraphSAGE model, we first compute the representation of all nodes on the 1st GraphSAGE layer, using a 1-layer `FullNeighborBlockSampler` to take all neighbors into account.  Such representations are computed in minibatches.  After all the representations from the 1st GraphSAGE layer are computed, we start from there and compute the representation of all nodes on the 2nd GraphSAGE layer.  We repeat the process until we go to the last layer.

In [14]:

def inference_with_sagenet(sagenet, graph, input_features, batch_size):
    block_sampler = FullNeighborBlockSampler(graph, 1)
    h = input_features
    
    with torch.no_grad():
        # We are computing all representations of one layer at a time.
        # The outer loop iterates over GNN layers.
        for conv in sagenet.convs:
            new_h_list = []
            node_ids = torch.arange(graph.number_of_nodes())
            # The inner loop iterates over batch of nodes.
            for batch_start in range(0, graph.number_of_nodes(), batch_size):
                # Sample a block with full neighbors of the current node batch
                block = block_sampler.sample(node_ids[batch_start:batch_start+batch_size])[0]
                # Get the necessary input node IDs for this node batch on this layer
                input_node_ids = block.srcdata[dgl.NID]
                # Get the input features
                h_input = h[input_node_ids]
                # Compute the output of this node batch on this layer
                new_h = sagenet.propagate(block, h_input, conv)
                new_h_list.append(new_h)
            # We finished computing all representations on this layer.  We need to compute the
            # representations of next layer.
            h = torch.cat(new_h_list)
        
    return h

In [31]:
print(inference_with_sagenet(sagenet, example_graph, example_graph.ndata['features'], 2))

tensor([[-0.3472, -0.3333,  3.1190, -2.0490, -0.0394,  0.1479],
        [ 0.0063, -0.7284,  1.4344,  0.3252, -4.1741, -3.4755],
        [-1.1008,  2.5305,  2.4708,  0.8450, -0.2665, -0.4709],
        [-2.5998, -2.6475,  2.6640, -2.3707, -1.7509, -4.0879],
        [-1.6154,  4.0293,  3.1417,  2.7561,  0.9981,  0.2866],
        [-0.4941,  1.3706,  1.5416,  2.2070, -3.7743, -4.5323],
        [-1.5748,  4.4402,  3.5040,  3.0403,  1.9333, -0.6245],
        [ 1.8063,  2.3488,  4.2769, -0.1978, -0.1918,  0.0698],
        [-0.2144,  1.8032,  1.7731,  2.6832, -2.3726, -3.6091],
        [ 1.0633,  0.5951,  1.2866,  0.3168, -1.1016, -0.3529],
        [ 0.2300,  1.2489,  1.7450,  2.1502, -0.3055, -0.1434],
        [-1.4086, -1.1529,  1.1392, -0.5382, -0.6736, -0.6523]])


## Putting Together

Now let's see how we could apply stochastic training on a node classification task.  We take PubMed dataset as an example.

### Load Dataset

In [15]:
import dgl.data

dataset = dgl.data.citation_graph.load_pubmed()

# Set features and labels for each node
graph = dgl.graph(dataset.graph)
graph.ndata['features'] = torch.FloatTensor(dataset.features)
graph.ndata['labels'] = torch.LongTensor(dataset.labels)

# Find the node IDs in the training, validation, and test set.
train_nid = dataset.train_mask.nonzero()[0]
val_nid = dataset.val_mask.nonzero()[0]
test_nid = dataset.test_mask.nonzero()[0]

Downloading /Users/liuxuefe/.dgl/pubmed.zip from https://data.dgl.ai/dataset/pubmed.zip...
Extracting file to /Users/liuxuefe/.dgl/pubmed
Finished data loading and preprocessing.
  NumNodes: 19717
  NumEdges: 88651
  NumFeats: 500
  NumClasses: 3
  NumTrainingSamples: 60
  NumValidationSamples: 500
  NumTestSamples: 1000


### Define Neighbor Sampler

We can reuse our neighbor sampler code above.

In [16]:
neighbor_sampler = NeighborSampler(graph, [10, 10])

### Define DataLoader

PyTorch generates minibatches with a `DataLoader` object.  We can also use it.

Note that to compute the output of a minibatch of nodes, we need a list of blocks described as above.  Therefore, we need to change the `collate_fn` argument which defines how to compose different individual examples into a minibatch.

In [17]:
import torch.utils.data

BATCH_SIZE = 5

train_dataloader = torch.utils.data.DataLoader(
    train_nid, batch_size=BATCH_SIZE, collate_fn=neighbor_sampler.sample, shuffle=True)

### Define Model and Optimizer

In [19]:
HIDDEN_FEATURES = 10
model = BatchGraphSAGEModel(2, dataset.features.shape[1], dataset.num_labels, HIDDEN_FEATURES)

opt = torch.optim.Adam(model.parameters(), lr=1e-3)

### Evaluation

In [20]:
def compute_accuracy(pred, labels):
    return (pred.argmax(1) == labels).float().mean().item()

### Training Loop

In [21]:
NUM_EPOCHS = 50
EVAL_BATCH_SIZE = 1000
for epoch in range(NUM_EPOCHS):
    sagenet.train()
    for blocks in train_dataloader:
        input_nodes = blocks[0].srcdata[dgl.NID]
        output_nodes = blocks[-1].dstdata[dgl.NID]
        
        input_features = graph.ndata['features'][input_nodes]
        output_labels = graph.ndata['labels'][output_nodes]
        
        output_predictions = model(blocks, input_features)
        loss = F.cross_entropy(output_predictions, output_labels)
        opt.zero_grad()
        loss.backward()
        opt.step()
        
    sagenet.eval()
    all_predictions = inference_with_sagenet(model, graph, graph.ndata['features'], EVAL_BATCH_SIZE)
    
    val_predictions = all_predictions[val_nid]
    val_labels = graph.ndata['labels'][val_nid]
    test_predictions = all_predictions[test_nid]
    test_labels = graph.ndata['labels'][test_nid]
    
    print('Validation acc:', compute_accuracy(val_predictions, val_labels),
          'Test acc:', compute_accuracy(test_predictions, test_labels))

Validation acc: 0.3880000114440918 Test acc: 0.4129999876022339
Validation acc: 0.3919999897480011 Test acc: 0.41600000858306885
Validation acc: 0.4020000100135803 Test acc: 0.4300000071525574
Validation acc: 0.4320000112056732 Test acc: 0.4560000002384186
Validation acc: 0.4959999918937683 Test acc: 0.49799999594688416
Validation acc: 0.5540000200271606 Test acc: 0.5509999990463257
Validation acc: 0.5640000104904175 Test acc: 0.5709999799728394
Validation acc: 0.6259999871253967 Test acc: 0.6420000195503235
Validation acc: 0.6340000033378601 Test acc: 0.6320000290870667
Validation acc: 0.6520000100135803 Test acc: 0.6510000228881836
Validation acc: 0.6520000100135803 Test acc: 0.6610000133514404
Validation acc: 0.6759999990463257 Test acc: 0.671999990940094
Validation acc: 0.6660000085830688 Test acc: 0.675000011920929
Validation acc: 0.6880000233650208 Test acc: 0.6840000152587891
Validation acc: 0.7160000205039978 Test acc: 0.7070000171661377
Validation acc: 0.7039999961853027 Test 