In [1]:
import pandas as pd
import numpy as np

import dgl
import dgl.function as fn
import torch as th

Using backend: pytorch


# Load graph

In [2]:
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator



In [3]:
data = DglNodePropPredDataset(name='ogbn-proteins', root='/home/stanislas/dataset')
evaluator = Evaluator(name='ogbn-proteins')

splitted_idx = data.get_idx_split()
train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"]
graph, labels = data[0]

In [4]:
def one_hot_encoder(x):
    ids = x.unique()
    id_dict = dict(list(zip(ids.numpy(), np.arange(len(ids)))))
    one_hot = th.zeros((len(x), len(ids)))
    for i, u in enumerate(x):
        if id_dict[u.item()] == 4:
            pass
        else:
            one_hot[i][id_dict[u.item()]] = 1

    return one_hot

In [5]:
species = graph.ndata['species']
features = one_hot_encoder(species)
graph.ndata['feat'] = features

In [6]:
def preprocess(graph, use_label=False):
    # add additional features
    graph.update_all(fn.copy_e("feat", "e"), fn.sum("e", "feat_add"))
    if use_label:
        graph.ndata['feat'] = th.cat((graph.ndata['feat_add'], graph.ndata['feat']), dim=1)
    else:
        graph.ndata['feat'] = graph.ndata['feat_add']
    graph.create_formats_()

    return graph

In [7]:
graph_prep = preprocess(graph)

# Load model

In [8]:
from gipa_model import *

model = GIPA(n_node_feat=8,
             n_edge_feat=8,
             n_node_emb=80,
             n_edge_emb=16,
             n_hiddens_att=[80],
             n_heads_att=8,
             n_hiddens_prop=[80],
             n_hiddens_agg=[],
             n_hiddens_deep=[],
             n_layers=6,
             n_classes=112,
             agg_type='sum',
             act_type='relu',
             edge_drop=0.1,
             dropout_node=0.1,
             dropout_att=0.1,
             dropout_prop=0.25,
             dropout_agg=0.25,
             dropout_deep=0.5)

In [9]:
model_path = "./saved_models/gipa_protein_model.pt"
model.load_state_dict(th.load(model_path))
model.eval()

GIPA(
  (node_emb): Linear(in_features=8, out_features=80, bias=False)
  (gipa_layers): ModuleList(
    (0): GIPAConv(
      (att_src_layers): ModuleList(
        (0): Linear(in_features=80, out_features=80, bias=False)
        (1): Linear(in_features=80, out_features=8, bias=False)
      )
      (att_dst_layers): ModuleList(
        (0): Linear(in_features=80, out_features=80, bias=False)
        (1): Linear(in_features=80, out_features=8, bias=False)
      )
      (att_edge_layers): ModuleList(
        (0): Linear(in_features=16, out_features=80, bias=False)
        (1): Linear(in_features=80, out_features=8, bias=False)
      )
      (src_prop_layers): ModuleList(
        (0): Linear(in_features=80, out_features=640, bias=True)
      )
      (dst_prop_layers): ModuleList(
        (0): Linear(in_features=80, out_features=640, bias=True)
      )
      (agg_layers): ModuleList()
      (dropout_att): Dropout(p=0.1, inplace=False)
      (dropout_prop): Dropout(p=0.25, inplace=False)
    

# Inference

In [10]:
pred = model.forward(graph_prep)

In [11]:
evaluator.eval({"y_pred": pred[val_idx], "y_true": labels[val_idx]})["rocauc"]

0.9189592164422503

In [12]:
evaluator.eval({"y_pred": pred[test_idx], "y_true": labels[test_idx]})["rocauc"]

0.8705416020922279