# Stochastic Training of GNN for Link Prediction on Large Graphs

In this tutorial you will learn how to train a multi-layer GraphSAGE in an unsupervised learning setting via link prediction on Amazon Copurchase Network provided by OGB.  The dataset contains 2.4 million nodes and 61 million edges, hence not able to fit in a single GPU.

The contents in this tutorial include how to 

* Train a GNN model with a single machine with a single GPU on a graph of any size.
* Train a GNN model for link prediction.
* Train a GNN model for unsupervised learning.

This tutorial is based on the data downloaded in the previous tutorial.

## Link Prediction Overview

The objective of link prediction is to predict whether an edge exists between two given nodes.  We often formulate the problem as predicting a score $s_{uv} = \phi(\boldsymbol{h}^{(l)}_u, \boldsymbol{h}^{(l)}_v)$ corresponding to the likelihood of an edge existing between two nodes.  We train the model via *negative sampling*, i.e. comparing the score of a real edge against that of a "non-existent" edge.

A common loss function is negative log-likelihood.

$$
\mathcal{L} = -\log \sigma\left(s_{uv}\right) - Q \mathbb{E}_{v^- \in P^-(v)}\left[ \sigma\left(-s_{uv^-}\right) \right]
$$

You can also use other loss functions such as BPR or margin loss.

Note that the formulation is very similar to that in implicit matrix factorization or word embedding learning.

## Overview of Unsupervised Learning with GNNs

Link prediction itself is already useful in various tasks such as recommendation where you will predict whether a node will interact with another node.  It is also useful in an unsupervised learning setting where you just want to learn a latent representation of all the nodes.

The model will be trained in an unsupervised manner by predicting whether two nodes are connected with an edge, and the learned representations could be used later for nearest neighbor search or future training of a classifier.  The objective function can also be combined together with supervised cross-entropy loss for nodee classification.

## Load Dataset

We directly load the dataset preprocessed by the previous tutorial.

In [None]:
import dgl
import torch
import numpy as np
import utils
import pickle

with open('data.pkl', 'rb') as f:
    data = pickle.load(f)
graph, node_features, node_labels, train_nids, valid_nids, test_nids = data
utils.prepare_mp(graph)

## Define Data Loader with Neighbor Sampling

Different from node classification, we need to iterate over edges, then compute the output representation of incident nodes using neighbor sampling and GNN.

DGL also provides `EdgeDataLoader` allowing you to iterate over edges for edge classification or link prediction tasks.  To perform link prediction, you need to provide a negative sampler.

For homogeneous graphs, the negative sampler can be any callable that has the following signature:

```python
def negative_sampler(g: DGLGraph, eids: Tensor) -> Tuple[Tensor, Tensor]:
    pass
```

The first argument is the original graph and the second argument is the minibatch of edge IDs.  The function returns a pair of $u$-$v^-$ node ID tensors as negative examples.

The following code implements a negative sampler that find non-existent edges by sampling `k` $v^-$ for each $u$ according to a distribution $P^-(v) \propto d(v)^{0.75}$, where $d(v)$ is the degree of $v$.

In [None]:
class NegativeSampler(object):
    def __init__(self, g, k):
        self.k = k
        self.weights = g.in_degrees().float() ** 0.75
    def __call__(self, g, eids):
        src, _ = g.find_edges(eids)
        src = src.repeat_interleave(self.k)
        dst = self.weights.multinomial(len(src), replacement=True)
        return src, dst

After defining the negative sampler, one can then define the edge data loader with neighbor sampling.  Here we will be taking 5 negative examples per positive example.

In [None]:
sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4, 4])
k = 5
train_dataloader = dgl.dataloading.EdgeDataLoader(
    graph, torch.arange(graph.number_of_edges()), sampler,
    negative_sampler=NegativeSampler(graph, k),
    batch_size=1024,
    shuffle=True,
    drop_last=False,
    num_workers=4
)

You can peek one minibatch from `train_dataloader` and see what it will give you.

In [None]:
example_minibatch = next(iter(train_dataloader))
print(example_minibatch)

The example minibatch consists of four elements.

* The input node list necessary for computing the representation of output nodes.
* The subgraph induced by the nodes being sampled in the minibatch (including those in the negative examples) as well as the edges sampled in the minibatch.
* The subgraph induced by the nodes being sampled in the minibatch (including those in the negative examples) as well as the non-existent edges sampled by the negative sampler.
* The list of computation dependency for each layer.

In [None]:
input_nodes, pos_graph, neg_graph, blocks = example_minibatch
print('Number of input nodes:', len(input_nodes))
print('Positive graph # nodes:', pos_graph.number_of_nodes(), '# edges:', pos_graph.number_of_edges())
print('Negative graph # noeds:', neg_graph.number_of_nodes(), '# edges:', neg_graph.number_of_edges())
print(blocks)

## Defining Model for Node Representation

The model can be written as follows:

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import dgl.nn as dglnn

class SAGE(nn.Module):
    def __init__(self, in_feats, n_hidden, n_layers):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
        for i in range(1, n_layers):
            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        
    def forward(self, blocks, x):
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            x = layer(block, x)
            if l != self.n_layers - 1:
                x = F.relu(x)
        return x

## Obtaining Node Representation from GNN

In the previous tutorial we talked about offline inference of a graph neural network without neighbor sampling.  This can be directly copy-pasted for computing the node representation output from a GNN under an unsupervised learning setting as well.

In [None]:
def inference(model, graph, input_features, batch_size):
    nodes = torch.arange(graph.number_of_nodes())
    
    sampler = dgl.dataloading.MultiLayerNeighborSampler([None])  # one layer at a time, taking all neighbors
    dataloader = dgl.dataloading.NodeDataLoader(
        graph, nodes, sampler,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=0)
    
    with torch.no_grad():
        for l, layer in enumerate(model.layers):
            # Allocate a buffer of output representations for every node
            # Note that the buffer is on CPU memory.
            output_features = torch.zeros(graph.number_of_nodes(), model.n_hidden)

            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
                block = blocks[0].to(torch.device('cuda'))

                x = input_features[input_nodes].cuda()

                # the following code is identical to the loop body in model.forward()
                x = layer(block, x)
                if l != model.n_layers - 1:
                    x = F.relu(x)

                output_features[output_nodes] = x.cpu()
            input_features = output_features
    return output_features

## Define the Score Predictor for Edges

After getting the node representation necessary for the minibatch, we would like to predict the score of the edges and non-existent edges in the sampled minibatch.  This can be easily accomplished with `apply_edges` method.  Here, we will simply compute the score by dot product of the representations of both incident nodes.

In [None]:
class ScorePredictor(nn.Module):
    def forward(self, subgraph, x):
        with subgraph.local_scope():
            subgraph.ndata['x'] = x
            subgraph.apply_edges(dgl.function.u_dot_v('x', 'x', 'score'))
            return subgraph.edata['score']

## Evaluate Performance of the Learned Embedding

In this tutorial we will be evaluating the performance of the output embedding by training a linear classifier with the output embedding as input on the training set, and compute the accuracy on the validation/test set.

In [None]:
import sklearn.linear_model
import sklearn.metrics
def evaluate(emb, label, train_nids, valid_nids, test_nids):
    classifier = sklearn.linear_model.LogisticRegression(solver='lbfgs', multi_class='multinomial', verbose=1, max_iter=1000)
    classifier.fit(emb[train_nids], label[train_nids])
    valid_pred = classifier.predict(emb[valid_nids])
    test_pred = classifier.predict(emb[test_nids])
    valid_acc = sklearn.metrics.accuracy_score(label[valid_nids], valid_pred)
    test_acc = sklearn.metrics.accuracy_score(label[test_nids], test_pred)
    return valid_acc, test_acc

## Defining Training Loop

The following initializes the model and defines the optimizer.

In [None]:
model = SAGE(node_features.shape[1], 128, 3).cuda()
predictor = ScorePredictor().cuda()
opt = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()))

The following is the training loop for unsupervised learning and evaluation, and also saves the model that performs the best on the validation set:

In [None]:
import tqdm
import sklearn.metrics

best_accuracy = 0
best_model_path = 'model.pt'
for epoch in range(10):
    model.train()
    
    with tqdm.tqdm(train_dataloader) as tq:
        for step, (input_nodes, pos_graph, neg_graph, blocks) in enumerate(tq):
            blocks = [b.to(torch.device('cuda')) for b in blocks]
            pos_graph = pos_graph.to(torch.device('cuda'))
            neg_graph = neg_graph.to(torch.device('cuda'))
            inputs = node_features[input_nodes].cuda()
            outputs = model(blocks, inputs)
            pos_score = predictor(pos_graph, outputs)
            neg_score = predictor(neg_graph, outputs)
            
            score = torch.cat([pos_score, neg_score])
            label = torch.cat([torch.ones_like(pos_score), torch.zeros_like(neg_score)])
            loss = F.binary_cross_entropy_with_logits(score, label)
            
            opt.zero_grad()
            loss.backward()
            opt.step()
            
            tq.set_postfix({'loss': '%.03f' % loss.item()}, refresh=False)
        
    model.eval()
    emb = inference(model, graph, node_features, 16384)
    valid_acc, test_acc = evaluate(emb.numpy(), node_labels.numpy())
    print('Epoch {} Validation Accuracy {} Test Accuracy {}'.format(epoch, valid_acc, test_acc))
    if best_accuracy < valid_acc:
        best_accuracy = valid_acc
        torch.save(model.state_dict(), best_model_path)

## Conclusion

In this tutorial, you have learned how to train a multi-layer GraphSAGE for unsupervised learning via link prediction on a large dataset that cannot fit into GPU.  The method you have learned can scale to a graph of any size, and works on a single machine with a single GPU.

## What's next?

The next tutorial will be about scaling the training procedure out to multiple GPUs on a single machine.