# Graph classification with DGL

Notebook created by Rui Valente de Almeida, for the ISM PhD Course; FCT NOVA, 2019

This notebook follows the [tutorial on graph classification with DGL](https://docs.dgl.ai/en/latest/tutorials/basics/4_batch.html), by the DGL team. In it, we will use some toy examples of graph and implement some GCN code that will allow us to automatically classify a graph as being an instance of one of the labeled classes.

Solutions to problems similar to this can be useful for a variety of fields, like cyber security or social network analysis and there are already several papers in the literature on the subject:
* [Ying et al](https://arxiv.org/abs/1806.08804);
* [Cangea et al., 2018](https://arxiv.org/abs/1811.01287);
* [Knyazev et al., 2018](https://arxiv.org/abs/1811.09595);
* [Bianchi et al., 2019](https://arxiv.org/abs/1901.01343); 
* [Liao et al., 2019](https://arxiv.org/abs/1901.01484);
* [Gao et al., 2019](https://openreview.net/forum?id=HJePRoAct7);

The idea of this tutorial is to use a synthetically generated dataset (generated using one of DGL's functions) to train the neural network and enable the possibility of classification. We will try to balance the dataset, giving the same number of elements to each one of the 8 classes. The classes will contain elements like the following:
![8classes](img/dataset_overview.png "")

## Imports

As usual, we import what we need.

In [None]:
from dgl.data import MiniGCDDataset
from matplotlib import pyplot as plt
import networkx as nx
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

## Dataset generation

We now generate our dataset: 80 samples, each graph with size [10,20].

In [None]:
dataset = MiniGCDDataset(80, 10, 20)

Data visualisation is very important. We should always have somekind of function that allows us to quickly see a sample of what we are working with.

In [1]:
def view_ds_sample(graph_index, ret=False):
    g, l = dataset[graph_index]
    fig, ax = plt.subplots()
    nx.draw(g.to_networkx(), ax=ax)
    ax.set_title(f'Class:{l}')
    plt.show()
    
    if ret: return g, l

In [None]:
view_ds_sample(0)

## Batching

When working in deep learning, it is common and good practice to produce mini-batches of whatever we are trying to classify (images, text, documents, ...) before running them through the network we are training. With euclidean data, this is easy: to form a mini-batch of images, we stack how many images we want to stack on a tensor, and it is just another dimension in a matrix. Non euclidean data (graphs) are a different manner, because of their irregularity.

To circumvent this inconvenient, the people at GDL have been very clever: they have noticed that a minibatch of graphs can be viewed as a new graph with many disjoint connected components. This makes batching graphs almost a matter of concatenating the graphs that we want to batch. We now define a function that does this for us.

In [2]:
def collate(samples):
    graphs, labels = map(list, zip(*samples))
    return dgl.batch(graphs, labels), torch.tensor(labels)

This clever trick has another important consequence. The return from `dgl.batch` is a graph, on which we can call whatever method we can call for any other graph.

## Classification

Finally we reach the classification stage. It is performed in two stages: first, the nodes pass their feature information (messages) to their neighbours. This phase is called message passing or convolution (as we have seen in the presentation). Afterwards, we perform the aggregation stage, in which a tensor is created to represent information from nodes and edges (globally). Finally, these representations (which can be called embeddings) can be passed into a classifier that will predict the graph label. The next image depicts all this process.
![graph_class](img/graph_classifier.png "")

In [None]:
#message passing function
msg = fn.copy_src(src='h', out='m')

# This function takes an average of neighbour features and updates the original node's features.
def reduce(nodes): return {'h': torch.mean(nodes.mailbox['m'], 1)}

class NodeApplyModule(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(NodeApplyModule, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
        self.activation = activation
    
    def forward(self, node): return {'h': self.activation(self.linear(node.data['h']))}

class GCN(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(GCN, self).__init__()
        self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)
        
    def forward(self, g, feature):
        g.ndata['h'] = feature
        g.update_all(msg, reduce)
        g.apply_nodes(func=self.apply_mod)
        return g.ndata.pop('h')

## Aggregation and classification

In this tutorial, initial node features are their degrees. We run the convolution for two rounds and then we aggregate, averaging all node features for each graph in our mini-batch.
DGL handles this task through the `mean_nodes()`function, and then it is just a matter of feeding the embeddings to a classifier network.

In [None]:
class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(Classifier, self).__init__()
        self.layers = nn.ModuleList([
            GCN(in_dim, hidden_dim, F.relu),
            GCN(hidden_dim, hidden_dim, F.relu)
        ])
        self.classify = nn.Linear(hidden_dim, n_classes)
    
    def forward(self, g):
        h = g.in_degrees().view(-1,1).float()
        
        for conv in self.layers:
            h = conv(g, h)
        
        g.ndata['h'] = h
        return self.classify(dgl.mean_nodes(g, 'h'))

## Let's train

We create a dataset with 400 graphs with 10 to 20 nodes. We split the dataset into training in test (80-20)

In [None]:
ds_size = 400
bs = 32
lr = 0.001
train = MiniGCDataset(0.8*ds_size, 10, 20)
test = MiniGCDataset(0.2*ds_size, 10, 20)
data_loader = DataLoader(train, batch_size=bs, shuffle=True, collate_fn=collate)

model = Classifier(1, 256, trainset.num_classes)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
model.train()

epoch_losses = []
for epoch in range(80):
    epoch_loss = 0
    for i, (bg, label) in enumerate(data_loader):
        prediction = model(bg)
        loss = loss_func(prediction, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
    epoch_loss /= (i + 1)
    print(f'Epoch {epoch}, loss {epoch_loss}')
    epoch_losses.append(epoch_loss)
