# Experimenting with VGAE Code

Code source: https://github.com/DaehanKim/vgae_pytorch
Paper reference: "Variational Graph Auto-Encoders" by Thomas N. Kipf and Max Welling, 2016

## To figure out: 
- [ ] how do they pre-process their data? What form does their input data take? 
- [ ] how does GAE and GVAE work? can I implement? 

In [34]:
!pip install networkx

Collecting networkx
  Downloading networkx-2.6.2-py3-none-any.whl (1.9 MB)
[K     |████████████████████████████████| 1.9 MB 29.0 MB/s eta 0:00:01
[?25hInstalling collected packages: networkx
Successfully installed networkx-2.6.2


In [35]:
import torch
import torch.nn.functional as F
from torch.optim import Adam
from sklearn.metrics import roc_auc_score, average_precision_score
import scipy.sparse as sp
import numpy as np
import os
import time
from pyprojroot import here
import sys
import pickle as pkl
import networkx as nx

# from input_data import load_data
# from preprocessing import *
# import args
# import model

In [36]:
root = here(project_files=[".here"])
sys.path.append(str(root))

In [21]:
print(root)

/home/emma/vgae_pytorch


In [37]:
def parse_index_file(filename):
    index = []
    for line in open(filename):
        index.append(int(line.strip()))
    return index

In [47]:
def load_data(dataset):
    # load the data: x, tx, allx, graph
    names = ['x', 'tx', 'allx', 'graph']
    objects = []
    for i in range(len(names)):
        with open("{}/data/ind.{}.{}".format(root, dataset, names[i]), 'rb') as f:
            if sys.version_info > (3, 0):
                objects.append(pkl.load(f, encoding='latin1'))
            else:
                objects.append(pkl.load(f))
    x, tx, allx, graph = tuple(objects)
    print('graph', type(graph))
#     print('graph', graph)
    # graph is a dict (default dict from collections module)
    # each key is a node, each value is a list of the adjacent nodes
    test_idx_reorder = parse_index_file("{}/data/ind.{}.test.index".format(root, dataset))
    test_idx_range = np.sort(test_idx_reorder)

    if dataset == 'citeseer':
        # Fix citeseer dataset (there are some isolated nodes in the graph)
        # Find isolated nodes, add them as zero-vecs into the right position
        test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1)
        tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
        tx_extended[test_idx_range-min(test_idx_range), :] = tx
        tx = tx_extended

    features = sp.vstack((allx, tx)).tolil()
    features[test_idx_reorder, :] = features[test_idx_range, :]
    adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))

    return adj, features

In [50]:
dataset = 'cora'
adj, features = load_data(dataset)

print(adj.shape)
print(type(adj))
# print(adj)

# adj is a sparse matrix (scipy datatype) that contains all of the information provided in graph, see above.

graph <class 'collections.defaultdict'>
(2708, 2708)
<class 'scipy.sparse.csr.csr_matrix'>


In [55]:
# Store original adjacency matrix (without diagonal entries) for later
# TODO: look up the scipy sparse matrix format
adj_orig = adj
adj_orig = adj_orig - sp.dia_matrix((adj_orig.diagonal()[np.newaxis, :], [0]), shape=adj_orig.shape)
# print(adj_orig)
adj_orig.eliminate_zeros()
# print(adj_orig)

In [56]:
# TODO: understand these functions and what they're doing
def sparse_to_tuple(sparse_mx):
    if not sp.isspmatrix_coo(sparse_mx):
        sparse_mx = sparse_mx.tocoo()
    coords = np.vstack((sparse_mx.row, sparse_mx.col)).transpose()
    values = sparse_mx.data
    shape = sparse_mx.shape
    return coords, values, shape

def preprocess_graph(adj):
    adj = sp.coo_matrix(adj)
    adj_ = adj + sp.eye(adj.shape[0])
    rowsum = np.array(adj_.sum(1))
    degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten())
    adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()
    return sparse_to_tuple(adj_normalized)

In [63]:
# Some preprocessing
adj_norm = preprocess_graph(adj)
print(len(adj_norm))
print(adj_norm[0]) # coords
print(adj_norm[1]) # values
print(adj_norm[2]) # shape

3
[[   0    0]
 [ 633    0]
 [1862    0]
 ...
 [1473 2707]
 [2706 2707]
 [2707 2707]]
[0.25      0.25      0.2236068 ... 0.2       0.2       0.2      ]
(2708, 2708)
