In [34]:
import gust
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from scipy.spatial.distance import squareform

In [35]:
torch.set_default_tensor_type('torch.cuda.FloatTensor')

A, X, _, z = gust.load_dataset('cora_ml').standardize().unpack()
adj = torch.FloatTensor(A.toarray()).cuda()

In [36]:
N = A.shape[0]
D = 32

Z = nn.Parameter(torch.empty(N, D).normal_(std=0.1))
opt = torch.optim.Adam([Z], lr=1e-2)
e1, e2 = A.nonzero()

In [28]:
def sig(Z, b=0.1): 
    dist = torch.matmul(Z,Z.T)
    print('dist: ', dist, type(dist))
    sigdist = nn.Sigmoid(dist)
    print('sigdist: ', sigdist, type(sigdist))
    neg_term = torch.log(1-nn.Sigmoid(dist+b))
    print('neg_term: ', neg_term, type(neg_term))
    neg_term[np.diag_indices(N)] = 0.0
    print('neg_term new: ', neg_term, type(neg_term))
    pos_term = nn.LogSimoid(e1,e2.T)
    print('pos_term: ', pos_term, type(pos_term))
    neg_term[e1,e2]= 0.0
    print('neg_term last: ', neg_term, type(neg_term))
    return -(pos_term.sum() + neg_term.sum()) / Z.shape[0]**2

In [37]:
def sig2(Z, b=0.1):
    #kernel: theta(z_i,z_j)=sigma(z_i^Tz_j+b)
    # Initialization
    N,d=Z.shape
    
    #compute f(z_i, z_j) = sigma(z_i^Tz_j+b)
    dot=torch.matmul(Z,Z.T)
    logits =dot+b
    
    #transform adj
    ind=torch.triu_indices(N,N,offset=1)
    logits = logits[ind[0], ind[1]] 
    labels = adj[ind[0],ind[1]]
    
    
    #compute p(A|Z)
    loss = F.binary_cross_entropy_with_logits(logits, labels, weight=None, size_average=None, reduce=None, reduction='mean')

    return loss

In [23]:
def dist(Z, eps=1e-5):
    dist = ((Z[:, None] - Z[None, :]).pow(2.0).sum(-1) + eps).sqrt()
    print('dist: ', dist, type(dist))
    helper = -torch.expm1(-dist) + 1e-5
    print('helper: ', helper, type(helper))
    neg_term = torch.log(-torch.expm1(-dist) + 1e-5)
    print('neg_term: ', neg_term, type(neg_term))
    neg_term[np.diag_indices(N)] = 0.0
    print('neg_term new: ', neg_term, type(neg_term))
    pos_term = -dist[e1, e2]   
    print('pos_term: ', pos_term, type(pos_term))
    neg_term[e1, e2] = 0.0
    print('neg_term new: ', neg_term, type(neg_term))
    return -(pos_term.sum() + neg_term.sum()) / Z.shape[0]**2

In [39]:
for epoch in range(1000):
    opt.zero_grad()
    loss = sig2(Z)
    loss.backward()
    opt.step()
    print(loss.item())

0.7445080280303955
0.7444092035293579
0.7444114089012146
0.7444655299186707
0.7444756031036377
0.7444179654121399
0.7443504333496094
0.7443345189094543
0.7443631887435913
0.7443763017654419
0.7443443536758423
0.7443006038665771
0.744286298751831
0.7442971467971802
0.7443013191223145
0.7442848086357117
0.7442617416381836
0.7442489266395569
0.7442482113838196
0.7442488074302673
0.7442406415939331
0.7442262768745422
0.7442153692245483
0.74421226978302
0.7442108988761902
0.7442045211791992
0.7441942095756531
0.744186282157898
0.7441825866699219
0.7441791892051697
0.7441728115081787
0.744165301322937
0.7441591620445251
0.7441543340682983
0.7441496253013611
0.7441436648368835
0.7441369891166687
0.7441309094429016
0.7441254258155823
0.7441198825836182
0.7441136837005615
0.7441068887710571
0.7441003322601318
0.7440942525863647
0.7440879344940186
0.7440808415412903
0.7440736293792725
0.7440666556358337
0.744059681892395
0.7440523505210876
0.7440445423126221
0.7440367341041565
0.74402916431427
0

0.7309325933456421
0.7308791279792786
0.7308257818222046
0.7307723760604858
0.7307192087173462
0.7306659817695618
0.7306129932403564
0.7305603623390198
0.7305083274841309
0.730456531047821
0.7304046750068665
0.7303525805473328
0.7302997708320618
0.7302463054656982
0.7301934957504272
0.7301415205001831
0.7300896048545837
0.7300373911857605
0.729985773563385
0.7299350500106812
0.7298839092254639
0.7298315167427063
0.7297791838645935
0.7297273874282837
0.7296758890151978
0.7296246290206909
0.7295737266540527
0.7295224070549011
0.72947096824646
0.7294196486473083
0.7293686270713806
0.7293171286582947
0.7292655110359192
0.7292141914367676
0.7291629910469055
0.7291117906570435
0.7290607690811157
0.7290101051330566
0.7289593815803528
0.7289084792137146
0.7288571000099182
0.7288059592247009
0.7287551760673523
0.7287047505378723
0.7286543846130371
0.7286038398742676
0.7285529971122742
0.7285022735595703
0.7284519672393799
0.7284018397331238
0.7283510565757751
0.7283002138137817
0.72824960947036

0.7139185070991516
0.7138963341712952
0.7138742804527283
0.7138522863388062
0.7138302326202393
0.7138081789016724
0.713786244392395
0.713764488697052
0.7137426733970642
0.7137208580970764
0.7136994004249573
0.7136781215667725
0.7136574387550354
0.7136369943618774
0.7136179804801941
0.7135996222496033
0.7135842442512512
0.7135680317878723
0.713556706905365
0.7135366201400757
0.7135224938392639
0.7134935855865479
0.7134738564491272
0.7134478092193604
0.7134184837341309
0.7133849263191223
0.7133551239967346
0.7133395671844482
0.7133275270462036
0.7133067846298218
0.7132741212844849
0.7132443785667419
0.713225781917572
0.7132105827331543
0.7131893038749695
0.7131631374359131
0.7131392359733582
0.7131187319755554
0.7130984663963318
0.7130760550498962
0.7130537629127502
0.7130331993103027
0.713014543056488
0.7129935622215271
0.7129712104797363
0.7129489779472351
0.7129279375076294
0.7129080295562744
0.7128878831863403
0.7128667235374451
0.7128457427024841
0.7128251791000366
0.712804555892944

In [None]:
tensor([[0.0032, 0.9944, 0.9861,  ..., 0.9856, 0.9849, 0.9889],
        [0.9944, 0.0032, 0.9867,  ..., 0.9849, 0.9918, 0.9850],
        [0.9861, 0.9867, 0.0032,  ..., 0.9867, 0.9932, 0.9920],
        ...,
        [0.9856, 0.9849, 0.9867,  ..., 0.0032, 0.9917, 0.9857],
        [0.9849, 0.9918, 0.9932,  ..., 0.9917, 0.0032, 0.9906],
        [0.9889, 0.9850, 0.9920,  ..., 0.9857, 0.9906, 0.0032]],
       grad_fn=<AddBackward0>) <class 'torch.Tensor'>