In [None]:
import os

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

import pytorch_lightning as pl
import torch_geometric as pyg

import sklearn.metrics as mt

from tqdm.notebook import tqdm

import clip_graph as cg

import utils as ut

In [None]:
os.chdir(os.path.expanduser('~/github/congrat'))

In [None]:
device = 'cpu'

In [None]:
pl.seed_everything(2969591811)

# What should we evaluate?

In [None]:
datasets = {
    'pubmed': {
        'svd_init_dataset': 'configs/eval-datasets/pubmed/causal.yaml',
        'svd_init_baseline': 'lightning_logs/gnn-pretrain/pubmed/version_5/',
        'svd_init_key': 'x',
        
        'models': {
            'causal': {
                'base': 'lightning_logs/clip-graph/inductive-causal/pubmed/version_0/',
                'sim10': 'lightning_logs/clip-graph/inductive-causal/pubmed/version_2/',
            },
            'masked': {
                'base': 'lightning_logs/clip-graph/inductive-masked/pubmed/version_0/',
                'sim10': 'lightning_logs/clip-graph/inductive-masked/pubmed/version_2/',
            },
        },
    },

    'trex': {
        'svd_init_dataset': 'configs/eval-datasets/trex/causal.yaml',
        'svd_init_baseline': 'lightning_logs/gnn-pretrain/trex/version_5/',
        'svd_init_key': 'x',
        
        'models': {
            'causal': {
                'base': 'lightning_logs/clip-graph/inductive-causal/trex/version_0/',
                'sim10': 'lightning_logs/clip-graph/inductive-causal/trex/version_2/',
            },
            'masked': {
                'base': 'lightning_logs/clip-graph/inductive-masked/trex/version_0/',
                'sim10': 'lightning_logs/clip-graph/inductive-masked/trex/version_2/',
            },
        },
    },

    'twitter_small': {  # don't use a hyphen! things will break!
        'svd_init_dataset': 'configs/eval-datasets/twitter-small/causal.yaml',
        'svd_init_baseline': 'lightning_logs/gnn-pretrain/twitter-small/version_3/',
        'svd_init_key': 'tsvd',
        
        'models': {
            'causal': {
                'base': 'lightning_logs/clip-graph/inductive-causal/twitter-small/version_5/',
                'sim10': 'lightning_logs/clip-graph/inductive-causal/twitter-small/version_6/',
            },
            'masked': {
                'base': 'lightning_logs/clip-graph/inductive-masked/twitter-small/version_5/',
                'sim10': 'lightning_logs/clip-graph/inductive-masked/twitter-small/version_6/',
            },
        },
    },
    
    'pubmed_directed': {
        'svd_init_dataset': 'configs/eval-datasets/pubmed/causal-directed.yaml',
        'svd_init_baseline': 'lightning_logs/gnn-pretrain-directed/pubmed/version_5/',
        'svd_init_key': 'x',
        
        'models': {
            'causal': {
                'base': 'lightning_logs/clip-graph-directed/inductive-causal/pubmed/version_0/',
            },
            'masked': {
                'base': 'lightning_logs/clip-graph-directed/inductive-masked/pubmed/version_0/',
            },
        },
    },

    'trex_directed': {
        'svd_init_dataset': 'configs/eval-datasets/trex/causal-directed.yaml',
        'svd_init_baseline': 'lightning_logs/gnn-pretrain-directed/trex/version_5/',
        'svd_init_key': 'x',
        
        'models': {
            'causal': {
                'base': 'lightning_logs/clip-graph-directed/inductive-causal/trex/version_0/',
            },
            'masked': {
                'base': 'lightning_logs/clip-graph-directed/inductive-masked/trex/version_0/',
            },
        },
    },

    'twitter_small_directed': {  # don't use a hyphen! things will break!
        'svd_init_dataset': 'configs/eval-datasets/twitter-small/causal-directed.yaml',
        'svd_init_baseline': 'lightning_logs/gnn-pretrain-directed/twitter-small/version_3/',
        'svd_init_key': 'tsvd',

        'models': {
            'causal': {
                'base': 'lightning_logs/clip-graph-directed/inductive-causal/twitter-small/version_1/',
            },
            'masked': {
                'base': 'lightning_logs/clip-graph-directed/inductive-masked/twitter-small/version_1/',
            },
        },
    },
}

# Do the evaluation

In [None]:
def test(z, pos_edge_index, neg_edge_index=None, eps=1e-15):
    if neg_edge_index is None:
        neg_edge_index = pyg.utils.negative_sampling(pos_edge_index, z.size(0))
    
    pos_y = z.new_ones(pos_edge_index.size(1))
    neg_y = z.new_zeros(neg_edge_index.size(1))
    y = torch.cat([pos_y, neg_y], dim=0).long()
    y = y.detach().cpu().numpy()

    decoder = pyg.nn.models.autoencoder.InnerProductDecoder()
    pos_dec = decoder(z, pos_edge_index, sigmoid=True)
    neg_dec = decoder(z, neg_edge_index, sigmoid=True)
    pred = torch.cat([pos_dec, neg_dec], dim=0)
    pred = pred.detach().cpu().numpy()

    return {
        'auc': mt.roc_auc_score(y, pred),
        'ap': mt.average_precision_score(y, pred),
        
        'recon': (
            -torch.log(pos_dec + eps).mean() +
            -torch.log(1 - neg_dec + eps).mean()
        ).item(),
        
        # very good scores from our model, but poorly calibrated;
        # let's just report the AUC/AP
        # 'accuracy': mt.accuracy_score(y, pred > 0.5),
        # 'precision': mt.precision_score(y, pred > 0.5),
        # 'recall': mt.recall_score(y, pred > 0.5),
        # 'f1': mt.f1_score(y, pred > 0.5),
    }

In [None]:
results = []

for dataset, paths in tqdm(datasets.items()):
    #
    # Dataset and specific objects to input to models
    #
    
    dm = cg.utils.datamodule_from_yaml(paths['svd_init_dataset'])['dm']

    tx = getattr(dm.train_dataset.dataset.graph_data, paths['svd_init_key']).to(device)
    tei = dm.train_dataset.dataset.graph_data.edge_index.to(device)
    tnei = dm.train_dataset.dataset.graph_data.neg_edge_index.to(device)

    vx = getattr(dm.test_dataset.dataset.graph_data, paths['svd_init_key']).to(device)
    vei = dm.test_dataset.dataset.graph_data.edge_index.to(device)
    vnei = dm.test_dataset.dataset.graph_data.neg_edge_index.to(device)

    #
    # Baselines
    #
    
    ## Fine-tuned for graph autoencoding
    gn_model = cg.scoring.interpret_ckpt_dir(paths['svd_init_baseline'], dm)['model'].model.encoder
    gn_model = gn_model.to(device)
    
    ## Same architecture, randomly initialized, totally untrained
    ckpt = cg.scoring.interpret_ckpt_dir(paths['svd_init_baseline'], dm)
    cls = getattr(cg.models, ckpt['config']['model']['init_args']['model_class_name'])
    params = ckpt['config']['model']['init_args']['model_params']
    bl_model = cls(**params)
    bl_model = bl_model.to(device)

    #
    # Generate embeddings
    #

    embs = {}

    ## First, baselines
    with torch.no_grad():
        embs[f'{dataset}-baseline'] = gn_model(vx, vei)['output']
        embs[f'{dataset}-untrained'] = bl_model(vx, vei)['output']
        
    ## Other models
    for lmtype in tqdm(paths['models'].keys()):
        for mod, path in tqdm(paths['models'][lmtype].items()):
            cg_model = cg.scoring.interpret_ckpt_dir(path, dm)['model'].model
            cg_model = cg_model.to(device)

            embs[f'{dataset}-{lmtype}_{mod}'] = F.normalize(cg_model.embed_nodes(vx, vei), p=2, dim=1)

    res = pd.Series({
        k : test(v, vei, vnei)
        for k, v in tqdm(embs.items())
    }).apply(pd.Series)
    
    res['dataset'] = res.index.str.split('-').map(lambda s: s[0])
    res['model'] = res.index.str.split('-').map(lambda s: s[1])
    res = res.reset_index(drop=True).set_index(['dataset', 'model'])
    
    results += [res]

results = pd.concat(results, axis=0)
results = results.sort_index()

results.to_csv('data/link-prediction-eval.csv', index=True)

# Examine results

In [None]:
results = pd.read_csv('data/link-prediction-eval.csv')
results = results.set_index(['dataset', 'model'])

## Display

In [None]:
with pd.option_context('display.max_rows', None):
    display(results)

## Tables for paper

In [None]:
mods = [
    'baseline', 'untrained',
    'causal_base', 'masked_base',
    'causal_sim10', 'masked_sim10',
]

tmp = results.loc[pd.IndexSlice[:, mods], :].sort_index()
tmp = tmp['auc'].reset_index().copy()

model_map = {
    **{
        k : k
        for k in tmp['model'].unique()
        if k not in ('baseline', 'untrained')
    },
    
    **{
        'baseline': 'baseline_svd',
        'untrained': 'baseline_untrained',
    }
}

tmp['model'] = tmp['model'].map(model_map)
tmp['type'] = tmp['model'].apply(lambda s: s.split('_')[0])
tmp['model'] = tmp['model'].apply(lambda s: s.split('_')[1])

tmp = tmp.loc[tmp['model'] != 'untrained', :]

tmp = tmp.set_index(['dataset', 'type', 'model'])
tmp = tmp.sort_index()
tmp = tmp.unstack(0)
tmp.columns = tmp.columns.droplevel(0)
tmp = tmp.loc[['causal', 'masked', 'baseline'], :]

tmp.index = tmp.index.set_levels(tmp.index.levels[0].map({
    'causal': 'Causal',
    'masked': 'Masked',
    'baseline': 'GNN Autoencoder',
}), level=0)

tmp.index = tmp.index.set_levels(tmp.index.levels[1].map({
    'base': r'$\alpha = 0$',
    'sim10': r'$\alpha = 0.1$',
    'svd': 'SVD',
    'untrained': 'Untrained GNN',
}), level=1)

tmp.index.names = ['', '']
tmp.index = tmp.index.swaplevel()
tmp = tmp.sort_index()

tmp = tmp[['pubmed', 'trex', 'twitter_small', 'pubmed_directed',
           'trex_directed', 'twitter_small_directed']]

tmp.columns = tmp.columns.map({
    'pubmed': 'Undirected-Pubmed',
    'trex': 'Undirected-TRex',
    'twitter_small': 'Undirected-Twitter',

    'pubmed_directed': 'Directed-Pubmed',
    'trex_directed': 'Directed-TRex',
    'twitter_small_directed': 'Directed-Twitter'
})

tmp.columns = pd.MultiIndex.from_frame(pd.DataFrame(tmp.columns.to_series().reset_index(drop=True).str.split('-').tolist()))

tmp.columns.name = ''
tmp.columns.names = ['', '']

tmp.index.names = ['alpha', 'txt']
tmp = tmp.reset_index()
tmp['txt'] = 'ConGraT-' + tmp['txt']
tmp['txt'] = tmp['txt'] + ' (' + tmp['alpha'] + ')'

tmp.loc[tmp['txt'] == 'ConGraT-GNN Autoencoder (Baseline)', 'txt'] = 'GAT Autoencoder (Baseline)'

tmp = tmp.drop('alpha', axis=1)

# tmp = tmp.apply(np.roll, shift=1)

tmp = tmp.set_index('txt')
tmp.index.name = ''

In [None]:
def bold_except_last_row(s):
    return pd.concat([
        ut.bold_above_thresh(s[:-1], s[-2]),
        pd.Series([''], index=[s.index[-1]]),
    ])

tab = tmp.style \
    .format(precision=3, na_rep='--') \
    .apply(bold_except_last_row, axis=0)
    
with pd.option_context('display.html.use_mathjax', True):
    display(tab)

In [None]:
print(tab.to_latex(
        hrules = True,
        column_format = 'lcccccc',
        position = 'ht',
        label = 'tab:link-prediction',
        multicol_align = '|c',
        position_float = 'centering',
        environment = 'table*',
        convert_css = True,
    ))