In [None]:
import json

import matplotlib.pyplot as plt
import numpy as np

from soft_label_learning.config import path_output
from soft_label_learning.experiments.experiment_settings import real_world_settings
from soft_label_learning.experiments.process_synthetic_data import replace_list_item

#### Analyse experiment results

In [None]:
exp_settings = real_world_settings.copy()

## Settings
train_fractions = [5, 10, 20, 40, 60, 80]
alpha = 1

Load experiment results

In [None]:
# TODO set to the datetime string of the result
time_string = "date_hh_mm_ss"
complete_path = path_output / "real_world" / f"{time_string}_result_dict.json"

with open(complete_path) as f:
    result_dict_loaded = json.load(f)

#### Analyse results

In [None]:
metric_dict = {
    "auc-test-soft_prop-pv": "AUC",
    "auc-test-soft_prop-samp": "AUC",
    "tvd-test-soft_prop-soft": r"$\overline{TVD}$",
}

In [None]:
eval_settings = [
    "auc-test-soft_prop-pv",
    "auc-test-soft_prop-samp",
    "tvd-test-soft_prop-soft",
]

figure_dict = {}

for eval_setting in eval_settings:
    # initiate figure
    fig, axs = plt.subplots(4, 1, sharex=True, figsize=(12, 8))

    # uncomment to get title with plot
    # fig.subtitle("Evaluation setting: " + eval_setting)
    color_counter = 0
    minval = 100
    maxval = 0

    for i, ax in enumerate(axs):
        clf = ["LR", "SGD", "GNB", "DT"][i]
        color_counter += 1

        methods = exp_settings["method"].copy()

        plot_values = []

        color = f"C{color_counter}"
        alpha = 1
        width = 0.13
        scale = 1.15

        counter = 0

        for temp_train_frac in train_fractions:
            train_frac = str(temp_train_frac)

            for idx, method in enumerate(methods):
                result = result_dict_loaded[train_frac][clf][method][eval_setting][
                    "mean"
                ]
                if result < minval:
                    minval = result
                if result > maxval:
                    maxval = result

                label = method

                plot_values.append(result)
                if temp_train_frac == 80:
                    ax.bar(
                        idx * scale - 0.5 + (counter * 0.15),
                        result,
                        width=width,
                        label=label,
                        alpha=alpha,
                        color=color,
                    )
                else:
                    ax.bar(
                        idx * scale - 0.5 + (counter * 0.15),
                        result,
                        width=width,
                        alpha=alpha,
                        color=color,
                    )

            counter += 1

        # set the method as the x labels
        # rotate the x labels
        methods = replace_list_item(methods, "base_clf_s", "SampleClf")
        ax.set_xticks(np.array(range(len(methods))) * scale)
        ax.set_xticklabels(methods, rotation=30, ha="right")
        ax.set_title(clf)
        ax.set_ylabel(metric_dict[eval_setting])

        ax.hlines(
            max(plot_values),
            -2,
            16,
            colors="grey",
            linestyles=(0, (10, 10)),
            linewidth=1,
        )
        ax.hlines(
            min(plot_values),
            -2,
            16,
            colors="grey",
            linestyles=(0, (10, 10)),
            linewidth=1,
        )

        plt.xlim(-0.85, 15.5)
        plt.tight_layout()

    for ax in axs:
        ax.set_ylim(minval - 0.02, maxval + 0.02)

    plt.subplots_adjust(hspace=0.25)

    figure_dict[eval_setting] = fig

In [None]:
# iterate over figure dict and save the result
for item in figure_dict.items():
    item[1].savefig(path_output / "real_world" / f"{item[0]}.png", dpi=300)