## Random Feature Example (PyG)

In [None]:
import argparse

from ogb.graphproppred import PygGraphPropPredDataset

import torch
from torch import nn
import torch.nn.functional as F

from torch_geometric.nn import global_mean_pool

from model import EGCNConv, EGINConv

from ogb.graphproppred.mol_encoder import AtomEncoder

from graph_pred import run_graph_pred

from utils import seed_everything

In [None]:
# import random feature
from gtrick import random_feature

### Define a Model

In [None]:
class EGNN(nn.Module):

    def __init__(self, hidden_channels, out_channels, num_layers,
                 dropout, conv_type):

        super(EGNN, self).__init__()

        self.node_encoder = AtomEncoder(hidden_channels)

        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()

        self.num_layers = num_layers

        # the dim of hidden state plus 1
        hidden_channels += 1

        for i in range(self.num_layers):
            if conv_type == 'gin':
                self.convs.append(
                    EGINConv(hidden_channels, self.mol))
            elif conv_type == 'gcn':
                self.convs.append(
                    EGCNConv(hidden_channels, self.mol))

            self.bns.append(torch.nn.BatchNorm1d(hidden_channels))

        self.dropout = dropout

        self.out = nn.Linear(hidden_channels, out_channels)

    def reset_parameters(self):
        for emb in self.node_encoder.atom_embedding_list:
            nn.init.xavier_uniform_(emb.weight.data)

        for i in range(self.num_layers):
            self.convs[i].reset_parameters()
            self.bns[i].reset_parameters()

        self.out.reset_parameters()

    def forward(self, batch_data):
        x, edge_index, edge_attr, batch = batch_data.x, batch_data.edge_index, batch_data.edge_attr, batch_data.batch

        h = self.node_encoder(x)

        # use random_feature to add a random feature (batch_size x 1) to h
        h = random_feature(h)

        for i, conv in enumerate(self.convs[:-1]):
            h = conv(h, edge_index, edge_attr)
            h = self.bns[i](h)
            h = F.relu(h)
            
            h = F.dropout(h, p=self.dropout, training=self.training)

        h = self.convs[-1](h, edge_index, edge_attr)

        if not self.mol:
            h = self.bns[-1](h)

        h = F.dropout(h, self.dropout, training=self.training)

        h = global_mean_pool(h, batch)

        h = self.out(h)

        return h

### Run Experiment

In [None]:
parser = argparse.ArgumentParser(
    description='train graph property prediction')
parser.add_argument('--dataset', type=str, default='ogbg-molhiv',
                    choices=['ogbg-molhiv', 'ogbg-ppa'])
parser.add_argument('--dataset_path', type=str, default='/dev/dataset',
                    help='path to dataset')
parser.add_argument('--device', type=int, default=1)
parser.add_argument('--log_steps', type=int, default=1)
parser.add_argument('--num_layers', type=int, default=4)
parser.add_argument('--hidden_channels', type=int, default=300)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--batch_size', type=int, default=64,
                    help='batch size')
parser.add_argument('--num_workers', type=int, default=0,
                    help='number of workers (default: 0)')
parser.add_argument('--model', type=str, default='gcn')
parser.add_argument('--epochs', type=int, default=500)
parser.add_argument('--runs', type=int, default=3)
parser.add_argument('--patience', type=int, default=30)
args = parser.parse_args(args=[])
print(args)

seed_everything(3042)

In [None]:
dataset = PygGraphPropPredDataset(
name=args.dataset, root=args.dataset_path)

model = EGNN(args.hidden_channels,
                dataset.num_tasks, args.num_layers,
                args.dropout, args.model)

In [None]:
run_graph_pred(args, model, dataset)