In [None]:
# Create plot for exectime

In [None]:
from typing import Optional, Collection
import numpy as np
from matplotlib import pyplot as plt
import os
import pickle as pkl

In [None]:
def gather_results(results_dir) -> list:
    """Load pickled results from a directory and gather the contents into a list."""
    results = []
    for file in os.listdir(results_dir):
        if file.endswith(".pkl"):
            with open(os.path.join(results_dir, file), "rb") as f:
                results.append(pkl.load(f))
    return results


def extract_grouped_values(data: list,
                           independent_key: str,
                           dependent_key: str,
                           filter: Optional[dict] = None):
    """
    Extracts and groups values from a list of dictionaries based on given conditions and keys.
    Parameters
    ----------
    data : List[dict]
        List of dicts to parse.
    independent_key : str
        Key to use for the independent variable.
    dependent_key : str
        Key to use for the dependent variable.
    filter : dict, optional
        Restrict the data to only those matching the given conditions.

    Returns
    -------

    """

    # Dictionary to store independent value as key and list of dependent values as value
    grouped_data = {}
    if filter is None:
        filter = {}
    # Iterate over each record in the data list
    for record in data:
        skip_record = False
        for filt_key, filt_value in filter.items():
            record_value = record.get(filt_key)
            if isinstance(record_value, list) and len(record_value) >= 1:
                record_value = record_value[0]
            if record_value != filt_value:
                skip_record = True
                break  # skip if record doesn't match the requested filter value
        if skip_record:
            continue
        indep_val = record.get(independent_key)
        if isinstance(indep_val, Collection) and len(indep_val) >= 1:
            indep_val = indep_val[0]
        dep_val = record.get(dependent_key)
        if indep_val is not None:
            # Append the dependent value to the list for the corresponding independent value
            if indep_val in grouped_data:
                grouped_data[indep_val].append(dep_val)
            else:
                grouped_data[indep_val] = [dep_val]

    # Extract the unique independent values and the corresponding dependent values lists.
    # Sorting the keys is optional; if order matters, you might choose to preserve the order
    # in which they were first encountered.
    independent_values = list(grouped_data.keys())
    dependent_values = [grouped_data[key] for key in independent_values]

    return independent_values, dependent_values

In [None]:
# First set of data is in results_exectime2 - this is the directory for tradeSeq and GPfates results.
data = gather_results("results_exectime2/")
ind, dep = {}, {}
for method in ["scTransient", "tradeSeq", "GPfates"]:
    ind[method], dep[method] = extract_grouped_values(data, independent_key="n_cells", dependent_key="time", filter={"method": method})

In [None]:
# Second set is for scTransient. Since scTransient executes more quickly, it was run with parameters that extend beyond what the other 
# methods are capable of processing in a reasonable timeframe (e.g. 500k cells).
data2 = gather_results("results_exectime3/")
for method in ["scTransient"]:
    ind[method], dep[method] = extract_grouped_values(data2, independent_key="n_cells", dependent_key="time", filter={"method": method})

In [None]:
# Create the plot

fig, axs = plt.subplots(dpi=1200)
axs.set_yscale('log')
axs.set_xscale('log')
for method in ["scTransient", "GPfates", "tradeSeq"]:
    m = np.array([np.mean(d) for d in dep[method]])
    s = np.array([np.std(d) for d in dep[method]])
    reorder = np.argsort(ind[method])
    axs.errorbar(x=np.array(ind[method])[reorder], y=m[reorder], yerr=s[reorder], label=method, marker=".")
fig.legend(loc=(0.76,0.80))
fig.tight_layout()
axs.set_xlabel("Cell count", fontdict={"size": 20})
axs.set_ylabel("Execution time (s)", fontdict={"size": 20})
axs.set_title("Method execution time for TE detection", fontdict={"size": 20})
axs.set_xlim(10,600000)
fig.savefig("method_exectime.png")
fig.savefig("method_exectime.eps")