# Analyze shaps

## SETUP

In [None]:
"""Initial analysis of shap values behavior."""
# pylint: disable=redefined-outer-name, expression-not-assigned, import-error, not-callable, pointless-statement, no-value-for-parameter, undefined-variable, unused-argument, line-too-long, use-dict-literal, too-many-lines, unused-import, unused-variable
from __future__ import annotations

import copy
import itertools
import json
import re
import shutil
from collections import Counter, defaultdict
from pathlib import Path
from typing import Dict, List, Sequence, Set, Tuple

import matplotlib.patheffects as path_effects
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px  # type: ignore
import plotly.graph_objects as go  # type: ignore
import plotly.io as pio  # type: ignore
import upsetplot  # type: ignore
from IPython.display import display
from scipy.special import softmax  # type: ignore

pio.renderers.default = "notebook"

from epi_ml.core import metadata
from epi_ml.core.data_source import EpiDataSource
from epi_ml.utils.bed_utils import bins_to_bed_ranges, write_to_bed
from epi_ml.utils.general_utility import get_valid_filename
from epi_ml.utils.metadata_utils import count_combinations, count_pairs
from epi_ml.utils.shap.analyze_shaps_kfold import compare_kfold_shap_analysis
from epi_ml.utils.shap.shap_analysis import (
    DECILES,
    feature_overlap_stats,
    print_feature_overlap_stats,
    print_importance_info,
)
from epi_ml.utils.shap.shap_utils import (
    extract_shap_values_and_info,
    get_archives,
    get_shap_matrix,
    n_most_important_features,
)
from epi_ml.utils.time import time_now_str

BIOMATERIAL_TYPE = "harmonized_biomaterial_type"
CELL_TYPE = "harmonized_sample_ontology_intermediate"
ASSAY = "assay_epiclass"
SEX = "harmonized_donor_sex"
CANCER = "harmonized_sample_cancer_high"
DISEASE = "harmonized_sample_disease_high"
LIFE_STAGE = "harmonized_donor_life_stage"
TRACK = "track_type"

In [None]:
%matplotlib inline

In [None]:
home = Path().home() / "Projects"
input_dir = home / "epilap/input"
metadata_path = (
    input_dir
    / "metadata/dfreeze-v2/hg38_2023-epiatlas-dfreeze_v2.1_w_encode_noncore_2.json"
)
my_meta = metadata.Metadata(metadata_path)
chroms = EpiDataSource.load_external_chrom_file(
    input_dir / "chromsizes/hg38.noy.chrom.sizes"
)

In [None]:
RESOLUTION = 100 * 1000

In [None]:
category = CELL_TYPE

In [None]:
my_meta.remove_missing_labels(category)
my_meta.display_labels(category)

In [None]:
# count_combinations(my_meta, [ASSAY, CELL_TYPE, TRACK])

## Analyze important features over all folds

In [None]:
logs_dir = (
    Path.home()
    / "mounts/narval-mount/project-rabyj/epilap/output/logs/epiatlas-dfreeze-v2.1/hg38_100kb_all_none"
)
if not logs_dir.exists():
    raise ValueError(f"Logs dir {logs_dir} does not exist")

### Collect information from all splits

In [None]:
def join_important_features(
    parent_folder: Path, top_n: int, frequency_threshold: str
) -> Dict[str, Dict[str, Dict[str, List[int]]]]:
    """Join important features from all folds

    Args:
        parent_folder (Path): Parent folder of split folders.
        top_n (int): Number of top features to consider for each sample,
                    used to select the analysis folder.
        frequency_threshold (str): Frequency threshold used to select the analysis folder.

    Returns:
        Dict: all_important_features
        - level 1: folder_name (str): split_dict
        - level 2: split_name (str) : all_class_features (dict)
        - level 3: class_label (classifier outputs, str): important features for each frequency threshold (dict)
        - level 4: frequency_threshold (str, 0 to 100): list of features (List[int])
    """
    print("WARNING: Consider using a local copy of the data to speed up the process.")
    all_important_features = defaultdict(dict)
    all_folders = parent_folder.glob(
        f"split*/shap/analysis_n{top_n}_f{frequency_threshold}/*"
    )
    for important_features_path in all_folders:
        split = important_features_path.parents[2].name
        if "split" not in split:
            raise ValueError(f"Split not found in {split}")
        folder = important_features_path.name

        # # Do not consider assay+celltype subsamplings
        # if folder not in assay_labels and folder != "mixed_samples":
        #     continue

        json_path = important_features_path / "important_features.json"
        try:
            with open(json_path, "r", encoding="utf8") as f:
                features = json.load(f)
        except FileNotFoundError:
            shutil.rmtree(important_features_path)
            continue

        all_important_features[folder][split] = features

    return all_important_features

In [None]:
def perform_global_analysis(
    all_important_features: Dict[str, Dict[str, Dict[str, List[int]]]],
    resolution: int,
    chromsizes: List[Tuple[str, int]],
    output_folder: Path,
    chosen_percentile: float = 80.0,
    minimum_count: int = 8,
    top_n: int = 303,
) -> None:
    """Perform global analysis of shap values. For each
    subsampling, compute the frequency of features over all splits
    and save it to a feature_count.json file.

    Also write to bed the features that respect the chosen_percentile and
    minimum_count thresholds.
    """
    for folder_name, splits_dict in all_important_features.items():
        parent = output_folder / "global_shap_analysis"
        if not parent.exists():
            raise ValueError(f"Parent folder {parent} does not exist")
        subsampling_output = parent / f"top{top_n}" / folder_name
        subsampling_output.mkdir(parents=True, exist_ok=True)
        print(f"Performing global analysis for {folder_name}")
        feature_count = compare_kfold_shap_analysis(
            important_features_all_splits=splits_dict,
            resolution=resolution,
            chromsizes=chromsizes,
            output_folder=subsampling_output,
            name=folder_name,
            chosen_percentile=chosen_percentile,
            minimum_count=minimum_count,
        )
        with open(subsampling_output / "feature_count.json", "w", encoding="utf8") as f:
            json.dump(feature_count, f, indent=4)

In [None]:
categories = [CELL_TYPE, LIFE_STAGE, SEX, CANCER]
rest_of_paths = [
    "10fold-oversampling",
    "no-unknown/10fold-oversampling",
    "w-mixed/10fold-oversample",
    "10fold-oversampling",
]

# Variables used to find the right analysis folder and write new files
top_n = 303
chosen_percentile = 80.0
minimum_count = 8
frequency_threshold = f"{chosen_percentile:.2f}"

pairs = list(zip(categories, rest_of_paths))

In [None]:
for category, rest_of_path in pairs:
    if category != CANCER:
        continue

    results_dir = logs_dir / f"{category}_1l_3000n/{rest_of_path}"
    if not results_dir.is_dir():
        raise ValueError(f"Directory {results_dir} does not exist.")

    all_important_features = join_important_features(
        results_dir, top_n=top_n, frequency_threshold=frequency_threshold
    )

    perform_global_analysis(
        all_important_features=all_important_features,
        resolution=RESOLUTION,
        chromsizes=chroms,
        output_folder=results_dir,
        chosen_percentile=chosen_percentile,
        minimum_count=minimum_count,
        top_n=top_n,
    )

In [None]:
# print(results_dir)
join_important_features(results_dir, top_n=top_n, frequency_threshold=frequency_threshold)

In [None]:
perform_global_analysis(
    all_important_features=all_important_features,
    resolution=RESOLUTION,
    chromsizes=chroms,
    output_folder=results_dir,
    chosen_percentile=chosen_percentile,
    minimum_count=minimum_count,
    top_n=top_n,
)

### Intersection matrix

In [None]:
def compare_samplings_shap_analysis(
    jsons_parent_folder: Path, minimum_count: int = 8
) -> pd.DataFrame:
    """Compare the SHAP analysis results from multiple subsamplings and make a matrix out of them.

    Args:
        jsons_parent_folder (Path): Parent folder containing subfolders with the feature count jsons.
        minimum_count (int, optional): For a given feature, presence in how many training folds is required
                                       to be selected. Defaults to 8.

    Returns:
        pd.DataFrame: Intersection matrix (dtype=int) of the feature sets.
    """
    all_features_counts = {}
    output_classes = []
    for folder in jsons_parent_folder.iterdir():
        if not folder.is_dir():
            continue

        json_path = folder / "feature_count.json"

        with open(json_path, "r", encoding="utf8") as json_file:
            feature_count = json.load(json_file)

        all_features_counts[folder.name] = feature_count

        output_classes.extend(feature_count.keys())

    output_classes = list(set(output_classes))

    print(f"Output classes: {output_classes}")
    print(f"Number of output classes: {len(output_classes)}")

    # Filter features using minimum_count
    filtered_class_features = {}
    for subsampling_name, subsampling_feature_count in all_features_counts.items():
        for class_label, class_features in subsampling_feature_count.items():
            filtered_class_features[f"{subsampling_name} & {class_label}"] = [
                feature for feature, count in class_features if count >= minimum_count
            ]

    # Make pandas intersection matrix, diagonal is the number of features in each dict item
    intersection_matrix = np.zeros(
        (len(filtered_class_features), len(filtered_class_features)), dtype=int
    )
    for i, (_, features1) in enumerate(filtered_class_features.items()):
        for j, (_, features2) in enumerate(filtered_class_features.items()):
            intersection_matrix[i, j] = len(set(features1).intersection(features2))

    labels = [label.lower() for label in filtered_class_features]
    matrix_df = pd.DataFrame(data=intersection_matrix, index=labels, columns=labels)

    # Remove completely empty row and columns
    rem = np.where(matrix_df.sum(axis=1) == 0)
    matrix_df = matrix_df.drop(matrix_df.index[rem])
    matrix_df = matrix_df.drop(matrix_df.columns[rem], axis=1)

    return matrix_df

In [None]:
global_analysis_folder = base_shap_results_dir / "global_shap_analysis" / "top100"
global_analysis_folder.mkdir(parents=False, exist_ok=True)
matrix_df = compare_samplings_shap_analysis(global_analysis_folder)

In [None]:
# Add sample count to the labels
label_pairs = count_pairs(my_meta, ASSAY, CELL_TYPE)
label_pairs = {
    f"{names[0]} & {names[1].lower()}": count for names, count in label_pairs.items()
}
new_index = [label for label in matrix_df.index if label in label_pairs]
new_index_mapping = {label: f"{label} ({label_pairs[label]})" for label in new_index}
matrix_df.rename(index=new_index_mapping, inplace=True)

In [None]:
def sorting_key(value: str):
    """Return "something_else" from "something & something_else (count)"""
    main_part = value.split(" & ")[-1]  # Part after "&"
    sortable_part = main_part.split("(")[0].strip()  # Part before "("
    return sortable_part.lower()


# Sort the DataFrame
sorted_index = sorted(matrix_df.index, key=sorting_key)
sorted_columns = sorted(matrix_df.columns, key=sorting_key)
matrix_df = matrix_df.reindex(sorted_index)
matrix_df = matrix_df[sorted_columns]

In [None]:
def create_intersection_matrix_heatmap(matrix_df: pd.DataFrame, category: str) -> None:
    """Create a heatmap of the intersection matrix.

    Args:
        matrix_df (pd.DataFrame): Intersection matrix. Index and columns labels used as is.
        category (str): Category name. Used in the plot title.
    """
    # Create the figure and axis objects
    fig, ax = plt.subplots(figsize=(25, 20))  # Adjust the size as needed

    # Create a colormap that sets the 'bad' values (masked) to white

    masked_array = np.ma.array(matrix_df, mask=(matrix_df == 0))  # pylint: disable=C0325
    cmap = mpl.cm.get_cmap("viridis").copy()  # type: ignore
    cmap.set_bad(color="white")

    # Display the data
    cax = ax.imshow(masked_array, cmap=cmap, interpolation="none")

    # Add colorbar
    fig.colorbar(cax)

    # Annotate each cell with the numeric value
    # We iterate over the indices and the data in the DataFrame
    for (i, j), val in np.ndenumerate(matrix_df):
        if val != 0:  # Skip zero values to keep the white background clean
            ax.text(j, i, int(val), ha="center", va="center", color="white", path_effects=[path_effects.withStroke(linewidth=1.5, foreground="black")])  # type: ignore

    # Adjust the grid lines and labels
    ax.grid(which="major", color="black", linestyle="-", linewidth=1)
    # Shift the ticks and labels to be at the center of each cell
    ax.set_xticks(np.arange(len(matrix_df.columns)) - 0.5)
    ax.set_yticks(np.arange(len(matrix_df.index)) - 0.5)

    # Set the labels to be at the tick locations
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.tick_params(axis="both", which="both", length=0)  # Hide the tick marks

    # Add new axis labels with the correct offset
    ax2 = ax.secondary_xaxis("bottom")
    ax2.set_xticks(np.arange(len(matrix_df.columns)))
    ax2.set_xticklabels(matrix_df.columns, rotation=45, ha="right")

    ax3 = ax.secondary_yaxis("left")
    ax3.set_yticks(np.arange(len(matrix_df.index)))
    ax3.set_yticklabels(matrix_df.index)

    # Set the title and show the plot
    title = f"{category} important features groups intersection heatmap (top100, f80, split_count=8)"
    ax.set_title(title)
    plt.tight_layout()
    img_path = global_analysis_folder / (get_valid_filename(title) + ".png")
    plt.savefig(img_path, dpi=300)
    plt.show()

In [None]:
# create_intersection_matrix_heatmap(matrix_df, category)

### Venn Diagramm over multiple classification tasks

In [None]:
categories = [CANCER, LIFE_STAGE, CELL_TYPE]
union_important_features = defaultdict(set)
for category in categories:
    print(f"Category: {category}")
    path = list(
        (logs_dir / f"{category}_1l_3000n").glob(
            "10fold-oversampl*/global_shap_analysis/top100*"
        )
    )
    if not path:
        path = list(
            (logs_dir / f"{category}_1l_3000n").glob(
                "*/10fold-oversampl*/global_shap_analysis/top100*"
            )
        )
        if not path:
            raise ValueError(f"Path {path} does not exist.")
    path = path[0]
    print(f"Path: {path}")

    for folder in path.iterdir():
        if not folder.is_dir():
            continue

        json_path = folder / "feature_count.json"

        with open(json_path, "r", encoding="utf8") as json_file:
            feature_count = json.load(json_file)

        for class_label, class_features in feature_count.items():
            features = [feature for feature, count in class_features if count >= 8]
            union_important_features[category].update(features)

In [None]:
from matplotlib_venn import venn3  # type: ignore

v = venn3(
    list(union_important_features.values()), set_labels=union_important_features.keys()
)

## Fonctions definitions

In [None]:
def plot_feature_importance(
    sample_shap_values: np.ndarray,
    important_features: list,
    title: str,
    plot_type: str,
    logdir: str | Path,
) -> None:
    """Plot feature importance in a sample, highlighting important features using Plotly.

    Args:
        sample_shap_values (np.ndarray): The SHAP values for a single sample.
        important_features (list): List of indices corresponding to important features.
        title (str): The title for the plot.
        plot_type (str): Type of plot ("raw", "softmax", or "rank").
    """

    if plot_type == "raw":
        plot_values = sample_shap_values
    elif plot_type == "softmax":
        plot_values = softmax(sample_shap_values)
    elif plot_type == "rank":
        plot_values = np.argsort(
            np.argsort(-np.abs(sample_shap_values))
        )  # Rank based on absolute values
    else:
        raise ValueError("Invalid plot_type.")

    title = f"{title} ({plot_type})"
    # General points
    trace1 = go.Scatter(
        x=list(range(len(plot_values))),
        y=plot_values,
        mode="markers",
        marker=dict(color="blue"),
        name="All Features",
    )

    # Important points
    trace2 = go.Scatter(
        x=important_features,
        y=[plot_values[i] for i in important_features],
        mode="markers",
        marker=dict(color="red"),
        name="Important Features",
    )

    layout = go.Layout(
        title=title, xaxis=dict(title="Feature index"), yaxis=dict(title=plot_type)
    )

    fig = go.Figure(data=[trace1, trace2], layout=layout)

    logdir = Path(logdir)
    fig.write_image(logdir / f"{title}.png")

## Main functions definitions

In [None]:
def run_the_whole_thing(
    metadata: metadata.Metadata,
    shap_dir: Path,
    output_dir: Path,
    label_category: str,
    top_n: int = 100,
) -> Dict[str, List[int]]:
    """Execute the complete process of SHAP value analysis.

    This function performs the complete SHAP value analysis given the metadata and the directory
    of the SHAP value files. It carries out the following steps:
    1. Load the SHAP value archives and print basic statistics.
    2. Filter the metadata to match the samples in SHAP value archives.
    For each output class class:
        3. Extract SHAP values.
        4. Determine the top N features for each sample.
        5. Compute and print feature overlap statistics.
        6. Analyze feature importance.
        7. Convert bin indices to genomic ranges and write to a BED file.
        8. Display and save a plot of importance distribution for one sample.

    Args:
        metadata (metadata.Metadata): The metadata for the samples.
        shap_logdir (Path): The directory path where SHAP value files are stored.
        label_category (str): The name of the classifier output category that computed the shaps.
        top_n (int): The number of top features to be selected for each sample. Defaults to 100.

    Returns:
        Dict[str, Dict[int, List[int]]: Dictionary where keys are class labels and values are lists
        of the most frequently occurring important (high shap) features for that class, for each computed quantile.
        (see feature_overlap_stats function for more details).

    Raises:
        KeyError: If "evaluation_md5s" or "evaluation_ids" are not found in the loaded SHAP value archives.
    """
    metadata = copy.deepcopy(metadata)

    # Extract shap values and md5s from archive
    shap_matrices, eval_md5s, classes = extract_shap_values_and_info(shap_dir)

    # Filter metadata to include only the samples that exist in the SHAP value archives
    for md5 in list(metadata.md5s):
        if md5 not in set(eval_md5s):
            del metadata[md5]

    metadata.display_labels("assay_epiclass")
    metadata.display_labels("harmonized_donor_sex")

    # Loop over each class to perform SHAP value analysis
    important_features = {}
    for class_int, class_label in classes:
        class_int = int(class_int)
        print(f"\n\nClass: {class_label} ({class_int})")

        # Get the SHAP matrix for the current class,
        # and only select samples that also correspond to that class
        shap_matrix, chosen_idxs = get_shap_matrix(
            meta=metadata,
            shap_matrices=shap_matrices,
            eval_md5s=eval_md5s,
            label_category=label_category,
            selected_labels=[class_label],
            class_idx=class_int,
        )

        if len(chosen_idxs) < 5:
            print(f"Not enough samples (5) to perform analysis on {class_label}.")
            continue

        chosen_percentile = 90
        result_bed_filename = get_valid_filename(
            f"frequent_features_{chosen_percentile}_{class_label}.bed"
        )
        if (output_dir / result_bed_filename).is_file():
            print(f"Skipping {class_label} because {result_bed_filename} already exists.")
            continue

        # Computing statistics of feature overlap
        print(
            f"Selecting features with top {top_n} SHAP values for each sample of {class_label}."
        )
        top_n_features = []
        for sample in shap_matrix:
            top_n_features.append(list(n_most_important_features(sample, top_n)))

        feat_intersect, feat_union, frequent_features, hist_fig = feature_overlap_stats(
            top_n_features, [90, 95, 99]
        )
        important_features[class_label] = frequent_features

        hist_fig.write_image(
            file=output_dir / f"top{top_n}_feature_frequency_{class_label}.png",
            format="png",
        )

        # print_feature_overlap_stats(some_stats)

        feature_selection = frequent_features[chosen_percentile]

        # print_importance_info(feature_selection, shap_matrix)

        # Convert bin indices to genomic ranges and write to a BED file
        bed_vals = bins_to_bed_ranges(
            sorted(feature_selection), chroms, resolution=RESOLUTION
        )

        write_to_bed(
            bed_vals,
            output_dir / result_bed_filename,
            verbose=True,
        )

        # # Display and save a plot of importance distribution for one sample
        # print("One sample")
        # probs_1sample = pd.DataFrame(softmax(shap_matrix, axis=1)[0, :] * 100)
        # display(probs_1sample.describe(percentiles=DECILES))
        # fig_title = f"Importance distribution - One sample - {eval_md5s[chosen_idxs[0]]}"
        # fig = px.violin(probs_1sample, box=True, points="all", title=fig_title)
        # fig.write_image(shap_logdir / "importance_dist_1sample.png")

    return important_features

In [None]:
def run_full_shap_analysis(
    shap_values_and_info: Tuple[np.ndarray, List[str], List[Tuple[str, str]]],
    metadata: metadata.Metadata,
    output_dir: Path,
    label_category: str,
    top_n: int = 100,
    filter_metdata: bool = True,
) -> Dict[str, List[int]]:
    """Execute the complete process of SHAP value analysis, without the value extraction part.

    Extraction part metadata filtering takes a long time, this is intended to not repeat that part if possible.

    This function performs the complete SHAP value analysis given
    - shap values archive information (extract_shap_values_and_info output)
    - metadata
    - output directory
    It carries out the following steps:
    1. Filter the metadata to match the samples in SHAP value archives.
    For each output class class:
        2. Extract SHAP values.
        3. Determine the top N features for each sample.
        4. Compute and print feature overlap statistics.
        5. Analyze feature importance.
        6. Convert bin indices to genomic ranges and write to a BED file.
        7. Display and save a plot of importance distribution for one sample.

    Args:
        shap_values_and_info : Tuple containing the shap values archive information.
        metadata (metadata.Metadata): The metadata for the samples.
        output_dir (Path): The directory path where the results will be stored.
        label_category (str): The name of the classifier output category that computed the shaps.
        top_n (int): The number of top features to be selected for each sample. Defaults to 100.
        filter_metdata (bool): Whether to filter the metadata to match the samples in SHAP value archives.

    Returns:
        Dict[str, Dict[int, List[int]]: Dictionary where keys are class labels and values are lists
        of the most frequently occurring important (high shap) features for that class, for each computed quantile.
        (see feature_overlap_stats function for more details).

    Raises:
        KeyError: If "evaluation_md5s" or "evaluation_ids" are not found in the loaded SHAP value archives.
    """
    shap_matrices, eval_md5s, classes = shap_values_and_info

    # Filter metadata to include only the samples that exist in the SHAP value archives
    if filter_metdata:
        metadata = copy.deepcopy(metadata)
        for md5 in list(metadata.md5s):
            if md5 not in set(eval_md5s):
                del metadata[md5]
    else:
        print("Warning: Skipping metadata filtering.")

    metadata.display_labels("assay_epiclass")
    metadata.display_labels("harmonized_donor_sex")

    # Loop over each class to perform SHAP value analysis
    important_features = {}
    for class_int, class_label in classes:
        class_int = int(class_int)
        print(f"\n\nClass: {class_label} ({class_int})")

        # Get the SHAP matrix for the current class,
        # and only select samples that also correspond to that class
        shap_matrix, chosen_idxs = get_shap_matrix(
            meta=metadata,
            shap_matrices=shap_matrices,
            eval_md5s=eval_md5s,
            label_category=label_category,
            selected_labels=[class_label],
            class_idx=class_int,
        )

        if len(chosen_idxs) < 5:
            print(f"Not enough samples (5) to perform analysis on {class_label}.")
            continue

        chosen_percentile = 90
        result_bed_filename = get_valid_filename(
            f"frequent_features_{chosen_percentile}_{class_label}.bed"
        )

        # Don't redo work
        if (output_dir / result_bed_filename).is_file():
            print(f"Skipping {class_label} because {result_bed_filename} already exists.")
            continue

        # Computing statistics of feature overlap
        print(
            f"Selecting features with top {top_n} SHAP values for each sample of {class_label}."
        )
        top_n_features = []
        for sample in shap_matrix:
            top_n_features.append(list(n_most_important_features(sample, top_n)))

        feat_intersect, feat_union, frequent_features, hist_fig = feature_overlap_stats(
            top_n_features, [90, 95, 99]
        )
        important_features[class_label] = frequent_features

        hist_fig.write_image(
            file=output_dir / f"top{top_n}_feature_frequency_{class_label}.png",
            format="png",
        )

        # print_feature_overlap_stats(some_stats)

        feature_selection = frequent_features[chosen_percentile]

        # print_importance_info(feature_selection, shap_matrix)

        # Convert bin indices to genomic ranges and write to a BED file
        bed_vals = bins_to_bed_ranges(
            sorted(feature_selection), chroms, resolution=RESOLUTION
        )

        write_to_bed(
            bed_vals,
            output_dir / result_bed_filename,
            verbose=True,
        )

        # # Display and save a plot of importance distribution for one sample
        # print("One sample")
        # probs_1sample = pd.DataFrame(softmax(shap_matrix, axis=1)[0, :] * 100)
        # display(probs_1sample.describe(percentiles=DECILES))
        # fig_title = f"Importance distribution - One sample - {eval_md5s[chosen_idxs[0]]}"
        # fig = px.violin(probs_1sample, box=True, points="all", title=fig_title)
        # fig.write_image(shap_logdir / "importance_dist_1sample.png")

    return important_features

In [None]:
def run_alternative_analysis(
    metadata: metadata.Metadata,
    shap_logdir: Path,
    label_category: str,
    selected_classes: list[str],
    top_n: int = 100,
) -> None:
    """Run an alternative analysis that involves plotting feature importance for selected classes.

    This function performs the following steps:
    1. Extracts SHAP values and associated metadata.
    2. Filters out samples from the metadata that are not present in the SHAP value archives.
    3. Collects the most important features for selected classes.
    4. Plots the feature importance for samples of these selected classes using different metrics ("raw", "softmax", "rank").

    Args:
        metadata ("metadata.Metadata"): The metadata object containing sample information.
        shap_logdir (Path): The directory where the SHAP value archives are stored.
        label_category (str): The category of the label to be used for class selection.
        selected_classes (List[str]): A list of classes for which the analysis should be run.
        top_n (int, optional): The top N most important features to consider. Default is 100.

    Raises:
        ValueError: If sample indices are not unique across classes.
    """
    metadata = copy.deepcopy(metadata)

    # Extract shap values and md5s from archive
    shap_matrices, eval_md5s, classes = extract_shap_values_and_info(shap_logdir)

    # Filter metadata to include only the samples that exist in the SHAP value archives
    for md5 in list(metadata.md5s):
        if md5 not in set(eval_md5s):
            del metadata[md5]

    # collect important features for selected classes
    selected_percentile = 90
    classes_dict = {
        class_label: int(class_int)
        for class_int, class_label in classes
        if class_label in selected_classes
    }
    important_features = {}
    sample_idxs = {}
    for class_label, class_int in classes_dict.items():
        print(f"\n\nClass: {class_label} ({class_int})")

        # Get the SHAP matrix for the current class,
        # and only select samples that also correspond to that class
        shap_matrix, chosen_idxs = get_shap_matrix(
            meta=metadata,
            shap_matrices=shap_matrices,
            eval_md5s=eval_md5s,
            label_category=label_category,
            selected_labels=[class_label],
            class_idx=class_int,
        )
        sample_idxs[class_label] = chosen_idxs

        # Computing statistics of feature overlap
        top_n_features = []
        for sample in shap_matrix:
            top_n_features.append(list(n_most_important_features(sample, top_n)))

        some_stats = feature_overlap_stats(top_n_features, [selected_percentile])
        frequent_features = some_stats[2]
        important_features[class_label] = frequent_features[selected_percentile]

    all_chosen_idxs = set()
    for idxs in sample_idxs.values():
        all_chosen_idxs.update(idxs)

    if len(all_chosen_idxs) != sum(len(idxs) for idxs in sample_idxs.values()):
        raise ValueError("Sample indices are not unique across classes.")

    logdir = shap_logdir / "feature_rank_analysis"
    logdir.mkdir(parents=True, exist_ok=True)

    for class_of_interest in selected_classes:
        important_feats = important_features[class_of_interest]
        for comparison_class in selected_classes:
            shap_matrix = shap_matrices[classes_dict[comparison_class]]

            # Lists to hold md5s, metadata category, and feature ranks for each sample
            md5_list = []
            cell_type_list = []
            ranks_list = []

            for i, sample_shap_values in enumerate(shap_matrix):
                if i not in sample_idxs[comparison_class]:
                    continue
                # if i not in all_chosen_idxs:
                #     continue

                md5_list.append(eval_md5s[i])
                cell_type_list.append(metadata[eval_md5s[i]][CELL_TYPE])

                ranks = np.argsort(
                    np.argsort(-np.abs(sample_shap_values))
                )  # Ranking in descending order of absolute SHAP value
                ranks_of_important_feats = ranks[
                    important_feats
                ]  # Get the ranks of the important features
                ranks_list.append(ranks_of_important_feats)

            # Combine all the lists into a DataFrame
            ranks_df = pd.DataFrame(
                {
                    "md5sum": md5_list,
                    CELL_TYPE: cell_type_list,
                    **{
                        f"Feature_{feat}": [ranks[i] for ranks in ranks_list]
                        for i, feat in enumerate(important_feats)
                    },
                }
            )

            # Save the DataFrame to CSV
            title = f"important_{class_of_interest}_features_in_{comparison_class}_shap_matrix.csv".replace(
                " ", "_"
            )
            ranks_df.to_csv(logdir / title, index=True)
            print(
                f"Feature ranks for '{class_of_interest}' features in '{comparison_class}' samples have been saved."
            )

## Analysis

In [None]:
# print(logdir1, logdir2)
# print(logdir1)

In [None]:
logdir1 = output / f"{category}_1l_3000n/10fold/split0/shap/rna_only"
if not logdir1.exists():
    raise ValueError(f"{logdir1} does not exist.")

In [None]:
shap_matrices, eval_md5s, classes = extract_shap_values_and_info(logdir1, verbose=False)

In [None]:
# Metadata filtering
eval_md5s_set = set(eval_md5s)
for md5 in list(my_meta.md5s):
    if md5 not in eval_md5s_set:
        del my_meta[md5]

In [None]:
my_meta.display_labels(ASSAY)
my_meta.display_labels(CELL_TYPE)
if category not in [ASSAY, CELL_TYPE]:
    my_meta.display_labels(category)

In [None]:
# This is intended as a stop to not run code further needlessly, but still be able to run the beginning in one shot
raise ValueError("Stop here")  # pylint: disable=unreachable

### Donor sex

In [None]:
# pylint: disable=unreachable
assay_labels = (
    [[assay_label] for assay_label in my_meta.unique_classes(label_category=ASSAY)]
    + [["rna_seq", "mrna_seq"]]
    + [["wbgs-pbat", "wgbs-standard"]]
)
cell_type_labels = [
    "T cell",
    "neutrophil",
    "monocyte",
    "lymphocyte of B lineage",
    "brain",
    "myeloid cell",
]

for assay_label_list in assay_labels:
    for ct_label in cell_type_labels:
        print(assay_label_list, ct_label)
        meta = copy.deepcopy(my_meta)
        meta.select_category_subsets(ASSAY, assay_label_list)
        meta.select_category_subsets(CELL_TYPE, [ct_label])

        if len(meta) < 5:
            print("Not enough samples for this combination. Continuing")
            continue

        name = "_".join(assay_label_list) + "_" + ct_label
        run_logdir = logdir1 / "shap_analysis_and_go" / get_valid_filename(name)
        run_logdir.mkdir(parents=False, exist_ok=True)

        # Already have female/male results bed
        if len(list(run_logdir.glob("*.bed"))) >= 2:
            continue

        print(run_logdir)
        important_features = run_the_whole_thing(
            metadata=meta,
            shap_dir=logdir1,
            output_dir=run_logdir,
            label_category=category,
            top_n=100,
        )

        feat_90 = {name: sets[90] for name, sets in important_features.items()}

        if len(feat_90) <= 1:
            continue

        upset_features = upsetplot.from_contents(feat_90)

        upsetplot.UpSet(
            upset_features,
            subset_size="count",
            show_counts=True,  # type: ignore
            show_percentages=True,
            sort_by="cardinality",
            sort_categories_by="cardinality",
        ).plot()

        plt.savefig(run_logdir / f"upset_{name}.png", bbox_inches="tight")

In [None]:
# feat_intersection = set.intersection(
#     *[set(feat_list[90]) for feat_list in important_features.values()]
# )

# male_feat = set(important_features["male"][90])
# female_feat = set(important_features["female"][90])

# only_male = male_feat - female_feat
# only_female = female_feat - male_feat

# for feat_set, name in zip([only_female, only_male], ["only_female", "only_male"]):
#     bed_vals = bins_to_bed_ranges(sorted(feat_set), chroms, resolution=RESOLUTION)
#     var_name = f"{feat_set=}".split("=")[0]
#     write_to_bed(
#         bed_vals,
#         logdir / f"frequent_features_{90}_{name}.bed",
#         verbose=False,
#     )

In [None]:
# run_alternative_analysis(
#     metadata=my_meta,
#     shap_logdir=logdir,
#     label_category=SEX,
#     selected_classes=["female", "male"],
#     top_n=100,
# )

### Sample ontology - 6hist_6ct

In [None]:
important_feat_dict = {}

In [None]:
assay_labels = [
    [assay_label] for assay_label in my_meta.unique_classes(label_category=ASSAY)
]
cell_type_labels = [
    [ct_label] for ct_label in my_meta.unique_classes(label_category=CELL_TYPE)
]

general_logdir = logdir1 / "shap_analysis_and_go" / "feature_frequency_method"
general_logdir.mkdir(parents=False, exist_ok=True)

for assay_label_list in assay_labels:
    assay_label = "_".join(assay_label_list)
    assay_logdir = general_logdir / assay_label
    assay_logdir.mkdir(parents=False, exist_ok=True)
    for ct_label_list in cell_type_labels:
        print(assay_label_list, ct_label_list)
        meta = copy.deepcopy(my_meta)
        meta.select_category_subsets(ASSAY, assay_label_list)
        meta.select_category_subsets(CELL_TYPE, ct_label_list)

        meta.display_labels(ASSAY)
        meta.display_labels(CELL_TYPE)

        if len(meta) < 5:
            print("Not enough samples for this combination. Continuing")
            continue

        important_features = run_full_shap_analysis(
            shap_values_and_info=(shap_matrices, eval_md5s, classes),
            metadata=meta,
            output_dir=assay_logdir,
            label_category=category,
            top_n=100,
            filter_metdata=False,
        )

        cell_type_name = get_valid_filename("_".join(ct_label_list))
        important_feat_dict[(assay_label, cell_type_name)] = {
            name: sets[90] for name, sets in important_features.items()
        }

In [None]:
for k, v in important_feat_dict.items():
    print(k, v.keys(), len(list(v.values())[0]))

### Sample ontology - (m)rna-seq

In [None]:
assay_labels = [
    [assay_label] for assay_label in my_meta.unique_classes(label_category=ASSAY)
] + [["rna_seq", "mrna_seq"]]
cell_type_labels = [
    [ct_label] for ct_label in my_meta.unique_classes(label_category=CELL_TYPE)
]

general_logdir = logdir1 / "frequent_features" / "feature_frequency_method"
general_logdir.mkdir(parents=False, exist_ok=True)

for assay_label_list in assay_labels:
    assay_label = "_".join(assay_label_list)
    assay_logdir = general_logdir / assay_label
    assay_logdir.mkdir(parents=False, exist_ok=True)
    for ct_label_list in cell_type_labels:
        print(assay_label_list, ct_label_list)

        meta = copy.deepcopy(my_meta)
        meta.select_category_subsets(ASSAY, assay_label_list)
        meta.select_category_subsets(CELL_TYPE, ct_label_list)

        meta.display_labels(ASSAY)
        meta.display_labels(CELL_TYPE)

        if len(meta) < 5:
            print("Not enough samples for this combination. Continuing")
            continue

        important_features = run_full_shap_analysis(
            shap_values_and_info=(shap_matrices, eval_md5s, classes),
            metadata=meta,
            output_dir=assay_logdir,
            label_category=category,
            top_n=100,
            filter_metdata=False,
        )

        cell_type_name = get_valid_filename("_".join(ct_label_list))
        important_feat_dict[(assay_label, cell_type_name)] = {
            name: sets[90] for name, sets in important_features.items()
        }

In [None]:
for k, v in important_feat_dict.items():
    print(k, v.keys(), len(list(v.values())[0]))

### Upset plots for sample_ontology

In [None]:
# -- assay upset plot --
for assay_label_list in assay_labels:
    assay_label = "_".join(assay_label_list)
    assay_logdir = general_logdir / assay_label
    if not assay_logdir.exists():
        raise ValueError(f"{assay_logdir} does not exist.")

    ct_labels = [pair[1] for pair in important_feat_dict if pair[0] == assay_label]
    feat_90 = {
        ct_label: important_feat_dict[(assay_label, ct_label)].values()
        for ct_label in ct_labels
    }
    upset_features = upsetplot.from_contents(feat_90)

    upsetplot.UpSet(
        upset_features,
        subset_size="count",
        show_counts=True,  # type: ignore
        show_percentages=True,
        sort_by="cardinality",
        sort_categories_by="cardinality",
    ).plot()

    upset_filename = assay_logdir / f"upset_{assay_label}.png"
    print(f"Saving UpSet to {upset_filename}")
    plt.savefig(upset_filename, bbox_inches="tight")


# -- cell type upset plot --

# cell_types = set(pair[1] for pair in important_feat_dict if pair[0][0:2] == "h3") # mixing rna-only mixing with 6hist_6ct results

for ct_label_list in cell_type_labels:
    ct_label = "_".join(ct_label_list)
    assay_labels = [pair[0] for pair in important_feat_dict if pair[1] == ct_label]
    feat_90 = {
        assay_label: important_feat_dict[(assay_label, ct_label)].values()
        for assay_label in assay_labels
    }

    upset_features = upsetplot.from_contents(feat_90)
    upsetplot.UpSet(
        upset_features,
        subset_size="count",
        show_counts=True,  # type: ignore
        show_percentages=True,
        sort_by="cardinality",
        sort_categories_by="cardinality",
    ).plot()

    upset_filename = general_logdir / f"upset_{ct_label}_w_rna.png"
    print(f"Saving UpSet to {upset_filename}")
    plt.savefig(upset_filename, bbox_inches="tight")

### Violin plots

In [None]:
def rank_violin_plots(logdir: str | Path, selected_classes: list[str]) -> None:
    """Create violin plots of important feature ranks for selected classes."""
    logdir = Path(logdir)
    # Iterate through each class of interest
    for class_of_interest in selected_classes:
        # Initialize an empty figure
        fig = go.Figure()

        # Iterate through each class to add its violin plot to the figure
        for comparison_class in selected_classes:
            df = pd.read_csv(
                logdir
                / f"important_{class_of_interest}_features_in_{comparison_class}_shap_matrix.csv"
            )
            df = df.filter(
                like="Feature_"
            )  # Remove non-feature columns from DataFrame for plotting

            # Create violin plot for the current comparison_class
            violin = go.Violin(
                y=df.values.flatten(),  # Flattened feature ranks
                name=f"{comparison_class} (n={df.shape[0]})",  # Name of the violin plot
                box_visible=True,  # Display box inside the violin
                line_color=px.colors.qualitative.Plotly[
                    len(fig.data)  # type: ignore
                ],  # Different color for each violin
                points="all",  # Display all points
            )
            fig.add_trace(violin)

            print(df.shape)
        # Set title and axis labels
        fig.update_layout(
            title=f"Violin plot of ranks for important features of '{class_of_interest}' ({df.shape[1]} features)",  # type: ignore
            xaxis_title="Source Matrix",
            yaxis_title="Feature Rank",
        )

        # Show the figure
        fig.write_image(
            logdir / f"important_{class_of_interest}_feature_ranks_violin_plot.png"
        )
        fig.write_html(
            logdir / f"important_{class_of_interest}_feature_ranks_violin_plot.html"
        )
        fig.show()

In [None]:
# rank_violin_plots(
#     logdir=logdir / "feature_rank_analysis", selected_classes=["female", "male"]
# )

### Samples ontology class pairs important features overlap

In [None]:
# feature_union = set()
# feature_intersection = set(list(important_features.values())[0][90])
# for label, features in important_features.items():
#     features_90 = features[90]
#     print(f"\n\nClass: {label}")
#     print(f"Most frequent features in 90th quantile: {features_90}")
#     feature_union.update(features_90)
#     feature_intersection &= set(features_90)

# print(f"\n\nUnion of all features: {len(feature_union)} features")
# print(
#     f"\n\nIntersection of all features: {len(feature_intersection)} features: {feature_intersection}"
# )

In [None]:
def compute_intersections(important_features: dict) -> pd.DataFrame:
    """Compute all possible intersections between pairs of sets and store them in a DataFrame.

    Args:
        important_features (dict): Dictionary where keys are class labels and values are sets of important features.

    Returns:
        pd.DataFrame: A DataFrame containing the intersections and their properties.
    """
    sets = {
        label: set(quantile_features[90])
        for label, quantile_features in important_features.items()
    }
    records = []

    for set1_info, set2_info in itertools.combinations(sets.items(), 2):
        label1, set1 = set1_info
        label2, set2 = set2_info
        intersection = set1.intersection(set2)

        record = {
            "Set1_Label": label1,
            "Set2_Label": label2,
            "Set1_Size": len(set1),
            "Set2_Size": len(set2),
            "Intersection": intersection,
            "Intersection_Size": len(intersection),
        }

        records.append(record)

    return pd.DataFrame(records)

In [None]:
# df = compute_intersections(important_features)
# df.to_csv(logdir / "feature_intersections_q90.csv", index=False)

In [None]:
# counter = Counter()
# for feature_set in df["Intersection"]:
#     counter.update(feature_set)

# for k, v in counter.most_common():
#     print(f"{k} {v}")

In [None]:
# feature_selection = []
# sort_order = np.argsort(feature_selection)

# bed_vals = bins_to_bed_ranges(sorted(feature_selection), chroms, resolution=RESOLUTION)

# write_to_bed(bed_ranges=bed_vals, bed_path=logdir / "frequent_features.bed", verbose=True)

In [None]:
# for val in np.array(bed_vals)[sort_order.argsort()]:
#     print(val)