In [1]:
! pip list | grep torch

torch                         1.13.1+cu116
torchaudio                    0.13.1+cu116
torchsummary                  1.5.1
torchtext                     0.14.1
torchvision                   0.14.1+cu116


In [2]:
! pip install  dgl -f https://data.dgl.ai/wheels/cu116/repo.html
! pip install  dglgo -f https://data.dgl.ai/wheels-test/repo.html

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.dgl.ai/wheels/cu116/repo.html
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.dgl.ai/wheels-test/repo.html


In [3]:
import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import dgl
from dgl.data import GINDataset
from dgl import function as fn
from dgl.utils import expand_as_pair

from sklearn.metrics import classification_report

class GINConv(nn.Module):
    def __init__(
        self,
        input_dim,
        hidden_dim,
        init_eps=0,
        learn_eps=False,
    ):
        super().__init__()
        
        # "We can make ε a learnable parameter or a fixed scalar." refer to P5
        self.eps = nn.Parameter(
            torch.FloatTensor([init_eps]), 
            requires_grad=learn_eps
        )

        # accoding to section 5.1, 1-LAYER PERCEPTRONS ARE NOT SUFFICIENT
        # we use 3-layer mlp here.
        self.mlp = nn.Sequential(
            nn.BatchNorm1d(input_dim),
            nn.ReLU(),
            nn.Linear(input_dim, hidden_dim),

            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),

            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

    def forward(self, graph, feat):
        with graph.local_scope():
            feat_src, feat_dst = expand_as_pair(feat, graph)
            graph.srcdata['h'] = feat_src
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
            rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh']
            rst = self.mlp(rst)
        return rst


class GIN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.ginlayers = nn.ModuleList()
        for layer in range(num_layers):
            self.ginlayers.append(
                # accoding to P8, GIN-0 shows strong empirical performance
                # so we set learn_eps to False
                GINConv(
                    input_dim  = input_dim if layer==0 else hidden_dim,
                    hidden_dim = hidden_dim, 
                    learn_eps  = False
                )
            )


        # pooling in each layer, refer to (the bottom of) page 5
        # 'features from earlier iterations may sometimes generalize better'
        # 'we use information from all depths/iterations of the model'
        readout_dim = hidden_dim * num_layers + input_dim 
        # use a 3-layer mlp to do the prediction
        self.prediction_head = nn.Sequential(
            nn.BatchNorm1d(readout_dim),
            nn.ReLU(),
            nn.Linear(readout_dim, hidden_dim),

            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),

            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, g, h):
        # list of hidden representation at each layer (including the input layer)
        # for the pooling layer
        # refer to (the bottom of) page 5
        hidden_rep = [h]
        for i, layer in enumerate(self.ginlayers):
            h = layer(g, h)
            hidden_rep.append(h)
        # perform graph sum pooling over all nodes in each layer

        readout_lits = []
        for i, h in enumerate(hidden_rep):
            # print(g)
            # print(SumPooling(g, h).shape)
            # print('hhh')
            # the number of nodes in each graph 
            n_node_per_graph_list = g.batch_num_nodes().cpu().numpy()
            # node index in `h`
            accumulated = [
                sum(n_node_per_graph_list[:i+1]) 
                for i in range(len(n_node_per_graph_list))
            ]
            # print(accumulated)
            start_end_idx_list = [
                (0 if g_idx==0 else accumulated[g_idx-1], end)
                for g_idx, end in enumerate(accumulated)
                                
            ]
            # print(start_end_idx_list)
            # n_graphs = len(n_node_per_graph_list)
            # sum nodes embedding for each graph
            readout = list(map(
                lambda start_end: (h[start_end[0]:start_end[1]]).sum(dim=0), 
                start_end_idx_list
            ))
            
            readout = torch.stack(readout, dim=0)
            # print(n_node_per_graph_list)
            # print(start_end_idx_list)
            # print(readout)
            # readout = self.pool(g, h)
            # print(readout)
            # return
            # print(readout.shape)

            # print(n_nodes_total,n_node_per_graph_list,n_graphs)
            # return

            readout_lits.append(readout)
            # print('hhh2')
            # print(i)
        
        # equation 4.2
        readout = torch.cat(readout_lits, dim=-1)

        pred = self.prediction_head(readout)

        return pred


def evaluate(dataloader, device, model):
    with torch.no_grad():
        model.eval()
        pred_list = []
        true_list = []
        for batched_graph, labels in dataloader:
            batched_graph = batched_graph.to(device)
            labels = labels.to(device)
            feat = batched_graph.ndata.pop("attr")
            logits = model(batched_graph, feat)
            pred = torch.argmax(logits, dim=1)

            pred_list.append(pred)
            true_list.append(labels)
        y_pred = torch.cat(pred_list,dim=-1).cpu().numpy()
        y_true = torch.cat(true_list,dim=-1).cpu().numpy()
        result = classification_report(
            y_true, y_pred, output_dict=True, digits=4, zero_division=0
        )
    return {'accuracy': result['accuracy'], 'macro avg f1': result['macro avg']['f1-score']}


def train(train_dataloader, val_dataloader, device, model):
    # loss function, optimizer
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters())
    # training loop
    best_acc = 0.
    best_f1 = 0.
    patience = 50
    not_improved = 0
    for epoch in range(10000):
        model.train()
        total_loss = 0
        for batch, (batched_graph, labels) in enumerate(train_dataloader):
            batched_graph = batched_graph.to(device)
            labels = labels.to(device)
            feat = batched_graph.ndata.pop("attr")
            logits = model(batched_graph, feat)
            loss = loss_fcn(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        train_result = evaluate(train_dataloader, device, model)
        val_result = evaluate(val_dataloader, device, model)
        # print(train_result['macro avg f1'])
        print(
            f"\rEpoch {epoch:05d} | Loss {total_loss / (batch + 1):.4f} | "
            + f"Train Acc {train_result['accuracy']:.4f} | Val Acc {val_result['accuracy']:.4f} |" 
            + f"Train Macro F1 {train_result['macro avg f1']:.4f} | Val Macro F1 {val_result['macro avg f1']:.4f}",
            end = ''
        )
        if best_acc < val_result['accuracy']:
            best_acc = val_result['accuracy']
            best_f1 = val_result['macro avg f1']
            not_improved = 0
        else:
            not_improved += 1
        if not_improved > patience:
            print(f"\rBest Val Acc {best_acc:.6f} | Best Val Macro F1 {best_f1:.6f}")
            return best_acc, best_f1
            
class GraphDataset(Dataset):
    def __init__(self, name:str, self_loop:bool, indices:list):
        
        # load dataset using dgl
        self.dataset = GINDataset(name, self_loop)
        # raw indices in the dataset
        self.indices = indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        # mapping local index to the raw index
        idx = self.indices[idx]
        graph, label = self.dataset[idx]
        return graph, label.reshape(1)

def collate_fn(data):
    """
      list of dataset samples
    """
    graphs, labels = zip(*data)

    return dgl.batch(graphs), torch.cat(labels)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
name = "MUTAG" # "MUTAG", "PTC", "NCI1", "PROTEINS"
self_loop = True
raw_dataset = GINDataset(name, self_loop)
num_samples = len(raw_dataset)
in_size = raw_dataset.dim_nfeats
out_size = raw_dataset.gclasses
train_ratio = 0.6
val_ratio = 1-train_ratio
n_train = int(train_ratio*num_samples)
hidden_size = 16
n_layers = 5
batch_size = 128


best_acc_list, best_f1_list = [], []
for n_round in range(5):
    shuffle_idx = torch.randperm(num_samples)
    train_idx = shuffle_idx[:n_train]
    val_idx = shuffle_idx[n_train:]

    train_dataset = GraphDataset(name, self_loop, train_idx)
    val_dataset = GraphDataset(name, self_loop, val_idx)
    # print()

    train_dataloader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, collate_fn = collate_fn, 
        pin_memory=torch.cuda.is_available()
    )
    val_dataloader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, collate_fn = collate_fn, 
        pin_memory=torch.cuda.is_available()
    )

    model = GIN(in_size, hidden_size, out_size, n_layers).to(device)
    best_acc, best_f1 = train(train_dataloader, val_dataloader, device, model)
    best_acc_list.append(best_acc)
    best_f1_list.append(best_f1)

print(f"Avg. Accuracy: {np.mean(best_acc_list):.4f} ± {np.std(best_acc_list):.4f}")
print(f"Avg. Macro F1: {np.mean(best_f1_list):.4f} ± {np.std(best_f1_list):.4f}")

Best Val Acc 0.842105 | Best Val Macro F1 0.802597
Best Val Acc 0.842105 | Best Val Macro F1 0.821176
Best Val Acc 0.868421 | Best Val Macro F1 0.856387
Best Val Acc 0.815789 | Best Val Macro F1 0.786859
Best Val Acc 0.894737 | Best Val Macro F1 0.885110
Avg. Accuracy: 0.8526 ± 0.0268
Avg. Macro F1: 0.8304 ± 0.0358
