In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
"""Initial analysis of shap values behavior."""
# pylint: disable=redefined-outer-name, expression-not-assigned, import-error, not-callable, pointless-statement, no-value-for-parameter, undefined-variable, unused-argument
from __future__ import annotations

import copy
import itertools
from collections import Counter
from pathlib import Path
from typing import Dict, List, Set, Tuple

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.io as pio
from IPython.display import display
from scipy.special import softmax

pio.renderers.default = "notebook"

from epi_ml.core import metadata

# from epi_ml.core.data import UnknownData
# from epi_ml.core.hdf5_loader import Hdf5Loader
# from epi_ml.core.model_pytorch import LightningDenseClassifier
# from epi_ml.core.shap_values import SHAP_Analyzer, SHAP_Handler

In [3]:
%matplotlib inline

In [36]:
def load_chroms(chrom_file):
    """Return sorted chromosome names list."""
    with open(chrom_file, "r", encoding="utf-8") as file:
        chroms = []
        for line in file:
            line = line.rstrip()
            if line:
                chroms.append(line.split()[0])
    chroms.sort()
    return chroms


load_chroms(
    "/home/local/USHERBROOKE/rabj2301/Projects/epilap/input/chromsizes/hg38.noy.chrom.sizes"
)

['chr1',
 'chr10',
 'chr11',
 'chr12',
 'chr13',
 'chr14',
 'chr15',
 'chr16',
 'chr17',
 'chr18',
 'chr19',
 'chr2',
 'chr20',
 'chr21',
 'chr22',
 'chr3',
 'chr4',
 'chr5',
 'chr6',
 'chr7',
 'chr8',
 'chr9',
 'chrX']

In [4]:
home = Path("/home/local/USHERBROOKE/rabj2301/Projects")
input_dir = home / "epilap/input"
metadata_path = (
    input_dir
    / "metadata/hg38_2023_epiatlas_dfreeze_plus_encode_noncore_formatted_JR.json"
)

output = home / "epilap/output"
# logdir = output / "logs/hg38_2022-epiatlas/shap"
# model_dir = output / "models/split0"

logdir = output / "models/SHAP/harmonized_donor_sex_1l_3000n-no_validation-binary"
model_dir = logdir

my_meta = metadata.Metadata(metadata_path)
meta_copy = copy.deepcopy(my_meta)

In [5]:
shap_values_path = (
    logdir / "shap_explain_harmonized_donor_sex_evaluation_2023-05-05_00-10-11.npz"
)
background_info_path = (
    logdir
    / "shap_explain_harmonized_donor_sex_explainer_background_2023-05-04_22-48-55.npz"
)

In [6]:
def select_shap_samples(shap_dict, n: int) -> Dict[str, List[np.ndarray]]:
    """Return a subset of shap values and their ids."""
    selected_shap_samples = {"shap": [], "ids": []}
    total_samples = len(shap_dict["ids"])
    selected_indices = np.random.choice(total_samples, n, replace=False)

    for class_shap_values in shap_dict["shap"]:
        selected_shap_samples["shap"].append(class_shap_values[selected_indices, :])

    selected_shap_samples["ids"] = [shap_dict["ids"][idx] for idx in selected_indices]

    return selected_shap_samples

In [7]:
with open(shap_values_path, "rb") as f:
    shap_values_archive = np.load(f)
    dict(shap_values_archive.items())

list(shap_values_archive.keys())

['evaluation_md5s', 'shap_values', 'classes']

In [8]:
with open(background_info_path, "rb") as f:
    explainer_background = np.load(f)
    explainer_background = dict(explainer_background.items())

list(explainer_background.keys())

['background_md5s', 'background_expectation', 'classes']

In [9]:
eval_md5s = shap_values_archive["evaluation_md5s"]
shap_matrices = shap_values_archive["shap_values"]

print(f"nb classes: {len(shap_matrices)}")
print(f"nb samples: {len(eval_md5s)}")
print(f"dim 1 shap value matrix: {shap_matrices[0].shape}")
print(shap_values_archive["classes"])

nb classes: 2
nb samples: 770
dim 1 shap value matrix: (770, 30321)
[['0' 'female']
 ['1' 'male']]


In [10]:
for md5 in list(my_meta.md5s):
    if md5 not in set(eval_md5s):
        del my_meta[md5]

In [11]:
my_meta.display_labels("assay_epiclass")
my_meta.display_labels("harmonized_donor_sex")


Label breakdown for assay_epiclass
0 labels missing and ignored from count
input: 70
h3k27ac: 70
h3k27me3: 70
h3k36me3: 70
h3k4me1: 70
h3k4me3: 70
h3k9me3: 70
wgbs-standard: 70
rna_seq: 70
mrna_seq: 70
wgbs-pbat: 70
For a total of 770 examples


Label breakdown for harmonized_donor_sex
0 labels missing and ignored from count
female: 393
male: 293
unknown: 74
mixed: 10
For a total of 770 examples



In [12]:
# shap_values_dict_new = select_shap_samples(shap_values_dict, 2)
# md5sums = shap_values_dict_new["ids"]
# /lustre06/project/6007017/rabyj/epilap/input/hdf5_list/hg38_2023-01-epiatlas-freeze/shap_assay_background.list
# /lustre06/project/6007017/rabyj/epilap/input/hdf5_list/hg38_2023-01-epiatlas-freeze/shap_assay_explain.list
# my_meta[md5sums[0]]

In [13]:
# pd.DataFrame(shap_values_dict['ids']).to_csv(logdir / "shap_sample_md5sums.list", index=False, header=False)

In [14]:
# hdf5_loader = Hdf5Loader(
#     chrom_file=input_dir / "chromsizes/hg38.noy.chrom.sizes", normalization=True
# )
# hdf5_loader.load_hdf5s(input_dir / "hdf5_list/100kb_all_none.list", md5s=md5sums)
# print(len(hdf5_loader.signals))
# dset = UnknownData(
#     md5sums, [hdf5_loader.signals[md5] for md5 in md5sums], y=None, y_str=None
# )
# model = LightningDenseClassifier.restore_model(model_dir)
# model.mapping

In [15]:
# my_meta.remove_category_subsets("track_type", ["raw", "fc", "Unique_raw"])
# # len(my_meta)
# my_meta.remove_small_classes(10, "assay")
# my_meta.display_labels("assay")
# my_meta.display_labels("track_type")

In [16]:
def average_impact(shap_values_matrices):
    """Return average absolute shap values."""
    shap_abs = np.zeros(shap_values_matrices[0].shape)
    for matrix in shap_values_matrices:
        shap_abs += np.absolute(matrix)
    shap_abs /= len(shap_values_matrices)
    return shap_abs

In [17]:
def n_most_important_features(sample_shaps, n):
    """Return features with highest absolute shap values."""
    return np.flip(np.argsort(np.absolute(sample_shaps)))[:n]

In [18]:
def box_plot(avg_impact):
    """Print a box plot"""
    px.box(y=avg_impact.sum(axis=0)).show()

In [19]:
def subsample_md5s(
    md5s: List[str], metadata: metadata.Metadata, category_label: str, labels: List[str]
) -> List[int]:
    """Subsample md5s based on metadata filtering provided, for a given category and filtering labels.

    Args:
            md5s (list): A list of MD5 hashes.
            metadata (Metadata): A metadata object containing the data to be filtered.
            category_label (str): The category label to be used for filtering the metadata.
            labels (list): A list of labels to be used for selecting category subsets in the metadata.

    Returns:
            list: A list of indices corresponding to the selected md5s.
    """
    meta = copy.deepcopy(metadata)
    meta.select_category_subsets(category_label, labels)
    chosen_idxs = []
    for i, md5 in enumerate(md5s):
        if md5 in meta:
            chosen_idxs.append(i)
    return chosen_idxs

In [20]:
my_meta.select_category_subsets("assay_epiclass", ["h3k27ac"])
# meta_copy.select_category_subsets("assay_epiclass", ["h3k27me3"])
# chosen_idxs_h3k27me3 = []
# for i, md5 in enumerate(eval_md5s):
#     if md5 in meta_copy:
#         chosen_idxs_h3k27me3.append(i)

# first_class_shap = shap_matrices[0]
# first_class_shap = first_class_shap[chosen_idxs,:]
# h3k27me3_class_shap = shap_matrices[0][chosen_idxs_h3k27me3,:]

In [44]:
chosen_idxs = subsample_md5s(eval_md5s, my_meta, "harmonized_donor_sex", ["female"])
print(len(chosen_idxs))

first_class_shap = shap_matrices[0]
selected_first_class_shap = first_class_shap[chosen_idxs, :]
print(selected_first_class_shap.shape)

41
(41, 30321)


In [45]:
# top_500_features = []
# for sample in selected_first_class_shap:
#     top_500_features.append(n_most_important_features(sample, 500))

# top_500_features_all = []
# for sample in first_class_shap:
#     top_500_features_all.append(n_most_important_features(sample, 500))

top_100_features = []
for sample in selected_first_class_shap:
    top_100_features.append(n_most_important_features(sample, 100))

top_100_features_all = []
for sample in first_class_shap:
    top_100_features_all.append(n_most_important_features(sample, 100))

In [34]:
def get_features_in_centile(
    pairwise_intersections: List[List[int]], centile_list: List[int]
) -> Dict[int, List[int]]:
    """
    Get a list of features in the specified frequency centile from multiple feature lists.

    This function takes a list of feature lists and a centile value. It calculates the occurrence frequency
    of each feature and returns the list of features in the specified frequency centile.

    Args:
        feature_lists (List[List[int]]): A list of feature lists, where each inner list contains feature indices.
        centile_list (List[int]: The centile values for which the features will be returned.

    Returns:
        Dict[int:List[int]]: A dict containing the list of features in each specified frequency centile.
    """
    # Compute the features in the specified centiles
    intersection_counter = Counter()
    for feature_set in pairwise_intersections:
        intersection_counter.update(feature_set)
    sorted_features = sorted(
        intersection_counter.items(), key=lambda x: x[1], reverse=True
    )

    centile_features_dict = {}
    for centile in centile_list:
        num_features = len(sorted_features)
        centile_start = int(np.percentile(range(num_features), centile))
        centile_end = int(np.percentile(range(num_features), centile + 1))
        centile_features = [
            feature for feature, _ in sorted_features[centile_start:centile_end]
        ]

        centile_features_dict[centile] = centile_features

    return centile_features_dict


def feature_overlap_stats(
    feature_lists: List[List[int]], centile_list: list[int]
) -> Tuple[float, float, Set[int], Set[int], Dict[int, List[int]]]:
    """
    Calculate the statistics of feature overlap between multiple feature lists.

    This function takes a list of feature lists and calculates the median and average
    pairwise overlaps between them. It also computes the union and intersection of all features
    in the given feature lists.

    Args:
        feature_lists (List[List[int]]): A list of feature lists, where each inner list contains feature indices.

    Returns:
        Tuple[float, float, Set[int], Set[int]]: A tuple containing the median overlap, average overlap,
                                                 intersection of all features, and union of all features.
    """
    # Compute the overlap between two feature lists
    all_pairwise_overlaps = [
        set(sample1) & set(sample2)
        for sample1, sample2 in itertools.combinations(feature_lists, 2)
    ]
    all_pairwise_overlaps_len = [len(x) for x in all_pairwise_overlaps]

    # Compute the median and average overlap
    median_overlap = np.median(all_pairwise_overlaps_len)
    average_overlap = np.mean(all_pairwise_overlaps_len)

    centile_dict = {}

    # Union and intersection of all features
    all_features_union: Set[int] = set()
    all_features_intersection: Set[int] = set(feature_lists[0])
    for feature_set in feature_lists:
        all_features_union.update(feature_set)
        all_features_intersection &= set(feature_set)

    return median_overlap, average_overlap, all_features_intersection, all_features_union, centile_dict  # type: ignore

In [53]:
(
    median_overlap,
    average_overlap,
    features_intersection,
    features_union,
    _,
) = feature_overlap_stats(top_100_features, [])
print(f"{len(chosen_idxs)} samples stats:")
print("Median overlap:", median_overlap)
print(f"Average overlap: {average_overlap:.2f}")
print(f"Intersection of all features: {len(features_intersection)} features")
print(f"Union of all features: {len(features_union)} features\n")
print(list(features_intersection)[8])

41 samples stats:
Median overlap: 66.5
Average overlap: 65.75
Intersection of all features: 12 features
Union of all features: 314 features

28917


In [None]:
(
    median_overlap,
    average_overlap,
    features_intersection,
    features_union,
    _,
) = feature_overlap_stats(top_100_features_all, [])
print("All 770 samples stats:")
print("Median overlap:", median_overlap)
print(f"Average overlap: {average_overlap:.2f}")
print(f"Intersection of all features: {len(features_intersection)} features")
print(f"Union of all features: {len(features_union)} features")

In [None]:
(
    median_overlap,
    average_overlap,
    features_intersection,
    features_union,
    _,
) = feature_overlap_stats(top_500_features, [])
print(f"{len(chosen_idxs)} samples stats:")
print("Median overlap:", median_overlap)
print(f"Average overlap: {average_overlap:.2f}")
print(f"Intersection of all features: {len(features_intersection)} features")
print(f"Union of all features: {len(features_union)} features\n")

393 samples stats:
Median overlap: 127.0
Average overlap: 151.05992626058057
Intersection of all features: 0 features
Union of all features: 6733 features



In [None]:
(
    median_overlap,
    average_overlap,
    features_intersection,
    features_union,
    _,
) = feature_overlap_stats(top_500_features_all, [])
print("All 770 samples stats:")
print("Median overlap:", median_overlap)
print(f"Average overlap: {average_overlap:.2f}")
print(f"Intersection of all features: {len(features_intersection)} features")
print(f"Union of all features: {len(features_union)} features")

All 770 samples stats:
Median overlap: 130.0
Average overlap: 150.47419992231437
Intersection of all features: 0 features
Union of all features: 7729 features


In [56]:
# print(first_class_shap[0,:])
N = len(features_intersection)
print(
    f"Average expected contribution of {N} feature if uniform importance:{N/30321*100:.5f}%"
)
print(
    f"Average expected contribution of 1 feature if uniform importance:{1/30321*100:.5f}%"
)
display(
    pd.DataFrame(
        softmax(selected_first_class_shap, axis=1)[:, list(features_intersection)].sum(
            axis=1
        )
        * 100
    ).describe()
)
display(
    pd.DataFrame(
        softmax(selected_first_class_shap, axis=1)[:, list(features_intersection)] * 100
    ).describe()
)

# display(pd.DataFrame(softmax(np.absolute(first_class_shap), axis=1)[5,:]*100).describe())
# display(pd.DataFrame(first_class_shap[0,:]).describe())
# display(pd.DataFrame(softmax(first_class_shap[0,:]*100)).describe())
# sum(softmax(first_class_shap[0,:]))

print("One sample")
probs_1sample = pd.DataFrame(softmax(first_class_shap, axis=1)[0, :] * 100)
display(probs_1sample.describe())
# print(np.percentile(probs_1sample, 90))
# print(np.percentile(probs_1sample, 95))
# print(np.percentile(probs_1sample, 99))
# display(pd.DataFrame(selected_first_class_shap[0,:]).describe())

Average expected contribution of 12 feature if uniform importance:0.03958%
Average expected contribution of 1 feature if uniform importance:0.00330%


Unnamed: 0,0
count,41.0
mean,0.042809
std,0.000596
min,0.041562
25%,0.042539
50%,0.042792
75%,0.043162
max,0.044316


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11
count,41.0,41.0,41.0,41.0,41.0,41.0,41.0,41.0,41.0,41.0,41.0,41.0
mean,0.003523,0.004291,0.002748,0.00268,0.003651,0.003895,0.002838,0.003501,0.004723,0.003699,0.003522,0.003739
std,4.7e-05,0.000219,0.000255,0.000144,7.5e-05,0.000218,0.000149,4.7e-05,0.000627,0.000153,5.5e-05,0.000105
min,0.003423,0.003811,0.001918,0.002352,0.003509,0.003459,0.00245,0.003418,0.003799,0.003419,0.003405,0.003429
25%,0.003498,0.004144,0.002634,0.002577,0.003594,0.003754,0.002746,0.003475,0.004371,0.003602,0.003483,0.003672
50%,0.003516,0.004268,0.002766,0.002691,0.003654,0.003874,0.002868,0.003492,0.004546,0.003682,0.003512,0.003736
75%,0.003539,0.004463,0.002955,0.002786,0.003694,0.004054,0.002936,0.003532,0.004812,0.003792,0.003564,0.003805
max,0.003617,0.004729,0.003048,0.002915,0.003846,0.004517,0.00305,0.003595,0.006573,0.004163,0.003629,0.003932


One sample


Unnamed: 0,0
count,30321.0
mean,0.003298
std,1.3e-05
min,0.003004
25%,0.003294
50%,0.003298
75%,0.003302
max,0.003674


In [None]:
# px.box(first_class_shap.copy())
# px.box(softmax(np.absolute(first_class_shap)))