In [1]:
import torch
import torch_geometric

train_data = torch.load('train_data.pt')
val_data = torch.load('val_data.pt')
test_data = torch.load('test_data.pt')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from sklearn.metrics import roc_auc_score, average_precision_score, recall_score
from scipy.sparse.csgraph import shortest_path

import torch.nn.functional as F
from torch.nn import Conv1d, MaxPool1d, Linear, Dropout, BCEWithLogitsLoss, GRU

from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, aggr, global_sort_pool
from torch_geometric.utils import k_hop_subgraph, to_scipy_sparse_matrix

In [3]:
from torch_geometric.data import Data
import numpy as np
import torch
from torch.nn import functional as F
from scipy.sparse.csgraph import shortest_path
from torch_geometric.utils import k_hop_subgraph, to_scipy_sparse_matrix

def seal_processing(dataset, edge_label_index, y, max_dist=6):
    data_list = []
    for src, dst in edge_label_index.t().tolist():
        # 提取子图
        sub_nodes, sub_edge_index, mapping, _ = k_hop_subgraph([src, dst], 2, dataset.edge_index, relabel_nodes=True)
        src, dst = mapping.tolist()

        # 从子图中移除目标边
        mask1 = (sub_edge_index[0] != src) | (sub_edge_index[1] != dst)
        mask2 = (sub_edge_index[0] != dst) | (sub_edge_index[1] != src)
        sub_edge_index = sub_edge_index[:, mask1 & mask2]

        # 确保 src < dst
        src, dst = (dst, src) if src > dst else (src, dst)

        # 计算邻接矩阵
        adj = to_scipy_sparse_matrix(sub_edge_index, num_nodes=sub_nodes.size(0)).tocsr()

        # 计算距离编码
        dist = shortest_path(adj, directed=False, unweighted=True, indices=[src, dst])
        dist = torch.from_numpy(dist).to(torch.long)

        # 处理距离编码
        dist[dist > max_dist] = max_dist
        dist[torch.isnan(dist)] = max_dist + 1

        # 将距离转置以匹配预期的形状 (num_nodes, 2)
        dist = dist.t()

        # 对每一个距离进行 one-hot 编码
        node_labels_src = F.one_hot(dist[:, 0], num_classes=max_dist + 2).to(torch.float)
        node_labels_dst = F.one_hot(dist[:, 1], num_classes=max_dist + 2).to(torch.float)

        # 将两个 one-hot 编码的结果拼接起来
        node_labels = torch.cat([node_labels_src, node_labels_dst], dim=1)

        # 获取子图中的节点特征
        node_emb = dataset.x[sub_nodes]
        
        # 拼接节点特征和距离标签
        node_x = torch.cat([node_emb, node_labels], dim=1)

        # 创建数据对象
        data = Data(x=node_x, z=dist, edge_index=sub_edge_index, y=y)
        data_list.append(data)

    return data_list

In [4]:
# Enclosing subgraphs extraction
train_pos_data_list = seal_processing(train_data, train_data.pos_edge_label_index, 1)
train_neg_data_list = seal_processing(train_data, train_data.neg_edge_label_index, 0)

In [5]:
val_pos_data_list = seal_processing(val_data, val_data.pos_edge_label_index, 1)
val_neg_data_list = seal_processing(val_data, val_data.neg_edge_label_index, 0)

In [6]:
test_pos_data_list = seal_processing(test_data, test_data.pos_edge_label_index, 1)
test_neg_data_list = seal_processing(test_data, test_data.neg_edge_label_index, 0)

In [7]:
train_dataset = train_pos_data_list + train_neg_data_list
val_dataset = val_pos_data_list + val_neg_data_list
test_dataset = test_pos_data_list + test_neg_data_list

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [8]:
class GCN(torch.nn.Module):
    def __init__(self, dim_in):
        super(GCN, self).__init__()
        # GCN layers
        self.gcn1 = GCNConv(dim_in, 32)
        self.gcn2 = GCNConv(32, 32)
        self.gcn3 = GCNConv(32, 32)
        self.gcn4 = GCNConv(32, 1)

        self.lin1 = Linear(97, dim_in)
        self.lin2 = Linear(dim_in, 1)

    def forward(self, x, edge_index, batch):
        x1 = self.gcn1(x, edge_index).tanh()
        x2 = self.gcn2(x1, edge_index).tanh()
        x3 = self.gcn3(x2, edge_index).tanh()
        x4 = self.gcn4(x3, edge_index).tanh()
        x = torch.cat([x1, x2, x3, x4], dim=-1)

        _, center_indices = np.unique(batch.cpu().numpy(), return_index=True)
        x_src = x[center_indices]
        x_dst = x[center_indices + 1]
        x = (x_src * x_dst)
        x = F.relu(self.lin1(x))
        x = self.lin2(x)

        return x

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(train_dataset[0].num_features).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
criterion = BCEWithLogitsLoss()

In [10]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

def train():
    model.train()
    total_loss = 0

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out.view(-1), data.y.to(torch.float))
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs

    return total_loss / len(train_dataset)

@torch.no_grad()
def test(loader):
    model.eval()
    y_pred, y_true = [], []

    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)
        y_pred.append(out.view(-1).cpu())
        y_true.append(data.y.view(-1).cpu().to(torch.float))

    y_pred_binary = (torch.cat(y_pred) > 0.5).numpy()
    y_true_array = torch.cat(y_true).numpy()

    auc = roc_auc_score(y_true_array, torch.cat(y_pred))
    accuracy = accuracy_score(y_true_array, y_pred_binary)
    f1 = f1_score(y_true_array, y_pred_binary)
    precision = precision_score(y_true_array, y_pred_binary)
    recall = recall_score(y_true_array, y_pred_binary)

    return auc, accuracy, f1, precision, recall
    

In [11]:
import matplotlib.pyplot as plt
train_loss = []
for epoch in range(200):
    loss = train()
    val_results = test(val_loader)
    val_auc, val_accuracy, val_f1, val_precision, val_recall = val_results
    print(f'Epoch {epoch:>2} | Loss: {loss:.4f} | Val AUC: {val_auc:.4f} ')
    train_loss.append(loss)

test_results = test(test_loader)
test_auc, test_accuracy, test_f1, test_precision, test_recall = test_results 
print(f'Test AUC: {test_auc:.4f} | Test Accuracy: {test_accuracy:.4f} | Test F1: {test_f1:.4f} | Test Precision: {test_precision:.4f} | Test Recall: {test_recall:.4f}')

Epoch  0 | Loss: 0.2879 | Val AUC: 0.7955 
Epoch  1 | Loss: 0.2557 | Val AUC: 0.8235 
Epoch  2 | Loss: 0.2465 | Val AUC: 0.8354 
Epoch  3 | Loss: 0.2354 | Val AUC: 0.8327 
Epoch  4 | Loss: 0.2337 | Val AUC: 0.8310 
Epoch  5 | Loss: 0.2242 | Val AUC: 0.8427 
Epoch  6 | Loss: 0.2180 | Val AUC: 0.8365 
Epoch  7 | Loss: 0.2160 | Val AUC: 0.8382 
Epoch  8 | Loss: 0.2139 | Val AUC: 0.8334 
Epoch  9 | Loss: 0.2107 | Val AUC: 0.8255 
Epoch 10 | Loss: 0.2080 | Val AUC: 0.8349 
Epoch 11 | Loss: 0.2067 | Val AUC: 0.8370 
Epoch 12 | Loss: 0.2016 | Val AUC: 0.8389 
Epoch 13 | Loss: 0.1987 | Val AUC: 0.8432 
Epoch 14 | Loss: 0.1977 | Val AUC: 0.8407 
Epoch 15 | Loss: 0.1937 | Val AUC: 0.8366 
Epoch 16 | Loss: 0.1950 | Val AUC: 0.8381 
Epoch 17 | Loss: 0.1928 | Val AUC: 0.8410 
Epoch 18 | Loss: 0.1899 | Val AUC: 0.8363 
Epoch 19 | Loss: 0.1903 | Val AUC: 0.8404 
Epoch 20 | Loss: 0.1879 | Val AUC: 0.8338 
Epoch 21 | Loss: 0.1842 | Val AUC: 0.8390 
Epoch 22 | Loss: 0.1857 | Val AUC: 0.8314 
Epoch 23 | 