In [None]:
"""Code related to SHAP analyses for flagship paper + supplementary figures."""
# pylint: disable=line-too-long, redefined-outer-name, import-error, duplicate-code, unreachable, unused-argument, use-dict-literal, unused-import

How relevant files were downloaded

~~~bash
base_path="/home/rabyj/scratch/epilap-logs/epiatlas-dfreeze-v2.1/hg38_100kb_all_none/harmonized_sample_ontology_intermediate_1l_3000n/10fold-oversampling"
rsync --info=progress2 -aR narval:${base_path}/./split*/*.md5 .

# rsync --info=progress2 -aR --exclude "*.npz" --exclude "analysis_n*_f80.00/" narval:${base_path}/./split*/shap .
rsync --info=progress2 -aR narval:${base_path}/./split*/shap/*background*.npz .
~~~

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import tarfile
from collections import Counter, defaultdict
from math import floor
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy.special import softmax

from epi_ml.core.data_source import EpiDataSource
from epi_ml.utils.bed_utils import bed_ranges_to_bins, read_bed_to_ranges
from epi_ml.utils.notebooks.paper.paper_utilities import ASSAY, CELL_TYPE, MetadataHandler

In [None]:
CANCER = "harmonized_sample_cancer_high"

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
paper_dir = base_dir

base_data_dir = (
    base_dir / "data" / "training_results" / "dfreeze_v2" / "hg38_100kb_all_none"
)

In [None]:
chromsize_path = paper_dir / "data" / "chromsizes" / "hg38.noy.chrom.sizes"
chroms = EpiDataSource.load_external_chrom_file(chromsize_path)

In [None]:
metadata_handler = MetadataHandler(paper_dir)
metadata = metadata_handler.load_metadata("v2")
metadata_df = pd.DataFrame.from_records(list(metadata.datasets))
metadata_df.set_index("md5sum", inplace=True)
metadata_df["md5sum"] = metadata_df.index
del metadata

In [None]:
cell_type_dir = base_data_dir / f"{CELL_TYPE}_1l_3000n" / "10fold-oversampling"
cancer_dir = base_data_dir / f"{CANCER}_1l_3000n" / "10fold-oversampling"

In [None]:
for folder in [cell_type_dir, cancer_dir]:
    assert folder.exists()

### Background selection details

In [None]:
def compare_all_training_vs_background(parent_dir: Path):
    """Create a table that compares the training data to the SHAP background data for each split.

    Args:
        parent_dir: The directory containing each split/fold directory.
    """
    split_composition_dfs = {}  # pylint: disable=unused-variable
    for split_dir in parent_dir.glob("split*"):
        if not split_dir.is_dir():
            continue
        split = split_dir.name  # pylint: disable=unused-variable

        training_data_path = list(split_dir.glob("*training*.md5"))
        if len(training_data_path) != 1:
            raise ValueError(f"Multiple training data files found in {split_dir}")
        training_data_path = training_data_path[0]

        background_data_path = list(split_dir.glob("shap/*background*.npz"))
        if len(background_data_path) != 1:
            raise ValueError(f"Multiple background data files found in {split_dir}")
        background_data_path = background_data_path[0]

        training_md5s = set(
            pd.read_csv(training_data_path, index_col=False, header=None)[0].values
        )
        background_data = np.load(background_data_path)
        background_md5s = set(background_data["background_md5s"])

        # print(f"Split: {split}")
        # print(f"Training data: {len(training_md5s)}")
        # print(f"Background data: {len(background_md5s)}")

        # Sanity check: background data is a subset of training data
        diff = background_md5s - training_md5s
        if diff:
            raise ValueError(
                f"Background data is not a subset of the training data in {split_dir}: {len(diff)} md5s unique to background data."
            )

        # for diff ratio, do ((ratio background - ration training))*100 -> positive if background has more
        # do assay / cell type / output class ratios
        # table should have full number + ratios + diff ratio for each metadata category
        # do for each split, save to a file, then average accross splits

    raise NotImplementedError("Finish this function")

In [None]:
# for folder in [cell_type_dir, cancer_dir]:
#     compare_all_training_vs_background(folder)

### SHAP values ranking

SHAP rank significance analysis (sample ontology)

- Gather rank of all bins for sample output class (20k x 30k matrix)
- For each % of highest values starting from top to 10%, compute mean SHAP value (20k x 10 matrix)
- Create a graph that represents the ratio between 0-1% and 1-2% chunks, etc up until 10%, for each sample. 1 violin or boxplot per chunk (10 plots), 20k points per plot.

Then, with flagship cell GO, get rank distribution of features "unique" in group1 VS group2 (t-cell, neutrophil: k27ac vs k27me3)

#### Compute SHAP rankings / 1% chunks means

In [None]:
def load_data(shap_dir: Path) -> Tuple[Dict, Dict]:
    """Load evaluation and background data from specified directory.

    Args:
        shap_dir (Path): Directory containing the SHAP data files.

    Returns:
        tuple: Tuple containing evaluation results and background data.
    """
    # Find all npz files once
    npz_files = list(shap_dir.glob("*.npz"))
    eval_files = [f for f in npz_files if "evaluation" in f.name]
    background_files = [f for f in npz_files if "background" in f.name]

    if len(eval_files) != 1:
        raise ValueError(
            f"Expected one evaluation file, found {len(eval_files)} in {shap_dir}"
        )
    if len(background_files) != 1:
        raise ValueError(
            f"Expected one background file, found {len(background_files)} in {shap_dir}"
        )

    eval_results = dict(np.load(eval_files[0], allow_pickle=True))
    background_data = dict(np.load(background_files[0], allow_pickle=True))

    return eval_results, background_data


def compute_mean_shap_values(
    shap_matrices: List[np.ndarray],
    md5_indices: List[str],
    classes: Dict[str, int],
    metadata_df: pd.DataFrame,
) -> Dict[str, Dict]:
    """Rank amplitude of shap values and computes the mean SHAP values for specific segments of feature importance rankings
    for each sample, and aggregates this information along with their softmax transformations.

    This function processes SHAP values for given classes, extracts the relevant SHAP values
    for each sample using its MD5 index, ranks these values, and calculates the mean of these values
    in the top 10% segments. Additionally, it applies a softmax transformation to the SHAP values
    and computes the mean for these as well, allowing for comparison between raw and transformed importance.

    Args:
        shap_matrices (List[np.ndarray]): SHAP values for all classes.
        md5s (List[str]): Indices corresponding to MD5 hashes.
        classes (dict): Mapping of class names to indices.
        metadata_df (pd.DataFrame): DataFrame containing metadata.

    Returns:
        Dict[str, Dict]: Dictionary of SHAP details for each sample ({md5: details}).
    """
    input_size = len(shap_matrices[0][0])
    chunks_10perc_idx = [
        (floor(input_size / 100) * i, floor(input_size / 100) * (i + 1))
        for i in range(10)
    ]
    shap_details = {}

    for md5_idx, md5 in enumerate(md5_indices):
        cell_type: str = metadata_df.loc[md5][CELL_TYPE]  # type: ignore
        class_idx = classes[cell_type]
        sample_shaps = shap_matrices[class_idx][md5_idx]

        sample_shaps = np.abs(sample_shaps)  # magnitude only
        softmax_shaps = softmax(sample_shaps)
        ranks = np.argsort(sample_shaps)[::-1]  # descending order, greatest to smallest

        sorted_shap_vals = sample_shaps[ranks]
        sorted_softmax_shap_vals = softmax_shaps[ranks]

        mean_10perc_vals = [
            np.mean(sorted_shap_vals[idx1:idx2], dtype=np.float64)
            for idx1, idx2 in chunks_10perc_idx
        ]
        mean_10perc_vals_softmax = [
            np.mean(sorted_softmax_shap_vals[idx1:idx2], dtype=np.float64)
            for idx1, idx2 in chunks_10perc_idx
        ]

        shap_details[md5] = {
            "ranks": ranks,
            "mean_10perc_vals": mean_10perc_vals,
            "mean_10perc_vals_softmax": mean_10perc_vals_softmax,
        }

    return shap_details

In [None]:
def create_shap_tables(shap_details: Dict[str, Dict]) -> Dict[str, pd.DataFrame]:
    """Create tables containing SHAP details for each sample.

    Args:
        shap_details (Dict[str, Dict]): Dictionary of SHAP details for each sample
            format: {md5: {"ranks": vals, "mean_10perc_vals": vals, "mean_10perc_vals_softmax": vals}}

    Returns:
       Dict[str, pd.DataFrame]: List of DataFrames containing each shap details category.
    """
    ranks = pd.DataFrame.from_dict(
        data={md5: details["ranks"] for md5, details in shap_details.items()},
        orient="index",
        dtype="int32",
    )

    mean_10perc_vals = pd.DataFrame.from_dict(
        data={md5: details["mean_10perc_vals"] for md5, details in shap_details.items()},
        orient="index",
        dtype="float64",
    )
    mean_10perc_vals.columns = [f"mean(top {i}% to {i+1}%)" for i in range(10)]

    mean_10perc_vals_softmax = pd.DataFrame.from_dict(
        data={
            md5: details["mean_10perc_vals_softmax"]
            for md5, details in shap_details.items()
        },
        orient="index",
        dtype="float64",
    )
    mean_10perc_vals_softmax.columns = mean_10perc_vals.columns

    return {
        "ranks": ranks,
        "mean_10perc_vals": mean_10perc_vals,
        "mean_10perc_vals_softmax": mean_10perc_vals_softmax,
    }

In [None]:
def save_shap_tables(
    tables: Dict[str, pd.DataFrame], output_dir: Path, verbose: bool = True
):
    """Save SHAP tables to specified directory.

    Args:
        tables (Dict[str, pd.DataFrame]): Dictionary of tables to save.
        output_dir (Path): Directory to save the tables.
    """
    for name, table in tables.items():
        if name == "ranks":
            output_path = output_dir / "shap_abs_ranks.npz"
            np.savez_compressed(
                output_path,
                **{
                    "index": table.index.values,
                    "columns": table.columns.values,
                    "values": table.values,
                },
            )
            if verbose:
                print(f"Saved SHAP ranks to {output_path}")
        else:
            output_path = output_dir / f"shap_table_{name}.csv"
            table.to_csv(output_path)
            if verbose:
                print(f"Saved SHAP details to {output_path}")

In [None]:
split_name = "split0"

In [None]:
# shap_details = defaultdict(dict)
# for shap_dir in cell_type_dir.glob(f"{split_name}/shap"):
#     if not shap_dir.is_dir():
#         continue

#     try:
#         eval_results, background_data = load_data(shap_dir)
#     except ValueError as e:
#         print(e)
#         continue

#     classes = {class_name: int(idx) for idx, class_name in background_data["classes"]}

#     shap_matrices = eval_results["shap_values"]

#     shap_details.update(
#         compute_mean_shap_values(
#             shap_matrices, eval_results["evaluation_md5s"], classes, metadata_df
#         )
#     )

# del shap_matrices, eval_results, background_data


# shap_tables = create_shap_tables(shap_details)


# output_dir = cell_type_dir / "global_shap_analysis"
# if not output_dir.exists():
#     raise ValueError(f"Output directory {output_dir} does not exist.")

# output_dir = output_dir / f"{split_name}_details"
# output_dir.mkdir(exist_ok=True)
# save_shap_tables(shap_tables, output_dir)

# print(f"Processed {len(shap_details)} samples.")

#### Graph computed means ratios

In [None]:
def graph_shap_10perc(shap_table_mean_path: Path, output_dir: Path):
    """Plot the variation of mean SHAP values accross 1% segments (top 10)"""
    df = pd.read_csv(shap_table_mean_path, index_col=0)
    fig = go.Figure()
    for idx, col in enumerate(df.columns):
        try:
            ratio = df.iloc[:, idx] / df.iloc[:, idx + 1]
        except IndexError:
            continue
        fig.add_trace(
            go.Violin(
                y=ratio,
                name=f"Ratio {col} / {df.columns[idx+1]}",
                points="all",
                marker=dict(size=1),
                hovertemplate="%{text}",
                text=[f"{md5}: {diff}" for md5, diff in zip(df.index, ratio)],
                box_visible=True,
                meanline_visible=True,
                spanmode="hard",
            )
        )
    fig.update_layout(
        title="Ratio of mean SHAP values between 1% segments",
        xaxis_title="1% segment",
        yaxis_title="Ratio",
    )

    # Save figure
    output = output_dir / shap_table_mean_path.stem

    fig.write_html(f"{output}.html")
    fig.write_image(f"{output}.png")
    fig.write_image(f"{output}.svg")

    fig.show()

In [None]:
# mean_vals_path = cell_type_dir / "global_shap_analysis" / "shap_table_mean_10perc_vals_softmax.csv"
# if not mean_vals_path.exists():
#     raise FileNotFoundError(mean_vals_path)
# graph_shap_10perc(mean_vals_path, output_dir)

In [None]:
def beds_to_bins(
    feature_beds_tar: Path, chroms: List[Tuple[str, int]]
) -> Dict[str, List[int]]:
    """Extracts and processes .bed files from a tar.gz archive, specifically targeting T cells and neutrophils with h3k27ac or h3k27me3 markers.

    Args:
        feature_beds_tar (Path): The path to the tar.gz file containing .bed files.
        chroms (List[Tuple[str, int]]): List of tuples containing chromosome names and their lengths.

    Returns:
        Dict[str, List[int]]: A dictionary with keys as feature names and values as lists of bins derived from the .bed ranges.
    """
    feature_bins = {}

    with tarfile.open(feature_beds_tar, "r:gz") as tar:
        for member in tar.getmembers():
            if (
                member.name.endswith(".bed")
                and ("T_cell" in member.name or "neutrophil" in member.name)
                and ("h3k27ac" in member.name or "h3k27me3" in member.name)
            ):
                with tar.extractfile(member) as f:  # type: ignore
                    bed_ranges = read_bed_to_ranges(f)
                    bed_bins = bed_ranges_to_bins(
                        bed_ranges, chroms=chroms, resolution=100 * 1000
                    )
                    name = Path(member.name).stem.replace("_feature", "")
                    feature_bins[name] = bed_bins

    return feature_bins

In [None]:
# ((metadata_df[CELL_TYPE] == "T cell") & (metadata_df[ASSAY] == "h3k27ac")).value_counts()

In [None]:
# metadata_df[CELL_TYPE].value_counts()

#### Graph rank of important relevant features from previous analyses

In [None]:
feature_beds_tar = cell_type_dir / "global_shap_analysis" / "select_beds_top303.tar.gz"

feature_bins_dict = beds_to_bins(
    feature_beds_tar=feature_beds_tar,
    chroms=chroms,
)

In [None]:
# feature_bins_dict.keys()

In [None]:
print("Loading rank values.")
output_dir = cell_type_dir / "global_shap_analysis" / f"{split_name}_details"
rank_file = output_dir / "shap_abs_ranks.npz"
print(rank_file)

with np.load(rank_file, allow_pickle=True) as f:
    rank_values = f["values"]
    rank_md5s = f["index"]
    rank_bins = f["columns"]

In [None]:
def graph_ranks_sample_ontology(
    rank_file: Path, feature_bins: List[int], output_dir: Path, metadata_df: pd.DataFrame
):
    """Graph the rank of specific sample ontology features, from perspective of different classes."""
    # metadata groups: t-cell/neutrophil with h3k27ac/h3k27me3
    cell_types = ["T cell", "neutrophil"]
    assays = ["h3k27ac", "h3k27me3"]
    md5s_dict = {
        f"{cell}_{assay}_md5s": list(
            metadata_df[(metadata_df[CELL_TYPE] == cell) & (metadata_df[ASSAY] == assay)][
                "md5sum"
            ].values
        )
        for cell in cell_types
        for assay in assays
    }

    # get md5s for each of the four metadata groups, and then get the ranks for each of the shared bins.
    # one boxplot group per metadata group, with the sample ranks for each shared bin)
    fig = go.Figure()
    for name, md5s in md5s_dict.items():
        avail_md5s = np.isin(rank_md5s, md5s)
        shared_ranks = rank_values[avail_md5s]
        print(f"{name}: {shared_ranks.shape}")

        # Each bin is a different boxplot within the group
        for bin_index in sorted(feature_bins):
            feature_ranks = [np.where(ranks == bin_index)[0][0] for ranks in shared_ranks]
            fig.add_trace(
                go.Box(
                    x=[f"{name.replace('_md5s', '')} (n={len(feature_ranks)})"]
                    * len(feature_ranks),
                    y=feature_ranks,
                    name=f"{name}_bin_{bin_index}",
                    # legendgroup=name,
                    # legendgrouptitle=dict(text=name),
                    line=dict(color="black", width=1),
                    showlegend=False,
                    boxpoints=False,
                )
            )

    fig.update_layout(
        title=f"miaw (n={len(feature_bins)} bins)",
        xaxis_title="Feature bin",
        yaxis_title="Rank",
        boxmode="group",
        width=1000,
        height=1000,
    )

    fig.show()

#### Compare old results important features with currently computed ranks

In [None]:
cell_types = ["T cell", "neutrophil"]
assays = ["h3k27ac", "h3k27me3"]
md5s_dict = {
    f"{cell}_{assay}_md5s": list(
        metadata_df[(metadata_df[CELL_TYPE] == cell) & (metadata_df[ASSAY] == assay)][
            "md5sum"
        ].values
    )
    for cell in cell_types
    for assay in assays
}

In [None]:
split_dir = cell_type_dir / split_name
shap_dir = split_dir / "shap"
analysis_dir = shap_dir / "analysis_n303_f80.00"

assay = "h3k27ac"
cell_type = "T cell"

important_features_path = analysis_dir / assay / "important_features.json"
with open(important_features_path, "r", encoding="utf8") as f:
    important_features = json.load(f)

# important_features_selection = important_features[cell_type]["80.0"]
selected_features = list(
    set(feature_bins_dict["h3k27ac_T_cells"]) - set(feature_bins_dict["h3k27me3_T_cells"])
)

graph_ranks_sample_ontology(
    rank_file=None,
    feature_bins=selected_features,
    output_dir=output_dir,
    metadata_df=metadata_df,
)

In [None]:
relevant_md5s = md5s_dict[f"{cell_type}_{assay}_md5s"]
sample_ranks = rank_values[np.isin(rank_md5s, relevant_md5s)]
print(f"Number of files in subset: {sample_ranks.shape[0]}")

# get the top features for each sample, using the ranks
top_features_n303 = [rank_bins[sample_rank][0:303] for sample_rank in sample_ranks]

top_n303_frequency = Counter(
    [feature for sublist in top_features_n303 for feature in sublist]
)

feature_most_frequently_in_top303 = top_n303_frequency.most_common(1)[0][0]

print(sample_ranks[:, feature_most_frequently_in_top303])


# cutoff = "95"
# N = len(important_features[cell_type][cutoff])
# print(f"Top 303 features (features within the top303 in {cutoff}% of samples)): {N}")
# most_frequent_features_in_top303 = [feature for feature, _ in top_n303_frequency.most_common(N)]
# for feature in most_frequent_features_in_top303:
#     if feature not in important_features[cell_type][cutoff]:
#         print(f"Feature {feature} not in important_features_selection")