In [1]:
"""Initial analysis of shap values behavior."""
# pylint: disable=redefined-outer-name, expression-not-assigned, import-error, not-callable, pointless-statement, no-value-for-parameter, undefined-variable, unused-argument, line-too-long
from __future__ import annotations

import copy
import itertools
from collections import Counter
from pathlib import Path
from typing import Dict, List, Sequence, Set, Tuple

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.io as pio
from IPython.display import display
from scipy.special import softmax

pio.renderers.default = "notebook"

from epi_ml.core import metadata
from epi_ml.core.analysis import bins_to_bed_ranges, write_to_bed

# from epi_ml.core.data import UnknownData
# from epi_ml.core.hdf5_loader import Hdf5Loader
# from epi_ml.core.model_pytorch import LightningDenseClassifier
# from epi_ml.core.shap_values import SHAP_Analyzer, SHAP_Handler

In [5]:
%matplotlib inline

In [6]:
def load_chroms(chrom_file):
    """Return sorted chromosome names list."""
    with open(chrom_file, "r", encoding="utf-8") as file:
        chroms = []
        for line in file:
            line = line.rstrip()
            if line:
                name, size = line.split()
                chroms.append(tuple([name, int(size)]))
    chroms.sort()
    return chroms

In [7]:
chroms = load_chroms(
    "/home/local/USHERBROOKE/rabj2301/Projects/epilap/input/chromsizes/hg38.noy.chrom.sizes"
)

In [8]:
home = Path("/home/local/USHERBROOKE/rabj2301/Projects")
input_dir = home / "epilap/input"
metadata_path = (
    input_dir
    / "metadata/hg38_2023_epiatlas_dfreeze_plus_encode_noncore_formatted_JR.json"
)

output = home / "epilap/output"
logdir = output / "models/SHAP/harmonized_donor_sex_1l_3000n/10fold-only_l1-split9"


my_meta = metadata.Metadata(metadata_path)
# meta_copy = copy.deepcopy(my_meta)

In [9]:
def select_shap_samples(shap_dict, n: int) -> Dict[str, List[np.ndarray]]:
    """Return a subset of shap values and their ids."""
    selected_shap_samples = {"shap": [], "ids": []}
    total_samples = len(shap_dict["ids"])
    selected_indices = np.random.choice(total_samples, n, replace=False)

    for class_shap_values in shap_dict["shap"]:
        selected_shap_samples["shap"].append(class_shap_values[selected_indices, :])

    selected_shap_samples["ids"] = [shap_dict["ids"][idx] for idx in selected_indices]

    return selected_shap_samples

In [10]:
def get_archives(shap_values_dir: Path):
    """Return shap values and explainer background archives. from npz files."""
    shap_values_path = next(shap_values_dir.glob("*evaluation*.npz"))
    background_info_path = next(shap_values_dir.glob("*explainer_background*.npz"))

    with open(shap_values_path, "rb") as f:
        shap_values_archive = np.load(f)
        shap_values_archive = dict(shap_values_archive.items())

    with open(background_info_path, "rb") as f:
        explainer_background = np.load(f)
        explainer_background = dict(explainer_background.items())

    return shap_values_archive, explainer_background

In [11]:
def average_impact(shap_values_matrices):
    """Return average absolute shap values."""
    shap_abs = np.zeros(shap_values_matrices[0].shape)
    for matrix in shap_values_matrices:
        shap_abs += np.absolute(matrix)
    shap_abs /= len(shap_values_matrices)
    return shap_abs

In [12]:
def n_most_important_features(sample_shaps, n):
    """Return features with highest absolute shap values."""
    return np.flip(np.argsort(np.absolute(sample_shaps)))[:n]

In [13]:
def subsample_md5s(
    md5s: List[str], metadata: metadata.Metadata, category_label: str, labels: List[str]
) -> List[int]:
    """Subsample md5s based on metadata filtering provided, for a given category and filtering labels.

    Args:
            md5s (list): A list of MD5 hashes.
            metadata (Metadata): A metadata object containing the data to be filtered.
            category_label (str): The category label to be used for filtering the metadata.
            labels (list): A list of labels to be used for selecting category subsets in the metadata.

    Returns:
            list: A list of indices corresponding to the selected md5s.
    """
    meta = copy.deepcopy(metadata)
    meta.select_category_subsets(category_label, labels)
    chosen_idxs = []
    for i, md5 in enumerate(md5s):
        if md5 in meta:
            chosen_idxs.append(i)
    return chosen_idxs

In [14]:
def get_most_frequent_feature(
    pairwise_intersections: List[Set[int]], quantile_list: List[int]
) -> Dict[int, List[int]]:
    """
    Get a list of the most frequent features from multiple feature lists, according to some quantiles..

    This function takes a list of feature lists and a quantile list. It calculates the occurrence frequency
    of each feature and returns the list of features at least as frequent as the specified quantiles.

    Args:
        feature_lists (List[List[int]]): A list of feature lists, where each inner list contains feature indices.
        quantile_list (List[int]: The quantile values for which the most frequent features will be returned.

    Returns:
        Dict[int:List[int]]: A dict containing the list of features in each specified quantile.
    """
    for quantile in quantile_list:
        if quantile < 0 or quantile > 100:
            raise ValueError("Quantile values must be between 0 and 100.")

    # Compute the features in the specified quantiles
    intersection_counter = Counter()
    for feature_set in pairwise_intersections:
        intersection_counter.update(feature_set)

    df = pd.DataFrame.from_dict(data=intersection_counter, orient="index").reset_index()
    df.columns = ["Feature", "Count"]

    quantile_features_dict = {}
    for quantile in quantile_list:
        curr_q = df["Count"].quantile(
            quantile / 100
        )  # this calculates the quantile value
        curr_choice = df[
            df["Count"] >= curr_q
        ]  # this selects all features within current quantile
        quantile_features_dict[quantile] = curr_choice["Feature"].tolist()

    return quantile_features_dict


def feature_overlap_stats(
    feature_lists: List[List[int]], quantile_list: list[int]
) -> Tuple[Set[int], Set[int], Dict[int, List]]:
    """
    Calculate the statistics of feature overlap between multiple feature lists.

    This function takes a list of feature lists and calculates the median and average
    pairwise overlaps between them. It also computes the union and intersection of all features
    in the given feature lists.

    Args:
        feature_lists (List[List[int]]): A list of feature lists, where each inner list contains feature indices.
        quantile_list (List[int]: The quantile values for which the most frequent features will be returned.

    Returns:
        Tuple[Set[int], Set[int], Dict[int, List]]: A tuple containing 1) intersection of all features
        2) union of all features 3) a dict containing the list of features in each specified quantile.
    """
    # Compute the overlap between two feature lists
    all_pairwise_overlaps = [
        set(sample1) & set(sample2)
        for sample1, sample2 in itertools.combinations(feature_lists, 2)
    ]
    all_pairwise_overlaps_len = [len(x) for x in all_pairwise_overlaps]
    print("Pairwise feature overlap statistics:")
    display(pd.DataFrame(all_pairwise_overlaps_len).describe())

    # Most frequent features (per quantile)
    frequent_features = get_most_frequent_feature(all_pairwise_overlaps, quantile_list)

    # Union and intersection of all features
    all_features_union: Set[int] = set()
    all_features_intersection: Set[int] = set(feature_lists[0])
    for feature_set in feature_lists:
        all_features_union.update(feature_set)
        all_features_intersection &= set(feature_set)

    return all_features_intersection, all_features_union, frequent_features  # type: ignore

In [15]:
def get_shap_matrix(
    meta: metadata.Metadata,
    shap_matrices: List[np.ndarray],
    eval_md5s: List[str],
    assay_subsample: List[str] | None,
    class_idx: int = 0,
):
    """Generates a SHAP matrix corresponding to a selected subset of samples.

    This function selects a subset of samples based on specified criteria
    and then generates a SHAP matrix for these selected samples. It filters
    the metadata if a specific assay subsample is provided, and selects a
    subset of samples that are identified by their md5 hash. It then selects
    the SHAP values of these samples for the class number.

    Args:
        meta (metadata.Metadata): Metadata object containing information about the samples.
        shap_matrices (List[np.ndarray]): List of SHAP matrices for each class.
        eval_md5s (List[str]): List of md5 hashes identifying the evaluation samples.
        assay_subsample (List[str] | None): List of assay subsamples to consider.
                                             If None, all samples are considered.

    Returns:
        np.ndarray: The selected SHAP matrix for the first class and for the
                    chosen samples based on the provided criteria.

    """
    my_meta = copy.deepcopy(meta)

    # Filter metadata
    if assay_subsample:
        my_meta.select_category_subsets("assay_epiclass", assay_subsample)
        print(f"Subsampled metadata with {assay_subsample}")
        my_meta.display_labels("harmonized_donor_sex")

    chosen_idxs = subsample_md5s(eval_md5s, my_meta, "harmonized_donor_sex", ["female"])
    print(len(chosen_idxs))

    try:
        class_shap = shap_matrices[class_idx]
    except IndexError as err:
        raise IndexError(f"Class index {class_idx} is out of bounds.") from err

    selected_first_class_shap = class_shap[chosen_idxs, :]
    print(selected_first_class_shap.shape)
    print(f"Chose {class_shap.shape[0]} samples from {class_shap.shape[0]} samples")
    return selected_first_class_shap

In [16]:
def print_feature_overlap_stats(feature_stats: Sequence):
    """Prints the statistics of feature overlap.

    This function receives the feature statistics which include the intersection,
    union and frequent features in each quantile of features. It then prints
    these statistics for easy inspection.

    Args:
        feature_stats (Sequence): Tuple containing the intersection, union and
                                  frequent features in each quantile of features.
    """
    features_intersection, features_union, frequent_features = feature_stats
    print(f"Intersection of all features: {len(features_intersection)} features")
    print(f"Union of all features: {len(features_union)} features\n")
    print(list(features_intersection))

    for k, v in frequent_features.items():
        print(f"Most frequent features in {k}th quantile: {len(v)} features")

In [17]:
def print_importance_info(feature_selection: List[int], shap_matrix: np.ndarray):
    """Prints the feature importance information.

    This function prints the feature importance information, which includes the
    average expected contribution of the selected features and one feature (if
    the importance was uniform), and statistical descriptions of the contributions
    of the selected features.

    Args:
        feature_selection (List[int]): The indices of the selected features.
        shap_matrix (np.ndarray): The SHAP values matrix.

    """

    N = len(feature_selection)
    print(
        f"Average expected contribution of {N} feature if uniform importance:{N/30321*100:.5f}%"
    )
    print(
        f"Average expected contribution of 1 feature if uniform importance:{1/30321*100:.5f}%"
    )
    display(
        pd.DataFrame(
            softmax(shap_matrix, axis=1)[:, list(feature_selection)].sum(axis=1) * 100
        ).describe()
    )
    display(
        pd.DataFrame(
            softmax(shap_matrix, axis=1)[:, list(feature_selection)] * 100
        ).describe()
    )

In [20]:
def run_the_whole_thing(metadata: metadata.Metadata, shap_logdir: Path):
    """Execute the complete process of SHAP value analysis.

    This function performs the complete SHAP value analysis given the metadata and the directory
    of the SHAP value files. It carries out the following steps:
    1. Load the SHAP value archives and print basic statistics.
    2. Filter the metadata to match the samples in SHAP value archives.
    3. Select a specific assay and compute SHAP values.
    4. Determine the top 100 features for each sample.
    5. Compute feature overlap statistics and print them.
    6. Analyze feature importance.
    7. Convert bin indices to genomic ranges and write to a BED file.
    8. Display and save a plot of importance distribution for one sample.

    Args:
        metadata (metadata.Metadata): The metadata for the samples.
        shap_logdir (Path): The directory path where SHAP value files are stored.
    """
    metadata = copy.deepcopy(metadata)
    shap_values_archive, _ = get_archives(shap_logdir)

    eval_md5s = shap_values_archive["evaluation_md5s"]
    shap_matrices = shap_values_archive["shap_values"]

    print(f"nb classes: {len(shap_matrices)}")
    print(f"nb samples: {len(eval_md5s)}")
    print(f"dim shap value matrix: {shap_matrices[0].shape}")
    print(shap_values_archive["classes"])

    for md5 in list(metadata.md5s):
        if md5 not in set(eval_md5s):
            del metadata[md5]

    metadata.display_labels("assay_epiclass")
    metadata.display_labels("harmonized_donor_sex")

    chosen_assay = "h3k27ac"
    shap_matrix = get_shap_matrix(metadata, shap_matrices, eval_md5s, [chosen_assay])

    top_100_features = []
    for sample in shap_matrix:
        top_100_features.append(list(n_most_important_features(sample, 100)))

    some_stats = feature_overlap_stats(top_100_features, [0, 90, 95, 99])
    frequent_features = some_stats[2]

    print_feature_overlap_stats(some_stats)

    feature_selection = frequent_features[99]

    print_importance_info(feature_selection, shap_matrix)

    bed_vals = bins_to_bed_ranges(
        sorted(feature_selection), chroms, resolution=100 * 1000
    )
    write_to_bed(
        bed_vals, shap_logdir / f"frequent_features_99_female_{chosen_assay}.bed"
    )

    print("One sample")
    probs_1sample = pd.DataFrame(softmax(shap_matrix, axis=1)[0, :] * 100)
    display(probs_1sample.describe())
    fig_title = "Importance distribution - One sample"
    fig = px.violin(probs_1sample, box=True, points="all", title=fig_title)
    fig.write_image(shap_logdir / "importance_dist_1sample.png")

    return frequent_features

In [21]:
gen_shap_logdir = output / "models" / "harmonized_donor_sex_1l_3000n"
names = ["10fold-binary-l1-500_dropout-0.10", "10fold-binary-l1-50_dropout-0.25"]
for name in names:
    print(name)
    path = gen_shap_logdir / name / "split0" / "SHAP"
    frequent_features = run_the_whole_thing(my_meta, path)

10fold-binary-l1-500_dropout-0.10
nb classes: 2
nb samples: 770
dim shap value matrix: (770, 30321)
[['0' 'female']
 ['1' 'male']]

Label breakdown for assay_epiclass
0 labels missing and ignored from count
input: 70
h3k27ac: 70
h3k27me3: 70
h3k36me3: 70
h3k4me1: 70
h3k4me3: 70
h3k9me3: 70
wgbs-standard: 70
rna_seq: 70
mrna_seq: 70
wgbs-pbat: 70
For a total of 770 examples


Label breakdown for harmonized_donor_sex
0 labels missing and ignored from count
female: 393
male: 293
unknown: 74
mixed: 10
For a total of 770 examples

Subsampled metadata with ['h3k27ac']

Label breakdown for harmonized_donor_sex
0 labels missing and ignored from count
female: 41
male: 25
unknown: 4
For a total of 70 examples

41
(41, 30321)
Chose 41 samples from 770 samples
Pairwise feature overlap statistics:


Unnamed: 0,0
count,820.0
mean,59.781707
std,12.225592
min,21.0
25%,52.75
50%,61.0
75%,68.0
max,85.0


Intersection of all features: 10 features
Union of all features: 371 features

[29956, 28774, 28775, 16809, 15888, 5651, 15219, 28889, 11325, 8574]
Most frequent features in 0th quantile: 272 features
Most frequent features in 90th quantile: 29 features
Most frequent features in 95th quantile: 17 features
Most frequent features in 99th quantile: 10 features
Average expected contribution of 10 feature if uniform importance:0.03298%
Average expected contribution of 1 feature if uniform importance:0.00330%


Unnamed: 0,0
count,41.0
mean,0.033568
std,0.000454
min,0.032125
25%,0.033325
50%,0.03366
75%,0.033872
max,0.034392


Unnamed: 0,0,1,2,3,4,5,6,7,8,9
count,41.0,41.0,41.0,41.0,41.0,41.0,41.0,41.0,41.0,41.0
mean,0.003891,0.002805,0.003593,0.003589,0.003553,0.003605,0.002939,0.002857,0.002982,0.003755
std,0.000126,0.000156,8.4e-05,6.4e-05,5.8e-05,7.5e-05,0.000177,0.000104,6.7e-05,0.000103
min,0.003618,0.002407,0.003447,0.003441,0.003434,0.003436,0.002348,0.002618,0.002842,0.003442
25%,0.003805,0.002709,0.003544,0.003544,0.00351,0.003539,0.002867,0.002787,0.002933,0.003688
50%,0.003883,0.002839,0.003587,0.003595,0.003539,0.003595,0.002953,0.002865,0.002984,0.003754
75%,0.003993,0.002911,0.003635,0.003617,0.003595,0.003672,0.003081,0.002941,0.003037,0.00382
max,0.00414,0.003025,0.003802,0.003707,0.00367,0.003748,0.003143,0.003019,0.00311,0.00394


One sample


Unnamed: 0,0
count,30321.0
mean,0.003298
std,2.4e-05
min,0.002654
25%,0.003292
50%,0.003298
75%,0.003304
max,0.003891


10fold-binary-l1-50_dropout-0.25
nb classes: 2
nb samples: 770
dim shap value matrix: (770, 30321)
[['0' 'female']
 ['1' 'male']]

Label breakdown for assay_epiclass
0 labels missing and ignored from count
input: 70
h3k27ac: 70
h3k27me3: 70
h3k36me3: 70
h3k4me1: 70
h3k4me3: 70
h3k9me3: 70
wgbs-standard: 70
rna_seq: 70
mrna_seq: 70
wgbs-pbat: 70
For a total of 770 examples


Label breakdown for harmonized_donor_sex
0 labels missing and ignored from count
female: 393
male: 293
unknown: 74
mixed: 10
For a total of 770 examples

Subsampled metadata with ['h3k27ac']

Label breakdown for harmonized_donor_sex
0 labels missing and ignored from count
female: 41
male: 25
unknown: 4
For a total of 70 examples

41
(41, 30321)
Chose 41 samples from 770 samples
Pairwise feature overlap statistics:


Unnamed: 0,0
count,820.0
mean,60.407317
std,12.444944
min,20.0
25%,54.0
50%,62.0
75%,69.0
max,86.0


Intersection of all features: 11 features
Union of all features: 379 features

[29956, 28774, 28775, 16809, 18061, 15888, 5651, 15219, 28889, 11325, 8574]
Most frequent features in 0th quantile: 273 features
Most frequent features in 90th quantile: 31 features
Most frequent features in 95th quantile: 19 features
Most frequent features in 99th quantile: 11 features
Average expected contribution of 11 feature if uniform importance:0.03628%
Average expected contribution of 1 feature if uniform importance:0.00330%


Unnamed: 0,0
count,41.0
mean,0.036886
std,0.000498
min,0.035357
25%,0.036651
50%,0.037012
75%,0.037184
max,0.037721


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10
count,41.0,41.0,41.0,41.0,41.0,41.0,41.0,41.0,41.0,41.0,41.0
mean,0.003955,0.00307,0.00289,0.003616,0.003621,0.003599,0.003625,0.002898,0.002839,0.003032,0.003741
std,0.000144,5.5e-05,0.00013,9e-05,7.1e-05,7.2e-05,8e-05,0.000196,0.000111,5.4e-05,0.0001
min,0.003648,0.00293,0.002544,0.00346,0.003463,0.003455,0.003448,0.002248,0.002582,0.002919,0.003439
25%,0.00385,0.003043,0.002815,0.003558,0.003572,0.003551,0.003553,0.002817,0.00276,0.002995,0.003681
50%,0.003948,0.003078,0.002927,0.003606,0.00362,0.003581,0.003616,0.002913,0.002851,0.003033,0.003741
75%,0.00406,0.003104,0.002988,0.003653,0.00365,0.003653,0.003689,0.003056,0.002928,0.003076,0.003806
max,0.004243,0.003155,0.003075,0.003838,0.00376,0.003737,0.003772,0.003126,0.003006,0.00314,0.003922


One sample


Unnamed: 0,0
count,30321.0
mean,0.003298
std,2.6e-05
min,0.002676
25%,0.003292
50%,0.003298
75%,0.003304
max,0.003955


In [5]:
# Define the bed coordinates
bed1_path = "/home/local/USHERBROOKE/rabj2301/Projects/epilap/output/models/harmonized_donor_sex_1l_3000n/10fold-binary-l1-50_dropout-0.25/split0/SHAP/frequent_features_99_female_h3k27ac.bed"
bed2_path = "/home/local/USHERBROOKE/rabj2301/Projects/epilap/output/models/harmonized_donor_sex_1l_3000n/10fold-binary-l1-500_dropout-0.10/split0/SHAP/frequent_features_99_female_h3k27ac.bed"
with open(bed1_path, encoding="utf8") as f:
    bed1 = set(tuple(line.strip().split("\t")) for line in f)
with open(bed2_path, encoding="utf8") as f:
    bed2 = set(tuple(line.strip().split("\t")) for line in f)

print(len(bed1 & bed2))
print(len(bed1 | bed2))

10
11
