In [1]:
import numpy as np
import pandas as pd
import pickle as pkl
import json
import scipy.sparse as sp

import networkx as nx
from networkx.readwrite import json_graph

from gae.model import *
from gae.optimizer import *
from gae.utils import *

import torch
from torch import optim
import torch.nn.functional as F

from datetime import datetime, timedelta
import time
import random
from collections import OrderedDict
import warnings; warnings.filterwarnings('ignore')

In [2]:
graphs = ['whole']
obj = []
for graph in graphs:
    with open(graph +'.graph', 'r') as f:
        data = json.load(f)
    obj.append(data)
whole_g = json_graph.node_link_graph(obj[0])
print(len(whole_g.nodes()), len(whole_g.edges()))

330 1564


In [52]:
idx = [i.upper() for i in whole_g.nodes()]

In [4]:
adj = nx.adjacency_matrix(whole_g, nodelist=whole_g.nodes())

In [5]:
idx2nodes = {}
nodes2idx = {}
for idx, node in enumerate(whole_g.nodes()):
    idx2nodes[idx] = node
    nodes2idx[node] = idx

### Model

In [40]:
hidden1 = 32
hidden2 = 16
lr = 0.0001
dropout = 0.
epochs = 200
val_ratio = 0.05

In [41]:
pos_weight = torch.Tensor([float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum()])
norm = adj.shape[0] * adj.shape[0] / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)

In [42]:
device = torch.device('cpu')

In [43]:
features = sp.identity(adj.shape[0])
features = torch.FloatTensor(np.array(features.todense()))
n_nodes, feat_dim = features.shape

In [44]:
model_name = 'GAE'

In [45]:
n_iter = 50

In [46]:
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

In [47]:
recon_adjs = np.zeros((adj.shape[0], adj.shape[0]))
for i in range(n_iter):
    if model_name == 'GAE':
        model = GCN_AE(feat_dim, hidden1, hidden2, dropout)
    elif model_name == 'VGAE':
        model = GCN_VAE(feat_dim, hidden1, hidden2, dropout)
        
    while True:
        try:
            train_adj, train_edges, val_edges, val_edges_false = mask_val_edges(adj, val_ratio)
        except AssertionError:
            continue
        break
    
    train_adj_norm = preprocess_graph(train_adj)
    adj_label = train_adj + sp.eye(train_adj.shape[0])
    adj_label = torch.FloatTensor(adj_label.toarray())
    
    train_adj_norm, features, adj_label, pos_weight = train_adj_norm.to(device), features.to(device), adj_label.to(device), pos_weight.to(device)
    
    model= model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    hidden_emb = None
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
    
        if model_name == 'GAE':
            recovered = model(features, train_adj_norm)
            loss = loss_function_gae(preds=recovered, labels=adj_label, norm=norm, pos_weight=pos_weight)
        else:
            recovered, mu, logvar = model(features, train_adj_norm)
            loss = loss_function_vgae(preds=recovered, labels=adj_label, mu=mu, logvar=logvar, n_nodes=n_nodes, norm=norm, pos_weight=pos_weight)

        loss.backward()
        cur_loss = loss.item()
        optimizer.step()
 
        hidden_emb = recovered.data.cpu().numpy()
    roc_score, ap_score, recon_adj = roc_ap_score(hidden_emb, adj, val_edges, val_edges_false)
    recon_adjs += sigmoid(recon_adj)
    print('Expriments {} result: val_roc: {:.4f}, val_ap: {:.4f}'.format(i+1, roc_score*100, ap_score*100), end='\n')

Expriments 1 result: val_roc: 78.7804, val_ap: 84.8147
Expriments 2 result: val_roc: 86.0289, val_ap: 86.0692
Expriments 3 result: val_roc: 85.3222, val_ap: 89.2717
Expriments 4 result: val_roc: 86.4727, val_ap: 87.9302
Expriments 5 result: val_roc: 85.8974, val_ap: 88.6267
Expriments 6 result: val_roc: 82.1006, val_ap: 81.8186
Expriments 7 result: val_roc: 86.2426, val_ap: 88.2326
Expriments 8 result: val_roc: 87.5575, val_ap: 88.8548
Expriments 9 result: val_roc: 82.2156, val_ap: 84.6825
Expriments 10 result: val_roc: 76.8409, val_ap: 81.4270
Expriments 11 result: val_roc: 87.8041, val_ap: 88.6845
Expriments 12 result: val_roc: 86.6535, val_ap: 88.1049
Expriments 13 result: val_roc: 81.7390, val_ap: 86.3509
Expriments 14 result: val_roc: 86.9165, val_ap: 86.4673
Expriments 15 result: val_roc: 81.0980, val_ap: 80.4722
Expriments 16 result: val_roc: 87.1137, val_ap: 87.5966
Expriments 17 result: val_roc: 79.4707, val_ap: 81.0751
Expriments 18 result: val_roc: 82.6430, val_ap: 80.9124
E

In [48]:
avg_recon_adjs = recon_adjs/50 - sp.eye(adj.shape[0])

In [49]:
avg_recon_adjs[avg_recon_adjs < 0.9] = 0.

In [50]:
np.count_nonzero(adj.todense())/2

1564.0

In [51]:
np.count_nonzero(avg_recon_adjs + adj)/2

1696.0

In [53]:
pred = pd.DataFrame(avg_recon_adjs + adj, columns=idx, index=idx)

In [56]:
pred

Unnamed: 0,A01K,A01M,A01N,A41D,A42B,A47B,A61B,A61F,A61H,A61K,...,H04M,H04N,H04Q,H04R,H04S,H04W,H05B,H05F,H05H,H05K
A01K,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
A01M,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
A01N,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
A41D,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
A42B,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
H04W,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
H05B,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
H05F,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
H05H,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [57]:
pred_G = nx.from_pandas_adjacency(pred)
pred_G_json = json_graph.node_link_data(pred_G)

In [58]:
# with open('pred.graph', 'w') as f:
#     json.dump(pred_G_json, f)

In [54]:
# pred.to_csv('pred_graph.csv')

In [37]:
ds = pd.read_csv('pred_graph.csv'); np.count_nonzero(ds.values)/2

1884.0