In [1]:
import sys
sys.path.append('../src/')

In [2]:
import os
import os.path as osp

import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid

from torch_geometric import seed_everything

import numpy as np
import pandas as pd

import networkx as nx
from torch_geometric.utils.convert import from_networkx

from tqdm import tqdm

from models import LinkGIN, LinkGCN, LinkSAGE, DeepVGAE
from decoders import InnerProductDecoder, CosineDecoder
from explainers import gnnexplainer, ig, deconvolution, backprop
from metrics import deletion_curve_edges, deletion_curve_features, linear_area_score
from utils import ws_graph_model, sbm_graph_model, get_computation_graph_as_nx
from utils import get_explanation
from plotting import visualize_explanation

from torch_geometric.utils import k_hop_subgraph

from matplotlib import pyplot as plt
import seaborn as sb

from torch_geometric.nn import APPNP, MessagePassing

process the dataset from ogb

https://ogb.stanford.edu/docs/linkprop/

Reference: 

https://github.com/snap-stanford/ogb/blob/master/examples/linkproppred/collab/gnn.py

https://medium.com/@xjurajkmec/pairwise-learning-for-neural-link-prediction-f1d16a0d28f6

https://github.com/lustoo/OGB_link_prediction/blob/main/DDI/link_pred_ddi_graphsage_edge.py#L173


图基准数据集（Open Graph Benchmark, OGB）
https://zhuanlan.zhihu.com/p/358995864?utm_id=0

1. ogbl-collab
- Nodes: 235,868
- Edges: 1,285,465

2. ogbl-ddi
- Nodes: 4,267
- Edges: 1,285,465

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

graph_model = 'ogbl-collab'  # ['cora', 'ogbl-collab']
model_name = 'gcn'  # ['gin', 'gcn', 'sage', 'vgae']
explainer = 'gnnexplainer'  # ['random', 'ig', 'gnnexplainer', 'deconvolution', 'grad', 'empty', 'pgmexplainer']
decoder = 'inner'  # ['inner', 'cosine']
edge_noise_type = 'bernoulli_whole'  # ['bernoulli_whole', 'bernoulli_computation', 'kde']
load_model = True
seed = 0

if model_name == 'vgae':
    sigmoid = False
else:
    sigmoid = True

if model_name == 'vgae':
    return_type = 'probs'
    from train_test import train_vgae as train
    from train_test import test_vgae as test
else:
    return_type = 'raw'
    from train_test import train
    from train_test import test
    
print(seed, graph_model, model_name, explainer, decoder, return_type)  

output_folder = f"../outputs/{graph_model}/{model_name}/{explainer}/"
model_folder = f"../outputs/{graph_model}/{model_name}/"

In [4]:
seed_everything(seed)

if graph_model in ['cora']:
    
    transform = T.Compose([
        T.NormalizeFeatures(),
        T.ToDevice(device),
        T.RandomLinkSplit(num_val=0.0, num_test=0.2, is_undirected=True),
    ])

    dataset = graph_model.capitalize()
    path = osp.join('../', 'data', 'Planetoid')
    dataset = Planetoid(path, dataset, transform=transform)
    # train_data, val_data, test_data = dataset[0]
    train_data, _, test_data = dataset[0]

elif graph_model in ['ogbl-collab']:
    
    
#     from ogb.linkproppred import PygLinkPropPredDataset

#     dataset = PygLinkPropPredDataset(name='ogbl-collab') 
#     split_edge = dataset.get_edge_split()
#     train_edge, valid_edge, test_edge = split_edge["train"], split_edge["valid"], split_edge["test"]
#     graph = dataset[0] # pyg graph object containing only training edges

In [5]:
device = 0
device = f'cuda:{device}' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)
device

In [6]:
from ogb.linkproppred import PygLinkPropPredDataset
import torch_geometric.utils as U
import torch_geometric.transforms as T

dataset = PygLinkPropPredDataset(name='ogbl-collab')

split_edge = dataset.get_edge_split()
data = dataset[0]

edge_index = data.edge_index
data.edge_weight = data.edge_weight.view(-1).to(torch.float)
data = T.ToSparseTensor()(data)
data

In [66]:
# torch.concat([split_edge['valid']['edge'], split_edge['train']['edge']], dim=0).shape

# new_edges = torch.concat([split_edge['valid']['edge'], split_edge['train']['edge']], dim=0)
# new_weights = torch.concat([split_edge['valid']['weight'], split_edge['train']['weight']])
# # new_edges, new_weights = U.coalesce(new_edges.t(), new_weights)
# split_edge['train']['edge'] = new_edges.t()
# split_edge['train']['weight'] = new_weights

In [107]:
# new_edges, new_weights = U.to_undirected(split_edge['train']['edge'].t(), split_edge['train']['weight'])
# new_weights = new_weights.unsqueeze(-1)
# data.edge_index = new_edges
# data.edge_weight = new_weights

# # data = T.ToSparseTensor()(data)
# # data.adj_t = data.adj_t.float()
# # row, col, data.edge_weight = data.adj_t.t().coo()
# # data.edge_index = torch.stack([row, col], dim=0)

# data

In [None]:
# # include validation edges in the original training set as the new training set

# data = train_data.clone()
# edge_split = deepcopy(edge_split)

# # Concatenate train and validation splits, remove duplicities.
# new_edges = torch.concat([edge_split['valid']['edge'], edge_split['train']['edge']], dim=0)
# new_edges = U.coalesce(new_edges.t())
# edge_split['train']['edge'] = new_edges.t()

# # Update the data object.
# # Create a new sparse adjacency matrix from the new training split.
# new_edges = U.to_undirected(edge_split['train']['edge'].t())

# data.edge_index = new_edges

# data = T.ToSparseTensor()(data)
# data.adj_t = data.adj_t.float()
# row, col, _ = data.adj_t.t().coo()
# data.edge_index = torch.stack([row, col], dim=0)



In [7]:
# from torch_sparse import SparseTensor

split_edge = dataset.get_edge_split()

use_valedges_as_input = False

# Use training + validation edges for inference on test set.
if use_valedges_as_input:
    val_edge_index = split_edge['valid']['edge'].t()
    full_edge_index = torch.cat([edge_index, val_edge_index], dim=-1)
    data.full_adj_t = SparseTensor.from_edge_index(full_edge_index).t()
    data.full_adj_t = data.full_adj_t.to_symmetric()
else:
    data.full_adj_t = data.adj_t

data = data.to(device)

In [13]:
dataset[0]

In [15]:
class GCN_OGB(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(GCN_OGB, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels, cached=True))
        for _ in range(num_layers - 2):
            self.convs.append(
                GCNConv(hidden_channels, hidden_channels, cached=True))
        self.convs.append(GCNConv(hidden_channels, out_channels, cached=True))

        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, adj_t):
        for conv in self.convs[:-1]:
            x = conv(x, adj_t)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)
        return x
    
class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(LinkPredictor, self).__init__()

        self.lins = torch.nn.ModuleList()
        self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
        self.lins.append(torch.nn.Linear(hidden_channels, out_channels))

        self.dropout = dropout

    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()

    def forward(self, x_i, x_j):
        x = x_i * x_j
        for lin in self.lins[:-1]:
            x = lin(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)
        return torch.sigmoid(x
                             
                             



In [None]:
class LinkGCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, sim='inner', sigmoid=False):
        super().__init__()

        self.sigmoid = sigmoid

        if sim == 'inner':
            self.decoder = InnerProductDecoder()
        elif sim == 'cosine':
            self.decoder = CosineDecoder()

        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def encoder(self, x, edge_index, edge_weight=None):
        x = self.conv1(x, edge_index, edge_weight).relu()
        x = self.conv2(x, edge_index, edge_weight)
        return x

    def encode(self, x, edge_index, edge_weight=None):

        return self.encoder(x, edge_index, edge_weight)

    def decode(self, z, edge_label_index):

        return self.decoder(z, edge_label_index, sigmoid=self.sigmoid)

    def decode_all(self, z):

        return self.decoder.forward_all(z, sigmoid=self.sigmoid)

    def forward(self, x, edge_index, edge_label_index, edge_weight=None):
        z = self.encode(x, edge_index, edge_weight)
        return self.decode(z, edge_label_index).view(-1)

In [None]:
# import torch
# import torch.nn.functional as F
# from torch_geometric.nn import GCNConv

# class LinkGCNCombined(torch.nn.Module):
#     def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout, sim='inner', sigmoid=True):
#         super().__init__()
#         self.sigmoid = sigmoid

#         # Initialize GCN layers
#         self.convs = torch.nn.ModuleList()
#         self.convs.append(GCNConv(in_channels, hidden_channels))
#         for _ in range(num_layers - 2):
#             self.convs.append(GCNConv(hidden_channels, hidden_channels))
#         self.convs.append(GCNConv(hidden_channels, out_channels))

#         # Initialize Link Predictor layers
#         self.lins = torch.nn.ModuleList()
#         self.lins.append(torch.nn.Linear(out_channels, hidden_channels))
#         for _ in range(num_layers - 2):
#             self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
#         self.lins.append(torch.nn.Linear(hidden_channels, 1))

#         self.dropout = dropout

#         # Decide the similarity function
#         if sim == 'inner':
#             self.similarity = lambda x_i, x_j: (x_i * x_j).sum(dim=-1)
#         elif sim == 'cosine':
#             self.similarity = lambda x_i, x_j: F.cosine_similarity(x_i, x_j, dim=-1)

# #     def reset_parameters(self):
# #         for conv in self.convs:
# #             conv.reset_parameters()
# #         for lin in self.lins:
# #             lin.reset_parameters()

#     def encode(self, x, edge_index):
#         # Pass data through GCN layers
#         for conv in self.convs[:-1]:
#             x = conv(x, edge_index).relu()
#             x = F.dropout(x, p=self.dropout, training=self.training)
#         x = self.convs[-1](x, edge_index)
#         return x

#     def decode(self, z, edge_label_index):
#         # Extract node embeddings for the given edge indices
#         z_i, z_j = z[edge_label_index[0]], z[edge_label_index[1]]
        
#         # Apply link predictor layers
#         x = self.similarity(z_i, z_j)
        
#         for lin in self.lins[:-1]:
#             x = lin(x).relu()
#             x = F.dropout(x, p=self.dropout, training=self.training)
#         x = self.lins[-1](x)
        
#         return torch.sigmoid(x) if self.sigmoid else x

#     def forward(self, x, edge_index, edge_label_index):
#         z = self.encode(x, edge_index)
#         return self.decode(z, edge_label_index).view(-1)


In [24]:
import argparse

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

# from torch_sparse import SparseTensor
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, SAGEConv

from ogb.linkproppred import PygLinkPropPredDataset, Evaluator

from logging import Logger

# if args.use_sage:
#     model = SAGE(data.num_features, args.hidden_channels,
#                  args.hidden_channels, args.num_layers,
#                  args.dropout).to(device)
# else:
#     model = GCN(data.num_features, args.hidden_channels,
#                 args.hidden_channels, args.num_layers,
#                 args.dropout).to(device)

hidden_channels = 128
dropout = 0.0
num_layers = 2
log_steps = 1
batch_size = 64 * 1024
lr = 0.001
epochs = 400
eval_steps = 50
runs = 1

# model = GCN_OGB(in_channels=data.num_features, hidden_channels=hidden_channels, 
#                 out_channels=64, num_layers=num_layers, dropout=dropout).to(device)

# predictor = LinkPredictor(in_channels=hidden_channels, hidden_channels=hidden_channels, 
#                           out_channels=1, num_layers=num_layers, dropout=dropout).to(device)


model = GCN_OGB(data.num_features, hidden_channels,
                hidden_channels, num_layers,
                dropout).to(device)

predictor = LinkPredictor(hidden_channels, hidden_channels, 1,
                          num_layers, dropout).to(device)

In [25]:
evaluator = Evaluator(name='ogbl-collab')
loggers = {
    'Hits@10': Logger(runs),
    'Hits@50': Logger(runs),
    'Hits@100': Logger(runs),
}

In [26]:
def train(model, predictor, data, split_edge, optimizer, batch_size):
    model.train()
    predictor.train()

    pos_train_edge = split_edge['train']['edge'].to(data.x.device)

    total_loss = total_examples = 0
    for perm in DataLoader(range(pos_train_edge.size(0)), batch_size,
                           shuffle=True):
        optimizer.zero_grad()

        h = model(data.x, data.adj_t)

        edge = pos_train_edge[perm].t()

        pos_out = predictor(h[edge[0]], h[edge[1]])
        pos_loss = -torch.log(pos_out + 1e-15).mean()

        # Just do some trivial random sampling.
        edge = torch.randint(0, data.num_nodes, edge.size(), dtype=torch.long,
                             device=h.device)
        neg_out = predictor(h[edge[0]], h[edge[1]])
        neg_loss = -torch.log(1 - neg_out + 1e-15).mean()

        loss = pos_loss + neg_loss
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        torch.nn.utils.clip_grad_norm_(predictor.parameters(), 1.0)

        optimizer.step()

        num_examples = pos_out.size(0)
        total_loss += loss.item() * num_examples
        total_examples += num_examples

    return total_loss / total_examples


@torch.no_grad()
def test(model, predictor, data, split_edge, evaluator, batch_size):
    model.eval()
    predictor.eval()

    h = model(data.x, data.adj_t)

    pos_train_edge = split_edge['train']['edge'].to(h.device)
    pos_valid_edge = split_edge['valid']['edge'].to(h.device)
    neg_valid_edge = split_edge['valid']['edge_neg'].to(h.device)
    pos_test_edge = split_edge['test']['edge'].to(h.device)
    neg_test_edge = split_edge['test']['edge_neg'].to(h.device)

    pos_train_preds = []
    for perm in DataLoader(range(pos_train_edge.size(0)), batch_size):
        edge = pos_train_edge[perm].t()
        pos_train_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_train_pred = torch.cat(pos_train_preds, dim=0)

    pos_valid_preds = []
    for perm in DataLoader(range(pos_valid_edge.size(0)), batch_size):
        edge = pos_valid_edge[perm].t()
        pos_valid_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_valid_pred = torch.cat(pos_valid_preds, dim=0)

    neg_valid_preds = []
    for perm in DataLoader(range(neg_valid_edge.size(0)), batch_size):
        edge = neg_valid_edge[perm].t()
        neg_valid_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    neg_valid_pred = torch.cat(neg_valid_preds, dim=0)

    h = model(data.x, data.full_adj_t)

    pos_test_preds = []
    for perm in DataLoader(range(pos_test_edge.size(0)), batch_size):
        edge = pos_test_edge[perm].t()
        pos_test_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_test_pred = torch.cat(pos_test_preds, dim=0)

    neg_test_preds = []
    for perm in DataLoader(range(neg_test_edge.size(0)), batch_size):
        edge = neg_test_edge[perm].t()
        neg_test_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    neg_test_pred = torch.cat(neg_test_preds, dim=0)

    results = {}
    for K in [10, 50, 100]:
        evaluator.K = K
        train_hits = evaluator.eval({
            'y_pred_pos': pos_train_pred,
            'y_pred_neg': neg_valid_pred,
        })[f'hits@{K}']
        valid_hits = evaluator.eval({
            'y_pred_pos': pos_valid_pred,
            'y_pred_neg': neg_valid_pred,
        })[f'hits@{K}']
        test_hits = evaluator.eval({
            'y_pred_pos': pos_test_pred,
            'y_pred_neg': neg_test_pred,
        })[f'hits@{K}']

        results[f'Hits@{K}'] = (train_hits, valid_hits, test_hits)

    return results

In [27]:
for run in range(runs):
    model.reset_parameters()
    predictor.reset_parameters()
    optimizer = torch.optim.Adam(
        list(model.parameters()) + list(predictor.parameters()),
        lr=lr)

    for epoch in range(1, 1 + epochs):
        loss = train(model, predictor, data, split_edge, optimizer,
                     batch_size)

        if epoch % eval_steps == 0:
            results = test(model, predictor, data, split_edge, evaluator,
                           batch_size)
            for key, result in results.items():
                print(key, result)
#                 loggers[key].add_result(run, result)

            if epoch % log_steps == 0:
                for key, result in results.items():
                    train_hits, valid_hits, test_hits = result
                    print(key)
                    print(f'Run: {run + 1:02d}, '
                          f'Epoch: {epoch:02d}, '
                          f'Loss: {loss:.4f}, '
                          f'Train: {100 * train_hits:.2f}%, '
                          f'Valid: {100 * valid_hits:.2f}%, '
                          f'Test: {100 * test_hits:.2f}%')
                print('---')

#     for key in loggers.keys():
#         print(key)
#         loggers[key].print_statistics(run)

# for key in loggers.keys():
#     print(key)
#     loggers[key].print_statistics()

In [29]:
import os

if not osp.exists(osp.dirname(model_folder)):
    os.makedirs(osp.dirname(model_folder))
torch.save(model.state_dict(), f'{model_folder}/model.pt')

In [None]:
seed_everything(seed)
if model_name == 'gin':
    model = LinkGIN(train_data.num_features, 128, 64, sim=decoder).to(device)
    tot_epochs = 61
if model_name == 'gcn':
    model = LinkGCN(train_data.num_features, 128, 64, sim=decoder).to(device)
    tot_epochs = 201
if model_name == 'sage':
    model = LinkSAGE(train_data.num_features, 128, 64, sim=decoder).to(device)
    tot_epochs = 101
if model_name == 'vgae':
    if decoder == 'inner':
        model = DeepVGAE(train_data.num_features, 128, 64, InnerProductDecoder()).to(device)
    if decoder == 'cosine':
        model = DeepVGAE(train_data.num_features, 128, 64, CosineDecoder()).to(device)
    tot_epochs = 2001

In [None]:
if load_model:
    model.load_state_dict(torch.load(f"{model_folder}/model.pt"))
    model.eval()
    print('Model loaded')
    # val_auc = test(model, val_data)
    test_auc, test_accuracy = test(model, test_data)
    # print(f'Val auc: {val_auc:.4f}, Test auc: {test_auc:.4f}')
    print(f'Test auc: {test_auc:.4f}, Test accuracy: {test_accuracy:.4f}')

else:
    seed_everything(0)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)
    for epoch in range(1, tot_epochs):
        loss = train(model, optimizer, train_data)
        if epoch % 20 == 0:
            if model_name == 'vgae':
                # val_auc = test(model, train_data, val_data)
                test_auc, test_accuracy = test(model, train_data, test_data)
            else:
                # val_auc = test(model, val_data)
                test_auc, test_accuracy = test(model, test_data)
            # print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, '
            #       f'Test: {test_auc:.4f}')
            print(
                f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Test auc: {test_auc:.4f}, Test accuracy: {test_accuracy:.4f}')
    if not osp.exists(osp.dirname(model_folder)):
        os.makedirs(osp.dirname(model_folder))
    torch.save(model.state_dict(), f'{model_folder}/model.pt')

In [None]:
# Take the first 100 examples.
# edge_label_index is the edge that is to be predicted
metric_list = []

# select random edges
selected_nodes_random_seed = 42
num_selected_query_edges = 100

# select edges with positive labels, and sample the num we need
# TODO: val_data -> test_data
positive_idx = (test_data.edge_label == torch.ones(test_data.edge_label.shape)).nonzero().flatten().numpy()
print('The number of positive edges: ', len(positive_idx))
# num_selected_query_edges = min(num_selected_query_edges, len(positive_idx))

if num_selected_query_edges > len(positive_idx):
    num_selected_query_edges = len(positive_idx)

# TODO: could still result in the problem of negative prediction despite the positive label
np.random.seed(selected_nodes_random_seed)
selected_edge_pair_idx = np.random.choice(positive_idx, num_selected_query_edges, replace=False)
print('Selected edges index: ', selected_edge_pair_idx)