In [1]:
import torch
import torch_geometric

train_data = torch.load('train_data.pt')
val_data = torch.load('val_data.pt')
test_data = torch.load('test_data.pt')

In [2]:
from sklearn.metrics import roc_auc_score, average_precision_score, recall_score
from scipy.sparse.csgraph import shortest_path

import torch.nn.functional as F
from torch.nn import Conv1d, MaxPool1d, Linear, Dropout, BCEWithLogitsLoss, GRU

from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, aggr, global_sort_pool
from torch_geometric.utils import k_hop_subgraph, to_scipy_sparse_matrix

In [3]:
from torch_geometric.data import Data
import numpy as np
def seal_processing(dataset, edge_label_index, y):
    data_list = []
    for src, dst in edge_label_index.t().tolist():
        sub_nodes, sub_edge_index, mapping, _ = k_hop_subgraph([src, dst], 2, dataset.edge_index, relabel_nodes=True)
        src, dst = mapping.tolist()

        # Remove target link from the subgraph
        mask1 = (sub_edge_index[0] != src) | (sub_edge_index[1] != dst)
        mask2 = (sub_edge_index[0] != dst) | (sub_edge_index[1] != src)
        sub_edge_index = sub_edge_index[:, mask1 & mask2]

        # Double-radius node labeling (DRNL)
        src, dst = (dst, src) if src > dst else (src, dst)
        adj = to_scipy_sparse_matrix(sub_edge_index, num_nodes=sub_nodes.size(0)).tocsr()

        idx = list(range(src)) + list(range(src + 1, adj.shape[0]))
        adj_wo_src = adj[idx, :][:, idx]

        idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))
        adj_wo_dst = adj[idx, :][:, idx]

        # Calculate the distance between every node and the source target node
        d_src = shortest_path(adj_wo_dst, directed=False, unweighted=True, indices=src)
        d_src = np.insert(d_src, dst, 0, axis=0)
        d_src = torch.from_numpy(d_src)

        # Calculate the distance between every node and the destination target node
        d_dst = shortest_path(adj_wo_src, directed=False, unweighted=True, indices=dst-1)
        d_dst = np.insert(d_dst, src, 0, axis=0)
        d_dst = torch.from_numpy(d_dst)

        # Calculate the label z for each node
        dist = d_src + d_dst
        z = 1 + torch.min(d_src, d_dst) + dist // 2 * (dist // 2 + dist % 2 - 1)
        z[src], z[dst], z[torch.isnan(z)] = 1., 1., 0.
        z = z.to(torch.long)

        # Concatenate node features and one-hot encoded node labels (with a fixed number of classes)
        node_labels = F.one_hot(z, num_classes=200).to(torch.float)
        node_emb = dataset.x[sub_nodes]
        node_x = torch.cat([node_emb, node_labels], dim=1)

        # Create data object
        data = Data(x=node_x, z=z, edge_index=sub_edge_index, y=y)
        data_list.append(data)

    return data_list

In [4]:
# Enclosing subgraphs extraction
train_pos_data_list = seal_processing(train_data, train_data.pos_edge_label_index, 1)
train_neg_data_list = seal_processing(train_data, train_data.neg_edge_label_index, 0)

In [5]:
val_pos_data_list = seal_processing(val_data, val_data.pos_edge_label_index, 1)
val_neg_data_list = seal_processing(val_data, val_data.neg_edge_label_index, 0)

In [6]:
test_pos_data_list = seal_processing(test_data, test_data.pos_edge_label_index, 1)
test_neg_data_list = seal_processing(test_data, test_data.neg_edge_label_index, 0)

In [7]:
train_dataset = train_pos_data_list + train_neg_data_list
val_dataset = val_pos_data_list + val_neg_data_list
test_dataset = test_pos_data_list + test_neg_data_list

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [8]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv
class GAT(torch.nn.Module):
    def __init__(self, dim_in, num_heads):
        super(GAT, self).__init__()
        # GCN layers
        self.gcn1 = GATConv(dim_in, 32, heads=num_heads, concat=True)
        self.gcn2 = GATConv(64,1, heads=1, concat=False)

        self.lin1 = Linear(65, dim_in)
        self.lin2 = Linear(dim_in, 1)

    def forward(self, x, edge_index, batch):
        x1 = self.gcn1(x, edge_index).tanh()
        x2 = self.gcn2(x1, edge_index).tanh()
        x = torch.cat([x1, x2], dim=-1)

        _, center_indices = np.unique(batch.cpu().numpy(), return_index=True)
        x_src = x[center_indices]
        x_dst = x[center_indices + 1]
        x = (x_src * x_dst)
        x = F.relu(self.lin1(x))
        x = self.lin2(x)

        return x

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GAT(train_dataset[0].num_features,2).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)
criterion = BCEWithLogitsLoss()

In [10]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

def train():
    model.train()
    total_loss = 0

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out.view(-1), data.y.to(torch.float))
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs

    return total_loss / len(train_dataset)

@torch.no_grad()
def test(loader):
    model.eval()
    y_pred, y_true = [], []

    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)
        y_pred.append(out.view(-1).cpu())
        y_true.append(data.y.view(-1).cpu().to(torch.float))

    y_pred_binary = (torch.cat(y_pred) > 0.5).numpy()
    y_true_array = torch.cat(y_true).numpy()

    auc = roc_auc_score(y_true_array, torch.cat(y_pred))
    accuracy = accuracy_score(y_true_array, y_pred_binary)
    f1 = f1_score(y_true_array, y_pred_binary)
    precision = precision_score(y_true_array, y_pred_binary)
    recall = recall_score(y_true_array, y_pred_binary)

    return auc, accuracy, f1, precision, recall
    

In [11]:
import matplotlib.pyplot as plt
train_loss = []
for epoch in range(100):
    loss = train()
    val_results = test(val_loader)
    val_auc, val_accuracy, val_f1, val_precision, val_recall = val_results
    print(f'Epoch {epoch:>2} | Loss: {loss:.4f} | Val AUC: {val_auc:.4f} ')
    train_loss.append(loss)

test_results = test(test_loader)
test_auc, test_accuracy, test_f1, test_precision, test_recall = test_results 
print(f'Test AUC: {test_auc:.4f} | Test Accuracy: {test_accuracy:.4f} | Test F1: {test_f1:.4f} | Test Precision: {test_precision:.4f} | Test Recall: {test_recall:.4f}')

Epoch  0 | Loss: 0.3151 | Val AUC: 0.7806 
Epoch  1 | Loss: 0.2635 | Val AUC: 0.8083 
Epoch  2 | Loss: 0.2544 | Val AUC: 0.8204 
Epoch  3 | Loss: 0.2474 | Val AUC: 0.8174 
Epoch  4 | Loss: 0.2424 | Val AUC: 0.8211 
Epoch  5 | Loss: 0.2375 | Val AUC: 0.8230 
Epoch  6 | Loss: 0.2326 | Val AUC: 0.8260 
Epoch  7 | Loss: 0.2287 | Val AUC: 0.8162 
Epoch  8 | Loss: 0.2252 | Val AUC: 0.8275 
Epoch  9 | Loss: 0.2210 | Val AUC: 0.8173 
Epoch 10 | Loss: 0.2178 | Val AUC: 0.8325 
Epoch 11 | Loss: 0.2154 | Val AUC: 0.8312 
Epoch 12 | Loss: 0.2123 | Val AUC: 0.8249 
Epoch 13 | Loss: 0.2102 | Val AUC: 0.8303 
Epoch 14 | Loss: 0.2077 | Val AUC: 0.8204 
Epoch 15 | Loss: 0.2055 | Val AUC: 0.8256 
Epoch 16 | Loss: 0.2027 | Val AUC: 0.8296 
Epoch 17 | Loss: 0.2018 | Val AUC: 0.8361 
Epoch 18 | Loss: 0.1986 | Val AUC: 0.8314 
Epoch 19 | Loss: 0.1977 | Val AUC: 0.8320 
Epoch 20 | Loss: 0.1946 | Val AUC: 0.8280 
Epoch 21 | Loss: 0.1941 | Val AUC: 0.8248 
Epoch 22 | Loss: 0.1921 | Val AUC: 0.8301 
Epoch 23 | 