In [10]:
from torch_geometric.datasets import Planetoid
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter 
import numpy as np
from sklearn.manifold import TSNE 

In [5]:
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]

In [80]:
class GraphAE(Module):
    def __init__(self, X, A, hid_dim = 32, embed_dim = 16, \
                 renormalize = False, act_func = F.relu):
        super(GraphAE, self).__init__()
        self.X = X
        self.A = A
        self.act_func = act_func
        
        self.in_feat = X.shape[1]
        self.weight_0 = Parameter(torch.FloatTensor(self.in_feat, hid_dim))
        self.weight_mu = Parameter(torch.FloatTensor(hid_dim, embed_dim))
#         self.weight_sigma = Parameter(torch.FloatTensor(hid_dim, embed_dim))
        
        # compute symmetrically normalized adjacency matrix
        self.norm_adj = A + 0
        if renormalize:
            self.norm_adj += np.eye(self.A.shape[0])
        D_half_inv = np.diag(self.norm_adj.sum(axis = 0)**(-0.5))
        self.norm_adj = np.matmul(D_half_inv, self.norm_adj.dot(D_half_inv))
        self.norm_adj = torch.FloatTensor(self.norm_adj)
        
        self.reset_parameters()
        
    def reset_parameters(self):
        torch.nn.init.xavier_normal_(self.weight_0)
        torch.nn.init.xavier_normal_(self.weight_mu)
#         torch.nn.init.xavier_normal_(self.weight_sigma)

    def forward(self, input_):
        hid_ = self.act_func(torch.spmm(self.norm_adj, \
                                        torch.spmm(input_, self.weight_0)))
        mu_embed_ = torch.spmm(self.norm_adj, \
                            torch.spmm(hid_, self.weight_mu))
#         sigma_embed_ = torch.spmm(self.norm_adj, \
#                             torch.spmm(hid_, self.weight_sigma))

        
        pre_output = torch.mm(mu_embed_, mu_embed_.transpose(0, 1))
        output = torch.sigmoid(pre_output)
        return mu_embed_, output
        

In [81]:
data

Data(edge_index=[2, 10556], test_mask=[2708], train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])

In [82]:
# construct feature matrix and adjecent matrix
X = data.x
# A = torch.zeros(X.shape[0], X.shape[0])
A = np.zeros((X.shape[0],X.shape[0]))
for i, j in zip(data.edge_index[0], data.edge_index[1]):
    A[i, j] = 1
    A[j, i] = 1

In [83]:
model = GraphAE(X, A)

In [84]:
a = model(X)

# TODO
1. training function
2. loss function
3. plot function