# Stochastic Training of GNN for Node Classification on Large Graphs

In this tutorial you will learn how to train a multi-layer GraphSAGE for node classification on Amazon Copurchase Network provided by OGB.  The dataset contains 2.4 million nodes and 61 million edges, hence not able to fit in a single GPU.

The contents in this tutorial include how to

* Create your DGL graph from your own data in other formats such as CSV.
* Train a GNN model with a single machine with a single GPU on a graph of any size.

## Load Dataset

Although you can directly use the Python package provided by OGB.  For demonstration we will instead manually download the dataset, peek into its contents, and process it with only `numpy`.

In [None]:
!wget https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/ogbn-products.zip

In [None]:
!unzip -o ogbn-products.zip

The dataset contains these files:

* `products/raw/edge.csv` (source-destination pairs)
* `products/raw/node-feat.csv` (node features)
* `products/raw/node-label.csv` (node labels)
* `products/raw/num-edge-list.csv` (number of edges)
* `products/raw/num-node-list.csv` (number of nodes)

We will only use the first three CSV files.

In addition, it also contains the following files defining the training-validation-test split in the directory `products/split/sales_ranking`.  All `train.csv`, `valid.csv` and `test.csv` are text files containing the node IDs in the training/validation/test set, one number per line.

<div class="alert alert-info">
    <b>Note:</b> The node IDs should be consecutive integers from 0 to the number of nodes minus 1.  If your node ID is not consecutive or not starting from 0 (e.g. starting from <code>100000</code>), then you need to relabel them yourselves.  The <code>astype</code> method in pandas DataFrame can conveniently relabel the IDs by converting the type to <code>"category"</code>.
</div>

In [None]:
import pandas as pd
edges = pd.read_csv('raw/edge.csv.gz', header=None).values
node_features = pd.read_csv('raw/node-feat.csv.gz', header=None).values
node_labels = pd.read_csv('raw/node-label.csv.gz', header=None).values[:, 0]

# pd.read_csv yields a DataFrame with one column, so we make them one-dimensional arrays.
train_nids = pd.read_csv('split/sales_ranking/train.csv.gz', header=None).values[:, 0]
valid_nids = pd.read_csv('split/sales_ranking/valid.csv.gz', header=None).values[:, 0]
test_nids = pd.read_csv('split/sales_ranking/test.csv.gz', header=None).values[:, 0]

We construct the graph as follows:

In [None]:
import dgl
import torch

graph = dgl.graph((edges[:, 0], edges[:, 1]))
node_features = torch.FloatTensor(node_features)
node_labels = torch.LongTensor(node_labels)

# Save the graph, features and training-validation-test split for use for future tutorials.
import pickle
with open('data.pkl', 'wb') as f:
    pickle.dump((graph, node_features, node_labels, train_nids, valid_nids, test_nids), f)

In [None]:
# Load the graph back from the file we saved
import dgl
import torch
import numpy as np
import pickle
with open('data.pkl', 'rb') as f:
    graph, node_features, node_labels, train_nids, valid_nids, test_nids = pickle.load(f)

We can see the size of the graph, features, and labels as follows.

In [None]:
print('Graph')
print(graph)
print('Shape of node features:', node_features.shape)
print('Shape of node labels:', node_labels.shape)

num_features = node_features.shape[1]
num_classes = (node_labels.max() + 1).item()
print('Number of classes:', num_classes)

## Define a Data Loader with Neighbor Sampling

### Neighbor sampling overview

The message passing formulation is commonly defined as the following:

$$
\begin{gathered}
  \boldsymbol{a}_v^{(l)} = \rho^{(l)} \left(
    \left\lbrace
      \boldsymbol{h}_u^{(l-1)} : u \in \mathcal{N} \left( v \right)
    \right\rbrace
  \right)
\\
  \boldsymbol{h}_v^{(l)} = \phi^{(l)} \left(
    \boldsymbol{h}_v^{(l-1)}, \boldsymbol{a}_v^{(l)}
  \right)
\end{gathered}
$$

where $\rho^{(l)}$ and $\phi^{(l)}$ are parameterized functions, and $\mathcal{N}(v)$ is defined as the set of predecessors (or equivalently *neighbors*) of $v$ on graph $\mathcal{G}$:
$$
\mathcal{N} \left( v \right) = \left\lbrace
  s \left( e \right) : e \in \mathbb{E}, t \left( e \right) = v
\right\rbrace
$$

For instance, to perform a message passing for updating the red node in the following graph:

![Imgur](https://i.imgur.com/xYPtaoy.png)

One needs to aggregate the node features of its neighbors, shown as green nodes:

![Imgur](https://i.imgur.com/OuvExp1.png)

Let's consider how multi-layer message passing works for computing the output of a single node.  In the following text we refer to the nodes whose GNN outputs are to be computed as *seed nodes*.

Consider computing with a 2-layer GNN the output of the seed node 8, colored red, in the following graph:

![Imgur](https://i.imgur.com/xYPtaoy.png)

By the formulation:

$$
\begin{gathered}
  \boldsymbol{a}_8^{(2)} = \rho^{(2)} \left(
    \left\lbrace
      \boldsymbol{h}_u^{(1)} : u \in \mathcal{N} \left( 8 \right)
    \right\rbrace
  \right) = \rho^{(2)} \left(
    \left\lbrace
      \boldsymbol{h}_4^{(1)}, \boldsymbol{h}_5^{(1)},
      \boldsymbol{h}_7^{(1)}, \boldsymbol{h}_{11}^{(1)}
    \right\rbrace
  \right)
\\
  \boldsymbol{h}_8^{(2)} = \phi^{(2)} \left(
    \boldsymbol{h}_8^{(1)}, \boldsymbol{a}_8^{(2)}
  \right)
\end{gathered}
$$

We can tell from the formulation that to compute $\boldsymbol{h}_8^{(2)}$ we need messages from node 4, 5, 7 and 11 (colored green) along the edges visualized below.

![Imgur](https://i.imgur.com/Gwjz05H.png)

This graph contains all the nodes in the original graph but only the edges necessary for message passing to the given output nodes.  We call that the *frontier* of the second GNN layer for node 8.

However, we still need the values of $\boldsymbol{h}_\cdot^{(1)}$, which are the outputs from the first GNN layer.  To compute those values for the red and green nodes, we further need to perform message passing on the edges visualized below.

![Imgur](https://i.imgur.com/g5Ptbj7.png)

Therefore, to compute the 2-layer GNN representation of the red node, we need the input features from the red node as well as the green and yellow nodes.

You may notice that the procedure which determines computation dependency is in the reverse direction of message aggregation: you start from the layer closest to the output, and work backwards to the input.

You can also see that computing representation for a small number of nodes often requires input features of a significantly larger number of nodes.  Often times, taking all neighbors for message aggregation is too costly since the nodes needed would easily cover a large portion of the graph.

Neighbor sampling addresses this issue by selecting a random subset of the neighbors to perform aggregation.  For instance, to compute $\boldsymbol{h}_8^{(1)}$ we can choose to sample 2 neighbors and aggregate.

![Imgur](https://i.imgur.com/2RJKDGB.png)

Similarly, to compute the first layer representation of the red and green nodes, we can also do neighbor sampling that takes 2 neighbors for each node.

![Imgur](https://i.imgur.com/jD79uy5.png)

You can see that this method could give us fewer nodes needed for input features.

### Defining neighbor sampler and data loader in DGL

DGL provides useful tools to generate such computation dependencies while iterating over the dataset in minibatches.  For node classification, you can use `dgl.dataloading.NodeDataLoader` for iterating over the dataset, and `dgl.dataloading.MultiLayerNeighborSampler` to generate computation dependencies of the nodes from a multi-layer GNN with neighbor sampling.

The syntax of `dgl.dataloading.NodeDataLoader` is mostly similar to a PyTorch `DataLoader`, with the addition that it needs a graph to generate computation dependency from, a set of node IDs to iterate on, and the neighbor sampler you defined.

Let's consider training a 3-layer GraphSAGE with neighbor sampling, and each node will gather message from 4 neighbors on each layer.  The code defining the data loader and neighbor sampler will look like the following.

In [None]:
sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4, 4])
train_dataloader = dgl.dataloading.NodeDataLoader(
    graph, train_nids, sampler,
    batch_size=1024,
    shuffle=True,
    drop_last=False,
    num_workers=0
)

We can iterate over the data loader we created and see what it gives us.

In [None]:
example_minibatch = next(iter(train_dataloader))
print(example_minibatch)

`NodeDataLoader` gives us three items per iteration.

* The input node list for the nodes whose input features are needed to compute the outputs.
* The output node list whose GNN representation are to be computed.
* The list of computation dependency for each layer.

In [None]:
input_nodes, output_nodes, blocks = example_minibatch
print("To compute {} nodes' output we need {} nodes' input features".format(len(output_nodes), len(input_nodes)))

The variable `blocks` shows how messages aggregate on each layer.

In [None]:
print(blocks)

## Defining Model

The model can be written as follows:

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import dgl.nn as dglnn

class SAGE(nn.Module):
    def __init__(self, in_feats, n_hidden, n_classes, n_layers):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
        for i in range(1, n_layers - 1):
            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
        
    def forward(self, blocks, x):
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            x = layer(block, x)
            if l != self.n_layers - 1:
                x = F.relu(x)
        return x

You can see that here we are iterating over the pairs of NN module layer and block generated by the data loader.

## Defining Training Loop

The following initializes the model and defines the optimizer.

In [None]:
model = SAGE(num_features, 128, num_classes, 3).cuda()
opt = torch.optim.Adam(model.parameters())

When computing the validation score for model selection, usually you can also do neighbor sampling.  To do that, you need to define another data loader.

In [None]:
valid_dataloader = dgl.dataloading.NodeDataLoader(
    graph, valid_nids, sampler,
    batch_size=1024,
    shuffle=False,
    drop_last=False,
    num_workers=0
)

The following is a training loop that performs validation every epoch.  It also saves the model with the best validation accuracy into a file.

In [None]:
import tqdm
import sklearn.metrics

best_accuracy = 0
best_model_path = 'model.pt'
for epoch in range(10):
    model.train()
    
    with tqdm.tqdm(train_dataloader) as tq:
        for step, (input_nodes, output_nodes, blocks) in enumerate(tq):
            blocks = [b.to(torch.device('cuda')) for b in blocks]
            inputs = node_features[input_nodes].cuda()
            labels = node_labels[output_nodes].cuda()
            predictions = model(blocks, inputs)

            loss = F.cross_entropy(predictions, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()

            accuracy = sklearn.metrics.accuracy_score(labels.cpu().numpy(), predictions.argmax(1).detach().cpu().numpy())
            
            tq.set_postfix({'loss': '%.03f' % loss.item(), 'acc': '%.03f' % accuracy}, refresh=False)
        
    model.eval()
    
    predictions = []
    labels = []
    with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad():
        for input_nodes, output_nodes, blocks in tq:
            blocks = [b.to(torch.device('cuda')) for b in blocks]
            inputs = node_features[input_nodes].cuda()
            labels.append(node_labels[output_nodes].numpy())
            predictions.append(model(blocks, inputs).argmax(1).cpu().numpy())
        predictions = np.concatenate(predictions)
        labels = np.concatenate(labels)
        accuracy = sklearn.metrics.accuracy_score(labels, predictions)
        print('Epoch {} Validation Accuracy {}'.format(epoch, accuracy))
        if best_accuracy < accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), best_model_path)

## Offline Inference without Neighbor Sampling

Usually for offline inference it is desirable to aggregate over the entire neighborhood to eliminate randomness introduced by neighbor sampling.  However, using the same methodology in training is not efficient, because there will be a lot of redundant computation.  Moreover, simply doing neighbor sampling by taking all neighbors will often exhaust GPU memory because the number of nodes required for input features may be too large to fit into GPU memory.

Instead, you need to compute the representations layer by layer: you first compute the output of the first GNN layer for all nodes, then you compute the output of second GNN layer for all nodes using the first GNN layer's output as input, etc.  This gives us a different algorithm from what is being used in training.  During training we have an outer loop that iterates over the nodes, and an inner loop that iterates over the layers.  In contrast, during inference we have an outer loop that iterates over the layers, and an inner loop that iterates over the nodes.

You can still use the `dgl.dataloading.MultiLayerNeighborSampler` and `dgl.dataloading.NodeDataLoader` to do offline inference.

![Imgur](https://i.imgur.com/rr1FG7S.gif)

In [None]:
def inference(model, graph, input_features, batch_size):
    nodes = torch.arange(graph.number_of_nodes())
    
    sampler = dgl.dataloading.MultiLayerNeighborSampler([None])  # one layer at a time, taking all neighbors
    dataloader = dgl.dataloading.NodeDataLoader(
        graph, nodes, sampler,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=0)
    
    with torch.no_grad():
        for l, layer in enumerate(model.layers):
            # Allocate a buffer of output representations for every node
            # Note that the buffer is on CPU memory.
            output_features = torch.zeros(
                graph.number_of_nodes(), model.n_hidden if l != model.n_layers - 1 else model.n_classes)

            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
                block = blocks[0].to(torch.device('cuda'))

                x = input_features[input_nodes].cuda()

                # the following code is identical to the loop body in model.forward()
                x = layer(block, x)
                if l != model.n_layers - 1:
                    x = F.relu(x)

                output_features[output_nodes] = x.cpu()
            input_features = output_features
    return output_features

The following code loads the best model from the file saved previously and performs offline inference.  It computes the accuracy on the test set afterwards.

In [None]:
model.load_state_dict(torch.load(best_model_path))
all_predictions = inference(model, graph, node_features, 8192)

In [None]:
test_predictions = all_predictions[test_nids].argmax(1)
test_labels = node_labels[test_nids]
test_accuracy = sklearn.metrics.accuracy_score(test_predictions.numpy(), test_labels.numpy())
print('Test accuracy:', test_accuracy)

## Conclusion

In this tutorial, you have learned how to train a multi-layer GraphSAGE with neighbor sampling on a large dataset that cannot fit into GPU.  The method you have learned can scale to a graph of any size, and works on a single machine with a single GPU.

## What's next?

The next tutorial will be about training the same GraphSAGE model in an unsupervised manner with link prediction, i.e. predicting whether an edge exist between two nodes.