In [1]:
import torch
from torch_geometric.loader import DataLoader
import torch.optim as optim
import torch.nn.functional as F
import models

from tqdm import tqdm
import argparse
import time
import numpy as np

from torch_geometric.utils import dense_to_sparse, to_networkx
import subgraph

import importlib
importlib.reload(models)
importlib.reload(subgraph)

### importing OGB
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator

In [2]:
def get_pair_info(batch):
    # DataBatch(edge_index=[2, 1826], edge_attr=[1826, 3], x=[847, 9], y=[32, 1], num_nodes=847, batch=[847], ptr=[33])
    n_nodes = batch.x.shape[0]
    n_edges = batch.edge_index.shape[1]
    nhbr_info = subgraph.compute_nhbr_pair_data(to_networkx(batch), batch.edge_index, edge_feat, edge_only)
    adj = torch.sparse_coo_tensor(batch.edge_index, torch.ones(n_edges), (n_nodes, n_nodes)).coalesce()
    return nhbr_info, adj


In [3]:
cls_criterion = torch.nn.BCEWithLogitsLoss()
reg_criterion = torch.nn.MSELoss()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

train_nhbr_map = {}
val_nhbr_map = {}
test_nhbr_map = {}

def train(model, device, loader, optimizer, task_type):
    model.train()

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        # batch = batch.to(device)

        if batch.x.shape[0] == 1 or batch.batch[-1] == 0:
            pass
        else:
            # if step not in train_nhbr_map:
            #     train_nhbr_map[batch.edge_index] = get_pair_info(batch)
            # nhbr_info, adj = train_nhbr_map[batch.edge_index]
            nhbr_info, adj = get_pair_info(batch)
            con, con_sct, not_con, not_con_sct = nhbr_info
            con = con.to(device)
            con_sct = con_sct.to(device)
            not_con = not_con.to(device)
            not_con_sct = not_con_sct.to(device)
            adj = adj.to(device)
            x = batch.x.to(device)
            e = torch.vstack((batch.edge_attr, torch.zeros((1, 3)))).to(device)
            batch_idx = batch.batch.to(device)
            pred = model(x, e, adj, (con, con_sct, not_con, not_con_sct), batch_idx)

            optimizer.zero_grad()
            ## ignore nan targets (unlabeled) when computing training loss.
            is_labeled = batch.y == batch.y
            if "classification" in task_type: 
                loss = cls_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(device).to(torch.float32)[is_labeled])
            else:
                loss = reg_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(device).to(torch.float32)[is_labeled])
            loss.backward()
            optimizer.step()

def eval(model, device, loader, evaluator, split_type):
    model.eval()
    y_true = []
    y_pred = []
    if split_type == "train":
        mapp = train_nhbr_map
    elif split_type == "val":
        mapp = val_nhbr_map
    else:
        mapp = test_nhbr_map

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        # batch = batch.to(device)

        if batch.x.shape[0] == 1:
            pass
        else:
            with torch.no_grad():
                # if step not in mapp:
                #     mapp[batch.edge_index] = get_pair_info(batch)
                # nhbr_info, adj = mapp[batch.edge_index]
                nhbr_info, adj = get_pair_info(batch)
                con, con_sct, not_con, not_con_sct = nhbr_info
                con = con.to(device)
                con_sct = con_sct.to(device)
                not_con = not_con.to(device)
                not_con_sct = not_con_sct.to(device)
                adj = adj.to(device)
                x = batch.x.to(device)
                e = torch.vstack((batch.edge_attr, torch.zeros((1, 3)))).to(device)
                batch_idx = batch.batch.to(device)
                pred = model(x, e, adj, (con, con_sct, not_con, not_con_sct), batch_idx)

            y_true.append(batch.y.view(pred.shape).detach().cpu())
            y_pred.append(pred.detach().cpu())

    y_true = torch.cat(y_true, dim = 0).numpy()
    y_pred = torch.cat(y_pred, dim = 0).numpy()

    input_dict = {"y_true": y_true, "y_pred": y_pred}

    return evaluator.eval(input_dict)



In [None]:
# Training settings
parser = argparse.ArgumentParser(description='GNN baselines on ogbgmol* data with Pytorch Geometrics')
parser.add_argument('--device', type=int, default=0,
                    help='which gpu to use if any (default: 0)')
parser.add_argument('--gnn', type=str, default='gin-virtual',
                    help='GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)')
parser.add_argument('--drop_ratio', type=float, default=0.5,
                    help='dropout ratio (default: 0.5)')
parser.add_argument('--num_layer', type=int, default=5,
                    help='number of GNN message passing layers (default: 5)')
parser.add_argument('--emb_dim', type=int, default=300,
                    help='dimensionality of hidden units in GNNs (default: 300)')
parser.add_argument('--batch_size', type=int, default=32,
                    help='input batch size for training (default: 32)')
parser.add_argument('--epochs', type=int, default=100,
                    help='number of epochs to train (default: 100)')
parser.add_argument('--num_workers', type=int, default=0,
                    help='number of workers (default: 0)')
parser.add_argument('--dataset', type=str, default="ogbg-molhiv",
                    help='dataset name (default: ogbg-molhiv)')
parser.add_argument('--lr', default=0.001, help='Initial learning rate.')
parser.add_argument('--weight_decay', type=float, default=0,
                    help='Weight decay (L2 loss on parameters).')
parser.add_argument('--edge_feat', action='store_true', default=False, 
                    help='Inject edge features')
parser.add_argument('--edge_only', action='store_true', default=False, 
                    help='Aggregate edge only pairs in neighbourhood.')
args, unknown = parser.parse_known_args()

edge_feat = args.edge_feat
edge_only = args.edge_only

### automatic dataloading and splitting
dataset = PygGraphPropPredDataset(name = args.dataset)

split_idx = dataset.get_idx_split()

### automatic evaluator. takes dataset name as input
evaluator = Evaluator(args.dataset)

train_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)

model = models.GNN_graphprop(nfeat=dataset.data.x.shape[1] + args.emb_dim, 
                        nhid=args.emb_dim, 
                        nclass=1,  # BCE for molhiv
                        nlayers=args.num_layer,
                        dropout=args.drop_ratio,
                        edge_feat=edge_feat)  # TODO consider virtual node
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# optimizer = optim.Adam(model.parameters(), lr=0.001)

valid_curve = []
test_curve = []
train_curve = []

for epoch in range(1, args.epochs + 1):
    print("=====Epoch {}".format(epoch))
    print('Training...')
    train(model, device, train_loader, optimizer, dataset.task_type)

    print('Evaluating...')
    train_perf = eval(model, device, train_loader, evaluator, "train")
    valid_perf = eval(model, device, valid_loader, evaluator, "val")
    test_perf = eval(model, device, test_loader, evaluator, "test")
    print({'Train': train_perf, 'Validation': valid_perf, 'Test': test_perf})
    train_curve.append(train_perf[dataset.eval_metric])
    valid_curve.append(valid_perf[dataset.eval_metric])
    test_curve.append(test_perf[dataset.eval_metric])

if 'classification' in dataset.task_type:
    best_val_epoch = np.argmax(np.array(valid_curve))
    best_train = max(train_curve)
else:
    best_val_epoch = np.argmin(np.array(valid_curve))
    best_train = min(train_curve)

print('Finished training!')
print('Best validation score: {}'.format(valid_curve[best_val_epoch]))
print('Test score: {}'.format(test_curve[best_val_epoch]))


=====Epoch 1
Training...


Iteration:   6%|███                                               | 62/1029 [00:02<00:32, 29.53it/s]