In [None]:
import json
from pathlib import Path
from pprint import pprint

import torch
import matplotlib.pyplot as plt

from pami import summarize, ratio_std_mean


def get_all_results(results_file='./results.json'):
    path = Path(results_file)
    if path.exists():
        out = json.load(open(path, 'r'))
    else:
        out = []
        for experiment in path.parent.glob('results/*/t*.pt'):
            out.append(summarize(experiment))
        json.dump(out, open(path, 'w'))
    return out


def process(result_dict, average=False, precision=3):
    x = result_dict.copy()
    for k in tuple(x.keys()):
        if k.endswith('_ratios'):
            v = x.pop(k)
            if average:
                mean = round(sum(v['mean']) / len(v['mean']), precision)
                std = round(sum(v['std']) / len(v['std']), precision)
                v = (mean, std)
            else:
                mean = [round(x, precision) for x in v['mean']]
                std = [round(x, precision) for x in v['std']]
                v = tuple(zip(mean, std))
            k = '_'.join(k.split('_')[:-1])
            x[f'{k}_error'] = v
    # x['trace'] = round(784 * x['std']**2, 2 * precision)
    return x


def sort(result_list, average=False, pandas=False, key=('std', 'k', 'baseline')):
    result_list = [process(x, average=average) for x in result_list]
    key = list(key)
    if pandas:
        import pandas as pd
        df = pd.DataFrame(result_list)
        if not average:
            df['g(x)'] = [tuple(range(10))] * len(df)
            df.set_index(key, inplace=True)
            df = df.apply(pd.Series.explode)
            df.reset_index(inplace=True)
            key.append('g(x)')
        df.set_index(key, inplace=True)
        df.sort_values(key, inplace=True)
        return df
    else:
        return sorted(result_list, key=lambda x: tuple(x[k] for k in key))


def to_latex(df, baseline=False):
    df.columns = ['mean', 'old var', 'new var']
    if 'g(x)' in df.index.names:
        df.reset_index(['baseline', 'k'], drop=True, inplace=True)
        df.index.rename(['$\\sigma$', '$\\mathbf{g}_i(\\mathbf{x})$'], inplace=True)
    else:
        df.reset_index(['baseline'], inplace=True)
        df = df[df['baseline'] == baseline]
        df.drop(columns='baseline', inplace=True)
        df.index.rename(['$\\sigma$', '$k$'], inplace=True)
    formatter = lambda x: f'${x[0]:0.3f} \\pm {x[1]:0.3f}$'
    return df.to_latex(formatters=[formatter] * 3, escape=False, caption='caption')


def plot_histograms(result_file, save=False, show=True):
    data = torch.load(result_file, 'cpu')
    out_var = torch.stack(data['out_var'])
    fg_var = torch.stack(data['fg_var'])
    fg_bad = torch.stack(data['fg_bad'])
    var_ratios = ratio_std_mean(fg_var, out_var)[1]
    bad_ratios = ratio_std_mean(fg_bad, out_var)[1]
    fig, axes = plt.subplots(2, 5, figsize=(17, 5))
    for i, (ax, m, n, o) in enumerate(zip(axes.flatten(), out_var.T, fg_var.T, fg_bad.T)):
        ax.hist(m, 50, alpha=1.0, color='black', label=f'MonteCarlo')
        ax.hist(n, 50, alpha=0.7, color='blue', label=f'New[{var_ratios[i]:.3f}]')
        ax.hist(o, 50, alpha=0.7, color='red', label=f'Old[{bad_ratios[i]:.3f}]')
        ax.xaxis.set_major_locator(plt.LinearLocator(numticks=3))
        ax.set_title(f'class {i}')
        ax.legend(prop={'family': 'monospace'})
    fig.subplots_adjust(hspace=0.5, wspace=0.3)
    # fig.suptitle(title, y=1.1)
    fig.tight_layout()
    if save:
        path = Path(result_file).absolute()
        plt.savefig(str(path.parent) + '.pdf', bbox_inches='tight')
    if show:
        plt.show(fig)

In [None]:
out = get_all_results()

In [None]:
table2 = sort(filter(lambda x: x['k'] == 10000, out), average=False, pandas=True)
# print(to_latex(table2, baseline=False))  # only if pandas is True
table2

In [None]:
table3 = sort(out, average=True, pandas=True)
# print(to_latex(table3, baseline=False))  # only if pandas is True
# print(to_latex(table3, baseline=True))  # only if pandas is True
table3

In [None]:
for result_file in sorted(Path('./results').glob('*/t2.pt')):
    plot_histograms(result_file, save=False, show=True)