In [15]:
import os.path as osp
import argparse

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.datasets import Planetoid, Coauthor, Amazon
from torch_geometric.utils import train_test_split_edges
from torch_geometric.nn import GAE, VGAE, APPNP
import torch_geometric.transforms as T

from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score

import pandas as pd
import numpy as np
from torch_geometric.data import Data

In [16]:
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='VGNAE')
parser.add_argument('--dataset', type=str, default='email')
parser.add_argument('--epochs', type=int, default=800)
parser.add_argument('--channels', type=int, default=128)
parser.add_argument('--scaling_factor', type=float, default=1.8)
parser.add_argument('--training_rate', type=float, default=0.9) 
args,_ = parser.parse_known_args()

In [17]:
nodes_number = 1133

filename_adj = "datasets/ia-email-univ.mtx"

raw_edges = pd.read_csv(filename_adj, header=None,sep=' ') - 1

drop_self_loop = raw_edges[raw_edges[0]!=raw_edges[1]]

graph_np=np.zeros((nodes_number, nodes_number))

for i in range(drop_self_loop.shape[0]):
    graph_np[drop_self_loop.iloc[i,0],drop_self_loop.iloc[i,1]]=1
    graph_np[drop_self_loop.iloc[i,1],drop_self_loop.iloc[i,0]]=1
    
edges = torch.tensor([list(graph_np.nonzero()[0]),list(graph_np.nonzero()[1])])

features = torch.eye(nodes_number)

data = Data(x=features, edge_index=edges)

data = T.NormalizeFeatures()(data)

In [18]:
class Encoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels, edge_index):
        super(Encoder, self).__init__()
        self.linear1 = nn.Linear(in_channels, out_channels)
        self.linear2 = nn.Linear(in_channels, out_channels)
        self.propagate = APPNP(K=1, alpha=0)

    def forward(self, x, edge_index,not_prop=0):
        if args.model == 'GNAE':
            x = self.linear1(x)
            x = F.normalize(x,p=2,dim=1)  * args.scaling_factor
            x = self.propagate(x, edge_index)
            return x

        if args.model == 'VGNAE':
            x_ = self.linear1(x)
            x_ = self.propagate(x_, edge_index)

            x = self.linear2(x)
            x = F.normalize(x,p=2,dim=1) * args.scaling_factor
            x = self.propagate(x, edge_index)
            return x, x_

        return x

In [19]:
dev = torch.device('cpu')
channels = args.channels
train_rate = args.training_rate
val_ratio = (1-args.training_rate) / 3
test_ratio = (1-args.training_rate) / 3 * 2
data = train_test_split_edges(data.to(dev), val_ratio=val_ratio, test_ratio=test_ratio)

N = int(data.x.size()[0])
if args.model == 'GNAE':   
    model = GAE(Encoder(data.x.size()[1], channels, data.train_pos_edge_index)).to(dev)
if args.model == 'VGNAE':
    model = VGAE(Encoder(data.x.size()[1], channels, data.train_pos_edge_index)).to(dev)

data.train_mask = data.val_mask = data.test_mask = data.y = None
x, train_pos_edge_index = data.x.to(dev), data.train_pos_edge_index.to(dev)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)



In [20]:
def computer_indi(z, pos_edge_index, neg_edge_index, model):
    pos_y = z.new_ones(pos_edge_index.size(1))
    neg_y = z.new_zeros(neg_edge_index.size(1))
    y = torch.cat([pos_y, neg_y], dim=0)

    pos_pred = model.decoder(z, pos_edge_index, sigmoid=True)
    neg_pred = model.decoder(z, neg_edge_index, sigmoid=True)
    pred = torch.cat([pos_pred, neg_pred], dim=0)

    y, pred = y.detach().cpu().numpy(), pred.detach().cpu().numpy()
    
    auc = roc_auc_score(y, pred)
    ap = average_precision_score(y, pred)

    return ap, auc

In [21]:
def train():
    model.train()
    optimizer.zero_grad()
    z  = model.encode(x, train_pos_edge_index)
    loss = model.recon_loss(z, train_pos_edge_index)
    if args.model in ['VGAE']:
        loss = loss + (1 / data.num_nodes) * model.kl_loss()
    loss.backward()
    optimizer.step()
    return loss

def test(pos_edge_index, neg_edge_index, plot_his=0):
    model.eval()
    with torch.no_grad():
        z = model.encode(x, train_pos_edge_index)
    return computer_indi(z, pos_edge_index, neg_edge_index, model)

In [22]:
for epoch in range(1,args.epochs):
    loss = train()
    loss = float(loss)
    
    with torch.no_grad():
        test_pos, test_neg = data.test_pos_edge_index, data.test_neg_edge_index
        ap, auc = test(data.test_pos_edge_index, data.test_neg_edge_index)
        print('Epoch: {:03d}, AP: {:.4f}, AUC: {:.4f}'.format(epoch, ap, auc))

Epoch: 001, AP: 0.7597, AUC: 0.7696
Epoch: 002, AP: 0.7661, AUC: 0.7762
Epoch: 003, AP: 0.7703, AUC: 0.7805
Epoch: 004, AP: 0.7730, AUC: 0.7835
Epoch: 005, AP: 0.7756, AUC: 0.7863
Epoch: 006, AP: 0.7787, AUC: 0.7893
Epoch: 007, AP: 0.7817, AUC: 0.7921
Epoch: 008, AP: 0.7851, AUC: 0.7949
Epoch: 009, AP: 0.7874, AUC: 0.7974
Epoch: 010, AP: 0.7907, AUC: 0.8002
Epoch: 011, AP: 0.7937, AUC: 0.8034
Epoch: 012, AP: 0.7975, AUC: 0.8064
Epoch: 013, AP: 0.8008, AUC: 0.8095
Epoch: 014, AP: 0.8039, AUC: 0.8119
Epoch: 015, AP: 0.8073, AUC: 0.8147
Epoch: 016, AP: 0.8109, AUC: 0.8173
Epoch: 017, AP: 0.8142, AUC: 0.8197
Epoch: 018, AP: 0.8179, AUC: 0.8213
Epoch: 019, AP: 0.8216, AUC: 0.8230
Epoch: 020, AP: 0.8260, AUC: 0.8248
Epoch: 021, AP: 0.8293, AUC: 0.8262
Epoch: 022, AP: 0.8332, AUC: 0.8282
Epoch: 023, AP: 0.8376, AUC: 0.8306
Epoch: 024, AP: 0.8422, AUC: 0.8333
Epoch: 025, AP: 0.8472, AUC: 0.8364
Epoch: 026, AP: 0.8508, AUC: 0.8389
Epoch: 027, AP: 0.8543, AUC: 0.8411
Epoch: 028, AP: 0.8577, AUC:

Epoch: 230, AP: 0.8829, AUC: 0.8658
Epoch: 231, AP: 0.8830, AUC: 0.8661
Epoch: 232, AP: 0.8832, AUC: 0.8663
Epoch: 233, AP: 0.8832, AUC: 0.8664
Epoch: 234, AP: 0.8833, AUC: 0.8665
Epoch: 235, AP: 0.8834, AUC: 0.8666
Epoch: 236, AP: 0.8835, AUC: 0.8668
Epoch: 237, AP: 0.8836, AUC: 0.8668
Epoch: 238, AP: 0.8837, AUC: 0.8669
Epoch: 239, AP: 0.8838, AUC: 0.8671
Epoch: 240, AP: 0.8839, AUC: 0.8673
Epoch: 241, AP: 0.8839, AUC: 0.8672
Epoch: 242, AP: 0.8839, AUC: 0.8673
Epoch: 243, AP: 0.8838, AUC: 0.8672
Epoch: 244, AP: 0.8839, AUC: 0.8672
Epoch: 245, AP: 0.8841, AUC: 0.8671
Epoch: 246, AP: 0.8839, AUC: 0.8668
Epoch: 247, AP: 0.8838, AUC: 0.8668
Epoch: 248, AP: 0.8837, AUC: 0.8667
Epoch: 249, AP: 0.8836, AUC: 0.8665
Epoch: 250, AP: 0.8834, AUC: 0.8663
Epoch: 251, AP: 0.8833, AUC: 0.8662
Epoch: 252, AP: 0.8833, AUC: 0.8661
Epoch: 253, AP: 0.8832, AUC: 0.8659
Epoch: 254, AP: 0.8830, AUC: 0.8657
Epoch: 255, AP: 0.8828, AUC: 0.8654
Epoch: 256, AP: 0.8828, AUC: 0.8654
Epoch: 257, AP: 0.8829, AUC:

Epoch: 458, AP: 0.8938, AUC: 0.8759
Epoch: 459, AP: 0.8937, AUC: 0.8758
Epoch: 460, AP: 0.8938, AUC: 0.8758
Epoch: 461, AP: 0.8937, AUC: 0.8756
Epoch: 462, AP: 0.8937, AUC: 0.8754
Epoch: 463, AP: 0.8938, AUC: 0.8754
Epoch: 464, AP: 0.8936, AUC: 0.8751
Epoch: 465, AP: 0.8935, AUC: 0.8750
Epoch: 466, AP: 0.8936, AUC: 0.8749
Epoch: 467, AP: 0.8935, AUC: 0.8748
Epoch: 468, AP: 0.8936, AUC: 0.8748
Epoch: 469, AP: 0.8936, AUC: 0.8748
Epoch: 470, AP: 0.8936, AUC: 0.8749
Epoch: 471, AP: 0.8937, AUC: 0.8750
Epoch: 472, AP: 0.8940, AUC: 0.8753
Epoch: 473, AP: 0.8941, AUC: 0.8756
Epoch: 474, AP: 0.8941, AUC: 0.8757
Epoch: 475, AP: 0.8942, AUC: 0.8757
Epoch: 476, AP: 0.8943, AUC: 0.8759
Epoch: 477, AP: 0.8944, AUC: 0.8760
Epoch: 478, AP: 0.8946, AUC: 0.8763
Epoch: 479, AP: 0.8947, AUC: 0.8763
Epoch: 480, AP: 0.8947, AUC: 0.8763
Epoch: 481, AP: 0.8948, AUC: 0.8764
Epoch: 482, AP: 0.8948, AUC: 0.8764
Epoch: 483, AP: 0.8948, AUC: 0.8763
Epoch: 484, AP: 0.8947, AUC: 0.8762
Epoch: 485, AP: 0.8947, AUC:

Epoch: 687, AP: 0.8964, AUC: 0.8775
Epoch: 688, AP: 0.8965, AUC: 0.8777
Epoch: 689, AP: 0.8966, AUC: 0.8778
Epoch: 690, AP: 0.8967, AUC: 0.8778
Epoch: 691, AP: 0.8967, AUC: 0.8779
Epoch: 692, AP: 0.8968, AUC: 0.8781
Epoch: 693, AP: 0.8970, AUC: 0.8784
Epoch: 694, AP: 0.8971, AUC: 0.8786
Epoch: 695, AP: 0.8973, AUC: 0.8787
Epoch: 696, AP: 0.8974, AUC: 0.8788
Epoch: 697, AP: 0.8975, AUC: 0.8789
Epoch: 698, AP: 0.8976, AUC: 0.8789
Epoch: 699, AP: 0.8976, AUC: 0.8788
Epoch: 700, AP: 0.8976, AUC: 0.8786
Epoch: 701, AP: 0.8974, AUC: 0.8784
Epoch: 702, AP: 0.8974, AUC: 0.8783
Epoch: 703, AP: 0.8975, AUC: 0.8783
Epoch: 704, AP: 0.8974, AUC: 0.8781
Epoch: 705, AP: 0.8975, AUC: 0.8781
Epoch: 706, AP: 0.8973, AUC: 0.8780
Epoch: 707, AP: 0.8973, AUC: 0.8779
Epoch: 708, AP: 0.8972, AUC: 0.8779
Epoch: 709, AP: 0.8971, AUC: 0.8779
Epoch: 710, AP: 0.8972, AUC: 0.8780
Epoch: 711, AP: 0.8971, AUC: 0.8780
Epoch: 712, AP: 0.8970, AUC: 0.8781
Epoch: 713, AP: 0.8970, AUC: 0.8781
Epoch: 714, AP: 0.8969, AUC: