## Virtual Node Example

### Define a model

In [1]:
import sys 
sys.path.append("..")

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from dgl.nn import AvgPooling
import dgl.function as fn
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder

from gtrick import VirtualNode

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class EGCNConv(nn.Module):
    def __init__(self, in_channels):
        super(EGCNConv, self).__init__()

        self.linear = nn.Linear(in_channels, in_channels)
        self.root_emb = nn.Embedding(1, in_channels)
    
    def reset_parameters(self):
        self.linear.reset_parameters()
        self.root_emb.reset_parameters()

    def forward(self, g, x, edge_embedding):
        with g.local_scope():
            x = self.linear(x)

            # Molecular graphs are undirected
            # g.out_degrees() is the same as g.in_degrees()
            degs = (g.out_degrees().float() + 1).to(x.device)
            norm = torch.pow(degs, -0.5).unsqueeze(-1)                # (N, 1)
            g.ndata['norm'] = norm
            g.apply_edges(fn.u_mul_v('norm', 'norm', 'norm'))

            g.ndata['x'] = x
            g.apply_edges(fn.copy_u('x', 'm'))

            g.edata['m'] = g.edata['norm'] * \
                F.relu(g.edata['m'] + edge_embedding)
            g.update_all(fn.copy_e('m', 'm'), fn.sum('m', 'new_x'))
            out = g.ndata['new_x'] + \
                F.relu(x + self.root_emb.weight) * 1. / degs.view(-1, 1)

            return out

In [4]:
class EGCN(nn.Module):

    def __init__(self, hidden_channels, out_channels, num_layers,
                 dropout, task_type):

        super(EGCN, self).__init__()

        self.node_encoder = AtomEncoder(hidden_channels)
        self.edge_encoder = BondEncoder(hidden_channels)

        self.convs = nn.ModuleList()

        for _ in range(num_layers-1):
            self.convs.append(
                VirtualNode(
                    EGCNConv(hidden_channels),
                    hidden_channels,
                    hidden_channels,
                    dropout
                )
            )
        
        self.convs.append(
            EGCNConv(hidden_channels)
        )

        self.dropout = dropout

        self.task_type = task_type

        self.pool = AvgPooling()

        self.out = nn.Linear(hidden_channels, out_channels)

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

        self.out.reset_parameters()

    def forward(self, g, x, ex):
        h = self.node_encoder(x)
        eh = self.edge_encoder(ex)

        for i, conv in enumerate(self.convs[:-1]):
            if i == 0:
                h, vh = conv(g, h, eh)
            else:
                h, vh = conv(g, h, eh, vh)
            
        h = self.convs[-1](g, h, eh)

        h = self.pool(g, h)
        h = self.out(h)

        return h

### Graph Property Prediction

In [5]:
import argparse
from ogb.graphproppred import DglGraphPropPredDataset
from graph_pred import run_graph_pred

In [6]:
parser = argparse.ArgumentParser(
    description='train graph property prediction')
parser.add_argument("--dataset", type=str, default="ogbg-molhiv",
                    choices=["ogbg-molhiv"])
parser.add_argument("--dataset_path", type=str, default="/home/ubuntu/.dgl_dataset",
                    help="path to dataset")
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--log_steps', type=int, default=1)
parser.add_argument('--num_layers', type=int, default=5)
parser.add_argument('--hidden_channels', type=int, default=256)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument("--batch_size", type=int, default=128,
                    help="batch size")
parser.add_argument('--num_workers', type=int, default=0,
                    help='number of workers (default: 0)')
parser.add_argument('--epochs', type=int, default=500)
parser.add_argument('--runs', type=int, default=5)
parser.add_argument('--patience', type=int, default=30)
args = parser.parse_args(args=[])
print(args)

Namespace(batch_size=128, dataset='ogbg-molhiv', dataset_path='/home/ubuntu/.dgl_dataset', device=0, dropout=0.5, epochs=500, hidden_channels=256, log_steps=1, lr=0.001, num_layers=5, num_workers=0, patience=30, runs=5)


In [7]:
dataset = DglGraphPropPredDataset(
    name=args.dataset, root=args.dataset_path)

model = EGCN(args.hidden_channels,
                dataset.num_tasks, args.num_layers,
                args.dropout, dataset.task_type)

In [8]:
run_graph_pred(args, model, dataset)

Run: 01, Epoch: 01, Loss: 0.0405, Train: 0.6430, Valid: 0.6484 Test: 0.6169
Run: 01, Epoch: 02, Loss: 0.0354, Train: 0.7084, Valid: 0.6896 Test: 0.6567
Run: 01, Epoch: 03, Loss: 0.0345, Train: 0.7323, Valid: 0.7595 Test: 0.7077
Run: 01, Epoch: 04, Loss: 0.0335, Train: 0.7202, Valid: 0.7341 Test: 0.6924
EarlyStopping counter: 1 out of 30

Run: 01, Epoch: 05, Loss: 0.0329, Train: 0.7493, Valid: 0.7755 Test: 0.6907
Run: 01, Epoch: 06, Loss: 0.0337, Train: 0.7503, Valid: 0.6792 Test: 0.6728
EarlyStopping counter: 1 out of 30

Run: 01, Epoch: 07, Loss: 0.0335, Train: 0.7706, Valid: 0.7489 Test: 0.6759
EarlyStopping counter: 2 out of 30

Run: 01, Epoch: 08, Loss: 0.0316, Train: 0.7754, Valid: 0.7552 Test: 0.7172
EarlyStopping counter: 3 out of 30

Run: 01, Epoch: 09, Loss: 0.0315, Train: 0.7795, Valid: 0.7681 Test: 0.6830
EarlyStopping counter: 4 out of 30

Run: 01, Epoch: 10, Loss: 0.0313, Train: 0.7760, Valid: 0.7466 Test: 0.7373
EarlyStopping counter: 5 out of 30

Run: 01, Epoch: 11, Loss