In [None]:
"""Workbook to create figures destined for the paper."""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import annotations

import itertools
from collections import defaultdict
from pathlib import Path
from typing import Dict

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from sklearn.metrics import confusion_matrix as sk_cm

from epi_ml.core.confusion_matrix import ConfusionMatrixWriter
from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    CELL_TYPE,
    IHECColorMap,
    MetadataHandler,
    SplitResultsHandler,
    merge_similar_assays,
)

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
paper_dir = base_dir

In [None]:
IHECColorMap = IHECColorMap(base_fig_dir)
assay_colors = IHECColorMap.assay_color_map
cell_type_colors = IHECColorMap.cell_type_color_map

In [None]:
base_fig_dir = base_fig_dir / "fig1_EpiAtlas_assay"

In [None]:
split_results_handler = SplitResultsHandler()

In [None]:
metadata_handler = MetadataHandler(paper_dir)
metadata_v2 = metadata_handler.load_metadata("v2")

### Prepare assay predictions data

In [None]:
all_split_dfs = split_results_handler.gather_split_results_across_methods(
    results_dir=base_data_dir / "dfreeze_v2",
    label_category=ASSAY,
    only_NN=False,
)

In [None]:
full_dfs = split_results_handler.concatenate_split_results(all_split_dfs)
merged_dfs = {classifier: merge_similar_assays(df) for classifier, df in full_dfs.items()}
assays = merged_dfs[next(iter(merged_dfs))]["True class"].unique()

# Add Max pred
for classifier, df in merged_dfs.items():
    df["Max pred"] = df[assays].max(axis=1)

# Join metadata
for classifier, df in merged_dfs.items():
    merged_dfs[classifier] = metadata_handler.join_metadata(df, metadata_v2)

### Prediction score per assay

Average distribution of prediction scores per assay 
violin plot. One point per UUID, track types averaged (combine 2xRNA and 2xWGBS)
points with 3 colors: 
- black for pred same class
- red for pred different class/mislabel
- orange bad qual (IHEC flag, was removed in later stages)

In [None]:
def fig1_a(
    NN_results: pd.DataFrame, logdir: Path, name: str, merge_assay_pairs: bool
) -> None:
    """
    Creates a Plotly figure with violin plots and associated scatter plots for each class.
    Red scatter points, indicating a mismatch, appear on top and have a larger size.

    Args:
        NN_results (pd.DataFrame): The DataFrame containing the neural network results, including metadata.
        logdir (Path): The directory where the figure will be saved.
        name (str): The name of the figure.
        merge_assay_pairs (bool): Whether to merge similar assays (mrna/rna, wgbs-pbat/wgbs-standard)
    Returns:
        None: Displays the plotly figure.
    """
    fig = go.Figure()

    y_range = [0.42, 1.01]  # Y-axis range for the plot

    # Combine similar assays
    if merge_assay_pairs:
        NN_results = merge_similar_assays(NN_results)

    # Adjustments for replacement and class ordering
    class_labels = NN_results["True class"].unique()
    class_labels_sorted = sorted(class_labels)
    class_index = {label: i for i, label in enumerate(class_labels_sorted)}

    scatter_offset = 0.05  # Scatter plot jittering

    for label in class_labels_sorted:
        df = NN_results[NN_results["True class"] == label]

        # Majority vote, mean prediction score
        groupby_epirr = df.groupby(["EpiRR", "Predicted class"])["Max pred"].aggregate(
            ["size", "mean"]
        )

        groupby_epirr = groupby_epirr.reset_index().sort_values(
            ["EpiRR", "size"], ascending=[True, False]
        )
        groupby_epirr = groupby_epirr.drop_duplicates(subset="EpiRR", keep="first")
        assert groupby_epirr["EpiRR"].is_unique

        mean_pred = groupby_epirr["mean"]

        # Add violin plot with integer x positions
        line_color = "white"
        fig.add_trace(
            go.Violin(
                x=[class_index[label]] * len(mean_pred),
                y=mean_pred,
                name=label,
                spanmode="hard",
                box_visible=True,
                meanline_visible=True,
                points="all",
                fillcolor=assay_colors[label],
                line_color=line_color,
                line=dict(width=0.8),
                marker=dict(size=1, color=assay_colors[label]),
                showlegend=False,
            )
        )

        # # Prepare data for scatter plots
        # jittered_x_positions = np.random.uniform(-scatter_offset, scatter_offset, size=len(mean_pred)) + class_index[label] - 0.25  # type: ignore

        # match_pred = [
        #     mean_pred.iloc[i]
        #     for i, row in enumerate(groupby_epirr.iterrows())
        #     if row[1]["Predicted class"] == label
        # ]
        # mismatch_pred = [
        #     mean_pred.iloc[i]
        #     for i, row in enumerate(groupby_epirr.iterrows())
        #     if row[1]["Predicted class"] != label
        # ]

        # match_x_positions = [
        #     jittered_x_positions[i]
        #     for i, row in enumerate(groupby_epirr.iterrows())
        #     if row[1]["Predicted class"] == label
        # ]
        # mismatch_x_positions = [
        #     jittered_x_positions[i]
        #     for i, row in enumerate(groupby_epirr.iterrows())
        #     if row[1]["Predicted class"] != label
        # ]

        # # Add scatter plots for matches in black
        # fig.add_trace(
        #     go.Scatter(
        #         x=match_x_positions,
        #         y=match_pred,
        #         mode="markers",
        #         name=f"Match {label}",
        #         marker=dict(
        #             color="black",
        #             size=1,  # Standard size for matches
        #         ),
        #         hoverinfo="text",
        #         hovertext=[
        #             f"EpiRR: {row[1]['EpiRR']}, Pred class: {row[1]['Predicted class']}, Mean pred: {row[1]['mean']:.2f}"
        #             for row in groupby_epirr.iterrows()
        #             if row[1]["Predicted class"] == label
        #         ],
        #         showlegend=False,
        #     )
        # )

        # # Add scatter plots for mismatches in red, with larger size
        # fig.add_trace(
        #     go.Scatter(
        #         x=mismatch_x_positions,
        #         y=mismatch_pred,
        #         mode="markers",
        #         name=f"Mismatch {label}",
        #         marker=dict(
        #             color="red",
        #             size=3,  # Larger size for mismatches
        #         ),
        #         hoverinfo="text",
        #         hovertext=[
        #             f"EpiRR: {row[1]['EpiRR']}, Pred class: {row[1]['Predicted class']}, Mean pred: {row[1]['mean']:.3f}"
        #             for row in groupby_epirr.iterrows()
        #             if row[1]["Predicted class"] != label
        #         ],
        #         showlegend=False,
        #     )
        # )

    # Update layout to improve visualization
    fig.update_layout(
        title_text="Prediction score distribution per assay class",
        yaxis_title="Average prediction score (majority class)",
        xaxis_title="Expected class label",
    )
    fig.update_yaxes(range=y_range)
    fig.update_xaxes(tickvals=list(class_index.values()), ticktext=class_labels_sorted)

    # # Add a dummy scatter plot for legend - black points
    # fig.add_trace(
    #     go.Scatter(
    #         x=[None],
    #         y=[None],
    #         mode="markers",
    #         name="Match",
    #         marker=dict(color="black", size=10),
    #         showlegend=True,
    #         legendgroup="match",
    #     )
    # )

    # # Add a dummy scatter plot for legend - red points
    # fig.add_trace(
    #     go.Scatter(
    #         x=[None],
    #         y=[None],
    #         mode="markers",
    #         name="Mismatch",
    #         marker=dict(color="red", size=10),
    #         showlegend=True,
    #         legendgroup="mismatch",
    #     )
    # )

    # # Update the layout to adjust the legend
    # fig.update_layout(
    #     legend=dict(
    #         title_text="Legend",
    #         itemsizing="constant",
    #         orientation="h",
    #         yanchor="bottom",
    #         y=1.02,
    #         xanchor="right",
    #         x=1,
    #     )
    # )

    # Save figure
    fig.write_image(logdir / f"{name}.svg")
    fig.write_image(logdir / f"{name}.png")
    fig.write_html(logdir / f"{name}.html")

    fig.show()

In [None]:
NN_results = merged_dfs["NN"]

In [None]:
logdir = base_fig_dir / "fig1--pred_score_per_assay"
# fig1_a(NN_results, logdir=logdir, name="pred_score_per_assay", merge_assay_pairs=False)

### Performance of classification algorithm (boxplot)

Violin plot (10 folds) of overall accuracy for each model (NN, LR, RF, LGBM, SV).  
For each split, 4 box plot per model:
  - Acc
  - F1
  - AUROC (OvR, both micro/macro)

Source files:  
epiatlas-dfreeze-v2.1/hg38_100kb_all_none (oversampling)

#### Figure

In [None]:
def plot_split_metrics(
    split_metrics: Dict[str, Dict[str, Dict[str, float]]],
    label_category: str,
    logdir: Path,
    name: str,
) -> None:
    """Render to box plots the metrics per classifier and split, each in its own subplot.

    Args:
        split_metrics: A dictionary containing metric scores for each classifier and split.
    """
    metrics = ["Accuracy", "F1_macro", "AUC_micro", "AUC_macro"]
    classifiers = list(next(iter(split_metrics.values())).keys())

    # Create subplots, one row for each metric
    fig = make_subplots(rows=1, cols=len(metrics), subplot_titles=metrics)

    colors = {
        classifier: px.colors.qualitative.Plotly[i]
        for i, classifier in enumerate(classifiers)
    }

    for i, metric in enumerate(metrics):
        for classifier in classifiers:
            values = [split_metrics[split][classifier][metric] for split in split_metrics]

            fig.add_trace(
                go.Box(
                    y=values,
                    name=classifier,
                    marker_color=colors[classifier],
                    line=dict(color="black", width=1),
                    marker=dict(size=2),
                    boxmean=True,
                    boxpoints="all",  # or "outliers" to show only outliers
                    pointpos=-1.4,
                    showlegend=False,
                    width=0.5,
                    hoverinfo="text",
                    hovertext=[
                        f"{split}: {value:.4f}"
                        for split, value in zip(split_metrics, values)
                    ],
                ),
                row=1,
                col=i + 1,
            )

    fig.update_layout(
        title_text=f"{label_category} classification - Metric distribution for 10fold cross-validation",
        yaxis_title="Value",
        boxmode="group",
    )

    # Adjust y-axis
    if label_category == ASSAY:
        range_acc = [0.955, 1.001]
        range_AUC = [0.992, 1.0001]
    elif label_category == CELL_TYPE:
        range_acc = [0.81, 1]
        range_AUC = [0.96, 1]
    else:
        range_acc = [0.6, 1.001]
        range_AUC = [0.9, 1.0001]

    fig.update_layout(yaxis=dict(range=range_acc))
    fig.update_layout(yaxis2=dict(range=range_acc))
    fig.update_layout(yaxis3=dict(range=range_AUC))
    fig.update_layout(yaxis4=dict(range=range_AUC))

    # Save figure
    fig.write_image(logdir / f"{name}.svg")
    fig.write_image(logdir / f"{name}.png")
    fig.write_html(logdir / f"{name}.html")

    fig.show()

In [None]:
for label_category in [ASSAY, CELL_TYPE]:
    path = base_data_dir / "dfreeze_v2"
    all_split_dfs = split_results_handler.gather_split_results_across_methods(
        results_dir=path, label_category=label_category
    )
    split_metrics = split_results_handler.compute_split_metrics(all_split_dfs)

    # plot_split_metrics(
    #     split_metrics,
    #     label_category=label_category,
    #     logdir=base_fig_dir,
    #     name=f"{label_category}_10fold_metrics_all_classifiers",
    # )

### Prediction score per assay across classifiers

Per model, compute score distribution per assay (1 violin per assay). No SVM. Agree black, red disagree.

In [None]:
def fig1_supp_B(df_dict: Dict[str, pd.DataFrame], logdir: Path, name: str) -> None:
    """
    Creates a Plotly figure with subplots for each assay, each containing violin plots for different classifiers
    and associated scatter plots for matches (in black) and mismatches (in red).

    Args:
        df_dict (Dict[str, pd.DataFrame]): Dictionary with the DataFrame containing the results for each classifier.
        logdir (Path): The directory path for saving the figures.
        name (str): The name for the saved figures.

    Returns:
        None: Displays the plotly figure.
    """
    # Ignore LinearSVC and RandomForest for this figure
    if "LinearSVC" in df_dict:
        del df_dict["LinearSVC"]
    if "RF" in df_dict:
        del df_dict["RF"]

    # Assuming all classifiers have the same assays for simplicity
    first_key = next(iter(df_dict))
    class_labels = df_dict[first_key]["True class"].unique()
    class_labels_sorted = sorted(class_labels)
    num_assays = len(class_labels_sorted)

    classifiers = list(df_dict.keys())
    classifier_index = {name: i for i, name in enumerate(classifiers)}
    num_classifiers = len(classifiers)

    scatter_offset = 0.05  # Scatter plot jittering

    # Calculate the size of the grid
    grid_size = int(np.ceil(np.sqrt(num_assays)))
    rows, cols = grid_size, grid_size

    # Create subplots with a square grid
    fig = make_subplots(
        rows=rows,
        cols=cols,
        subplot_titles=class_labels_sorted,
        shared_yaxes="all",  # type: ignore
        horizontal_spacing=0.05,
        vertical_spacing=0.05,
        y_title="Average prediction score",
    )
    for idx, label in enumerate(class_labels_sorted):
        row, col = divmod(idx, grid_size)
        for classifier_name, classifier_df in df_dict.items():
            df = classifier_df[classifier_df["True class"] == label]

            # Majority vote, mean prediction score
            groupby_epirr = df.groupby(["EpiRR", "Predicted class"])[
                "Max pred"
            ].aggregate(["size", "mean"])
            groupby_epirr = groupby_epirr.reset_index().sort_values(
                ["EpiRR", "size"], ascending=[True, False]
            )
            groupby_epirr = groupby_epirr.drop_duplicates(subset="EpiRR", keep="first")
            assert groupby_epirr["EpiRR"].is_unique

            mean_pred = groupby_epirr["mean"]
            classifier_pos = classifier_index[classifier_name]

            # Add violin plot with integer x positions
            fig.add_trace(
                go.Violin(
                    x=classifier_pos * np.ones(len(mean_pred)),
                    y=mean_pred,
                    name=label,
                    spanmode="hard",
                    box_visible=True,
                    meanline_visible=True,
                    points=False,
                    fillcolor="grey",
                    line_color="black",
                    line=dict(width=0.8),
                    showlegend=False,
                ),
                row=row + 1,  # Plotly rows are 1-indexed
                col=col + 1,
            )

            # Prepare data for scatter plots
            jittered_x_positions = np.random.uniform(-scatter_offset, scatter_offset, size=len(mean_pred)) + classifier_pos - 0.3  # type: ignore

            match_pred = [
                mean_pred.iloc[i]
                for i, row in enumerate(groupby_epirr.iterrows())
                if row[1]["Predicted class"] == label
            ]
            mismatch_pred = [
                mean_pred.iloc[i]
                for i, row in enumerate(groupby_epirr.iterrows())
                if row[1]["Predicted class"] != label
            ]

            match_x_positions = [
                jittered_x_positions[i]
                for i, row in enumerate(groupby_epirr.iterrows())
                if row[1]["Predicted class"] == label
            ]
            mismatch_x_positions = [
                jittered_x_positions[i]
                for i, row in enumerate(groupby_epirr.iterrows())
                if row[1]["Predicted class"] != label
            ]

            # Add scatter plots for matches in black
            fig.add_trace(
                go.Scatter(
                    x=match_x_positions,
                    y=match_pred,
                    mode="markers",
                    marker=dict(color="black", size=1),
                    showlegend=False,
                    name=f"Match {classifier_name}",
                ),
                row=row + 1,  # Plotly rows are 1-indexed
                col=col + 1,
            )

            # Add scatter plots for mismatches in red
            fig.add_trace(
                go.Scatter(
                    x=mismatch_x_positions,
                    y=mismatch_pred,
                    mode="markers",
                    marker=dict(color="red", size=3),
                    showlegend=False,
                    name=f"Mismatch {classifier_name}",
                ),
                row=row + 1,  # Plotly rows are 1-indexed
                col=col + 1,
            )

    # Add a dummy scatter plot for legend - black points
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            name="Match",
            marker=dict(color="black", size=10),
            showlegend=True,
            legendgroup="match",
        )
    )

    # Add a dummy scatter plot for legend - red points
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            name="Mismatch",
            marker=dict(color="red", size=10),
            showlegend=True,
            legendgroup="mismatch",
        )
    )

    # Update the layout to adjust the legend
    fig.update_layout(
        legend=dict(
            title_text="Legend",
            itemsizing="constant",
            orientation="h",
            yanchor="bottom",
            y=1.025,
            xanchor="right",
            x=1,
        )
    )

    # Update layout to improve visualization, adjust if needed for better appearance with multiple classifiers
    fig.update_layout(
        title_text="Prediction score distribution per assay across classifiers",
        height=1500,  # Adjust the height as necessary
        width=1500,  # Adjust the width based on the number of assays
    )

    # fig.update_layout(yaxis2=dict(range=[0.9, 1.01]))

    # Adjust tick names
    # Assuming equal spacing between each classifier on the x-axis
    tickvals = list(
        range(0, num_classifiers + 1)
    )  # Generate tick values (1-indexed for Plotly)
    ticktext = classifiers  # Use classifier names as tick labels
    for i, j in itertools.product(range(rows), range(cols)):
        fig.update_xaxes(tickvals=tickvals, ticktext=ticktext, row=i + 1, col=j + 1)

    # Save figure
    fig.write_image(logdir / f"{name}.svg")
    fig.write_image(logdir / f"{name}.png")
    fig.write_html(logdir / f"{name}.html")

    fig.show()

In [None]:
logdir = base_data_dir / "fig1--pred_score_per_assay_across_classifiers"
# fig1_supp_B(merged_dfs, logdir=base_fig_dir, name="fig1_supp_B")

### Confusions matrices across classification algorithms

For each classifier type

Confusion matrix (1point=1 uuid) for observed datasets with average scores>0.9
- Goal: Represent global predictions/mislabels. 11c

In [None]:
# TODO fix input problem, file VS epirr incoherence

In [None]:
def create_confusion_matrix(
    df: pd.DataFrame,
    min_pred_score: float,
    logdir: Path,
    name: str,
    majority: bool = False,
) -> None:
    """Create a confusion matrix for the given DataFrame.

    Args:
        df (pd.DataFrame): The DataFrame containing the neural network results.
        min_pred_score (float): The minimum prediction score to consider.
        logdir (Path): The directory path for saving the figures.
        name (str): The name for the saved figures.
        majority (bool): Whether to use majority vote (uuid-wise) for the predicted class.
    """
    # Compute confusion matrix
    classes = sorted(df["True class"].unique())
    if "Max pred" not in df.columns:
        df["Max pred"] = df[classes].max(axis=1)  # type: ignore
    filtered_df = df[df["Max pred"] > min_pred_score]

    if majority:
        # Majority vote for predicted class
        groupby_uuid = df.groupby(["uuid", "True class", "Predicted class"])[
            "Max pred"
        ].aggregate(["size", "mean"])
        groupby_uuid = groupby_uuid.reset_index().sort_values(
            ["uuid", "True class", "size"], ascending=[True, True, False]
        )
        groupby_uuid = groupby_uuid.drop_duplicates(
            subset=["uuid", "True class"], keep="first"
        )
        filtered_df = groupby_uuid

    confusion_mat = sk_cm(
        filtered_df["True class"], filtered_df["Predicted class"], labels=classes
    )

    mat_writer = ConfusionMatrixWriter(labels=classes, confusion_matrix=confusion_mat)
    mat_writer.to_all_formats(logdir, name=f"{name}_n{len(filtered_df)}")

In [None]:
min_pred_score = 0.9
majority = True

for classifier_name, df in full_dfs.items():
    df_with_meta = metadata_handler.join_metadata(df, metadata_v2)
    assert "Predicted class" in df_with_meta.columns

    name = f"{classifier_name}_pred>{min_pred_score}"
    if classifier_name == "LinearSVC":
        name = f"{classifier_name}"

    logdir = base_fig_dir / "fig1_supp_F-assay_c11_confusion_matrices"
    if majority:
        logdir = logdir / "per_uuid"
    else:
        logdir = logdir / "per_file"

# create_confusion_matrix(
#     df=df_with_meta,
#     min_pred_score=min_pred_score,
#     logdir=logdir,
#     name=name,
#     majority=majority,
# )

### Imputed prediction score

Inference on imputed data: Violin plot with pred score per assay (like Fig1A)  
VS  
Inference on observed data, from imputed   

In [None]:
fig_dir = base_fig_dir / "fig1--imputed_pred_score"
data_dir = base_data_dir / "imputation"

# Load data
normal_inf_imputed_path = next(
    (data_dir / "hg38_100kb_all_none/assay_epiclass_1l_3000n").glob("**/*.csv")
)
normal_inf_imputed_df = pd.read_csv(
    normal_inf_imputed_path, header=0, index_col=0, low_memory=False
)

imputed_inf_normal_path = next(
    (data_dir / "hg38_100kb_all_none_imputed/assay_epiclass_1l_3000n").rglob("**/*.csv")
)
imputed_inf_normal_df = pd.read_csv(
    imputed_inf_normal_path, header=0, index_col=0, low_memory=False
)

# Plot
assay_labels = normal_inf_imputed_df["True class"].unique()
for name, df in zip(
    ["train_normal_inf_imputed", "train_imputed_inf_normal"],
    [normal_inf_imputed_df, imputed_inf_normal_df],
):
    df["EpiRR"] = list(df.index)
    df[ASSAY] = df["True class"]
    df["Max pred"] = df[assay_labels].max(axis=1)
    # fig1_a(
    #     df, logdir=fig_dir, name=f"fig1_supp_D-{name}_n{len(df)}", merge_assay_pairs=False
    # )

### NN - Accuracy per assay (boxplot 10fold)

In [None]:
def fig1_acc_per_assay(
    all_split_dfs: Dict[str, Dict[str, pd.DataFrame]], logdir: Path, name: str
) -> None:
    """
    Creates a Plotly figure with a boxplot for the accuracy of each assay over all splits.

    Args:
        NN_results (pd.DataFrame): The DataFrame containing the neural network results.
        logdir (Path): The directory where the figure will be saved.
        name (str): The name of the figure.
    Returns:
        None: Displays the plotly figure.
    """
    fig = go.Figure()

    # Compute accuracy per assay per split
    assay_accuracies = defaultdict(dict)
    for split_name, classifier_dict in all_split_dfs.items():
        df = classifier_dict["NN"]
        df = merge_similar_assays(df)
        pred_groupby = df.groupby(["True class"])["Predicted class"].value_counts(
            normalize=True
        )
        assay_acc = pred_groupby.unstack().fillna(0).max(axis=1)
        assay_accuracies[split_name] = assay_acc.to_dict()

    # invert the dictionary
    assay_accuracies_inv = defaultdict(dict)
    for split_name, assay_acc in assay_accuracies.items():
        for assay, acc in assay_acc.items():
            assay_accuracies_inv[assay][split_name] = acc

    class_labels_sorted = sorted(assay_accuracies_inv.keys())

    for label in class_labels_sorted:
        hovertext = [
            f"{split_name}: {acc:.4f}"
            for split_name, acc in assay_accuracies_inv[label].items()
        ]
        line_color = "black"
        fig.add_trace(
            go.Box(
                name=label,
                y=list(assay_accuracies_inv[label].values()),
                boxmean=True,
                boxpoints="all",
                fillcolor=assay_colors[label],
                line_color=line_color,
                marker=dict(size=3, color=assay_colors[label]),
                showlegend=True,
                hoverinfo="text",
                hovertext=hovertext,
            )
        )

    # Update layout to improve visualization
    fig.update_layout(
        title_text="Neural network assay classification 10-fold accuracy",
        yaxis_title="Accuracy",
        xaxis_title="Assay experiment",
    )

    # Save figure
    fig.write_image(logdir / f"{name}.svg")
    fig.write_image(logdir / f"{name}.png")
    fig.write_html(logdir / f"{name}.html")

    fig.show()

In [None]:
all_split_dfs = split_results_handler.gather_split_results_across_methods(
    results_dir=base_data_dir / "dfreeze_v2",
    label_category=ASSAY,
    only_NN=True,
)

In [None]:
logdir = base_fig_dir / "fig1--acc_per_assay"
name = "fig1--acc_per_assay"
fig1_acc_per_assay(all_split_dfs, logdir=logdir, name=name)