## FLAG Example (PyG)

In [None]:
import argparse

from ogb.nodeproppred import PygNodePropPredDataset, Evaluator

import torch
from torch import nn
import torch.nn.functional as F

from torch_geometric.nn import GCNConv, SAGEConv
import torch_geometric.transforms as T

from utils import Logger, EarlyStopping

In [None]:
import sys
sys.path.append('../../')
from gtrick import FLAG

### Define a Model

In [None]:
class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout, conv_type):
        super(GNN, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.bns = torch.nn.ModuleList()

        for i in range(num_layers):
            if conv_type == 'gcn':
                if i == 0:
                    self.convs.append(GCNConv(in_channels, hidden_channels, cached=True))
                elif i == num_layers - 1:
                    self.convs.append(GCNConv(hidden_channels, out_channels, cached=True))
                else:
                    self.convs.append(
                        GCNConv(hidden_channels, hidden_channels, cached=True))
            elif conv_type == 'sage':
                if i == 0:
                    self.convs.append(SAGEConv(in_channels, hidden_channels))
                elif i == num_layers - 1:
                    self.convs.append(SAGEConv(hidden_channels, hidden_channels))
                else:
                    self.convs.append(
                        SAGEConv(hidden_channels, hidden_channels))
            
            if i != num_layers - 1:
                self.bns.append(torch.nn.BatchNorm1d(hidden_channels))

        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    # add a param perturb to pass perturb
    def forward(self, x, adj_t, perturb=None):
        # add perturb to x, note that do not use x += perturb
        if perturb is not None:
            x = x + perturb

        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, adj_t)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)
        return x

### Define Train Process

In [None]:
# pass flag to train func
def train(model, data, train_idx, flag):
    y = data.y[train_idx]

    # define a forward func to get the output of the model
    forward = lambda perturb: model(data.x, data.adj_t, perturb)[train_idx]

    # run flag to get loss and output
    loss, out = flag(model, forward, data.x.shape[0], y.squeeze(1))

    return loss.item()

In [None]:
@torch.no_grad()
def test(model, data, split_idx, evaluator, eval_metric):
    model.eval()

    y = data.y
    out = model(data.x, data.adj_t)
    y_pred = out.argmax(dim=-1, keepdim=True)

    train_metric = evaluator.eval({
        'y_true': y[split_idx['train']],
        'y_pred': y_pred[split_idx['train']],
    })[eval_metric]
    valid_metric = evaluator.eval({
        'y_true': y[split_idx['valid']],
        'y_pred': y_pred[split_idx['valid']],
    })[eval_metric]
    test_metric = evaluator.eval({
        'y_true': y[split_idx['test']],
        'y_pred': y_pred[split_idx['test']],
    })[eval_metric]

    return train_metric, valid_metric, test_metric


In [None]:
def run_node_pred(args, model, dataset):
    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    model.to(device)

    evaluator = Evaluator(name=args.dataset)

    data = dataset[0]
    data.adj_t = data.adj_t.to_symmetric()
    data = data.to(device)

    split_idx = dataset.get_idx_split()
    train_idx = split_idx['train']

    logger = Logger(args.runs, mode='max')

    for run in range(args.runs):
        model.reset_parameters()
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
        early_stopping = EarlyStopping(
            patience=args.patience, verbose=True, mode='max')

        if dataset.task_type == 'binary classification':
            loss_func = nn.BCEWithLogitsLoss()
        elif dataset.task_type == 'multiclass classification':
            loss_func = nn.CrossEntropyLoss()
        
        # define flag, params: in_feats, loss_func, optimizer
        flag = FLAG(data.x.shape[1], loss_func, optimizer)

        for epoch in range(1, 1 + args.epochs):
            loss = train(model, data, train_idx, flag)
            result = test(model, data, split_idx,
                          evaluator, dataset.eval_metric)
            logger.add_result(run, result)

            train_acc, valid_acc, test_acc = result

            if epoch % args.log_steps == 0:
                print(f'Run: {run + 1:02d}, '
                      f'Epoch: {epoch:02d}, '
                      f'Loss: {loss:.4f}, '
                      f'Train: {100 * train_acc:.2f}%, '
                      f'Valid: {100 * valid_acc:.2f}% '
                      f'Test: {100 * test_acc:.2f}%')

            if early_stopping(valid_acc, model):
                break

        logger.print_statistics(run)
    logger.print_statistics()

### Run Experiment

In [None]:
parser = argparse.ArgumentParser(
    description='train node property prediction')
parser.add_argument("--dataset", type=str, default="ogbn-arxiv",
                    choices=["ogbn-arxiv"])
parser.add_argument("--dataset_path", type=str, default="/dev/dataset",
                    help="path to dataset")
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--log_steps', type=int, default=1)
parser.add_argument('--model', type=str, default='sage')
parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('--hidden_channels', type=int, default=256)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--epochs', type=int, default=500)
parser.add_argument('--runs', type=int, default=1)
parser.add_argument('--patience', type=int, default=10)
args = parser.parse_args(args=[])
print(args)

In [None]:
dataset = PygNodePropPredDataset(
    name=args.dataset, 
    transform=T.ToSparseTensor(), 
    root=args.dataset_path
    )
data = dataset[0]

num_features = data.x.shape[1]

model = GNN(num_features, args.hidden_channels,
                    dataset.num_classes, args.num_layers,
                    args.dropout, args.model)

In [None]:
run_node_pred(args, model, dataset)