In [None]:
"""Graph results from sample_ontology_shap_ranks.py"""

# pylint: disable=import-error,redefined-outer-name,use-dict-literal

In [None]:
%load_ext autoreload
%autoreload 2

## SETUP

In [None]:
from __future__ import annotations

from pathlib import Path
from typing import Dict, Set

import pandas as pd
import plotly.graph_objects as go
from IPython.display import display  # pylint: disable=unused-import

from epi_ml.utils.notebooks.paper.paper_utilities import CELL_TYPE, IHECColorMap

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
paper_dir = base_dir

base_fig_dir = paper_dir / "figures"
base_data_dir = base_dir / "data" / "SHAP" / "hg38_100kb_all_none"

In [None]:
IHECColorMap = IHECColorMap(base_fig_dir)
cell_type_colors = IHECColorMap.cell_type_color_map

In [None]:
cell_type_dir = base_data_dir / f"{CELL_TYPE}_1l_3000n" / "10fold-oversampling"
if not cell_type_dir.exists():
    raise ValueError(f"Directory {cell_type_dir} does not exist")

In [None]:
ranks_folder = cell_type_dir / "shap_ranks" / "merge_samplings"
if not ranks_folder.exists():
    raise ValueError(f"Directory {ranks_folder} does not exist")

In [None]:
ranks_folder_assay_level = cell_type_dir / "global_shap_analysis" / "shap_ranks" / "core7"
if not ranks_folder_assay_level.exists():
    raise ValueError(f"Directory {ranks_folder_assay_level} does not exist")

## Loading data

In [None]:
results: Dict[str, pd.DataFrame] = {}
for median_rank_file in ranks_folder.glob("*median_ranks*"):
    median_ranks_df = pd.read_csv(median_rank_file, sep="\t")

    filename = median_rank_file.stem
    cell_type = filename.replace("merge_samplings_", "").replace(
        "_feature_set_median_ranks", ""
    )
    results[cell_type] = median_ranks_df

In [None]:
results_assay_level: Dict[str, pd.DataFrame] = {}
for median_rank_file in ranks_folder_assay_level.glob("*median_ranks*"):
    median_ranks_df = pd.read_csv(median_rank_file, sep="\t")

    name = median_rank_file.stem.replace("_feature_set_median_ranks", "")
    results_assay_level[name] = median_ranks_df

## Graph median ranks

### Define functions

In [None]:
def graph_median_ranks(
    ranks: Dict[str, pd.DataFrame],
    colors: Dict[str, str],
    logdir: Path,
    name: str = "non-unique features",
) -> None:
    """Graphs the average median rank for important features in each cell type.

    Args:
        results Dict[str, pd.DataFrame]: Dictionary of cell type to DataFrame, containing median ranks for each feature and cell type subset.
        colors (Dict[str, str]): Dictionary of cell type to color.
        logdir (Path): Directory to save the figure.
        name (str): Name of the global results set. Default is "non-unique features".

    Returns:
        None
    """
    fig = go.Figure()

    # Get the number of features for each cell type
    set_size: Dict[str, int] = {
        ct: len([col for col in df.columns if "med" in col]) for ct, df in ranks.items()
    }

    if len(ranks) > 16:
        inner_trace_order = list(ranks.values())[0][["Assay", "CellType"]]
        inner_trace_order = inner_trace_order.apply("_".join, axis=1)
        if not all(
            inner_trace_order.equals(df[["Assay", "CellType"]].apply("_".join, axis=1))
            for df in ranks.values()
        ):
            raise ValueError("Assay+Cell type order is not the same for all dataframes")

    # Sort cell types by number of features
    trace_names = []
    sorted_set = sorted(set_size.items(), key=lambda x: x[1], reverse=True)

    for i, (set_name, N) in enumerate(sorted_set):
        df = ranks[set_name]

        # Average median rank for all features, for each cell type
        df = df.drop(columns=[col for col in df.columns if "iqr" in col])
        avg_median_ranks = df.mean(axis=1, numeric_only=True)
        ct_order = df["CellType"]

        # print(f"set_name: {set_name}")
        cell_type = set_name if set_name in colors else set_name.split("_", maxsplit=1)[1]

        # for x-ticks
        trace_name = f"{set_name} ({N} features)"
        trace_names.append(trace_name)

        fig.add_trace(
            go.Box(
                x=[i] * len(avg_median_ranks),
                y=avg_median_ranks,
                name=trace_name,
                boxpoints=False,
                boxmean=True,
                line=dict(color=colors[cell_type]),
                showlegend=True,
            )
        )

        marker_sizes = None
        if set_name in colors:
            marker_sizes = [6 if name == set_name else 3 for name in ct_order]
        else:
            marker_sizes = [6 if name == set_name else 3 for name in inner_trace_order]

        hovertext = None
        if set_name in colors:
            hovertext = [f"{ct}: {val}" for ct, val in zip(ct_order, avg_median_ranks)]
        else:
            hovertext = [
                f"{set_name}: {val}"
                for set_name, val in zip(inner_trace_order, avg_median_ranks)
            ]

        fig.add_trace(
            go.Scatter(
                x=[i - 0.4] * len(avg_median_ranks),
                y=avg_median_ranks,
                name=trace_name,
                mode="markers",
                marker=dict(
                    color=[colors[ct] for ct in ct_order],
                    size=marker_sizes,
                ),
                hoverinfo="text",
                hovertext=hovertext,
                showlegend=False,
            )
        )

    # Modify integer ticks to cell type names
    fig.update_xaxes(tickvals=list(range(len(set_size))), ticktext=trace_names)

    width = 800
    if len(set_size) > 16:
        width = 50 * len(set_size)
    fig.update_layout(
        title=f"Average median SHAP rank for each important cell type features<br>({name})",
        xaxis_title="Reference cell type",
        yaxis_title="Average median SHAP rank",
        height=800,
        width=width,
    )

    # Save
    figname = f"global_shap_ranks_{name}"
    fig.write_html(logdir / f"{figname}.html")
    fig.write_image(logdir / f"{figname}.png")
    fig.write_image(logdir / f"{figname}.svg")

    fig.show()

In [None]:
def return_unique_features(ranks: Dict[str, pd.DataFrame]) -> Dict[str, pd.DataFrame]:
    """Return rank columns for unique subset features only.

    Args:
        ranks (Dict[str, pd.DataFrame]): Dictionary of set name to DataFrame, containing median ranks for each important feature in each set.
    """
    ranks = ranks.copy()

    # ignore iqr
    for set_name, df in list(ranks.items()):
        ranks[set_name] = df.drop(columns=[col for col in df.columns if "iqr" in col])

    features_by_set = {
        set_name: [col for col in df.columns if "med" in col]
        for set_name, df in results.items()
    }
    other_features: Dict[str, set[str]] = {}
    for main_set in features_by_set:
        other_features[main_set] = set()
        for set_name, features in features_by_set.items():
            if set_name == main_set:
                continue
            other_features[main_set].update(set(features))

    unique_set_features: Dict[str, Set[str]] = {}
    for set_name, features in features_by_set.items():
        unique_set_features[set_name] = set(features) - other_features[set_name]

    unique_features_results: Dict[str, pd.DataFrame] = {}
    for set_name, df in ranks.items():
        col_to_drop = [
            col for col in df.columns if col not in unique_set_features[set_name]
        ]
        col_to_drop.remove("CellType")
        unique_features_results[set_name] = df.drop(columns=col_to_drop)

    return unique_features_results

### Graph merge_samplings results

In [None]:
graph_median_ranks(
    results, cell_type_colors, logdir=ranks_folder, name="non-unique (all) features"
)

In [None]:
unique_features_results = return_unique_features(results)

In [None]:
graph_median_ranks(
    results, cell_type_colors, logdir=ranks_folder, name="non-unique features"
)

In [None]:
graph_median_ranks(
    unique_features_results, cell_type_colors, logdir=ranks_folder, name="unique features"
)

### Graph assay+ct subsets results

In [None]:
# remove non-core rows
for set_name, df in results_assay_level.items():
    results_assay_level[set_name] = df[~df["Assay"].str.contains("rna|wgb")]

In [None]:
graph_median_ranks(
    results_assay_level,
    cell_type_colors,
    logdir=ranks_folder_assay_level,
    name="non-unique (all) features",
)

In [None]:
# unique_features_results_assay_level = return_unique_features(results_assay_level)

In [None]:
# graph_median_ranks(
#     unique_features_results_assay_level, cell_type_colors, logdir=ranks_folder_assay_level, name="unique features"
# )

## Find 3 most important features (median-wise) per cell type

In [None]:
feature_mapping_dir = paper_dir / "data" / "regions" / "hg38.noy.100kb.bed"

feature_mapping = pd.read_csv(feature_mapping_dir, sep="\t", header=None)
feature_mapping.columns = ["chrom", "start", "end"]

In [None]:
top3_per_cell_type = []
for cell_type, df in results.items():
    relevant_row = df[df["CellType"] == cell_type]
    relevant_row = relevant_row.drop(columns=["CellType"])

    cols_median = [col for col in relevant_row.columns if "med" in col]

    median_values = pd.Series(
        [relevant_row[col].values[0] for col in relevant_row.columns if "med" in col]
    )
    iqr_values = pd.Series(
        [relevant_row[col].values[0] for col in relevant_row.columns if "iqr" in col]
    )

    ordered_idx = median_values.argsort()
    top_3_idx = ordered_idx[0:3].tolist()

    top3_med = median_values[top_3_idx].tolist()
    top3_iqr = iqr_values[top_3_idx].to_list()

    feature_names = [cols_median[idx].split("_")[1] for idx in top_3_idx]
    feature_regions = [feature_mapping.iloc[int(idx_str), :] for idx_str in feature_names]
    formatted_regions = [
        f"{chrom}:{start}-{end}" for chrom, start, end in feature_regions
    ]
    for feature, med, iqr in zip(formatted_regions, top3_med, top3_iqr):
        top3_per_cell_type.append((cell_type, feature, med, iqr))

In [None]:
df_top3 = pd.DataFrame(
    top3_per_cell_type,
    columns=["cell_type", "feature", "rank median", "rank IQR"],
)

In [None]:
df_top3.to_csv(ranks_folder / "top3_median_per_cell_type.csv", index=False)