In [None]:
"""Workbook to create figures (fig2) destined for the paper.
"""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines, unused-import, unused-argument, too-many-branches

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import annotations

import copy
import itertools
import logging
import os
import re
from collections import defaultdict
from pathlib import Path
from typing import Callable, Dict, List, Set, Tuple

logging.basicConfig(level=logging.DEBUG)

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import display
from plotly.subplots import make_subplots
from scipy.stats import zscore
from sklearn.metrics import confusion_matrix as sk_cm

from epi_ml.core.confusion_matrix import ConfusionMatrixWriter
from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    ASSAY_MERGE_DICT,
    ASSAY_ORDER,
    CELL_TYPE,
    LIFE_STAGE,
    SEX,
    IHECColorMap,
    MetadataHandler,
    SplitResultsHandler,
    extract_experiment_keys_from_output_files,
    extract_input_sizes_from_output_files,
    extract_node_jobs_from_error_files,
    merge_similar_assays,
)

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
paper_dir = base_dir

if not base_fig_dir.exists():
    raise FileNotFoundError(f"Directory {base_fig_dir} does not exist.")

In [None]:
IHECColorMap = IHECColorMap(base_fig_dir)
assay_colors = IHECColorMap.assay_color_map
cell_type_colors = IHECColorMap.cell_type_color_map
sex_colors = IHECColorMap.sex_color_map

In [None]:
split_results_handler = SplitResultsHandler()
metadata_handler = MetadataHandler(paper_dir)

In [None]:
metadata_v2_df = metadata_handler.load_metadata_df("v2", merge_assays=True)

In [None]:
parameters_metadata_path = (
    base_data_dir
    / "training_results"
    / "all_results_cometml_filtered_oversampling-fixed.csv"
)
RUN_METADATA = pd.read_csv(parameters_metadata_path, dtype=str)

## Fig 2 - EpiClass results on EpiAtlas other metadata

### Neural network performance across metadata categories

#### Check if oversampling is uniform

In [None]:
def check_for_oversampling(parent_dir: Path, verbose: bool = False):
    """Check for oversampling status in the results, using "output_job*.o" files.
    Args:
        parent_dir (Path): Parent directory of the results. (classifier type level, e.g. assay_epiclass_1l_3000n)

    """
    # Identify experiments
    exp_keys_dict = extract_experiment_keys_from_output_files(parent_dir)

    # Filter metadata to only include experiments in the results
    all_exp_keys = set()
    for exp_keys in exp_keys_dict.values():
        all_exp_keys.update(exp_keys)

    df = RUN_METADATA[RUN_METADATA["experimentKey"].isin(all_exp_keys)]
    df["general_name"] = df["Name"].str.replace(r"[_-]?split\d+$", "", regex=True)
    # print(df[["general_name"] + [f"run_arg_{i}" for i in range(5)]].value_counts())

    # Check oversampling values, ignore nan
    df_na = df[df["hparams/oversampling"].isna()]
    df = df[df["hparams/oversampling"].notna()]
    if not (df["hparams/oversampling"] == "true").all():
        err_df = df.groupby(["general_name", "hparams/oversampling"]).agg("size")
        print(
            "Not all experiments have oversampling:\n%s",
            err_df,
        )

    print(
        f"Checked {len(exp_keys_dict)} folders and found {len(df)} oversampling values."
    )
    if len(df_na) != 0:
        print(
            "Could not read oversampling value of all visited experiments. Values missing in:"
        )
        print(df_na[["general_name"] + [f"run_arg_{i}" for i in range(5)]].value_counts())

In [None]:
# base_data_dir

In [None]:
# check_for_oversampling(base_data_dir / "training_results" / "dfreeze_v2" / "hg38_10kb_all_none" / "assay_epiclass_1l_3000n", verbose=True)

In [None]:
assay_dir = (
    base_data_dir
    / "dfreeze_v2"
    / "hg38_100kb_all_none"
    / "assay_epiclass_1l_3000n"
    / "11c"
)
ct_dir = (
    base_data_dir
    / "dfreeze_v2"
    / "hg38_100kb_all_none"
    / "harmonized_sample_ontology_intermediate_1l_3000n"
)

In [None]:
def create_mislabel_corrector():
    """Obtain information necessary to correct sex and life_stage mislabels.

    Returns:
        Dict[str, str]: {md5sum: EpiRR_no-v}
        Dict[str, Dict[str, str]]: {label_category: {EpiRR_no-v: corrected_label}}
    """
    epirr_no_v = "EpiRR_no-v"
    # Associate epirrs to md5sums
    metadata = MetadataHandler(paper_dir).load_metadata("v2")
    metadata_df = pd.DataFrame.from_records(list(metadata.datasets))
    md5sum_to_epirr = metadata_df.set_index("md5sum")[epirr_no_v].to_dict()

    # Load mislabels
    epirr_to_corrections = {}
    metadata_dir = base_data_dir / "metadata" / "official" / "BadQual-mislabels"

    sex_mislabeled = pd.read_csv(metadata_dir / "official_Sex_mislabeled.csv")
    epirr_to_corrections[SEX] = sex_mislabeled.set_index(epirr_no_v)[
        "EpiClass_pred_Sex"
    ].to_dict()

    life_stage_mislabeled = pd.read_csv(
        metadata_dir / "official_Life_stage_mislabeled.csv"
    )
    epirr_to_corrections[LIFE_STAGE] = life_stage_mislabeled.set_index(epirr_no_v)[
        "EpiClass_pred_Life_stage"
    ].to_dict()

    return md5sum_to_epirr, epirr_to_corrections

In [None]:
def general_split_metrics(
    results_dir: Path,
    merge_assays: bool,
    exclude_categories: List[str] | None = None,
    exclude_names: List[str] | None = None,
) -> Tuple[Dict[str, Dict[str, Dict[str, float]]], Dict[str, Dict[str, pd.DataFrame]]]:
    """Create the content data for figure 2a. (get metrics for each task)

    Currently only using oversampled runs.

    Args:
        results_dir (Path): Directory containing the results. Needs to be parent over category folders.
        merge_assays (bool): Merge similar assays (rna-seq x2, wgbs x2)
        exclude_categories (List[str]): Task categories to exclude (first level directory names).
        exclude_names (List[str]): Names of folders to exclude (ex: 7c or no-mix).

    Returns:
        Dict[str, Dict[str, Dict[str, float]]] A metrics dictionary with the following structure:
            {split_name: {task_name: metrics_dict}}
        Dict[str, Dict[str, pd.DataFrame]] A split results dictionary with the following structure:
            {task_name: {split_name: split_results_df}}
    """
    all_split_results = {}
    split_results_handler = SplitResultsHandler()

    md5sum_to_epirr, epirr_to_corrections = create_mislabel_corrector()

    for parent, _, _ in os.walk(results_dir):
        # Looking for oversampling only results
        parent = Path(parent)
        if parent.name != "10fold-oversampling":
            continue

        # Get the category
        relpath = parent.relative_to(results_dir)
        category = relpath.parts[0].rstrip("_1l_3000n")
        if exclude_categories is not None:
            if any(exclude_str in category for exclude_str in exclude_categories):
                continue

        # Get the rest of the name, ignore certain runs
        rest_of_name = list(relpath.parts[1:])
        rest_of_name.remove("10fold-oversampling")

        if len(rest_of_name) > 1:
            raise ValueError(
                f"Too many parts in the name: {rest_of_name}. Path: {relpath}"
            )
        if rest_of_name:
            rest_of_name = rest_of_name[0]

        if exclude_names is not None:
            if any(name in rest_of_name for name in exclude_names):
                continue

        full_task_name = category
        if rest_of_name:
            full_task_name += f"_{rest_of_name}"

        # Get the split results
        split_results = split_results_handler.read_split_results(parent)
        if not split_results:
            raise ValueError(f"No split results found in {parent}")

        if "sex" in full_task_name or "life_stage" in full_task_name:
            corrections = epirr_to_corrections[category]
            for split_name in split_results:
                split_result_df = split_results[split_name]
                current_true_class = split_result_df["True class"].to_dict()
                new_true_class = {
                    k: corrections.get(md5sum_to_epirr[k], v)
                    for k, v in current_true_class.items()
                }
                split_result_df["True class"] = new_true_class.values()

                split_results[split_name] = split_result_df

        if ("assay" in full_task_name) and ("11c" in full_task_name) and merge_assays:
            for split_name in split_results:
                split_result_df = merge_similar_assays(split_results[split_name])
                split_results[split_name] = split_result_df

        all_split_results[full_task_name] = split_results

    try:
        split_results_metrics = split_results_handler.compute_split_metrics(
            all_split_results, concat_first_level=True
        )
    except KeyError as e:
        logging.error("KeyError: %s", e)
        logging.error("all_split_results: %s", all_split_results)
        logging.error("check folder: %s", results_dir)
        raise e
    return split_results_metrics, all_split_results

In [None]:
# pylint: disable=dangerous-default-value
def fig2_a(
    split_metrics: Dict[str, Dict[str, Dict[str, float]]],
    logdir: Path,
    name: str,
    exclude_categories: List[str] | None = None,
    y_range: List[float] | None = None,
    sort_by_acc: bool = False,
    metric_names: List[str] = ["Accuracy", "F1_macro"],
    show_plot: bool = True,
) -> None:
    """Render box plots of metrics per classifier and split, each in its own subplot.

    This function generates a figure with subplots, each representing a different
    metric. Each subplot contains box plots for each classifier, ordered by accuracy.

    Args:
        split_metrics: A nested dictionary with structure {split: {classifier: {metric: score}}}.
        logdir: The directory path to save the output plots.
        name: The base name for the output plot files.
        exclude_categories: Task categories to exclude from the plot.
        y_range: The y-axis range for the plots.
        sort_by_acc: Whether to sort the classifiers by accuracy.
        metrics: The metrics to include in the plot.
    """
    # Exclude some categories
    classifier_names = list(split_metrics["split0"].keys())
    if exclude_categories is not None:
        for category in exclude_categories:
            classifier_names = [c for c in classifier_names if category not in c]

    available_metrics = list(split_metrics["split0"][classifier_names[0]].keys())
    if any(metric not in available_metrics for metric in metric_names):
        raise ValueError(f"Invalid metric. Metrics need to be in {available_metrics}")

    # Sort classifiers by accuracy
    if sort_by_acc:
        mean_acc = {}
        for classifier in classifier_names:
            mean_acc[classifier] = np.mean(
                [split_metrics[split][classifier]["Accuracy"] for split in split_metrics]
            )
        classifier_names = sorted(
            classifier_names, key=lambda x: mean_acc[x], reverse=True
        )

    # Create subplots, one column for each metric
    fig = make_subplots(
        rows=1,
        cols=len(metric_names),
        subplot_titles=metric_names,
        horizontal_spacing=0.03,
    )

    color_group = px.colors.qualitative.Plotly
    colors = {
        classifier: color_group[i % len(color_group)]
        for i, classifier in enumerate(classifier_names)
    }

    # point_pos = -1.35
    point_pos = 0
    for i, metric in enumerate(metric_names):
        for classifier_name in classifier_names:
            values = [
                split_metrics[split][classifier_name][metric] for split in split_metrics
            ]

            fig.add_trace(
                go.Box(
                    y=values,
                    name=classifier_name,
                    fillcolor=colors[classifier_name],
                    line=dict(color="black", width=1.5),
                    marker=dict(size=3, color="black"),
                    boxmean=True,
                    boxpoints="all",
                    pointpos=point_pos,
                    showlegend=i == 0,  # Only show legend in the first subplot
                    hovertemplate="%{text}",
                    text=[
                        f"{split}: {value:.4f}"
                        for split, value in zip(split_metrics, values)
                    ],
                    legendgroup=classifier_name,
                    width=0.5,
                ),
                row=1,
                col=i + 1,
            )

    fig.update_layout(
        title_text="Neural network classification - Metric distribution for 10-fold cross-validation",
        yaxis_title="Value",
        boxmode="group",
        height=1200 * 0.8,
        width=1750 * 0.8,
    )

    # Acc, F1
    range_acc = [0.86, 1.001]
    fig.update_layout(yaxis=dict(range=range_acc))
    fig.update_layout(yaxis2=dict(range=range_acc))

    # AUC
    range_auc = [0.986, 1.0001]
    fig.update_layout(yaxis3=dict(range=range_auc))
    fig.update_layout(yaxis4=dict(range=range_auc))

    if y_range is not None:
        fig.update_yaxes(range=y_range)

    # Save figure
    fig.write_image(logdir / f"{name}.svg")
    fig.write_image(logdir / f"{name}.png")
    fig.write_html(logdir / f"{name}.html")

    if show_plot:
        fig.show()

#### Compute class imbalance

In [None]:
def compute_class_imbalance(
    all_split_results: Dict[str, Dict[str, pd.DataFrame]]
) -> pd.DataFrame:
    """Compute class imbalance for each task and split.

    Args:
        all_split_results: A dictionary with structure {task_name: {split_name: split_results_df}}.

    Returns:
        pd.DataFrame: A DataFrame with the following columns:
            - avg(balance_ratio): The average balance ratio for each task.
            - n: The number of classes for each task (used for the average).
    """
    # combine md5 lists
    task_md5s = {
        classifier_task: [split_df.index for split_df in split_results.values()]
        for classifier_task, split_results in all_split_results.items()
    }
    task_md5s = {
        classifier_task: [list(split_md5s) for split_md5s in md5s]
        for classifier_task, md5s in task_md5s.items()
    }
    task_md5s = {
        classifier_task: list(itertools.chain(*md5s))
        for classifier_task, md5s in task_md5s.items()
    }

    # get metadata
    metadata_df = metadata_handler.load_metadata_df("v2-encode")

    label_counts = {}
    for classifier_task, md5s in task_md5s.items():
        try:
            label_counts[classifier_task] = metadata_df.loc[md5s][
                classifier_task
            ].value_counts()
        except KeyError as e:
            category_name = classifier_task.rsplit("_", maxsplit=1)[0]
            try:
                label_counts[classifier_task] = metadata_df.loc[md5s][
                    category_name
                ].value_counts()
            except KeyError as e:
                raise e

    # Compute average class ratio vs majority class
    # class_ratios = {}
    # for classifier_task, counts in label_counts.items():
    #     class_ratios[classifier_task] = (np.mean(counts / max(counts)), len(counts))

    # Compute Shannon Entropy
    class_balance = {}
    for classifier_task, counts in label_counts.items():
        total_count = counts.sum()
        k = len(counts)
        p_x = counts / total_count  # class proportions
        p_x = p_x.values
        shannon_entropy = -np.sum(p_x * np.log2(p_x))
        balance = shannon_entropy / np.log2(k)
        class_balance[classifier_task] = (balance, k)

    df_class_balance = pd.DataFrame.from_dict(
        class_balance, orient="index", columns=["Normalized Shannon Entropy", "k"]
    ).sort_index()

    return df_class_balance

In [None]:
# hdf5_type = "hg38_100kb_all_none"
# results_dir = base_data_dir / "training_results" / "dfreeze_v2" / hdf5_type
# if not results_dir.exists():
#     raise FileNotFoundError(f"Directory {results_dir} does not exist.")
# _, all_split_results = general_split_metrics(
# results_dir, exclude_categories=None, exclude_names=exclude_names, merge_assays=True
# )

In [None]:
# fig_logdir = base_fig_dir / "fig2_EpiAtlas_other" / "fig2--NN_perf_across_categories"
# df_class_balance = compute_class_imbalance(all_split_results)
# df_class_balance.to_csv(fig_logdir / "class_balance_Shannon.csv")

#### Graph performance per metadata category

In [None]:
exclude_categories = ["track_type", "groups"]
exclude_names = ["chip-seq", "7c"]

In [None]:
hdf5_type = "hg38_100kb_all_none"
results_dir = base_data_dir / "training_results" / "dfreeze_v2" / hdf5_type
if not results_dir.exists():
    raise FileNotFoundError(f"Directory {results_dir} does not exist.")
split_results_metrics, all_split_results = general_split_metrics(
    results_dir,
    merge_assays=True,
    exclude_categories=exclude_categories,
    exclude_names=exclude_names,
)

In [None]:
fig_logdir = base_fig_dir / "fig2_EpiAtlas_other" / "fig2--NN_perf_across_categories"
fig_logdir.mkdir(parents=False, exist_ok=True)
fig_name = f"{hdf5_type}_perf_across_categories_full_internal"

metrics = ["Accuracy", "F1_macro", "AUC_micro", "AUC_micro"]
# fig2_a(
#     split_results_metrics,
#     fig_logdir,
#     fig_name,
#     sort_by_acc=True,
#     metric_names=metrics,
#     exclude_categories=None,
#     show_plot=False,
# )

In [None]:
fig_logdir = base_fig_dir / "fig2_EpiAtlas_other" / "fig2--NN_perf_across_categories"
fig_logdir.mkdir(parents=False, exist_ok=True)

metrics_full = ["Accuracy", "F1_macro", "AUC_micro", "AUC_macro"]
metrics_AUC = ["AUC_micro", "AUC_macro"]
metrics_acc_F1 = ["Accuracy", "F1_macro"]
exclude_categories = ["sex_no-mixed", "disease"]
y_range_AUC = [0.986, 1.0001]
y_range_acc = [0.86, 1.001]

for name, metrics, y_range in zip(
    ["full", "acc_F1", "AUC"],
    [metrics_full, metrics_acc_F1, metrics_AUC],
    [None, y_range_acc, y_range_AUC],
):
    fig_name = f"{hdf5_type}_perf_across_categories_{name}"
    # fig2_a(
    #     split_results_metrics,
    #     fig_logdir,
    #     fig_name,
    #     sort_by_acc=True,
    #     metric_names=metrics,
    #     exclude_categories=exclude_categories,
    #     show_plot=False,
    #     y_range=y_range,
    # )

In [None]:
fig_logdir = base_fig_dir / "fig2_EpiAtlas_other" / "fig2--NN_perf_across_categories"
fig_logdir.mkdir(parents=False, exist_ok=True)
metrics_acc_F1 = ["Accuracy", "F1_macro"]
fig_name = f"{hdf5_type}_perf_across_categories_acc_F1"
# fig2_a(
#     split_results_metrics,
#     fig_logdir,
#     fig_name,
#     sort_by_acc=True,
#     metric_names=metrics_acc_F1,
#     exclude_categories=exclude_categories,
#     show_plot=True,
# )

### Neural network performance per assay across metadata categories

In [None]:
def NN_performance_per_assay_across_categories(
    all_split_results: Dict[str, Dict[str, pd.DataFrame]],
    logdir: Path,
    name: str,
    title_end: str = "",
    exclude_categories: List[str] | None = None,
    y_range: None | List[float] = None,
):
    """Create a box plot of assay accuracy for each classifier."""
    all_split_results = copy.deepcopy(all_split_results)

    # Exclude some categories
    classifier_names = list(all_split_results.keys())
    if exclude_categories is not None:
        for category in exclude_categories:
            classifier_names = [c for c in classifier_names if category not in c]

    metadata_df = MetadataHandler(paper_dir).load_metadata_df("v2-encode")

    # One graph per metadata category
    for task_name in classifier_names:
        split_results = all_split_results[task_name]
        if ASSAY in task_name:
            for split_name in split_results:
                split_results[split_name] = merge_similar_assays(
                    split_results[split_name]
                )

        assay_acc_df = split_results_handler.compute_acc_per_assay(
            split_results, metadata_df
        )

        fig = go.Figure()
        for assay in ASSAY_ORDER:
            try:
                assay_accuracies = assay_acc_df[assay]
            except KeyError:
                continue

            fig.add_trace(
                go.Box(
                    y=assay_accuracies.values,
                    name=assay,
                    boxmean=True,
                    boxpoints="all",
                    showlegend=True,
                    marker=dict(size=3, color="black"),
                    line=dict(width=1, color="black"),
                    fillcolor=assay_colors[assay],
                    hovertemplate="%{text}",
                    text=[
                        f"{split}: {value:.4f}"
                        for split, value in assay_accuracies.items()
                    ],
                )
            )

        # if "sample_ontology" in task_name:
        #     yrange = [0.59, 1.001]
        # elif ASSAY in task_name:
        #     yrange = [0.985, 1.001]
        # else:
        yrange = [assay_acc_df.min(), 1.001]  # type: ignore

        if y_range is not None:
            yrange = y_range

        fig.update_yaxes(range=yrange)

        title_text = f"NN classification - {task_name}"
        if title_end:
            title_text += f" - {title_end}"
        fig.update_layout(
            title_text=title_text,
            yaxis_title="Accuracy",
            xaxis_title="Assay",
        )

        # Save figure
        this_name = name + f"_{task_name}"
        fig.write_image(logdir / f"{this_name}.svg")
        fig.write_image(logdir / f"{this_name}.png")
        fig.write_html(logdir / f"{this_name}.html")

        fig.show()

In [None]:
exclude_categories = ["track_type", "groups", "disease"]
exclude_names = ["chip-seq", "7c", "no-mixed"]

In [None]:
N = 303114
# N = 30321
hdf5_type = f"hg38_regulatory_regions_n{N}"
results_dir = base_data_dir / "training_results" / "dfreeze_v2" / hdf5_type
if not results_dir.exists():
    raise FileNotFoundError(f"Directory {results_dir} does not exist.")

split_results_metrics, all_split_results = general_split_metrics(
    results_dir,
    merge_assays=True,
    exclude_categories=exclude_categories,
    exclude_names=exclude_names,
)

In [None]:
all_split_results.keys()

In [None]:
logdir = (
    base_fig_dir
    / "fig2_EpiAtlas_other"
    / "fig2--NN_perf_across_categories"
    / "per_assay"
    / hdf5_type
)
logdir.mkdir(parents=False, exist_ok=True)
fig_name = "perf_per_assay"

# exclude_categories = None
# exclude_categories = ["groups", "track_type", "harmonized", "project", "paired"] # only assay_epiclass

# NN_performance_per_assay_across_categories(all_split_results, logdir, fig_name, exclude_categories, y_range=None)

### Neural network performance per assay, scatterplot

model_X split_n vs model_Y split_n for all n

In [None]:
def pairwise_performance_scatterplot(
    all_split_results: Dict[str, Dict[str, pd.DataFrame]],
    logdir: Path,
    name: str,
    label_category: str,
    verbose: bool = False,
) -> None:
    """
    For the two given classification tasks split results (need to be from same category),
    create a scatter plot of split performance per assay, split for split.

    Args:
        all_split_results: A dictionary with structure {task_name: {split_name: split_results_df}}.
        logdir (Path): The directory path to save the output plots.
        name (str): The base name for the output plot files.
        label_category (str): category used for labels, used for title and axis labels.
        verbose (bool): Print more information.
    """
    all_split_results = copy.deepcopy(all_split_results)
    metadata_df = MetadataHandler(paper_dir).load_metadata_df(
        "v2-encode", merge_assays=True
    )

    for task_name_1, task_name_2 in itertools.combinations(all_split_results.keys(), 2):
        if verbose:
            print(task_name_1, task_name_2)
        split_results_1 = all_split_results[task_name_1]
        split_results_2 = all_split_results[task_name_2]

        if ASSAY in task_name_1:
            for split_name in split_results_1:
                split_results_1[split_name] = merge_similar_assays(
                    split_results_1[split_name]
                )
                split_results_2[split_name] = merge_similar_assays(
                    split_results_2[split_name]
                )

        if split_results_1["split0"].shape != split_results_2["split0"].shape:
            raise ValueError(
                f"Split results for {task_name_1} and {task_name_2} do not have the same shape: {split_results_1['split0'].shape} != {split_results_2['split0'].shape}"
            )
        assay_acc_df_1 = split_results_handler.compute_acc_per_assay(
            split_results_1, metadata_df
        )
        assay_acc_df_2 = split_results_handler.compute_acc_per_assay(
            split_results_2, metadata_df
        )

        fig = go.Figure()
        min_x = 1
        min_y = 1
        for assay in ASSAY_ORDER:
            if verbose:
                print(assay)
            try:
                assay_accuracies_1 = assay_acc_df_1[assay]
                assay_accuracies_2 = assay_acc_df_2[assay]
            except KeyError as e:
                print(e)
                continue

            if verbose:
                print(f"{task_name_1}: {assay_accuracies_1}")
                print(f"{task_name_2}: {assay_accuracies_2}")

            hovertext = [
                f"{split}: ({assay_accuracies_1[split]:.4f},{assay_accuracies_2[split]:.4f})"
                for split in assay_accuracies_1.keys()
            ]

            x_gt_y = sum(assay_accuracies_1 > assay_accuracies_2)
            y_gt_x = sum(assay_accuracies_1 < assay_accuracies_2)
            trace_name = f"{assay} ({y_gt_x},{x_gt_y})"

            fig.add_trace(
                go.Scatter(
                    x=assay_accuracies_1.values,
                    y=assay_accuracies_2.values,
                    mode="markers",
                    name=trace_name,
                    marker=dict(size=5, color=assay_colors[assay]),
                    text=hovertext,
                    hovertemplate="%{text}",
                )
            )

            min_x = min(min_x, *assay_accuracies_1.values)
            min_y = min(min_y, *assay_accuracies_2.values)

        # diagonal line
        fig.add_trace(
            go.Scatter(
                x=[0, 1],
                y=[0, 1],
                mode="lines",
                line=dict(color="black", width=1, dash="dash"),
                showlegend=False,
            )
        )

        range_x = 1 - min_x
        range_y = 1 - min_y
        fig.update_xaxes(range=[min_x - 0.01 * range_x, 1 + 0.01 * range_x])
        fig.update_yaxes(range=[min_y - 0.01 * range_y, 1 + 0.01 * range_y])

        x_name = task_name_1.replace(f"_{label_category}", "")
        y_name = task_name_2.replace(f"_{label_category}", "")
        fig.update_layout(
            title_text=f"Neural network classification - {label_category} - 10-fold cross-validation",
            xaxis_title=f"{x_name} accuracy",
            yaxis_title=f"{y_name} accuracy",
        )

        fig.update_layout(legend_title_text="Assay: (y>x, x>y)")

        # Save figure
        this_name = f"{name}-{label_category}-{x_name}_VS_{y_name}"
        this_name = this_name.replace(ASSAY, "assay")
        this_name = this_name.replace(CELL_TYPE, "sample_ontology")
        fig.write_image(logdir / f"{this_name}.svg")
        fig.write_image(logdir / f"{this_name}.png")
        fig.write_html(logdir / f"{this_name}.html")

        # fig.show()

In [None]:
exclude_categories = ["track_type", "groups", "disease"]
exclude_names = ["chip-seq", "7c", "no-mixed"]

In [None]:
N1 = 303114
N2 = 30321
hdf5_type_reg1 = f"hg38_regulatory_regions_n{N1}"
hdf5_type_reg2 = f"hg38_regulatory_regions_n{N2}"
hdf5_type_100kb = "hg38_100kb_all_none"
hdf5_type_10kb = "hg38_10kb_all_none"

scatter_fig_results = {}
for hdf5_type in [hdf5_type_reg1, hdf5_type_reg2, hdf5_type_100kb, hdf5_type_10kb]:
    results_dir = base_data_dir / "training_results" / "dfreeze_v2" / hdf5_type
    if not results_dir.exists():
        raise FileNotFoundError(f"Directory {results_dir} does not exist.")

    _, all_split_results = general_split_metrics(
        results_dir,
        merge_assays=True,
        exclude_categories=exclude_categories,
        exclude_names=exclude_names,
    )

    scatter_fig_results.update(
        {
            f"{hdf5_type}_{task_name}": split_results
            for task_name, split_results in all_split_results.items()
        }
    )

In [None]:
for label_category in [ASSAY, CELL_TYPE]:
    results = {k: v for k, v in scatter_fig_results.items() if label_category in k}
    pairwise_performance_scatterplot(
        results,
        logdir=base_fig_dir / "flagship" / "pairwise_scatterplot_acc" / label_category,
        name="acc_per_assay",
        label_category=label_category,
        verbose=False,
    )

### Track type effect on NN performance

In [None]:
parent_dir = base_data_dir / "training_results" / "dfreeze_v2" / "hg38_100kb_all_none"
assay_parent_dir = parent_dir / "assay_epiclass_1l_3000n" / "11c"
ct_parent_dir = parent_dir / "harmonized_sample_ontology_intermediate_1l_3000n"

assay_results = {
    folder.name: split_results_handler.read_split_results(folder)
    for folder in assay_parent_dir.iterdir()
    if "chip" not in folder.name
}
ct_results = {
    folder.name: split_results_handler.read_split_results(folder)
    for folder in ct_parent_dir.iterdir()
    if "l1" not in folder.name
}

_ = assay_results.pop("10fold-oversampling")
_ = ct_results.pop("10fold-oversampling")
_ = ct_results.pop("10fold-oversampling_chip-seq-only")

In [None]:
corrected_assay_results = copy.deepcopy(assay_results)
for task_name, split_dfs in list(corrected_assay_results.items()):
    for split_name in split_dfs:
        split_dfs[split_name] = merge_similar_assays(split_dfs[split_name])

In [None]:
assay_metrics = split_results_handler.compute_split_metrics(
    corrected_assay_results, concat_first_level=True
)
ct_metrics = split_results_handler.compute_split_metrics(
    ct_results, concat_first_level=True
)

In [None]:
logdir = base_fig_dir / "fig2_EpiAtlas_other"
name = f"{ASSAY}_global_track_type_effect"
# fig2_a(
#     assay_metrics,
#     logdir,
#     name,
#     exclude_categories=None,
#     y_range=[0.99, 1.0001],
#     sort_by_acc=False,
# )

logdir = base_fig_dir / "fig2_EpiAtlas_other"
name = f"{CELL_TYPE}_global_track_type_effect"
# fig2_a(
#     ct_metrics,
#     logdir,
#     name,
#     exclude_categories=None,
#     y_range=[0.91, 1.001],
#     sort_by_acc=False,
# )

In [None]:
name = f"{ASSAY}_global_track_type_effect_per_assay"
# NN_performance_per_assay_across_categories(
#     corrected_assay_results, logdir, name, exclude_categories=None, y_range=[0.96, 1.001]
# )

In [None]:
def only_keep_core_assays(
    results_dfs: Dict[str, Dict[str, pd.DataFrame]]
) -> Dict[str, Dict[str, pd.DataFrame]]:
    """Exclude non core-assays from split results. Also exclude input."""
    accepted_assays = ASSAY_ORDER[0:-3]
    new_results = copy.deepcopy(results_dfs)
    for task_name, split_dfs in list(new_results.items()):
        for split_name in split_dfs:
            df = split_dfs[split_name]
            if ASSAY not in df.columns:
                merged_df = df.merge(
                    metadata_v2_df, how="left", left_index=True, right_index=True
                )
                df = df[merged_df[ASSAY].isin(accepted_assays)]
                new_results[task_name][split_name] = df
    return new_results

In [None]:
# # Recompute metrics considering only histones
# for result_df, category_name, y_range in zip(
#     [corrected_assay_results, ct_results],
#     [ASSAY, CELL_TYPE],
#     [[0.85, 1.001], [0.91, 1.001]],
# ):
#     print(category_name)
#     name = f"{category_name}_core6c_track_type_effect"

#     core_result_df = only_keep_core_assays(result_df)
#     metrics = split_results_handler.compute_split_metrics(
#         core_result_df, concat_first_level=True
#     )

#     fig2_a(metrics, logdir, name, exclude_categories=None, y_range=y_range)

#     if category_name == ASSAY:
#         name = f"{ASSAY}_core6_track_type_effect_per_assay"
#         NN_performance_per_assay_across_categories(
#             core_result_df, logdir, name, exclude_categories=None, y_range=[0.97, 1.001]
#         )

### Sex chrY z-score distribution vs predictions

Violin plot of average z-score on chrY per sex, black dots for pred same class and red for pred different class.  

- Do the split male female violin per assay (only FC, merge 2xwgbs and 2xrna, no rna unique_raw). 
- Use scatter for points on each side, agree same color as violin, disagree other.
- Point labels: uuid, epirr

Compute chrY coverage z-score VS assay distribution

In [None]:
def compute_chrY_zscores(version: str):
    """Compute z-scores for chrY coverage data, per assay distribution.

    Excludes raw, pval, and Unique_raw tracks.
    """
    # Get chrY coverage data
    chrY_coverage_dir = base_data_dir / "chrY_coverage"
    if not chrY_coverage_dir.exists():
        raise FileNotFoundError(f"Directory {chrY_coverage_dir} does not exist.")
    chrY_coverage_df = pd.read_csv(chrY_coverage_dir / "chrXY_coverage_all.csv", header=0)

    # Filter out md5s not in metadata version
    metadata = MetadataHandler(paper_dir).load_metadata(version)
    md5s = set(metadata.md5s)
    chrY_coverage_df = chrY_coverage_df[chrY_coverage_df["filename"].isin(md5s)]

    # Make sure all values are non-zero
    assert (chrY_coverage_df["chrY"] != 0).all()

    # These tracks are excluded from z-score computation
    metadata.remove_category_subsets("track_type", ["raw", "pval", "Unique_raw"])
    metadata_df = pd.DataFrame.from_records(list(metadata.datasets))
    metadata_df[ASSAY].replace(ASSAY_MERGE_DICT, inplace=True)

    # Merge with metadata
    chrY_coverage_df = chrY_coverage_df.merge(
        metadata_df[["md5sum", ASSAY]], left_on="filename", right_on="md5sum"
    )

    # Compute z-score per assay
    chrY_dists = chrY_coverage_df.groupby(ASSAY).agg({"chrY": ["mean", "std", "count"]})

    output_dir = chrY_coverage_dir / f"dfreeze_{version}_stats"
    output_dir.mkdir(parents=False, exist_ok=True)
    chrY_dists.to_csv(output_dir / "chrY_coverage_stats.csv")

    # Compute z-score per assay group, merge back into the dataframe, save results
    metric_name = "chrY_zscore_vs_assay"
    groupby_df = chrY_coverage_df.groupby(ASSAY)
    for _, group in groupby_df:
        group["chrY_zscore"] = zscore(group["chrY"])
        chrY_coverage_df.loc[group.index, metric_name] = group["chrY_zscore"]

    output_cols = ["filename", "chrY", metric_name, ASSAY]
    chrY_coverage_df[output_cols].to_csv(
        output_dir / "chrY_coverage_zscore_vs_assay.csv", index=False
    )

In [None]:
compute_chrY_zscores("v2")

Plot z-scores according to sex

main Fig: chrY per EpiRR (excluding WGBS): only boxplot with all points

In [None]:
metric_label = "chrY_zscore_vs_assay"

In [None]:
def prepare_fig_2B_data(version: str, prediction_data_dir: Path) -> pd.DataFrame:
    """Prepare data for figure 2b."""
    # Load metadata
    meta_cols = ["md5sum", "EpiRR", SEX]
    metadata = MetadataHandler(paper_dir).load_metadata(version)
    metadata_df = pd.DataFrame.from_records(list(metadata.datasets))
    metadata_df = metadata_df[meta_cols]

    # Load z-score data
    zscore_dir = base_data_dir / "chrY_coverage" / f"dfreeze_{version}_stats"
    zscore_df = pd.read_csv(zscore_dir / "chrY_coverage_zscore_vs_assay.csv", header=0)

    # Load NN predictions
    split_results = split_results_handler.read_split_results(prediction_data_dir)
    pred_df = split_results_handler.concatenate_split_results(
        {"sex": split_results}, concat_first_level=True
    )["sex"]

    # Merge all
    zscore_df = zscore_df.merge(metadata_df, left_on="filename", right_on="md5sum")
    zscore_df = zscore_df.merge(pred_df, left_on="filename", right_index=True)
    zscore_df["Max pred"] = zscore_df[["female", "male", "mixed"]].max(axis=1)
    zscore_df.set_index("md5sum", inplace=True)
    return zscore_df

In [None]:
def fig2_B(zscore_df: pd.DataFrame, logdir: Path, name: str) -> None:
    """Create figure 2B.

    Args:
        zscore_df: The dataframe with z-score data.
    """
    assay_sizes = zscore_df[ASSAY].value_counts()
    assays = sorted(assay_sizes.index)

    # x_title = "Assay+Sex z-score distributions - Male/Female classification disagreement separate"
    x_title = "Assay+Sex z-score distributions"
    fig = make_subplots(
        rows=1,
        cols=len(assays),
        shared_yaxes=True,
        x_title=x_title,
        y_title="z-score",
        horizontal_spacing=0.02,
        subplot_titles=[
            f"{assay_label} ({assay_sizes[assay_label]})" for assay_label in assays
        ],
    )

    for i, assay_label in enumerate(sorted(assays)):
        sub_df = zscore_df[zscore_df[ASSAY] == assay_label]

        y_values = sub_df[metric_label]
        hovertext = [
            f"{epirr}: z-score={z_score:.3f}, pred={pred:.3f}"
            for epirr, pred, z_score in zip(
                sub_df["EpiRR"],
                sub_df["Max pred"],
                sub_df[metric_label],
            )
        ]
        hovertext = np.array(hovertext)

        female_idx = np.argwhere((sub_df["True class"] == "female").values).flatten()
        male_idx = np.argwhere((sub_df["True class"] == "male").values).flatten()

        # predicted_as_female_idx = np.argwhere(
        #     (
        #         (sub_df["Predicted class"] == "female") & (sub_df["True class"] == "male")
        #     ).values
        # ).flatten()
        # predicted_as_male_idx = np.argwhere(
        #     (
        #         (sub_df["Predicted class"] == "male") & (sub_df["True class"] == "female")
        #     ).values
        # ).flatten()

        # fig.add_trace(
        #     go.Violin(
        #         name="",
        #         x0=i,
        #         y=y_values[female_idx],
        #         box_visible=True,
        #         meanline_visible=True,
        #         points="all",
        #         hovertemplate="%{text}",
        #         text=hovertext[female_idx],
        #         side="negative",
        #         line_color=sex_colors["male"],
        #         spanmode="hard",
        #         showlegend=False,
        #         marker=dict(size=1),
        #     ),
        #     row=1,
        #     col=i + 1,
        # )

        # fig.add_trace(
        #     go.Violin(
        #         name="",
        #         x0=i,
        #         y=y_values[male_idx],
        #         box_visible=True,
        #         meanline_visible=True,
        #         points="all",
        #         hovertemplate="%{text}",
        #         text=hovertext[male_idx],
        #         side="positive",
        #         line_color=sex_colors["male"],
        #         spanmode="hard",
        #         showlegend=False,
        #         marker=dict(size=1),
        #     ),
        #     row=1,
        #     col=i + 1,
        # )

        fig.add_trace(
            go.Box(
                name=assay_label,
                y=y_values[female_idx],
                boxmean=True,
                boxpoints="all",
                hovertemplate="%{text}",
                text=hovertext[female_idx],
                marker=dict(
                    size=2,
                    color=sex_colors["female"],
                    line=dict(width=0.5, color="black"),
                ),
                fillcolor=sex_colors["female"],
                line=dict(width=1, color="black"),
                showlegend=False,
                legendgroup="Female",
            ),
            row=1,
            col=i + 1,
        )

        fig.add_trace(
            go.Box(
                name=assay_label,
                y=y_values[male_idx],
                boxmean=True,
                boxpoints="all",
                hovertemplate="%{text}",
                text=hovertext[male_idx],
                marker=dict(
                    size=2, color=sex_colors["male"], line=dict(width=0.5, color="black")
                ),
                fillcolor=sex_colors["male"],
                line=dict(width=1, color="black"),
                showlegend=False,
                legendgroup="Male",
            ),
            row=1,
            col=i + 1,
        )

        # temp_y_values = y_values[predicted_as_female_idx]
        # temp_size = 1 + 5 * sub_df["Max pred"].values[predicted_as_female_idx]
        # fig.add_trace(
        #     go.Scatter(
        #         name="",
        #         x=[i - 0.2] * len(temp_y_values),
        #         y=temp_y_values,
        #         mode="markers",
        #         marker=dict(color=sex_colors["female"], size=temp_size),
        #         showlegend=False,
        #         hovertemplate="%{text}",
        #         text=hovertext[predicted_as_female_idx],
        #     ),
        #     row=1,
        #     col=i + 1,
        # )

        # temp_y_values = y_values[predicted_as_male_idx]
        # temp_size = 1 + 5 * sub_df["Max pred"].values[predicted_as_male_idx]
        # fig.add_trace(
        #     go.Scatter(
        #         name="",
        #         x=[i - 0.25] * len(temp_y_values),
        #         y=temp_y_values,
        #         mode="markers",
        #         marker=dict(color=sex_colors["male"], size=temp_size),
        #         showlegend=False,
        #         hovertemplate="%{text}",
        #         text=hovertext[predicted_as_male_idx],
        #     ),
        #     row=1,
        #     col=i + 1,
        # )

    # Add a dummy scatter plot for legend
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            name="Female",
            marker=dict(color=sex_colors["female"], size=20),
            showlegend=True,
            legendgroup="Female",
        )
    )
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            name="Male",
            marker=dict(color=sex_colors["male"], size=20),
            showlegend=True,
            legendgroup="Male",
        )
    )

    fig.update_xaxes(showticklabels=False)
    fig.update_yaxes(range=[-1.5, 3])
    title = "z-score(mean chrY coverage per file) distribution per assay"
    fig.update_layout(
        title_text=title,
        width=3000,
        height=1000,
    )

    # Save figure
    fig.write_image(logdir / f"{name}.svg")
    fig.write_image(logdir / f"{name}.png")
    fig.write_html(logdir / f"{name}.html")

    fig.show()

In [None]:
logdir = base_fig_dir / "fig2_EpiAtlas_other" / "fig2--sex_chrY_zscore"
logdir.mkdir(parents=False, exist_ok=True)
name = "fig2--sex_chrY_zscore_only_box"

In [None]:
version = "v2"
pred_data_dir = (
    base_data_dir
    / "training_results"
    / f"dfreeze_{version}"
    / "hg38_100kb_all_none"
    / f"{SEX}_1l_3000n"
    / "w-mixed"
    / "10fold-oversampling"
)
if not pred_data_dir:
    raise FileNotFoundError(f"Directory {pred_data_dir} does not exist.")
zscore_df = prepare_fig_2B_data(version, pred_data_dir)

In [None]:
# fig2_B(zscore_df, logdir, name)

Plot z-score according to sex, merge assays except wgbs (1 violin plot, 1 point = 1 epirr)

In [None]:
def fig2_B_merged_assays(
    zscore_df: pd.DataFrame,
    sex_mislabels: Dict[str, str],
    logdir: Path,
    name: str,
    min_pred: float | None = None,
) -> None:
    """Create figure 2B.

    Args:
        zscore_df (pd.DataFrame): The dataframe with z-score data.
        sex_mislabels (Dict[str, str]): {EpiRR_no-v: corrected_sex_label}
        logdir (Path): The directory path to save the output plots.
        name (str): The base name for the output plot files.
        min_pred (float|None): Minimum prediction value to include in the plot. Used on average EpiRR 'Max pred' values.
    """
    zscore_df = zscore_df.copy(deep=True)
    zscore_df[ASSAY].replace(ASSAY_MERGE_DICT, inplace=True)

    # wgbs reverses male/female chrY tendency, so removed here
    zscore_df = zscore_df[zscore_df[ASSAY] != "wgbs"]

    # Average chrY z-score values
    mean_chrY_values_df = zscore_df.groupby(["EpiRR", SEX]).agg(
        {metric_label: "mean", "Max pred": "mean"}
    )
    mean_chrY_values_df.reset_index(inplace=True)
    if not mean_chrY_values_df["EpiRR"].is_unique:
        raise ValueError("EpiRR is not unique.")

    # Filter out low prediction values
    if min_pred is not None:
        mean_chrY_values_df = mean_chrY_values_df[
            mean_chrY_values_df["Max pred"] > min_pred
        ]

    mean_chrY_values_df.reset_index(drop=True, inplace=True)
    chrY_values = mean_chrY_values_df[metric_label]
    female_idx = np.argwhere((mean_chrY_values_df[SEX] == "female").values).flatten()  # type: ignore
    male_idx = np.argwhere((mean_chrY_values_df[SEX] == "male").values).flatten()  # type: ignore

    # Mislabels
    binary_mislabels = set(
        epirr_no_v
        for epirr_no_v, label in sex_mislabels.items()
        if label in ["male", "female"]
    )
    epirr_no_v = mean_chrY_values_df["EpiRR"].str.extract(pat=r"(\w+\d+).\d+")[0]
    mislabels_idx = np.argwhere(
        epirr_no_v.isin(binary_mislabels).values  # type: ignore
    ).flatten()

    mislabel_color_dict = {"female": sex_colors["male"], "male": sex_colors["female"]}
    mislabel_colors = [mislabel_color_dict[mean_chrY_values_df[SEX][i]] for i in mislabels_idx]  # type: ignore

    # Hovertext
    hovertext = [
        f"{epirr}: <z-score>={z_score:.3f}"
        for epirr, z_score in zip(
            mean_chrY_values_df["EpiRR"],
            mean_chrY_values_df[metric_label],
        )
    ]
    hovertext = np.array(hovertext)

    # Create figure
    fig = go.Figure()
    fig.add_trace(
        go.Box(
            name="Female",
            y=chrY_values[female_idx],
            boxmean=True,
            boxpoints="all",
            pointpos=0,
            hovertemplate="%{text}",
            text=hovertext[female_idx],
            marker=dict(size=1, color="black"),
            line=dict(width=1, color="black"),
            fillcolor=sex_colors["female"],
        ),
    )

    fig.add_trace(
        go.Box(
            name="Male",
            y=chrY_values[male_idx],
            boxmean=True,
            boxpoints="all",
            pointpos=0,
            hovertemplate="%{text}",
            text=hovertext[male_idx],
            marker=dict(size=1, color="black"),
            line=dict(width=1, color="black"),
            fillcolor=sex_colors["male"],
        ),
    )

    fig.add_trace(
        go.Scatter(
            name="Mislabel",
            x=np.zeros(len(mislabels_idx)),
            y=chrY_values[mislabels_idx],
            mode="markers",
            marker=dict(size=4, color=mislabel_colors, line=dict(width=1, color="black")),
            showlegend=False,
            hovertemplate="%{text}",
            text=hovertext[mislabels_idx],
        ),
    )

    fig.update_yaxes(range=[-1.5, 3])
    title = "z-score(mean chrY coverage per file) distribution - z-scores averaged over assays"
    if min_pred is not None:
        title += f"<br>avg_maxPred>{min_pred}"

    fig.update_layout(
        title=dict(text=title, x=0.5),
        xaxis_title=SEX,
        yaxis_title="Average z-score",
        width=750,
        height=750,
    )

    # Save figure
    this_name = f"{name}_n{mean_chrY_values_df.shape[0]}"
    fig.write_image(logdir / f"{this_name}.svg")
    fig.write_image(logdir / f"{this_name}.png")
    fig.write_html(logdir / f"{this_name}.html")

    fig.show()

In [None]:
_, epirr_mislabels = create_mislabel_corrector()
sex_mislabels = epirr_mislabels[SEX]

In [None]:
min_pred = None
name = "fig2--sex_chrY_zscore_merged_assays"
if min_pred is not None:
    name = f"fig2--sex_chrY_zscore_merged_assays_avg_maxPred>{min_pred}"

logdir = base_fig_dir / "fig2_EpiAtlas_other" / "fig2--sex_chrY_zscore"
# fig2_B_merged_assays(zscore_df, sex_mislabels, logdir, name, min_pred=min_pred)

In [None]:
def merged_assays_separation_distance(
    zscore_df: pd.DataFrame, logdir: Path, name: str
) -> None:
    """Complement to figure 2B, showing separation distance (mean, median)
    between male/female zscore clusters.

    Args:
        zscore_df (pd.DataFrame): The dataframe with z-score data.
        logdir (Path): The directory path to save the output plots.
        name (str): The base name for the output plot files.
    """
    zscore_df = zscore_df.copy(deep=True)
    zscore_df[ASSAY].replace(ASSAY_MERGE_DICT, inplace=True)

    # wgbs reverses male/female chrY tendency, so removed here
    zscore_df = zscore_df[zscore_df[ASSAY] != "wgbs"]

    # Average chrY z-score values
    mean_chrY_values_df = zscore_df.groupby(["EpiRR", SEX]).agg(
        {metric_label: "mean", "Max pred": "mean"}
    )
    mean_chrY_values_df.reset_index(inplace=True)
    if not mean_chrY_values_df["EpiRR"].is_unique:
        raise ValueError("EpiRR is not unique.")

    mean_chrY_values_df.reset_index(drop=True, inplace=True)

    distances = {"mean": [], "median": []}
    min_preds = np.arange(0, 1.0, 0.01)
    sample_count = []
    for min_pred in min_preds:
        subset_chrY_values_df = mean_chrY_values_df[
            mean_chrY_values_df["Max pred"] > min_pred
        ]
        sample_count.append(subset_chrY_values_df.shape[0])

        # Compute separation distances
        chrY_vals_female = subset_chrY_values_df[subset_chrY_values_df[SEX] == "female"][
            metric_label
        ]
        chrY_vals_male = subset_chrY_values_df[subset_chrY_values_df[SEX] == "male"][
            metric_label
        ]

        if not chrY_vals_female.empty and not chrY_vals_male.empty:
            mean_distance = np.abs(chrY_vals_female.mean() - chrY_vals_male.mean())
            median_distance = np.abs(chrY_vals_female.median() - chrY_vals_male.median())

            distances["mean"].append(mean_distance)
            distances["median"].append(median_distance)
        else:
            distances["mean"].append(np.nan)
            distances["median"].append(np.nan)

    # Plotting the results
    fig = go.Figure()

    # Add traces for mean and median distances
    fig.add_trace(
        go.Scatter(
            x=min_preds,
            y=distances["mean"],
            mode="lines+markers",
            name="Mean Distance (left)",
            line=dict(color="blue"),
        )
    )
    fig.add_trace(
        go.Scatter(
            x=min_preds,
            y=distances["median"],
            mode="lines+markers",
            name="Median Distance (left)",
            line=dict(color="green"),
        )
    )

    # Add trace for number of files
    fig.add_trace(
        go.Scatter(
            x=min_preds,
            y=sample_count,
            mode="lines+markers",
            name="Number of Files (right)",
            line=dict(color="red"),
            yaxis="y2",
        )
    )

    # Update layout for secondary y-axis
    fig.update_layout(
        title="Separation Distance of chrY z-scores male/female clusters",
        xaxis_title="Average Prediction Score minimum threshold",
        yaxis_title="Z-score Distance",
        yaxis2=dict(title="Number of Files", overlaying="y", side="right"),
        legend=dict(
            x=1.08,
        ),
    )
    # Save the plot
    fig.write_image(logdir / f"{name}.svg")
    fig.write_image(logdir / f"{name}.png")
    fig.write_html(logdir / f"{name}.html")

    fig.show()

In [None]:
name = "sex_chrY_zscore_merged_assays_distance"
logdir = base_fig_dir / "fig2_EpiAtlas_other" / "fig2--sex_chrY_zscore"
merged_assays_separation_distance(zscore_df, logdir, name)

### Sex: prediction score

In [None]:
def pred_score_violin(
    results_df: pd.DataFrame, logdir: Path, name: str, min_y: float | None = None
) -> None:
    """
    Creates a Plotly figure with violin plots and associated scatter plots for each class.
    Red scatter points, indicating a mismatch, appear on top and have a larger size.

    Args:
        results_df (pd.DataFrame): The DataFrame containing the neural network results, including metadata.
        logdir (Path): The directory where the figure will be saved.
        name (str): The name of the figure.
    Returns:
        None: Displays the plotly figure.
    """
    fig = go.Figure()

    # Class ordering
    class_labels = (
        results_df["True class"].unique().tolist()
        + results_df["Predicted class"].unique().tolist()
    )
    class_labels_sorted = sorted(set(class_labels))
    class_index = {label: i for i, label in enumerate(class_labels_sorted)}

    for label in class_labels_sorted:
        df = results_df[results_df["True class"] == label]

        # Majority vote, mean prediction score
        groupby_epirr = df.groupby(["EpiRR", "Predicted class"])["Max pred"].aggregate(
            ["size", "mean"]
        )

        groupby_epirr = groupby_epirr.reset_index().sort_values(
            ["EpiRR", "size"], ascending=[True, False]
        )
        groupby_epirr = groupby_epirr.drop_duplicates(subset="EpiRR", keep="first")
        assert groupby_epirr["EpiRR"].is_unique

        mean_pred = groupby_epirr["mean"]

        fig.add_trace(
            go.Violin(
                x=[class_index[label]] * len(mean_pred),
                y=mean_pred,
                name=label,
                spanmode="hard",
                box_visible=True,
                meanline_visible=True,
                points=False,
                fillcolor=sex_colors[label],
                line_color="black",
                showlegend=False,
            )
        )

        # Prepare data for scatter plots
        match_pred = [
            mean_pred.iloc[i]
            for i, row in enumerate(groupby_epirr.iterrows())
            if row[1]["Predicted class"] == label
        ]
        mismatch_pred = [
            mean_pred.iloc[i]
            for i, row in enumerate(groupby_epirr.iterrows())
            if row[1]["Predicted class"] != label
        ]

        jitter_match = np.random.uniform(-0.01, 0.01, len(match_pred))

        # Add scatter plots for matches in black
        fig.add_trace(
            go.Scatter(
                x=[class_index[label]] * len(match_pred) + jitter_match,
                y=match_pred,
                mode="markers",
                name=label,
                marker=dict(
                    color="black",
                    size=2,  # Standard size for matches
                ),
                hovertemplate="%{text}",
                text=[
                    f"EpiRR: {row[1]['EpiRR']}, Pred class: {row[1]['Predicted class']}, Mean pred: {row[1]['mean']:.2f}"
                    for row in groupby_epirr.iterrows()
                    if row[1]["Predicted class"] == label
                ],
                showlegend=False,
                legendgroup="match",
            )
        )

        # Add scatter plots for mismatches in red, with larger size
        mismatch_info = groupby_epirr[groupby_epirr["Predicted class"] != label]
        strong_mismatch = mismatch_info[mismatch_info["mean"] > 0.9]
        display(strong_mismatch)
        fig.add_trace(
            go.Scatter(
                x=[class_index[label]] * len(mismatch_pred),
                y=mismatch_pred,
                mode="markers",
                name=label,
                marker=dict(
                    color="red",
                    size=5,  # Larger size for mismatches
                ),
                hovertemplate="%{text}",
                text=[
                    f"EpiRR: {row[1]['EpiRR']}, Pred class: {row[1]['Predicted class']}, Mean pred: {row[1]['mean']:.3f} (n={row[1]['size']})"
                    for row in groupby_epirr.iterrows()
                    if row[1]["Predicted class"] != label
                ],
                showlegend=False,
                legendgroup="mismatch",
            )
        )

    # Add a dummy scatter plot for legend - black points
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            name="Match",
            marker=dict(color="black", size=10),
            showlegend=True,
            legendgroup="match",
        )
    )

    # Add a dummy scatter plot for legend - red points
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            name="Mismatch",
            marker=dict(color="red", size=10),
            showlegend=True,
            legendgroup="mismatch",
        )
    )

    tickvals = list(class_index.values())
    ticktext = list(class_index.keys())
    fig.update_xaxes(tickvals=tickvals, ticktext=ticktext)

    if min_y is None:
        min_y = min(results_df["Max pred"])

    fig.update_yaxes(range=[min_y, 1.001])

    fig.update_layout(
        title_text="Prediction score distribution class",
        yaxis_title="Avg. prediction score (majority class)",
        xaxis_title="Expected class label",
        width=750,
    )

    fig.update_layout(
        legend=dict(
            title_text="Legend",
            itemsizing="constant",
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1,
        )
    )

    # Save figure
    # fig.write_html(logdir / f"{name}.html")
    # fig.write_image(logdir / f"{name}.svg")
    # fig.write_image(logdir / f"{name}.png")

    fig.show()

In [None]:
# sex_df = concatenated_results["harmonized_donor_sex_w-mixed"]
# sex_df = split_results_handler.add_max_pred(sex_df)
# sex_df = metadata_handler.join_metadata(sex_df, metadata_v2)

In [None]:
name = "fig2--sex_pred_score_post_correction"
logdir = base_fig_dir / "fig2_EpiAtlas_other" / "fig2--sex_pred_score"
# pred_score_violin(sex_df, logdir, name, min_y=0.485)

### Reduced features sets NN metrics

Regulatory regions NN trainings data download

~~~bash
# Download phase
paper_dir="${HOME}/Projects/epiclass/output/paper/data"
cd ${paper_dir}/training_results/dfreeze_v2

base_path="/lustre07/scratch/rabyj/epilap-logs/epiatlas-dfreeze-v2.1"
rsync --info=progress2 -av --exclude "*/EpiLaP/" --exclude "*.png" --exclude "validation_confusion*" --exclude "*.md5" --exclude "full*" narval:${base_path}/hg38_regulatory_regions_n* .

# Cleanup phase
# Remove files related to failed experiements
# Step 1: Find files and extract numbers
find hg38_regulatory_regions_n* -type f -name "*.e" -exec grep -l 'has non-string label of type' {} + | \
grep -oE "job[0-9]+" | grep -oE "[0-9]+" > failure_jobid.txt

# Step 2: Use extracted numbers to find and echo all matching filenames
cat failure_jobid.txt | xargs -I% sh -c 'find . -type f -name "*%*" -delete'
rm failure_jobid.txt
~~~

In [None]:
def obtain_all_feature_set_metrics(
    parent_folder: Path,
    merge_assays: bool,
) -> Dict[str, Dict[str, Dict[str, Dict[str, float]]]]:
    """Obtain all metrics for all feature sets.

    Args:
        parent_folder (Path): The parent folder containing all feature set folders.
                              Needs to be parent of feature set folders.

    Returns:
        Dict[str, Dict[str, Dict[str, float]]]: A dictionary containing all metrics for all feature sets.
            Format: {feature_set: {task_name: {split_name: metric_dict}}}
    """
    all_metrics: Dict[str, Dict[str, Dict[str, Dict[str, float]]]] = {}
    for folder in parent_folder.iterdir():
        if not folder.is_dir():
            continue
        feature_set = folder.name
        try:
            split_results_metrics, _ = general_split_metrics(
                folder, merge_assays=merge_assays
            )
        except ValueError as err:
            raise ValueError(f"Problem with {feature_set}") from err
        inverted_dict = split_results_handler.invert_metrics_dict(split_results_metrics)
        all_metrics[feature_set] = inverted_dict
    return all_metrics

In [None]:
def get_input_sizes_from_metadata() -> Dict[str, int]:
    """Get input sizes for models using certain feature sets using comet-ml run metadata file."""
    run_metadata = RUN_METADATA.copy(deep=True)

    # Filter out epigeec_filter_1.4.5 runs, wrong input sizes.
    run_metadata = run_metadata[run_metadata["startTimeMillis"] > 1706943404420]

    run_metadata["feature_set"] = run_metadata["Name"].str.extract(
        pat=r"(^hg38_1[0]{0,2}kb_.*_none).*$"
    )

    input_sizes_count = run_metadata.groupby(["feature_set", "input_size"]).size()
    accepted_input_sizes = {idx[0]: int(idx[1]) for idx in input_sizes_count.index}

    assert len(input_sizes_count) == len(accepted_input_sizes)
    accepted_input_sizes.update({"hg38_100kb_all_none": 30321})

    return accepted_input_sizes

In [None]:
gen_data_dir = base_data_dir / "training_results" / "dfreeze_v2"
input_sizes = extract_input_sizes_from_output_files(gen_data_dir)  # type: ignore
input_sizes: Dict[str, int] = {k: v.pop() for k, v in input_sizes.items() if len(v) == 1}  # type: ignore

In [None]:
all_metrics = obtain_all_feature_set_metrics(gen_data_dir, merge_assays=True)

1 - contenant slmt 100kb_all + 10kb_all + 1kb_200k qui servira à la suppFig1  

1.1 - les 2 regulatory (30k puis 300k) à la fin

2 - enlever les 118 puis réordonner pour commencer par les 3 précédentes, suivies 4.5k_100kb et 4.5k_random, 45k_10kb et son random, puis  45k_1kb et son random; ce sera visuellement un peu moins intéressant mais bcp plus facile à décrire je crois  

2.1 - v2 en intégrant aussi regulatory donc dans l'ordre : enlever les 118 puis réordonner : 100kb_all, suivies 4.5k_100kb et 4.5k_random, puis 10kb_all, 45k_10kb et son random, puis 1kb_200k, 45k_1kb et son random + regul_30k et 300k. 

In [None]:
# print(all_metrics["hg38_100kb_all_none"].keys())
print(all_metrics.keys())

In [None]:
v1_metrics = {
    name: vals
    for name, vals in all_metrics.items()
    if any(label in name for label in ["all_none", "200k"])
}
v1_reg_metrics = {
    name: vals
    for name, vals in all_metrics.items()
    if any(label in name for label in ["all_none", "200k", "reg"])
}
v2_metrics = {
    **v1_metrics,
    **{
        name: vals
        for name, vals in all_metrics.items()
        if any(label in name for label in ["global_tasks_union", "4510", "45k"])
    },
}

In [None]:
# print(list(enumerate(v1_metrics.keys())))
# print(list(enumerate(v1_reg_metrics.keys())))
# print(list(enumerate(v2_metrics.keys())))

In [None]:
desired_v1_order = [1, 2, 0]
desired_v1_reg_order = [1, 3, 0, 2, 4]
desired_v2_order = [1, 5, 7, 2, 8, 4, 0, 6, 3]

ordered_v1_metrics = {
    list(v1_metrics.keys())[i]: list(v1_metrics.values())[i] for i in desired_v1_order
}
ordered_v1_reg_metrics = {
    list(v1_reg_metrics.keys())[i]: list(v1_reg_metrics.values())[i]
    for i in desired_v1_reg_order
}
ordered_v2_metrics = {
    list(v2_metrics.keys())[i]: list(v2_metrics.values())[i] for i in desired_v2_order
}
ordered_v2_reg_metrics = {
    **ordered_v2_metrics,
    **{k: v for k, v in all_metrics.items() if "reg" in k},
}

In [None]:
# print(ordered_v1_metrics.keys())
# print(ordered_v1_reg_metrics.keys())
# print(ordered_v2_metrics.keys())
# print(ordered_v2_reg_metrics.keys())

In [None]:
resolution_colors = {
    "100kb": px.colors.qualitative.Safe[0],
    "10kb": px.colors.qualitative.Safe[1],
    "1kb": px.colors.qualitative.Safe[2],
    "regulatory": px.colors.qualitative.Safe[3],
}

In [None]:
resolution_colors = {
    "100kb": px.colors.qualitative.Safe[0],
    "10kb": px.colors.qualitative.Safe[1],
    "1kb": px.colors.qualitative.Safe[2],
    "regulatory": px.colors.qualitative.Safe[3],
    "gene": px.colors.qualitative.Safe[4],
    "cpg": px.colors.qualitative.Safe[5],
    "1mb": px.colors.qualitative.Safe[6],
    "5mb": px.colors.qualitative.Safe[7],
    "10mb": px.colors.qualitative.Safe[8],
}

In [None]:
def graph_feature_set_metrics(
    all_metrics: Dict[str, Dict[str, Dict[str, Dict[str, float]]]],
    input_sizes: Dict[str, int],
    logdir: Path,
    sort_by_input_size: bool = False,
) -> None:
    """Create a graph of all metrics for all feature sets."""
    reference_hdf5_type = "hg38_1kb_global_tasks_union_UpResolution_1kb_sampled-200k_none"
    metadata_categories = list(all_metrics[reference_hdf5_type].keys())

    non_standard_names = {ASSAY: f"{ASSAY}_11c", SEX: f"{SEX}_w-mixed"}
    non_standard_assay_task_names = ["hg38_100kb_all_none"]
    non_standard_sex_task_name = [
        "hg38_100kb_all_none",
        "hg38_regulatory_regions_n30321",
        "hg38_regulatory_regions_n303114",
    ]

    for i in range(len(metadata_categories)):
        category_idx = i
        category_fig = make_subplots(
            rows=1,
            cols=2,
            shared_yaxes=True,
            subplot_titles=["Accuracy", "F1-score (macro)"],
            x_title="Feature set",
            y_title="Metric value",
        )

        trace_names = []
        order = list(all_metrics.keys())
        if sort_by_input_size:
            order = sorted(
                all_metrics.keys(),
                key=lambda x: input_sizes[x],
            )
        for feature_set_name in order:
            tasks_dicts = all_metrics[feature_set_name]
            meta_categories = copy.deepcopy(metadata_categories)

            if feature_set_name not in input_sizes:
                print(f"Skipping {feature_set_name}, no input size found.")
                continue

            task_name = meta_categories[category_idx]
            try:
                task_dict = tasks_dicts[task_name]
            except KeyError as err:
                if SEX in str(err) and feature_set_name in non_standard_sex_task_name:
                    task_dict = tasks_dicts[non_standard_names[SEX]]
                elif (
                    ASSAY in str(err)
                    and feature_set_name in non_standard_assay_task_names
                ):
                    task_dict = tasks_dicts[non_standard_names[ASSAY]]
                else:
                    print("Skipping", feature_set_name, task_name)
                    continue

            input_size = input_sizes[feature_set_name]

            feature_set_name = feature_set_name.replace("_none", "")
            feature_set_name = feature_set_name.replace("hg38_", "")

            resolution = feature_set_name.split("_")[0]

            trace_name = f"{input_size}|{feature_set_name}"
            trace_names.append(trace_name)

            # Accuracy
            metric = "Accuracy"
            y_vals = [task_dict[split][metric] for split in task_dict]
            hovertext = [
                f"{split}: {metrics_dict[metric]:.4f}"
                for split, metrics_dict in task_dict.items()
            ]

            category_fig.add_trace(
                go.Box(
                    y=y_vals,
                    name=trace_name,
                    boxmean=True,
                    boxpoints="all",
                    showlegend=False,
                    marker=dict(size=3, color="black"),
                    line=dict(width=1, color="black"),
                    fillcolor=resolution_colors[resolution],
                    hovertemplate="%{text}",
                    text=hovertext,
                    legendgroup=resolution,
                ),
                row=1,
                col=1,
            )

            metric = "F1_macro"
            y_vals = [task_dict[split][metric] for split in task_dict]
            hovertext = [
                f"{split}: {metrics_dict[metric]:.4f}"
                for split, metrics_dict in task_dict.items()
            ]
            category_fig.add_trace(
                go.Box(
                    y=y_vals,
                    name=trace_name,
                    boxmean=True,
                    boxpoints="all",
                    showlegend=False,
                    marker=dict(size=3, color="black"),
                    line=dict(width=1, color="black"),
                    fillcolor=resolution_colors[resolution],
                    hovertemplate="%{text}",
                    text=hovertext,
                    legendgroup=resolution,
                ),
                row=1,
                col=2,
            )

        # category_fig.update_yaxes(range=[0.65, 1.001])
        category_fig.update_layout(
            width=1500,
            height=1500,
            title=f"Neural network performance - {metadata_categories[category_idx]}",
        )

        if logdir.name == "all":
            category_fig.update_xaxes(
                categoryorder="array",
                categoryarray=sorted(trace_names, key=lambda x: int(x.split("|")[0])),
            )

        # dummy scatters for resolution colors
        relevant_resolutions = [
            resolution
            for resolution in resolution_colors
            if any(resolution in name for name in trace_names)
        ]
        for resolution in relevant_resolutions:
            color = resolution_colors[resolution]
            category_fig.add_trace(
                go.Scatter(
                    x=[None],
                    y=[None],
                    mode="markers",
                    name=resolution,
                    marker=dict(color=color, size=5),
                    showlegend=True,
                    legendgroup=resolution,
                )
            )

        # Save figure
        base_name = f"feature_set_metrics_{metadata_categories[category_idx]}"
        category_fig.write_html(logdir / f"{base_name}.html")
        category_fig.write_image(logdir / f"{base_name}.svg")
        category_fig.write_image(logdir / f"{base_name}.png")

        # category_fig.show()

In [None]:
for folder, metrics in zip(
    ["v1", "v1_reg", "v2", "v2_reg"],
    [
        ordered_v1_metrics,
        ordered_v1_reg_metrics,
        ordered_v2_metrics,
        ordered_v2_reg_metrics,
    ],
):
    logdir = base_fig_dir / "fig2_EpiAtlas_other" / "fig2--reduced_feature_sets" / folder
    logdir.mkdir(parents=False, exist_ok=True)
    graph_feature_set_metrics(metrics, input_sizes, logdir)

### Effect of zeroing blacklisted regions, and winzorizing input files

Data download

~~~bash
paper_dir="${HOME}/Projects/epiclass/output/paper/data"
cd ${paper_dir}/training_results/2023-01-epiatlas-freeze

base_path="/lustre07/scratch/rabyj/epilap-logs/2023-01-epiatlas-freeze"
rsync --info=progress2 -a --exclude "*/EpiLaP/" --exclude "*.png" --exclude "validation_confusion*" --exclude "*.md5" narval:${base_path}/hg38_100kb_all_none_0blklst* .
~~~

In [None]:
BLKLST_CATEGORIES = [
    "assay_epiclass",
    "harmonized_biomaterial_type",
    "harmonized_donor_sex",
    "harmonized_sample_ontology_intermediate",
]

#### Check oversampling

Make sure oversampling is same in all training runs used

In [None]:
def verify_2023_runs_oversampling():
    """Check if oversampling is on for all 2023 training runs used for blacklisted/winzorized metrics."""
    data_dir = base_data_dir / "2023-01-epiatlas-freeze"
    for folder in data_dir.iterdir():
        for category in BLKLST_CATEGORIES:
            category_parent_folder = folder / f"{category}_1l_3000n"

            if not category_parent_folder.exists():
                raise FileNotFoundError("Cannot find: {category_parent_folder}")

            print(f"Processing {category_parent_folder}")

            check_for_oversampling(category_parent_folder, verbose=False)
            print()

In [None]:
# verify_2023_runs_oversampling()

verify_2023_runs_oversampling result:
Oversampling uniform across hdf5 types, but unsure across metadata categories.
  - harmonized_biomaterial_type: On
  - harmonized_sample_ontology_intermediate: On
  - harmonized_donor_sex: Unknown, very probably On.
    All nan values, but used human_longer.json hparams, which is the same as with the other runs that have oversampling on. 

#### Compute metrics

In [None]:
def get_blklst_split_metrics(
    verbose: bool = False,
) -> Dict[str, Dict[str, Dict[str, float]]]:
    """Compute metrics on relevant categories and runs.

    Returns:
        Dict[str, Dict[str, Dict[str, float]]]: A dictionary containing all metrics for all blklst related feature sets.
            Format: {feature_set: {task_name: {split_name: metric_dict}}}
    """
    data_dir = base_data_dir / "training_results" / "2023-01-epiatlas-freeze"
    feature_set_metrics_dict = {}
    for folder in data_dir.iterdir():
        if folder.is_file():
            continue
        feature_set_name = folder.name

        tasks_dict = {}
        for category in BLKLST_CATEGORIES:
            category_parent_folder = folder / f"{category}_1l_3000n"

            if not category_parent_folder.exists():
                raise FileNotFoundError("Cannot find: {category_parent_folder}")

            if verbose:
                print(f"Processing {category_parent_folder}")

            for task_folder in category_parent_folder.iterdir():
                if task_folder.is_file():
                    continue
                split_results = split_results_handler.read_split_results(task_folder)
                general_name = f"{category_parent_folder.name}-{task_folder.name}"
                tasks_dict[general_name] = split_results

        feature_set_metrics = split_results_handler.compute_split_metrics(
            tasks_dict, concat_first_level=True
        )
        feature_set_metrics_dict[
            feature_set_name
        ] = split_results_handler.invert_metrics_dict(feature_set_metrics)

    return feature_set_metrics_dict

In [None]:
# feature_set_metrics_dict = get_blklst_split_metrics(verbose=False)

#### Create graphs

In [None]:
def create_blklst_graphs(
    feature_set_metrics_dict: Dict[str, Dict[str, Dict[str, float]]], logdir: Path
) -> None:
    """Create boxplots for blacklisted related feature sets.

    Args:
        feature_set_metrics_dict (Dict[str, Dict[str, Dict[str, float]]]): The dictionary containing all metrics for all blklst related feature sets.
            format: {feature_set: {task_name: {split_name: metric_dict}}}
    """
    # Assume names exist in all feature sets
    task_names = list(feature_set_metrics_dict.values())[0].keys()

    traces_names_dict = {
        "hg38_100kb_all_none": "observed",
        "hg38_100kb_all_none_0blklst": "0blklst",
        "hg38_100kb_all_none_0blklst_winsorized": "0blklst_winsorized",
    }

    for task_name in task_names:
        category_fig = make_subplots(
            rows=1,
            cols=2,
            shared_yaxes=False,
            subplot_titles=["Accuracy", "F1-score (macro)"],
            x_title="Feature set",
            y_title="Metric value",
            horizontal_spacing=0.03,
        )
        for feature_set_name, tasks_dicts in feature_set_metrics_dict.items():
            task_dict = tasks_dicts[task_name]
            trace_name = traces_names_dict[feature_set_name]

            # Accuracy
            metric = "Accuracy"
            y_vals = [task_dict[split][metric] for split in task_dict]  # type: ignore
            hovertext = [
                f"{split}: {metrics_dict[metric]:.4f}"  # type: ignore
                for split, metrics_dict in task_dict.items()
            ]

            category_fig.add_trace(
                go.Box(
                    y=y_vals,
                    name=trace_name,
                    boxmean=True,
                    boxpoints="all",
                    showlegend=False,
                    marker=dict(size=3, color="black"),
                    line=dict(width=1, color="black"),
                    hovertemplate="%{text}",
                    text=hovertext,
                ),
                row=1,
                col=1,
            )

            metric = "F1_macro"
            y_vals = [task_dict[split][metric] for split in task_dict]  # type: ignore
            hovertext = [
                f"{split}: {metrics_dict[metric]:.4f}"  # type: ignore
                for split, metrics_dict in task_dict.items()
            ]
            category_fig.add_trace(
                go.Box(
                    y=y_vals,
                    name=trace_name,
                    boxmean=True,
                    boxpoints="all",
                    showlegend=False,
                    marker=dict(size=3, color="black"),
                    line=dict(width=1, color="black"),
                    hovertemplate="%{text}",
                    text=hovertext,
                ),
                row=1,
                col=2,
            )

        category_fig.update_xaxes(
            categoryorder="array",
            categoryarray=list(traces_names_dict.values()),
        )
        category_fig.update_yaxes(range=[0.85, 1.001])

        task_name = task_name.replace("_1l_3000n-10fold", "")
        category_fig.update_layout(
            title=f"Neural network performance - {task_name} - 100kb",
        )

        # Save figure
        base_name = f"metrics_{task_name}"
        category_fig.write_html(logdir / f"{base_name}.html")
        category_fig.write_image(logdir / f"{base_name}.svg")
        category_fig.write_image(logdir / f"{base_name}.png")

        category_fig.show()

In [None]:
logdir = base_fig_dir / "fig2_EpiAtlas_other" / "fig2--blklst_and_winsorized" / "y0.85"
logdir.mkdir(parents=False, exist_ok=True)
# create_blklst_graphs(feature_set_metrics_dict, logdir)