In [1]:
"""Code related to SHAP analyses for flagship paper + supplementary figures."""
# pylint: disable=line-too-long, redefined-outer-name, import-error, duplicate-code, unreachable

'Code related to SHAP analyses for flagship paper + supplementary figures.'

How relevant files were downloaded

~~~bash
base_path="/home/rabyj/scratch/epilap-logs/epiatlas-dfreeze-v2.1/hg38_100kb_all_none/harmonized_sample_ontology_intermediate_1l_3000n/10fold-oversampling"
rsync --info=progress2 -aR narval:${base_path}/./split*/*.md5 .

# rsync --info=progress2 -aR --exclude "*.npz" --exclude "analysis_n*_f80.00/" narval:${base_path}/./split*/shap .
rsync --info=progress2 -aR narval:${base_path}/./split*/shap/*background*.npz .
~~~

In [2]:
from collections import defaultdict
from math import floor
from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd
from scipy.special import softmax

from epi_ml.utils.notebooks.paper.paper_utilities import ASSAY, CELL_TYPE, MetadataHandler

In [3]:
CANCER = "harmonized_sample_cancer_high"

In [4]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = (
    base_dir / "data" / "training_results" / "dfreeze_v2" / "hg38_100kb_all_none"
)
paper_dir = base_dir

In [5]:
metadata_handler = MetadataHandler(paper_dir)
metadata = metadata_handler.load_metadata("v2")
metadata_df = pd.DataFrame.from_records(list(metadata.datasets))
metadata_df.set_index("md5sum", inplace=True)
del metadata

In [6]:
cell_type_dir = base_data_dir / f"{CELL_TYPE}_1l_3000n" / "10fold-oversampling"
cancer_dir = base_data_dir / f"{CANCER}_1l_3000n" / "10fold-oversampling"

In [7]:
for folder in [cell_type_dir, cancer_dir]:
    assert folder.exists()

### Background selection details

In [8]:
def compare_all_training_vs_background(parent_dir: Path):
    """Create a table that compares the training data to the SHAP background data for each split.

    Args:
        parent_dir: The directory containing each split/fold directory.
    """
    split_composition_dfs = {}
    for split_dir in parent_dir.glob("split*"):
        if not split_dir.is_dir():
            continue
        split = split_dir.name

        training_data_path = list(split_dir.glob("*training*.md5"))
        if len(training_data_path) != 1:
            raise ValueError(f"Multiple training data files found in {split_dir}")
        training_data_path = training_data_path[0]

        background_data_path = list(split_dir.glob("shap/*background*.npz"))
        if len(background_data_path) != 1:
            raise ValueError(f"Multiple background data files found in {split_dir}")
        background_data_path = background_data_path[0]

        training_md5s = set(
            pd.read_csv(training_data_path, index_col=False, header=None)[0].values
        )
        background_data = np.load(background_data_path)
        background_md5s = set(background_data["background_md5s"])

        # print(f"Split: {split}")
        # print(f"Training data: {len(training_md5s)}")
        # print(f"Background data: {len(background_md5s)}")

        # Sanity check: background data is a subset of training data
        diff = background_md5s - training_md5s
        if diff:
            raise ValueError(
                f"Background data is not a subset of the training data in {split_dir}: {len(diff)} md5s unique to background data."
            )

        # for diff ratio, do ((ratio background - ration training))*100 -> positive if background has more
        # do assay / cell type / output class ratios
        # table should have full number + ratios + diff ratio for each metadata category
        # do for each split, save to a file, then average accross splits

        raise NotImplementedError("Finish this function")

In [9]:
# for folder in [cell_type_dir, cancer_dir]:
#     compare_all_training_vs_background(folder)

### SHAP values ranking

SHAP rank significance analysis (sample ontology)

- Gather rank of all bins for sample output class (20k x 30k matrix)
- For each % of highest values starting from top to 10%, compute mean SHAP value (20k x 10 matrix)
- Create a graph that represents the ratio between 0-1% and 1-2% chunks, etc up until 10%, for each sample. 1 violin or boxplot per chunk (10 plots), 20k points per plot.

Then, with flagship cell GO, get rank distribution of features "unique" in group1 VS group2 (t-cell, neutrophil: k27ac vs k27me3)

In [44]:
def load_data(shap_dir: Path):
    """Load evaluation and background data from specified directory.

    Args:
        shap_dir (Path): Directory containing the SHAP data files.

    Returns:
        tuple: Tuple containing evaluation results and background data.
    """
    # Find all npz files once
    npz_files = list(shap_dir.glob("*.npz"))
    eval_files = [f for f in npz_files if "evaluation" in f.name]
    background_files = [f for f in npz_files if "background" in f.name]

    if len(eval_files) != 1:
        raise ValueError(
            f"Expected one evaluation file, found {len(eval_files)} in {shap_dir}"
        )
    if len(background_files) != 1:
        raise ValueError(
            f"Expected one background file, found {len(background_files)} in {shap_dir}"
        )

    eval_results = np.load(eval_files[0], allow_pickle=True)
    background_data = np.load(background_files[0], allow_pickle=True)

    return eval_results, background_data


def compute_mean_shap_values(
    shap_matrices: List[np.ndarray],
    md5_indices: List[str],
    classes: Dict[str, int],
    metadata_df: pd.DataFrame,
) -> Dict[str, Dict]:
    """Rank amplitude of shap values and computes the mean SHAP values for specific segments of feature importance rankings
    for each sample, and aggregates this information along with their softmax transformations.

    This function processes SHAP values for given classes, extracts the relevant SHAP values
    for each sample using its MD5 index, ranks these values, and calculates the mean of these values
    in the top 10% segments. Additionally, it applies a softmax transformation to the SHAP values
    and computes the mean for these as well, allowing for comparison between raw and transformed importance.

    Args:
        shap_matrices (List[np.ndarray]): SHAP values for all classes.
        md5s (List[str]): Indices corresponding to MD5 hashes.
        classes (dict): Mapping of class names to indices.
        metadata_df (pd.DataFrame): DataFrame containing metadata.

    Returns:
        Dict[str, Dict]: Dictionary of SHAP details for each sample ({md5: details}).
    """
    input_size = len(shap_matrices[0][0])
    chunks_10perc_idx = [
        (floor(input_size / 100) * i, floor(input_size / 100) * (i + 1))
        for i in range(10)
    ]
    shap_details = {}

    for md5_idx, md5 in enumerate(md5_indices):
        cell_type: str = metadata_df.loc[md5][CELL_TYPE]  # type: ignore
        class_idx = classes[cell_type]
        sample_shaps = shap_matrices[class_idx][md5_idx]

        sample_shaps = np.abs(sample_shaps)  # magnitude only
        softmax_shaps = softmax(sample_shaps)
        ranks = np.argsort(sample_shaps)[::-1]  # descending order, greatest to smallest

        sorted_shap_vals = sample_shaps[ranks]
        sorted_softmax_shap_vals = softmax_shaps[ranks]

        mean_10perc_vals = [
            np.mean(sorted_shap_vals[idx1:idx2], dtype=np.float64)
            for idx1, idx2 in chunks_10perc_idx
        ]
        mean_10perc_vals_softmax = [
            np.mean(sorted_softmax_shap_vals[idx1:idx2], dtype=np.float64)
            for idx1, idx2 in chunks_10perc_idx
        ]

        shap_details[md5] = {
            "ranks": ranks,
            "mean_10perc_vals": mean_10perc_vals,
            "mean_10perc_vals_softmax": mean_10perc_vals_softmax,
        }

    return shap_details

In [46]:
def create_shap_tables(shap_details: Dict[str, Dict]) -> Dict[str, pd.DataFrame]:
    """Create tables containing SHAP details for each sample.

    Args:
        shap_details (Dict[str, Dict]): Dictionary of SHAP details for each sample
            format: {md5: {"ranks": vals, "mean_10perc_vals": vals, "mean_10perc_vals_softmax": vals}}

    Returns:
       Dict[str, pd.DataFrame]: List of DataFrames containing each shap details category.
    """
    ranks = pd.DataFrame.from_dict(
        data={md5: details["ranks"] for md5, details in shap_details.items()},
        orient="index",
        dtype="int32",
    )

    mean_10perc_vals = pd.DataFrame.from_dict(
        data={md5: details["mean_10perc_vals"] for md5, details in shap_details.items()},
        orient="index",
        dtype="float64",
    )
    mean_10perc_vals.columns = [f"mean(top {i}% to {i+1}%)" for i in range(10)]

    mean_10perc_vals_softmax = pd.DataFrame.from_dict(
        data={
            md5: details["mean_10perc_vals_softmax"]
            for md5, details in shap_details.items()
        },
        orient="index",
        dtype="float64",
    )
    mean_10perc_vals_softmax.columns = mean_10perc_vals.columns

    return {
        "ranks": ranks,
        "mean_10perc_vals": mean_10perc_vals,
        "mean_10perc_vals_softmax": mean_10perc_vals_softmax,
    }

In [59]:
def save_shap_tables(
    tables: Dict[str, pd.DataFrame], output_dir: Path, verbose: bool = True
):
    """Save SHAP tables to specified directory.

    Args:
        tables (Dict[str, pd.DataFrame]): Dictionary of tables to save.
        output_dir (Path): Directory to save the tables.
    """
    for name, table in tables.items():
        if name == "ranks":
            output_path = output_dir / "shap_abs_ranks.npz"
            np.savez_compressed(
                output_path,
                **{
                    "index": table.index.values,
                    "columns": table.columns.values,
                    "values": table.values,
                },
            )
            if verbose:
                print(f"Saved SHAP ranks to {output_path}")
        else:
            output_path = output_dir / f"shap_details_{name}.csv"
            table.to_csv(output_path)
            if verbose:
                print(f"Saved SHAP details to {output_path}")

In [None]:
shap_details = defaultdict(dict)
for shap_dir in cell_type_dir.glob("split*/shap"):
    if not shap_dir.is_dir():
        continue

    try:
        eval_results, background_data = load_data(shap_dir)
    except ValueError as e:
        print(e)

    classes = {class_name: int(idx) for idx, class_name in background_data["classes"]}

    shap_matrices = eval_results["shap_values"]

    shap_details.update(
        compute_mean_shap_values(
            shap_matrices, eval_results["evaluation_md5s"], classes, metadata_df
        )
    )

shap_tables = create_shap_tables(shap_details)

save_shap_tables(shap_tables, cell_type_dir)

print(f"Processed {len(shap_details)} samples.")