In [None]:
"""Workbook to create figures destined for the paper."""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines, unused-import

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import annotations

import itertools
from collections import defaultdict
from pathlib import Path
from typing import Dict

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import display
from plotly.subplots import make_subplots
from sklearn.metrics import confusion_matrix as sk_cm

from epi_ml.core.confusion_matrix import ConfusionMatrixWriter
from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    CELL_TYPE,
    IHECColorMap,
    MetadataHandler,
    SplitResultsHandler,
    merge_similar_assays,
)

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
paper_dir = base_dir

if not base_fig_dir.exists():
    raise FileNotFoundError(f"Directory {base_fig_dir} does not exist.")

In [None]:
IHECColorMap = IHECColorMap(base_fig_dir)
assay_colors = IHECColorMap.assay_color_map
cell_type_colors = IHECColorMap.cell_type_color_map

In [None]:
split_results_handler = SplitResultsHandler()

## Fig 2 - EpiClass results on EpiAtlas other metadata

For all sub-figures 2+3 use v1.1 of sample metadata (called v2.1 internally)

A) Histogram of performance (accuracy and F1 scores) for each category (using metadata v1)  
B) Violin plot of average z-score on chrY per sex, black dots for pred same class and red for pred different class.  
- Do the split male female violin per assay (only FC, merge 2xwgbs and 2xrna, no rna unique_raw). 
- Use scatter for points on each side, agree same color as violin, disagree other.
- Point labels: uuid, epirr  

C) ---  
D) ---  
E) SHAP cell-types GO  


### Fig 2.A

Check if all training runs were done with oversampling on.

In [None]:
v1_results_dir = base_data_dir / "dfreeze_v1"
if not v1_results_dir.exists():
    raise FileNotFoundError(f"Directory {v1_results_dir} does not exist.")

In [None]:
def check_for_oversampling(base_data_dir: Path):
    """Check for oversampling status in the results.

    Returns a ValeError if not all experiments have oversampling.
    """
    # Identify experiments
    exp_key_line = "The current experiment key is"
    exp_keys_dict = defaultdict(list)
    for category in v1_results_dir.iterdir():
        for stdout_file in category.glob("*/output_job*.o"):
            with open(stdout_file, "r", encoding="utf8") as f:
                lines = [l.rstrip() for l in f if exp_key_line in l]
            exp_keys = [l.split(exp_key_line)[1].strip() for l in lines]
            exp_keys_dict[category.name].extend(exp_keys)

    # Get all hparam values
    gen_run_metadata = (
        base_data_dir / "all_results_cometml_filtered_oversampling-fixed.csv"
    )
    run_metadata = pd.read_csv(gen_run_metadata, header=0)

    # Check oversampling values
    all_exp_keys = set()
    for exp_keys in exp_keys_dict.values():
        all_exp_keys.update(exp_keys)

    df = run_metadata[run_metadata["experimentKey"].isin(all_exp_keys)]
    if not df["hparams/oversampling"].all():
        raise ValueError("Not all experiments have oversampling.")

In [None]:
check_for_oversampling(base_data_dir)

Histogram of performance (accuracy and F1 scores) for each category (using metadata v1)  

In [None]:
def fig2_a_content() -> Dict[str, Dict[str, Dict[str, float]]]:
    """Create the content data for figure 2a. (get metrics for each task)

    Returns:
        Dict[str, Dict[str, Dict[str, float]]] A metrics dictionary with the following structure:
            {split_name: {task_name: metrics_dict}}
    """
    metadata_v1 = MetadataHandler(paper_dir).load_metadata("v1")
    all_md5s = set(metadata_v1.md5s)

    # Get the data
    results_dir = base_data_dir / "dfreeze_v1"
    split_results_handler = SplitResultsHandler()
    split_results = split_results_handler.gather_split_results_across_categories(
        results_dir
    )

    # Verify all md5sums are part of metadata v1
    concat_results = split_results_handler.concatenate_split_results(
        split_results, concat_first_level=True
    )
    for task_name, task_results in concat_results.items():
        task_md5s = set(task_results.index)
        if not task_md5s.issubset(all_md5s):
            problematic_md5s = task_md5s - all_md5s
            raise ValueError(
                f"Some md5s not in metadata v1 for {task_name}: {problematic_md5s}"
            )

    split_results_metrics = split_results_handler.compute_split_metrics(
        split_results, concat_first_level=True
    )
    return split_results_metrics

In [None]:
def fig2_a(
    split_metrics: Dict[str, Dict[str, Dict[str, float]]],
    logdir: Path,
    name: str,
) -> None:
    """Render box plots of metrics per classifier and split, each in its own subplot.

    This function generates a figure with subplots, each representing a different
    metric. Each subplot contains box plots for each classifier, ordered by accuracy.

    Args:
        split_metrics: A nested dictionary with structure {split: {classifier: {metric: score}}}.
        logdir: The directory path to save the output plots.
        name: The base name for the output plot files.
    """
    metrics = ["Accuracy", "F1_macro", "AUC_micro", "AUC_macro"]
    classifier_names = list(next(iter(split_metrics.values())).keys())

    # sort classifiers by accuracy
    mean_acc = {}
    for classifier in classifier_names:
        mean_acc[classifier] = np.mean(
            [split_metrics[split][classifier]["Accuracy"] for split in split_metrics]
        )
    classifier_names = sorted(classifier_names, key=lambda x: mean_acc[x], reverse=True)

    # Create subplots, one row for each metric
    fig = make_subplots(
        rows=1,
        cols=len(metrics),
        subplot_titles=metrics,
        horizontal_spacing=0.03,
    )

    colors = {
        classifier: px.colors.qualitative.Alphabet[i]
        for i, classifier in enumerate(classifier_names)
    }

    for i, metric in enumerate(metrics):
        for classifier_name in classifier_names:
            values = [
                split_metrics[split][classifier_name][metric] for split in split_metrics
            ]

            label_name = classifier_name
            if classifier_name == "random_10fold":
                label_name = "random23c_10fold"

            fig.add_trace(
                go.Box(
                    y=values,
                    name=label_name,
                    fillcolor=colors[classifier_name],
                    line=dict(color="black", width=1),
                    marker=dict(size=2),
                    marker_color=colors[classifier_name],
                    boxmean=True,
                    boxpoints="all",  # or "outliers" to show only outliers
                    pointpos=-1.4,
                    showlegend=i == 0,  # Only show legend in the first subplot
                    width=0.5,
                    hoverinfo="text",
                    hovertext=[
                        f"{split}: {value:.4f}"
                        for split, value in zip(split_metrics, values)
                    ],
                    legendgroup=classifier_name,
                ),
                row=1,
                col=i + 1,
            )

    fig.update_layout(
        title_text="Neural network classification - Metric distribution for 10-fold cross-validation",
        yaxis_title="Value",
        boxmode="group",
        height=1000,
        width=1750,
    )

    # Save figure
    fig.write_image(logdir / f"{name}.svg")
    fig.write_image(logdir / f"{name}.png")
    fig.write_html(logdir / f"{name}.html")

    fig.show()

In [None]:
fig_logdir = base_fig_dir / "fig2" / "fig2_A"
fig_logdir.mkdir(parents=False, exist_ok=True)
fig_name = "fig2_A"

In [None]:
all_split_metrics = fig2_a_content()

In [None]:
# for split_name, split_metrics in all_split_metrics.items():
#     print(split_name)
#     for task_name, task_metrics in split_metrics.items():
#         print(task_name)
#         print(task_metrics)
#     print()

In [None]:
fig2_a(all_split_metrics, fig_logdir, fig_name)