In [None]:
"""Workbook to create figures destined for the paper."""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import annotations

import json
from collections import defaultdict
from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    ASSAY_ORDER,
    ASSAY_MERGE_DICT,
    CELL_TYPE,
    IHECColorMap,
    MetadataHandler,
    SplitResultsHandler,
)

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
paper_dir = base_dir

In [None]:
IHECColorMap = IHECColorMap(base_fig_dir)
assay_colors = IHECColorMap.assay_color_map
cell_type_colors = IHECColorMap.cell_type_color_map

### Flagship paper figure

cell type classifier:  

  for each assay, have a violin plot for accuracy per cell type (16 points)

In [None]:
fig_dir = base_fig_dir / "flagship"

In [None]:
metadata_handler = MetadataHandler(paper_dir)
split_results_handler = SplitResultsHandler()
path_results_cell_type = base_data_dir / "dfreeze_v2" / "hg38_10kb_all_none"

In [None]:
# Load split results into one combined dataframe
ct_split_dfs = split_results_handler.gather_split_results_across_categories(path_results_cell_type)["harmonized_sample_ontology_intermediate_1l_3000n_10fold-oversampling"]
ct_full_df = pd.concat(ct_split_dfs.values(), axis=0)

In [None]:
# Load metadata and join with split results
metadata_2 = metadata_handler.load_metadata("v2")
ct_full_df = metadata_handler.join_metadata(ct_full_df, metadata_2)
ct_full_df[ASSAY].replace(ASSAY_MERGE_DICT, inplace=True)

#### violin version

In [None]:
def fig_flagship_ct(cell_type_df: pd.DataFrame, logdir: Path, name: str) -> None:
    """
    Create a figure showing the cell type classifier performance on different assays.

    Args:
        cell_type_df (pd.DataFrame): DataFrame containing the cell type prediction results.
        logdir (Path): The directory path for saving the figure.
        name (str): The name for the saved figure.

    Returns:
        None: Displays the plotly figure.
    """
    # Assuming all classifiers have the same assays for simplicity
    assay_labels = ASSAY_ORDER
    num_assays = len(assay_labels)

    ct_labels = sorted(cell_type_df["True class"].unique())
    if len(ct_labels) != 16:
        raise AssertionError(f"Expected 16 cell type labels, got {len(ct_labels)}")
    ct_colors = [cell_type_colors[ct_label] for ct_label in ct_labels]

    scatter_offset = 0.1  # Scatter plot jittering

    # Calculate the size of the grid
    grid_size = int(np.ceil(np.sqrt(num_assays)))
    rows, cols = grid_size, grid_size

    # Compute assay acc values beforehand, to be able to sort the assays by mean acc
    assay_acc_dict = defaultdict(dict)
    subclass_sizes = defaultdict(dict)
    for idx, assay_label in enumerate(assay_labels):
        assay_df = cell_type_df[cell_type_df[ASSAY] == assay_label]

        # cell type subclass accuracy
        subclass_sizes[assay_label] = assay_df.groupby(["True class"]).agg("size")
        subclass_groupby_acc = assay_df.groupby(["True class", "Predicted class"]).agg(
            "size"
        )
        ct_accuracies = {}
        for ct_label in sorted(ct_labels):
            acc = float(subclass_groupby_acc[ct_label][ct_label]) / float(
                subclass_sizes[assay_label][ct_label]
            )
            ct_accuracies[ct_label] = acc

        assay_acc_dict[assay_label] = ct_accuracies

    # assay_sorted_by_mean_acc = sorted(
    #     assay_acc_dict,
    #     key=lambda x: np.mean(list(assay_acc_dict[x].values())),
    #     reverse=True,
    # )

    # Create subplots with a square grid
    fig = make_subplots(
        rows=rows,
        cols=cols,
        subplot_titles=ASSAY_ORDER,
        shared_yaxes="all",  # type: ignore
        horizontal_spacing=0,
        vertical_spacing=0.02,
        y_title="Cell type subclass accuracy",
    )

    for idx, assay_label in enumerate(ASSAY_ORDER):
        row, col = divmod(idx, grid_size)

        acc_values = list(assay_acc_dict[assay_label].values())
        # Add violin plot with integer x positions
        fig.add_trace(
            go.Violin(
                x=[idx] * len(acc_values),
                y=acc_values,
                name=assay_label,
                spanmode="hard",
                box_visible=True,
                meanline_visible=True,
                points=False,
                fillcolor=assay_colors[assay_label],
                line_color="white",
                line=dict(width=0.8),
                showlegend=False,
            ),
            row=row + 1,  # Plotly rows are 1-indexed
            col=col + 1,
        )

        fig.update_xaxes(showticklabels=False)

        # Prepare data for scatter plots
        jittered_x_positions = np.random.uniform(-scatter_offset, scatter_offset, size=len(acc_values)) + idx - 0.4  # type: ignore

        scatter_marker_size = 10
        fig.add_trace(
            go.Scatter(
                x=jittered_x_positions,
                y=acc_values,
                mode="markers",
                marker=dict(size=scatter_marker_size, color=ct_colors),
                hovertemplate="%{text}",
                text=[
                    f"{ct_label} ({assay_acc_dict[assay_label][ct_label]:.3f}, n={subclass_sizes[assay_label][ct_label]})"
                    for ct_label in assay_acc_dict[assay_label]
                ],
                showlegend=False,
            ),
            row=row + 1,  # Plotly rows are 1-indexed
            col=col + 1,
        )

    # Add a dummy scatter plot for legend
    for ct_label in ct_labels:
        fig.add_trace(
            go.Scatter(
                x=[None],
                y=[None],
                mode="markers",
                name=ct_label,
                marker=dict(color=cell_type_colors[ct_label], size=scatter_marker_size),
                showlegend=True,
            )
        )

    fig.update_yaxes(range=[0.34, 1.01])

    title_text = f"{CELL_TYPE.replace('_', ' ').title()} classifier: Accuracy per assay"
    fig.update_layout(
        title=title_text,
        height=1500,
        width=1500,
    )

    # Save figure
    fig.write_image(logdir / f"{name}.svg")
    fig.write_image(logdir / f"{name}.png")
    fig.write_html(logdir / f"{name}.html")

    fig.show()

In [None]:
fig_dir = base_fig_dir / "flagship" / "ct_assay_accuracy"
fig_flagship_ct(ct_full_df, logdir=fig_dir, name="ct_assay_accuracy_violin_10kb")

#### boxplot version

In [None]:
def fig_flagship_ct_boxplot(cell_type_df: pd.DataFrame, logdir: Path, name: str) -> None:
    """
    Generates a boxplot for cell type classification accuracy across different assays.

    This function creates a single figure with boxplots for each assay, displaying the accuracy
    of cell type classification. Each boxplot represents the distribution of accuracy for one assay
    across different cell types.

    Args:
        cell_type_df: DataFrame containing the cell type prediction results.
        logdir: The directory path for saving the figure.
        name: The name for the saved figure.
    """
    # Assuming all classifiers have the same assays for simplicity
    assay_labels = sorted(ASSAY_ORDER)
    ct_labels = sorted(cell_type_df["True class"].unique())

    if len(ct_labels) != 16:
        raise AssertionError(f"Expected 16 cell type labels, got {len(ct_labels)}")

    assay_acc_dict = defaultdict(list)
    for assay_label in assay_labels:
        assay_df = cell_type_df[cell_type_df[ASSAY] == assay_label]

        # cell type subclass accuracy
        subclass_size = assay_df.groupby(["True class"]).agg("size")
        subclass_groupby_acc = assay_df.groupby(["True class", "Predicted class"]).agg(
            "size"
        )
        for ct_label in sorted(ct_labels):
            acc = subclass_groupby_acc[ct_label][ct_label] / subclass_size[ct_label]
            assay_acc_dict[assay_label].append(acc)

    # assay_sorted_by_mean_acc = sorted(
    #     assay_acc_dict, key=lambda x: np.mean(assay_acc_dict[x]), reverse=True
    # )

    # Create the boxplot
    fig = go.Figure()
    for assay_label in ASSAY_ORDER:
        # Select accuracies corresponding to the current assay
        assay_accuracies = assay_acc_dict[assay_label]
        assert len(assay_accuracies) == 16
        fig.add_trace(
            go.Box(
                y=assay_accuracies,
                name=assay_label,
                boxpoints="outliers",
                boxmean=True,
                fillcolor=assay_colors[assay_label],
                line_color="black",
                showlegend=False,
                marker=dict(size=2),
            )
        )

    fig.update_yaxes(range=[0.34, 1.01])

    title_text = f"{CELL_TYPE.replace('_', ' ').title()} classifier: Accuracy per assay"
    fig.update_layout(
        title=title_text,
        yaxis_title="Accuracy",
        xaxis_title="Assay",
        height=600,
        width=1000,
    )

    # Save and display the figure
    if not logdir.exists():
        raise FileNotFoundError(f"Could not find {logdir}")

    fig.write_image(logdir / f"{name}.svg")
    fig.write_image(logdir / f"{name}.png")
    fig.write_html(logdir / f"{name}.html")

    fig.show()

In [None]:
fig_flagship_ct_boxplot(ct_full_df, logdir=fig_dir, name="ct_assay_accuracy_boxplot_10kb")

#### Metadata 3rd factor

cell type classification: check (input, ct) pairs for enrichment in any metadata category
    - use:
        - harmonized_donor_life_stage
        - harmonized_donor_sex
        - harmonized_sample_cancer_high
        - harmonized_biomaterial
        - paired_end
        - project
    - find a 3rd factor metric, e.g. if any pair (assay, ct) subclass is very different from global dist, it can use that info as 3rd factor, and we're looking at assay specifically
        - one score per pair, pearson w accuracy vector (one vector per assay)

In [None]:
def calculate_metadata_distribution(
    df: pd.DataFrame, columns: List[str]
) -> Dict[str, pd.Series]:
    """
    Calculates the percentage of metadata labels within specified columns of a DataFrame.

    Args:
        df: A pandas DataFrame containing the data.
        columns: A list of column names to analyze.

    Returns:
        A dictionary where keys are column names and values are Series objects containing
        the percentage of each unique label in the respective column.
    """
    distribution = {}
    nb_samples = len(df)
    for column in columns:
        # Count the occurrences of each unique value in the column
        value_counts = df[column].value_counts(dropna=False)
        # Calculate the percentages
        percentages = (value_counts / nb_samples) * 100
        # Store the results in the dictionary
        distribution[column] = percentages

    return distribution

In [None]:
def compare_label_ratios(
    target_distribution: Dict[str, pd.Series],
    comparison_distributions: List[Dict[str, pd.Series]],
    labels: List[str],
) -> Dict[str, pd.DataFrame]:
    """
    Compares label ratios of a target distribution against multiple comparison distributions,
    calculating the difference in percentage points for each label within each metadata category.

    Args:
        target_distribution: A dictionary of Series representing the target distribution for comparison.
        comparison_distributions: A list of dictionaries of Series, where each dictionary
                                  represents a distribution (e.g., assay, cell type, global) for comparison.
        labels: A list of labels corresponding to each distribution in `comparison_distributions`,
                used for labeling the columns in the result.

    Returns:
        A dictionary of DataFrames, where each DataFrame shows the difference in percentage points
        for each label in a metadata category between the target distribution and each of the
        comparison distributions.
    """
    comparison_results = {}
    for category, target_series in target_distribution.items():
        # Initialize a DataFrame to store comparison results for this category
        comparison_df = pd.DataFrame()

        for label, comparison_distribution in zip(labels, comparison_distributions):
            # Ensure the comparison distribution series for this category exists and align target with comparison
            comparison_series = comparison_distribution.get(
                category, pd.Series(dtype="float64")
            )
            aligned_target, aligned_comparison = target_series.align(
                comparison_series, fill_value=0
            )

            # Calculate difference in percentage points
            difference = aligned_target - aligned_comparison

            # Store the results in the comparison DataFrame
            comparison_df[f"Difference_vs_{label}"] = difference

        comparison_results[category] = comparison_df

    return comparison_results

In [None]:
metadata_categories = [
    "harmonized_donor_life_stage",
    "harmonized_donor_sex",
    "harmonized_sample_disease_high",
    "harmonized_biomaterial_type",
    "paired_end",
    "project",
]

In [None]:
def compute_third_factor_correlation(
    ct_full_df: pd.DataFrame,
    metadata_categories: List[str],
    save_full_details: bool = False,
):
    """
    Calculates the correlation between third factor influence and cell type classification accuracy for each assay.

    This function operates on classification results to evaluate how a third factor, represented by metadata category distributions,
    correlates with the accuracy of cell type classifications across assays. It involves comparing metadata distributions
    within assay and cell type groups to a global distribution, and then correlating these comparisons with classification
    accuracies.

    Args:
        df (pd.DataFrame): DataFrame with epigenomics data, including assays, cell types, and metadata for classification.
        metadata_categories (List[str]): A list of metadata categories to analyze.
    """
    global_dist = calculate_metadata_distribution(ct_full_df, metadata_categories)
    subclass_distributions = {}
    comparison_results = (
        {}
    )  # Initialize a dict to hold comparison results for each subgroup

    for group in ct_full_df.groupby(ASSAY):
        label = group[0]
        sub_df = group[1]
        subclass_distributions[label] = calculate_metadata_distribution(
            sub_df, metadata_categories
        )

    for group in ct_full_df.groupby(CELL_TYPE):
        label = group[0]
        sub_df = group[1]
        subclass_distributions[label] = calculate_metadata_distribution(
            sub_df, metadata_categories
        )

    # Loop through each group and compare to global
    for group in ct_full_df.groupby([ASSAY, CELL_TYPE]):
        assay, cell_type = group[0]  # type: ignore
        sub_df = group[1]
        pair_subclass_dist = calculate_metadata_distribution(sub_df, metadata_categories)
        subclass_distributions[(assay, cell_type)] = pair_subclass_dist

        assay_dist = subclass_distributions[assay]
        cell_type_dist = subclass_distributions[cell_type]

        comparisons_dists = [assay_dist, cell_type_dist, global_dist]
        comparison_labels = [assay, cell_type, "global"]

        comparison_results[(assay, cell_type)] = compare_label_ratios(
            target_distribution=pair_subclass_dist,
            comparison_distributions=comparisons_dists,
            labels=comparison_labels,
        )

    pair_dfs = {}
    pairs_3rd_factor = {}
    for (assay, cell_type), comparisons in comparison_results.items():
        # Initialize an empty list to collect DataFrames for concatenation
        dfs_to_concat = []

        for category, df_comparison in comparisons.items():
            df_comparison.columns = [
                "Difference vs Assay",
                "Difference vs Cell Type",
                "Difference vs Global",
            ]
            # Add identifiers for the assay, cell type, and category
            df_comparison["Assay"] = assay
            df_comparison["Cell Type"] = cell_type
            df_comparison["Category"] = category

            subclass_dist = subclass_distributions[(assay, cell_type)][category]
            df_comparison["(assay, ct) subclass %"] = subclass_dist

            # Collect the DataFrame
            dfs_to_concat.append(df_comparison.reset_index())

        # Concatenate all DataFrames along rows
        final_df = pd.concat(dfs_to_concat, ignore_index=True)
        final_df.fillna(0, inplace=True)

        new_columns = final_df.columns.tolist()
        new_first = ["index", "Category", "(assay, ct) subclass %"]
        for label in new_first:
            new_columns.remove(label)
        new_columns = new_first + new_columns
        final_df = final_df[new_columns]

        pair_dfs[(assay, cell_type)] = final_df

        # val_3rd_factor = (final_df["Difference vs Assay"] - ).abs().sum()
        # val_3rd_factor= final_df["Difference vs Global"].abs().max()
        val_3rd_factor = final_df["Difference vs Global"].min()
        pairs_3rd_factor[(assay, cell_type)] = val_3rd_factor

    # Subclass accuracy per assay
    assay_labels = sorted(ct_full_df[ASSAY].unique())
    ct_labels = sorted(ct_full_df[CELL_TYPE].unique())
    assay_accuracies = {}
    for assay_label in assay_labels:
        assay_df = ct_full_df[ct_full_df[ASSAY] == assay_label]

        # cell type subclass accuracy
        subclass_size = assay_df.groupby(["True class"]).agg("size")
        subclass_groupby_acc = assay_df.groupby(["True class", "Predicted class"]).agg(
            "size"
        )
        accuracies = {}
        for ct_label in sorted(ct_labels):
            acc_label = subclass_groupby_acc[ct_label][ct_label] / subclass_size[ct_label]
            accuracies[ct_label] = acc_label

        assay_accuracies[assay_label] = accuracies

    # Concatenate all DataFrames along rows
    if save_full_details:
        all_pairs_df = pd.concat(pair_dfs, axis=0, ignore_index=True)
        for assay in assay_labels:
            for ct in ct_labels:
                all_pairs_df.loc[
                    (all_pairs_df["Assay"] == assay) & (all_pairs_df["Cell Type"] == ct),
                    "Accuracy",
                ] = assay_accuracies[assay][ct]

        all_pairs_df.columns = [
            "Label" if x == "index" else x for x in all_pairs_df.columns
        ]
        file_path = base_fig_dir / "flagship" / "metadata_comparison_all.csv"
        all_pairs_df.to_csv(file_path, index=False)

    pearson_3rd_factor = {}
    for assay, acc_dict in assay_accuracies.items():
        acc_vector = {ct: acc_dict[ct] for ct in ct_labels}
        acc_vector = pd.Series(acc_vector)

        diff_metric = {ct: pairs_3rd_factor[(assay, ct)] for ct in ct_labels}
        diff_metric = pd.Series(diff_metric)
        pearson = acc_vector.corr(diff_metric, method="pearson")
        pearson_3rd_factor[assay] = pearson

    return pearson_3rd_factor

In [None]:
compute_third_factor_correlation(ct_full_df, metadata_categories, save_full_details=True)

In [None]:
pearson_series = []
for metadata_category in metadata_categories:
    pearson_dict = compute_third_factor_correlation(ct_full_df, [metadata_category])
    pearson_df = pd.DataFrame(pearson_dict, index=[metadata_category])
    pearson_series.append(pearson_df)

full_pearson_df = pd.concat(pearson_series, axis=0)

# Add  max for each row and column
full_pearson_df["Max"] = full_pearson_df.abs().max(axis=1)
max_row = full_pearson_df.abs().max(axis=0)
max_row.name = "Max"
full_pearson_df = full_pearson_df.append(max_row)

full_pearson_df.to_csv(
    base_fig_dir / "flagship" / "3rd_factor_min_diff_pearson_correlation.csv"
)

In [None]:
assay_labels = sorted(ct_full_df[ASSAY].unique())
ct_labels = sorted(ct_full_df[CELL_TYPE].unique())
life_stages = ct_full_df["harmonized_donor_life_stage"].unique().tolist()
life_stages.remove("unknown")

acc_list = []
for assay_label in assay_labels:
    assay_df = ct_full_df[ct_full_df[ASSAY] == assay_label]
    for ct in ct_labels:
        ct_df = assay_df[assay_df[CELL_TYPE] == ct]
        for life_stage in life_stages:
            life_stage_df = ct_df[ct_df["harmonized_donor_life_stage"] == life_stage]
            acc = life_stage_df["Predicted class"].eq(life_stage_df["True class"]).mean()
            size = len(life_stage_df)
            acc_list.append((assay_label, ct, life_stage, size, acc))

acc_df = pd.DataFrame(
    acc_list, columns=["Assay", "Cell Type", "Life Stage", "Size", "Accuracy"]
)

In [None]:
acc_df.to_csv(base_fig_dir / "flagship" / "assay_ct_life_stage_accuracy.csv")

#### Input classification - Hdf5 values at top SHAP features

comparer le signal de input dans les différents cell-types (e.g. boxplot des 40 régions "input & t cell" vs les 9 régions "input & lymphocyte of b lineage" dans ces 2 CT (donc 4 boxplots) ou même ajouter un autre CT externe comme muscle comme ctrl neg (6 boxplots))

all input, check 40 and 9 regions in
- t cell
- b cell
- muscle

In [None]:
def flagship_supp_shap_input(paper_dir: Path) -> None:
    """
    Plot hdf5 values of different (input, cell type) pairs for top SHAP values
    of T cell and B cell. (only ct which produced common top shap values features)

    Args:
        paper_dir (Path): The directory containing the paper data.
        metadata (Metadata): Metadata version 2 object.

    Returns:
        None: Displays the plotly figure.
    """
    # Load relevant metadata
    ct_labels = ["T cell", "lymphocyte of B lineage", "neutrophil", "muscle organ"]
    metadata_2 = MetadataHandler(paper_dir).load_metadata("v2")
    metadata_2.select_category_subsets(ASSAY, ["input"])
    metadata_2.select_category_subsets(CELL_TYPE, ct_labels)
    md5_per_ct = metadata_2.md5_per_class(CELL_TYPE)

    # Load feature bins
    hdf5_val_dir = (
        base_data_dir
        / "harmonized_sample_ontology_intermediate/all_splits/harmonized_sample_ontology_intermediate_1l_3000n/10fold-dfreeze-v2/global_shap_analysis/top303/input"
    )
    feature_filepath = hdf5_val_dir / "features_n8.json"
    with open(feature_filepath, "r", encoding="utf8") as f:
        features: Dict[str, List[int]] = json.load(f)

    # Load feature values
    hdf5_val_path = hdf5_val_dir / "hdf5_values_100kb_all_none_input_4ct_features_n8.csv"
    df = pd.read_csv(hdf5_val_path, index_col=0, header=0)

    df_ct_dict = {}
    for ct in ct_labels:
        md5s = md5_per_ct[ct]
        df_ct_dict[ct] = df.loc[md5s]

    # Make two groups of boxplots, four boxplot per group (one per cell type)
    # Each boxplot will take the values of the columns in the features dict
    fig = go.Figure()
    for ct_label, df_ct in df_ct_dict.items():
        for i, (cell_type, top_shap_bins) in enumerate(features.items()):
            top_shap_bins = [str(b) for b in top_shap_bins]
            mean_bin_values_per_md5 = df_ct[top_shap_bins].mean(axis=1)

            hovertext = [
                f"{md5}. {val:02f}"
                for md5, val in zip(df_ct.index, mean_bin_values_per_md5)
            ]

            fig.add_trace(
                go.Violin(
                    x=[f"Top SHAP '{cell_type}' bins"] * len(mean_bin_values_per_md5),
                    y=mean_bin_values_per_md5,
                    name=f"{ct_label} (n={len(hovertext)})",
                    points="all",
                    box_visible=True,
                    meanline_visible=True,
                    spanmode="hard",
                    fillcolor=cell_type_colors[ct_label],
                    line_color="black",
                    showlegend=i == 0,
                    marker=dict(size=2),
                    hovertemplate="%{text}",
                    text=hovertext,
                    legendgroup=ct_label,
                )
            )

    fig.update_layout(
        title="Input files bin z-score values for top SHAP features",
        yaxis_title="Mean z-score across bins",
        xaxis_title="Top SHAP features groups",
        boxmode="group",
        violinmode="group",
        height=1000,
        width=1000,
    )

    logdir = base_fig_dir / "flagship"
    name = "input_feature_values_top_shap_bins_per_md5"
    fig.write_image(logdir / f"{name}.svg")
    fig.write_image(logdir / f"{name}.png")
    fig.write_html(logdir / f"{name}.html")

    fig.show()

In [None]:
flagship_supp_shap_input(paper_dir)