In [46]:
%load_ext autoreload
%autoreload 2

import os
import sys

%store -r DISCO_ROOT_FOLDER
if "DISCO_ROOT_FOLDER" in globals():
    os.chdir(DISCO_ROOT_FOLDER)
    sys.path.append(DISCO_ROOT_FOLDER)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [47]:
from src.utils.metrics import METRIC_ABBREV_TABLES

In [48]:
import glob
import itertools
import pandas as pd
import numpy as np


from collections import defaultdict
from collections.abc import Mapping


def deep_update(dict, dict_update):
    for k, v in dict_update.items():
        if isinstance(v, Mapping):
            dict[k] = deep_update(dict.get(k, {}), v)
        else:
            dict[k] = v
    return dict


def group_by_column_and_aggregate_values(df, column, aggregation_func):
    column_values = df.groupby(["dataset", "measure"])[column]
    aggregated_values = getattr(column_values, aggregation_func)()

    data_dict = defaultdict(dict)
    for (dataset, eval_method), aggregated_value in aggregated_values.to_dict().items():
        data_dict[(dataset, aggregation_func)][(eval_method, column)] = aggregated_value

    return data_dict


def calc_pairwise_column_aggregation_func_dict(df, columns, aggregation_funcs):
    pairwise_column_aggregation_func = itertools.product(columns, aggregation_funcs)
    pairwise_column_aggregation_func_data_dict = {}
    for column, aggregations_func in pairwise_column_aggregation_func:
        column_aggregation_func_data_dict = group_by_column_and_aggregate_values(
            df, column, aggregations_func
        )
        deep_update(pairwise_column_aggregation_func_data_dict, column_aggregation_func_data_dict)
    return pairwise_column_aggregation_func_data_dict


def gather_and_aggregate_data(path, columns, aggregation_funcs):
    dataframes = [pd.read_csv(path) for path in glob.glob(f"{path}*/*")]
    df = pd.concat(dataframes)

    data_dict = calc_pairwise_column_aggregation_func_dict(df, columns, aggregation_funcs)
    return pd.DataFrame.from_dict(data_dict, orient="index")

In [49]:
def reindex_df(df, row_index, column_index, precision=3):
    df = df.reindex(row_index)
    df = df.reindex(columns=df.columns.reindex(column_index)[0])
    df = df.round(precision)
    return df


def calc_indices_and_reindex(df, dataset_names, aggregation_funcs, metrics, selection, precision=3):
    if metrics is None:
        metrics = [column[0] for column in df.columns]
    row_index = list(itertools.product(dataset_names, aggregation_funcs))
    column_index = list(itertools.product(metrics, selection))
    return reindex_df(df, row_index=row_index, column_index=column_index, precision=precision)

In [50]:
def run_regex(expr_list, path):
    for expr in expr_list:
        expr_ = '\'' + expr + '\''
        path_ = '"' + path + '"'
        !perl -pwi -e {expr_} {path_}

def regex_file(path, caption, categories=[], metric_abbrev={}):
    generell_stuff_regex = [
        r's/table/table\*/g',
        r's/_/\\_/g',
        r's/\\\\\n/ \\\\\n/g',
        r'1 while s/\$nan( \\pm nan)?\$/-/g',
        # Two digits
        r's/\\pm 0.00?\$/\$/g',
        r's/(\d\.\d)( |\$)/${1}0$2/g',
        # arraystretch
        r's/(\\begin\{table\*\})/\\renewcommand\{\\arraystretch\}\{1.2\}\n\n\n$1/g',
        r's/(\\end\{table\*\}.*$)/$1\n\n\\renewcommand\{\\arraystretch\}\{1\}\n/g',
    ]

    remove_second_row_index_regex = [
        # Remove second level of row index
        r's/ & \\textbf\{Metric\}//g',
        # r's/^(.*?& ).*?& /$1/g',
        r's/(\{tabular\}\{l\|)l\|/$1/g',
        # Remove all midrules
        r's/\\midrule\n//g',
    ]

    caption_regex = [
        # Caption
        r's/\\caption\{TODO\}/\\caption\{%s\}/g'%(caption),
    ]

    categories_regex = [
        # Generell categories stuff
        r's/(\{tabular\}\{)/$1r/g',
        r's/^(.*?& )/& $1/g',
    ] + [
        # Categories
        r"s/(^& %s)/\\midrule\n\\parbox[t]\{2mm\}\{\\multirow\{%s\}\{*\}\{\\rotatebox[origin=c]\{90\}\{%s\}\}\}\n$1/g"
        % (first_dataset_in_category, nr_of_datasets_in_category, category_name)
        for first_dataset_in_category, nr_of_datasets_in_category, category_name in categories
    ]

    run_regex(remove_second_row_index_regex, path)
    run_regex(caption_regex, path)
    if len(categories) > 0:
        run_regex(categories_regex, path)
    else:
        # Insert midrule after headline
        run_regex([r's/(^\\textbf\{Dataset\}.*?$)/$1\n\\midrule/g'], path)

    if metric_abbrev:
        for metric, abbrev in metric_abbrev.items():
            abbrev = abbrev.replace("$", "\\$")
            run_regex([r's/& %s( |\\)/& %s$1/g'%(metric, abbrev)], path)

    run_regex(generell_stuff_regex, path)

In [51]:
def latex_coloring(
    path,
    skiprows=9,
    axis=0,
    metric_selection=None,
    higher_is_better=None,
    lower_is_better=[],
):
    df = pd.read_csv(
        path, sep="&", header=0, index_col=0, skiprows=skiprows, skipfooter=3, engine="python"
    )
    df = df.drop(df.columns[0], axis=1)
    if None in df.index:
        df = df.drop(index=[None], axis=0)
    if "\\midrule" in df.index:
        df = df.drop(index=["\\midrule"], axis=0)
    df.columns = df.columns.str.replace("\\", "")
    df.columns = df.columns.str.strip()
    df_std = df.copy()
    df_std = df_std.replace(r"\$(.*?) ?(\\pm.*?)?\$(.*\\\\)?", value=r"\2", regex=True)
    df_std = df_std.replace(r" $", value="", regex=True)
    df = df.replace(r"\$(.*?)( \\pm.*?)?\$(.*\\\\)?", value=r"\1", regex=True)
    df = df.astype(float)

    if metric_selection is None:
        metric_selection = df.columns
    df_selected = df[metric_selection]

    df_min = df.copy()
    df_max = df.copy()
    df_min.loc[:, metric_selection] = np.expand_dims(df_selected.min(axis=axis, skipna=True).values, axis=axis)  # type: ignore
    df_max.loc[:, metric_selection] = np.expand_dims(df_selected.max(axis=axis, skipna=True).values, axis=axis)  # type: ignore

    df_color_saturation = df.copy()
    df_color_saturation.loc[:,:] = 0
    if higher_is_better is None:
        higher_is_better = df.columns
    df_color_saturation.loc[:, higher_is_better] = (
        df.loc[:, higher_is_better] - df_min.loc[:, higher_is_better]
    ) / (df_max.loc[:, higher_is_better] - df_min.loc[:, higher_is_better])
    lower_is_better = [metric for metric in lower_is_better if metric in df.columns]
    df_color_saturation.loc[:, lower_is_better] = (
        df_max.loc[:, lower_is_better] - df.loc[:, lower_is_better]
    ) / (df_max.loc[:, lower_is_better] - df_min.loc[:, lower_is_better])
    df_color_saturation = df_color_saturation * 65 + 5
    df_color_saturation.replace(np.nan, 0, inplace=True)
    df_color_saturation = df_color_saturation.astype(int)

    df_latex = df.astype(str).combine(df_color_saturation.astype(str), lambda value, color_saturation: "\\cellcolor{" + "Green" + "!" + color_saturation + r"} $" + value)
    df_latex = df_latex + df_std
    df_latex = df_latex.replace(r" $", value="", regex=True)
    df_latex = df_latex + "$"
    df_latex.insert(0, "dataset", df_latex.index.str.strip())
    df_joined_columns = df_latex[df_latex.columns[:]].apply(lambda x: " & ".join(x), axis=1)
    df_joined_columns = df_joined_columns.replace("\\\\", "\\\\\\\\", regex=True)
    df_joined_columns.index = df_joined_columns.index.str.replace("\\", "\\\\")

    for dataset, row in df_joined_columns.items():
        row = row.replace("$", "\\$")
        run_regex([r's/%s.*\\\\/%s \\\\/g'%(dataset, row)], path)

In [52]:
from clustpy.utils import evaluation_df_to_latex_table
from src.utils.metrics import METRIC_ABBREV_TABLES


def generate_latex_file(
    path,
    latex_path,
    dataset_names,
    aggregation_funcs=["mean", "std"],
    metrics=None,
    metric_abbrev=METRIC_ABBREV_TABLES,
    selection=["value", "time", "process_time"],
    caption="TODO",
    categories=[],
    precision=3,
    latex_coloring_axis=None,
    latex_coloring_selection=None,
    higher_is_better=None,
    lower_is_better=[],
):
    df_matrix = gather_and_aggregate_data(
        path, columns=selection, aggregation_funcs=aggregation_funcs
    )
    df_reindexed = calc_indices_and_reindex(
        df_matrix, dataset_names, aggregation_funcs, metrics, selection, precision=precision
    )
    evaluation_df_to_latex_table(
        df_reindexed,
        latex_path,
        best_in_bold=False,
        second_best_underlined=False,
        in_percent=False,
        decimal_places=2,
    )

    if latex_coloring_axis is not None:
        latex_coloring(latex_path, skiprows=6, axis=latex_coloring_axis, metric_selection=latex_coloring_selection, higher_is_better=higher_is_better, lower_is_better=lower_is_better)

    regex_file(latex_path, caption=caption, categories=categories, metric_abbrev=metric_abbrev)
    print(f"Generated: `{latex_path}`")

In [53]:
from datasets.real_world_datasets import Datasets as RealWorldDatasets
from datasets.density_datasets import Datasets as DensityDatasets
# from src.utils.metrics import METRICS, SELECTED_METRICS

METRICS = [
    "DISCO",
    # "DC_DUNN",
    ### Competitors
    "DBCV",
    "DBCV_eucl",
    "DCSI",
    "LCCV",
    "VIASCKDE",
    "CVDD",
    "CDBW",
    "CVNN",
    # "DSI",
    ### Gauss
    "SILHOUETTE",
    "S_DBW",
    # "DUNN",
    # "DB",
    # "CH",
]

configs = {
    # Real World Datasets
    "real_world_colored_column_wise": {
        "path": "results/real_world/",
        "latex_path": "latex/real_world_experiments.tex",
        "dataset_names": [dataset.name for dataset in RealWorldDatasets],
        "aggregation_funcs": ["mean"],
        "metrics": METRICS,
        "lower_is_better": ["CVNN", "DCVI", "S_DBW"],
        "selection": ["value"],
        "caption": "Evaluating on real-world datasets. Column-wise Green.",
        "categories": [
            ("Synth_low", 8, "Tabular data"),
            ("Weizmann", 2, "Video"),
            ("COIL20", 3, "Image"),
            ("Optdigits", 5, "MNIST"),
        ],
        "latex_coloring_axis": 0,
        "latex_coloring_selection": None,
    },
    "real_world_standardized_colored_column_wise": {
        "path": "results/real_world_standardized/",
        "latex_path": "latex/real_world_experiments_standardized.tex",
        "dataset_names": [dataset.name for dataset in RealWorldDatasets],
        "aggregation_funcs": ["mean"],
        "metrics": METRICS,
        "lower_is_better": ["CVNN", "DCVI", "S_DBW"],
        "selection": ["value"],
        "caption": "Evaluating on real-world datasets (standardized). Column-wise Green.",
        "categories": [
            ("Synth_low", 8, "Tabular data"),
            ("Weizmann", 2, "Video"),
            ("COIL20", 3, "Image"),
            ("Optdigits", 5, "MNIST"),
        ],
        "latex_coloring_axis": 0,
        "latex_coloring_selection": None,
    },
    # "real_world_colored_row_wise_selected": {
    #     "path": "results/real_world/",
    #     "latex_path": "latex/real_world_experiments (selected).tex",
    #     "dataset_names": [dataset.name for dataset in RealWorldDatasets],
    #     "aggregation_funcs": ["mean"],
    #     "metrics": METRICS,
    #     "lower_is_better": ["CVNN", "DCVI", "S_DBW"],
    #     "selection": ["value"],
    #     "caption": "Evaluating on real-world datasets. Row-wise Green of selected ones.",
    #     "categories": [
    #         ("Synth_low", 8, "Tabular data"),
    #         ("Weizmann", 2, "Video"),
    #         ("COIL20", 3, "Image"),
    #         ("Optdigits", 5, "MNIST"),
    #     ],
    #     "latex_coloring_axis": 1,
    #     "latex_coloring_selection": ["DISCO", "DBCV", "DCSI", "S_DBW", "DSI", "SILHOUETTE", "DUNN"],
    # },
    # "real_world_standardized_colored_row_wise_selected": {
    #     "path": "results/real_world_standardized/",
    #     "latex_path": "latex/real_world_experiments_standardized (selected).tex",
    #     "dataset_names": [dataset.name for dataset in RealWorldDatasets],
    #     "aggregation_funcs": ["mean"],
    #     "metrics": METRICS,
    #     "lower_is_better": ["CVNN", "DCVI", "S_DBW"],
    #     "selection": ["value"],
    #     "caption": "Evaluating on real-world datasets (standardized). Row-wise Green of selected ones.",
    #     "categories": [
    #         ("Synth_low", 8, "Tabular data"),
    #         ("Weizmann", 2, "Video"),
    #         ("COIL20", 3, "Image"),
    #         ("Optdigits", 5, "MNIST"),
    #     ],
    #     "latex_coloring_axis": 1,
    #     "latex_coloring_selection": ["DISCO", "DBCV", "DCSI", "S_DBW", "DSI", "SILHOUETTE", "DUNN"],
    # },
    # Density Datasets
    "density_colored_column_wise": {
        "path": "results/density/",
        "latex_path": "latex/density_experiments.tex",
        "dataset_names": [dataset.name for dataset in DensityDatasets],
        "aggregation_funcs": ["mean", "std"],
        "metrics": METRICS,
        "lower_is_better": ["CVNN", "DCVI", "S_DBW"],
        "selection": ["value"],
        "caption": "Evaluating on density datasets. Column-wise Green.",
        "categories": [],
        "latex_coloring_axis": 0,
        "latex_coloring_selection": None,
    },
    "density_standardized_colored_column_wise": {
        "path": "results/density_standardized/",
        "latex_path": "latex/density_experiments_standardized.tex",
        "dataset_names": [dataset.name for dataset in DensityDatasets],
        "aggregation_funcs": ["mean", "std"],
        "metrics": METRICS,
        "lower_is_better": ["CVNN", "DCVI", "S_DBW"],
        "selection": ["value"],
        "caption": "Evaluating on density datasets (standardized). Column-wise Green.",
        "categories": [],
        "latex_coloring_axis": 0,
        "latex_coloring_selection": None,
    },
    # "density_colored_row_wise_selected": {
    #     "path": "results/density/",
    #     "latex_path": "latex/density_experiments (selected).tex",
    #     "dataset_names": [dataset.name for dataset in DensityDatasets],
    #     "aggregation_funcs": ["mean", "std"],
    #     "metrics": METRICS,
    #     "lower_is_better": ["CVNN", "DCVI", "S_DBW"],
    #     "selection": ["value"],
    #     "caption": "Evaluating on density datasets. Row-wise Green of selected ones.",
    #     "categories": [],
    #     "latex_coloring_axis": 1,
    #     "latex_coloring_selection": ["DISCO", "DBCV", "DCSI", "S_DBW", "DSI", "SILHOUETTE", "DUNN"],
    # },
    # "density_standardized_colored_row_wise_selected": {
    #     "path": "results/density_standardized/",
    #     "latex_path": "latex/density_experiments_standardized (selected).tex",
    #     "dataset_names": [dataset.name for dataset in DensityDatasets],
    #     "aggregation_funcs": ["mean", "std"],
    #     "metrics": METRICS,
    #     "lower_is_better": ["CVNN", "DCVI", "S_DBW"],
    #     "selection": ["value"],
    #     "caption": "Evaluating on density datasets (standardized). Row-wise Green of selected ones.",
    #     "categories": [],
    #     "latex_coloring_axis": 1,
    #     "latex_coloring_selection": ["DISCO", "DBCV", "DCSI", "S_DBW", "DSI", "SILHOUETTE", "DUNN"],
    # },
}

In [54]:
from mpire.pool import WorkerPool

pool = WorkerPool(n_jobs=30, use_dill=True)
pool.map_unordered(generate_latex_file, configs.values())
pool.stop_and_join()
pool.terminate()

Generated: `latex/real_world_experiments_standardized.tex`
Generated: `latex/real_world_experiments.tex`
Generated: `latex/density_experiments.tex`
Generated: `latex/density_experiments_standardized.tex`


In [55]:
from src.utils.metrics import DB, CH, S_DBW

X, l = DensityDatasets.dartboard1.standardized_data_cached
S_DBW(X[l != -1], l[l != -1])

1.6229067378440645

In [56]:
X, l = DensityDatasets.cluto_t5_8k.data_cached
CH(X[l!=-1], l[l!=-1])

47275.21919050223

In [57]:
X, l = DensityDatasets.dartboard1.standardized_data_cached

datasets = {}
np.random.seed(0)
seeds = np.random.choice(10_000, size=10, replace=False)
for run in range(10):
    np.random.seed(seeds[run])
    shuffle_data_index = np.random.choice(len(X), size=len(X), replace=False)
    X_ = X[shuffle_data_index]
    l_ = l[shuffle_data_index]
    datasets[run] = (X_, l_)

In [58]:
[S_DBW(X, l) for X, l in datasets.values()]

[1.7125248320227824,
 1.7491498886700856,
 1.5734443722526668,
 1.6824502635722405,
 1.6564048633332913,
 1.6882179890285087,
 1.5548063794211258,
 1.7755537118481022,
 1.7491498886700858,
 1.712524832022782]