In [None]:
"""Graph results from sample_ontology_shap_ranks.py"""
# pylint: disable=import-error,redefined-outer-name,use-dict-literal

In [None]:
%load_ext autoreload
%autoreload 2

## SETUP

In [None]:
from __future__ import annotations

from pathlib import Path
from typing import Dict

import numpy as np  # pylint: disable=unused-import
import pandas as pd
import plotly.graph_objects as go
from IPython.display import display  # pylint: disable=unused-import

from epi_ml.utils.notebooks.paper.paper_utilities import CELL_TYPE, IHECColorMap

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
paper_dir = base_dir

base_fig_dir = paper_dir / "figures"
base_data_dir = (
    base_dir / "data" / "training_results" / "dfreeze_v2" / "hg38_100kb_all_none"
)

In [None]:
IHECColorMap = IHECColorMap(base_fig_dir)
cell_type_colors = IHECColorMap.cell_type_color_map

In [None]:
cell_type_dir = base_data_dir / f"{CELL_TYPE}_1l_3000n" / "10fold-oversampling"
if not cell_type_dir.exists():
    raise ValueError(f"Directory {cell_type_dir} does not exist")

In [None]:
ranks_folder = cell_type_dir / "global_shap_analysis" / "shap_ranks" / "merge_samplings"
if not ranks_folder.exists():
    raise ValueError(f"Directory {ranks_folder} does not exist")

## Loading data

In [None]:
results: Dict[str, pd.DataFrame] = {}
for median_rank_file in ranks_folder.glob("*median_ranks*"):
    median_ranks_df = pd.read_csv(median_rank_file, sep="\t")

    filename = median_rank_file.stem
    cell_type = filename.replace("merge_samplings_", "").replace(
        "_feature_set_median_ranks", ""
    )
    results[cell_type] = median_ranks_df

## Graph median ranks

In [None]:
def graph_median_ranks(
    ranks: Dict[str, pd.DataFrame],
    colors: Dict[str, str],
    logdir: Path,
    name: str = "non-unique features",
) -> None:
    """Graphs the average median rank for imporant features in each cell type.

    Args:
        results Dict[str, pd.DataFrame]: Dictionary of cell type to DataFrame, containing median ranks for each feature and cell type subset.
        colors (Dict[str, str]): Dictionary of cell type to color.
        logdir (Path): Directory to save the figure.
        name (str): Name of the global results set. Default is "non-unique features".

    Returns:
        None
    """
    fig = go.Figure()

    # Get the number of features for each cell type
    set_size: Dict[str, int] = {
        ct: len([col for col in df.columns if "med" in col]) for ct, df in ranks.items()
    }

    # Sort cell types by number of features
    trace_names = []
    for i, (cell_type, N) in enumerate(
        sorted(set_size.items(), key=lambda x: x[1], reverse=True)
    ):
        df = ranks[cell_type]

        # Average median rank for all features, for each cell type
        df = df.drop(columns=[col for col in df.columns if "iqr" in col])
        avg_median_ranks = df.mean(axis=1, numeric_only=True)
        ct_order = df["CellType"]

        # for x-ticks
        trace_name = f"{cell_type} ({N} features)"
        trace_names.append(trace_name)

        fig.add_trace(
            go.Box(
                x=[i] * len(avg_median_ranks),
                y=avg_median_ranks,
                name=trace_name,
                boxpoints=False,
                boxmean=True,
                line=dict(color=colors[cell_type]),
                showlegend=True,
            )
        )

        hovertext = [f"{ct}: {val}" for ct, val in zip(ct_order, avg_median_ranks)]
        fig.add_trace(
            go.Scatter(
                x=[i - 0.4] * len(avg_median_ranks),
                y=avg_median_ranks,
                name=trace_name,
                mode="markers",
                marker=dict(
                    size=[6 if ct == cell_type else 3 for ct in ct_order],
                    color=[colors[ct] for ct in ct_order],
                ),
                hoverinfo="text",
                hovertext=hovertext,
                showlegend=False,
            )
        )

    # Modify integer ticks to cell type names
    fig.update_xaxes(tickvals=list(range(len(set_size))), ticktext=trace_names)

    fig.update_layout(
        title=f"Average median SHAP rank for each important cell type features<br>({name})",
        xaxis_title="Reference cell type",
        yaxis_title="Average median SHAP rank",
        height=800,
    )

    # Save
    figname = f"global_shap_ranks_{name}"
    fig.write_html(logdir / f"{figname}.html")
    fig.write_image(logdir / f"{figname}.png")
    fig.write_image(logdir / f"{figname}.svg")

    fig.show()

In [None]:
graph_median_ranks(
    results, cell_type_colors, logdir=ranks_folder, name="non-unique features"
)

In [None]:
# ignore iqr
for ct, df in list(results.items()):
    results[ct] = df.drop(columns=[col for col in df.columns if "iqr" in col])

# Get features unique to each cell type
features_by_ct = {ct: df.columns[1:] for ct, df in results.items()}
other_features: Dict[str, set[str]] = {}
for main_ct in features_by_ct:
    other_features[main_ct] = set()
    for ct, features in features_by_ct.items():
        if ct == main_ct:
            continue
        other_features[main_ct].update(set(features))

unique_ct_features = {}
for ct, features in features_by_ct.items():
    # print(f"{ct}: {len(features)} features")
    unique_ct_features[ct] = set(features) - other_features[ct]
    # print(f"{ct}: {len(unique_ct_features[ct])} unique features")

# Get results for unique features only
unique_features_results: Dict[str, pd.DataFrame] = {}
for ct, df in results.items():
    col_to_drop = [col for col in df.columns if col not in unique_ct_features[ct]]
    col_to_drop.remove("CellType")
    unique_features_results[ct] = df.drop(columns=col_to_drop)

In [None]:
graph_median_ranks(
    unique_features_results, cell_type_colors, logdir=ranks_folder, name="unique features"
)