In [1]:
import os.path as osp
import argparse

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.datasets import Planetoid, Coauthor, Amazon
from torch_geometric.utils import train_test_split_edges
from torch_geometric.nn import GAE, VGAE, APPNP
import torch_geometric.transforms as T

from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score

import pandas as pd
import numpy as np
from torch_geometric.data import Data

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='VGNAE')
parser.add_argument('--dataset', type=str, default='citeseer')
parser.add_argument('--epochs', type=int, default=800)
parser.add_argument('--channels', type=int, default=128)
parser.add_argument('--scaling_factor', type=float, default=1.8)
parser.add_argument('--training_rate', type=float, default=0.9) 
args,_ = parser.parse_known_args()

In [3]:
nodes_number = 3327

filename_adj = "datasets/citeseer-edges.txt"

raw_edges = pd.read_csv(filename_adj, header=None)

drop_self_loop = raw_edges[raw_edges[0]!=raw_edges[1]]

graph_np=np.zeros((nodes_number, nodes_number))

for i in range(drop_self_loop.shape[0]):
    graph_np[drop_self_loop.iloc[i,0],drop_self_loop.iloc[i,1]]=1
    graph_np[drop_self_loop.iloc[i,1],drop_self_loop.iloc[i,0]]=1
    
edges = torch.tensor([list(graph_np.nonzero()[0]),list(graph_np.nonzero()[1])])

features = torch.eye(nodes_number)

data = Data(x=features, edge_index=edges)

data = T.NormalizeFeatures()(data)

In [4]:
class Encoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels, edge_index):
        super(Encoder, self).__init__()
        self.linear1 = nn.Linear(in_channels, out_channels)
        self.linear2 = nn.Linear(in_channels, out_channels)
        self.propagate = APPNP(K=1, alpha=0)

    def forward(self, x, edge_index,not_prop=0):
        if args.model == 'GNAE':
            x = self.linear1(x)
            x = F.normalize(x,p=2,dim=1)  * args.scaling_factor
            x = self.propagate(x, edge_index)
            return x

        if args.model == 'VGNAE':
            x_ = self.linear1(x)
            x_ = self.propagate(x_, edge_index)

            x = self.linear2(x)
            x = F.normalize(x,p=2,dim=1) * args.scaling_factor
            x = self.propagate(x, edge_index)
            return x, x_

        return x

In [5]:
dev = torch.device('cpu')
channels = args.channels
train_rate = args.training_rate
val_ratio = (1-args.training_rate) / 3
test_ratio = (1-args.training_rate) / 3 * 2
data = train_test_split_edges(data.to(dev), val_ratio=val_ratio, test_ratio=test_ratio)

N = int(data.x.size()[0])
if args.model == 'GNAE':   
    model = GAE(Encoder(data.x.size()[1], channels, data.train_pos_edge_index)).to(dev)
if args.model == 'VGNAE':
    model = VGAE(Encoder(data.x.size()[1], channels, data.train_pos_edge_index)).to(dev)

data.train_mask = data.val_mask = data.test_mask = data.y = None
x, train_pos_edge_index = data.x.to(dev), data.train_pos_edge_index.to(dev)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)



In [6]:
def computer_indi(z, pos_edge_index, neg_edge_index, model):
    pos_y = z.new_ones(pos_edge_index.size(1))
    neg_y = z.new_zeros(neg_edge_index.size(1))
    y = torch.cat([pos_y, neg_y], dim=0)

    pos_pred = model.decoder(z, pos_edge_index, sigmoid=True)
    neg_pred = model.decoder(z, neg_edge_index, sigmoid=True)
    pred = torch.cat([pos_pred, neg_pred], dim=0)

    y, pred = y.detach().cpu().numpy(), pred.detach().cpu().numpy()
    
    auc = roc_auc_score(y, pred)
    ap = average_precision_score(y, pred)

    return ap, auc

In [7]:
def train():
    model.train()
    optimizer.zero_grad()
    z  = model.encode(x, train_pos_edge_index)
    loss = model.recon_loss(z, train_pos_edge_index)
    if args.model in ['VGAE']:
        loss = loss + (1 / data.num_nodes) * model.kl_loss()
    loss.backward()
    optimizer.step()
    return loss

def test(pos_edge_index, neg_edge_index, plot_his=0):
    model.eval()
    with torch.no_grad():
        z = model.encode(x, train_pos_edge_index)
    return computer_indi(z, pos_edge_index, neg_edge_index, model)

In [8]:
for epoch in range(1,args.epochs):
    loss = train()
    loss = float(loss)
    
    with torch.no_grad():
        test_pos, test_neg = data.test_pos_edge_index, data.test_neg_edge_index
        ap, auc = test(data.test_pos_edge_index, data.test_neg_edge_index)
        print('Epoch: {:03d}, AP: {:.4f}, AUC: {:.4f}'.format(epoch, ap, auc))

Epoch: 001, AP: 0.6953, AUC: 0.6676
Epoch: 002, AP: 0.7312, AUC: 0.7035
Epoch: 003, AP: 0.7640, AUC: 0.7319
Epoch: 004, AP: 0.7904, AUC: 0.7496
Epoch: 005, AP: 0.8105, AUC: 0.7623
Epoch: 006, AP: 0.8252, AUC: 0.7727
Epoch: 007, AP: 0.8346, AUC: 0.7804
Epoch: 008, AP: 0.8379, AUC: 0.7820
Epoch: 009, AP: 0.8394, AUC: 0.7812
Epoch: 010, AP: 0.8392, AUC: 0.7794
Epoch: 011, AP: 0.8380, AUC: 0.7767
Epoch: 012, AP: 0.8376, AUC: 0.7747
Epoch: 013, AP: 0.8371, AUC: 0.7732
Epoch: 014, AP: 0.8371, AUC: 0.7717
Epoch: 015, AP: 0.8368, AUC: 0.7704
Epoch: 016, AP: 0.8371, AUC: 0.7699
Epoch: 017, AP: 0.8381, AUC: 0.7710
Epoch: 018, AP: 0.8395, AUC: 0.7728
Epoch: 019, AP: 0.8416, AUC: 0.7752
Epoch: 020, AP: 0.8433, AUC: 0.7775
Epoch: 021, AP: 0.8452, AUC: 0.7804
Epoch: 022, AP: 0.8471, AUC: 0.7833
Epoch: 023, AP: 0.8484, AUC: 0.7853
Epoch: 024, AP: 0.8497, AUC: 0.7870
Epoch: 025, AP: 0.8502, AUC: 0.7877
Epoch: 026, AP: 0.8511, AUC: 0.7886
Epoch: 027, AP: 0.8518, AUC: 0.7891
Epoch: 028, AP: 0.8525, AUC:

Epoch: 230, AP: 0.8652, AUC: 0.8041
Epoch: 231, AP: 0.8652, AUC: 0.8040
Epoch: 232, AP: 0.8653, AUC: 0.8044
Epoch: 233, AP: 0.8654, AUC: 0.8044
Epoch: 234, AP: 0.8655, AUC: 0.8049
Epoch: 235, AP: 0.8657, AUC: 0.8051
Epoch: 236, AP: 0.8659, AUC: 0.8054
Epoch: 237, AP: 0.8662, AUC: 0.8059
Epoch: 238, AP: 0.8664, AUC: 0.8063
Epoch: 239, AP: 0.8665, AUC: 0.8065
Epoch: 240, AP: 0.8669, AUC: 0.8072
Epoch: 241, AP: 0.8670, AUC: 0.8074
Epoch: 242, AP: 0.8670, AUC: 0.8075
Epoch: 243, AP: 0.8669, AUC: 0.8074
Epoch: 244, AP: 0.8665, AUC: 0.8069
Epoch: 245, AP: 0.8662, AUC: 0.8065
Epoch: 246, AP: 0.8660, AUC: 0.8062
Epoch: 247, AP: 0.8657, AUC: 0.8056
Epoch: 248, AP: 0.8653, AUC: 0.8050
Epoch: 249, AP: 0.8652, AUC: 0.8048
Epoch: 250, AP: 0.8653, AUC: 0.8048
Epoch: 251, AP: 0.8653, AUC: 0.8049
Epoch: 252, AP: 0.8654, AUC: 0.8050
Epoch: 253, AP: 0.8654, AUC: 0.8050
Epoch: 254, AP: 0.8653, AUC: 0.8048
Epoch: 255, AP: 0.8652, AUC: 0.8046
Epoch: 256, AP: 0.8651, AUC: 0.8044
Epoch: 257, AP: 0.8650, AUC:

Epoch: 459, AP: 0.8700, AUC: 0.8104
Epoch: 460, AP: 0.8700, AUC: 0.8103
Epoch: 461, AP: 0.8699, AUC: 0.8102
Epoch: 462, AP: 0.8698, AUC: 0.8099
Epoch: 463, AP: 0.8697, AUC: 0.8098
Epoch: 464, AP: 0.8697, AUC: 0.8097
Epoch: 465, AP: 0.8695, AUC: 0.8097
Epoch: 466, AP: 0.8695, AUC: 0.8096
Epoch: 467, AP: 0.8695, AUC: 0.8097
Epoch: 468, AP: 0.8694, AUC: 0.8096
Epoch: 469, AP: 0.8693, AUC: 0.8095
Epoch: 470, AP: 0.8694, AUC: 0.8097
Epoch: 471, AP: 0.8695, AUC: 0.8099
Epoch: 472, AP: 0.8697, AUC: 0.8103
Epoch: 473, AP: 0.8697, AUC: 0.8104
Epoch: 474, AP: 0.8696, AUC: 0.8102
Epoch: 475, AP: 0.8696, AUC: 0.8101
Epoch: 476, AP: 0.8695, AUC: 0.8101
Epoch: 477, AP: 0.8695, AUC: 0.8101
Epoch: 478, AP: 0.8695, AUC: 0.8101
Epoch: 479, AP: 0.8693, AUC: 0.8098
Epoch: 480, AP: 0.8694, AUC: 0.8099
Epoch: 481, AP: 0.8692, AUC: 0.8095
Epoch: 482, AP: 0.8692, AUC: 0.8094
Epoch: 483, AP: 0.8692, AUC: 0.8094
Epoch: 484, AP: 0.8690, AUC: 0.8091
Epoch: 485, AP: 0.8692, AUC: 0.8095
Epoch: 486, AP: 0.8693, AUC:

Epoch: 687, AP: 0.8677, AUC: 0.8060
Epoch: 688, AP: 0.8676, AUC: 0.8058
Epoch: 689, AP: 0.8675, AUC: 0.8056
Epoch: 690, AP: 0.8673, AUC: 0.8054
Epoch: 691, AP: 0.8672, AUC: 0.8052
Epoch: 692, AP: 0.8670, AUC: 0.8048
Epoch: 693, AP: 0.8670, AUC: 0.8048
Epoch: 694, AP: 0.8669, AUC: 0.8047
Epoch: 695, AP: 0.8667, AUC: 0.8044
Epoch: 696, AP: 0.8667, AUC: 0.8046
Epoch: 697, AP: 0.8668, AUC: 0.8046
Epoch: 698, AP: 0.8668, AUC: 0.8047
Epoch: 699, AP: 0.8670, AUC: 0.8052
Epoch: 700, AP: 0.8672, AUC: 0.8055
Epoch: 701, AP: 0.8674, AUC: 0.8059
Epoch: 702, AP: 0.8677, AUC: 0.8064
Epoch: 703, AP: 0.8679, AUC: 0.8067
Epoch: 704, AP: 0.8680, AUC: 0.8067
Epoch: 705, AP: 0.8682, AUC: 0.8070
Epoch: 706, AP: 0.8684, AUC: 0.8072
Epoch: 707, AP: 0.8685, AUC: 0.8073
Epoch: 708, AP: 0.8685, AUC: 0.8071
Epoch: 709, AP: 0.8687, AUC: 0.8074
Epoch: 710, AP: 0.8689, AUC: 0.8076
Epoch: 711, AP: 0.8690, AUC: 0.8078
Epoch: 712, AP: 0.8690, AUC: 0.8077
Epoch: 713, AP: 0.8691, AUC: 0.8076
Epoch: 714, AP: 0.8693, AUC: