In [None]:
"""Initial analysis of shap values behavior."""
# pylint: disable=redefined-outer-name, expression-not-assigned, import-error, not-callable, pointless-statement, no-value-for-parameter, undefined-variable, unused-argument, line-too-long
from __future__ import annotations

import copy
import itertools
from collections import Counter
from pathlib import Path
from typing import Dict, List, Sequence, Set, Tuple

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.io as pio
from IPython.display import display
from scipy.special import softmax  # type: ignore

pio.renderers.default = "notebook"

from epi_ml.core import metadata
from epi_ml.core.analysis import bins_to_bed_ranges, write_to_bed

# from epi_ml.core.data import UnknownData
# from epi_ml.core.hdf5_loader import Hdf5Loader
# from epi_ml.core.model_pytorch import LightningDenseClassifier
# from epi_ml.core.shap_values import SHAP_Analyzer, SHAP_Handler

In [None]:
%matplotlib inline

In [None]:
def load_chroms(chrom_file):
    """Return sorted chromosome names list."""
    with open(chrom_file, "r", encoding="utf-8") as file:
        chroms = []
        for line in file:
            line = line.rstrip()
            if line:
                name, size = line.split()
                chroms.append(tuple([name, int(size)]))
    chroms.sort()
    return chroms

In [None]:
chroms = load_chroms(
    "/home/local/USHERBROOKE/rabj2301/Projects/epilap/input/chromsizes/hg38.noy.chrom.sizes"
)

In [None]:
home = Path("/home/local/USHERBROOKE/rabj2301/Projects")
input_dir = home / "epilap/input"
metadata_path = (
    input_dir
    / "metadata/hg38_2023_epiatlas_dfreeze_plus_encode_noncore_formatted_JR.json"
)

output = home / "epilap/output"
logdir = (
    output / "models/SHAP" / "harmonized_donor_sex_1l_3000n/100kb_all_none_blklst/split0/"
)

my_meta = metadata.Metadata(metadata_path)
# meta_copy = copy.deepcopy(my_meta)

In [None]:
def select_shap_samples(shap_dict, n: int) -> Dict[str, List[np.ndarray]]:
    """Return a subset of shap values and their ids."""
    selected_shap_samples = {"shap": [], "ids": []}
    total_samples = len(shap_dict["ids"])
    selected_indices = np.random.choice(total_samples, n, replace=False)

    for class_shap_values in shap_dict["shap"]:
        selected_shap_samples["shap"].append(class_shap_values[selected_indices, :])

    selected_shap_samples["ids"] = [shap_dict["ids"][idx] for idx in selected_indices]

    return selected_shap_samples

In [None]:
def get_archives(shap_values_dir: Path):
    """Return shap values and explainer background archives. from npz files."""
    try:
        shap_values_path = next(shap_values_dir.glob("*evaluation*.npz"))
        background_info_path = next(shap_values_dir.glob("*explainer_background*.npz"))
    except StopIteration as err:
        raise FileNotFoundError(
            f"Could not find shap values or explainer background archives in {shap_values_dir}"
        ) from err

    with open(shap_values_path, "rb") as f:
        shap_values_archive = np.load(f)
        shap_values_archive = dict(shap_values_archive.items())

    with open(background_info_path, "rb") as f:
        explainer_background = np.load(f)
        explainer_background = dict(explainer_background.items())

    return shap_values_archive, explainer_background

In [None]:
def average_impact(shap_values_matrices):
    """Return average absolute shap values."""
    shap_abs = np.zeros(shap_values_matrices[0].shape)
    for matrix in shap_values_matrices:
        shap_abs += np.absolute(matrix)
    shap_abs /= len(shap_values_matrices)
    return shap_abs

In [None]:
def n_most_important_features(sample_shaps, n):
    """Return features with highest absolute shap values."""
    return np.flip(np.argsort(np.absolute(sample_shaps)))[:n]

In [None]:
def subsample_md5s(
    md5s: List[str], metadata: metadata.Metadata, category_label: str, labels: List[str]
) -> List[int]:
    """Subsample md5s based on metadata filtering provided, for a given category and filtering labels.

    Args:
            md5s (list): A list of MD5 hashes.
            metadata (Metadata): A metadata object containing the data to be filtered.
            category_label (str): The category label to be used for filtering the metadata.
            labels (list): A list of labels to be used for selecting category subsets in the metadata.

    Returns:
            list: A list of indices corresponding to the selected md5s.
    """
    meta = copy.deepcopy(metadata)
    meta.select_category_subsets(category_label, labels)
    chosen_idxs = []
    for i, md5 in enumerate(md5s):
        if md5 in meta:
            chosen_idxs.append(i)
    return chosen_idxs

In [None]:
def get_most_frequent_feature(
    pairwise_intersections: List[Set[int]], quantile_list: List[int]
) -> Dict[int, List[int]]:
    """
    Get a list of the most frequent features from multiple feature lists, according to some quantiles..

    This function takes a list of feature lists and a quantile list. It calculates the occurrence frequency
    of each feature and returns the list of features at least as frequent as the specified quantiles.

    Args:
        feature_lists (List[List[int]]): A list of feature lists, where each inner list contains feature indices.
        quantile_list (List[int]: The quantile values for which the most frequent features will be returned.

    Returns:
        Dict[int:List[int]]: A dict containing the list of features in each specified quantile.
    """
    for quantile in quantile_list:
        if quantile < 0 or quantile > 100:
            raise ValueError("Quantile values must be between 0 and 100.")

    # Compute the features in the specified quantiles
    intersection_counter = Counter()
    for feature_set in pairwise_intersections:
        intersection_counter.update(feature_set)

    df = pd.DataFrame.from_dict(data=intersection_counter, orient="index").reset_index()
    df.columns = ["Feature", "Count"]

    quantile_features_dict = {}
    for quantile in quantile_list:
        curr_q = df["Count"].quantile(
            quantile / 100
        )  # this calculates the quantile value
        curr_choice = df[
            df["Count"] >= curr_q
        ]  # this selects all features within current quantile
        quantile_features_dict[quantile] = curr_choice["Feature"].tolist()

    return quantile_features_dict


def feature_overlap_stats(
    feature_lists: List[List[int]], quantile_list: list[int]
) -> Tuple[Set[int], Set[int], Dict[int, List]]:
    """
    Calculate the statistics of feature overlap between multiple feature lists.

    This function takes a list of feature lists and calculates the median and average
    pairwise overlaps between them. It also computes the union and intersection of all features
    in the given feature lists.

    Args:
        feature_lists (List[List[int]]): A list of feature lists, where each inner list contains feature indices.
        quantile_list (List[int]: The quantile values for which the most frequent features will be returned.

    Returns:
        Tuple[Set[int], Set[int], Dict[int, List]]: A tuple containing 1) intersection of all features
        2) union of all features 3) a dict containing the list of features in each specified quantile.
    """
    # Compute the overlap between two feature lists
    all_pairwise_overlaps = [
        set(sample1) & set(sample2)
        for sample1, sample2 in itertools.combinations(feature_lists, 2)
    ]
    all_pairwise_overlaps_len = [len(x) for x in all_pairwise_overlaps]
    print("Pairwise feature overlap statistics:")
    display(pd.DataFrame(all_pairwise_overlaps_len).describe())

    # Most frequent features (per quantile)
    frequent_features = get_most_frequent_feature(all_pairwise_overlaps, quantile_list)

    # Union and intersection of all features
    all_features_union: Set[int] = set()
    all_features_intersection: Set[int] = set(feature_lists[0])
    for feature_set in feature_lists:
        all_features_union.update(feature_set)
        all_features_intersection &= set(feature_set)

    return all_features_intersection, all_features_union, frequent_features  # type: ignore

In [None]:
def get_shap_matrix(
    meta: metadata.Metadata,
    shap_matrices: List[np.ndarray],
    eval_md5s: List[str],
    assay_subsample: List[str] | None,
    class_idx: int = 0,
) -> Tuple[np.ndarray, List[int]]:
    """Generates a SHAP matrix corresponding to a selected subset of samples.

    This function selects a subset of samples based on specified criteria
    and then generates a SHAP matrix for these selected samples. It filters
    the metadata if a specific assay subsample is provided, and selects a
    subset of samples that are identified by their md5 hash. It then selects
    the SHAP values of these samples for the class number.

    Args:
        meta (metadata.Metadata): Metadata object containing information about the samples.
        shap_matrices (List[np.ndarray]): List of SHAP matrices for each class.
        eval_md5s (List[str]): List of md5 hashes identifying the evaluation samples.
        assay_subsample (List[str] | None): List of assay subsamples to consider.
                                             If None, all samples are considered.

    Returns:
        np.ndarray: The selected SHAP matrix for the first class and for the
                    chosen samples based on the provided criteria.
        List[int]: The indices of the chosen samples in the original SHAP matrix.
    """
    my_meta = copy.deepcopy(meta)

    # Filter metadata
    if assay_subsample:
        my_meta.select_category_subsets("assay_epiclass", assay_subsample)
        print(f"Subsampled metadata with {assay_subsample}")
        my_meta.display_labels("harmonized_donor_sex")

    selected_class = "female"
    chosen_idxs = subsample_md5s(
        eval_md5s, my_meta, "harmonized_donor_sex", [selected_class]
    )
    try:
        class_shap = shap_matrices[class_idx]
    except IndexError as err:
        raise IndexError(f"Class index {class_idx} is out of bounds.") from err

    selected_class_shap = np.array(class_shap[chosen_idxs, :])
    print(
        f"Shape of selected class ({selected_class}) shap values: {selected_class_shap.shape}"
    )
    print(f"Chose {len(chosen_idxs)} samples from {class_shap.shape[0]} samples")
    return selected_class_shap, chosen_idxs

In [None]:
def print_feature_overlap_stats(feature_stats: Sequence):
    """Prints the statistics of feature overlap.

    This function receives the feature statistics which include the intersection,
    union and frequent features in each quantile of features. It then prints
    these statistics for easy inspection.

    Args:
        feature_stats (Sequence): Tuple containing the intersection, union and
                                  frequent features in each quantile of features.
    """
    features_intersection, features_union, frequent_features = feature_stats
    print(f"Intersection of all features: {len(features_intersection)} features")
    print(f"Fully intersecting features: {list(features_intersection)}")
    print(f"Union of all features: {len(features_union)} features\n")
    for k, v in frequent_features.items():
        print(f"Most frequent features in {k}th quantile: {len(v)} features")

In [None]:
def print_importance_info(feature_selection: List[int], shap_matrix: np.ndarray):
    """Prints the feature importance information.

    This function prints the feature importance information, which includes the
    average expected contribution of the selected features and one feature (if
    the importance was uniform), and statistical descriptions of the contributions
    of the selected features.

    Args:
        feature_selection (List[int]): The indices of the selected features.
        shap_matrix (np.ndarray): The SHAP values matrix.

    """

    N = len(feature_selection)
    nb_files = shap_matrix.shape[0]
    print(
        f"Average expected contribution of {N} feature if uniform importance:{N/30321*100:.5f}%"
    )
    print(
        f"Average expected contribution of 1 feature if uniform importance:{1/30321*100:.5f}%"
    )
    print(f"Average contribution of selected features for {nb_files} files:")
    display(
        pd.DataFrame(
            softmax(shap_matrix, axis=1)[:, list(feature_selection)].sum(axis=1) * 100
        ).describe()
    )
    print(f"Individual contribution of selected features for {nb_files} files:")
    display(
        pd.DataFrame(
            softmax(shap_matrix, axis=1)[:, list(feature_selection)] * 100
        ).describe()
    )

In [None]:
def run_the_whole_thing(metadata: metadata.Metadata, shap_logdir: Path, top_n: int = 100):
    """Execute the complete process of SHAP value analysis.

    This function performs the complete SHAP value analysis given the metadata and the directory
    of the SHAP value files. It carries out the following steps:
    1. Load the SHAP value archives and print basic statistics.
    2. Filter the metadata to match the samples in SHAP value archives.
    3. Select a specific assay and compute SHAP values.
    4. Determine the top N features for each sample.
    5. Compute feature overlap statistics and print them.
    6. Analyze feature importance.
    7. Convert bin indices to genomic ranges and write to a BED file.
    8. Display and save a plot of importance distribution for one sample.

    Args:
        metadata (metadata.Metadata): The metadata for the samples.
        shap_logdir (Path): The directory path where SHAP value files are stored.
    """
    metadata = copy.deepcopy(metadata)
    shap_values_archive, _ = get_archives(shap_logdir)

    eval_md5s: List[str] = shap_values_archive["evaluation_md5s"]
    shap_matrices = shap_values_archive["shap_values"]

    print(f"nb classes: {len(shap_matrices)}")
    print(f"nb samples: {len(eval_md5s)}")
    print(f"dim shap value matrix: {shap_matrices[0].shape}")
    print(f"Output classes of classifier:\n {shap_values_archive['classes']}")

    for md5 in list(metadata.md5s):
        if md5 not in set(eval_md5s):
            del metadata[md5]

    metadata.display_labels("assay_epiclass")
    metadata.display_labels("harmonized_donor_sex")

    chosen_assay = "h3k27ac"
    shap_matrix, chosen_idxs = get_shap_matrix(
        metadata, shap_matrices, eval_md5s, [chosen_assay]
    )

    print(f"Selecting features with top {top_n} SHAP values for each sample.")
    top_n_features = []
    for sample in shap_matrix:
        top_n_features.append(list(n_most_important_features(sample, top_n)))

    some_stats = feature_overlap_stats(top_n_features, [0, 90, 95, 99])
    frequent_features = some_stats[2]

    print_feature_overlap_stats(some_stats)

    feature_selection = frequent_features[99]

    print_importance_info(feature_selection, shap_matrix)

    bed_vals = bins_to_bed_ranges(
        sorted(feature_selection), chroms, resolution=100 * 1000
    )
    write_to_bed(
        bed_vals,
        shap_logdir / f"frequent_features_99_female_{chosen_assay}.bed",
        verbose=True,
    )

    print("One sample")
    probs_1sample = pd.DataFrame(softmax(shap_matrix, axis=1)[0, :] * 100)
    display(probs_1sample.describe())
    fig_title = f"Importance distribution - One sample - {eval_md5s[chosen_idxs[0]]}"
    fig = px.violin(probs_1sample, box=True, points="all", title=fig_title)
    fig.write_image(shap_logdir / "importance_dist_1sample.png")

    return frequent_features

In [None]:
gen_shap_logdir = output / "models" / "SHAP" / "harmonized_donor_sex_1l_3000n"
names = [
    "100kb_all_none_0blklst/10fold-l1-50_dropout-0.25",
    "100kb_all_none_0blklst_winsorized/10fold-l1-50_dropout-0.25",
]
for name in names:
    print(name)
    path = gen_shap_logdir / name / "split0"
    frequent_features = run_the_whole_thing(my_meta, path)
    print()