In [None]:
"""Workbook to create figures (fig2) destined for the paper.

Please use dfreeze v2 for these. v1 is only for fig1.
"""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines, unused-import

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import annotations

import itertools
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import display
from plotly.subplots import make_subplots
from scipy.stats import zscore
from sklearn.metrics import confusion_matrix as sk_cm

from epi_ml.core.confusion_matrix import ConfusionMatrixWriter
from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    ASSAY_MERGE_DICT,
    CELL_TYPE,
    IHECColorMap,
    MetadataHandler,
    SplitResultsHandler,
)

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
paper_dir = base_dir

if not base_fig_dir.exists():
    raise FileNotFoundError(f"Directory {base_fig_dir} does not exist.")

In [None]:
IHECColorMap = IHECColorMap(base_fig_dir)
assay_colors = IHECColorMap.assay_color_map
cell_type_colors = IHECColorMap.cell_type_color_map

In [None]:
split_results_handler = SplitResultsHandler()

## Fig 2 - EpiClass results on EpiAtlas other metadata

For following figures, use v1.1 of sample metadata (called v2.1 internally), i.e. dfreeze 2

A) Histogram of performance (accuracy and F1 scores) for each category 
B) Violin plot of average z-score on chrY per sex, black dots for pred same class and red for pred different class.  
- Do the split male female violin per assay (only FC, merge 2xwgbs and 2xrna, no rna unique_raw). 
- Use scatter for points on each side, agree same color as violin, disagree other.
- Point labels: uuid, epirr

C) ---  
D) ---  
E) --- 

### Fig 2.A

Check if all training runs were done with oversampling on.

In [None]:
v1_results_dir = base_data_dir / "dfreeze_v1"
if not v1_results_dir.exists():
    raise FileNotFoundError(f"Directory {v1_results_dir} does not exist.")

In [None]:
def check_for_oversampling(base_data_dir: Path):
    """Check for oversampling status in the results.

    Returns a ValeError if not all experiments have oversampling.
    """
    # Identify experiments
    exp_key_line = "The current experiment key is"
    exp_keys_dict = defaultdict(list)
    for category in v1_results_dir.iterdir():
        for stdout_file in category.glob("*/output_job*.o"):
            with open(stdout_file, "r", encoding="utf8") as f:
                lines = [l.rstrip() for l in f if exp_key_line in l]
            exp_keys = [l.split(exp_key_line)[1].strip() for l in lines]
            exp_keys_dict[category.name].extend(exp_keys)

    # Get all hparam values
    gen_run_metadata = (
        base_data_dir / "all_results_cometml_filtered_oversampling-fixed.csv"
    )
    run_metadata = pd.read_csv(gen_run_metadata, header=0)

    # Check oversampling values
    all_exp_keys = set()
    for exp_keys in exp_keys_dict.values():
        all_exp_keys.update(exp_keys)

    df = run_metadata[run_metadata["experimentKey"].isin(all_exp_keys)]
    if not df["hparams/oversampling"].all():
        raise ValueError("Not all experiments have oversampling.")

In [None]:
# check_for_oversampling(base_data_dir)

Histogram of performance (accuracy and F1 scores) for each category (using metadata v1)  

In [None]:
def fig2_a_content(
    exclude_categories: List[str], exclude_names: List[str]
) -> Dict[str, Dict[str, Dict[str, float]]]:
    """Create the content data for figure 2a. (get metrics for each task)

    Currently only using oversampled runs.

    Args:
        exclude_categories (List[str]): Task categories to exclude (first level directory names).
        exclude_names (List[str]): Names of folders to exclude (ex: 7c or no-mix).

    Returns:
        Dict[str, Dict[str, Dict[str, float]]] A metrics dictionary with the following structure:
            {split_name: {task_name: metrics_dict}}
    """
    all_split_results = {}
    split_results_handler = SplitResultsHandler()

    # Get the data
    results_dir = base_data_dir / "dfreeze_v2"
    for parent, _, _ in os.walk(results_dir):
        # Looking for oversampling only results
        parent = Path(parent)
        if parent.name != "10fold-oversampling":
            continue

        # Get the category
        relpath = parent.relative_to(results_dir)
        category = relpath.parts[0].rstrip("_1l_3000n")
        if category in exclude_categories:
            continue

        # Get the rest of the name, ignore certain runs
        rest_of_name = list(relpath.parts[1:])
        rest_of_name.remove("10fold-oversampling")

        if len(rest_of_name) > 1:
            raise ValueError(f"Too many parts in the name: {rest_of_name}")
        if rest_of_name:
            rest_of_name = rest_of_name[0]

        if any(name in rest_of_name for name in exclude_names):
            continue

        full_task_name = category
        if rest_of_name:
            full_task_name += f"_{rest_of_name}"

        # Get the split results
        split_results = split_results_handler.read_split_results(parent)
        all_split_results[full_task_name] = split_results

    split_results_metrics = split_results_handler.compute_split_metrics(
        all_split_results, concat_first_level=True
    )
    return split_results_metrics

In [None]:
exclude_categories = ["groups_second_level_name", "track_type"]
exclude_names = ["chip-seq", "7c"]

fig2_a_content = fig2_a_content(exclude_categories, exclude_names)

In [None]:
def fig2_a(
    split_metrics: Dict[str, Dict[str, Dict[str, float]]],
    logdir: Path,
    name: str,
) -> None:
    """Render box plots of metrics per classifier and split, each in its own subplot.

    This function generates a figure with subplots, each representing a different
    metric. Each subplot contains box plots for each classifier, ordered by accuracy.

    Args:
        split_metrics: A nested dictionary with structure {split: {classifier: {metric: score}}}.
        logdir: The directory path to save the output plots.
        name: The base name for the output plot files.
    """
    metrics = ["Accuracy", "F1_macro", "AUC_micro", "AUC_macro"]
    classifier_names = list(next(iter(split_metrics.values())).keys())

    # sort classifiers by accuracy
    mean_acc = {}
    for classifier in classifier_names:
        mean_acc[classifier] = np.mean(
            [split_metrics[split][classifier]["Accuracy"] for split in split_metrics]
        )
    classifier_names = sorted(classifier_names, key=lambda x: mean_acc[x], reverse=True)

    # Create subplots, one row for each metric
    fig = make_subplots(
        rows=1,
        cols=len(metrics),
        subplot_titles=metrics,
        horizontal_spacing=0.03,
    )

    colors = {
        classifier: px.colors.qualitative.Alphabet[i]
        for i, classifier in enumerate(classifier_names)
    }

    for i, metric in enumerate(metrics):
        for classifier_name in classifier_names:
            values = [
                split_metrics[split][classifier_name][metric] for split in split_metrics
            ]

            label_name = classifier_name
            if classifier_name == "random_10fold":
                label_name = "random23c_10fold"

            fig.add_trace(
                go.Box(
                    y=values,
                    name=label_name,
                    fillcolor=colors[classifier_name],
                    line=dict(color="black", width=1),
                    marker=dict(size=2),
                    marker_color=colors[classifier_name],
                    boxmean=True,
                    boxpoints="all",  # or "outliers" to show only outliers
                    pointpos=-1.4,
                    showlegend=i == 0,  # Only show legend in the first subplot
                    width=0.5,
                    hoverinfo="text",
                    hovertext=[
                        f"{split}: {value:.4f}"
                        for split, value in zip(split_metrics, values)
                    ],
                    legendgroup=classifier_name,
                ),
                row=1,
                col=i + 1,
            )

    fig.update_layout(
        title_text="Neural network classification - Metric distribution for 10-fold cross-validation",
        yaxis_title="Value",
        boxmode="group",
        height=1000,
        width=1750,
    )

    # Save figure
    fig.write_image(logdir / f"{name}.svg")
    fig.write_image(logdir / f"{name}.png")
    fig.write_html(logdir / f"{name}.html")

    fig.show()

In [None]:
fig_logdir = base_fig_dir / "fig2" / "fig2_A"
fig_logdir.mkdir(parents=False, exist_ok=True)
fig_name = "fig2_A"

In [None]:
# for split_name, split_metrics in all_split_metrics.items():
#     print(split_name)
#     for task_name, task_metrics in split_metrics.items():
#         print(task_name)
#         print(task_metrics)
#     print()

In [None]:
fig2_a(fig2_a_content, fig_logdir, fig_name)

### Fig 2.B

Violin plot of average z-score on chrY per sex, black dots for pred same class and red for pred different class.  

- Do the split male female violin per assay (only FC, merge 2xwgbs and 2xrna, no rna unique_raw). 
- Use scatter for points on each side, agree same color as violin, disagree other.
- Point labels: uuid, epirr

Compute chrY coverage z-score VS assay distribution, using dfreeze v1 files.

In [None]:
def compute_chrY_zscores():
    """Compute z-scores for chrY coverage data, per assay distribution.
    Use metadata v1.
    """
    # Get chrY coverage data
    chrY_coverage_dir = base_data_dir / "chrY_coverage"
    if not chrY_coverage_dir.exists():
        raise FileNotFoundError(f"Directory {chrY_coverage_dir} does not exist.")
    chrY_coverage_df = pd.read_csv(chrY_coverage_dir / "chrXY_coverage_all.csv", header=0)

    # Filter out md5s not in metadata v1
    metadata_v1 = MetadataHandler(paper_dir).load_metadata("v1")
    v1_md5s = set(metadata_v1.md5s)
    chrY_coverage_df = chrY_coverage_df[chrY_coverage_df["filename"].isin(v1_md5s)]

    # Make sure all values are non-zero
    assert (chrY_coverage_df["chrY"] != 0).all()

    # These tracks are excluded from z-score computation
    metadata_v1.remove_category_subsets("track_type", ["raw", "pval", "Unique_raw"])
    metadata_df = pd.DataFrame.from_records(list(metadata_v1.datasets))
    metadata_df[ASSAY].replace(ASSAY_MERGE_DICT, inplace=True)

    # Merge with metadata
    chrY_coverage_df = chrY_coverage_df.merge(
        metadata_df[["md5sum", ASSAY]], left_on="filename", right_on="md5sum"
    )

    # Compute z-score per assay
    chrY_dists = chrY_coverage_df.groupby(ASSAY).agg({"chrY": ["mean", "std", "count"]})
    chrY_dists.to_csv(chrY_coverage_dir / "dfreeze_v1_stats" / "chrY_coverage_stats.csv")

    # Compute z-score per assay group, merge back into the dataframe, save results
    metric_name = "chrY_zscore_vs_assay"
    groupby_df = chrY_coverage_df.groupby(ASSAY)
    for _, group in groupby_df:
        group["chrY_zscore"] = zscore(group["chrY"])
        chrY_coverage_df.loc[group.index, metric_name] = group["chrY_zscore"]

    output_cols = ["filename", "chrY", metric_name, ASSAY]
    chrY_coverage_df[output_cols].to_csv(
        chrY_coverage_dir / "dfreeze_v1_stats" / "chrY_coverage_zscore_vs_assay.csv",
        index=False,
    )

In [None]:
compute_chrY_zscores()

### Plot z-scores according to sex

In [None]:
SEX = "harmonized_donor_sex"
metric_label = "chrY_zscore_vs_assay"

In [None]:
def prepare_fig_2B_data() -> pd.DataFrame:
    """Prepare data for figure 2b."""
    # Load metadata
    meta_cols = ["md5sum", "EpiRR"]
    metadata_v1 = MetadataHandler(paper_dir).load_metadata("v1")
    metadata_df = pd.DataFrame.from_records(list(metadata_v1.datasets))
    metadata_df = metadata_df[meta_cols]

    # Load z-score data
    zscore_dir = base_data_dir / "chrY_coverage" / "dfreeze_v1_stats"
    zscore_df = pd.read_csv(zscore_dir / "chrY_coverage_zscore_vs_assay.csv", header=0)

    # Load NN predictions
    pred_data_dir = (
        base_data_dir
        / "dfreeze_v2"
        / f"{SEX}_1l_3000n"
        / "w-mixed"
        / "10fold-oversampling"
    )
    pred_df = pd.read_csv(
        pred_data_dir / "full-10fold-validation_prediction.csv", header=0, index_col=0
    )

    # Merge all
    zscore_df = zscore_df.merge(metadata_df, left_on="filename", right_on="md5sum")
    zscore_df = zscore_df.merge(pred_df, left_on="filename", right_index=True)
    zscore_df["Max pred"] = zscore_df[["female", "male"]].max(axis=1)
    zscore_df.set_index("md5sum", inplace=True)
    return zscore_df

In [None]:
def fig2_B(zscore_df: pd.DataFrame) -> None:
    """Create figure 2B.

    Args:
        zscore_df: The dataframe with z-score data.
    """
    assay_sizes = zscore_df[ASSAY].value_counts()
    assays = sorted(assay_sizes.index)

    x_title = "Assay+Sex z-score distributions - Male/Female classification disagreement separate"
    fig = make_subplots(
        rows=1,
        cols=len(assays),
        shared_yaxes=False,
        x_title=x_title,
        y_title="z-score",
        horizontal_spacing=0.02,
        subplot_titles=[
            f"{assay_label} ({assay_sizes[assay_label]})" for assay_label in assays
        ],
    )

    for i, assay_label in enumerate(sorted(assays)):
        sub_df = zscore_df[zscore_df[ASSAY] == assay_label]

        y_values = sub_df[metric_label]
        hovertext = [
            f"{epirr}: z-score={z_score:.3f}, pred={pred:.3f}"
            for epirr, pred, z_score in zip(
                sub_df["EpiRR"],
                sub_df["Max pred"],
                sub_df[metric_label],
            )
        ]
        hovertext = np.array(hovertext)

        female_idx = np.argwhere((sub_df["True class"] == "female").values).flatten()
        male_idx = np.argwhere((sub_df["True class"] == "male").values).flatten()

        predicted_as_female_idx = np.argwhere(
            (
                (sub_df["Predicted class"] == "female") & (sub_df["True class"] == "male")
            ).values
        ).flatten()
        predicted_as_male_idx = np.argwhere(
            (
                (sub_df["Predicted class"] == "male") & (sub_df["True class"] == "female")
            ).values
        ).flatten()

        fig.add_trace(
            go.Violin(
                name="",
                x0=i,
                y=y_values[female_idx],
                box_visible=True,
                meanline_visible=True,
                points="all",
                hovertemplate="%{text}",
                text=hovertext[female_idx],
                side="negative",
                line_color="red",
                spanmode="hard",
                showlegend=False,
                marker=dict(size=1),
            ),
            row=1,
            col=i + 1,
        )

        fig.add_trace(
            go.Violin(
                name="",
                x0=i,
                y=y_values[male_idx],
                box_visible=True,
                meanline_visible=True,
                points="all",
                hovertemplate="%{text}",
                text=hovertext[male_idx],
                side="positive",
                line_color="blue",
                spanmode="hard",
                showlegend=False,
                marker=dict(size=1),
            ),
            row=1,
            col=i + 1,
        )

        temp_y_values = y_values[predicted_as_female_idx]
        temp_size = 1 + 5 * sub_df["Max pred"].values[predicted_as_female_idx]
        fig.add_trace(
            go.Scatter(
                name="",
                x=[i - 0.2] * len(temp_y_values),
                y=temp_y_values,
                mode="markers",
                marker=dict(color="red", size=temp_size),
                showlegend=False,
                hovertemplate="%{text}",
                text=hovertext[predicted_as_female_idx],
            ),
            row=1,
            col=i + 1,
        )

        temp_y_values = y_values[predicted_as_male_idx]
        temp_size = 1 + 5 * sub_df["Max pred"].values[predicted_as_male_idx]
        fig.add_trace(
            go.Scatter(
                name="",
                x=[i - 0.25] * len(temp_y_values),
                y=temp_y_values,
                mode="markers",
                marker=dict(color="blue", size=temp_size),
                showlegend=False,
                hovertemplate="%{text}",
                text=hovertext[predicted_as_male_idx],
            ),
            row=1,
            col=i + 1,
        )

    # Add a dummy scatter plot for legend
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            name="Female",
            marker=dict(color="red", size=20),
            showlegend=True,
            legendgroup="Female",
        )
    )
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            name="Male",
            marker=dict(color="blue", size=20),
            showlegend=True,
            legendgroup="Male",
        )
    )

    fig.update_xaxes(showticklabels=False)
    title = "z-score(mean chrY coverage per file) distribution per assay"
    fig.update_layout(
        title_text=f"{title}",
        width=3000,
        height=1000,
    )

    # Save figure
    logdir = base_fig_dir / "fig2" / "fig2_B"
    logdir.mkdir(parents=False, exist_ok=True)
    name = "fig2_B"

    fig.write_image(logdir / f"{name}.svg")
    fig.write_image(logdir / f"{name}.png")
    fig.write_html(logdir / f"{name}.html")

    fig.show()

In [None]:
zscore_df = prepare_fig_2B_data()

In [None]:
fig2_B(zscore_df)