In [None]:
import os
import sys
import math
from statistics import mean
import pandas as pd
import numpy as np
from scipy.special import comb
from scipy import interp
import scipy
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from noesis import Noesis
import logging
import warnings
from sklearn.metrics import roc_auc_score

from sklearn import metrics

import networkx as nx
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.collections import LineCollection
from matplotlib.colors import ListedColormap, BoundaryNorm, LinearSegmentedColormap

sns.set_theme()
sns.set_style("whitegrid", {"grid.color": ".6", "grid.linestyle": ":"})

logging.getLogger('matplotlib.font_manager').disabled = True
logging.getLogger('matplotlib.axes').disabled = True
warnings.filterwarnings("ignore", category=DeprecationWarning) 
tqdm.pandas()

sys.path.append("../src/models/")


#import mutual_information
#import katz
#import random_walk
import tgn
import jodie
import dyrep
import tgn_viz
data_path = "../data/processed"

ns = Noesis()


# Utils

In [None]:
def predict_dataset(data_name):
    train = pd.read_csv(os.path.join(data_path, f"{data_name}/train/ml_{data_name}.csv"), index_col=0)
    test = pd.read_csv(os.path.join(data_path, f"{data_name}/test/ml_{data_name}.csv"), index_col=0)
    # CM and Jaccard constant 0, beacuse it is a bipartite graph
    
    pred, embed = tgn.predict(data_name, "tgn_ablation_time", ablation='time', seed=0, n_runs=5, n_epoch=10)
    for i, val in enumerate(pred):
        test[f"tgn_ablation_time_{i}"] = val
    
    pred, embed = jodie.predict(data_name, "jodie_ablation_time", ablation='time', seed=0, n_runs=5, n_epoch=10)
    for i, val in enumerate(pred):
        test[f"jodie_ablation_time{i}"] = val
        
    pred, embed = dyrep.predict(data_name, "dyrep_ablation_time", ablation='time', seed=0, n_runs=5, n_epoch=10)
    for i, val in enumerate(pred):
        test[f"dyrep_ablation_time{i}"] = val
        
    pred, embed = tgn.predict(data_name, "tgn", seed=0, n_runs=5, n_epoch=10)
    for i, val in enumerate(pred):
        test[f"tgn_{i}"] = val
    
    pred, embed = jodie.predict(data_name, "jodie", seed=0, n_runs=5, n_epoch=10)
    for i, val in enumerate(pred):
        test[f"jodie_{i}"] = val
        
    pred, embed = dyrep.predict(data_name, "dyrep", seed=0, n_runs=5, n_epoch=10)
    for i, val in enumerate(pred):
        test[f"dyrep_{i}"] = val
        
    test["mutual"], test["mutual_normalized"] = mutual_information.predict(train, test, normalize=True)
    test['katz'] = katz.predict(os.path.join(data_path, f"{data_name}/train/{data_name}_train.json"),
                            os.path.join(data_path, f"{data_name}/train/{data_name}_train.gml"),
                            test,
                            ns)
    res = random_walk.predict(os.path.join(data_path, f"{data_name}/train/{data_name}_train.json"),
                            os.path.join(data_path, f"{data_name}/train/{data_name}_train.gml"),
                            test,
                            ns)
    for i, val in enumerate(res):
        test[f"random_walk_{i}"] = val
    
    test.to_csv(f"../data/results/test_{data_name}.csv")
    
    return test

def plot_results(test):
    col_names = {"Mutual Information": ["mutual_normalized"],
                "Random Walk": [f"random_walk_{x}" for x in range(5)],
                "TGN (ablate timestamps)": [f"tgn_ablation_time_{x}" for x in range(5)],
                "TGN": [f"tgn_{x}" for x in range(5)],
                "Jodie (ablate timestamps)": [f"jodie_ablation_time{x}" for x in range(5)],
                "Jodie": [f"jodie_{x}" for x in range(5)],
                "DyRep (ablate timestamps)": [f"dyrep_ablation_time{x}" for x in range(5)],
                "DyRep": [f"dyrep_{x}" for x in range(5)]}
    
    
    label = test.ground_truth.values
    for title, col in col_names.items():
        tprs = []
        base_fpr = np.linspace(0, 1, 101)
        auc=[]
        for pred in col:
            pred_prob = test[pred].values
            fpr, tpr, thresh = metrics.roc_curve(label,
                                                 pred_prob)
            plt.plot(fpr, tpr, color=[52/255,97/255,120/255], alpha=0.15)
            auc.append(metrics.roc_auc_score(label, pred_prob))
            
            tpr = interp(base_fpr, fpr, tpr)
            tpr[0] = 0.0
            tprs.append(tpr)
        tprs = np.array(tprs)
        mean_tprs = tprs.mean(axis=0)
        std = tprs.std(axis=0)
            
        tprs_upper = np.minimum(mean_tprs + std, 1)
        tprs_lower = mean_tprs - std
        
        plt.plot(base_fpr, mean_tprs, color=[52/255,97/255,120/255], lw=2)
        plt.fill_between(base_fpr, tprs_lower, tprs_upper, color=[143/255,195/255,216/255], alpha=0.3)

        plt.plot([0, 1], [0, 1],'--', color = [252/255,97/255,31/255])
        plt.xlim([-0.01, 1.01])
        plt.ylim([-0.01, 1.01])
        plt.ylabel('True Positive Rate')
        plt.xlabel('False Positive Rate')
        # plt.axes().set_aspect('equal', 'datalim')
        plt.title(f"{title} ROC curve, AUC: {mean(auc):.4f}")
        plt.show()

        # PR curve
        tprs = []
        base_fpr = np.linspace(0, 1, 101)
        auc=[]
        for pred in col:
            pred_prob = test[pred].values
            precision, recall, thresholds = metrics.precision_recall_curve(label, pred_prob)
            auc.append(metrics.auc(recall, precision))
            
            plt.plot(recall, precision, color=[52/255,97/255,120/255], alpha=0.15)
            
            reversed_recall = np.fliplr([recall])[0]
            reversed_precision = np.fliplr([precision])[0]
            tpr = interp(base_fpr, reversed_recall, reversed_precision)
            tpr[0] = 1.0
            tprs.append(tpr)
        tprs = np.array(tprs)
        mean_tprs = tprs.mean(axis=0)
        std = tprs.std(axis=0)
            
        tprs_upper = np.minimum(mean_tprs + std, 1)
        tprs_lower = mean_tprs - std
        
        plt.plot(base_fpr, mean_tprs, color=[52/255,97/255,120/255], lw=2)
        plt.fill_between(base_fpr, tprs_lower, tprs_upper, color=[143/255,195/255,216/255], alpha=0.3)
        plt.plot([0, 1], [0.5, 0.5],'--', color = [252/255,97/255,31/255])
        plt.xlim([-0.01, 1.01])
        plt.ylim([-0.01, 1.01])
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.title(f"{title} PR curve, AUC: {mean(auc):.4f}")
        plt.show()

# Whole dataset

In [None]:
test = predict_dataset('lastfm')

In [None]:
test = predict_dataset('mooc')

In [None]:
test = predict_dataset('reddit')

In [None]:
test = predict_dataset('wikipedia')

In [None]:
ns.end()

# Time split dataset prediction

In [None]:
def get_last_models(path):
    df = pd.DataFrame([x.split('-') for x in os.listdir(path)], columns = ["folder", "model", "run", "epoch"])
    df_g = df.groupby(["folder", "model", "run"]).max().reset_index()
    files = {int(y):os.path.join(path, '-'.join(x)) for x,y in zip(df_g.values, df_g.run.values)}
    return files

In [None]:
def predict_split(dataset_name):
    path = f"../data/processed/split_data/{dataset_name}/"
    for folder in os.listdir(path):
        datapath = os.path.join(path,folder)
        print(datapath, folder)
        n_split = int(folder[-1])
        if n_split > 0:
            old_folder = f"{dataset_name}_{n_split-1}"
        #pred, embed = tgn.predict(folder, "tgn_ablation_time", ablation='time', seed=0, n_runs=5, n_epoch=10, data_path = datapath)
        #pred, embed = jodie.predict(folder, "jodie_ablation_time", ablation='time', seed=0, n_runs=5, n_epoch=10, data_path = datapath)
        #pred, embed = dyrep.predict(folder, "dyrep_ablation_time", ablation='time', seed=0, n_runs=5, n_epoch=10, data_path = datapath)
        if not os.path.isdir(f"../data/results/{folder}/tgn"):
            if n_split >0:
                models_to_load = get_last_models(f"../models/{old_folder}/tgn/saved_checkpoints")
            else:
                models_to_load = None
            print("Train model TGN...")
            pred, embed = tgn.predict(folder, "tgn", seed=11,
                                      n_runs=3, n_epoch=10,
                                      data_path = datapath,
                                      models_to_load=models_to_load)
            
        if not os.path.isdir(f"../data/results/{folder}/jodie"):
            if n_split >0:
                models_to_load = get_last_models(f"../models/{old_folder}/jodie/saved_checkpoints")
            else:
                models_to_load = None
            print("Train model Jodie...")
            pred, embed = jodie.predict(folder, "jodie",
                                        seed=112, n_runs=3, n_epoch=10,
                                        data_path = datapath,
                                        models_to_load=models_to_load)
            
        if not os.path.isdir(f"../data/results/{folder}/dyrep"):
            if n_split >0:
                models_to_load = get_last_models(f"../models/{old_folder}/dyrep/saved_checkpoints")
            else:
                models_to_load = None
            print("Train model DyRep...")
            pred, embed = dyrep.predict(folder, "dyrep",
                                        seed=112, n_runs=3, n_epoch=10,
                                        data_path = datapath,
                                        models_to_load=models_to_load)

In [None]:
for dataset_name in ['wikipedia', 'reddit', 'mooc', 'lastfm']:
    predict_split(dataset_name)

# TGN visualization training

In [None]:
def inter_from_256(x):
    return np.interp(x=x,xp=[0,255],fp=[0,1])

rgb_list = [[52,97,120], [92, 166, 154], [252,97,31], [253, 193, 14]]
all_red = []
all_green = []
all_blue = []
for rgb in rgb_list:
    all_red.append(rgb[0])
    all_green.append(rgb[1])
    all_blue.append(rgb[2])
# build each section
n_section = len(all_red) - 1
red = tuple([(1/n_section*i,inter_from_256(v),inter_from_256(v)) for i,v in enumerate(all_red)])
green = tuple([(1/n_section*i,inter_from_256(v),inter_from_256(v)) for i,v in enumerate(all_green)])
blue = tuple([(1/n_section*i,inter_from_256(v),inter_from_256(v)) for i,v in enumerate(all_blue)])
cdict = {'red':red,'green':green,'blue':blue}
new_cmap = LinearSegmentedColormap('new_cmap',segmentdata=cdict)


def plot_node_mvmt(x,y,t):
    points = np.array([x, y]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)

    fig, ax = plt.subplots()

    # Create a continuous norm to map from data points to colors
    norm = plt.Normalize(t.min(), t.max())
    lc = LineCollection(segments, cmap=new_cmap, norm=norm, alpha=0.3)
    # Set the values used for colormapping
    lc.set_array(t)
    lc.set_linewidth(1)
    line = ax.add_collection(lc)
    fig.colorbar(line, ax=ax)
    # plt.scatter(x,y)

    sns.scatterplot(x=x, y=y, c=t, cmap=new_cmap, alpha=0.5)

    height = y.max() - y.min()
    width = x.max() - x.min()

    plt.xlim([x.min() - 0.1*width, x.max() + 0.1*width])
    plt.ylim([y.min() - 0.1*height, y.max() + 0.1*height])

    plt.show()
    

def calc_and_viz(data_name, test, affinity_merge_layer, node=7144):
    print(affinity_merge_layer.upper())
    
    res = tgn_viz.predict(data_name, f"tgn_viz_{affinity_merge_layer}", seed=0, n_runs=1, n_epoch=10, affinity_merge_layer=affinity_merge_layer)
    preds, nodes, times, embeds = res
    for i, val in enumerate(preds):
        test[f"tgn_viz_{affinity_merge_layer}_{i}"] = val

    sc = StandardScaler()
    sc.fit(embeds)
    X_train_std = sc.transform(embeds)
    pca = PCA(n_components=2)
    embed_pca = pca.fit_transform(X_train_std)

    for n in [2134, 7066, 3958, 7058, 4309, 6419]:
        print(n)
        try:
            mask = nodes == n
            plot_node_mvmt(embeds[mask,0],embeds[mask,1],times[mask])
            plot_node_mvmt(embed_pca[mask,0],embed_pca[mask,1],times[mask])
        except:
            continue
    return test

def predict_viz_dataset(data_name, node=7144):
    train = pd.read_csv(os.path.join(data_path, f"{data_name}/train/ml_{data_name}.csv"), index_col=0)
    test = pd.read_csv(os.path.join(data_path, f"{data_name}/test/ml_{data_name}.csv"), index_col=0)
    
    # test = calc_and_viz(data_name, test, "default", node=node)

    test = calc_and_viz(data_name, test, "extra_layers", node=node) # 0.5407455559081602
    test = calc_and_viz(data_name, test, "extra_layers_had", node=node) # 0.7102923331217205
    test = calc_and_viz(data_name, test, "extra_layers_sincos", node=node) # 0.5
    test = calc_and_viz(data_name, test, "extra_layers_extra6", node=node) # 0.7357540433690669
    test = calc_and_viz(data_name, test, "extra_layers_extra6_sincoshad", node=node) # 0.5464958168453279

    test = calc_and_viz(data_name, test, "extra_layers_relu", node=node) # 0.3875494501740585
    test = calc_and_viz(data_name, test, "extra_layers_had_relu", node=node) # 0.5107732625871076
    test = calc_and_viz(data_name, test, "extra_layers_sincos_relu", node=node) # 0.5688160155678049
    test = calc_and_viz(data_name, test, "extra_layers_extra6_relu", node=node) # 0.5718021841862889
    test = calc_and_viz(data_name, test, "extra_layers_extra6_sincoshad_relu", node=node) # 0.415627888526674

    test = calc_and_viz(data_name, test, "extra_layers_tanh", node=node) # 0.4294699816102592
    test = calc_and_viz(data_name, test, "extra_layers_had_tanh", node=node) # 0.8038525607161809
    test = calc_and_viz(data_name, test, "extra_layers_sincos_tanh", node=node) # 0.5
    test = calc_and_viz(data_name, test, "extra_layers_extra6_tanh", node=node) # 0.4583873437927973
    test = calc_and_viz(data_name, test, "extra_layers_extra6_sincoshad_tanh", node=node) # 0.6201307368437561

    test = calc_and_viz(data_name, test, "extra_layers_sigmoid", node=node) # 0.3564450016425934
    test = calc_and_viz(data_name, test, "extra_layers_had_sigmoid", node=node) # 0.37128277873516746
    test = calc_and_viz(data_name, test, "extra_layers_sincos_sigmoid", node=node) # 0.4470434151774143
    test = calc_and_viz(data_name, test, "extra_layers_extra6_sigmoid", node=node) # 0.5001662306099629
    test = calc_and_viz(data_name, test, "extra_layers_extra6_sincoshad_sigmoid", node=node) # 

    return test

In [None]:
test = predict_viz_dataset('mooc')