In [14]:
import gust
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from scipy.spatial.distance import squareform

In [15]:
torch.set_default_tensor_type('torch.cuda.FloatTensor')

A, X, _, z = gust.load_dataset('cora_ml').standardize().unpack()
adj = torch.FloatTensor(A.toarray()).cuda()

In [16]:
N = A.shape[0]
D = 32

Z = nn.Parameter(torch.empty(N, D).normal_(std=0.1))
opt = torch.optim.Adam([Z], lr=1e-2)
e1, e2 = A.nonzero()

In [17]:
def sig(Z, b=0.1, eps=1e-8): 
    dist = torch.matmul(Z,Z.T) +b
    sigdist = 1/(1+torch.exp(dist+eps)+eps)
    logsigdist = torch.log(sigdist+eps)
    pos_term = logsigdist[e1,e2]
    neg_term = torch.log(1-sigdist)
    neg_term[np.diag_indices(N)] = 0.0
    
    return -(pos_term.sum() + neg_term.sum()) / Z.shape[0]**2

In [18]:
def dist(Z, eps=1e-5):
    dist = ((Z[:, None] - Z[None, :]).pow(2.0).sum(-1) + eps).sqrt()
    neg_term = torch.log(-torch.expm1(-dist) + eps)
    neg_term[np.diag_indices(N)] = 0.0
    pos_term = -dist[e1, e2]   
    neg_term[e1, e2] = 0.0
    
    return -(pos_term.sum() + neg_term.sum()) / Z.shape[0]**2

In [19]:
for epoch in range(100):
    opt.zero_grad()
    loss = sig(Z)
    loss.backward()
    opt.step()
    print(loss.item())

0.6460767388343811
0.6446844339370728
0.6415046453475952
0.6360921859741211
0.6282792091369629
0.617963433265686
0.6050869226455688
0.5896317362785339
0.5716220736503601
0.5511276125907898
0.528266966342926
0.5032113790512085
0.4761869013309479
0.4474751949310303
0.4174121916294098
0.3863829970359802
0.35481417179107666
0.3231607973575592
0.2918904423713684
0.2614634037017822
0.23231090605258942
0.2048133760690689
0.1792810708284378
0.15593978762626648
0.1349232792854309
0.11627373844385147
0.09994973242282867
0.0858401507139206
0.07378186285495758
0.063578300178051
0.05501674860715866
0.04788283258676529
0.04197130352258682
0.03709319606423378
0.03307977691292763
0.029783945530653
0.027079854160547256
0.02486143261194229
0.02304030768573284
0.02154347486793995
0.020311031490564346
0.019294053316116333
0.018452702090144157
0.01775461435317993
0.0171735230833292
0.01668810471892357
0.016281068325042725
0.01593836210668087
0.01564856432378292
0.015402384102344513
0.015192242339253426
0.0