# Training a GNN for Graph Classification

In [1]:
import os
os.environ['DGLBACKEND'] = 'pytorch'

import dgl
import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F

## Overview of Graph Classification with GNN

**Graph classification or regression** requires a model to predict certain **graph-level properties** of a single graph given its node and edge features. **Molecular property prediction** is one particular application.

## Loading Data

In [2]:
dataset = dgl.data.GINDataset('PROTEINS', self_loop=True)
dataset

Dataset("PROTEINS", num_graphs=1113, save_path=/Users/qinzijian/.dgl/PROTEINS_0c2c49a1)

In [3]:
type(dataset)

dgl.data.gindt.GINDataset

The dataset is a set of graphs, each with node features and a single label.

In [4]:
print('Node feature dimensionality: {}'.format(dataset.dim_nfeats))
print('Number of graph categories: {}'.format(dataset.gclasses))

Node feature dimensionality: 3
Number of graph categories: 2


## Defining Data Loader

 - A graph classification dataset usually contains two types of elements: a set of graphs, and their graph-level labels.
 - Similar to an image classification task, when the dataset is large enough, we need to train with mini-batches.
 - In PyTorch you will use a `DataLoader` to iterate over the dataset. In DGL, use `GraphDataLoader`.
 - can also use other dataset samplers provided in `torch.utils.data.sampler`; e.g., this tutorial creates a training `GraphDataLoader` and test `GraphDataLoader`, using `SubsetRandomSampler` to tell PyTorch to sample from only a subset of the dataset.

In [5]:
from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler

In [6]:
num_examples = len(dataset)
num_train = int(num_examples * 0.8)
print('all {}, train {}'.format(num_examples, num_train))

all 1113, train 890


In [7]:
train_sampler = SubsetRandomSampler(torch.arange(num_train))
test_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples))

In [8]:
train_sampler

<torch.utils.data.sampler.SubsetRandomSampler at 0x7fa3b81cba90>

In [9]:
torch.arange(num_train)

tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
        140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
        154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
        168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 1

In [10]:
torch.arange(num_train, num_examples)

tensor([ 890,  891,  892,  893,  894,  895,  896,  897,  898,  899,  900,  901,
         902,  903,  904,  905,  906,  907,  908,  909,  910,  911,  912,  913,
         914,  915,  916,  917,  918,  919,  920,  921,  922,  923,  924,  925,
         926,  927,  928,  929,  930,  931,  932,  933,  934,  935,  936,  937,
         938,  939,  940,  941,  942,  943,  944,  945,  946,  947,  948,  949,
         950,  951,  952,  953,  954,  955,  956,  957,  958,  959,  960,  961,
         962,  963,  964,  965,  966,  967,  968,  969,  970,  971,  972,  973,
         974,  975,  976,  977,  978,  979,  980,  981,  982,  983,  984,  985,
         986,  987,  988,  989,  990,  991,  992,  993,  994,  995,  996,  997,
         998,  999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009,
        1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021,
        1022, 1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1032, 1033,
        1034, 1035, 1036, 1037, 1038, 10

In [11]:
train_dataloader = GraphDataLoader(dataset, sampler=train_sampler, batch_size=5, drop_last=False)
test_dataloader = GraphDataLoader(dataset, sampler=test_sampler, batch_size=5, drop_last=False)

In [12]:
train_dataloader

<dgl.dataloading.dataloader.GraphDataLoader at 0x7fa3b81efd00>

try to iterate over the `GraphDataLoader` and see what it gives

In [13]:
it = iter(train_dataloader)
batch = next(it)
print(batch)

[Graph(num_nodes=275, num_edges=1289,
      ndata_schemes={'attr': Scheme(shape=(3,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={}), tensor([1, 0, 1, 1, 0])]


 - each element in `dataset` has a graph and a label, the `GraphDataLoader` returns two objects for each iteration
 - The first element is the **batched graph**, and the second element is simply a **label vector** repesenting the category of each graph in the mini-batch.

## A Batched Graph in DGL

 - In each mini-batch, the sampled graphs are combined into a single bigger batched graph.
 - The single batched graph merges all original graphs as separately connected components, with the node and edge features concatenated.
 - This bigger graph is also `DGLGraph` instance (can still treat as a normal `DGLGraph` object)

In [14]:
batched_graph, labels = batch

print('Number of nodes for each graph element in the batch: {}'\
     .format(batched_graph.batch_num_nodes()))
print('Number of edges for each graph element in the batch: {}'\
     .format(batched_graph.batch_num_edges()))

Number of nodes for each graph element in the batch: tensor([ 13,  33,  24, 150,  55])
Number of edges for each graph element in the batch: tensor([ 63, 169, 116, 692, 249])


In [15]:
graphs = dgl.unbatch(batched_graph)
print('The original graphs in the minibatch:')
print(graphs)

The original graphs in the minibatch:
[Graph(num_nodes=13, num_edges=63,
      ndata_schemes={'attr': Scheme(shape=(3,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={}), Graph(num_nodes=33, num_edges=169,
      ndata_schemes={'attr': Scheme(shape=(3,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={}), Graph(num_nodes=24, num_edges=116,
      ndata_schemes={'attr': Scheme(shape=(3,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={}), Graph(num_nodes=150, num_edges=692,
      ndata_schemes={'attr': Scheme(shape=(3,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={}), Graph(num_nodes=55, num_edges=249,
      ndata_schemes={'attr': Scheme(shape=(3,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={})]


In [16]:
batched_graph

Graph(num_nodes=275, num_edges=1289,
      ndata_schemes={'attr': Scheme(shape=(3,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={})

In [17]:
batched_graph.ndata

{'attr': tensor([[1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
 

## Define Model
 
 - will build a two-layer Graph Convolutional Network (GCN)
 - each of its layer computes new node representations by aggregating neighbour information
 - The two differences to the 1_introduction
     - to predict a single category for the *entire graph* instead of for every node; need to aggregate the representations of all the nodes and potentially edges to form a graph-level representation; such process referred as a **readout**
     - the readout functions provided by DGL can handle batched graphs so that they will return one representation for each minibatch element

In [18]:
from dgl.nn import GraphConv

class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)
    
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata['h'] = h
        return dgl.mean_nodes(g, 'h')

## Training Loop

In [19]:
# create the model with given dimensions
model = GCN(dataset.dim_nfeats, 16, dataset.gclasses)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(20):
    for batched_graph, labels in train_dataloader:
        pred = model(batched_graph, batched_graph.ndata['attr'].float())
        loss = F.cross_entropy(pred, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [20]:
num_correct = 0
num_tests = 0
for batched_graph, labels in test_dataloader:
    pred = model(batched_graph, batched_graph.ndata['attr'].float())
    num_correct += (pred.argmax(1) == labels).sum().item()
    num_tests += len(labels)

print('Test accuracy {:.2f}'.format(num_correct / num_tests))

Test accuracy 0.35
