In [7]:
"""Workbook to create figures destined for the paper."""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines, too-many-branches

'Workbook to create figures destined for the paper.'

## SETUP

In [8]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
from __future__ import annotations

import json
from collections import defaultdict
from pathlib import Path
from typing import Dict, Iterable, List

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import display  # pylint: disable=unused-import
from plotly.subplots import make_subplots

from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    ASSAY_MERGE_DICT,
    ASSAY_ORDER,
    CELL_TYPE,
    IHECColorMap,
    MetadataHandler,
    SplitResultsHandler,
)

In [10]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
paper_dir = base_dir

if not base_data_dir.exists():
    raise FileNotFoundError(f"Directory {base_data_dir} does not exist.")

base_results_dir = base_data_dir / "training_results" / "dfreeze_v2"
if not base_results_dir.exists():
    raise FileNotFoundError(f"Directory {base_results_dir} does not exist.")

In [11]:
IHECColorMap = IHECColorMap(base_fig_dir)
assay_colors = IHECColorMap.assay_color_map
cell_type_colors = IHECColorMap.cell_type_color_map

In [12]:
metadata_handler = MetadataHandler(paper_dir)
split_results_handler = SplitResultsHandler()

## Flagship assay/ct accuracy

cell type classifier:  

  for each assay, have a violin plot for accuracy per cell type (16 points)

In [20]:
fig_dir = base_fig_dir / "flagship"

In [22]:
path_results_cell_type = (
    base_data_dir / "training_results" / "dfreeze_v2" / "hg38_100kb_all_none"
)
if not path_results_cell_type.exists():
    raise FileNotFoundError(f"{path_results_cell_type} does not exist.")

# for path in path_results_cell_type.glob("*.csv"):
#     print(path)

In [23]:
# Load split results into one combined dataframe
ct_split_dfs = split_results_handler.gather_split_results_across_categories(
    path_results_cell_type
)["harmonized_sample_ontology_intermediate_1l_3000n_10fold-oversampling"]
ct_full_df = pd.concat(ct_split_dfs.values(), axis=0)

In [24]:
# Load metadata and join with split results
metadata_2 = metadata_handler.load_metadata("v2")
ct_full_df = metadata_handler.join_metadata(ct_full_df, metadata_2)
ct_full_df[ASSAY].replace(ASSAY_MERGE_DICT, inplace=True)

### violin version

In [45]:
def fig_flagship_ct(
    cell_type_df: pd.DataFrame,
    logdir: Path | None = None,
    name: str | None = None,
    normal_ct: bool = True,
    task_name: str = CELL_TYPE,
) -> None:
    """
    Create a figure showing the cell type classifier performance on different assays.
    Handles cases where not all cell types have samples in all assays.

    Args:
        cell_type_df (pd.DataFrame): DataFrame containing the cell type prediction results.
        logdir (Path): The directory path for saving the figure.
        name (str): The name for the saved figure.
        normal_ct (bool): Whether to use predefined colors for normal cell types.

    Returns:
        None: Displays the plotly figure.
    """
    # Assuming all classifiers have the same assays for simplicity
    assay_labels = ASSAY_ORDER
    num_assays = len(assay_labels)

    ct_labels = set(
        sorted(cell_type_df["True class"].unique())
        + sorted(cell_type_df["Predicted class"].unique())
    )
    if normal_ct:
        if len(ct_labels) != 16:
            raise AssertionError(f"Expected 16 cell type labels, got {len(ct_labels)}")

    scatter_offset = 0.1  # Scatter plot jittering

    # Calculate the size of the grid
    grid_size = int(np.ceil(np.sqrt(num_assays)))
    rows, cols = grid_size, grid_size
    # print(f"Grid size: {grid_size}x{grid_size} ({num_assays} assays)")

    # Compute assay acc values beforehand
    assay_acc_dict = {}
    subclass_sizes = {}

    # Track number of assays per cell type
    ct_assay_counts = {ct: 0 for ct in ct_labels}

    for idx, assay_label in enumerate(assay_labels):
        assay_df = cell_type_df[cell_type_df[ASSAY] == assay_label]

        # Get total samples per cell type
        subclass_sizes[assay_label] = assay_df.groupby(["True class"]).agg("size")

        # Get correct predictions per cell type
        correct_predictions = (
            assay_df[assay_df["True class"] == assay_df["Predicted class"]]
            .groupby("True class")
            .size()
        )

        # Calculate accuracies for all cell types
        ct_accuracies = {
            ct_label: (
                correct_predictions.get(ct_label, 0)
                / subclass_sizes[assay_label].get(ct_label, np.nan)
                if ct_label in subclass_sizes[assay_label]
                else np.nan
            )
            for ct_label in sorted(ct_labels)
        }

        # Update assay counts for cell types present in this assay
        for ct_label in ct_labels:
            if not np.isnan(ct_accuracies[ct_label]):
                ct_assay_counts[ct_label] += 1

        assay_acc_dict[assay_label] = ct_accuracies

    # Create subplots with a square grid
    fig = make_subplots(
        rows=rows,
        cols=cols,
        subplot_titles=ASSAY_ORDER,
        shared_yaxes="all",  # type: ignore
        horizontal_spacing=0,
        vertical_spacing=0.02,
        y_title="Cell type subclass accuracy",
    )

    for idx, assay_label in enumerate(ASSAY_ORDER):
        row, col = divmod(idx, grid_size)

        # Filter out NaN values for violin plot
        acc_values = [v for v in assay_acc_dict[assay_label].values() if not np.isnan(v)]
        valid_ct_labels = [
            ct
            for ct in sorted(ct_labels)
            if not np.isnan(assay_acc_dict[assay_label][ct])
        ]

        if acc_values:  # Only create violin plot if there are valid values
            fig.add_trace(
                go.Violin(
                    x=[idx] * len(acc_values),
                    y=acc_values,
                    name=assay_label,
                    spanmode="hard",
                    box_visible=True,
                    meanline_visible=True,
                    points=False,
                    fillcolor=assay_colors[assay_label],
                    line_color="white",
                    line=dict(width=0.8),
                    showlegend=False,
                ),
                row=row + 1,
                col=col + 1,
            )

        fig.update_xaxes(showticklabels=False)

        # Scatter plot data preparation
        if valid_ct_labels:
            jittered_x_positions = (
                np.random.uniform(
                    -scatter_offset, scatter_offset, size=len(valid_ct_labels)
                )
                + idx
                - 0.35
            )
            valid_colors = [
                cell_type_colors[ct_label]
                if normal_ct
                else px.colors.qualitative.Dark24[list(ct_labels).index(ct_label)]
                for ct_label in valid_ct_labels
            ]

            # Create hover text safely using .get()
            hover_texts = [
                f"{ct_label} ({assay_acc_dict[assay_label][ct_label]:.3f}, "
                f"n={subclass_sizes[assay_label].get(ct_label, 0)})"
                for ct_label in valid_ct_labels
            ]

            scatter_marker_size = 10
            fig.add_trace(
                go.Scatter(
                    x=jittered_x_positions,
                    y=[assay_acc_dict[assay_label][ct] for ct in valid_ct_labels],
                    mode="markers",
                    marker=dict(size=scatter_marker_size, color=valid_colors),
                    hovertemplate="%{text}",
                    text=hover_texts,
                    showlegend=False,
                ),
                row=row + 1,
                col=col + 1,
            )

    # Add a dummy scatter plot for legend
    for ct_label in sorted(ct_labels):
        color = (
            cell_type_colors[ct_label]
            if normal_ct
            else px.colors.qualitative.Dark24[list(ct_labels).index(ct_label)]
        )

        legend_name = f"{ct_label} ({ct_assay_counts[ct_label]} assays)"
        fig.add_trace(
            go.Scatter(
                x=[None],
                y=[None],
                mode="markers",
                name=legend_name,
                marker=dict(color=color, size=scatter_marker_size),
                showlegend=True,
            )
        )

    fig.update_yaxes(range=[0, 1.01])

    title_text = f"{task_name.replace('_', ' ')} classifier: Accuracy per assay"
    if name and "minPredScore" in name:
        title_text += f" (minPredScore={name.split('minPredScore', 1)[1]})"

    fig.update_layout(
        title_text=title_text,
        height=1500,
        width=1500,
    )

    # Save figure
    if logdir:
        name = name if name else "ct_assay_accuracy_violin"
        fig.write_image(logdir / f"{name}.svg")
        fig.write_image(logdir / f"{name}.png")
        fig.write_html(logdir / f"{name}.html")

    fig.show()

In [46]:
fig_dir = base_fig_dir / "flagship" / "ct_assay_accuracy"
# fig_flagship_ct(ct_full_df, logdir=fig_dir, name="ct_assay_accuracy_violin_10kb")
fig_flagship_ct(ct_full_df)

In [27]:
gen_data_dir = base_data_dir / "training_results" / "dfreeze_v2" / "hg38_100kb_all_none"
all_split_results = split_results_handler.general_split_metrics(
    results_dir=gen_data_dir,
    merge_assays=True,
    return_type="split_results",
    include_categories=[CELL_TYPE, "cell_type_martin", "cell_type_PE"],
    exclude_names=["16ct", "27ct", "7c", "chip-seq-only"],
)
all_concat_results = split_results_handler.concatenate_split_results(
    all_split_results,  # type: ignore
    concat_first_level=True,
)

In [28]:
new_ct_meta_path = base_data_dir / "metadata" / "Martin_class_v3_041224.tsv"
new_ct_meta_df = pd.read_csv(
    new_ct_meta_path,
    sep="\t",
    names=["epirr_id_without_version", "cell_type_martin", "cell_type_PE"],
)

In [None]:
fig_dir = (
    base_fig_dir
    / "flagship"
    / "ct_assay_accuracy"
    / "other_cell_type_groupings"
    / "violin"
)
if not fig_dir.exists():
    fig_dir.mkdir(parents=True, exist_ok=True)

for task_name, df in all_concat_results.items():
    this_logdir = fig_dir / str(task_name)
    if not this_logdir.exists():
        this_logdir.mkdir(parents=False, exist_ok=True)

    full_df = metadata_handler.join_metadata(df, metadata_2)  # type: ignore
    full_df = full_df.merge(new_ct_meta_df, on="epirr_id_without_version")
    full_df[ASSAY].replace(ASSAY_MERGE_DICT, inplace=True)

    if "Max pred" not in full_df.columns:
        if any(label not in full_df.columns for label in ["Predicted class", "split"]):
            raise ValueError(
                "Cannot find predicted class or split column, cannot compute max pred."
            )
        full_df = full_df.copy(deep=True)
        idx1: int = full_df.columns.get_loc("Predicted class")  # type: ignore
        idx2: int = full_df.columns.get_loc("split")  # type: ignore
        full_df["Max pred"] = full_df[full_df.columns[idx1 + 1 : idx2]].max(axis=1)

    # for min_pred_score in [0, 0.6, 0.8]:
    for min_pred_score in [0]:
        full_df = full_df[full_df["Max pred"] >= min_pred_score]
        fig_flagship_ct(
            full_df,
            normal_ct=False,
            task_name=task_name,  # type: ignore
            logdir=this_logdir,
            name=f"{task_name}_ct_assay_accuracy_violin_minPredScore{min_pred_score:.2f}",
        )

### boxplot version

In [15]:
def fig_flagship_ct_boxplot(
    cell_type_df: pd.DataFrame, name: str | None = None, logdir: Path | None = None
) -> None:
    """
    Generates a boxplot for cell type classification accuracy across different assays.

    This function creates a single figure with boxplots for each assay, displaying the accuracy
    of cell type classification. Each boxplot represents the distribution of accuracy for one assay
    across different cell types.

    Args:
        cell_type_df: DataFrame containing the cell type prediction results.
        logdir: The directory path for saving the figure.
        name: The name for the saved figure.
    """
    # Assuming all classifiers have the same assays for simplicity
    assay_labels = sorted(ASSAY_ORDER)
    ct_labels = sorted(cell_type_df["True class"].unique())

    if len(ct_labels) != 16:
        raise AssertionError(f"Expected 16 cell type labels, got {len(ct_labels)}")

    assay_acc_dict = defaultdict(list)
    for assay_label in assay_labels:
        assay_df = cell_type_df[cell_type_df[ASSAY] == assay_label]

        # cell type subclass accuracy
        subclass_size = assay_df.groupby(["True class"]).agg("size")
        pred_confusion_matrix = assay_df.groupby(["True class", "Predicted class"]).agg(
            "size"
        )
        for ct_label in sorted(ct_labels):
            acc = pred_confusion_matrix[ct_label][ct_label] / subclass_size[ct_label]
            assay_acc_dict[assay_label].append(acc)

    # assay_sorted_by_mean_acc = sorted(
    #     assay_acc_dict, key=lambda x: np.mean(assay_acc_dict[x]), reverse=True
    # )

    # Create the boxplot
    fig = go.Figure()
    for assay_label in ASSAY_ORDER:
        # Select accuracies corresponding to the current assay
        assay_accuracies = assay_acc_dict[assay_label]
        assert len(assay_accuracies) == 16
        fig.add_trace(
            go.Box(
                y=assay_accuracies,
                name=assay_label,
                boxpoints="outliers",
                boxmean=True,
                fillcolor=assay_colors[assay_label],
                line_color="black",
                showlegend=False,
                marker=dict(size=2),
            )
        )

    fig.update_yaxes(range=[0.34, 1.01])

    title_text = f"{CELL_TYPE.replace('_', ' ').title()} classifier: Accuracy per assay"
    fig.update_layout(
        title=title_text,
        yaxis_title="Accuracy",
        xaxis_title="Assay",
        height=600,
        width=1000,
    )

    # Save and display the figure
    if logdir:
        this_name = name if name else "ct_assay_accuracy_boxplot"
        fig.write_image(logdir / f"{this_name}.svg")
        fig.write_image(logdir / f"{this_name}.png")
        fig.write_html(logdir / f"{this_name}.html")

    fig.show()

In [16]:
# fig_flagship_ct_boxplot(ct_full_df, logdir=fig_dir, name="ct_assay_accuracy_boxplot_10kb")
fig_flagship_ct_boxplot(ct_full_df)

### Subclass (assay, ct, life_stage) accuracy

In [17]:
assay_labels = sorted(ct_full_df[ASSAY].unique())
ct_labels = sorted(ct_full_df[CELL_TYPE].unique())
life_stages = ct_full_df["harmonized_donor_life_stage"].unique().tolist()
life_stages.remove("unknown")

acc_list = []
for assay_label in assay_labels:
    assay_df = ct_full_df[ct_full_df[ASSAY] == assay_label]
    for ct in ct_labels:
        ct_df = assay_df[assay_df[CELL_TYPE] == ct]
        for life_stage in life_stages:
            life_stage_df = ct_df[ct_df["harmonized_donor_life_stage"] == life_stage]
            acc = life_stage_df["Predicted class"].eq(life_stage_df["True class"]).mean()
            size = len(life_stage_df)
            acc_list.append((assay_label, ct, life_stage, size, acc))

acc_df = pd.DataFrame(
    acc_list, columns=["Assay", "Cell Type", "Life Stage", "Size", "Accuracy"]
)

In [18]:
# acc_df.to_csv(base_fig_dir / "flagship" / "assay_ct_life_stage_accuracy.csv")

## Input classification - Hdf5 values at top SHAP features

comparer le signal de input dans les différents cell-types (e.g. boxplot des 40 régions "input & t cell" vs les 9 régions "input & lymphocyte of b lineage" dans ces 2 CT (donc 4 boxplots) ou même ajouter un autre CT externe comme muscle comme ctrl neg (6 boxplots))

all input, check 40 and 9 regions in
- t cell
- b cell
- muscle

In [19]:
def flagship_supp_shap_input(
    paper_dir: Path, base_results_dir: Path, logdir: Path | None = None
) -> None:
    """
    Plot hdf5 values of different (input, cell type) pairs for top SHAP values
    of T cell and B cell. (only ct which produced common top shap values features)

    Args:
        paper_dir (Path): The directory containing the paper data.
        logdir (Path, optional): The directory to save the figure. Defaults to None.

    Returns:
        None: Displays the plotly figure.
    """
    # Load relevant metadata
    ct_labels = ["T cell", "lymphocyte of B lineage", "neutrophil", "muscle organ"]
    metadata_2 = MetadataHandler(paper_dir).load_metadata("v2")
    metadata_2.select_category_subsets(ASSAY, ["input"])
    metadata_2.select_category_subsets(CELL_TYPE, ct_labels)
    md5_per_ct = metadata_2.md5_per_class(CELL_TYPE)

    # Load feature bins
    task_dir = (
        base_results_dir
        / "hg38_100kb_all_none"
        / f"{CELL_TYPE}_1l_3000n"
        / "10fold-oversampling"
    )
    if not task_dir.exists():
        raise FileNotFoundError(f"Directory {task_dir} does not exist.")
    hdf5_val_dir = task_dir / "global_shap_analysis" / "top303" / "input"
    if not hdf5_val_dir.exists():
        raise FileNotFoundError(f"Directory {hdf5_val_dir} does not exist.")

    feature_filepath = hdf5_val_dir / "features_n8.json"
    with open(feature_filepath, "r", encoding="utf8") as f:
        features: Dict[str, List[int]] = json.load(f)

    # Load feature values
    hdf5_val_path = hdf5_val_dir / "hdf5_values_100kb_all_none_input_4ct_features_n8.csv"
    df = pd.read_csv(hdf5_val_path, index_col=0, header=0)

    df_ct_dict = {}
    for ct in ct_labels:
        md5s = md5_per_ct[ct]
        df_ct_dict[ct] = df.loc[md5s]

    # Make two groups of boxplots, four boxplot per group (one per cell type)
    # Each boxplot will take the values of the columns in the features dict
    fig = go.Figure()
    for ct_label, df_ct in df_ct_dict.items():
        for i, (cell_type, top_shap_bins) in enumerate(features.items()):
            top_shap_bins = [str(b) for b in top_shap_bins]
            mean_bin_values_per_md5 = df_ct[top_shap_bins].mean(axis=1)

            hovertext = [
                f"{md5}. {val:02f}"
                for md5, val in zip(df_ct.index, mean_bin_values_per_md5)
            ]

            fig.add_trace(
                go.Violin(
                    x=[f"Top SHAP '{cell_type}' bins"] * len(mean_bin_values_per_md5),
                    y=mean_bin_values_per_md5,
                    name=f"{ct_label} (n={len(hovertext)})",
                    points="all",
                    box_visible=True,
                    meanline_visible=True,
                    spanmode="hard",
                    fillcolor=cell_type_colors[ct_label],
                    line_color="black",
                    showlegend=i == 0,
                    marker=dict(size=2),
                    hovertemplate="%{text}",
                    text=hovertext,
                    legendgroup=ct_label,
                )
            )

    fig.update_layout(
        title="Input files bin z-score values for top SHAP features",
        yaxis_title="Mean z-score across bins",
        xaxis_title="Top SHAP features groups",
        boxmode="group",
        violinmode="group",
        height=1000,
        width=1000,
    )

    if logdir:
        name = "input_feature_values_top_shap_bins_per_md5"
        fig.write_image(logdir / f"{name}.svg")
        fig.write_image(logdir / f"{name}.png")
        fig.write_html(logdir / f"{name}.html")

    fig.show()

In [20]:
flagship_supp_shap_input(paper_dir, base_results_dir=base_results_dir)
# flagship_supp_shap_input(paper_dir, logdir=base_fig_dir / "flagship")

## Cell type classification - Training per assay results

Training with unique assay

In [20]:
results_dir_1 = (
    base_results_dir
    / "hg38_100kb_all_none"
    / "harmonized_sample_ontology_intermediate_1l_3000n"
    / "10fold-oversampling-unique_assay"
)
results_dir_2 = (
    base_results_dir
    / "hg38_cpg_topvar_200bp_10kb_coord_n30k/harmonized_sample_ontology_intermediate_1l_3000n/10fold-oversampling_wgbs-only"
)

assay_results = {}
for assay_folder in results_dir_1.glob("*_only"):
    assay_results[assay_folder.name] = split_results_handler.read_split_results(
        assay_folder
    )

assay_results[
    "wgbs_only-cpg_topvar_200bp_n30321"
] = split_results_handler.read_split_results(results_dir_2)

assay_metrics = split_results_handler.compute_split_metrics(
    assay_results, concat_first_level=True
)
assay_metrics = split_results_handler.invert_metrics_dict(assay_metrics)

Mixed assays training

In [None]:
result_dir_3 = (
    base_results_dir
    / "hg38_100kb_all_none"
    / "harmonized_sample_ontology_intermediate_1l_3000n"
    / "10fold-oversampling"
)

metadata_2_df = MetadataHandler(paper_dir).load_metadata_df("v2")

split_results = split_results_handler.read_split_results(result_dir_3)
acc_per_assay = split_results_handler.compute_acc_per_assay(split_results, metadata_2_df)

In [None]:
# Convert to save format as compute_split_metrics outputs.
for col in acc_per_assay.columns:
    split_accs = acc_per_assay[col].to_dict()
    for split, acc in list(split_accs.items()):
        split_accs[split] = {"Accuracy": acc}
    assay_metrics[col] = split_accs

Plot

In [27]:
def NN_performance_per_assay(
    assay_metrics: Dict[str, Dict[str, Dict[str, float]]],
    name: str | None = None,
    logdir: Path | None = None,
    title_end: str = "",
    y_range: None | List[float] = None,
):
    """Create a box plot of assay accuracy for each classifier."""
    fig = go.Figure()
    for assay in ASSAY_ORDER:
        try:
            metrics_unique_assay = assay_metrics[f"{assay}_only"]
            unique_assay_acc = {
                split: metrics_unique_assay[split]["Accuracy"]
                for split in metrics_unique_assay
            }

            metrics_assay_mixed = assay_metrics[f"{assay}"]
            mixed_assay_acc = {
                split: metrics_assay_mixed[split]["Accuracy"]
                for split in metrics_assay_mixed
            }
        except KeyError:
            print(f"KeyError. Skipping '{assay}'")
            continue

        fig.add_trace(
            go.Box(
                y=list(unique_assay_acc.values()),
                name=f"{assay}_only",
                boxmean=True,
                boxpoints="all",
                showlegend=True,
                marker=dict(size=3, color="black"),
                line=dict(width=1, color="black"),
                fillcolor=assay_colors[assay],
                hovertemplate="%{text}",
                text=[
                    f"{split}: {value:.4f}" for split, value in unique_assay_acc.items()
                ],
            )
        )

        fig.add_trace(
            go.Box(
                y=list(mixed_assay_acc.values()),
                name=f"{assay}_mixed",
                boxmean=True,
                boxpoints="all",
                showlegend=True,
                marker=dict(size=3, color="black"),
                line=dict(width=1, color="black"),
                fillcolor=assay_colors[assay],
                hovertemplate="%{text}",
                text=[
                    f"{split}: {value:.4f}" for split, value in mixed_assay_acc.items()
                ],
            )
        )

    name_other = "wgbs_only-cpg_topvar_200bp_n30321"
    wgbs_extra = assay_metrics[name_other]
    acc = {split: wgbs_extra[split]["Accuracy"] for split in wgbs_extra}
    fig.add_trace(
        go.Box(
            y=list(acc.values()),
            name="wgbs_cpg*",
            boxmean=True,
            boxpoints="all",
            showlegend=True,
            marker=dict(size=3, color="black"),
            line=dict(width=1, color="black"),
            fillcolor=assay_colors["wgbs"],
            hovertemplate="%{text}",
            text=[f"{split}: {value:.4f}" for split, value in acc.items()],
        )
    )

    if y_range is not None:
        fig.update_yaxes(range=y_range)

    title_text = "NN classification (100kb_all_none) - Sample ontology - Training per assay VS mixed"
    if title_end:
        title_text += f" - {title_end}"
    fig.update_layout(
        title_text=title_text,
        yaxis_title="Accuracy",
        xaxis_title="Assay",
        width=1000,
        height=700,
    )

    # Save figure
    if logdir:
        name = name if name else "acc_training_per_assay"
        fig.write_image(logdir / f"{name}.svg")
        fig.write_image(logdir / f"{name}.png")
        fig.write_html(logdir / f"{name}.html")

    fig.show()

In [31]:
logdir = base_fig_dir / "flagship" / "cell_type_unique_assay_training"
for y_min in [0.4, 0.5]:
    NN_performance_per_assay(
        assay_metrics=assay_metrics,
        name=f"cell_type_training_per_assay_Y_{y_min:.2f}",
        logdir=logdir,
        y_range=[y_min, 1.001],
    )
# NN_performance_per_assay(
#     assay_metrics=assay_metrics,
#     y_range=[y_min, 1.001],
# )

KeyError. Skipping 'rna_seq'


KeyError. Skipping 'rna_seq'


## Metrics for selected bins throughout whole dataset.

In [None]:
def plot_important_features_metrics(
    important_features: Dict[str, List[int]],
    npz_file_path: Path,
    logdir: Path | None = None,
    include_categories: Iterable[str] | None = None,
) -> None:
    """Using the important features positions, plot (violin) the mean values according to the given npz file.

    Adds a violin for a random feature set of the same size, and one for the global distribution.

    Compute the KS test for the random features and the global distribution, and add the p-value to the plot.

    Args:
    - important_features: A dictionary with category names as keys, and lists of feature positions as values.
    - npz_file_path: The path to the npz file containing the bin metrics.
    - logdir: The directory where to save the plots.
    - include_categories: The categories to include in the plot.

    Returns:
    - A dictionary with category names as keys, and tuples of sample size and p-values as values.
    """
    with np.load(npz_file_path) as data:
        bin_metrics = {metric: data[metric] for metric in data.keys()}

    means = bin_metrics["mean"]

    for category_name, features_pos in important_features.items():
        if include_categories and category_name not in include_categories:
            continue

        fig = go.Figure()

        selected_features = [means[pos] for pos in features_pos]
        fig.add_trace(
            go.Violin(
                y=selected_features,
                name=f"{category_name} features (N={len(features_pos)})",
                box_visible=True,
                meanline_visible=True,
                points="all",
                hovertemplate="%{text}",
                text=[f"Bin {pos}" for pos in features_pos],
            )
        )

        # Random features comparison
        N = len(features_pos)
        random_features_pos = np.random.choice(len(means), size=N, replace=False)
        random_features = [means[pos] for pos in random_features_pos]
        fig.add_trace(
            go.Violin(
                y=random_features,
                name=f"Random features (N={N})",
                box_visible=True,
                meanline_visible=True,
                points="all",
                hovertemplate="%{text}",
                text=[f"Bin {pos}" for pos in random_features_pos],
            )
        )

        # Global distribution comparison
        fig.add_trace(
            go.Violin(
                y=means,
                name=f"All features N={len(means)}",
                box_visible=True,
                meanline_visible=True,
                points=False,
            )
        )

        # Small points
        fig.update_traces(marker=dict(size=2))

        fig.update_layout(
            title=f"{npz_file_path.stem.split('_')[0]} mean values for {category_name} features",
            xaxis_title="Feature set",
            yaxis_title="Mean values",
            violinmode="group",
        )

        if logdir:
            fig.write_html(logdir / f"{npz_file_path.stem}_{category_name}_violin.html")
            fig.write_image(logdir / f"{npz_file_path.stem}_{category_name}_violin.png")

        fig.show()

### ChromScore / Chromatin Activity

Using ChromScore values 1698 EpiRRs.

In [43]:
important_features_path = (
    base_data_dir
    / "training_results"
    / "dfreeze_v2"
    / "hg38_100kb_all_none"
    / "global_shap_info"
    / "global_task_features.json"
)
if not important_features_path.exists():
    raise FileNotFoundError(f"File {important_features_path} does not exist.")

with open(important_features_path, "r", encoding="utf8") as f:
    important_features = json.load(f)

npz_filepath = base_data_dir / "ChromScore" / "ChromScore_metrics_raw.npz"
if not npz_filepath.exists():
    raise FileNotFoundError(f"File {npz_filepath} does not exist.")

In [41]:
plot_important_features_metrics(
    important_features=important_features,
    npz_file_path=npz_filepath,
    include_categories=[CELL_TYPE],
)