In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys

%store -r DISCO_ROOT_FOLDER
if "DISCO_ROOT_FOLDER" in globals():
    os.chdir(DISCO_ROOT_FOLDER)
    sys.path.append(DISCO_ROOT_FOLDER)

In [2]:
import glob
import itertools
import pandas as pd

from collections import defaultdict
from collections.abc import Mapping


def deep_update(dict, dict_update):
    for k, v in dict_update.items():
        if isinstance(v, Mapping):
            dict[k] = deep_update(dict.get(k, {}), v)
        else:
            dict[k] = v
    return dict


def group_by_column_and_aggregate_values(df, column, aggregation_func):
    column_values = df.groupby(["dataset", "measure"])[column]
    aggregated_values = getattr(column_values, aggregation_func)()

    data_dict = defaultdict(dict)
    for (dataset, eval_method), aggregated_value in aggregated_values.to_dict().items():
        data_dict[(dataset, aggregation_func)][(eval_method, column)] = aggregated_value

    return data_dict


def calc_pairwise_column_aggregation_func_dict(df, columns, aggregation_funcs):
    pairwise_column_aggregation_func = itertools.product(columns, aggregation_funcs)
    pairwise_column_aggregation_func_data_dict = {}
    for column, aggregations_func in pairwise_column_aggregation_func:
        column_aggregation_func_data_dict = group_by_column_and_aggregate_values(
            df, column, aggregations_func
        )
        deep_update(pairwise_column_aggregation_func_data_dict, column_aggregation_func_data_dict)
    return pairwise_column_aggregation_func_data_dict


def gather_and_aggregate_data(path, columns, aggregation_funcs):
    dataframes = [pd.read_csv(path) for path in glob.glob(f"{path}*/*")]
    df = pd.concat(dataframes)

    data_dict = calc_pairwise_column_aggregation_func_dict(df, columns, aggregation_funcs)
    return pd.DataFrame.from_dict(data_dict, orient="index")

In [3]:
def reindex_df(df, row_index, column_index, precision=3):
    df = df.reindex(row_index)
    df = df.reindex(columns=df.columns.reindex(column_index)[0])
    df = df.round(precision)
    return df


def calc_indices_and_reindex(df, dataset_names, aggregation_funcs, metrics, selection, precision=3):
    row_index = list(itertools.product(dataset_names, aggregation_funcs))
    column_index = list(itertools.product(metrics, selection))
    return reindex_df(df, row_index=row_index, column_index=column_index, precision=precision)

In [4]:
def run_regex(expr_list, path):
    for expr in expr_list:
        expr = '\'' + expr + '\''
        !perl -pwi -e {expr} {path}

def regex_file(path, caption, categories=[]):
    generell_stuff_regex = [
        r's/table/table\*/g',
        r's/_/\\_/g',
        r'1 while s/\$nan \\pm nan\$/-/g',
        # Two digits
        r's/(\d\.\d) /${{1}}0 /g',
        # arraystretch
        r's/(\\begin\{table\*\})/\\renewcommand\{\\arraystretch\}\{1.2\}\n\n\n$1/g',
        r's/(\\end\{table\*\}.*$)/$1\n\n\\renewcommand\{\\arraystretch\}\{1\}\n/g',
    ]

    remove_second_row_index_regex = [
        # Remove second level of row index
        r's/^(.*?& ).*?& /$1/g',
        r's/(\{tabular\}\{l\|)l\|/$1/g',
        # Remove all midrules
        r's/\\midrule\n//g',
        # Insert midrule after headline
        # r's/(^\\textbf\{Dataset\}.*?$)/$1\n\\midrule/g', 
    ]

    caption_regex = [
        # Caption
        r's/\\caption\{TODO\}/\\caption\{%s\}/g'%(caption),
    ]

    categories_regex = [
        # Generell categories stuff
        r's/(\{tabular\}\{)/$1r/g',
        r's/^(.*?& )/& $1/g',
    ] + [
        # Categories
        r"s/(^& %s)/\\midrule\n\\parbox[t]\{2mm\}\{\\multirow\{%s\}\{*\}\{\\rotatebox[origin=c]\{90\}\{%s\}\}\}\n$1/g"
        % (first_dataset_in_category, nr_of_datasets_in_category, category_name)
        for first_dataset_in_category, nr_of_datasets_in_category, category_name in categories
    ]

    run_regex(generell_stuff_regex, path)
    run_regex(remove_second_row_index_regex, path)
    run_regex(caption_regex, path)
    if len(categories) > 0:
        run_regex(categories_regex, path)


In [5]:
from clustpy.utils import evaluation_df_to_latex_table


METRICS = [
    "DISCO",
    "DC_DUNN",
    ### Competitors
    "DBCV",
    "DCSI",
    "S_DBW",
    "CDBW",
    "CVDD",
    "CVNN",
    "DSI",
    ### Gauss
    "SILHOUETTE",
    "DUNN",
    "DB",
    "CH",
]


def generate_latex_file(
    path,
    latex_path,
    dataset_names,
    aggregation_funcs=["mean", "std"],
    metrics=METRICS,
    selection=["value", "time", "process_time"],
    caption="TODO",
    categories=[],
    precision=3,
):
    df_list = gather_and_aggregate_data(
        path, columns=selection, aggregation_funcs=aggregation_funcs
    )
    df_matrix = calc_indices_and_reindex(
        df_list, dataset_names, aggregation_funcs, metrics, selection, precision=precision
    )
    evaluation_df_to_latex_table(
        df_matrix,
        latex_path,
        best_in_bold=False,
        second_best_underlined=False,
        in_percent=False,
        decimal_places=2,
    )
    regex_file(latex_path, caption=caption, categories=categories)

In [6]:
from datasets.real_world_datasets import Datasets as RealWorldDatasets

configs = {
    "real_world": {
        "path": "results/real_world/",
        "latex_path": "latex/real_world_experiments.tex",
        "dataset_names": [dataset.name for dataset in RealWorldDatasets],
        "aggregation_funcs": ["mean"],
        "metrics": METRICS,
        "selection": ["value"],
        "caption": "Evaluating on real-world datasets.",
        "categories": [
            ("Synth\\_low", 8, "Tabular data"),
            ("Weizmann", 2, "Video"),
            ("COIL20", 3, "Image"),
            ("Optdigits", 5, "MNIST"),
        ],
    },
    "real_world_standardized": {
        "path": "results/real_world_standardized/",
        "latex_path": "latex/real_world_experiments_standardized.tex",
        "dataset_names": [dataset.name for dataset in RealWorldDatasets],
        "aggregation_funcs": ["mean"],
        "metrics": METRICS,
        "selection": ["value"],
        "caption": "Evaluating on real-world datasets (standardized).",
        "categories": [
            ("Synth\\_low", 8, "Tabular data"),
            ("Weizmann", 2, "Video"),
            ("COIL20", 3, "Image"),
            ("Optdigits", 5, "MNIST"),
        ],
    },
}

In [9]:
from mpire.pool import WorkerPool

pool = WorkerPool(n_jobs=30, use_dill=True)
pool.map_unordered(generate_latex_file, configs.values())
pool.stop_and_join()
pool.terminate()