In [None]:
"""Initial analysis of shap values behavior."""
# pylint: disable=redefined-outer-name, expression-not-assigned, import-error, not-callable, pointless-statement, no-value-for-parameter, undefined-variable, unused-argument, line-too-long, use-dict-literal, too-many-lines, unused-import, unused-variable
from __future__ import annotations

import copy
import itertools
import re
from collections import Counter
from pathlib import Path
from typing import Dict, List, Sequence, Set, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
import upsetplot
from IPython.display import display
from scipy.special import softmax  # type: ignore

pio.renderers.default = "notebook"

from epi_ml.core import metadata
from epi_ml.utils.bed_utils import bins_to_bed_ranges, write_to_bed
from epi_ml.utils.general_utility import get_valid_filename
from epi_ml.utils.shap_utils import n_most_important_features

BIOMATERIAL_TYPE = "harmonized_biomaterial_type"
CELL_TYPE = "harmonized_sample_ontology_intermediate"
ASSAY = "assay_epiclass"
SEX = "harmonized_donor_sex"
CANCER = "harmonized_sample_cancer_high"
DISEASE = "harmonized_sample_disease_high"
LIFE_STAGE = "harmonized_donor_life_stage"

In [None]:
%matplotlib inline

In [None]:
DECILES = list(np.arange(10, 100, 10) / 100)

In [None]:
def load_chroms(chrom_file):
    """Return sorted chromosome names list."""
    with open(chrom_file, "r", encoding="utf-8") as file:
        chroms = []
        for line in file:
            line = line.rstrip()
            if line:
                name, size = line.split()
                chroms.append(tuple([name, int(size)]))
    chroms.sort()
    return chroms

In [None]:
home = Path().home() / "Projects"
input_dir = home / "epilap/input"
metadata_path = (
    input_dir
    / "metadata/dfreeze-v2/hg38_2023-epiatlas-dfreeze_v2.1_w_encode_noncore_2.json"
)
my_meta = metadata.Metadata(metadata_path)

In [None]:
chroms = load_chroms(input_dir / "chromsizes/hg38.noy.chrom.sizes")

In [None]:
category = SEX

In [None]:
output = home / "epilap/output/logs/epiatlas-dfreeze-v2.1/hg38_100kb_all_none/shap"
# logdir = (
#     output / "models/SHAP" / "harmonized_donor_sex_1l_3000n/100kb_all_none_blklst/split0/"
# )
# logdir = output / "2023-01-epiatlas-freeze/hg38_1kb_all_none/harmonized_donor_sex_1l_200n/10fold-l1-100_l2-0.01_dropout-0.50/split0/SHAP/"
# logdir = output / "2023-01-epiatlas-freeze/hg38_100kb_all_none/harmonized_donor_sex/predict-10fold-binary/lgbm-dart/lgbm-l1-0.01-l2-0.01/SHAP/split0/"
logdir1 = output / f"{category}_1l_3000n/w-mixed/10fold-oversample/split0/shap/"
# logdir2 = output / f"{category}_1l_3000n/10fold/split0/shap/6hist_6ct/"

In [None]:
RESOLUTION = 100 * 1000

In [None]:
def get_archives(shap_values_dir: str | Path):
    """Return shap values and explainer background archives. from npz files."""
    shap_values_dir = Path(shap_values_dir)
    try:
        shap_values_path = next(shap_values_dir.glob("*evaluation*.npz"))
        background_info_path = next(shap_values_dir.glob("*explainer_background*.npz"))
    except StopIteration as err:
        raise FileNotFoundError(
            f"Could not find shap values or explainer background archives in {shap_values_dir}"
        ) from err

    with open(shap_values_path, "rb") as f:
        shap_values_archive = np.load(f)
        shap_values_archive = dict(shap_values_archive.items())

    with open(background_info_path, "rb") as f:
        explainer_background = np.load(f)
        explainer_background = dict(explainer_background.items())

    return shap_values_archive, explainer_background

In [None]:
print(get_archives(logdir1)[0]["classes"])
# print(get_archives(logdir2)[0]["classes"])

In [None]:
def subsample_md5s(
    md5s: List[str], metadata: metadata.Metadata, category_label: str, labels: List[str]
) -> List[int]:
    """Subsample md5s based on metadata filtering provided, for a given category and filtering labels.

    Args:
            md5s (list): A list of MD5 hashes.
            metadata (Metadata): A metadata object containing the data to be filtered.
            category_label (str): The category label to be used for filtering the metadata.
            labels (list): A list of labels to be used for selecting category subsets in the metadata.

    Returns:
            list: A list of indices corresponding to the selected md5s.
    """
    meta = copy.deepcopy(metadata)
    meta.select_category_subsets(category_label, labels)
    chosen_idxs = []
    for i, md5 in enumerate(md5s):
        if md5 in meta:
            chosen_idxs.append(i)
    return chosen_idxs

In [None]:
def feature_overlap_stats(
    feature_lists: List[List[int]], percentile_list: list[int]
) -> Tuple[Set[int], Set[int], Dict[int, List[int]], go.Figure]:
    """
    Calculate the statistics of feature overlap between multiple feature lists.

    This function takes a list of feature lists and computes feature frequency percentiles.
    It also computes the union and intersection of all features from the given feature lists.

    Args:
        feature_lists (List[List[int]]): A list of feature lists, where each inner list contains feature indices.
        percentile_list (List[int]: The percentile values for which the most frequent features will be returned.

    Returns:
        Tuple[Set[int], Set[int], Dict[int, List]]: A tuple containing
        1) intersection of all features
        2) union of all features
        3) a dict containing the list of features present present in each file number percentile.
        4) a plotly figure showing the histogram of feature frequency
    """
    nb_files = len(feature_lists)
    if not feature_lists:
        raise ValueError("Input list must not be empty.")

    for percentile in percentile_list:
        if percentile < 0 or percentile > 100:
            raise ValueError("Percentile values must be between 0 and 100.")

    # Most frequent features (per percentile)
    feature_counter = Counter()
    for feature_list in feature_lists:
        feature_counter.update(feature_list)

    df = pd.DataFrame.from_dict(data=feature_counter, orient="index").reset_index()
    df.columns = ["Feature", "Count"]

    # Histogram of feature frequency
    nb_features = len(feature_counter)
    nbins = int(np.sqrt(nb_features))
    hist = px.histogram(
        df,
        x="Count",
        title=f"Top N features: frequency of {nb_features} features in {nb_files} files",
        nbins=nbins,
        range_x=[0, nb_files],
    )
    hist.update_layout(xaxis_title="Nb files", yaxis_title="Feature count")

    # Feature frequency stats
    describe_percentiles = sorted([0.25, 0.5, 0.75] + [p / 100 for p in percentile_list])
    count_stats = pd.DataFrame(df["Count"].describe(percentiles=describe_percentiles))
    count_stats["% of files"] = count_stats["Count"] / nb_files * 100
    count_stats["% of files"]["count"] = "nan"
    display(count_stats)

    percentile_features_dict = {}
    for percentile in percentile_list:
        # Calculate percentile count value, then select all features >= current percentile
        curr_perc = nb_files * percentile / 100
        features_above_perc = df[df["Count"] >= curr_perc]
        percentile_features_dict[percentile] = features_above_perc["Feature"].tolist()

    # Union and intersection of all features
    all_features_union: Set[int] = set()
    all_features_intersection: Set[int] = set(feature_lists[0])
    for feature_set in feature_lists:
        all_features_union.update(feature_set)
        all_features_intersection &= set(feature_set)

    return all_features_intersection, all_features_union, percentile_features_dict, hist  # type: ignore

In [None]:
def get_shap_matrix(
    meta: metadata.Metadata,
    shap_matrices: np.ndarray,
    eval_md5s: List[str],
    label_category: str,
    selected_labels: List[str],
    class_idx: int,
) -> Tuple[np.ndarray, List[int]]:
    """Generates a SHAP matrix corresponding to a selected subset of samples.

    This function selects a subset of samples based on specified criteria
    and then generates a SHAP matrix for these selected samples. It filters
    the metadata if a specific target subsample is provided, and selects a
    subset of samples that are identified by their md5 hash. It then selects
    the SHAP values of these samples under the matrix of the given class number.

    Args:
        meta (metadata.Metadata): Metadata object containing information about the samples.
        shap_matrices (np.ndarray): Array of SHAP matrices for each class.
        eval_md5s (List[str]): List of md5 hashes identifying the evaluation samples.
        label_category (str): Name of the category in the metadata that contains the desired labels.
        selected_labels (List[str]): Name of the classes for which samples will be considered.
        class_idx (int): Index of the class for which the shap values matrix will be used.

    Returns:
        np.ndarray: The selected SHAP matrix for the first class and for the
                    chosen samples based on the provided criteria.
        List[int]: The indices of the chosen samples in the original SHAP matrix.

    Raises:
            IndexError: If the `class_idx` is out of bounds for the `shap_matrices`.
    """
    my_meta = copy.deepcopy(meta)

    chosen_idxs = subsample_md5s(
        md5s=eval_md5s,
        metadata=my_meta,
        category_label=label_category,
        labels=selected_labels,
    )
    if len(shap_matrices.shape) == 3:  # deepSHAP
        try:
            class_shap = shap_matrices[class_idx]
        except IndexError as err:
            raise IndexError(f"Class index {class_idx} is out of bounds.") from err

        selected_class_shap = np.array(class_shap[chosen_idxs, :])
    else:  # TreeExplainer 2D
        class_shap = shap_matrices
        selected_class_shap = class_shap[chosen_idxs]
    print(
        f"Shape of selected class ({selected_labels}) shap values: {selected_class_shap.shape}"
    )
    print(f"Chose {len(chosen_idxs)} samples from {class_shap.shape[0]} samples")
    return selected_class_shap, chosen_idxs

In [None]:
def print_feature_overlap_stats(feature_stats: Sequence):
    """Prints the statistics of feature overlap.

    This function receives the feature statistics which include the intersection,
    union and frequent features in each quantile of features. It then prints
    these statistics for easy inspection.

    Args:
        feature_stats (Sequence): Tuple containing the intersection, union and
                                  frequent features in each quantile of features.
    """
    features_intersection, features_union, frequent_features = feature_stats
    print(f"Intersection of all features: {len(features_intersection)} features")
    print(f"Fully intersecting features: {list(features_intersection)}")
    print(f"Union of all features: {len(features_union)} features\n")
    for k, v in frequent_features.items():
        print(f"Most frequent features in {k}th quantile: {len(v)} features")

In [None]:
def print_importance_info(feature_selection: List[int], shap_matrix: np.ndarray):
    """Prints the feature importance information.

    This function prints the feature importance information, which includes the
    average expected contribution of the selected features and one feature (if
    the importance was uniform), and statistical descriptions of the contributions
    of the selected features.

    Args:
        feature_selection (List[int]): The indices of the selected features.
        shap_matrix (np.ndarray): The SHAP values matrix.

    """
    N = len(feature_selection)
    nb_files, nb_bins = shap_matrix.shape
    print(
        f"Average expected contribution of {N} feature if uniform importance:{N/nb_bins*100:.5f}%"
    )
    print(
        f"Average expected contribution of 1 feature if uniform importance:{1/nb_bins*100:.5f}%"
    )
    print(f"Average contribution of selected features for {nb_files} files:")
    display(
        pd.DataFrame(
            softmax(shap_matrix, axis=1)[:, list(feature_selection)].sum(axis=1) * 100
        ).describe(percentiles=DECILES)
    )
    print(f"Individual contribution of selected features for {nb_files} files:")
    display(
        pd.DataFrame(
            softmax(shap_matrix, axis=1)[:, list(feature_selection)] * 100
        ).describe(percentiles=DECILES)
    )

In [None]:
def extract_shap_values_and_info(
    shap_logdir: str | Path, verbose: bool = True
) -> Tuple[np.ndarray, List[str], List[Tuple[str, str]]]:
    """Extract and print basic statistics about SHAP values from an archive.

    Args:
        shap_logdir (str): The directory where the SHAP values archive is located.
        verbose (bool): Whether to print basic statistics about the SHAP values.

    Returns:
        shap_matrices (np.ndarray): SHAP matrices.
        eval_md5s (List[str]): List of evaluation MD5s.
        classes (List[Tuple[str, str]]): List of classes. Each class is a tuple containing the class index and the class label.
    """
    # Extract shap values and md5s from archive
    shap_values_archive, _ = get_archives(shap_logdir)
    try:
        eval_md5s: List[str] = shap_values_archive["evaluation_md5s"]
    except KeyError:
        eval_md5s: List[str] = shap_values_archive["evaluation_ids"]
    shap_matrices: np.ndarray = shap_values_archive["shap_values"]

    # Print basic statistics about the loaded SHAP values
    if verbose:
        print(f"nb classes: {len(shap_matrices)}")
        print(f"nb samples: {len(eval_md5s)}")
        print(f"dim shap value matrix: {shap_matrices[0].shape}")
        print(f"Output classes of classifier:\n {shap_values_archive['classes']}")

    return shap_matrices, eval_md5s, shap_values_archive["classes"]

In [None]:
def run_the_whole_thing(
    metadata: metadata.Metadata,
    shap_dir: Path,
    output_dir: Path,
    label_category: str,
    top_n: int = 100,
) -> Dict[str, List[int]]:
    """Execute the complete process of SHAP value analysis.

    This function performs the complete SHAP value analysis given the metadata and the directory
    of the SHAP value files. It carries out the following steps:
    1. Load the SHAP value archives and print basic statistics.
    2. Filter the metadata to match the samples in SHAP value archives.
    For each output class class:
        3. Extract SHAP values.
        4. Determine the top N features for each sample.
        5. Compute and print feature overlap statistics.
        6. Analyze feature importance.
        7. Convert bin indices to genomic ranges and write to a BED file.
        8. Display and save a plot of importance distribution for one sample.

    Args:
        metadata (metadata.Metadata): The metadata for the samples.
        shap_logdir (Path): The directory path where SHAP value files are stored.
        label_category (str): The name of the classifier output category that computed the shaps.
        top_n (int): The number of top features to be selected for each sample. Defaults to 100.

    Returns:
        Dict[str, Dict[int, List[int]]: Dictionary where keys are class labels and values are lists
        of the most frequently occurring important (high shap) features for that class, for each computed quantile.
        (see feature_overlap_stats function for more details).

    Raises:
        KeyError: If "evaluation_md5s" or "evaluation_ids" are not found in the loaded SHAP value archives.
    """
    metadata = copy.deepcopy(metadata)

    # Extract shap values and md5s from archive
    shap_matrices, eval_md5s, classes = extract_shap_values_and_info(shap_dir)

    # Filter metadata to include only the samples that exist in the SHAP value archives
    for md5 in list(metadata.md5s):
        if md5 not in set(eval_md5s):
            del metadata[md5]

    metadata.display_labels("assay_epiclass")
    metadata.display_labels("harmonized_donor_sex")

    # Loop over each class to perform SHAP value analysis
    important_features = {}
    for class_int, class_label in classes:
        class_int = int(class_int)
        print(f"\n\nClass: {class_label} ({class_int})")

        # Get the SHAP matrix for the current class,
        # and only select samples that also correspond to that class
        shap_matrix, chosen_idxs = get_shap_matrix(
            meta=metadata,
            shap_matrices=shap_matrices,
            eval_md5s=eval_md5s,
            label_category=label_category,
            selected_labels=[class_label],
            class_idx=class_int,
        )

        if len(chosen_idxs) < 5:
            print(f"Not enough samples (5) to perform analysis on {class_label}.")
            continue

        # Computing statistics of feature overlap
        print(
            f"Selecting features with top {top_n} SHAP values for each sample of {class_label}."
        )
        top_n_features = []
        for sample in shap_matrix:
            top_n_features.append(list(n_most_important_features(sample, top_n)))

        feat_intersect, feat_union, frequent_features, hist_fig = feature_overlap_stats(
            top_n_features, [90, 95, 99]
        )
        important_features[class_label] = frequent_features

        hist_fig.write_image(
            file=output_dir / f"top{top_n}_feature_frequency_{class_label}.png",
            format="png",
        )

        # print_feature_overlap_stats(some_stats)

        chosen_percentile = 90
        feature_selection = frequent_features[chosen_percentile]

        # print_importance_info(feature_selection, shap_matrix)

        # Convert bin indices to genomic ranges and write to a BED file
        bed_vals = bins_to_bed_ranges(
            sorted(feature_selection), chroms, resolution=RESOLUTION
        )
        bed_filename = get_valid_filename(
            f"frequent_features_{chosen_percentile}_{class_label}.bed"
        )
        write_to_bed(
            bed_vals,
            output_dir / bed_filename,
            verbose=True,
        )

        # # Display and save a plot of importance distribution for one sample
        # print("One sample")
        # probs_1sample = pd.DataFrame(softmax(shap_matrix, axis=1)[0, :] * 100)
        # display(probs_1sample.describe(percentiles=DECILES))
        # fig_title = f"Importance distribution - One sample - {eval_md5s[chosen_idxs[0]]}"
        # fig = px.violin(probs_1sample, box=True, points="all", title=fig_title)
        # fig.write_image(shap_logdir / "importance_dist_1sample.png")

    return important_features

In [None]:
def run_alternative_analysis(
    metadata: metadata.Metadata,
    shap_logdir: Path,
    label_category: str,
    selected_classes: list[str],
    top_n: int = 100,
) -> None:
    """Run an alternative analysis that involves plotting feature importance for selected classes.

    This function performs the following steps:
    1. Extracts SHAP values and associated metadata.
    2. Filters out samples from the metadata that are not present in the SHAP value archives.
    3. Collects the most important features for selected classes.
    4. Plots the feature importance for samples of these selected classes using different metrics ("raw", "softmax", "rank").

    Args:
        metadata ("metadata.Metadata"): The metadata object containing sample information.
        shap_logdir (Path): The directory where the SHAP value archives are stored.
        label_category (str): The category of the label to be used for class selection.
        selected_classes (List[str]): A list of classes for which the analysis should be run.
        top_n (int, optional): The top N most important features to consider. Default is 100.

    Raises:
        ValueError: If sample indices are not unique across classes.
    """
    metadata = copy.deepcopy(metadata)

    # Extract shap values and md5s from archive
    shap_matrices, eval_md5s, classes = extract_shap_values_and_info(shap_logdir)

    # Filter metadata to include only the samples that exist in the SHAP value archives
    for md5 in list(metadata.md5s):
        if md5 not in set(eval_md5s):
            del metadata[md5]

    # collect important features for selected classes
    selected_percentile = 90
    classes_dict = {
        class_label: int(class_int)
        for class_int, class_label in classes
        if class_label in selected_classes
    }
    important_features = {}
    sample_idxs = {}
    for class_label, class_int in classes_dict.items():
        print(f"\n\nClass: {class_label} ({class_int})")

        # Get the SHAP matrix for the current class,
        # and only select samples that also correspond to that class
        shap_matrix, chosen_idxs = get_shap_matrix(
            meta=metadata,
            shap_matrices=shap_matrices,
            eval_md5s=eval_md5s,
            label_category=label_category,
            selected_labels=[class_label],
            class_idx=class_int,
        )
        sample_idxs[class_label] = chosen_idxs

        # Computing statistics of feature overlap
        top_n_features = []
        for sample in shap_matrix:
            top_n_features.append(list(n_most_important_features(sample, top_n)))

        some_stats = feature_overlap_stats(top_n_features, [selected_percentile])
        frequent_features = some_stats[2]
        important_features[class_label] = frequent_features[selected_percentile]

    all_chosen_idxs = set()
    for idxs in sample_idxs.values():
        all_chosen_idxs.update(idxs)

    if len(all_chosen_idxs) != sum(len(idxs) for idxs in sample_idxs.values()):
        raise ValueError("Sample indices are not unique across classes.")

    logdir = shap_logdir / "feature_rank_analysis"
    logdir.mkdir(parents=True, exist_ok=True)

    for class_of_interest in selected_classes:
        important_feats = important_features[class_of_interest]
        for comparison_class in selected_classes:
            shap_matrix = shap_matrices[classes_dict[comparison_class]]

            # Lists to hold md5s, metadata category, and feature ranks for each sample
            md5_list = []
            cell_type_list = []
            ranks_list = []

            for i, sample_shap_values in enumerate(shap_matrix):
                if i not in sample_idxs[comparison_class]:
                    continue
                # if i not in all_chosen_idxs:
                #     continue

                md5_list.append(eval_md5s[i])
                cell_type_list.append(metadata[eval_md5s[i]][CELL_TYPE])

                ranks = np.argsort(
                    np.argsort(-np.abs(sample_shap_values))
                )  # Ranking in descending order of absolute SHAP value
                ranks_of_important_feats = ranks[
                    important_feats
                ]  # Get the ranks of the important features
                ranks_list.append(ranks_of_important_feats)

            # Combine all the lists into a DataFrame
            ranks_df = pd.DataFrame(
                {
                    "md5sum": md5_list,
                    CELL_TYPE: cell_type_list,
                    **{
                        f"Feature_{feat}": [ranks[i] for ranks in ranks_list]
                        for i, feat in enumerate(important_feats)
                    },
                }
            )

            # Save the DataFrame to CSV
            title = f"important_{class_of_interest}_features_in_{comparison_class}_shap_matrix.csv".replace(
                " ", "_"
            )
            ranks_df.to_csv(logdir / title, index=True)
            print(
                f"Feature ranks for '{class_of_interest}' features in '{comparison_class}' samples have been saved."
            )

In [None]:
important_feat_dict = {}

In [None]:
# print(logdir1, logdir2)
print(logdir1)

In [None]:
_, eval_md5s, _ = extract_shap_values_and_info(logdir1, verbose=False)
eval_md5s = set(eval_md5s)
for md5 in list(my_meta.md5s):
    if md5 not in eval_md5s:
        del my_meta[md5]

In [None]:
my_meta.display_labels(CELL_TYPE)

In [None]:
# assay_labels = [[assay_label] for assay_label in my_meta.unique_classes(label_category=ASSAY)] + [["rna_seq", "mrna_seq"]] + [["wbgs-pbat", "wgbs-standard"]]
cell_type_labels = [
    "T cell",
    "neutrophil",
    "monocyte",
    "lymphocyte of B lineage",
    "brain",
    "myeloid cell",
]

# for assay_label_list in assay_labels:
for ct_label in cell_type_labels:
    meta = copy.deepcopy(my_meta)
    # meta.select_category_subsets(ASSAY, assay_label_list)
    meta.select_category_subsets(CELL_TYPE, [ct_label])

    if len(meta) < 5:
        continue

    # if len(assay_label_list) == 2:
    #     continue

    # name = "_".join(assay_label_list) + "_" + ct_label
    name = ct_label
    run_logdir = logdir1 / name
    run_logdir.mkdir(parents=False, exist_ok=True)

    important_features = run_the_whole_thing(
        metadata=meta,
        shap_dir=logdir1,
        output_dir=run_logdir,
        label_category=category,
        top_n=100,
    )

    feat_90 = {name: sets[90] for name, sets in important_features.items()}

    print(len(feat_90))
    if len(feat_90) <= 1:
        continue

    upset_features = upsetplot.from_contents(feat_90)

    upsetplot.UpSet(
        upset_features,
        subset_size="count",
        show_counts=True,
        show_percentages=True,
        sort_by="cardinality",
        sort_categories_by="cardinality",
    ).plot()

    plt.savefig(run_logdir / f"upset_{name}.png", bbox_inches="tight")

In [None]:
# print(important_feat_dict)

In [None]:
# new_dict = {
#     f"{assay}_{output_class}": class_set
#     for assay, dicts in important_feat_dict.items()
#     for output_class, class_set in dicts.items()
# }

# # add merge of output classes
# assays = set(assay.split("_")[0] for assay in new_dict.keys())
# for assay in assays:
#     merged_features = set()

#     for key, features in new_dict.items():
#         if key.startswith(assay):
#             merged_features.update(features)

#     new_dict[f"{assay}_all_classes"] = list(merged_features)
#     # del new_dict[f"{assay}_cancer"]
#     # del new_dict[f"{assay}_non-cancer"]

In [None]:
# new_dict.keys()

In [None]:
# upset_features = upsetplot.from_contents(new_dict)
# upsetplot.UpSet(
#     upset_features, subset_size="count", show_counts=True, show_percentages=True, sort_by="cardinality", sort_categories_by="cardinality"
# ).plot()
# plt.savefig(run_logdir / "upset_assay.png", bbox_inches='tight')

In [None]:
relevant_cell_types = [
    "T cell",
    "lymphocyte of B lineage",
    "monocyte",
    "muscle organ",
    "myeloid cell",
    "neutrophil",
]
for cell_type in relevant_cell_types:
    new_dict = {
        hist: hist_dict.get(cell_type, [])
        for hist, hist_dict in important_feat_dict.items()
    }
    upset_features = upsetplot.from_contents(new_dict)

    plot_filename = f"upset_{get_valid_filename(cell_type)}.png"
    fig = upsetplot.UpSet(
        upset_features,
        subset_size="count",
        show_counts=True,
        show_percentages=True,
        sort_by="cardinality",
        sort_categories_by="cardinality",
    ).plot()
    plt.savefig(logdir1.parent / plot_filename, dpi=300)

sample ontology specific

In [None]:
# feat_intersection = set.intersection(
#     *[set(feat_list[90]) for feat_list in important_features.values()]
# )

# male_feat = set(important_features["male"][90])
# female_feat = set(important_features["female"][90])

# only_male = male_feat - female_feat
# only_female = female_feat - male_feat

# for feat_set, name in zip([only_female, only_male], ["only_female", "only_male"]):
#     bed_vals = bins_to_bed_ranges(sorted(feat_set), chroms, resolution=RESOLUTION)
#     var_name = f"{feat_set=}".split("=")[0]
#     write_to_bed(
#         bed_vals,
#         logdir / f"frequent_features_{90}_{name}.bed",
#         verbose=False,
#     )

In [None]:
# run_alternative_analysis(
#     metadata=my_meta,
#     shap_logdir=logdir,
#     label_category=SEX,
#     selected_classes=["female", "male"],
#     top_n=100,
# )

In [None]:
def rank_violin_plots(logdir: str | Path, selected_classes: list[str]) -> None:
    """Create violin plots of important feature ranks for selected classes."""
    logdir = Path(logdir)
    # Iterate through each class of interest
    for class_of_interest in selected_classes:
        # Initialize an empty figure
        fig = go.Figure()

        # Iterate through each class to add its violin plot to the figure
        for comparison_class in selected_classes:
            df = pd.read_csv(
                logdir
                / f"important_{class_of_interest}_features_in_{comparison_class}_shap_matrix.csv"
            )
            df = df.filter(
                like="Feature_"
            )  # Remove non-feature columns from DataFrame for plotting

            # Create violin plot for the current comparison_class
            violin = go.Violin(
                y=df.values.flatten(),  # Flattened feature ranks
                name=f"{comparison_class} (n={df.shape[0]})",  # Name of the violin plot
                box_visible=True,  # Display box inside the violin
                line_color=px.colors.qualitative.Plotly[
                    len(fig.data)
                ],  # Different color for each violin
                points="all",  # Display all points
            )
            fig.add_trace(violin)

            print(df.shape)
        # Set title and axis labels
        fig.update_layout(
            title=f"Violin plot of ranks for important features of '{class_of_interest}' ({df.shape[1]} features)",
            xaxis_title="Source Matrix",
            yaxis_title="Feature Rank",
        )

        # Show the figure
        fig.write_image(
            logdir / f"important_{class_of_interest}_feature_ranks_violin_plot.png"
        )
        fig.write_html(
            logdir / f"important_{class_of_interest}_feature_ranks_violin_plot.html"
        )
        fig.show()

In [None]:
logdir

In [None]:
# rank_violin_plots(
#     logdir=logdir / "feature_rank_analysis", selected_classes=["female", "male"]
# )

## Samples ontology class pairs important features overlap

In [None]:
# feature_union = set()
# feature_intersection = set(list(important_features.values())[0][90])
# for label, features in important_features.items():
#     features_90 = features[90]
#     print(f"\n\nClass: {label}")
#     print(f"Most frequent features in 90th quantile: {features_90}")
#     feature_union.update(features_90)
#     feature_intersection &= set(features_90)

# print(f"\n\nUnion of all features: {len(feature_union)} features")
# print(
#     f"\n\nIntersection of all features: {len(feature_intersection)} features: {feature_intersection}"
# )

In [None]:
def compute_intersections(important_features: dict) -> pd.DataFrame:
    """Compute all possible intersections between pairs of sets and store them in a DataFrame.

    Args:
        important_features (dict): Dictionary where keys are class labels and values are sets of important features.

    Returns:
        pd.DataFrame: A DataFrame containing the intersections and their properties.
    """
    sets = {
        label: set(quantile_features[90])
        for label, quantile_features in important_features.items()
    }
    records = []

    for set1_info, set2_info in itertools.combinations(sets.items(), 2):
        label1, set1 = set1_info
        label2, set2 = set2_info
        intersection = set1.intersection(set2)

        record = {
            "Set1_Label": label1,
            "Set2_Label": label2,
            "Set1_Size": len(set1),
            "Set2_Size": len(set2),
            "Intersection": intersection,
            "Intersection_Size": len(intersection),
        }

        records.append(record)

    return pd.DataFrame(records)

In [None]:
# df = compute_intersections(important_features)
# df.to_csv(logdir / "feature_intersections_q90.csv", index=False)

In [None]:
# counter = Counter()
# for feature_set in df["Intersection"]:
#     counter.update(feature_set)

# for k, v in counter.most_common():
#     print(f"{k} {v}")

In [None]:
# feature_selection = []
# sort_order = np.argsort(feature_selection)

# bed_vals = bins_to_bed_ranges(sorted(feature_selection), chroms, resolution=RESOLUTION)

# write_to_bed(bed_ranges=bed_vals, bed_path=logdir / "frequent_features.bed", verbose=True)

In [None]:
# chroms

In [None]:
# for val in np.array(bed_vals)[sort_order.argsort()]:
#     print(val)