In [None]:
import os
import glob

import numpy as np
import pandas as pd

from sklearn.metrics import auc

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

from tqdm.auto import tqdm

In [None]:
%matplotlib inline

matplotlib.rcParams["figure.figsize"] = [4, 4]
matplotlib.rcParams["figure.dpi"] = 125
matplotlib.rcParams["image.cmap"] = "Spectral_r"
# no bounding boxes or axis:
matplotlib.rcParams["axes.spines.bottom"] = "on"
matplotlib.rcParams["axes.spines.top"] = "off"
matplotlib.rcParams["axes.spines.left"] = "on"
matplotlib.rcParams["axes.spines.right"] = "off"

In [None]:
data_ids = ["linear", "branch", "cluster", "covid19-pbmc", "bcr-xl"]
result_path = "../benchmark"

In [None]:
runtimes = dict()
allAUCs = dict()
for dataset_id in tqdm(data_ids, desc="dataset"):
    directory_path = os.path.join(result_path, dataset_id)

    AUCs = dict()
    rt = runtimes.get(dataset_id, dict())
    for csv_file in tqdm(glob.glob(f"{directory_path}/*.csv"), desc=dataset_id):
        if "_batch.csv" in csv_file:
            continue
        df = pd.read_csv(csv_file)
        filename = os.path.basename(csv_file)
        method = filename.split(".")[-2]
        for term in filename.split("_"):
            if term.startswith("enr"):
                break
        else:
            term = "enr0"
        enrichment = float(term[3:])
        a1 = auc(
            np.maximum.accumulate(df["FPR"]),
            np.maximum.accumulate(df["TPR"]),
        )
        a2 = auc(
            np.minimum.accumulate(df["FPR"]),
            np.minimum.accumulate(df["TPR"]),
        )
        a = max(a1, a2)
        enrichment_aucs = AUCs.get(enrichment, dict())
        method_aucs = enrichment_aucs.get(method, list())
        method_aucs.append(a)
        enrichment_aucs[method] = method_aucs
        AUCs[enrichment] = enrichment_aucs

        runntime = rt.get(method, list())
        runntime += df["runtime"].to_list()
        rt[method] = runntime
    runtimes[dataset_id] = rt
    allAUCs[dataset_id] = AUCs

In [None]:
colormap = matplotlib.colormaps["Paired"]
width, height = matplotlib.rcParams["figure.figsize"]
nplot = len(allAUCs)
fig, axs = plt.subplots(ncols=nplot, figsize=(nplot * width, height))

# Extract unique methods and map to colors
unique_methods = {
    method
    for AUCs in allAUCs.values()
    for enrichment_aucs in AUCs.values()
    for method in enrichment_aucs.keys()
}
method_colors = {
    method: color for method, color in zip(sorted(unique_methods), colormap.colors)
}

for ax, (dataset_id, AUCs) in zip(axs, allAUCs.items()):
    sorted_enrichments = sorted(AUCs.keys())

    # Position tracker for x-ticks
    position = 0

    # Track tick positions and labels for enrichments
    tick_positions = []
    tick_labels = []

    for enrichment in sorted_enrichments:
        enrichment_aucs = AUCs[enrichment]
        methods = sorted(
            enrichment_aucs.keys(), key=lambda x: -np.median(enrichment_aucs[x])
        )
        data = [enrichment_aucs[method] for method in methods]

        bp = ax.boxplot(
            data,
            positions=range(position, position + len(methods)),
            widths=1,
            patch_artist=True,
        )

        for patch, method in zip(bp["boxes"], methods):
            patch.set_facecolor(method_colors[method])
            patch.set_edgecolor("black")

        # Record tick position and label for this enrichment
        tick_positions.append(position + len(methods) / 2)
        tick_labels.append(str(enrichment))

        position += len(methods)

    ax.set_xticks(tick_positions, tick_labels)

    ax.set_xlim(-1, position)

    ax.set_xlabel("Enrichment")
    ax.set_ylabel("AUC")
    ax.set_title(f"dataset = {dataset_id}")

legend_elements = [
    mpatches.Patch(facecolor=color, label=method, edgecolor="black")
    for method, color in method_colors.items()
]

plt.legend(handles=legend_elements, loc="center left", bbox_to_anchor=(1, 0.5))

plt.show()

In [None]:
# Create the plot
width, height = matplotlib.rcParams["figure.figsize"]
nplot = len(runtimes)
fig, axs = plt.subplots(ncols=nplot, figsize=(nplot * width, height))

for ax, (dataset_id, rt) in zip(axs, runtimes.items()):
    # Sort categories by mean value
    methods = sorted(rt.keys(), key=lambda x: np.mean(rt[x]))
    data = [rt[method] for method in methods]

    bp = ax.boxplot(data, labels=methods, patch_artist=True)

    # Set the colors for the boxes
    for patch, method in zip(bp["boxes"], methods):
        patch.set_facecolor(method_colors[method])
        patch.set_edgecolor("black")

    ax.set_xticklabels(ax.get_xticklabels(), rotation=90)

    ax.set_xlabel("Methods")
    ax.set_ylabel("Runntime")
    ax.set_title(f"dataset = {dataset_id}")

    ax.set_yscale("log")

plt.legend(bp["boxes"], methods, loc="center left", bbox_to_anchor=(1, 0.5))

plt.show()