In [1]:
%load_ext autoreload
%autoreload 2
import gust  # library for loading graph data

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist
import time
from scipy.spatial.distance import squareform
torch.set_default_tensor_type('torch.cuda.FloatTensor')
%matplotlib inline
sns.set_style('whitegrid')

In [2]:
# Load the dataset using `gust` library
# graph.standardize() makes the graph unweighted, undirected and selects
# the largest connected component
# graph.unpack() returns the necessary vectors / matrices

A, X, _, y = gust.load_dataset('cora').standardize().unpack()
# A - adjacency matrix 
# X - attribute matrix - not needed
# y - node labels

if (A != A.T).sum() > 0:
    raise RuntimeError("The graph must be undirected!")

if (A.data != 1).sum() > 0:
    raise RuntimeError("The graph must be unweighted!")

In [3]:
num_nodes = A.shape[0]
num_edges = A.sum()

# Convert adjacency matrix to a CUDA Tensor
adj = torch.FloatTensor(A.toarray()).cuda()

In [4]:
torch.manual_seed(123)
# Define the embedding matrix
embedding_dim = 64
emb = nn.Parameter(torch.empty(num_nodes, embedding_dim).normal_(0.0, 1.0))



# Initialize the bias
# The bias is initialized in such a way that if the dot product between two embedding vectors is 0 
# (i.e. z_i^T z_j = 0), then their connection probability is sigmoid(b) equals to the 
# background edge probability in the graph. This significantly speeds up training
edge_proba = num_edges / (num_nodes**2 - num_nodes)
bias_init = np.log(edge_proba / (1 - edge_proba))
b = nn.Parameter(torch.Tensor([bias_init]))


# Regularize the embeddings but don't regularize the bias
# The value of weight_decay has a significant effect on the performance of the model (don't set too high!)
opt = torch.optim.Adam([
    {'params': [emb], 'weight_decay': 1e-7}, {'params': [b]}],
    lr=1e-2)


In [9]:
def compute_loss_ber_sig(adj, emb, b=0.1): 
    #kernel: theta(z_i,z_j)=sigma(z_i^Tz_j+b)
    # Initialization
    N,d=emb.shape
    
    #compute f(z_i, z_j) = sigma(z_i^Tz_j+b)
    dot=torch.matmul(emb,emb.T)
    logits =dot+b
    
    #transform adj
    ind=torch.triu_indices(N,N,offset=1)
    logits = logits[ind[0], ind[1]] 
    labels = adj[ind[0],ind[1]]
    
    
    #compute p(A|Z)
    loss = F.binary_cross_entropy_with_logits(logits, labels, weight=None, size_average=None, reduce=None, reduction='mean')
    return loss

def compute_loss_ber_exp1(adj, emb, b=0.1):
    #Init
    N,d=emb.shape
    gamma=0.001
    
    print('adj', adj.size(), adj.dtype, type(adj), adj.requires_grad, adj.device)
    print('emb', emb.size(), emb.dtype, type(emb), emb.requires_grad, emb.device)
    
    
    #get indices of upper triangular matrix
    ind=torch.triu_indices(N,N,offset=1)
    labels = adj[ind[0],ind[1]]
    
    #compute f(z_i,z_j) = exp(-gamma||z_i-z_j||^2)
    dist=F.pdist(emb, p=2)
    print('dist', dist, dist.size(), dist.dtype, type(dist), dist.requires_grad, dist.device)
    
    #put distances into upper triangular matrix
    #dist_matrix0 = nn.Parameter(torch.empty(num_nodes, embedding_dim))
    
    logits=torch.exp(-gamma * dist**2)
    print('logits: ', logits.size(), logits.dtype, type(logits), logits.requires_grad, logits.device)

    
    print('labels', labels.size(), labels.dtype, type(labels), labels.requires_grad, labels.device)
    #compute loss
    loss = F.binary_cross_entropy_with_logits(logits, labels, reduction='mean')
    print(loss)
    return loss

def compute_loss_d1(adj, emb, b=0.0): 
    """Compute the rdf distance of the Bernoulli model."""
    # Initialization
    start_time = time.time()
    N,d=emb.shape
    squared_euclidian = torch.zeros(N,N).cuda()
    gamma= 0.1
    end_time= time.time()
    duration= end_time -start_time
    #print(f' Time for initialization = {duration:.5f}')
    # Compute squared euclidian
    start_time = time.time()
    for index, embedding in enumerate(emb):
        sub =  embedding-emb + 10e-9
        squared_euclidian[index,:]= torch.sum(torch.pow(sub,2),1)
    end_time= time.time()
    duration= end_time -start_time
    #print(f' Time for euclidian = {duration:.5f}')
    # Compute exponentianl
    start_time = time.time()
    radial_exp = torch.exp (-gamma * torch.sqrt(squared_euclidian))
    loss = F.binary_cross_entropy(radial_exp, adj, reduction='none')
    loss[np.diag_indices(adj.shape[0])] = 0.0
    end_time= time.time()
    duration= end_time -start_time
    #print(f' Time for loss  = {duration:.5f}')
    return loss.mean()


def compute_loss_ber_exp2(adj, emb, b=0.1):
    #Init
    N,d=emb.shape

    #get indices of upper triangular matrix
    ind=torch.triu_indices(N,N,offset=1)
    
    #compute f(z_i, z_j) = sigma(z_i^Tz_j+b)
    dot=torch.matmul(emb,emb.T)
    print('dist: ', dot, dot.size(), type(dot))
    logits=1-torch.exp(-dot)
    logits=logits[ind[0],ind[1]]
    labels = adj[ind[0],ind[1]]
    print('logits: ', logits, logits.size(), type(logits))
    
    #compute loss
    loss = F.binary_cross_entropy_with_logits(logits, labels, reduction='mean')

    return loss

def compute_loss_KL(adj, emb, b=0.0):
    #adj = torch.FloatTensor(A.toarray()).cuda()
    degree= torch.from_numpy(adj.sum(axis=1))
    inv_degree=torch.diagflat(1/degree).cuda()
    P = inv_degree.mm(adj) 
    loss = -(P*torch.log( 10e-9+ F.softmax(emb.mm(emb.t() ),dim=1,dtype=torch.float)))
    return loss.mean()

In [10]:
max_epochs = 1000
display_step = 250
compute_loss = compute_loss_ber_exp1

for epoch in range(max_epochs):
    opt.zero_grad()
    loss = compute_loss(adj, emb, b)
    loss.backward()
    opt.step()
    # Training loss is printed every display_step epochs
    if epoch == 0 or (epoch + 1) % display_step == 0:
        print(f'Epoch {epoch+1:4d}, loss = {loss.item():.5f}')

adj torch.Size([2485, 2485]) torch.float32 <class 'torch.Tensor'> False cuda:0
emb torch.Size([2485, 64]) torch.float32 <class 'torch.nn.parameter.Parameter'> True cuda:0
dist tensor([24.3387, 22.5350, 26.6627,  ..., 24.8556, 27.0524, 21.2161],
       grad_fn=<PdistBackward>) torch.Size([3086370]) torch.float32 <class 'torch.Tensor'> True cuda:0
logits:  torch.Size([3086370]) torch.float32 <class 'torch.Tensor'> True cuda:0
labels torch.Size([3086370]) torch.float32 <class 'torch.Tensor'> False cuda:0
tensor(1.0108, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)


RuntimeError: CUDA error: invalid configuration argument

In [71]:
exp1:
    dist:  tensor([[ 0.0000, 11.7170, 11.2746,  ..., 12.0473, 10.1984, 11.3381],
        [ 0.0000,  0.0000, 12.3547,  ..., 11.8961, 11.2169, 12.2756],
        [ 0.0000,  0.0000,  0.0000,  ...,  9.9336, 11.4303, 12.5992],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000, 12.0894, 12.9180],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000, 10.4011],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       grad_fn=<IndexPutBackward>) torch.Size([2485, 2485]) <class 'torch.Tensor'>
exp2: 
    dist:  tensor([[ 6.8276e+01,  2.0167e+00,  3.5997e-02,  ..., -5.0973e+00,
          1.1831e+01,  4.7927e+00],
        [ 2.0167e+00,  7.3249e+01, -1.0391e+01,  ..., -8.7648e-01,
          3.3881e+00, -3.9294e+00],
        [ 3.5997e-02, -1.0391e+01,  5.9067e+01,  ...,  1.3639e+01,
         -6.0809e+00, -1.5023e+01],
        ...,
        [-5.0973e+00, -8.7648e-01,  1.3639e+01,  ...,  6.6841e+01,
         -9.9354e+00, -1.5082e+01],
        [ 1.1831e+01,  3.3881e+00, -6.0809e+00,  ..., -9.9354e+00,
          5.9443e+01,  1.0573e+01],
        [ 4.7927e+00, -3.9294e+00, -1.5023e+01,  ..., -1.5082e+01,
          1.0573e+01,  6.9868e+01]], grad_fn=<MmBackward>) torch.Size([2485, 2485]) <class 'torch.Tensor'>
        
        
exp1: 
    logits:  tensor([0.8717, 0.8806, 0.8570,  ..., 0.8640, 0.8463, 0.8975],
       grad_fn=<IndexBackward>) torch.Size([3086370]) <class 'torch.Tensor'>
exp2: 
    logits:  tensor([ 8.7085e-01,  2.0776e-02, -5.2848e+05,  ..., -2.1115e+04,
        -3.6931e+06,  9.9997e-01], grad_fn=<IndexBackward>) torch.Size([3086370]) <class 'torch.Tensor'>

SyntaxError: invalid syntax (<ipython-input-71-54a037b251c7>, line 1)