## Virtual Node Example

### Define a model

In [1]:
import sys 
sys.path.append("..")

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from dgl.nn import AvgPooling
import dgl.function as fn
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder

from gtrick import VirtualNode

In [3]:
class EGINConv(nn.Module):
    def __init__(self, emb_dim):
        '''
            emb_dim (int): node embedding dimensionality
        '''

        super(EGINConv, self).__init__()

        self.mlp = nn.Sequential(nn.Linear(emb_dim, emb_dim),
                                 nn.BatchNorm1d(emb_dim),
                                 nn.ReLU(),
                                 nn.Linear(emb_dim, emb_dim))
        self.eps = nn.Parameter(torch.Tensor([0]))

    def reset_parameters(self):
        for c in self.mlp.children():
            if hasattr(c, 'reset_parameters'):
                c.reset_parameters()
        nn.init.constant_(self.eps.data, 0)

    def forward(self, g, x, edge_embedding):
        with g.local_scope():
            g.ndata['x'] = x
            g.apply_edges(fn.copy_u('x', 'm'))
            g.edata['m'] = F.relu(g.edata['m'] + edge_embedding)
            g.update_all(fn.copy_e('m', 'm'), fn.sum('m', 'new_x'))
            out = self.mlp((1 + self.eps) * x + g.ndata['new_x'])

            return out

In [9]:
class EGIN(nn.Module):

    def __init__(self, hidden_channels, out_channels, num_layers,
                 dropout, task_type):

        super(EGIN, self).__init__()

        self.node_encoder = AtomEncoder(hidden_channels)
        self.edge_encoder = BondEncoder(hidden_channels)

        self.convs = nn.ModuleList()

        for _ in range(num_layers-1):
            self.convs.append(
                VirtualNode(
                    EGINConv(hidden_channels),
                    hidden_channels,
                    hidden_channels,
                    dropout
                )
            )
        
        self.convs.append(
            EGINConv(hidden_channels)
        )

        self.dropout = dropout

        self.task_type = task_type

        self.pool = AvgPooling()

        self.out = nn.Linear(hidden_channels, out_channels)

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

        self.out.reset_parameters()

    def forward(self, g, x, ex):
        h = self.node_encoder(x)
        eh = self.edge_encoder(ex)

        for i, conv in enumerate(self.convs[:-1]):
            if i == 0:
                h, vh = conv(g, h, eh)
            else:
                h, vh = conv(g, h, eh, vh)
            
        h = self.convs[-1](g, h, eh)

        h = self.pool(g, h)
        h = self.out(h)

        return h

### Graph Property Prediction

In [5]:
import argparse
from ogb.graphproppred import DglGraphPropPredDataset
from graph_pred import run_graph_pred

In [6]:
!nvidia-smi

Wed Apr 20 03:29:48 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.142.00   Driver Version: 450.142.00   CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            On   | 00000000:00:1B.0 Off |                    0 |
| N/A   42C    P8    15W /  70W |      3MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla T4            On   | 00000000:00:1C.0 Off |                    0 |
| N/A   41C    P8    15W /  70W |      3MiB / 15109MiB |      0%      Default |
|       

In [7]:
parser = argparse.ArgumentParser(
    description='train graph property prediction')
parser.add_argument("--dataset", type=str, default="ogbg-molhiv",
                    choices=["ogbg-molhiv"])
parser.add_argument("--dataset_path", type=str, default="/home/ubuntu/.dgl_dataset",
                    help="path to dataset")
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--log_steps', type=int, default=1)
parser.add_argument('--num_layers', type=int, default=5)
parser.add_argument('--hidden_channels', type=int, default=256)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument("--batch_size", type=int, default=128,
                    help="batch size")
parser.add_argument('--num_workers', type=int, default=0,
                    help='number of workers (default: 0)')
parser.add_argument('--epochs', type=int, default=500)
parser.add_argument('--runs', type=int, default=5)
parser.add_argument('--patience', type=int, default=30)
args = parser.parse_args(args=[])
print(args)

Namespace(batch_size=128, dataset='ogbg-molhiv', dataset_path='/home/ubuntu/.dgl_dataset', device=0, dropout=0.5, epochs=500, hidden_channels=256, log_steps=1, lr=0.001, num_layers=5, num_workers=0, patience=30, runs=5)


In [10]:
dataset = DglGraphPropPredDataset(
    name=args.dataset, root=args.dataset_path)

model = EGIN(args.hidden_channels,
                dataset.num_tasks, args.num_layers,
                args.dropout, dataset.task_type)

In [11]:
run_graph_pred(args, model, dataset)

Run: 01, Epoch: 01, Loss: 0.0420, Train: 0.6862, Valid: 0.7034 Test: 0.6634
Run: 01, Epoch: 02, Loss: 0.0357, Train: 0.7200, Valid: 0.7163 Test: 0.6742
Run: 01, Epoch: 03, Loss: 0.0348, Train: 0.7432, Valid: 0.7408 Test: 0.7326
Run: 01, Epoch: 04, Loss: 0.0336, Train: 0.7460, Valid: 0.6760 Test: 0.6520
EarlyStopping counter: 1 out of 30

Run: 01, Epoch: 05, Loss: 0.0331, Train: 0.7687, Valid: 0.7180 Test: 0.7107
EarlyStopping counter: 2 out of 30

Run: 01, Epoch: 06, Loss: 0.0327, Train: 0.7622, Valid: 0.7355 Test: 0.6883
EarlyStopping counter: 3 out of 30

Run: 01, Epoch: 07, Loss: 0.0322, Train: 0.7805, Valid: 0.7394 Test: 0.7261
EarlyStopping counter: 4 out of 30

Run: 01, Epoch: 08, Loss: 0.0331, Train: 0.7654, Valid: 0.7443 Test: 0.6877
Run: 01, Epoch: 09, Loss: 0.0323, Train: 0.7457, Valid: 0.7680 Test: 0.7487
Run: 01, Epoch: 10, Loss: 0.0318, Train: 0.7777, Valid: 0.7974 Test: 0.7359
Run: 01, Epoch: 11, Loss: 0.0312, Train: 0.7791, Valid: 0.7692 Test: 0.7410
EarlyStopping counte