# imports

In [1]:
import os
import os.path
import pickle
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from typing import Callable
import re
from math import floor
import statistics

from test_runner import *
from test_runner.translators import *
from test_runner.analysers import SearchResult


# from test_runner import TestCase, BaseTestRunner, LiftedPlanningRunner, GroundedPlanningRunner
# from test_runner.tapaal_caller import QueryResult

from parse_results import translator_result_type, search_result_type#, load_translator_results, load_search_results, 


# Style Config

In [None]:
# downward = circle red
# grounded = square green
# lifted = triangle blue

# Load Data

In [2]:

infinite = float("inf")


results_dir = "./results"
plot_save_dir = "./results/plots"

os.makedirs(plot_save_dir, exist_ok=True)


results_path = "./results"

def load_translator_results() -> translator_result_type:
    with open(os.path.join(results_path, f"translator_results.pickle"), "rb") as f:
        return pickle.load(f)


def load_search_results() -> search_result_type:
    with open(os.path.join(results_path, f"search_results.pickle"), "rb") as f:
        return pickle.load(f)
    

translator_results: translator_result_type = load_translator_results()
search_results: search_result_type = load_search_results()

# searcher -> translator -> test_case
search_results_search_translator_test_case: dict["BaseSearcher", dict["BaseTranslator", dict["TestCase", list["SearchResult"]]]] = dict()
for translator, translator_results in search_results.items():
    for test_case, test_results in translator_results.items():
        for search, results in test_results.items():
            search_results_search_translator_test_case[search] = search_results_search_translator_test_case.get(search, dict())
            search_results_search_translator_test_case[search][translator] = search_results_search_translator_test_case[search].get(translator, dict())
            search_results_search_translator_test_case[search][translator][test_case] = results


# test_case -> searcher -> translator
search_results_test_case_searcher_translator: dict["TestCase", dict["BaseSearcher", dict["BaseTranslator", list["SearchResult"]]]] = dict()
for translator, translator_results in search_results.items():
    for test_case, test_results in translator_results.items():
        for search, results in test_results.items():
            search_results_test_case_searcher_translator[test_case] = search_results_test_case_searcher_translator.get(test_case, dict())
            search_results_test_case_searcher_translator[test_case][search] = search_results_test_case_searcher_translator[test_case].get(search, dict())
            search_results_test_case_searcher_translator[test_case][search][translator] = results


#def median(numbers: list[float], total_length: int) -> float:


class ResultCollection(object):
    is_tapaal: bool = False
    raw: list[SearchResult]
    median_time: float = infinite
    median_unfolding_time: float = infinite
    median_verification_time: float = infinite
    median_trace_length: int = infinite
    

domain_regex = re.compile(r"^(?P<domain>.+?)_(?P<id>\d\d)$", re.MULTILINE)

# Domain -> translator -> searcher-> IPCMedianTime[] 
search_results_domain_translator_medians: dict["str", dict["BaseTranslator", dict["BaseSearcher", dict["int", list["ResultCollection"]]]]] = dict()
for translator, translator_results in search_results.items():
    for test_case, test_results in translator_results.items():
        for search, results in test_results.items():
            if (len(results)==0):
                continue

            domain_name = domain_regex.match(test_case.name)["domain"]
            domain_problem_id = int(domain_regex.match(test_case.name)["id"])

            # times = [(result.time.seconds_system + result.time.seconds_user) for result in results if result.get("has_plan", False)]
            # times.sort()
            # mid = floor(len(results)/2)-1
            # time_median =  times[mid] if len(times)-1 > mid else float('inf')

            resultCollection = ResultCollection()
            resultCollection.raw = results


            times = [(result.time.seconds_system + result.time.seconds_user) if result.get("has_plan", False) else float("inf") for result in results ]
            resultCollection.median_time = statistics.median(times)
            resultCollection.median_time_minus_unfolding = resultCollection.median_time
            resultCollection.min_time = min(times)

            if translator.name=="colored" or translator.name=="grounded":
                is_tapaal = True

                verification_times = [result["time_verification"] if result.get("has_plan", False) else float("inf") for result in results ]
                resultCollection.median_verification_time = statistics.median(verification_times)

            if translator.name=="colored":
                
                plan_lengths = [len(result.plan.actions) if result.get("has_plan", False) else float("inf") for result in results ]
                resultCollection.median_trace_length = statistics.median(plan_lengths)

                time_unfolding = [result["time_unfolding"] if result.get("has_plan", False) else float("inf") for result in results ]
                resultCollection.median_unfolding_time = statistics.median(time_unfolding)

                resultCollection.median_time_minus_unfolding = statistics.median([(result.time.seconds_system + result.time.seconds_user - result["time_unfolding"]) if result.get("has_plan", False) else float("inf") for result in results ])


            search_results_domain_translator_medians[domain_name] = search_results_domain_translator_medians.get(domain_name, dict())
            search_results_domain_translator_medians[domain_name][translator] = search_results_domain_translator_medians[domain_name].get(translator, dict())
            search_results_domain_translator_medians[domain_name][translator][search] = search_results_domain_translator_medians[domain_name][translator].get(search, [float('inf') for x in list(range(0,30))])
            search_results_domain_translator_medians[domain_name][translator][search][domain_problem_id-1] = resultCollection


# Domain -> translator -> searcher-> IPCMedianPlan[] 
#search_results_domain_translator_medians: dict["str", dict["TestCase", dict["BaseSearcher", dict["BaseTranslator", list["SearchResult"]]]]] = dict()
for translator, translator_results in search_results.items():
    for test_case, test_results in translator_results.items():
        for search, results in test_results.items():

            break
            

            domain_name = domain_regex.match(test_case.name)["domain"]
            domain_problem_id = int(domain_regex.match(test_case.name)["id"])


            plan_len = [(result.plan.actions) for result in results if result.get("has_plan", False)]
            times = [(result.time.seconds_system + result.time.seconds_user) for result in results if result.get("has_plan", False)]
            times.sort()
            mid = floor(len(results)/2)
            time_median =  times[mid] if len(times)-1 > mid else float('inf')
            

            search_results_domain_translator_medians[domain_name] = search_results_domain_translator_medians.get(domain_name, dict())
            search_results_domain_translator_medians[domain_name][translator] = search_results_domain_translator_medians[domain_name].get(translator, dict())
            search_results_domain_translator_medians[domain_name][translator][search] = search_results_domain_translator_medians[domain_name][translator].get(search, [float('inf') for x in list(range(0,30))])
            search_results_domain_translator_medians[domain_name][translator][search][domain_problem_id] = time_median



In [None]:
for translator, translator_results in search_results.items():
    for test_case, test_results in translator_results.items():
        for search, results in test_results.items():
            if search.name == "no_color_optimizations":
                search.style = None
                continue

            if translator.name == "colored" and search.name == "rpfs":
                search.style = "s-"
                search.color="#1b9e77"
                continue

            if translator.name == "grounded" and search.name == "rpfs":
                search.style = "v-"
                search.color="#d95f02"
                continue

            if translator.name == "downward" and search.name == "lama_first":
                search.style = "o-"
                search.color="#7570b3"
                continue

            print("Unstyled config")




# Plots

## General Setup

## Cactus Plot

In [None]:
# System + User Time

def make_cactus_plot(name: str, description: str, reduction: Callable[["SearchResult"], float]):
    fig_rows = 8
    fig_cols = 6


    i=1
    plt.figure(figsize=(20, 12))
    for (test_case, test_case_results) in search_results_test_case_searcher_translator.items():
        subplt = plt.subplot(fig_rows, fig_cols, i)
        for (searcher, searcher_results) in test_case_results.items():
            for (translator, results) in searcher_results.items():
                times = [reduction(res) for res in results if "has_plan" in res]
                times.sort()
                #median = times[int(len(times)/2)]
                median_str = f"{times[int(len(times)/2)]:.4f}" if len(times) > 0 else "N/A"
                plt.plot(times, 'o-', label=f"{translator.name}({searcher.name}) - median={median_str}")
                for res in results:
                    print(res)

        plt.title(f'{description} - {test_case.name}')
        plt.gca().legend(loc='best')
        plt.xlabel('Index')
        plt.ylabel('Time (sec)')
        subplt.set_yscale("log")
        i += 1

    plt.savefig(os.path.join(plot_save_dir, f"{name}.png"))

# make_cactus_plot("total_time", "System + User time", lambda res: res.time.seconds_system + res.time.seconds_user)


In [None]:
# System + User Time

def make_cactus_plot(name: str, description: str, reduction: Callable[["SearchResult"], float]):
    fig_rows = 8
    fig_cols = 6


    i=1
    plt.figure(figsize=(20, 12))
    for (test_case, test_case_results) in search_results_test_case_searcher_translator.items():
        subplt = plt.subplot(fig_rows, fig_cols, i)
        for (searcher, searcher_results) in test_case_results.items():
            for (translator, results) in searcher_results.items():
                times = [reduction(res) for res in results if "has_plan" in res]
                times.sort()
                #median = times[int(len(times)/2)]
#                median_str = f"{times[int(len(times)/2)]:.4f}" if len(times) > 0 else "N/A"
#                plt.plot(times, 'o-', label=f"{translator.name}({searcher.name})")
                plt.plot(times, 'o-')
                #for res in results:
                #    print(res)

        plt.title(test_case.name)
        #plt.title(f'{description} - {test_case.name}')
        plt.gca().legend(loc='best')
        plt.xlabel('Index')
        plt.ylabel('Time (sec)')
        subplt.set_yscale("log")
        i += 1

    plt.title(description)
    plt.tight_layout()
    plt.savefig(os.path.join(plot_save_dir, f"{name}.png"))

# make_cactus_plot("total_time", "System + User time", lambda res: res.time.seconds_system + res.time.seconds_user)


# Overview Plot
plot = run config

x = IPC instance id

y = median time

In [None]:
# System + User Time

import matplotlib


# class XTickFormater(matplotlib.ticker.Formatter):
#     def format_ticks(values):
#         return [f"P{v+1}" for v in values]
# xtickFormatter = XTickFormater()
xtickFormatter = matplotlib.ticker.FuncFormatter(lambda x, pos: f"P{1+x:02.0f}")
xtickLocator = matplotlib.ticker.MaxNLocator(integer=True)


def make_overview_cactus_plot(name: str, description: str, reduction: Callable[["ResultCollection"], float]):
    fig_rows = 3
    fig_cols = 3


    # Domain -> searcher -> translator -> IPCMedian[] 
    # search_results_domain_searcher_translator_test_case


    i=1
    # scale = 4.5
    scale = 10
    # fig = plt.figure(figsize=(2.3*scale, scale))
    fig = plt.figure(figsize=(scale, scale))
#    ax = fig.gca()
    # plt.title(description)
    for (domain_name, domain_results) in search_results_domain_translator_medians.items():
        subplt = plt.subplot(fig_rows, fig_cols, i)
        max_y = 0
        axis = plt.gca()

        for (translator, translator_results) in domain_results.items():
            for (searcher, resultCollections) in translator_results.items():
                #plt.plot(list(median_times.values())[0], 'o-', label=f"{translator.name}({searcher.name})")
                # plt.plot(median_times, 'o-', label=f"{translator.name}({searcher.name})")
                reduced_data = [reduction(x) for x in resultCollections]
                if searcher.style != None and len([x for x in reduced_data if x != infinite]) > 0:
                    
                    plt.plot(reduced_data, searcher.style, color=searcher.color, label=translator.name)#f"{translator.name}({searcher.name})")

                    #plt.xticks(range(0,30), [f"P{x+1}" for x in range(0,30)])
                    # plt.xlabels([f"P{x+1}" for x in range(0,30)])
                    #(xlim_min, xlim_max) = axis.get_xlim()
                    #axis.set_xlim([xlim_min, min(xlim_max, 30)])

                    # yaxis begins at 0
                    #(ylim_min, ylim_max) = axis.get_ylim()
                    max_y = max(max_y, max([x for x in reduced_data if x != infinite]))


                    subplt.xaxis.set_major_formatter(xtickFormatter)
                    subplt.xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True))
                    # subplt.axes.ticklabel_format(xtickFormatter, "x")
                    # fig.gca().axes.xaxis.format_ticks(xtickFormatter)
                #yticks = range(0,30)
                #plt.yticks(yticks, step=1)
#                ax.xaxis.set_major_locator(MaxNLocator(integer=True)) # Integer x axis

        # axis.set_ylim(bottom=0)

        plt.title(domain_name)
        #plt.title(f'{description} - {test_case.name}')
        plt.gca().legend(loc='best')
        plt.xlabel('Task index')
        if (i-1)%3 == 0:
            plt.ylabel(description)
        subplt.set_yscale("log")
        i += 1


    plt.subplots_adjust(bottom=0.1, right=0.8, top=0.9)
    plt.tight_layout()
    plt.savefig(os.path.join(plot_save_dir, f"{name}.png"))

#make_overview_cactus_plot("overview_min_time", "Min Total Time", lambda res: res.min_time)
#make_overview_cactus_plot("overview_min_verificaiton_time", 'Min verification time', lambda resultCollection: min([res.get("time_verification", infinite) for res in resultCollection.raw])) 

make_overview_cactus_plot("overview_time", 'Median Time (sec)', lambda res: res.median_time)
#make_overview_cactus_plot("overview_time_minus_unfolding", 'Median Time (sec) - minus unfolding', lambda res: res.median_time_minus_unfolding)
#make_overview_cactus_plot("overview_time_minus_unfolding", 'Median unfolding Time (sec)', lambda res: res.median_unfolding_time)
make_overview_cactus_plot("overview_verification_time", "Median Verification Time", lambda res: res.median_verification_time)
make_overview_cactus_plot("overview_unfolded_places", "Unfolded Place Count", lambda resultCollection: resultCollection.raw[0].get("place_count_after_reduction", infinite))
make_overview_cactus_plot("overview_unfolded_transitions", "Unfolded Transition Count", lambda resultCollection: resultCollection.raw[0].get("transition_count_after_reduction", infinite))


# Single Domain Overview

In [None]:
# System + User Time

import matplotlib

xtickFormatter = matplotlib.ticker.FuncFormatter(lambda x, pos: f"P{1+x:02.0f}")
xtickLocator = matplotlib.ticker.MaxNLocator(integer=True)


def make_domain_cactus_plot(name: str, description: str, rows, cols, reductions: list[tuple[str, Callable[["ResultCollection"], float]]]):
    fig_rows = rows #1
    fig_cols = cols #len(reductions)



    scale = 10

#    ax = fig.gca()
    # plt.title(description)
    for (domain_name, domain_results) in search_results_domain_translator_medians.items():
        fig = plt.figure(figsize=(scale, scale))
        i=1
        for (reduction_name, reduction) in reductions:
            subplt = plt.subplot(fig_rows, fig_cols, i)
            max_y = 0
            axis = plt.gca()
            for (translator, translator_results) in domain_results.items():
                for (searcher, resultCollections) in translator_results.items():
                    reduced_data = [reduction(x) for x in resultCollections]
                    if searcher.style != None and len([x for x in reduced_data if x != infinite]) > 0:
                        
                        plt.plot(reduced_data, searcher.style, color=searcher.color, label=translator.name)#f"{translator.name}({searcher.name})")

                        max_y = max(max_y, max([x for x in reduced_data if x != infinite]))


                        subplt.xaxis.set_major_formatter(xtickFormatter)
                        subplt.xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True))


            plt.title(reduction_name)
            #plt.title(f'{description} - {test_case.name}')
            plt.gca().legend(loc='best')
            plt.xlabel('Task index')
            if (i-1)%4 == 0:
                plt.ylabel(domain_name)
            subplt.set_yscale("log")
            i += 1


        plt.subplots_adjust(bottom=0.1, right=0.8, top=0.9)
        plt.tight_layout()
        plt.savefig(os.path.join(plot_save_dir, f"{name}_{domain_name}.png"))

#make_overview_cactus_plot("overview_min_time", "Min Total Time", lambda res: res.min_time)
#make_overview_cactus_plot("overview_min_verificaiton_time", 'Min verification time', lambda resultCollection: min([res.get("time_verification", infinite) for res in resultCollection.raw])) 

#make_overview_cactus_plot("overview_time", 'Median Time (sec)', lambda res: res.median_time)
#make_overview_cactus_plot("overview_time_minus_unfolding", 'Median Time (sec) - minus unfolding', lambda res: res.median_time_minus_unfolding)
#make_overview_cactus_plot("overview_time_minus_unfolding", 'Median unfolding Time (sec)', lambda res: res.median_unfolding_time)

reductions:list[tuple[str, Callable[["ResultCollection"], float]]] = {
    ('Median Time (sec)', lambda res: res.median_time),
    ("Median Verification Time", lambda res: res.median_verification_time),
    ("Unfolded Place Count", lambda resultCollection: resultCollection.raw[0].get("place_count_after_reduction", infinite)),
    ("Unfolded Transition Count", lambda resultCollection: resultCollection.raw[0].get("transition_count_after_reduction", infinite))
}

make_domain_cactus_plot("name", "description", 2, 2, reductions)

## Reductions

In [None]:
# min_time              ", "Min Total Time                     ", lambda res: res.min_time)




# min_verificaiton_time ", 'Min verification time              ', lambda resultCollection: min([res.get("time_verification", infinite) for res in resultCollection.raw])) 




# time                  ", 'Median Time (sec)                  ', lambda res: res.median_time)



# time_minus_unfolding  ", 'Median Time (sec) - minus unfolding', lambda res: res.median_time_minus_unfolding)

# make_overview_cactus_plot("overview_time_minus_unfolding", 'Median Time (sec) - minus unfolding', lambda res: res.median_time_minus_unfolding)

def get_median_time_minus_unfolding(res: ResultCollection):
                
                
                if translator.name=="colored":
                
                plan_lengths = [len(result.plan.actions) if result.get("has_plan", False) else float("inf") for result in results ]
                resultCollection.median_trace_length = statistics.median(plan_lengths)

                time_unfolding = [result["time_unfolding"] if result.get("has_plan", False) else float("inf") for result in results ]
                resultCollection.median_unfolding_time = statistics.median(time_unfolding)

                resultCollection.median_time_minus_unfolding = statistics.median([(result.time.seconds_system + result.time.seconds_user - result["time_unfolding"]) if result.get("has_plan", False) else float("inf") for result in results ])



# time_minus_unfolding  ", 'Median unfolding Time (sec)        ', lambda res: res.median_unfolding_time)


# verification_time     ", "Median Verification Time           ", lambda res: res.median_verification_time)


# unfolded_places       ", "Unfolded Place Count               ", lambda resultCollection: resultCollection.raw[0].get("place_count_after_reduction", infinite))


# unfolded_transitions  ", "Unfolded Transition Count          ", lambda resultCollection: resultCollection.raw[0].get("transition_count_after_reduction", infinite))
