In [1]:
import numpy as np
np.random.seed(0)
import matplotlib.pyplot as plt
import torch
torch.manual_seed(0)
import torch_geometric.transforms as T
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import torch.nn.functional as F

from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv, VGAE, aggr
from torch_geometric.utils import k_hop_subgraph, to_scipy_sparse_matrix
from sklearn.metrics import roc_auc_score, average_precision_score
from scipy.sparse.csgraph import shortest_path
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
transform = T.Compose([
    T.NormalizeFeatures(),
    T.ToDevice(device),
    T.RandomLinkSplit(
        num_val=0.05,
        num_test=0.1,
        is_undirected=True,
        add_negative_train_samples=False,
        split_labels=True,
    )])
cora_dataset = Planetoid(root='../data/Cora', name='Cora', transform=transform)

In [3]:
train_data, val_data, test_data = cora_dataset[0]

In [4]:
class Encoder(torch.nn.Module):
    def __init__(self, dim_in,dim_out):
        super(Encoder, self).__init__()
        self.conv1 = GCNConv(dim_in, dim_out*2)
        self.conv_mu = GCNConv(dim_out*2, dim_out)
        self.conv_logstd = GCNConv(2*dim_out, dim_out)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        mu = self.conv_mu(x, edge_index)
        logstd = self.conv_logstd(x, edge_index)
        return mu, logstd
    
model = VGAE(Encoder(cora_dataset.num_features, 16)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(train_data.x, train_data.edge_index)
    loss = model.recon_loss(z, train_data.pos_edge_label_index) 
    loss += (1 / train_data.num_nodes) * model.kl_loss()
    loss.backward()
    optimizer.step()
    return float(loss)
@torch.no_grad()
def test(data):
    model.eval()
    z = model.encode(data.x, data.edge_index)
    return model.test(z, data.pos_edge_label_index, data.neg_edge_label_index)

for epoch in range(301):
    loss = train()
    if epoch % 50 == 0:
        print(f'Epoch {epoch:03d}, Loss: {loss:.4f}')
        auc, ap = test(val_data)
        print(f'Validation AUC: {auc:.4f}, AP: {ap:.4f}')

Epoch 000, Loss: 3.4718
Validation AUC: 0.7351, AP: 0.7586
Epoch 050, Loss: 1.3337
Validation AUC: 0.7270, AP: 0.7591
Epoch 100, Loss: 1.2080
Validation AUC: 0.7598, AP: 0.7930
Epoch 150, Loss: 1.0411
Validation AUC: 0.8266, AP: 0.8387
Epoch 200, Loss: 1.0026
Validation AUC: 0.8602, AP: 0.8766
Epoch 250, Loss: 0.9948
Validation AUC: 0.8730, AP: 0.8838
Epoch 300, Loss: 0.9635
Validation AUC: 0.8834, AP: 0.8953


In [5]:
test_auc, test_ap = test(test_data)
print(f'Test AUC: {test_auc:.4f}, AP: {test_ap:.4f}')

Test AUC: 0.8564, AP: 0.8560


In [6]:
z= model.encode(test_data.x, test_data.edge_index)
Ahat = torch.sigmoid(z @z.T)
Ahat

tensor([[0.8063, 0.7095, 0.7639,  ..., 0.4197, 0.7835, 0.7631],
        [0.7095, 0.8262, 0.8304,  ..., 0.5361, 0.7983, 0.7014],
        [0.7639, 0.8304, 0.8474,  ..., 0.5167, 0.8228, 0.7478],
        ...,
        [0.4197, 0.5361, 0.5167,  ..., 0.5859, 0.4664, 0.4564],
        [0.7835, 0.7983, 0.8228,  ..., 0.4664, 0.8185, 0.7523],
        [0.7631, 0.7014, 0.7478,  ..., 0.4564, 0.7523, 0.7309]],
       device='cuda:0', grad_fn=<SigmoidBackward0>)

In [7]:
transform = RandomLinkSplit(
    num_val=0.05,
    num_test=0.1,
    is_undirected=True,
    split_labels=True,
)

cora_dataset = Planetoid(root='../data/Cora', name='Cora', transform=transform)
train_data, val_data, test_data = cora_dataset[0]
train_data

Data(x=[2708, 1433], edge_index=[2, 8976], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], pos_edge_label=[4488], pos_edge_label_index=[2, 4488], neg_edge_label=[4488], neg_edge_label_index=[2, 4488])

In [10]:
def seal_processing(data, edge_label_index, y):
    data_list = []
    for src, dst in edge_label_index.t().tolist():
        sub_nodes, sub_edge_index, mapping, _= k_hop_subgraph([src, dst], 2, data.edge_index, relabel_nodes=True)
        src, dst = mapping.tolist()
        mask1 = (sub_edge_index[0] != src) | (sub_edge_index[1] != dst)
        mask2 = (sub_edge_index[0] != dst) | (sub_edge_index[1] != src)
        sub_edge_index = sub_edge_index[:, mask1 & mask2]
        src, dst = (dst, src) if src> dst else (src, dst)
        adj = to_scipy_sparse_matrix(sub_edge_index, num_nodes=sub_nodes.size(0)).tocsr()
        idx = list(range(src)) + list(range(src+1, adj.shape[0]))
        adj_wo_src = adj[idx,:][:, idx]
        idx = list(range(dst)) + list(range(dst+1, adj.shape[0]))
        adj_wo_dst = adj[idx,:][:, idx]
        d_src = shortest_path(adj_wo_dst , directed=False, unweighted=True, indices=src)
        d_src = np.insert(d_src, dst, 0, axis= 0)
        d_src = torch.from_numpy(d_src)
        d_dst = shortest_path(adj_wo_src , directed=False, unweighted=True, indices=dst-1)
        d_dst = np.insert(d_dst, src, 0, axis= 0)
        d_dst = torch.from_numpy(d_dst)
        dist = d_src + d_dst
        z= 1 + torch.min(d_src, d_dst) + dist//2*(dist//2 + dist % 2 -1)
        z[src], z[dst], z[torch.isnan(z)] = 1.,1.,0.
        z = z.to(torch.long)
        node_lables = F.one_hot(z, num_classes=200).to(torch.float)
        node_emb = data.x[sub_nodes]
        node_x = torch.cat([node_emb, node_lables], dim=1)
        data_obj = Data(x=node_x,z =z, edge_index=sub_edge_index, y=y)
        data_list.append(data_obj)
    return data_list

In [9]:
train_pos_data_list = seal_processing(train_data, train_data.pos_edge_label_index, 1)
train_neg_data_list = seal_processing(train_data, train_data.neg_edge_label_index, 0)
val_pos_data_list = seal_processing(val_data, val_data.pos_edge_label_index, 1)
val_neg_data_list = seal_processing(val_data, val_data.neg_edge_label_index, 0)
test_pos_data_list = seal_processing(test_data, test_data.pos_edge_label_index, 1)
test_neg_data_list = seal_processing(test_data, test_data.neg_edge_label_index, 0)

In [11]:
train_dataset = train_pos_data_list + train_neg_data_list
val_dataset = val_pos_data_list + val_neg_data_list
test_dataset = test_pos_data_list + test_neg_data_list

In [32]:
train_dataset[0].y

1

In [None]:
train_dataset[1].num_features

1633

In [34]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
class DGCNN(torch.nn.Module):
    def __init__(self, dim_in,dim_h, dim_out, k=30):
        super(DGCNN, self).__init__()
        self.gcn1 = GCNConv(dim_in, dim_h) # 
        self.gcn2 = GCNConv(dim_h, dim_h)
        self.gcn3 = GCNConv(dim_h, dim_h)
        self.gcn4 = GCNConv(dim_h, dim_out)
        self.global_pool = aggr.SortAggregation(k=k)
        self.conv1 = torch.nn.Conv1d(1,16,97,97)
        self.conv2 = torch.nn.Conv1d(16,32,5,1)
        self.maxpool = torch.nn.MaxPool1d(2,2)
        self.linear1 = torch.nn.Linear(352, 128)
        self.dropout = torch.nn.Dropout(0.5)
        self.linear2 = torch.nn.Linear(128, 1)
    def forward(self, x, edge_index, batch):
        h1 = self.gcn1(x, edge_index).tanh()
        h2 = self.gcn2(h1, edge_index).tanh()
        h3 = self.gcn3(h2, edge_index).tanh()
        h4 = self.gcn4(h3, edge_index).tanh()
        h= torch.cat([h1, h2, h3, h4], dim=-1)
        h = self.global_pool(h, batch)
        h= h.view(h.size(0), 1, h.size(-1))
        h = self.conv1(h)
        h = self.maxpool(h)
        h = self.conv2(h)
        h = h.view(h.size(0), -1)
        h = self.linear1(h)
        h = self.dropout(h)
        h = self.linear2(h).sigmoid()
        return h

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DGCNN(train_dataset[0].num_features, 32,train_dataset[0].y, k = 30).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = torch.nn.BCELoss()
def train():
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out.view(-1), data.y.to(torch.float))
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs
    return total_loss / len(train_dataset)
@torch.no_grad()
def test(loader):
    model.eval()
    y_pred, y_true = [], []
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)
        y_pred.append(out.view(-1).cpu())
        y_true.append(data.y.view(-1).cpu().to(torch.float))
    auc = roc_auc_score(torch.cat(y_true), torch.cat(y_pred))
    ap = average_precision_score(torch.cat(y_true), torch.cat(y_pred))
    return auc, ap
epochs = 30
for epoch in range(epochs):
    loss = train()
    val_auc, val_ap = test(val_loader)
    print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, Validation AUC: {val_auc:.4f}, AP: {val_ap:.4f}')


Epoch 000, Loss: 0.5937, Validation AUC: 0.8297, AP: 0.8438
Epoch 001, Loss: 0.4681, Validation AUC: 0.8569, AP: 0.8820
Epoch 002, Loss: 0.4263, Validation AUC: 0.8699, AP: 0.8948
Epoch 003, Loss: 0.3927, Validation AUC: 0.8831, AP: 0.9063
Epoch 004, Loss: 0.3616, Validation AUC: 0.8906, AP: 0.9153
Epoch 005, Loss: 0.3309, Validation AUC: 0.8900, AP: 0.9142
Epoch 006, Loss: 0.3110, Validation AUC: 0.8845, AP: 0.9109
Epoch 007, Loss: 0.2965, Validation AUC: 0.8806, AP: 0.9080
Epoch 008, Loss: 0.2873, Validation AUC: 0.8744, AP: 0.9029
Epoch 009, Loss: 0.2768, Validation AUC: 0.8734, AP: 0.9021
Epoch 010, Loss: 0.2681, Validation AUC: 0.8677, AP: 0.9003
Epoch 011, Loss: 0.2641, Validation AUC: 0.8628, AP: 0.8962
Epoch 012, Loss: 0.2582, Validation AUC: 0.8627, AP: 0.8962
Epoch 013, Loss: 0.2528, Validation AUC: 0.8653, AP: 0.8981
Epoch 014, Loss: 0.2478, Validation AUC: 0.8673, AP: 0.8983
Epoch 015, Loss: 0.2436, Validation AUC: 0.8662, AP: 0.8976
Epoch 016, Loss: 0.2362, Validation AUC:

In [35]:
test_auc, test_ap = test(test_loader)
print(f'Test AUC: {test_auc:.4f}, AP: {test_ap:.4f}')

Test AUC: 0.9115, AP: 0.9335
