### Imports

In [9]:
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import pickle as pkl
import networkx as nx

from tqdm import tqdm
from DGI.models import DGI, LogReg
from DGI.utils import process
from tensorflow.keras.optimizers import Adam, Nadam

from scripts.utils import *

### Class Definitions

In [4]:
class BaseEmbedder:
    def __init__(self, graph, embed_shape = (128,)):
        self.embed(graph)
        self.E = list(graph.edges())
        self.graph = graph
        self.embed_shape = embed_shape
    
    def embed(self, graph):
        raise NotImplementedError
    
    def get_embedding(self):
        raise NotImplementedError
        

In [36]:
class DGIEmbedding(BaseEmbedder):
    def __init__(self, embed_dim = 64, graph = None, feature_matrix = None, use_xm = False, debug = False, batch_size = 1, nb_epochs = 2500, patience = 20, ortho_ = 0.1, sparse_ = 0.1, lr = 1e-3, l2_coef = 0.0, drop_prob = 0.0, sparse = True, nonlinearity = 'prelu'):

        self.embed_dim = embed_dim
        self.debug = debug
        
        # Training Params
        self.graph = graph
        self.batch_size = batch_size
        self.nb_epochs = nb_epochs
        self.patience = patience
        self.lr = lr
        self.l2_coef = l2_coef
        self.feature_matrix = feature_matrix
        self.drop_prob = drop_prob
        self.hid_units = embed_dim
        self.sparse = sparse
        self.nonlinearity = nonlinearity
        self.use_xm = use_xm
        self.ortho_ = ortho_
        self.sparse_ = sparse_
        
        if graph is not None:
            self.embed()
        else:
            self.graph = None
    
    def embed(self):

        
        if self.feature_matrix is None:
            feature_matrix = np.identity(len(self.graph))
        else: 
            feature_matrix = self.feature_matrix

        adj = nx.to_scipy_sparse_array(self.graph)
        features = sp.lil_matrix(feature_matrix)
        features, _ = process.preprocess_features(features)

        nb_nodes = features.shape[0]
        ft_size = features.shape[1]

        adj = process.normalize_adj(adj + sp.eye(adj.shape[0]))

        if self.sparse:
            sp_adj = process.sparse_mx_to_torch_sparse_tensor(adj)
        else:
            adj = (adj + sp.eye(adj.shape[0])).todense()

        features = torch.FloatTensor(features[np.newaxis])
        if not self.sparse:
            adj = torch.FloatTensor(adj[np.newaxis])

        if self.feature_matrix is not None: 
            sense_features = torch.FloatTensor(self.feature_matrix)


        model = DGI(ft_size, self.hid_units, self.nonlinearity)
        optimiser = torch.optim.Adam(model.parameters(), lr = self.lr, weight_decay = self.l2_coef)

        b_xent = nn.BCEWithLogitsLoss()
        xent = nn.CrossEntropyLoss()
        cnt_wait = 0
        best = 1e9
        best_t = 0

        for epoch in tqdm(range(self.nb_epochs)):
            model.train()
            optimiser.zero_grad()

            idx = np.random.permutation(nb_nodes)
            shuf_fts = features[:, idx, :]

            lbl_1 = torch.ones(self.batch_size, nb_nodes)
            lbl_2 = torch.zeros(self.batch_size, nb_nodes)
            lbl = torch.cat((lbl_1, lbl_2), 1)

            if torch.cuda.is_available():
                shuf_fts = shuf_fts.cuda()
                lbl = lbl.cuda()

            logits = model(features, shuf_fts, sp_adj if self.sparse else adj, self.sparse, None, None, None) 
            print (logits.shape)

            if self.use_xm == True and feature_matrix is not None:
                sf = sense_features[idx, :]
                sense_mat = torch.einsum('ij, ik -> ijk', logits, sf)
                E = sense_mat
                y_norm = torch.diagonal(torch.matmul(logits, torch.transpose(logits, 0, 1)))
                sense_norm = torch.diagonal(torch.matmul(sf, torch.transpose(sf, 0, 1)))
                norm = torch.multiply(y_norm, sense_norm)
                print (E.shape)
                print (norm.shape)
                E = torch.transpose(torch.transpose(E, 0, 1) / norm)

                ortho_loss = (self.ortho_ * E) / self.batch_size
                sparse_loss = (self.sparse_ * torch.sum(torch.linalg.norm(E, ord = 1, axis = 0))) / batch_size

            loss = b_xent(logits, lbl) + ortho_loss + sparse_loss

            if self.debug:
                print('Loss:', loss)

            if loss < best:
                best = loss
                best_t = epoch
                cnt_wait = 0
                torch.save(model.state_dict(), 'best_dgi.pkl')
            else:
                cnt_wait += 1

            if cnt_wait == self.patience:
                if self.debug: 
                    print('Early stopping!')
                break

            loss.backward()
            optimiser.step()

        if self.debug: 
            print('Loading {}th epoch'.format(best_t))
        model.load_state_dict(torch.load('best_dgi.pkl'))

        self.node_model = model
        self.fitted = True

        embeds, _ = model.embed(features, sp_adj if self.sparse else adj, self.sparse, None)
        self.embeddings = embeds
    
    def get_embedding(self):
        return np.squeeze(self.embeddings.numpy())
    

In [6]:
with open('./data/email.pkl', 'rb') as file: 
    graph_dict = pkl.load(file)
    
graph = nx.Graph(nx.to_numpy_array(graph_dict['graph']))

In [37]:
dgi = DGIEmbedding(graph = graph, 
                   embed_dim = 64, 
                   feature_matrix = sense_features, 
                   use_xm = True, 
                   ortho_ = 0.1, 
                   sparse_ = 0.1, 
                   batch_size = 10)

  0%|                                                  | 0/2500 [00:00<?, ?it/s]

torch.Size([1, 1972])
torch.Size([986, 1972, 7])
torch.Size([986])





RuntimeError: The size of tensor a (7) must match the size of tensor b (986) at non-singleton dimension 2

In [10]:
sense_feat_dict, sense_features = get_sense_features(graph, ppr_flag = 'std')

uncorrelated_feats = ['Degree',
                    'Clustering Coefficient',
                    'Personalized Page Rank - Standard Deviation',
                    'Average Neighbor Degree',
                    'Average Neighbor Clustering',
                    'Eccentricity',
                    'Katz Centrality']
sense_features = sense_features[:, [list(sense_feat_dict).index(feat) for feat in uncorrelated_feats]]
sense_feat_dict = {feat : idx for idx, feat in enumerate(uncorrelated_feats)}

Calculating Degrees...                                   Calculating Average Neighbor Degree...                    Calculating Clustering Coefficient...                     Calculating Average Neighbor Clustering Coefficients...   Calculating Eccentricity...                               Calculating Page Rank...                                  Calculating Personalized Page Rank...                     

986it [01:45,  9.31it/s]


Done                                                      

  A = nx.adjacency_matrix(G, nodelist=nodelist, weight=weight).todense().T


In [19]:
sf = torch.FloatTensor(sense_features)
torch.diagonal(torch.matmul(sf, torch.transpose(sf, 0, 1)))

tensor([0.6181, 0.5628, 0.9938, 0.9559, 0.8618, 1.0928, 0.7465, 0.8630, 0.8639,
        0.8429, 0.7087, 0.7559, 0.6545, 0.6687, 0.5673, 0.8514, 0.5724, 0.6970,
        0.4402, 0.7604, 0.7440, 0.7943, 0.7221, 0.8261, 1.0566, 0.9244, 0.9343,
        0.8780, 0.8233, 0.8003, 0.7912, 0.7670, 1.0185, 0.9083, 1.2444, 0.6826,
        0.7661, 0.7836, 0.7537, 0.8212, 0.8221, 0.6622, 0.6772, 1.0986, 0.8488,
        0.8046, 0.8970, 0.7833, 0.8579, 0.8335, 0.8521, 0.5424, 0.4901, 0.4613,
        1.0689, 0.9767, 1.1368, 0.9257, 0.8808, 1.1523, 0.8627, 0.5037, 0.8086,
        0.7742, 0.5804, 0.9541, 0.9653, 0.8088, 0.7450, 0.8030, 0.7846, 1.0215,
        0.7370, 0.5899, 0.4515, 0.8545, 1.0824, 0.8220, 0.8494, 0.6706, 0.7349,
        0.8424, 0.9342, 0.8012, 0.5575, 0.3935, 0.7924, 0.6692, 1.0691, 1.1735,
        0.7741, 0.8809, 0.5204, 0.7094, 0.7348, 0.6425, 1.1050, 1.0988, 0.8956,
        0.9687, 0.8292, 0.9124, 0.9144, 0.6828, 0.9563, 0.7152, 0.7443, 0.9620,
        0.9270, 0.6514, 1.2519, 1.2761, 