# Plot and analyze the results of the translation and rotation errors

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import yaml
import casino
import numpy as np
import matplotlib.pyplot as plt


from DITTO.config import RESULTS_DIR

In [None]:
def load_results(singles_dir):
    all_results = {}
    for single_result_file in singles_dir.glob("*"):
        with single_result_file.open("r") as f:
            tmp = yaml.load(f, Loader=yaml.Loader)

        all_results[single_result_file.stem] = tmp
    return all_results

In [None]:
def print_full_results(all_results):
    all_results_listed = list(all_results.values())
    print(f"Parsing {len(all_results_listed)}")

    accumulator = casino.special_dicts.AccumulatorDict()

    for result_dict in all_results_listed:
        accumulator.increment_dict(result_dict)
        
    print("### Mean ###")
    for metric, vals in accumulator.items():
        print(f"{metric = }: {np.mean(vals):.4f}")

    print("### Max ###")
    for metric, vals in accumulator.items():
        print(f"{metric = }: {np.max(vals):.4f}")
        

In [None]:
from pathlib import Path

relative_poses_result_dir = RESULTS_DIR / "inter_poses" / "single_runs"

all_results = load_results(relative_poses_result_dir)
metrics_list = [metric for metric in next(iter(all_results.values())).keys() if "duration" not in metric]
print(f"Metrics: {metrics_list}")
print_full_results(all_results)

In [None]:
MAX_VALUE = 1.2  #


def plot_metrics(results, metrics_list):
    fig, axs = plt.subplots(
        nrows=2, ncols=len(metrics_list) // 2, figsize=(10, 10), sharey="row"
    )

    translation_lists = [metric for metric in metrics_list if "translation" in metric]
    rotation_lists = [metric for metric in metrics_list if "rotation" in metric]

    for row_idx, metric_list in enumerate([translation_lists, rotation_lists]):
        for column_idx, metric in enumerate(metric_list):
            values = []
            for episode_id, episode_results in results.items():
                metric_values = np.array(
                    episode_results[metric]
                )  # (N-1) x (2*T); N: episodes per task; T: time steps
                values.extend(
                    np.mean(metric_values, axis=1)
                )  # Get the mean across all bi-directional timesteps

            row_index = 0 if "translation" in metric else 1

            axs[row_index][column_idx].hist(
                values, bins=np.linspace(0.0, MAX_VALUE, 20)
            )  # , bins=len(values))
            broken_lines = "\n".join(metric.split("/"))
            axs[row_index][column_idx].set_title(f"Metric: {broken_lines}")
            axs[row_index][column_idx].set_xlabel("Error Values")

    for ax in axs.flatten():
        ax.set_ylabel("Frequency")
    plt.tight_layout()
    plt.show()


# Plot the metrics
plot_metrics(all_results, metrics_list)