In [None]:
"""Workbook to create figures (fig2) destined for the paper.

Please use dfreeze v2 for these.
"""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines, unused-import, unused-argument, too-many-branches

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import annotations

import copy
import itertools
import logging
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple

logging.basicConfig(level=logging.DEBUG)

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import display
from plotly.subplots import make_subplots
from scipy.stats import zscore
from sklearn.metrics import confusion_matrix as sk_cm

from epi_ml.core.confusion_matrix import ConfusionMatrixWriter
from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    ASSAY_MERGE_DICT,
    ASSAY_ORDER,
    CELL_TYPE,
    LIFE_STAGE,
    SEX,
    IHECColorMap,
    MetadataHandler,
    SplitResultsHandler,
    merge_similar_assays,
)

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
paper_dir = base_dir

if not base_fig_dir.exists():
    raise FileNotFoundError(f"Directory {base_fig_dir} does not exist.")

In [None]:
IHECColorMap = IHECColorMap(base_fig_dir)
assay_colors = IHECColorMap.assay_color_map
cell_type_colors = IHECColorMap.cell_type_color_map
sex_colors = IHECColorMap.sex_color_map

In [None]:
split_results_handler = SplitResultsHandler()

In [None]:
metadata_handler = MetadataHandler(paper_dir)
metadata_v2 = metadata_handler.load_metadata("v2")
metadata_v2_df = pd.DataFrame.from_records(list(metadata_v2.datasets))
metadata_v2_df.set_index("md5sum", inplace=True)
metadata_v2_df[ASSAY].replace(ASSAY_MERGE_DICT, inplace=True)

In [None]:
parameters_metadata_path = (
    base_data_dir / "all_results_cometml_filtered_oversampling-fixed.csv"
)
RUN_METADATA = pd.read_csv(parameters_metadata_path)

## Fig 2 - EpiClass results on EpiAtlas other metadata

For following figures, use v1.1 of sample metadata (called v2.1 internally), i.e. dfreeze 2

A) Histogram of performance (accuracy and F1 scores) for each category  

B) Violin plot of average z-score on chrY per sex, black dots for pred same class and red for pred different class.  
- Do the split male female violin per assay (only FC, merge 2xwgbs and 2xrna, no rna unique_raw). 
- Use scatter for points on each side, agree same color as violin, disagree other.
- Point labels: uuid, epirr

C) ---  
D) ---  
E) --- 

### Neural network performance across metadata categories

#### Check if oversampling is uniform

In [None]:
def check_for_oversampling(parent_dir: Path):
    """Check for oversampling status in the results.

    Returns a ValeError if not all experiments have oversampling.
    """
    # Identify experiments
    exp_key_line = "The current experiment key is"
    exp_keys_dict = defaultdict(list)
    folders = list(parent_dir.iterdir())
    for folder in folders:
        for stdout_file in folder.glob("output_job*.o"):
            with open(stdout_file, "r", encoding="utf8") as f:
                lines = [l.rstrip() for l in f if exp_key_line in l]
            exp_keys = [l.split(exp_key_line)[1].strip() for l in lines]
            exp_keys_dict[folder.name].extend(exp_keys)

    # Filter metadata to only include experiments in the results
    all_exp_keys = set()
    for exp_keys in exp_keys_dict.values():
        all_exp_keys.update(exp_keys)

    df = RUN_METADATA[RUN_METADATA["experimentKey"].isin(all_exp_keys)]

    # Check oversampling values, ignore nan
    df = df[df["hparams/oversampling"].notna()]
    if not (df["hparams/oversampling"] == "true").all():
        df["general_name"] = df["Name"].str.replace(r"[_-]?split\d+$", "", regex=True)
        err_df = df.groupby(["general_name", "hparams/oversampling"]).agg("size")
        logging.warning(
            "Not all experiments have oversampling:\n%s",
            err_df,
        )

    print(f"Checked {len(folders)} folders and {len(df)} experiments.")
    if len(folders) * 10 != len(df):
        logging.warning("Could not read oversampling value of all visited experiments")

In [None]:
assay_dir = (
    base_data_dir
    / "dfreeze_v2"
    / "hg38_100kb_all_none"
    / "assay_epiclass_1l_3000n"
    / "11c"
)
ct_dir = (
    base_data_dir
    / "dfreeze_v2"
    / "hg38_100kb_all_none"
    / "harmonized_sample_ontology_intermediate_1l_3000n"
)
# check_for_oversampling(assay_dir,s ASSAY)
# check_for_oversampling(ct_dir, CELL_TYPE)

In [None]:
# v2_results_dir = base_data_dir / "dfreeze_v2"
# check_for_oversampling_global(base_data_dir)

Histogram of performance (accuracy and F1 scores) for each category

In [None]:
def create_mislabel_corrector():
    """Obtain information necessary to correct sex and life_stage mislabels.

    Returns:
        Dict[str, str]: {md5sum: EpiRR_no-v}
        Dict[str, Dict[str, str]]: {label_category: {EpiRR_no-v: corrected_label}}
    """
    epirr_no_v = "EpiRR_no-v"
    # Associate epirrs to md5sums
    metadata = MetadataHandler(paper_dir).load_metadata("v2")
    metadata_df = pd.DataFrame.from_records(list(metadata.datasets))
    md5sum_to_epirr = metadata_df.set_index("md5sum")[epirr_no_v].to_dict()

    # Load mislabels
    epirr_to_corrections = {}
    metadata_dir = base_data_dir / "metadata"

    sex_mislabeled = pd.read_csv(metadata_dir / "official_Sex_mislabeled.csv")
    epirr_to_corrections[SEX] = sex_mislabeled.set_index(epirr_no_v)[
        "EpiClass_pred_Sex"
    ].to_dict()

    life_stage_mislabeled = pd.read_csv(
        metadata_dir / "official_Life_stage_mislabeled.csv"
    )
    epirr_to_corrections[LIFE_STAGE] = life_stage_mislabeled.set_index(epirr_no_v)[
        "EpiClass_pred_Life_stage"
    ].to_dict()

    return md5sum_to_epirr, epirr_to_corrections

In [None]:
def fig2_a_content(
    results_dir: Path,
    exclude_categories: List[str] | None = None,
    exclude_names: List[str] | None = None,
) -> Tuple[Dict[str, Dict[str, Dict[str, float]]], Dict[str, Dict[str, pd.DataFrame]]]:
    """Create the content data for figure 2a. (get metrics for each task)

    Currently only using oversampled runs.

    Args:
        results_dir (Path): Directory containing the results. Needs to be parent over category folders.
        exclude_categories (List[str]): Task categories to exclude (first level directory names).
        exclude_names (List[str]): Names of folders to exclude (ex: 7c or no-mix).

    Returns:
        Dict[str, Dict[str, Dict[str, float]]] A metrics dictionary with the following structure:
            {split_name: {task_name: metrics_dict}}
        Dict[str, Dict[str, pd.DataFrame]] A split results dictionary with the following structure:
            {task_name: {split_name: split_results_df}}
    """
    all_split_results = {}
    split_results_handler = SplitResultsHandler()

    md5sum_to_epirr, epirr_to_corrections = create_mislabel_corrector()

    for parent, _, _ in os.walk(results_dir):
        # Looking for oversampling only results
        parent = Path(parent)
        if parent.name != "10fold-oversampling":
            continue

        # Get the category
        relpath = parent.relative_to(results_dir)
        category = relpath.parts[0].rstrip("_1l_3000n")
        if exclude_categories is not None:
            if any(exclude_str in category for exclude_str in exclude_categories):
                continue

        # Get the rest of the name, ignore certain runs
        rest_of_name = list(relpath.parts[1:])
        rest_of_name.remove("10fold-oversampling")

        if len(rest_of_name) > 1:
            raise ValueError(
                f"Too many parts in the name: {rest_of_name}. Path: {relpath}"
            )
        if rest_of_name:
            rest_of_name = rest_of_name[0]

        if exclude_names is not None:
            if any(name in rest_of_name for name in exclude_names):
                continue

        full_task_name = category
        if rest_of_name:
            full_task_name += f"_{rest_of_name}"

        # Get the split results
        split_results = split_results_handler.read_split_results(parent)
        if not split_results:
            raise ValueError(f"No split results found in {parent}")

        if "sex" in full_task_name or "life_stage" in full_task_name:
            corrections = epirr_to_corrections[category]
            for split_name in split_results:
                split_result_df = split_results[split_name]
                current_true_class = split_result_df["True class"].to_dict()
                new_true_class = {
                    k: corrections.get(md5sum_to_epirr[k], v)
                    for k, v in current_true_class.items()
                }
                split_result_df["True class"] = new_true_class.values()

                split_results[split_name] = split_result_df

        all_split_results[full_task_name] = split_results

    try:
        split_results_metrics = split_results_handler.compute_split_metrics(
            all_split_results, concat_first_level=True
        )
    except KeyError as e:
        logging.error("KeyError: %s", e)
        logging.error("all_split_results: %s", all_split_results)
        logging.error("check folder: %s", results_dir)
        raise e
    return split_results_metrics, all_split_results

In [None]:
def fig2_a(
    split_metrics: Dict[str, Dict[str, Dict[str, float]]],
    logdir: Path,
    name: str,
    exclude_categories: List[str],
    y_range: List[float] | None = None,
    sort_by_acc: bool = False,
) -> None:
    """Render box plots of metrics per classifier and split, each in its own subplot.

    This function generates a figure with subplots, each representing a different
    metric. Each subplot contains box plots for each classifier, ordered by accuracy.

    Args:
        split_metrics: A nested dictionary with structure {split: {classifier: {metric: score}}}.
        logdir: The directory path to save the output plots.
        name: The base name for the output plot files.
        exclude_categories: Task categories to exclude from the plot.
    """
    metrics = ["Accuracy", "F1_macro"]
    # metrics = ["AUC_micro", "AUC_macro"]
    # metrics = ["Accuracy", "F1_macro", "AUC_micro", "AUC_macro"]

    # Exclude some categories
    classifier_names = list(next(iter(split_metrics.values())).keys())
    for category in exclude_categories:
        classifier_names = [c for c in classifier_names if category not in c]

    # Sort classifiers by accuracy
    if sort_by_acc:
        mean_acc = {}
        for classifier in classifier_names:
            mean_acc[classifier] = np.mean(
                [split_metrics[split][classifier]["Accuracy"] for split in split_metrics]
            )
        classifier_names = sorted(
            classifier_names, key=lambda x: mean_acc[x], reverse=True
        )

    # Create subplots, one column for each metric
    fig = make_subplots(
        rows=1,
        cols=len(metrics),
        subplot_titles=metrics,
        horizontal_spacing=0.03,
    )

    colors = {
        classifier: px.colors.qualitative.Plotly[i]
        for i, classifier in enumerate(classifier_names)
    }

    for i, metric in enumerate(metrics):
        for classifier_name in classifier_names:
            values = [
                split_metrics[split][classifier_name][metric] for split in split_metrics
            ]

            fig.add_trace(
                go.Box(
                    y=values,
                    name=classifier_name,
                    fillcolor=colors[classifier_name],
                    line=dict(color="black", width=1.5),
                    marker=dict(size=3, color="black"),
                    boxmean=True,
                    boxpoints="all",
                    pointpos=0,
                    showlegend=i == 0,  # Only show legend in the first subplot
                    hovertemplate="%{text}",
                    text=[
                        f"{split}: {value:.4f}"
                        for split, value in zip(split_metrics, values)
                    ],
                    legendgroup=classifier_name,
                    width=0.5,
                ),
                row=1,
                col=i + 1,
            )

    fig.update_layout(
        title_text="Neural network classification - Metric distribution for 10-fold cross-validation",
        yaxis_title="Value",
        boxmode="group",
        height=1200 * 0.8,
        width=1750 * 0.8,
    )

    # Acc, F1
    range_acc = [0.86, 1.001]
    fig.update_layout(yaxis=dict(range=range_acc))
    fig.update_layout(yaxis2=dict(range=range_acc))

    # AUC
    range_auc = [0.986, 1.0001]
    fig.update_layout(yaxis3=dict(range=range_auc))
    fig.update_layout(yaxis4=dict(range=range_auc))

    if y_range is not None:
        fig.update_yaxes(range=y_range)

    # Save figure
    fig.write_image(logdir / f"{name}.svg")
    fig.write_image(logdir / f"{name}.png")
    fig.write_html(logdir / f"{name}.html")

    fig.show()

In [None]:
# for split_name, split_metrics in all_split_metrics.items():
#     print(split_name)
#     for task_name, task_metrics in split_metrics.items():
#         print(task_name)
#         print(task_metrics)
#     print()

In [None]:
exclude_categories = ["groups_second_level_name", "track_type", "disease"]
exclude_names = ["chip-seq", "7c"]

In [None]:
# # 100kb
# results_dir = base_data_dir / "dfreeze_v2" / "hg38_100kb_all_none"
# if not results_dir.exists():
#     raise FileNotFoundError(f"Directory {results_dir} does not exist.")
# split_results_metrics, all_split_results = fig2_a_content(results_dir, exclude_categories, exclude_names)

In [None]:
# # 10kb
# results_dir = base_data_dir / "dfreeze_v2" / "hg38_10kb_all_none"
# if not results_dir.exists():
#     raise FileNotFoundError(f"Directory {results_dir} does not exist.")
# split_results_metrics, all_split_results = fig2_a_content(results_dir, exclude_categories, exclude_names)

In [None]:
# concatenated_results = split_results_handler.concatenate_split_results(all_split_results, concat_first_level=True)

In [None]:
# all_split_results.keys()
# split_results_metrics["split0"].keys()

In [None]:
fig_logdir = base_fig_dir / "fig2_EpiAtlas_other" / "fig2--NN_perf_across_categories"
fig_logdir.mkdir(parents=False, exist_ok=True)
fig_name = "fig2_A"

exclude_categories = ["disease", "sex_no-mixed"]
# fig2_a(split_results_metrics, fig_logdir, fig_name, exclude_categories)

### Neural network performance per assay across metadata categories

In [None]:
def NN_performance_per_assay_across_categories(
    all_split_results: Dict[str, Dict[str, pd.DataFrame]],
    logdir: Path,
    name: str,
    exclude_categories: List[str],
    y_range: None | List[float] = None,
):
    """Create a box plot of assay accuracy for each classifier."""
    all_split_results = copy.deepcopy(all_split_results)

    # Exclude some categories
    classifier_names = list(all_split_results.keys())
    for category in exclude_categories:
        classifier_names = [c for c in classifier_names if category not in c]

    metadata_df = MetadataHandler(paper_dir).load_metadata_df("v2")
    metadata_df = metadata_df[ASSAY]  # only need assay column
    metadata_df.replace(ASSAY_MERGE_DICT, inplace=True)

    # One graph per metadata category
    for task_name in classifier_names:
        split_results = all_split_results[task_name]
        if ASSAY in task_name:
            for split_name in split_results:
                split_results[split_name] = merge_similar_assays(
                    split_results[split_name]
                )

        assay_acc = defaultdict(dict)
        for split_name, split_result_df in split_results.items():
            # Merge metadata
            split_result_df = split_result_df.merge(
                metadata_df, left_index=True, right_index=True
            )

            # Compute accuracy per assay
            assay_groupby = split_result_df.groupby(ASSAY)
            for assay, assay_df in assay_groupby:
                assay_acc[assay][split_name] = np.mean(
                    assay_df["True class"].astype(str).str.lower()
                    == assay_df["Predicted class"].astype(str).str.lower()
                )

        assay_acc_df = pd.DataFrame(assay_acc)
        fig = go.Figure()
        for assay in ASSAY_ORDER:
            try:
                assay_accuracies = assay_acc_df[assay]
            except KeyError:
                continue

            fig.add_trace(
                go.Box(
                    y=assay_accuracies.values,
                    name=assay,
                    boxmean=True,
                    boxpoints="all",
                    showlegend=True,
                    marker=dict(size=3, color="black"),
                    line=dict(width=1, color="black"),
                    fillcolor=assay_colors[assay],
                    hovertemplate="%{text}",
                    text=[
                        f"{split}: {value:.4f}"
                        for split, value in assay_accuracies.items()
                    ],
                )
            )

        if "sample_ontology" in task_name:
            yrange = [0.59, 1.001]
        elif ASSAY in task_name:
            yrange = [0.985, 1.001]
        else:
            yrange = [assay_acc_df.min(), 1.001]  # type: ignore

        if y_range is not None:
            yrange = y_range

        fig.update_yaxes(range=yrange)

        fig.update_layout(
            title_text=f"Neural network classification - {task_name} - Assay accuracy",
            yaxis_title="Accuracy",
            xaxis_title="Assay",
        )

        # Save figure
        this_name = name + f"_{task_name}"
        fig.write_image(logdir / f"{this_name}.svg")
        fig.write_image(logdir / f"{this_name}.png")
        fig.write_html(logdir / f"{this_name}.html")

        fig.show()

In [None]:
exclude_categories = ["disease", "sex_no-mixed"]

In [None]:
# logdir = base_fig_dir / "fig2_EpiAtlas_other" / "fig2--NN_perf_across_categories" / "per_assay"
# NN_performance_per_assay_across_categories(all_split_results, logdir, fig_name, exclude_categories)

### Track type effect on NN performance

In [None]:
parent_dir = base_data_dir / "dfreeze_v2" / "hg38_100kb_all_none"
assay_parent_dir = parent_dir / "assay_epiclass_1l_3000n" / "11c"
ct_parent_dir = parent_dir / "harmonized_sample_ontology_intermediate_1l_3000n"

assay_results = {
    folder.name: split_results_handler.read_split_results(folder)
    for folder in assay_parent_dir.iterdir()
    if "chip" not in folder.name
}
ct_results = {
    folder.name: split_results_handler.read_split_results(folder)
    for folder in ct_parent_dir.iterdir()
    if "l1" not in folder.name
}

_ = assay_results.pop("10fold-oversampling")
_ = ct_results.pop("10fold-oversampling")
_ = ct_results.pop("10fold-oversampling_chip-seq-only")

In [None]:
corrected_assay_results = copy.deepcopy(assay_results)
for task_name, split_dfs in list(corrected_assay_results.items()):
    for split_name in split_dfs:
        split_dfs[split_name] = merge_similar_assays(split_dfs[split_name])

In [None]:
assay_metrics = split_results_handler.compute_split_metrics(
    corrected_assay_results, concat_first_level=True
)
ct_metrics = split_results_handler.compute_split_metrics(
    ct_results, concat_first_level=True
)

In [None]:
logdir = base_fig_dir / "fig2_EpiAtlas_other"
name = f"{ASSAY}_global_track_type_effect"
# fig2_a(
#     assay_metrics,
#     logdir,
#     name,
#     exclude_categories=[],
#     y_range=[0.99, 1.0001],
#     sort_by_acc=False,
# )

logdir = base_fig_dir / "fig2_EpiAtlas_other"
name = f"{CELL_TYPE}_global_track_type_effect"
# fig2_a(
#     ct_metrics,
#     logdir,
#     name,
#     exclude_categories=[],
#     y_range=[0.91, 1.001],
#     sort_by_acc=False,
# )

In [None]:
name = f"{ASSAY}_global_track_type_effect_per_assay"
# NN_performance_per_assay_across_categories(
#     corrected_assay_results, logdir, name, exclude_categories=[], y_range=[0.96, 1.001]
# )

In [None]:
def only_keep_core_assays(
    results_dfs: Dict[str, Dict[str, pd.DataFrame]]
) -> Dict[str, Dict[str, pd.DataFrame]]:
    """Exclude non core-assays from split results. Also exclude input."""
    accepted_assays = ASSAY_ORDER[0:-3]
    new_results = copy.deepcopy(results_dfs)
    for task_name, split_dfs in list(new_results.items()):
        for split_name in split_dfs:
            df = split_dfs[split_name]
            if ASSAY not in df.columns:
                merged_df = df.merge(
                    metadata_v2_df, how="left", left_index=True, right_index=True
                )
                df = df[merged_df[ASSAY].isin(accepted_assays)]
                new_results[task_name][split_name] = df
    return new_results

In [None]:
# # Recompute metrics considering only histones
# for result_df, category_name, y_range in zip(
#     [corrected_assay_results, ct_results],
#     [ASSAY, CELL_TYPE],
#     [[0.85, 1.001], [0.91, 1.001]],
# ):
#     print(category_name)
#     name = f"{category_name}_core6c_track_type_effect"

#     core_result_df = only_keep_core_assays(result_df)
#     metrics = split_results_handler.compute_split_metrics(
#         core_result_df, concat_first_level=True
#     )

#     fig2_a(metrics, logdir, name, exclude_categories=[], y_range=y_range)

#     if category_name == ASSAY:
#         name = f"{ASSAY}_core6_track_type_effect_per_assay"
#         NN_performance_per_assay_across_categories(
#             core_result_df, logdir, name, exclude_categories=[], y_range=[0.97, 1.001]
#         )

### Sex chrY z-score distribution vs predictions

Violin plot of average z-score on chrY per sex, black dots for pred same class and red for pred different class.  

- Do the split male female violin per assay (only FC, merge 2xwgbs and 2xrna, no rna unique_raw). 
- Use scatter for points on each side, agree same color as violin, disagree other.
- Point labels: uuid, epirr

Compute chrY coverage z-score VS assay distribution

In [None]:
def compute_chrY_zscores(version: str):
    """Compute z-scores for chrY coverage data, per assay distribution.

    Excludes raw, pval, and Unique_raw tracks.
    """
    # Get chrY coverage data
    chrY_coverage_dir = base_data_dir / "chrY_coverage"
    if not chrY_coverage_dir.exists():
        raise FileNotFoundError(f"Directory {chrY_coverage_dir} does not exist.")
    chrY_coverage_df = pd.read_csv(chrY_coverage_dir / "chrXY_coverage_all.csv", header=0)

    # Filter out md5s not in metadata version
    metadata = MetadataHandler(paper_dir).load_metadata(version)
    md5s = set(metadata.md5s)
    chrY_coverage_df = chrY_coverage_df[chrY_coverage_df["filename"].isin(md5s)]

    # Make sure all values are non-zero
    assert (chrY_coverage_df["chrY"] != 0).all()

    # These tracks are excluded from z-score computation
    metadata.remove_category_subsets("track_type", ["raw", "pval", "Unique_raw"])
    metadata_df = pd.DataFrame.from_records(list(metadata.datasets))
    metadata_df[ASSAY].replace(ASSAY_MERGE_DICT, inplace=True)

    # Merge with metadata
    chrY_coverage_df = chrY_coverage_df.merge(
        metadata_df[["md5sum", ASSAY]], left_on="filename", right_on="md5sum"
    )

    # Compute z-score per assay
    chrY_dists = chrY_coverage_df.groupby(ASSAY).agg({"chrY": ["mean", "std", "count"]})

    output_dir = chrY_coverage_dir / f"dfreeze_{version}_stats"
    output_dir.mkdir(parents=False, exist_ok=True)
    chrY_dists.to_csv(output_dir / "chrY_coverage_stats.csv")

    # Compute z-score per assay group, merge back into the dataframe, save results
    metric_name = "chrY_zscore_vs_assay"
    groupby_df = chrY_coverage_df.groupby(ASSAY)
    for _, group in groupby_df:
        group["chrY_zscore"] = zscore(group["chrY"])
        chrY_coverage_df.loc[group.index, metric_name] = group["chrY_zscore"]

    output_cols = ["filename", "chrY", metric_name, ASSAY]
    chrY_coverage_df[output_cols].to_csv(
        output_dir / "chrY_coverage_zscore_vs_assay.csv", index=False
    )

In [None]:
compute_chrY_zscores("v2")

Plot z-scores according to sex

main Fig: chrY per EpiRR (excluding WGBS): only boxplot with all points

In [None]:
metric_label = "chrY_zscore_vs_assay"

In [None]:
def prepare_fig_2B_data(version: str, prediction_data_dir: Path) -> pd.DataFrame:
    """Prepare data for figure 2b."""
    # Load metadata
    meta_cols = ["md5sum", "EpiRR", SEX]
    metadata = MetadataHandler(paper_dir).load_metadata(version)
    metadata_df = pd.DataFrame.from_records(list(metadata.datasets))
    metadata_df = metadata_df[meta_cols]

    # Load z-score data
    zscore_dir = base_data_dir / "chrY_coverage" / f"dfreeze_{version}_stats"
    zscore_df = pd.read_csv(zscore_dir / "chrY_coverage_zscore_vs_assay.csv", header=0)

    # Load NN predictions
    pred_df = pd.read_csv(
        prediction_data_dir / "full-10fold-validation_prediction.csv",
        header=0,
        index_col=0,
    )

    # Merge all
    zscore_df = zscore_df.merge(metadata_df, left_on="filename", right_on="md5sum")
    zscore_df = zscore_df.merge(pred_df, left_on="filename", right_index=True)
    zscore_df["Max pred"] = zscore_df[["female", "male", "mixed"]].max(axis=1)
    zscore_df.set_index("md5sum", inplace=True)
    return zscore_df

In [None]:
def fig2_B(zscore_df: pd.DataFrame, logdir: Path, name: str) -> None:
    """Create figure 2B.

    Args:
        zscore_df: The dataframe with z-score data.
    """
    assay_sizes = zscore_df[ASSAY].value_counts()
    assays = sorted(assay_sizes.index)

    # x_title = "Assay+Sex z-score distributions - Male/Female classification disagreement separate"
    x_title = "Assay+Sex z-score distributions"
    fig = make_subplots(
        rows=1,
        cols=len(assays),
        shared_yaxes=True,
        x_title=x_title,
        y_title="z-score",
        horizontal_spacing=0.02,
        subplot_titles=[
            f"{assay_label} ({assay_sizes[assay_label]})" for assay_label in assays
        ],
    )

    for i, assay_label in enumerate(sorted(assays)):
        sub_df = zscore_df[zscore_df[ASSAY] == assay_label]

        y_values = sub_df[metric_label]
        hovertext = [
            f"{epirr}: z-score={z_score:.3f}, pred={pred:.3f}"
            for epirr, pred, z_score in zip(
                sub_df["EpiRR"],
                sub_df["Max pred"],
                sub_df[metric_label],
            )
        ]
        hovertext = np.array(hovertext)

        female_idx = np.argwhere((sub_df["True class"] == "female").values).flatten()
        male_idx = np.argwhere((sub_df["True class"] == "male").values).flatten()

        # predicted_as_female_idx = np.argwhere(
        #     (
        #         (sub_df["Predicted class"] == "female") & (sub_df["True class"] == "male")
        #     ).values
        # ).flatten()
        # predicted_as_male_idx = np.argwhere(
        #     (
        #         (sub_df["Predicted class"] == "male") & (sub_df["True class"] == "female")
        #     ).values
        # ).flatten()

        # fig.add_trace(
        #     go.Violin(
        #         name="",
        #         x0=i,
        #         y=y_values[female_idx],
        #         box_visible=True,
        #         meanline_visible=True,
        #         points="all",
        #         hovertemplate="%{text}",
        #         text=hovertext[female_idx],
        #         side="negative",
        #         line_color=sex_colors["male"],
        #         spanmode="hard",
        #         showlegend=False,
        #         marker=dict(size=1),
        #     ),
        #     row=1,
        #     col=i + 1,
        # )

        # fig.add_trace(
        #     go.Violin(
        #         name="",
        #         x0=i,
        #         y=y_values[male_idx],
        #         box_visible=True,
        #         meanline_visible=True,
        #         points="all",
        #         hovertemplate="%{text}",
        #         text=hovertext[male_idx],
        #         side="positive",
        #         line_color=sex_colors["male"],
        #         spanmode="hard",
        #         showlegend=False,
        #         marker=dict(size=1),
        #     ),
        #     row=1,
        #     col=i + 1,
        # )

        fig.add_trace(
            go.Box(
                name=assay_label,
                y=y_values[female_idx],
                boxmean=True,
                boxpoints="all",
                hovertemplate="%{text}",
                text=hovertext[female_idx],
                marker=dict(
                    size=2,
                    color=sex_colors["female"],
                    line=dict(width=0.5, color="black"),
                ),
                fillcolor=sex_colors["female"],
                line=dict(width=1, color="black"),
                showlegend=False,
                legendgroup="Female",
            ),
            row=1,
            col=i + 1,
        )

        fig.add_trace(
            go.Box(
                name=assay_label,
                y=y_values[male_idx],
                boxmean=True,
                boxpoints="all",
                hovertemplate="%{text}",
                text=hovertext[male_idx],
                marker=dict(
                    size=2, color=sex_colors["male"], line=dict(width=0.5, color="black")
                ),
                fillcolor=sex_colors["male"],
                line=dict(width=1, color="black"),
                showlegend=False,
                legendgroup="Male",
            ),
            row=1,
            col=i + 1,
        )

        # temp_y_values = y_values[predicted_as_female_idx]
        # temp_size = 1 + 5 * sub_df["Max pred"].values[predicted_as_female_idx]
        # fig.add_trace(
        #     go.Scatter(
        #         name="",
        #         x=[i - 0.2] * len(temp_y_values),
        #         y=temp_y_values,
        #         mode="markers",
        #         marker=dict(color=sex_colors["female"], size=temp_size),
        #         showlegend=False,
        #         hovertemplate="%{text}",
        #         text=hovertext[predicted_as_female_idx],
        #     ),
        #     row=1,
        #     col=i + 1,
        # )

        # temp_y_values = y_values[predicted_as_male_idx]
        # temp_size = 1 + 5 * sub_df["Max pred"].values[predicted_as_male_idx]
        # fig.add_trace(
        #     go.Scatter(
        #         name="",
        #         x=[i - 0.25] * len(temp_y_values),
        #         y=temp_y_values,
        #         mode="markers",
        #         marker=dict(color=sex_colors["male"], size=temp_size),
        #         showlegend=False,
        #         hovertemplate="%{text}",
        #         text=hovertext[predicted_as_male_idx],
        #     ),
        #     row=1,
        #     col=i + 1,
        # )

    # Add a dummy scatter plot for legend
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            name="Female",
            marker=dict(color=sex_colors["female"], size=20),
            showlegend=True,
            legendgroup="Female",
        )
    )
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            name="Male",
            marker=dict(color=sex_colors["male"], size=20),
            showlegend=True,
            legendgroup="Male",
        )
    )

    fig.update_xaxes(showticklabels=False)
    fig.update_yaxes(range=[-1.5, 3])
    title = "z-score(mean chrY coverage per file) distribution per assay"
    fig.update_layout(
        title_text=title,
        width=3000,
        height=1000,
    )

    # Save figure
    fig.write_image(logdir / f"{name}.svg")
    fig.write_image(logdir / f"{name}.png")
    fig.write_html(logdir / f"{name}.html")

    fig.show()

In [None]:
logdir = base_fig_dir / "fig2_EpiAtlas_other" / "fig2--sex_chrY_zscore"
logdir.mkdir(parents=False, exist_ok=True)
name = "fig2--sex_chrY_zscore_only_box"

In [None]:
version = "v2"
pred_data_dir = (
    base_data_dir
    / f"dfreeze_{version}"
    / "hg38_100kb_all_none"
    / f"{SEX}_1l_3000n"
    / "w-mixed"
    / "10fold-oversampling"
)
if not pred_data_dir:
    raise FileNotFoundError(f"Directory {pred_data_dir} does not exist.")
# zscore_df = prepare_fig_2B_data(version, pred_data_dir)

In [None]:
# fig2_B(zscore_df, logdir, name)

Plot z-score according to sex, merge assays except wgbs (1 violin plot, 1 point = 1 epirr)

In [None]:
def fig2_B_merged_assays(
    zscore_df: pd.DataFrame,
    sex_mislabels: Dict[str, str],
    logdir: Path,
    name: str,
    min_pred: float | None = None,
) -> None:
    """Create figure 2B.

    Args:
        zscore_df (pd.DataFrame): The dataframe with z-score data.
        sex_mislabels (Dict[str, str]): {EpiRR_no-v: corrected_sex_label}
        logdir (Path): The directory path to save the output plots.
        name (str): The base name for the output plot files.
        min_pred (float|None): Minimum prediction value to include in the plot. Used on average EpiRR 'Max pred' values.
    """
    zscore_df = zscore_df.copy(deep=True)
    zscore_df[ASSAY].replace(ASSAY_MERGE_DICT, inplace=True)

    # wgbs reverses male/female chrY tendency, so removed here
    zscore_df = zscore_df[zscore_df[ASSAY] != "wgbs"]

    # Average chrY z-score values
    mean_chrY_values_df = zscore_df.groupby(["EpiRR", SEX]).agg(
        {metric_label: "mean", "Max pred": "mean"}
    )
    mean_chrY_values_df.reset_index(inplace=True)
    if not mean_chrY_values_df["EpiRR"].is_unique:
        raise ValueError("EpiRR is not unique.")

    # Filter out low prediction values
    if min_pred is not None:
        mean_chrY_values_df = mean_chrY_values_df[
            mean_chrY_values_df["Max pred"] > min_pred
        ]

    mean_chrY_values_df.reset_index(drop=True, inplace=True)
    chrY_values = mean_chrY_values_df[metric_label]
    female_idx = np.argwhere((mean_chrY_values_df[SEX] == "female").values).flatten()  # type: ignore
    male_idx = np.argwhere((mean_chrY_values_df[SEX] == "male").values).flatten()  # type: ignore

    # Mislabels
    binary_mislabels = set(
        epirr_no_v
        for epirr_no_v, label in sex_mislabels.items()
        if label in ["male", "female"]
    )
    epirr_no_v = mean_chrY_values_df["EpiRR"].str.extract(pat=r"(\w+\d+).\d+")[0]
    mislabels_idx = np.argwhere(
        epirr_no_v.isin(binary_mislabels).values  # type: ignore
    ).flatten()

    mislabel_color_dict = {"female": sex_colors["male"], "male": sex_colors["female"]}
    mislabel_colors = [mislabel_color_dict[mean_chrY_values_df[SEX][i]] for i in mislabels_idx]  # type: ignore

    # Hovertext
    hovertext = [
        f"{epirr}: <z-score>={z_score:.3f}"
        for epirr, z_score in zip(
            mean_chrY_values_df["EpiRR"],
            mean_chrY_values_df[metric_label],
        )
    ]
    hovertext = np.array(hovertext)

    # Create figure
    fig = go.Figure()
    fig.add_trace(
        go.Box(
            name="Female",
            y=chrY_values[female_idx],
            boxmean=True,
            boxpoints="all",
            pointpos=0,
            hovertemplate="%{text}",
            text=hovertext[female_idx],
            marker=dict(size=1, color="black"),
            line=dict(width=1, color="black"),
            fillcolor=sex_colors["female"],
        ),
    )

    fig.add_trace(
        go.Box(
            name="Male",
            y=chrY_values[male_idx],
            boxmean=True,
            boxpoints="all",
            pointpos=0,
            hovertemplate="%{text}",
            text=hovertext[male_idx],
            marker=dict(size=1, color="black"),
            line=dict(width=1, color="black"),
            fillcolor=sex_colors["male"],
        ),
    )

    fig.add_trace(
        go.Scatter(
            name="Mislabel",
            x=np.zeros(len(mislabels_idx)),
            y=chrY_values[mislabels_idx],
            mode="markers",
            marker=dict(size=4, color=mislabel_colors, line=dict(width=1, color="black")),
            showlegend=False,
            hovertemplate="%{text}",
            text=hovertext[mislabels_idx],
        ),
    )

    fig.update_yaxes(range=[-1.5, 3])
    title = "z-score(mean chrY coverage per file) distribution - z-scores averaged over assays"
    if min_pred is not None:
        title += f"<br>avg_maxPred>{min_pred}"

    fig.update_layout(
        title=dict(text=title, x=0.5),
        xaxis_title=SEX,
        yaxis_title="Average z-score",
        width=750,
        height=750,
    )

    # Save figure
    this_name = f"{name}_n{mean_chrY_values_df.shape[0]}"
    fig.write_image(logdir / f"{this_name}.svg")
    fig.write_image(logdir / f"{this_name}.png")
    fig.write_html(logdir / f"{this_name}.html")

    fig.show()

In [None]:
_, epirr_mislabels = create_mislabel_corrector()
sex_mislabels = epirr_mislabels[SEX]

In [None]:
min_pred = None
name = "fig2--sex_chrY_zscore_merged_assays"
if min_pred is not None:
    name = f"fig2--sex_chrY_zscore_merged_assays_avg_maxPred>{min_pred}"

logdir = base_fig_dir / "fig2_EpiAtlas_other" / "fig2--sex_chrY_zscore"
# fig2_B_merged_assays(zscore_df, sex_mislabels, logdir, name, min_pred=min_pred)

### Sex: prediction score

In [None]:
def pred_score_violin(
    results_df: pd.DataFrame, logdir: Path, name: str, min_y: float | None = None
) -> None:
    """
    Creates a Plotly figure with violin plots and associated scatter plots for each class.
    Red scatter points, indicating a mismatch, appear on top and have a larger size.

    Args:
        results_df (pd.DataFrame): The DataFrame containing the neural network results, including metadata.
        logdir (Path): The directory where the figure will be saved.
        name (str): The name of the figure.
    Returns:
        None: Displays the plotly figure.
    """
    fig = go.Figure()

    # Class ordering
    class_labels = (
        results_df["True class"].unique().tolist()
        + results_df["Predicted class"].unique().tolist()
    )
    class_labels_sorted = sorted(set(class_labels))
    class_index = {label: i for i, label in enumerate(class_labels_sorted)}

    for label in class_labels_sorted:
        df = results_df[results_df["True class"] == label]

        # Majority vote, mean prediction score
        groupby_epirr = df.groupby(["EpiRR", "Predicted class"])["Max pred"].aggregate(
            ["size", "mean"]
        )

        groupby_epirr = groupby_epirr.reset_index().sort_values(
            ["EpiRR", "size"], ascending=[True, False]
        )
        groupby_epirr = groupby_epirr.drop_duplicates(subset="EpiRR", keep="first")
        assert groupby_epirr["EpiRR"].is_unique

        mean_pred = groupby_epirr["mean"]

        fig.add_trace(
            go.Violin(
                x=[class_index[label]] * len(mean_pred),
                y=mean_pred,
                name=label,
                spanmode="hard",
                box_visible=True,
                meanline_visible=True,
                points=False,
                fillcolor=sex_colors[label],
                line_color="black",
                showlegend=False,
            )
        )

        # Prepare data for scatter plots
        match_pred = [
            mean_pred.iloc[i]
            for i, row in enumerate(groupby_epirr.iterrows())
            if row[1]["Predicted class"] == label
        ]
        mismatch_pred = [
            mean_pred.iloc[i]
            for i, row in enumerate(groupby_epirr.iterrows())
            if row[1]["Predicted class"] != label
        ]

        jitter_match = np.random.uniform(-0.01, 0.01, len(match_pred))

        # Add scatter plots for matches in black
        fig.add_trace(
            go.Scatter(
                x=[class_index[label]] * len(match_pred) + jitter_match,
                y=match_pred,
                mode="markers",
                name=label,
                marker=dict(
                    color="black",
                    size=2,  # Standard size for matches
                ),
                hovertemplate="%{text}",
                text=[
                    f"EpiRR: {row[1]['EpiRR']}, Pred class: {row[1]['Predicted class']}, Mean pred: {row[1]['mean']:.2f}"
                    for row in groupby_epirr.iterrows()
                    if row[1]["Predicted class"] == label
                ],
                showlegend=False,
                legendgroup="match",
            )
        )

        # Add scatter plots for mismatches in red, with larger size
        mismatch_info = groupby_epirr[groupby_epirr["Predicted class"] != label]
        strong_mismatch = mismatch_info[mismatch_info["mean"] > 0.9]
        display(strong_mismatch)
        fig.add_trace(
            go.Scatter(
                x=[class_index[label]] * len(mismatch_pred),
                y=mismatch_pred,
                mode="markers",
                name=label,
                marker=dict(
                    color="red",
                    size=5,  # Larger size for mismatches
                ),
                hovertemplate="%{text}",
                text=[
                    f"EpiRR: {row[1]['EpiRR']}, Pred class: {row[1]['Predicted class']}, Mean pred: {row[1]['mean']:.3f} (n={row[1]['size']})"
                    for row in groupby_epirr.iterrows()
                    if row[1]["Predicted class"] != label
                ],
                showlegend=False,
                legendgroup="mismatch",
            )
        )

    # Add a dummy scatter plot for legend - black points
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            name="Match",
            marker=dict(color="black", size=10),
            showlegend=True,
            legendgroup="match",
        )
    )

    # Add a dummy scatter plot for legend - red points
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            name="Mismatch",
            marker=dict(color="red", size=10),
            showlegend=True,
            legendgroup="mismatch",
        )
    )

    tickvals = list(class_index.values())
    ticktext = list(class_index.keys())
    fig.update_xaxes(tickvals=tickvals, ticktext=ticktext)

    if min_y is None:
        min_y = min(results_df["Max pred"])

    fig.update_yaxes(range=[min_y, 1.001])

    fig.update_layout(
        title_text="Prediction score distribution class",
        yaxis_title="Avg. prediction score (majority class)",
        xaxis_title="Expected class label",
        width=750,
    )

    fig.update_layout(
        legend=dict(
            title_text="Legend",
            itemsizing="constant",
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1,
        )
    )

    # Save figure
    # fig.write_html(logdir / f"{name}.html")
    # fig.write_image(logdir / f"{name}.svg")
    # fig.write_image(logdir / f"{name}.png")

    fig.show()

In [None]:
# sex_df = concatenated_results["harmonized_donor_sex_w-mixed"]
# sex_df = split_results_handler.add_max_pred(sex_df)
# sex_df = metadata_handler.join_metadata(sex_df, metadata_v2)

In [None]:
name = "fig2--sex_pred_score_post_correction"
logdir = base_fig_dir / "fig2_EpiAtlas_other" / "fig2--sex_pred_score"
# pred_score_violin(sex_df, logdir, name, min_y=0.485)

PRIORITY

- REDUCED FEATURES: Graph of performance metrics per feature reduction for assay and cellType (100-0.1%)
--> naive version for now

-  Blacklisted et Winsorization

THEN

- CELL TYPE: Graph of performance metrics on cell-types after removing some assays --> waiting for details

- SEX: Distribution of pred scores, black dots for pred same class and red for pred different class

- SEX: Confusion matrix with pred scores>0.9 to identify the potential mislabeled and complementation (including Unknows) --> waiting for details


### Reduced features sets NN metrics

In [None]:
def graph_task_metrics(df: pd.DataFrame, category: str, output_dir: Path) -> None:
    """Graph the metrics of a task."""
    for metric in ["val_Accuracy", "val_F1Score"]:
        label_order = [
            "all",
            "global_tasks_union",
            "random_n4510",
            "global_tasks_intersection",
            "random_n118",
        ]
        fig = px.box(
            df,
            x="HDF5 filter",
            y=metric,
            title=f"{category}: {metric}",
            points="all",
            category_orders={
                "HDF5 filter": label_order,
                "HDF5 Resolution": ["1.0kb", "10.0kb", "100.0kb"],
            },
            color="HDF5 Resolution",
            color_discrete_sequence=px.colors.qualitative.Safe,
            width=800,
            height=800,
        )
        fig.update_traces(boxmean=True)
        fig.write_html(output_dir / f"{category}_{metric}.html")
        fig.write_image(output_dir / f"{category}_{metric}.png")

In [None]:
gen_data_dir = base_data_dir / "dfreeze_v2"
all_metrics = {}
input_sizes = defaultdict(set)
for folder in gen_data_dir.iterdir():
    feature_set = folder.name
    split_results_metrics, _ = fig2_a_content(folder)
    inverted_dict = split_results_handler.invert_metrics_dict(split_results_metrics)
    all_metrics[feature_set] = inverted_dict

In [None]:
# Identify input size
import re

input_size_line = "(1): Linear"
input_sizes = defaultdict(set)
folders = list(parent_dir.iterdir())
for folder in gen_data_dir.iterdir():
    for stdout_file in folder.rglob("output_job*.o"):
        with open(stdout_file, "r", encoding="utf8") as f:
            line = [l.rstrip() for l in f if input_size_line in l]
            if len(line) == 0:
                # print(f"Skipping {stdout_file.name}, no model description found.")
                continue
            if len(line) > 1:
                raise ValueError(
                    f"Incorrect file reading, captured more than one line in {stdout_file.name}: {line}"
                )
            input_size = int(re.match(pattern=r".*in_features=(\d+).*", string=line[0]).group(1))  # type: ignore
            input_sizes[folder.name].add(input_size)

In [None]:
# # Determine input size of runs after epigeec_filter was fixed, according to run metadata
# run_metadata = RUN_METADATA.copy(deep=True)
# run_metadata = run_metadata[run_metadata["startTimeMillis"] > 1706943404420]

# run_metadata["feature_set"] = run_metadata["Name"].str.extract(pat=r"(^hg38_1[0]{0,2}kb_.*_none).*$")

# input_sizes_count = run_metadata.groupby(["feature_set", "input_size"]).size()
# accepted_input_sizes = {idx[0]:int(idx[1]) for idx in input_sizes_count.index}

# assert len(input_sizes_count) == len(accepted_input_sizes)
# accepted_input_sizes.update({"hg38_100kb_all_none": 30321})

# for k,v in sorted(accepted_input_sizes.items()):
#     print(k,v)

In [None]:
print(all_metrics["hg38_100kb_all_none"].keys())

In [None]:
hg38_100kb_all_none_models = [
    "assay_epiclass_11c",
    "harmonized_donor_life_stage",
    "harmonized_donor_sex_w-mixed",
    "harmonized_sample_cancer_high",
    "harmonized_sample_ontology_intermediate",
]