In [None]:
import os
import json

import numpy as np
import pandas as pd
import scipy.stats as ss

import matplotlib as mp
%matplotlib inline
import matplotlib.pyplot as plt

import pytorch_lightning as pl

import utils as ut

In [None]:
os.chdir(os.path.expanduser('~/github/congrat'))

In [None]:
pl.seed_everything(2969591811)

# Load results

In [None]:
prefix = 'data/evals/'
split = 'test'

undirected_files = [
    'causal-base',
    'causal-sim10',

    'masked-base',
    'masked-sim10',
]

directed_files = [
    'causal-base',
    'masked-base',
]

datasets = {
    'pubmed': {
        'name': 'Pubmed (Undirected)',
        'files': undirected_files,
    },
    
    'trex': {
        'name': 'TRex (Undirected)',
        'files': undirected_files,
    },
    
    'twitter-small': {
        'name': 'Twitter (Undirected)',
        'files': undirected_files,
    },
    
    'pubmed-directed': {
        'name': 'Pubmed (Directed)',
        'files': directed_files,
    },
    
    'trex-directed': {
        'name': 'TRex (Directed)',
        'files': directed_files,
    },
    
    'twitter-small-directed': {
        'name': 'Twitter (Directed)',
        'files': directed_files,
    },
}

dataset_names = {k : v['name'] for k, v in datasets.items()}

evals = {}
for dataset, obj in datasets.items():
    evals[dataset] = {}
    for file in obj['files']:
        with open(os.path.join(prefix, f'{dataset}-{split}-{file}.json'), 'rt') as f:
            evals[dataset][file] = json.load(f)

# Prepare data

## Top-k accuracy

In [None]:
exact_topk = pd.concat([
    pd.DataFrame(evals[dataset][file]['eval_topk_accuracy']['exact_point']).assign(dataset=dataset, model=file)
    for dataset in datasets for file in datasets[dataset]['files']
], axis=0).set_index(['dataset', 'model', 'k']).sort_index()

In [None]:
runs_topk = pd.concat([
    pd.DataFrame(evals[dataset][file]['eval_topk_accuracy']['resample']).assign(dataset=dataset, model=file)
    for dataset in datasets for file in datasets[dataset]['files']
], axis=0).set_index(['dataset', 'model', 'run', 'k']).sort_index()

## Within-node similarity pre-post

In [None]:
exact_node_recovery = pd.concat([
    pd.Series(evals[dataset][file]['eval_within_node_dist_pre_post']['exact_point']).rename((dataset, file))
    for dataset in datasets for file in datasets[dataset]['files']
], axis=1).T

exact_node_recovery.index.names = ['dataset', 'model']

In [None]:
runs_node_recovery = pd.concat([
    pd.DataFrame(evals[dataset][file]['eval_within_node_dist_pre_post']['resample']).assign(dataset=dataset, model=file)
    for dataset in datasets for file in datasets[dataset]['files']
], axis=0).set_index(['dataset', 'model', 'run'])

## Correlation of distances

In [None]:
runs_dist_corr = pd.concat([
    pd.DataFrame(evals[dataset][file]['eval_emb_dist_coupling']['resample']).assign(dataset=dataset, model=file)
    for dataset in datasets for file in datasets[dataset]['files']
], axis=0).set_index(['dataset', 'model', 'run'])

## Embedding distances vs graph distances

In [None]:
runs_emb_vs_graph = pd.concat([
    pd.DataFrame(evals[dataset][file]['eval_emb_dist_vs_graph_dist']['resample']).assign(dataset=dataset, model=file)
    for dataset in datasets for file in datasets[dataset]['files']
], axis=0).set_index(['dataset', 'model', 'run'])

# Top-k accuracy

In [None]:
topk_stats = []

for dataset in datasets:
    ## Masked
    cols = [f for f in datasets[dataset]['files'] if f.startswith('masked')]
    tmpmb = exact_topk.loc[dataset, ...].unstack(0).loc[:, ('comp_acc', 'masked-base')].rename('baseline')
    
    tmpm = exact_topk.loc[dataset, ...].unstack(0).loc[:, pd.IndexSlice['trained_acc', cols]]
    tmpm.columns = tmpm.columns.droplevel(0)
    tmpm.columns = [c.replace('masked-', '') if c.startswith('masked-') else c for c in tmpm.columns]
    tmpm = pd.concat([tmpm, tmpmb], axis=1)
    tmpm = tmpm.unstack(0).reset_index().rename({'level_0': 'model', 0: 'acc'}, axis=1)
    tmpm['lmtype'] = 'masked'
    
    ## Causal
    cols = [f for f in datasets[dataset]['files'] if f.startswith('causal')]
    tmpcb = exact_topk.loc[dataset, ...].unstack(0).loc[:, ('comp_acc', 'causal-base')].rename('baseline')
    
    tmpc = exact_topk.loc[dataset, ...].unstack(0).loc[:, pd.IndexSlice['trained_acc', cols]]
    tmpc.columns = tmpc.columns.droplevel(0)
    tmpc.columns = [c.replace('causal-', '') if c.startswith('causal-') else c for c in tmpc.columns]
    tmpc = pd.concat([tmpc, tmpcb], axis=1)
    tmpc = tmpc.unstack(0).reset_index().rename({'level_0': 'model', 0: 'acc'}, axis=1)
    tmpc['lmtype'] = 'causal'

    tmp = pd.concat([tmpm, tmpc], axis=0)
    tmp['dataset'] = dataset
    
    topk_stats += [tmp]

topk_stats = pd.concat(topk_stats, axis=0)

topk_stats['directed'] = topk_stats['dataset'].apply(lambda s: 'directed' if s.endswith('-directed') else 'undirected')
topk_stats['dataset'] = topk_stats['dataset'].str.replace('-directed', '')
topk_stats['dataset'] = topk_stats['dataset'].map({
    'pubmed': 'pubmed',
    'trex': 'trex',
    'twitter-small': 'twitter',
})

#topk_stats = topk_stats.set_index(['dataset', 'directed', 'lmtype', 'model', 'k']).sort_index()

In [None]:
ncol = 2
nrow = topk_stats['dataset'].nunique()

fig = plt.figure(figsize=(5 * ncol, 5 * nrow), constrained_layout=True)
fig.suptitle('Top-k Accuracy: Predicting Origin Node for Text')

subfigs = fig.subfigures(nrows=nrow, ncols=1)#, sharex=True, sharey=True)
for i, (subfig, dataset) in enumerate(zip(subfigs, topk_stats['dataset'].unique())):
    tmp = topk_stats.loc[topk_stats['dataset'] == dataset, :].drop('dataset', axis=1)
    tmpm = tmp.loc[tmp['lmtype'] == 'masked', :].drop('lmtype', axis=1)
    tmpc = tmp.loc[tmp['lmtype'] == 'causal', :].drop('lmtype', axis=1)
    
    subfig.suptitle(dataset.title())
    axes = subfig.subplots(nrows=1, ncols=ncol)
    
    ## Masked
    bas = tmpm.loc[tmpm['model'] == 'baseline', ['directed', 'k', 'acc']].set_index(['directed', 'k']).unstack(0)
    bas.columns = bas.columns.droplevel(0)
    bas['directed'].plot(ax=axes[0], label='Directed Baseline')
    bas['undirected'].plot(ax=axes[0], label='Undirected Baseline')
    
    bas = tmpm.loc[tmpm['model'] == 'base', ['directed', 'k', 'acc']].set_index(['directed', 'k']).unstack(0)
    bas.columns = bas.columns.droplevel(0)
    bas['directed'].plot(ax=axes[0], label=r'Directed, $\alpha = 0$')
    bas['undirected'].plot(ax=axes[0], label=r'Undirected, $\alpha = 0$')

    bas = tmpm.loc[tmpm['model'] == 'sim10', ['k', 'acc']].set_index('k').unstack(0)
    bas.index = bas.index.droplevel()
    bas.plot(ax=axes[0], label=r'Undirected, $\alpha = 0.1$')
        
    ## Causal
    bas = tmpc.loc[tmpc['model'] == 'baseline', ['directed', 'k', 'acc']].set_index(['directed', 'k']).unstack(0)
    bas.columns = bas.columns.droplevel(0)
    bas['directed'].plot(ax=axes[1], label='Directed Baseline')
    bas['undirected'].plot(ax=axes[1], label='Undirected Baseline')
    
    bas = tmpc.loc[tmpc['model'] == 'base', ['directed', 'k', 'acc']].set_index(['directed', 'k']).unstack(0)
    bas.columns = bas.columns.droplevel(0)
    bas['directed'].plot(ax=axes[1], label=r'Directed, $\alpha = 0$')
    bas['undirected'].plot(ax=axes[1], label=r'Undirected, $\alpha = 0$')

    bas = tmpc.loc[tmpc['model'] == 'sim10', ['k', 'acc']].set_index('k').unstack(0)
    bas.index = bas.index.droplevel()
    bas.plot(ax=axes[1], label=r'Undirected, $\alpha = 0.1$')
    
    axes[0].set_title('Masked')
    axes[1].set_title('Causal')
    
    axes[0].legend()
    axes[1].legend()

In [None]:
tmp = exact_topk.loc[pd.IndexSlice[:, :, [1, 5, 10]], :] \
    [['trained_acc', 'comp_acc']] \
    .unstack(1)

cols = [c for c in tmp.columns if c[0] == 'trained_acc' or c[1] in ('causal-base', 'masked-base')]
tmp = tmp.loc[:, cols]

tmp.columns = [c[0] + '_' + c[1] for c in tmp.columns]
tmp = tmp.rename({
    'comp_acc_causal-base': 'trained_acc_causal-baseline',
    'comp_acc_masked-base': 'trained_acc_masked-baseline',
}, axis=1)
tmp.columns = [c.replace('trained_acc_', '') if c.startswith('trained_acc_') else c for c in tmp.columns]
tmp = tmp[sorted(tmp.columns)]

tmp.columns = tmp.columns.str.split('-', 1, expand=True)
tmp.columns = tmp.columns.set_levels(tmp.columns.levels[0].str.title(), level=0)
tmp.columns = tmp.columns.set_levels(tmp.columns.levels[1].map({
    'base': 'sim0',
    'sim10': 'sim10',
    'sim50': 'sim50',
    'simexp': 'sim10_exp',
    'simexptt': 'sim10_exptt',
    'baseline': 'Baseline',
}), level=1)
tmp = tmp[sorted(tmp.columns)]

tmp.index = tmp.index.set_levels(tmp.index.levels[0].map(dataset_names), level=0)

tmp = tmp.stack(0).reorder_levels([0, 2, 1]).sort_index()

tab = tmp.style \
    .format(precision=3) \
    .apply(ut.bold_largest_by_row, axis=1)

display(tab)

In [None]:
print(tab.to_latex(
        hrules = True,
        column_format = 'lr|rrrrrr',
        position = 'ht',
        label = 'tab:topk_acc',
        multicol_align = '|c',
        position_float = 'centering',
        environment = 'table',
        convert_css = True,
    ))

# Summarize other metrics

In [None]:
stats = {
    'Distance Coupling': runs_dist_corr.groupby(level=[0,1]).mean()[['trained_corr', 'comp_corr']] \
                         .rename({'trained_corr': 'Joint', 'comp_corr': 'Baseline'}, axis=1),

    'Emb. vs Graph: Text': runs_emb_vs_graph.groupby(level=[0,1]).mean()[['trained_text_corr', 'comp_text_corr']] \
                       .rename({'trained_text_corr': 'Joint', 'comp_text_corr': 'Baseline'}, axis=1),
}

stats = pd.concat(stats.values(), keys=stats.keys(), axis=1).reset_index()

stats['directed'] = stats['dataset'].apply(lambda s: 'Directed' if 'directed' in s else 'Undirected')
stats['lmtype'] = stats['model'].apply(lambda s: s.split('-')[0].title())
stats['model'] = stats['model'].apply(lambda s: s.split('-')[1])
stats['model'] = stats['model'].map({'base': r'$\alpha = 0.0$', 'sim10': r'$\alpha = 0.1$'})
stats['dataset'] = stats['dataset'].map({
    'pubmed': 'Pubmed',
    'trex': 'TRex',
    'twitter-small': 'Twitter',
    'pubmed-directed': 'Pubmed',
    'trex-directed': 'TRex',
    'twitter-small-directed': 'Twitter',    
})

stats = stats.set_index(['dataset', 'directed', 'lmtype', 'model']).sort_index(ascending=True)
stats.index.names = ['Dataset', 'Directed', 'LM Type', 'Sim.']

In [None]:
tab = stats \
    .loc[:, ['Distance Coupling', 'Emb. vs Graph: Text']] \
    .style \
    .format(precision=3) \
    .apply(ut.bold_largest_by_metric, axis=1)

with pd.option_context('display.html.use_mathjax', True):
    display(tab)

## LaTeX tables for paper

In [None]:
print(stats \
    .loc[:, ['Distance Coupling', 'Emb. vs Graph: Text']] \
    .style \
    .format(precision=3) \
    .apply(ut.bold_largest_by_metric, axis=1) \
    .to_latex(
        hrules = True,
        column_format = 'lllr|rr|rr',
        position = 'ht',
        label = 'tab:cross-modality-results',
        multicol_align = '|c',
        position_float = 'centering',
        environment = 'table*',
        convert_css = True,
    ))

# Hypothesis tests

## Top-k accuracy

In [None]:
tmp = runs_topk.groupby(level=['dataset', 'model', 'k'])[['trained_acc', 'comp_acc', 'diff_acc']].describe().loc[:, pd.IndexSlice[:, 'mean']]
tmp.columns = [c[0] for c in tmp.columns]

tmp = tmp.rename({'trained_acc': 'trained_acc_resample', 'comp_acc': 'comp_acc_resample', 'diff_acc': 'diff_acc_resample'}, axis=1)
tmp = tmp.merge(exact_topk, how='inner', left_index=True, right_index=True)

assert tmp.shape[0] == exact_topk.shape[0]

tmp['trained_acc_diff'] = tmp['trained_acc'] - tmp['trained_acc_resample']
tmp['comp_acc_diff'] = tmp['comp_acc'] - tmp['comp_acc_resample']
tmp['diff_acc_diff'] = tmp['diff_acc'] - tmp['diff_acc_resample']

tmp[['trained_acc_diff', 'comp_acc_diff', 'diff_acc_diff']].describe()

In [None]:
masses = runs_topk['diff_acc'].apply(lambda s: s < 0).groupby(level=['dataset', 'model', 'k']).sum()
shapes = runs_topk.groupby(level=['model', 'k']).size()
pvals = (masses / shapes).unstack(0)

pvals.applymap(lambda pval: 2 * min(pval, 1 - pval) if pval <= 0.5 else pval)  # two-sided

## Within-node similarity pre-post

In [None]:
tmp = runs_node_recovery \
    .reset_index() \
    .groupby(['dataset', 'model']) \
    [['trained_sim_avg', 'comp_sim_avg', 'diff_sim_avg']] \
    .mean() \
    .rename({k : k + '_resample' for k in ['trained_sim_avg', 'comp_sim_avg', 'diff_sim_avg']}, axis=1)

tmp = tmp.merge(exact_node_recovery, how='inner', left_index=True, right_index=True)

tmp['trained_sim_avg_diff'] = tmp['trained_sim_avg'] - tmp['trained_sim_avg_resample']
tmp['comp_sim_avg_diff'] = tmp['comp_sim_avg'] - tmp['comp_sim_avg_resample']
tmp['diff_sim_avg_diff'] = tmp['diff_sim_avg'] - tmp['diff_sim_avg_resample']

tmp[['trained_sim_avg_diff', 'comp_sim_avg_diff', 'diff_sim_avg_diff']]#.describe()

In [None]:
pvals = runs_node_recovery['diff_sim_avg'].apply(lambda s: s < 0).groupby(level=[0, 1]).mean()
pvals.apply(lambda pval: 2 * min(pval, 1 - pval) if pval <= 0.5 else pval)  # two-sided

## Correlation of distances

In [None]:
pvals = runs_dist_corr['diff_corr'].apply(lambda s: s < 0).groupby(level=[0, 1]).mean()
pvals.apply(lambda pval: 2 * min(pval, 1 - pval) if pval <= 0.5 else pval)  # two-sided

## Embedding distances vs graph distances

In [None]:
pvals = runs_emb_vs_graph['diff_node_corr'].apply(lambda s: s < 0).groupby(level=[0, 1]).mean()
pvals.apply(lambda pval: 2 * min(pval, 1 - pval) if pval <= 0.5 else pval)  # two-sided

In [None]:
pvals = runs_emb_vs_graph['diff_text_corr'].apply(lambda s: s < 0).groupby(level=[0, 1]).mean()
pvals.apply(lambda pval: 2 * min(pval, 1 - pval) if pval <= 0.5 else pval)  # two-sided