In [25]:
from __future__ import division
from __future__ import print_function

import argparse
import time

import numpy as np
import scipy.sparse as sp
import torch
from torch import optim

from gae.model import GCNModelVAE
from gae.optimizer import loss_function
from gae.utils import load_data, mask_test_edges, preprocess_graph, get_roc_score

In [26]:
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='gcn_vae', help="models used")
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--epochs', type=int, default=200, help='Number of epochs to train.')
parser.add_argument('--hidden1', type=int, default=32, help='Number of units in hidden layer 1.')
parser.add_argument('--hidden2', type=int, default=16, help='Number of units in hidden layer 2.')
parser.add_argument('--lr', type=float, default=0.01, help='Initial learning rate.')
parser.add_argument('--dropout', type=float, default=0., help='Dropout rate (1 - keep probability).')
parser.add_argument('--dataset-str', type=str, default='cora', help='type of dataset.')

args,_ = parser.parse_known_args()

In [27]:
def gae_for(args):
    print("Using {} dataset".format(args.dataset_str))
    adj, features = load_data(args.dataset_str)
    features = torch.eye(2708)
    
    n_nodes, feat_dim = features.shape

    # Store original adjacency matrix (without diagonal entries) for later
    adj_orig = adj
    adj_orig = adj_orig - sp.dia_matrix((adj_orig.diagonal()[np.newaxis, :], [0]), shape=adj_orig.shape)
    adj_orig.eliminate_zeros()

    adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false = mask_test_edges(adj)
    adj = adj_train

    # Some preprocessing
    adj_norm = preprocess_graph(adj)
    adj_label = adj_train + sp.eye(adj_train.shape[0])
    # adj_label = sparse_to_tuple(adj_label)
    adj_label = torch.FloatTensor(adj_label.toarray())

    pos_weight = torch.tensor(float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum())
    norm = adj.shape[0] * adj.shape[0] / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)

    model = GCNModelVAE(feat_dim, args.hidden1, args.hidden2, args.dropout)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    hidden_emb = None
    for epoch in range(args.epochs):
        t = time.time()
        model.train()
        optimizer.zero_grad()
        recovered, mu, logvar = model(features, adj_norm)
        
        loss = loss_function(preds=recovered, labels=adj_label,
                             mu=mu, logvar=logvar, n_nodes=n_nodes,
                             norm=norm, pos_weight=pos_weight)
        loss.backward()
        cur_loss = loss.item()
        optimizer.step()

        hidden_emb = mu.data.numpy()
        roc_curr, ap_curr = get_roc_score(hidden_emb, adj_orig, val_edges, val_edges_false)

        print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(cur_loss),
              "val_ap=", "{:.5f}".format(ap_curr),
              "time=", "{:.5f}".format(time.time() - t)
              )

    print("Optimization Finished!")

    roc_score, ap_score = get_roc_score(hidden_emb, adj_orig, test_edges, test_edges_false)
    print('Test ROC score: ' + str(roc_score))
    print('Test AP score: ' + str(ap_score))

In [28]:
if __name__ == '__main__':
    gae_for(args)

Using cora dataset
Epoch: 0001 train_loss= 1.73497 val_ap= 0.64207 time= 0.19647
Epoch: 0002 train_loss= 1.69780 val_ap= 0.70084 time= 0.20346
Epoch: 0003 train_loss= 1.68353 val_ap= 0.71687 time= 0.20844
Epoch: 0004 train_loss= 1.68172 val_ap= 0.72343 time= 0.20046
Epoch: 0005 train_loss= 1.59166 val_ap= 0.71836 time= 0.20345
Epoch: 0006 train_loss= 1.57900 val_ap= 0.71662 time= 0.21243
Epoch: 0007 train_loss= 1.54107 val_ap= 0.71319 time= 0.19857
Epoch: 0008 train_loss= 1.45779 val_ap= 0.70946 time= 0.20054
Epoch: 0009 train_loss= 1.43472 val_ap= 0.70605 time= 0.19643
Epoch: 0010 train_loss= 1.36174 val_ap= 0.70362 time= 0.21144
Epoch: 0011 train_loss= 1.31118 val_ap= 0.70000 time= 0.20351
Epoch: 0012 train_loss= 1.25039 val_ap= 0.69795 time= 0.19775
Epoch: 0013 train_loss= 1.17574 val_ap= 0.69556 time= 0.21543
Epoch: 0014 train_loss= 1.11304 val_ap= 0.69392 time= 0.19847
Epoch: 0015 train_loss= 1.05155 val_ap= 0.69340 time= 0.20744
Epoch: 0016 train_loss= 1.00844 val_ap= 0.69263 tim

Epoch: 0133 train_loss= 0.44096 val_ap= 0.88232 time= 0.20734
Epoch: 0134 train_loss= 0.44085 val_ap= 0.88194 time= 0.20722
Epoch: 0135 train_loss= 0.44029 val_ap= 0.88182 time= 0.20693
Epoch: 0136 train_loss= 0.44036 val_ap= 0.88225 time= 0.19961
Epoch: 0137 train_loss= 0.44008 val_ap= 0.88292 time= 0.18697
Epoch: 0138 train_loss= 0.43963 val_ap= 0.88288 time= 0.22749
Epoch: 0139 train_loss= 0.43921 val_ap= 0.88285 time= 0.23279
Epoch: 0140 train_loss= 0.43905 val_ap= 0.88270 time= 0.23974
Epoch: 0141 train_loss= 0.43867 val_ap= 0.88229 time= 0.20556
Epoch: 0142 train_loss= 0.43822 val_ap= 0.88218 time= 0.18953
Epoch: 0143 train_loss= 0.43811 val_ap= 0.88232 time= 0.18761
Epoch: 0144 train_loss= 0.43791 val_ap= 0.88291 time= 0.19132
Epoch: 0145 train_loss= 0.43739 val_ap= 0.88283 time= 0.18752
Epoch: 0146 train_loss= 0.43691 val_ap= 0.88269 time= 0.19491
Epoch: 0147 train_loss= 0.43685 val_ap= 0.88258 time= 0.19484
Epoch: 0148 train_loss= 0.43656 val_ap= 0.88293 time= 0.19448
Epoch: 0