### Imports

In [1]:
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import pickle as pkl
import networkx as nx

from tqdm import tqdm
from DGI.models import DGI, LogReg
from DGI.utils import process
from tensorflow.keras.optimizers import Adam, Nadam

from scripts.utils import *

  from .autonotebook import tqdm as notebook_tqdm


### Class Definitions

In [2]:
class BaseEmbedder:
    def __init__(self, graph, embed_shape = (128,)):
        self.embed(graph)
        self.E = list(graph.edges())
        self.graph = graph
        self.embed_shape = embed_shape
    
    def embed(self, graph):
        raise NotImplementedError
    
    def get_embedding(self):
        raise NotImplementedError
        

In [3]:
class DGIEmbedding(BaseEmbedder):
    def __init__(self, embed_dim = 64, graph = None, feature_matrix = None, use_xm = False, debug = False, batch_size = 1, nb_epochs = 2500, patience = 20, ortho_ = 0.1, sparse_ = 0.1, lr = 1e-3, l2_coef = 0.0, drop_prob = 0.0, sparse = True, nonlinearity = 'prelu', model_name = ''):

        self.embed_dim = embed_dim
        self.debug = debug
        
        # Training Params
        self.graph = graph
        self.batch_size = batch_size
        self.nb_epochs = nb_epochs
        self.patience = patience
        self.lr = lr
        self.l2_coef = l2_coef
        self.feature_matrix = feature_matrix
        self.drop_prob = drop_prob
        self.hid_units = embed_dim
        self.sparse = sparse
        self.nonlinearity = nonlinearity
        self.use_xm = use_xm
        self.ortho_ = ortho_
        self.sparse_ = sparse_
        self.model_name = model_name

        self.time_per_epoch = None
        
        if graph is not None:
            self.embed()
        else:
            self.graph = None
    
    def embed(self):

        
        if self.feature_matrix is None:
            feature_matrix = np.identity(len(self.graph))
        else: 
            feature_matrix = self.feature_matrix

        adj = nx.to_scipy_sparse_array(self.graph)
        features = sp.lil_matrix(feature_matrix)
        features, _ = process.preprocess_features(features)

        nb_nodes = features.shape[0]
        ft_size = features.shape[1]

        adj = process.normalize_adj(adj + sp.eye(adj.shape[0]))

        if self.sparse:
            sp_adj = process.sparse_mx_to_torch_sparse_tensor(adj)
        else:
            adj = (adj + sp.eye(adj.shape[0])).todense()

        features = torch.FloatTensor(features[np.newaxis])
        if not self.sparse:
            adj = torch.FloatTensor(adj[np.newaxis])
            if torch.cuda.is_available():
                adj = adj.cuda()

        if self.feature_matrix is not None: 
            sense_features = torch.FloatTensor(self.feature_matrix)
            if torch.cuda.is_available():
                sense_features = sense_features.cuda()


        model = DGI(ft_size, self.hid_units, self.nonlinearity)
        if torch.cuda.is_available():
            model = model.cuda()
        optimiser = torch.optim.Adam(model.parameters(), lr = self.lr, weight_decay = self.l2_coef)

        b_xent = nn.BCEWithLogitsLoss()
        xent = nn.CrossEntropyLoss()
        cnt_wait = 0
        best = 1e9
        best_t = 0
        
        start_time = time.time()
        for epoch in tqdm(range(self.nb_epochs)):
            model.train()
            optimiser.zero_grad()

            idx = np.random.permutation(nb_nodes)
            shuf_fts = features[:, idx, :]

            lbl_1 = torch.ones(self.batch_size, nb_nodes)
            lbl_2 = torch.zeros(self.batch_size, nb_nodes)
            lbl = torch.cat((lbl_1, lbl_2), 1)

            if torch.cuda.is_available():
                shuf_fts = shuf_fts.cuda()
                lbl = lbl.cuda()
                sp_adj = sp_adj.cuda()
                features = features.cuda()
            
            logits = model(features, shuf_fts, sp_adj if self.sparse else adj, self.sparse, None, None, None) 
            
            if self.use_xm == True and feature_matrix is not None:
                
                start_idx = 0
                loop = True
                
                ortho_loss = 0
                sparse_loss = 0
                xm_batch_size = 128
                
                sf = sense_features
                embeds, _ = model.embed(sf, sp_adj if self.sparse else adj, self.sparse, None)
                                
                while loop:
                    end_idx = start_idx + xm_batch_size
                    if end_idx > len(self.graph):
                        loop = False
                        end_idx = len(self.graph)
                        
                    
                    sf = sense_features[start_idx : end_idx]
                    embeds_ = torch.squeeze(embeds)[start_idx : end_idx]
                    
                    
                    sense_mat = torch.einsum('ij, ik -> ijk', embeds_, sf)
                    E = sense_mat
                    y_norm = torch.diagonal(torch.matmul(embeds_, torch.transpose(embeds_, 0, 1)))
                    sense_norm = torch.diagonal(torch.matmul(sf, torch.transpose(sf, 0, 1)))
                    norm = torch.multiply(y_norm, sense_norm)
                    E = torch.transpose(torch.transpose(E, 0, 2) / norm, 0, 2)
                    E = (E - torch.amin(E, dim = [-1, -2], keepdim = True)) / (torch.amax(E, dim = [-1, -2], keepdim = True) - torch.amin(E, dim = [-1, -2], keepdim = True))

                    E_t = torch.transpose(E, 1, 2)
                    E_o = torch.einsum('aij, ajh -> aih', E, E_t)
                    E_o = torch.sum(E_o)
                    batch_ortho_loss = (self.ortho_ * E_o) / self.batch_size

                    batch_sparse_loss = (self.sparse_ * torch.sum(torch.linalg.norm(E, ord = 1, axis = 0))) / self.batch_size
                        
                    ortho_loss += batch_ortho_loss
                    sparse_loss += batch_sparse_loss
                    
                    start_idx = end_idx
                    
                loss = b_xent(logits, lbl) + ortho_loss + sparse_loss
            else:
                loss = b_xent(logits, lbl)

            if self.debug:
                print('Loss:', loss)

            if loss < best:
                best = loss
                best_t = epoch
                cnt_wait = 0
                torch.save(model.state_dict(), self.model_name + '.pkl')
            else:
                cnt_wait += 1

            if cnt_wait == self.patience:
                if self.debug: 
                    print('Early stopping!')
                break

            loss.backward()
            optimiser.step()
            
        self.time_per_epoch = (time.time() - start_time) / epoch

        if self.debug: 
            print('Loading {}th epoch'.format(best_t))
        model.load_state_dict(torch.load(self.model_name + '.pkl'))

        self.node_model = model
        self.fitted = True

        embeds, _ = model.embed(features, sp_adj if self.sparse else adj, self.sparse, None)
        self.embeddings = embeds
    
    def get_embedding(self):
        if torch.cuda.is_available():
            return np.squeeze(self.embeddings.cpu().numpy())
        return np.squeeze(self.embeddings.numpy())
    




In [None]:
with open('./data/squirrel.pkl', 'rb') as file: 
    graph_dict = pkl.load(file)
    
graph = nx.Graph(nx.to_numpy_array(graph_dict))    
graph = nx.Graph(nx.to_numpy_array(graph))


sense_feat_dict, sense_features = get_sense_features(graph, ppr_flag = 'std')


Calculating Personalized Page Rank...                     

4733it [01:46, 47.30it/s]

In [5]:
uncorrelated_feats = ['Degree',
                    'Clustering Coefficient',
                    'Personalized Page Rank - Standard Deviation',
                    'Average Neighbor Degree',
                    'Average Neighbor Clustering',
                    'Eccentricity',
                    'Katz Centrality']
sense_features = sense_features[:, [list(sense_feat_dict).index(feat) for feat in uncorrelated_feats]]
sense_feat_dict = {feat : idx for idx, feat in enumerate(uncorrelated_feats)}

In [15]:
# dgi_og = DGIEmbedding(graph = graph, 
#            embed_dim = 128, 
#            feature_matrix = sense_features, 
#            use_xm = False, 
#            ortho_ = 0, 
#            sparse_ = 0, 
#            batch_size = 1, 
#            model_name = '-')
# embed_og = dgi_og.get_embedding()
# embed_og = (embed_og - np.min(embed_og)) / np.ptp(embed_og)
# feature_dict_og = find_feature_membership(input_embed = embed_og,
#                                                     embed_name = 'DGI-SF',
#                                                     sense_features = sense_features,
#                                                     sense_feat_dict = sense_feat_dict,
#                                                     top_k = 8,
#                                                     solver = 'nmf')

# explain_og = feature_dict_og['explain_norm']
# error_og = sense_features * np.log((sense_features + 1e-10) / ((embed_og @ feature_dict_og['explain_norm']) + 1e-10)) - sense_features + (embed_og @ feature_dict_og['explain_norm'])
# explain_og = (explain_og - np.min(explain_og)) / np.ptp(explain_og)


dgi_plus = DGIEmbedding(graph = graph, 
           embed_dim = 128, 
           feature_matrix = sense_features, 
           use_xm = True, 
           ortho_ = 100, 
           sparse_ = 100, 
           batch_size = 1, 
           model_name = '-', 
           patience = 500)
embed_plus = dgi_plus.get_embedding()
embed_plus = (embed_plus - np.min(embed_plus)) / np.ptp(embed_plus)
feature_dict_plus = find_feature_membership(input_embed = embed_plus,
                                                    embed_name = 'DGI+XM',
                                                    sense_features = sense_features,
                                                    sense_feat_dict = sense_feat_dict,
                                                    top_k = 8,
                                                    solver = 'nmf')

explain_plus = feature_dict_plus['explain_norm']
error_plus = sense_features * np.log((sense_features + 1e-10) / ((embed_plus @ feature_dict_plus['explain_norm']) + 1e-10)) - sense_features + (embed_plus @ feature_dict_plus['explain_norm'])
explain_plus = (explain_plus - np.min(explain_plus)) / np.ptp(explain_plus)

 30%|████████████▏                           | 758/2500 [00:56<02:09, 13.42it/s]

overflow encountered in exp


invalid value encountered in divide



In [16]:
fig = go.Figure()
fig.add_trace(go.Heatmap(z = explain_og,
                         x = list(sense_feat_dict), 
                         ))
fig.update_layout(title_text = 'DGI : ' + str(np.linalg.norm(explain_og, ord = 'nuc')), 
                  xaxis_title_text = 'Sense Features', 
                  yaxis_title_text = 'Dimensions')

fig.show()

fig = go.Figure()
fig.add_trace(go.Heatmap(z = explain_plus,
                         x = list(sense_feat_dict), 
                         ))
fig.update_layout(title_text = 'DGI+ : ' + str(np.linalg.norm(explain_plus, ord = 'nuc')), 
                  xaxis_title_text = 'Sense Features', 
                  yaxis_title_text = 'Dimensions')

fig.show()

In [17]:
Y_plus = embed_plus
sense_mat = tf.einsum('ij, ik -> ijk', Y_plus, sense_features)
Y_plus_norm = tf.linalg.diag_part(tf.matmul(Y_plus, Y_plus, transpose_b = True), k = 0)
sense_norm = tf.linalg.diag_part(tf.matmul(sense_features, sense_features, transpose_b = True), k = 0)
norm = Y_plus_norm * tf.cast(sense_norm, tf.float32)
D_plus = tf.transpose(tf.transpose(sense_mat) / norm)
D_plus = (D_plus - tf.reshape(tf.reduce_min(D_plus, axis = [-1, -2]), (-1, 1, 1))) / tf.reshape(tf.reduce_max(D_plus, axis = [-1, -2]) - tf.reduce_min(D_plus, axis = [-1, -2]), (-1, 1, 1))


Y_og = embed_og
sense_mat = tf.einsum('ij, ik -> ijk', Y_og, sense_features)
Y_og_norm = tf.linalg.diag_part(tf.matmul(Y_og, Y_og, transpose_b = True), k = 0)
sense_norm = tf.linalg.diag_part(tf.matmul(sense_features, sense_features, transpose_b = True), k = 0)
norm = Y_og_norm * tf.cast(sense_norm, tf.float32)
D_og = tf.transpose(tf.transpose(sense_mat) / norm)
D_og = (D_og - tf.reshape(tf.reduce_min(D_og, axis = [-1, -2]), (-1, 1, 1))) / tf.reshape(tf.reduce_max(D_og, axis = [-1, -2]) - tf.reduce_min(D_og, axis = [-1, -2]), (-1, 1, 1))


In [18]:
norm_og = [np.linalg.norm(D_og[node, :, :], ord = 'nuc') for node in tqdm(range(len(graph)))]
norm_plus = [np.linalg.norm(D_plus[node, :, :], ord = 'nuc') for node in tqdm(range(len(graph)))]



100%|██████████████████████████████████████| 1186/1186 [00:02<00:00, 552.13it/s]
100%|██████████████████████████████████████| 1186/1186 [00:02<00:00, 570.13it/s]


In [19]:
diff = np.array(norm_og) - np.array(norm_plus)
fig = go.Figure()
fig.add_trace(go.Histogram(x = diff))
fig.update_layout(title_text = 'Distribution of Difference In Nuclear Norms - DGI vs DGI+XM', 
                  xaxis_title_text = 'Difference In Nuclear Norms - DGI vs DGI+XM', 
                  yaxis_title_text = 'Frequency', 
                  plot_bgcolor = 'white', 
                  paper_bgcolor = 'white', 
                  font = dict(size = 30))
fig.show()

In [20]:
fig = go.Figure()

fig.add_trace(go.Histogram(x = norm_og, 
                           name = 'DGI'))
fig.add_trace(go.Histogram(x = norm_plus, 
                           name = 'DGI+XM'))


fig.update_layout(title_text = 'Distribution of Nuclear Norm of Node Explain Matrix', 
                  xaxis_title_text = 'Nuclear Norm', 
                  yaxis_title_text = 'Frequency', 
                  paper_bgcolor = 'white', 
                  plot_bgcolor = 'white', 
                  font = dict(size = 20))
fig.show()

In [79]:
fig = go.Figure()
fig.add_trace(go.Heatmap(z = D_plus[110, :, :]))

In [80]:
fig = go.Figure()
fig.add_trace(go.Heatmap(z = D_og[110, :, :]))