In [None]:
"""Workbook to create figures destined for the paper."""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines, too-many-branches

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import annotations

import json
from collections import defaultdict
from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from IPython.display import display  # pylint: disable=unused-import
from plotly.subplots import make_subplots

# from epi_ml.core.metadata import UUIDMetadata
from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    ASSAY_MERGE_DICT,
    ASSAY_ORDER,
    CELL_TYPE,
    IHECColorMap,
    MetadataHandler,
    SplitResultsHandler,
)

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
paper_dir = base_dir

In [None]:
IHECColorMap = IHECColorMap(base_fig_dir)
assay_colors = IHECColorMap.assay_color_map
cell_type_colors = IHECColorMap.cell_type_color_map

## Flagship assay/ct accuracy

cell type classifier:  

  for each assay, have a violin plot for accuracy per cell type (16 points)

In [None]:
fig_dir = base_fig_dir / "flagship"

In [None]:
metadata_handler = MetadataHandler(paper_dir)
split_results_handler = SplitResultsHandler()

In [None]:
path_results_cell_type = (
    base_data_dir / "training_results" / "dfreeze_v2" / "hg38_10kb_all_none"
)
if not path_results_cell_type.exists():
    raise FileNotFoundError(f"{path_results_cell_type} does not exist.")

In [None]:
# Load split results into one combined dataframe
ct_split_dfs = split_results_handler.gather_split_results_across_categories(
    path_results_cell_type
)["harmonized_sample_ontology_intermediate_1l_3000n_10fold-oversampling"]
ct_full_df = pd.concat(ct_split_dfs.values(), axis=0)

In [None]:
# Load metadata and join with split results
metadata_2 = metadata_handler.load_metadata("v2")
ct_full_df = metadata_handler.join_metadata(ct_full_df, metadata_2)
ct_full_df[ASSAY].replace(ASSAY_MERGE_DICT, inplace=True)

### violin version

In [None]:
def fig_flagship_ct(cell_type_df: pd.DataFrame, logdir: Path, name: str) -> None:
    """
    Create a figure showing the cell type classifier performance on different assays.

    Args:
        cell_type_df (pd.DataFrame): DataFrame containing the cell type prediction results.
        logdir (Path): The directory path for saving the figure.
        name (str): The name for the saved figure.

    Returns:
        None: Displays the plotly figure.
    """
    # Assuming all classifiers have the same assays for simplicity
    assay_labels = ASSAY_ORDER
    num_assays = len(assay_labels)

    ct_labels = sorted(cell_type_df["True class"].unique())
    if len(ct_labels) != 16:
        raise AssertionError(f"Expected 16 cell type labels, got {len(ct_labels)}")
    ct_colors = [cell_type_colors[ct_label] for ct_label in ct_labels]

    scatter_offset = 0.1  # Scatter plot jittering

    # Calculate the size of the grid
    grid_size = int(np.ceil(np.sqrt(num_assays)))
    rows, cols = grid_size, grid_size

    # Compute assay acc values beforehand, to be able to sort the assays by mean acc
    assay_acc_dict = {}
    subclass_sizes = {}
    for idx, assay_label in enumerate(assay_labels):
        assay_df = cell_type_df[cell_type_df[ASSAY] == assay_label]

        # cell type subclass accuracy
        subclass_sizes[assay_label] = assay_df.groupby(["True class"]).agg("size")
        pred_confusion_matrix = assay_df.groupby(["True class", "Predicted class"]).agg(
            "size"
        )

        ct_accuracies = {
            ct_label: pred_confusion_matrix[ct_label][ct_label]
            / float(subclass_sizes[assay_label][ct_label])
            for ct_label in sorted(ct_labels)
        }

        assay_acc_dict[assay_label] = ct_accuracies

    # assay_sorted_by_mean_acc = sorted(
    #     assay_acc_dict,
    #     key=lambda x: np.mean(list(assay_acc_dict[x].values())),
    #     reverse=True,
    # )

    # Create subplots with a square grid
    fig = make_subplots(
        rows=rows,
        cols=cols,
        subplot_titles=ASSAY_ORDER,
        shared_yaxes="all",  # type: ignore
        horizontal_spacing=0,
        vertical_spacing=0.02,
        y_title="Cell type subclass accuracy",
    )

    for idx, assay_label in enumerate(ASSAY_ORDER):
        row, col = divmod(idx, grid_size)

        acc_values = list(assay_acc_dict[assay_label].values())
        # Add violin plot with integer x positions
        fig.add_trace(
            go.Violin(
                x=[idx] * len(acc_values),
                y=acc_values,
                name=assay_label,
                spanmode="hard",
                box_visible=True,
                meanline_visible=True,
                points=False,
                fillcolor=assay_colors[assay_label],
                line_color="white",
                line=dict(width=0.8),
                showlegend=False,
            ),
            row=row + 1,  # Plotly rows are 1-indexed
            col=col + 1,
        )

        fig.update_xaxes(showticklabels=False)

        # Prepare data for scatter plots
        jittered_x_positions = np.random.uniform(-scatter_offset, scatter_offset, size=len(acc_values)) + idx - 0.4  # type: ignore

        scatter_marker_size = 10
        fig.add_trace(
            go.Scatter(
                x=jittered_x_positions,
                y=acc_values,
                mode="markers",
                marker=dict(size=scatter_marker_size, color=ct_colors),
                hovertemplate="%{text}",
                text=[
                    f"{ct_label} ({assay_acc_dict[assay_label][ct_label]:.3f}, n={subclass_sizes[assay_label][ct_label]})"
                    for ct_label in assay_acc_dict[assay_label]
                ],
                showlegend=False,
            ),
            row=row + 1,  # Plotly rows are 1-indexed
            col=col + 1,
        )

    # Add a dummy scatter plot for legend
    for ct_label in ct_labels:
        fig.add_trace(
            go.Scatter(
                x=[None],
                y=[None],
                mode="markers",
                name=ct_label,
                marker=dict(color=cell_type_colors[ct_label], size=scatter_marker_size),
                showlegend=True,
            )
        )

    fig.update_yaxes(range=[0.34, 1.01])

    title_text = f"{CELL_TYPE.replace('_', ' ').title()} classifier: Accuracy per assay"
    fig.update_layout(
        title=title_text,
        height=1500,
        width=1500,
    )

    # Save figure
    fig.write_image(logdir / f"{name}.svg")
    fig.write_image(logdir / f"{name}.png")
    fig.write_html(logdir / f"{name}.html")

    fig.show()

In [None]:
fig_dir = base_fig_dir / "flagship" / "ct_assay_accuracy"
# fig_flagship_ct(ct_full_df, logdir=fig_dir, name="ct_assay_accuracy_violin_10kb")

### boxplot version

In [None]:
def fig_flagship_ct_boxplot(cell_type_df: pd.DataFrame, logdir: Path, name: str) -> None:
    """
    Generates a boxplot for cell type classification accuracy across different assays.

    This function creates a single figure with boxplots for each assay, displaying the accuracy
    of cell type classification. Each boxplot represents the distribution of accuracy for one assay
    across different cell types.

    Args:
        cell_type_df: DataFrame containing the cell type prediction results.
        logdir: The directory path for saving the figure.
        name: The name for the saved figure.
    """
    # Assuming all classifiers have the same assays for simplicity
    assay_labels = sorted(ASSAY_ORDER)
    ct_labels = sorted(cell_type_df["True class"].unique())

    if len(ct_labels) != 16:
        raise AssertionError(f"Expected 16 cell type labels, got {len(ct_labels)}")

    assay_acc_dict = defaultdict(list)
    for assay_label in assay_labels:
        assay_df = cell_type_df[cell_type_df[ASSAY] == assay_label]

        # cell type subclass accuracy
        subclass_size = assay_df.groupby(["True class"]).agg("size")
        pred_confusion_matrix = assay_df.groupby(["True class", "Predicted class"]).agg(
            "size"
        )
        for ct_label in sorted(ct_labels):
            acc = pred_confusion_matrix[ct_label][ct_label] / subclass_size[ct_label]
            assay_acc_dict[assay_label].append(acc)

    # assay_sorted_by_mean_acc = sorted(
    #     assay_acc_dict, key=lambda x: np.mean(assay_acc_dict[x]), reverse=True
    # )

    # Create the boxplot
    fig = go.Figure()
    for assay_label in ASSAY_ORDER:
        # Select accuracies corresponding to the current assay
        assay_accuracies = assay_acc_dict[assay_label]
        assert len(assay_accuracies) == 16
        fig.add_trace(
            go.Box(
                y=assay_accuracies,
                name=assay_label,
                boxpoints="outliers",
                boxmean=True,
                fillcolor=assay_colors[assay_label],
                line_color="black",
                showlegend=False,
                marker=dict(size=2),
            )
        )

    fig.update_yaxes(range=[0.34, 1.01])

    title_text = f"{CELL_TYPE.replace('_', ' ').title()} classifier: Accuracy per assay"
    fig.update_layout(
        title=title_text,
        yaxis_title="Accuracy",
        xaxis_title="Assay",
        height=600,
        width=1000,
    )

    # Save and display the figure
    if not logdir.exists():
        raise FileNotFoundError(f"Could not find {logdir}")

    fig.write_image(logdir / f"{name}.svg")
    fig.write_image(logdir / f"{name}.png")
    fig.write_html(logdir / f"{name}.html")

    fig.show()

In [None]:
# fig_flagship_ct_boxplot(ct_full_df, logdir=fig_dir, name="ct_assay_accuracy_boxplot_10kb")

## Input classification - Hdf5 values at top SHAP features

comparer le signal de input dans les différents cell-types (e.g. boxplot des 40 régions "input & t cell" vs les 9 régions "input & lymphocyte of b lineage" dans ces 2 CT (donc 4 boxplots) ou même ajouter un autre CT externe comme muscle comme ctrl neg (6 boxplots))

all input, check 40 and 9 regions in
- t cell
- b cell
- muscle

In [None]:
def flagship_supp_shap_input(paper_dir: Path) -> None:
    """
    Plot hdf5 values of different (input, cell type) pairs for top SHAP values
    of T cell and B cell. (only ct which produced common top shap values features)

    Args:
        paper_dir (Path): The directory containing the paper data.
        metadata (Metadata): Metadata version 2 object.

    Returns:
        None: Displays the plotly figure.
    """
    # Load relevant metadata
    ct_labels = ["T cell", "lymphocyte of B lineage", "neutrophil", "muscle organ"]
    metadata_2 = MetadataHandler(paper_dir).load_metadata("v2")
    metadata_2.select_category_subsets(ASSAY, ["input"])
    metadata_2.select_category_subsets(CELL_TYPE, ct_labels)
    md5_per_ct = metadata_2.md5_per_class(CELL_TYPE)

    # Load feature bins
    hdf5_val_dir = (
        base_data_dir
        / "harmonized_sample_ontology_intermediate/all_splits/harmonized_sample_ontology_intermediate_1l_3000n/10fold-dfreeze-v2/global_shap_analysis/top303/input"
    )
    feature_filepath = hdf5_val_dir / "features_n8.json"
    with open(feature_filepath, "r", encoding="utf8") as f:
        features: Dict[str, List[int]] = json.load(f)

    # Load feature values
    hdf5_val_path = hdf5_val_dir / "hdf5_values_100kb_all_none_input_4ct_features_n8.csv"
    df = pd.read_csv(hdf5_val_path, index_col=0, header=0)

    df_ct_dict = {}
    for ct in ct_labels:
        md5s = md5_per_ct[ct]
        df_ct_dict[ct] = df.loc[md5s]

    # Make two groups of boxplots, four boxplot per group (one per cell type)
    # Each boxplot will take the values of the columns in the features dict
    fig = go.Figure()
    for ct_label, df_ct in df_ct_dict.items():
        for i, (cell_type, top_shap_bins) in enumerate(features.items()):
            top_shap_bins = [str(b) for b in top_shap_bins]
            mean_bin_values_per_md5 = df_ct[top_shap_bins].mean(axis=1)

            hovertext = [
                f"{md5}. {val:02f}"
                for md5, val in zip(df_ct.index, mean_bin_values_per_md5)
            ]

            fig.add_trace(
                go.Violin(
                    x=[f"Top SHAP '{cell_type}' bins"] * len(mean_bin_values_per_md5),
                    y=mean_bin_values_per_md5,
                    name=f"{ct_label} (n={len(hovertext)})",
                    points="all",
                    box_visible=True,
                    meanline_visible=True,
                    spanmode="hard",
                    fillcolor=cell_type_colors[ct_label],
                    line_color="black",
                    showlegend=i == 0,
                    marker=dict(size=2),
                    hovertemplate="%{text}",
                    text=hovertext,
                    legendgroup=ct_label,
                )
            )

    fig.update_layout(
        title="Input files bin z-score values for top SHAP features",
        yaxis_title="Mean z-score across bins",
        xaxis_title="Top SHAP features groups",
        boxmode="group",
        violinmode="group",
        height=1000,
        width=1000,
    )

    logdir = base_fig_dir / "flagship"
    name = "input_feature_values_top_shap_bins_per_md5"
    fig.write_image(logdir / f"{name}.svg")
    fig.write_image(logdir / f"{name}.png")
    fig.write_html(logdir / f"{name}.html")

    fig.show()

In [None]:
# flagship_supp_shap_input(paper_dir)

## Cell type classification - Training per assay results

In [None]:
base_results_dir = base_data_dir / "training_results" / "dfreeze_v2"
results_dir_1 = (
    base_results_dir
    / "hg38_100kb_all_none"
    / "harmonized_sample_ontology_intermediate_1l_3000n"
    / "10fold-oversampling-unique_assay"
)
results_dir_2 = (
    base_results_dir
    / "hg38_cpg_topvar_200bp_10kb_coord_n30k/harmonized_sample_ontology_intermediate_1l_3000n/10fold-oversampling_wgbs-only"
)

assay_results = {}
for assay_folder in results_dir_1.glob("*_only"):
    assay_results[assay_folder.name] = split_results_handler.read_split_results(
        assay_folder
    )

assay_results[
    "wgbs_only-cpg_topvar_200bp_n30321"
] = split_results_handler.read_split_results(results_dir_2)

assay_metrics = split_results_handler.compute_split_metrics(
    assay_results, concat_first_level=True
)
assay_metrics = split_results_handler.invert_metrics_dict(assay_metrics)

In [None]:
def NN_performance_per_assay(
    assay_metrics: Dict[str, Dict[str, Dict[str, float]]],
    logdir: Path,
    name: str,
    title_end: str = "",
    y_range: None | List[float] = None,
):
    """Create a box plot of assay accuracy for each classifier."""
    fig = go.Figure()
    for assay in ASSAY_ORDER:
        try:
            metrics = assay_metrics[f"{assay}_only"]
            assay_acc = {split: metrics[split]["Accuracy"] for split in metrics}
        except KeyError:
            print(f"KeyError. Skipping '{assay}'")
            continue

        fig.add_trace(
            go.Box(
                y=list(assay_acc.values()),
                name=assay,
                boxmean=True,
                boxpoints="all",
                showlegend=True,
                marker=dict(size=3, color="black"),
                line=dict(width=1, color="black"),
                fillcolor=assay_colors[assay],
                hovertemplate="%{text}",
                text=[f"{split}: {value:.4f}" for split, value in assay_acc.items()],
            )
        )
    name_other = "wgbs_only-cpg_topvar_200bp_n30321"
    wgbs_extra = assay_metrics[name_other]
    acc = {split: wgbs_extra[split]["Accuracy"] for split in wgbs_extra}
    fig.add_trace(
        go.Box(
            y=list(acc.values()),
            name="wgbs_cpg*",
            boxmean=True,
            boxpoints="all",
            showlegend=True,
            marker=dict(size=3, color="black"),
            line=dict(width=1, color="black"),
            fillcolor=assay_colors["wgbs"],
            hovertemplate="%{text}",
            text=[f"{split}: {value:.4f}" for split, value in acc.items()],
        )
    )

    if y_range is not None:
        fig.update_yaxes(range=y_range)

    title_text = (
        "NN classification (100kb_all_none) - Sample ontology - Training per assay"
    )
    if title_end:
        title_text += f" - {title_end}"
    fig.update_layout(
        title_text=title_text,
        yaxis_title="Accuracy",
        xaxis_title="Assay",
        width=1000,
        height=700,
    )

    # Save figure
    this_name = name
    fig.write_image(logdir / f"{this_name}.svg")
    fig.write_image(logdir / f"{this_name}.png")
    fig.write_html(logdir / f"{this_name}.html")

    fig.show()

In [None]:
logdir = base_fig_dir / "flagship" / "cell_type_unique_assay_training"
y_min = 0.4
NN_performance_per_assay(
    assay_metrics=assay_metrics,
    logdir=logdir,
    name=f"cell_type_training_per_assay_Y_{y_min:.2f}",
    y_range=[y_min, 1.001],
)