### Imports

In [9]:
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import pickle as pkl
import networkx as nx

from tqdm import tqdm
from DGI.models import DGI, LogReg
from DGI.utils import process
from tensorflow.keras.optimizers import Adam, Nadam

from scripts.utils import *

### Class Definitions

In [4]:
class BaseEmbedder:
    def __init__(self, graph, embed_shape = (128,)):
        self.embed(graph)
        self.E = list(graph.edges())
        self.graph = graph
        self.embed_shape = embed_shape
    
    def embed(self, graph):
        raise NotImplementedError
    
    def get_embedding(self):
        raise NotImplementedError
        

In [90]:
class DGIEmbedding(BaseEmbedder):
    def __init__(self, embed_dim = 64, graph = None, feature_matrix = None, use_xm = False, debug = False, batch_size = 1, nb_epochs = 2500, patience = 20, ortho_ = 0.1, sparse_ = 0.1, lr = 1e-3, l2_coef = 0.0, drop_prob = 0.0, sparse = True, nonlinearity = 'prelu'):

        self.embed_dim = embed_dim
        self.debug = debug
        
        # Training Params
        self.graph = graph
        self.batch_size = batch_size
        self.nb_epochs = nb_epochs
        self.patience = patience
        self.lr = lr
        self.l2_coef = l2_coef
        self.feature_matrix = feature_matrix
        self.drop_prob = drop_prob
        self.hid_units = embed_dim
        self.sparse = sparse
        self.nonlinearity = nonlinearity
        self.use_xm = use_xm
        self.ortho_ = ortho_
        self.sparse_ = sparse_
        
        if graph is not None:
            self.embed()
        else:
            self.graph = None
    
    def embed(self):

        
        if self.feature_matrix is None:
            feature_matrix = np.identity(len(self.graph))
        else: 
            feature_matrix = self.feature_matrix

        adj = nx.to_scipy_sparse_array(self.graph)
        features = sp.lil_matrix(feature_matrix)
        features, _ = process.preprocess_features(features)

        nb_nodes = features.shape[0]
        ft_size = features.shape[1]

        adj = process.normalize_adj(adj + sp.eye(adj.shape[0]))

        if self.sparse:
            sp_adj = process.sparse_mx_to_torch_sparse_tensor(adj)
        else:
            adj = (adj + sp.eye(adj.shape[0])).todense()

        features = torch.FloatTensor(features[np.newaxis])
        if not self.sparse:
            adj = torch.FloatTensor(adj[np.newaxis])

        if self.feature_matrix is not None: 
            sense_features = torch.FloatTensor(self.feature_matrix)


        model = DGI(ft_size, self.hid_units, self.nonlinearity)
        optimiser = torch.optim.Adam(model.parameters(), lr = self.lr, weight_decay = self.l2_coef)

        b_xent = nn.BCEWithLogitsLoss()
        xent = nn.CrossEntropyLoss()
        cnt_wait = 0
        best = 1e9
        best_t = 0

        for epoch in tqdm(range(self.nb_epochs)):
            model.train()
            optimiser.zero_grad()

            idx = np.random.permutation(nb_nodes)
            shuf_fts = features[:, idx, :]

            lbl_1 = torch.ones(self.batch_size, nb_nodes)
            lbl_2 = torch.zeros(self.batch_size, nb_nodes)
            lbl = torch.cat((lbl_1, lbl_2), 1)

            if torch.cuda.is_available():
                shuf_fts = shuf_fts.cuda()
                lbl = lbl.cuda()

            logits = model(features, shuf_fts, sp_adj if self.sparse else adj, self.sparse, None, None, None) 
            
            if self.use_xm == True and feature_matrix is not None:
                
                sf = sense_features
                embeds, _ = model.embed(sf, sp_adj if self.sparse else adj, self.sparse, None)
                embeds = torch.squeeze(embeds)
                sense_mat = torch.einsum('ij, ik -> ijk', embeds, sf)
                E = sense_mat
                y_norm = torch.diagonal(torch.matmul(logits, torch.transpose(logits, 0, 1)))
                sense_norm = torch.diagonal(torch.matmul(sf, torch.transpose(sf, 0, 1)))
                norm = torch.multiply(y_norm, sense_norm)
                E = torch.transpose(torch.transpose(E, 0, 2) / norm, 0, 2)
                
                E_t = torch.transpose(E, 1, 2)
                E_o = torch.einsum('aij, ajh -> aih', E, E_t)
                E_o = torch.sum(E_o)
                ortho_loss = (self.ortho_ * E_o) / self.batch_size
                
                sparse_loss = (self.sparse_ * torch.sum(torch.linalg.norm(E, ord = 1, axis = 0))) / self.batch_size
        
                loss = b_xent(logits, lbl) + ortho_loss + sparse_loss
            else:
                loss = b_xent(logits, lbl)

            if self.debug:
                print('Loss:', loss)

            if loss < best:
                best = loss
                best_t = epoch
                cnt_wait = 0
                torch.save(model.state_dict(), 'best_dgi.pkl')
            else:
                cnt_wait += 1

            if cnt_wait == self.patience:
                if self.debug: 
                    print('Early stopping!')
                break

            loss.backward()
            optimiser.step()

        if self.debug: 
            print('Loading {}th epoch'.format(best_t))
        model.load_state_dict(torch.load('best_dgi.pkl'))

        self.node_model = model
        self.fitted = True

        embeds, _ = model.embed(features, sp_adj if self.sparse else adj, self.sparse, None)
        self.embeddings = embeds
    
    def get_embedding(self):
        return np.squeeze(self.embeddings.numpy())
    

In [87]:
with open('./data/email.pkl', 'rb') as file: 
    graph_dict = pkl.load(file)
    
graph = nx.Graph(nx.to_numpy_array(graph_dict['graph']))

In [88]:
sense_feat_dict, sense_features = get_sense_features(graph, ppr_flag = 'std')

uncorrelated_feats = ['Degree',
                    'Clustering Coefficient',
                    'Personalized Page Rank - Standard Deviation',
                    'Average Neighbor Degree',
                    'Average Neighbor Clustering',
                    'Eccentricity',
                    'Katz Centrality']
sense_features = sense_features[:, [list(sense_feat_dict).index(feat) for feat in uncorrelated_feats]]
sense_feat_dict = {feat : idx for idx, feat in enumerate(uncorrelated_feats)}

Calculating Personalized Page Rank...                     

986it [00:27, 36.20it/s]


Done                                                      

  A = nx.adjacency_matrix(G, nodelist=nodelist, weight=weight).todense().T


In [113]:
# dgi_og = DGIEmbedding(graph = graph, 
#                    embed_dim = 32, 
#                    feature_matrix = sense_features, 
#                    use_xm = False, 
#                    ortho_ = 0, 
#                    sparse_ = 0, 
#                    batch_size = 1)

# dgi_id = DGIEmbedding(graph = graph, 
#                    embed_dim = 32, 
#                    feature_matrix = np.identity(len(graph)), 
#                    use_xm = False, 
#                    ortho_ = 0, 
#                    sparse_ = 0, 
#                    batch_size = 1)

dgi_plus = DGIEmbedding(graph = graph, 
                   embed_dim = 32, 
                   feature_matrix = sense_features, 
                   use_xm = True, 
                   ortho_ = 10, 
                   sparse_ = 1, 
                   batch_size = 1)

100%|███████████████████████████████████████| 2500/2500 [00:51<00:00, 48.67it/s]


In [114]:
embed_og = dgi_og.get_embedding()
embed_og = (embed_og - np.min(embed_og)) / np.ptp(embed_og)
feature_dict_og = find_feature_membership(input_embed = embed_og,
                                                    embed_name = 'DGI',
                                                    sense_features = sense_features,
                                                    sense_feat_dict = sense_feat_dict,
                                                    top_k = 8,
                                                    solver = 'nmf')

explain_og = feature_dict_og['explain_norm']
explain_og = (explain_og - np.min(explain_og)) / np.ptp(explain_og)


embed_plus = dgi_plus.get_embedding()
embed_plus = (embed_plus - np.min(embed_plus)) / np.ptp(embed_plus)
feature_dict_plus = find_feature_membership(input_embed = embed_plus,
                                                    embed_name = 'DGI+',
                                                    sense_features = sense_features,
                                                    sense_feat_dict = sense_feat_dict,
                                                    top_k = 8,
                                                    solver = 'nmf')

explain_plus = feature_dict_plus['explain_norm']
explain_plus = (explain_plus - np.min(explain_plus)) / np.ptp(explain_plus)

In [115]:
fig = go.Figure()
fig.add_trace(go.Heatmap(z = explain_og,
                         x = list(sense_feat_dict), 
                         ))
fig.update_layout(title_text = 'DGI : ' + str(np.linalg.norm(explain_og, ord = 'nuc')), 
                  xaxis_title_text = 'Sense Features', 
                  yaxis_title_text = 'Dimensions')

fig.show()

fig = go.Figure()
fig.add_trace(go.Heatmap(z = explain_plus,
                         x = list(sense_feat_dict), 
                         ))
fig.update_layout(title_text = 'DGI+ : ' + str(np.linalg.norm(explain_plus, ord = 'nuc')), 
                  xaxis_title_text = 'Sense Features', 
                  yaxis_title_text = 'Dimensions')

fig.show()