In [1]:
import os.path as osp
import argparse

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.datasets import Planetoid, Coauthor, Amazon
from torch_geometric.utils import train_test_split_edges
from torch_geometric.nn import GAE, VGAE, APPNP
import torch_geometric.transforms as T

from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='VGNAE')
parser.add_argument('--dataset', type=str, default='Cora')
parser.add_argument('--epochs', type=int, default=500)
parser.add_argument('--channels', type=int, default=128)
parser.add_argument('--scaling_factor', type=float, default=1.8)
parser.add_argument('--training_rate', type=float, default=0.8) 
args,_ = parser.parse_known_args()

In [4]:
dataset = Planetoid('./datasets/', args.dataset, 'public')

data = dataset[0]
data.x = torch.eye(2708)
data = T.NormalizeFeatures()(data)

In [5]:
class Encoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels, edge_index):
        super(Encoder, self).__init__()
        self.linear1 = nn.Linear(in_channels, out_channels)
        self.linear2 = nn.Linear(in_channels, out_channels)
        self.propagate = APPNP(K=1, alpha=0)

    def forward(self, x, edge_index,not_prop=0):
        if args.model == 'GNAE':
            x = self.linear1(x)
            x = F.normalize(x,p=2,dim=1)  * args.scaling_factor
            x = self.propagate(x, edge_index)
            return x

        if args.model == 'VGNAE':
            x_ = self.linear1(x)
            x_ = self.propagate(x_, edge_index)

            x = self.linear2(x)
            x = F.normalize(x,p=2,dim=1) * args.scaling_factor
            x = self.propagate(x, edge_index)
            return x, x_

        return x

In [6]:
dev = torch.device('cpu')
channels = args.channels
train_rate = args.training_rate
val_ratio = (1-args.training_rate) / 3
test_ratio = (1-args.training_rate) / 3 * 2
data = train_test_split_edges(data.to(dev), val_ratio=val_ratio, test_ratio=test_ratio)

N = int(data.x.size()[0])
if args.model == 'GNAE':   
    model = GAE(Encoder(data.x.size()[1], channels, data.train_pos_edge_index)).to(dev)
if args.model == 'VGNAE':
    model = VGAE(Encoder(data.x.size()[1], channels, data.train_pos_edge_index)).to(dev)

data.train_mask = data.val_mask = data.test_mask = data.y = None
x, train_pos_edge_index = data.x.to(dev), data.train_pos_edge_index.to(dev)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)



In [7]:
def computer_indi(z, pos_edge_index, neg_edge_index, model):
    pos_y = z.new_ones(pos_edge_index.size(1))
    neg_y = z.new_zeros(neg_edge_index.size(1))
    y = torch.cat([pos_y, neg_y], dim=0)

    pos_pred = model.decoder(z, pos_edge_index, sigmoid=True)
    neg_pred = model.decoder(z, neg_edge_index, sigmoid=True)
    pred = torch.cat([pos_pred, neg_pred], dim=0)

    y, pred = y.detach().cpu().numpy(), pred.detach().cpu().numpy()
    
    auc = roc_auc_score(y, pred)
    ap = average_precision_score(y, pred)
    
    pred[pred>0.5] = 1
    pred[pred<=0.5] = 0
    
    acc = accuracy_score(y, pred)
    f1 = f1_score(y, pred, average='macro')

    return ap, auc, acc, f1

In [8]:
def train():
    model.train()
    optimizer.zero_grad()
    z  = model.encode(x, train_pos_edge_index)
    loss = model.recon_loss(z, train_pos_edge_index)
    if args.model in ['VGAE']:
        loss = loss + (1 / data.num_nodes) * model.kl_loss()
    loss.backward()
    optimizer.step()
    return loss

def test(pos_edge_index, neg_edge_index, plot_his=0):
    model.eval()
    with torch.no_grad():
        z = model.encode(x, train_pos_edge_index)
    return computer_indi(z, pos_edge_index, neg_edge_index, model)

In [9]:
for epoch in range(1,args.epochs):
    loss = train()
    loss = float(loss)
    
    with torch.no_grad():
        test_pos, test_neg = data.test_pos_edge_index, data.test_neg_edge_index
        ap, auc, acc, f1 = test(data.test_pos_edge_index, data.test_neg_edge_index)
        print('Epoch: {:03d}, AP: {:.4f}, AUC: {:.4f}, ACC: {:.4f}, F1: {:.4f}'.format(epoch, ap, auc, acc, f1))

Epoch: 001, AP: 0.7557, AUC: 0.7296, ACC: 0.5000, F1: 0.3333
Epoch: 002, AP: 0.7647, AUC: 0.7414, ACC: 0.5000, F1: 0.3333
Epoch: 003, AP: 0.7727, AUC: 0.7513, ACC: 0.5000, F1: 0.3333
Epoch: 004, AP: 0.7808, AUC: 0.7605, ACC: 0.5000, F1: 0.3333
Epoch: 005, AP: 0.7889, AUC: 0.7686, ACC: 0.5000, F1: 0.3333
Epoch: 006, AP: 0.7978, AUC: 0.7762, ACC: 0.5000, F1: 0.3333
Epoch: 007, AP: 0.8070, AUC: 0.7827, ACC: 0.5000, F1: 0.3333
Epoch: 008, AP: 0.8161, AUC: 0.7884, ACC: 0.5000, F1: 0.3333
Epoch: 009, AP: 0.8247, AUC: 0.7927, ACC: 0.5000, F1: 0.3333
Epoch: 010, AP: 0.8317, AUC: 0.7948, ACC: 0.5000, F1: 0.3346
Epoch: 011, AP: 0.8376, AUC: 0.7964, ACC: 0.5007, F1: 0.3386
Epoch: 012, AP: 0.8417, AUC: 0.7962, ACC: 0.5000, F1: 0.3491
Epoch: 013, AP: 0.8447, AUC: 0.7950, ACC: 0.5050, F1: 0.3694
Epoch: 014, AP: 0.8476, AUC: 0.7948, ACC: 0.5185, F1: 0.4046
Epoch: 015, AP: 0.8502, AUC: 0.7951, ACC: 0.5249, F1: 0.4299
Epoch: 016, AP: 0.8519, AUC: 0.7955, ACC: 0.5405, F1: 0.4664
Epoch: 017, AP: 0.8537, 

Epoch: 137, AP: 0.8880, AUC: 0.8345, ACC: 0.6927, F1: 0.6856
Epoch: 138, AP: 0.8882, AUC: 0.8350, ACC: 0.6927, F1: 0.6861
Epoch: 139, AP: 0.8883, AUC: 0.8353, ACC: 0.6920, F1: 0.6855
Epoch: 140, AP: 0.8885, AUC: 0.8356, ACC: 0.6935, F1: 0.6869
Epoch: 141, AP: 0.8887, AUC: 0.8360, ACC: 0.6942, F1: 0.6875
Epoch: 142, AP: 0.8889, AUC: 0.8363, ACC: 0.6942, F1: 0.6875
Epoch: 143, AP: 0.8890, AUC: 0.8365, ACC: 0.6935, F1: 0.6865
Epoch: 144, AP: 0.8891, AUC: 0.8367, ACC: 0.6942, F1: 0.6872
Epoch: 145, AP: 0.8892, AUC: 0.8368, ACC: 0.6949, F1: 0.6880
Epoch: 146, AP: 0.8893, AUC: 0.8371, ACC: 0.6956, F1: 0.6885
Epoch: 147, AP: 0.8895, AUC: 0.8374, ACC: 0.6942, F1: 0.6869
Epoch: 148, AP: 0.8896, AUC: 0.8377, ACC: 0.6927, F1: 0.6855
Epoch: 149, AP: 0.8898, AUC: 0.8379, ACC: 0.6927, F1: 0.6856
Epoch: 150, AP: 0.8898, AUC: 0.8380, ACC: 0.6942, F1: 0.6875
Epoch: 151, AP: 0.8899, AUC: 0.8382, ACC: 0.6949, F1: 0.6883
Epoch: 152, AP: 0.8900, AUC: 0.8383, ACC: 0.6970, F1: 0.6905
Epoch: 153, AP: 0.8902, 

Epoch: 273, AP: 0.8960, AUC: 0.8492, ACC: 0.7084, F1: 0.7030
Epoch: 274, AP: 0.8960, AUC: 0.8493, ACC: 0.7084, F1: 0.7030
Epoch: 275, AP: 0.8961, AUC: 0.8493, ACC: 0.7091, F1: 0.7039
Epoch: 276, AP: 0.8961, AUC: 0.8493, ACC: 0.7084, F1: 0.7030
Epoch: 277, AP: 0.8961, AUC: 0.8493, ACC: 0.7091, F1: 0.7036
Epoch: 278, AP: 0.8961, AUC: 0.8493, ACC: 0.7048, F1: 0.6992
Epoch: 279, AP: 0.8962, AUC: 0.8494, ACC: 0.7041, F1: 0.6983
Epoch: 280, AP: 0.8963, AUC: 0.8496, ACC: 0.7048, F1: 0.6989
Epoch: 281, AP: 0.8964, AUC: 0.8498, ACC: 0.7063, F1: 0.7004
Epoch: 282, AP: 0.8964, AUC: 0.8498, ACC: 0.7070, F1: 0.7012
Epoch: 283, AP: 0.8964, AUC: 0.8499, ACC: 0.7055, F1: 0.6996
Epoch: 284, AP: 0.8964, AUC: 0.8498, ACC: 0.7070, F1: 0.7012
Epoch: 285, AP: 0.8964, AUC: 0.8497, ACC: 0.7091, F1: 0.7034
Epoch: 286, AP: 0.8963, AUC: 0.8497, ACC: 0.7077, F1: 0.7018
Epoch: 287, AP: 0.8963, AUC: 0.8498, ACC: 0.7084, F1: 0.7026
Epoch: 288, AP: 0.8962, AUC: 0.8497, ACC: 0.7063, F1: 0.7001
Epoch: 289, AP: 0.8961, 

Epoch: 409, AP: 0.9000, AUC: 0.8555, ACC: 0.7148, F1: 0.7092
Epoch: 410, AP: 0.9002, AUC: 0.8557, ACC: 0.7155, F1: 0.7099
Epoch: 411, AP: 0.9003, AUC: 0.8559, ACC: 0.7141, F1: 0.7083
Epoch: 412, AP: 0.9005, AUC: 0.8562, ACC: 0.7148, F1: 0.7091
Epoch: 413, AP: 0.9007, AUC: 0.8566, ACC: 0.7148, F1: 0.7090
Epoch: 414, AP: 0.9008, AUC: 0.8568, ACC: 0.7155, F1: 0.7096
Epoch: 415, AP: 0.9009, AUC: 0.8570, ACC: 0.7134, F1: 0.7073
Epoch: 416, AP: 0.9010, AUC: 0.8571, ACC: 0.7155, F1: 0.7096
Epoch: 417, AP: 0.9010, AUC: 0.8572, ACC: 0.7162, F1: 0.7104
Epoch: 418, AP: 0.9011, AUC: 0.8573, ACC: 0.7169, F1: 0.7112
Epoch: 419, AP: 0.9011, AUC: 0.8573, ACC: 0.7169, F1: 0.7113
Epoch: 420, AP: 0.9011, AUC: 0.8574, ACC: 0.7155, F1: 0.7099
Epoch: 421, AP: 0.9011, AUC: 0.8573, ACC: 0.7169, F1: 0.7114
Epoch: 422, AP: 0.9010, AUC: 0.8572, ACC: 0.7169, F1: 0.7115
Epoch: 423, AP: 0.9009, AUC: 0.8571, ACC: 0.7141, F1: 0.7085
Epoch: 424, AP: 0.9008, AUC: 0.8569, ACC: 0.7141, F1: 0.7084
Epoch: 425, AP: 0.9008, 