# Analyze shaps

## SETUP

In [None]:
"""Initial analysis of shap values behavior."""
# pylint: disable=redefined-outer-name, expression-not-assigned, import-error, not-callable, pointless-statement, no-value-for-parameter, undefined-variable, unused-argument, line-too-long, use-dict-literal, too-many-lines, unused-import, unused-variable, too-many-branches
from __future__ import annotations

import copy
import itertools
import json
import re
import sys
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any, Dict, List, Sequence, Set, Tuple

import matplotlib as mpl
import matplotlib.patheffects as path_effects
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px  # type: ignore
import plotly.graph_objects as go  # type: ignore
import plotly.io as pio  # type: ignore
import upsetplot  # type: ignore
from IPython.display import display
from matplotlib_venn import venn3  # type: ignore

pio.renderers.default = "notebook"

from epi_ml.core import metadata
from epi_ml.core.data_source import EpiDataSource
from epi_ml.utils.bed_utils import bins_to_bed_ranges, write_to_bed
from epi_ml.utils.general_utility import get_valid_filename
from epi_ml.utils.metadata_utils import count_combinations, count_pairs
from epi_ml.utils.shap.subset_features_handling import (
    collect_all_features_from_feature_count_file,
    collect_features_from_feature_count_file,
    process_all_subsamplings,
)

BIOMATERIAL_TYPE = "harmonized_biomaterial_type"
CELL_TYPE = "harmonized_sample_ontology_intermediate"
ASSAY = "assay_epiclass"
SEX = "harmonized_donor_sex"
CANCER = "harmonized_sample_cancer_high"
DISEASE = "harmonized_sample_disease_high"
LIFE_STAGE = "harmonized_donor_life_stage"
TRACK = "track_type"

In [None]:
%matplotlib inline

In [None]:
home = Path().home() / "Projects"
input_dir = home / "epiclass/input"
metadata_path = (
    input_dir
    / "metadata/dfreeze-v2/hg38_2023-epiatlas-dfreeze_v2.1_w_encode_noncore_2.json"
)
my_meta = metadata.Metadata(metadata_path)
chroms = EpiDataSource.load_external_chrom_file(
    input_dir / "chromsizes/hg38.noy.chrom.sizes"
)

In [None]:
RESOLUTION = 100 * 1000

In [None]:
category = CELL_TYPE

In [None]:
my_meta.remove_missing_labels(category)
my_meta.display_labels(category)

## Analyze important features over all folds

In [None]:
# logs_dir = (
#     Path.home()
#     / "mounts/narval-mount/project-rabyj/epilap/output/logs/epiatlas-dfreeze-v2.1/hg38_100kb_all_none"
# )
logs_dir = Path.home() / "scratch/epiclass/join_important_features/"

if not logs_dir.exists():
    raise ValueError(f"Logs dir {logs_dir} does not exist")

### Collect information from all splits

In [None]:
def type_hint_repr(obj: Any) -> str:
    """
    Generates a string representation of the type hint for the given object.

    Args:
    obj (Any): The object to generate type hint for.

    Returns:
    str: A string representation of the type hint.
    """
    if isinstance(obj, (dict, defaultdict)):
        key_types = set(type_hint_repr(key) for key in obj.keys())
        value_types = set(type_hint_repr(value) for value in obj.values())

        key_type = next(iter(key_types)) if len(key_types) == 1 else "Any"
        value_type = next(iter(value_types)) if len(value_types) == 1 else "Any"

        dict_type = "Dict" if isinstance(obj, dict) else "defaultdict"
        return f"{dict_type}[{key_type}, {value_type}]"
    if isinstance(obj, list):
        element_types = set(type_hint_repr(element) for element in obj)
        element_type = next(iter(element_types)) if len(element_types) == 1 else "Any"
        return f"List[{element_type}]"
    if isinstance(obj, set):
        element_types = set(type_hint_repr(element) for element in obj)
        element_type = next(iter(element_types)) if len(element_types) == 1 else "Any"
        return f"Set[{element_type}]"
    if isinstance(obj, tuple):
        element_types = tuple(type_hint_repr(element) for element in obj)
        return f"Tuple[{', '.join(element_types)}]"

    return obj.__class__.__name__

In [None]:
def join_important_features(
    parent_folder: Path, top_n: int, frequency_threshold: str
) -> Dict[str, Dict[str, Dict[str, Dict[str, List[int]]]]]:
    """Join important features from all folds

    Args:
        parent_folder (Path): Parent folder of split folders.
        top_n (int): Number of top features to consider for each sample,
                    used to select the analysis folder.
        frequency_threshold (str): Frequency threshold used to select the analysis folder.

    Returns:
        Dict: all_important_features
        - level 1: folder_name (str): split_dict
        - level 2: split_name (str) : all_class_features (dict)
        - level 3: class_label (classifier outputs, str): important features for each frequency threshold (dict)
        - level 4: frequency_threshold (str, 0 to 100): list of features (List[int])
    """
    print("WARNING: Consider using a local copy of the data to speed up the process.")
    all_important_features = defaultdict(dict)
    all_folders = parent_folder.glob(
        f"split*/shap/analysis_n{top_n}_f{frequency_threshold}/*"
    )
    for important_features_path in all_folders:
        split = important_features_path.parents[2].name
        if "split" not in split:
            raise ValueError(f"Split not found in {split}")
        folder = important_features_path.name

        json_path = important_features_path / "important_features.json"
        try:
            with open(json_path, "r", encoding="utf8") as f:
                features = json.load(f)
        except FileNotFoundError:
            print(f"File {json_path} not found")
            important_features_path.rmdir()
            continue

        if not features:
            print(f"File {json_path} is empty. Removing folder.")
            json_path.unlink()
            important_features_path.rmdir()
            continue

        all_important_features[folder][split] = features
    # print(type_hint_repr(all_important_features))
    return all_important_features

In [None]:
def compare_kfold_shap_analysis(
    important_features_all_splits: Dict[str, Dict[str, Dict[str, List[int]]]],
    resolution: int,
    chromsizes: List[Tuple[str, int]],
    output_folder: Path,
    name: str,
    chosen_percentile: float | int = 80,
    minimum_count: int = 8,
    verbose: bool = False,
) -> Dict[str, List[Tuple[str, List[int]]]]:
    """Compare the SHAP analysis results from multiple splits and writes the results based on frequency.

    Args:
        important_features_all_splits: Dictionary containing important features for each split.
        resolution (int): Resolution for binning.
        chromsizes (List[Tuple[str, int]]): List with chromosome names and sizes.
        output_folder (Path): Output directory for writing BED files.
        name (str): Name of the analysis.
        chosen_percentile (float | int, optional): The chosen percentile for selecting important features.
        minimum_count (int, optional): The minimum count of splits that a feature must be present in.

    Returns:
        Dict[str, List[Tuple[str, List[int]]]]: A dict containing the frequency of features for each class over all splits, for all given percentiles.
    """
    class_features_frequency: Dict[str, List] = {}
    class_labels = set()
    # print(type_hint_repr(important_features_all_splits))
    for split_dict in important_features_all_splits.values():
        class_labels.update(list(split_dict.keys()))

    percentile_labels = set()
    for split_dict in important_features_all_splits.values():
        percentile_lists = itertools.chain(
            list(class_dict.keys()) for class_dict in list(split_dict.values())
        )
        for percentile_list in percentile_lists:
            percentile_labels.update(percentile_list)

    if str(chosen_percentile) not in percentile_labels:
        raise ValueError(
            f"Chosen percentile {chosen_percentile} not found in {percentile_labels}."
        )

    for class_label in class_labels:
        feature_counter = Counter()

        # Count the occurrence of each feature across all splits
        for features_dict in important_features_all_splits.values():
            current_features = features_dict.get(class_label, {}).get(
                str(chosen_percentile), []
            )
            feature_counter.update(current_features)

        class_features_frequency[class_label] = feature_counter.most_common()

    # Select features present in at least a certain count of splits (e.g., 8 out of 10)
    for class_label, feature_count_list in class_features_frequency.items():
        selected_features = {
            feature for feature, count in feature_count_list if count >= minimum_count
        }

        if not selected_features:
            if verbose:
                print(
                    f"No features meeting the required count for class {class_label}",
                    file=sys.stderr,
                )
            continue

        bed_vals = bins_to_bed_ranges(
            sorted(selected_features), chromsizes, resolution=resolution
        )

        bed_filename = get_valid_filename(
            f"selected_features_{name}_f{chosen_percentile:.2f}_count{minimum_count}_{class_label}.bed"
        )

        output_folder.mkdir(exist_ok=True, parents=True)
        write_to_bed(
            bed_vals,
            output_folder / bed_filename,
            verbose=False,
        )

    return class_features_frequency

In [None]:
def perform_global_analysis(
    all_important_features: Dict[str, Dict[str, Dict[str, Dict[str, List[int]]]]],
    resolution: int,
    chromsizes: List[Tuple[str, int]],
    output_folder: Path,
    chosen_percentile: float = 80.0,
    minimum_count: int = 8,
    top_n: int = 303,
    verbose: bool = True,
) -> None:
    """Perform global analysis of shap values. For each
    subsampling, compute the frequency of features over all splits
    and save it to a feature_count.json file.

    Also write to bed the features that respect the chosen_percentile and
    minimum_count thresholds.
    """
    for folder_name, splits_dict in all_important_features.items():
        parent = output_folder / "global_shap_analysis"
        if not parent.exists():
            raise ValueError(f"Parent folder {parent} does not exist")
        subsampling_output = parent / f"top{top_n}" / folder_name
        subsampling_output.mkdir(parents=True, exist_ok=True)
        if verbose:
            print(f"Performing global analysis for {folder_name}")
        feature_count = compare_kfold_shap_analysis(
            important_features_all_splits=splits_dict,
            resolution=resolution,
            chromsizes=chromsizes,
            output_folder=subsampling_output,
            name=folder_name,
            chosen_percentile=chosen_percentile,
            minimum_count=minimum_count,
            verbose=False,
        )
        with open(subsampling_output / "feature_count.json", "w", encoding="utf8") as f:
            json.dump(feature_count, f, indent=4)

In [None]:
# input_dir = (
#     logs_dir
#     / "harmonized_sample_ontology_intermediate_1l_3000n/10fold-oversampling/global_shap_analysis/top303"
# )
# input_dir = logs_dir / "harmonized_sample_cancer_high_1l_3000n/10fold-oversampling/global_shap_analysis/top303/"

In [None]:
logs_dir = (
    Path.home()
    / "scratch/epiclass/join_important_features/hg38_regulatory_regions_n30321_100kb_coord"
)

In [None]:
# categories = [ASSAY, CELL_TYPE, LIFE_STAGE, SEX, CANCER]
# rest_of_paths = [
#     "11c/10fold-oversampling",
#     "10fold-oversampling",
#     "no-unknown/10fold-oversampling",
#     "w-mixed/10fold-oversample",
#     "10fold-oversampling",
# ]
categories = [CELL_TYPE]
rest_of_paths = ["10fold-oversampling"]

# Variables used to find the right analysis folder and write new files
top_n = 303
chosen_percentile = 80.0
minimum_count = 8
frequency_threshold = f"{chosen_percentile:.2f}"

pairs = list(zip(categories, rest_of_paths))

In [None]:
for category, rest_of_path in pairs:
    # if category != ASSAY:
    #     continue

    print(f"Performing global analysis for classifiers of {category}")
    results_dir = logs_dir / f"{category}_1l_3000n/{rest_of_path}"
    if not results_dir.is_dir():
        raise ValueError(f"Directory {results_dir} does not exist.")

    all_important_features = join_important_features(
        results_dir, top_n=top_n, frequency_threshold=frequency_threshold
    )

    (results_dir / "global_shap_analysis").mkdir(parents=False, exist_ok=True)

    perform_global_analysis(
        all_important_features=all_important_features,
        resolution=RESOLUTION,
        chromsizes=chroms,
        output_folder=results_dir,
        chosen_percentile=chosen_percentile,
        minimum_count=minimum_count,
        top_n=top_n,
        verbose=False,
    )

### Intersection matrix

In [None]:
def create_subsamplings_intersection_matrix(
    flat_features_dict: Dict[str, Set[int]]
) -> pd.DataFrame:
    """Create an matrix of the intersection of each feature sets of flat_features_dict.

    Args:
        flat_features_dict (Dict[str, Set[int]]): A dictionary containing features sets.

    Returns:
        pd.DataFrame: Intersection matrix (dtype=int) of the feature sets.
        The diagonal is the number of features in each set.
        Labels are the keys of flat_features_dict.
    """
    # Make pandas intersection matrix, diagonal is the number of features in each dict item
    intersection_matrix = np.zeros(
        (len(flat_features_dict), len(flat_features_dict)), dtype=int
    )
    for i, (_, features1) in enumerate(flat_features_dict.items()):
        for j, (_, features2) in enumerate(flat_features_dict.items()):
            intersection_matrix[i, j] = len(features1.intersection(features2))

    labels = [label.lower() for label in flat_features_dict]
    matrix_df = pd.DataFrame(data=intersection_matrix, index=labels, columns=labels)

    # Remove completely empty row and columns
    rem = np.where(matrix_df.sum(axis=1) == 0)
    matrix_df = matrix_df.drop(matrix_df.index[rem])
    matrix_df = matrix_df.drop(matrix_df.columns[rem], axis=1)

    return matrix_df

In [None]:
def sorting_key_subsampling(value: str, invert: bool = False) -> Tuple[str, str]:
    """Sort in a smarter way the subsampling names."""
    try:
        prefix, suffix = value.split(" & ", 1)
    except ValueError:
        prefix = value
        suffix = ""

    if invert:
        return (suffix.lower(), prefix.lower())
    return (prefix.lower(), suffix.lower())


def sorting_key_subsampling_invert(value: str) -> Tuple[str, str]:
    """Sort in a smarter way the subsampling names."""
    return sorting_key_subsampling(value, invert=True)


def add_count_to_labels(
    matrix_df: pd.DataFrame, my_meta: metadata.Metadata
) -> pd.DataFrame:
    """Add sample count to the labels, only works for simple subsamplings"""
    label_pairs = count_pairs(my_meta, ASSAY, CELL_TYPE)
    label_pairs = {
        f"{names[0]} & {names[1].lower()}": count for names, count in label_pairs.items()
    }
    new_index = [label for label in matrix_df.index if label in label_pairs]
    new_index_mapping = {label: f"{label} ({label_pairs[label]})" for label in new_index}
    matrix_df.rename(index=new_index_mapping, inplace=True)
    return matrix_df


def sort_matrix_labels(matrix_df: pd.DataFrame, output_class_first: bool) -> pd.DataFrame:
    """Sort the labels in the matrix"""
    sorting_key = (
        sorting_key_subsampling_invert if output_class_first else sorting_key_subsampling
    )

    sorted_index = sorted(matrix_df.index, key=sorting_key)
    sorted_columns = sorted(matrix_df.columns, key=sorting_key)
    sorted_matrix_df = matrix_df.reindex(index=sorted_index, columns=sorted_columns)
    return sorted_matrix_df

In [None]:
def create_intersection_matrix_heatmap(
    matrix_df: pd.DataFrame, category: str, output_folder: Path
) -> Path:
    """Create a heatmap of the intersection matrix.

    Args:
        matrix_df (pd.DataFrame): Intersection matrix. Index and columns labels used as is.
        category (str): Category name. Used in the plot title.
    """
    # Create the figure and axis objects
    fig, ax = plt.subplots(figsize=(45, 40))  # Adjust the size as needed

    # Create a colormap that sets the 'bad' values (masked) to white

    masked_array = np.ma.array(matrix_df, mask=(matrix_df == 0))  # pylint: disable=C0325
    cmap = mpl.cm.get_cmap("viridis").copy()  # type: ignore
    cmap.set_bad(color="white")

    # Display the data
    cax = ax.imshow(masked_array, cmap=cmap, interpolation="none")

    # Add colorbar
    fig.colorbar(cax)

    # Annotate each cell with the numeric value
    # We iterate over the indices and the data in the DataFrame
    for (i, j), val in np.ndenumerate(matrix_df):
        if val != 0:  # Skip zero values to keep the white background clean
            ax.text(j, i, int(val), ha="center", va="center", color="white", path_effects=[path_effects.withStroke(linewidth=1.5, foreground="black")])  # type: ignore

    # Adjust the grid lines and labels
    ax.grid(which="major", color="black", linestyle="-", linewidth=1)
    # Shift the ticks and labels to be at the center of each cell
    ax.set_xticks(np.arange(len(matrix_df.columns)) - 0.5)
    ax.set_yticks(np.arange(len(matrix_df.index)) - 0.5)

    # Set the labels to be at the tick locations
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.tick_params(axis="both", which="both", length=0)  # Hide the tick marks

    # Add new axis labels with the correct offset
    ax2 = ax.secondary_xaxis("bottom")
    ax2.set_xticks(np.arange(len(matrix_df.columns)))
    ax2.set_xticklabels(matrix_df.columns, rotation=45, ha="right")

    ax3 = ax.secondary_yaxis("left")
    ax3.set_yticks(np.arange(len(matrix_df.index)))
    ax3.set_yticklabels(matrix_df.index)

    # Set the title and show the plot
    title = f"{category} important features groups intersection heatmap (top100, f80, split_count=8)"
    ax.set_title(title)
    plt.tight_layout()
    img_path = output_folder / (get_valid_filename(title) + ".png")
    fig.savefig(str(img_path), dpi=150)
    plt.close(fig)
    return img_path

In [None]:
logs_dir = Path.home() / "scratch/epiclass/join_important_features/hg38_100kb_all_none"

In [None]:
every_task_features = {}
global_union_features = {}
for category, rest_of_path in pairs:
    if category != CELL_TYPE:
        continue

    results_dir = logs_dir / f"{category}_1l_3000n/{rest_of_path}"

    global_analysis_folder = results_dir / "global_shap_analysis" / "top303"
    global_analysis_folder.mkdir(parents=False, exist_ok=True)
    print(f"Global analysis folder: {global_analysis_folder}")

    feature_sets = process_all_subsamplings(
        global_analysis_folder, aggregate=True, minimum_count=8, verbose=False
    )
    global_union_features[category] = feature_sets["global_union"]
    every_task_features[category] = feature_sets

    # # Create intersection matrix
    # matrix_df = create_subsamplings_intersection_matrix(feature_sets)
    # matrix_df = sort_matrix_labels(matrix_df, output_class_first=False)

    # matrix_df.to_csv(
    #     global_analysis_folder
    #     / f"intersection_matrix_{category}_top{top_n}_f{frequency_threshold}.csv"
    # )

    # counter = count_combinations(my_meta, set([ASSAY, CELL_TYPE, category]))
    # with open(
    #     global_analysis_folder / f"combination_count_{category}.csv", "w", encoding="utf8"
    # ) as f:
    #     for combination, count in sorted(counter.items()):
    #         f.write(f"{','.join(combination)},{count}\n")

In [None]:
feature_enrichment = defaultdict(set)
for key, value in every_task_features[CELL_TYPE].items():
    if "merge_samplings" in key:
        for feature in value:
            feature_enrichment[feature].add(key.replace("merge_samplings_", ""))

In [None]:
bed_ranges = bins_to_bed_ranges(
    bin_indexes=sorted(feature_enrichment.keys()),
    chroms=chroms,
    resolution=RESOLUTION,
)

In [None]:
with open(
    global_analysis_folder / "cell_type_feature_enrichment_bed-details.tsv",
    "w",
    encoding="utf8",
) as f:
    for feature_idx, bed_range in zip(sorted(feature_enrichment.keys()), bed_ranges):
        chrom, start, end = bed_range
        f.write(
            f"{chrom}\t{start}\t{end}\t{feature_idx}\t{list(feature_enrichment[feature_idx])}\n"
        )

In [None]:
global_cancer_folder = (
    logs_dir / f"{CANCER}_1l_3000n/10fold-oversampling/global_shap_analysis/top303"
)
if not global_cancer_folder.exists():
    raise ValueError(f"Folder {global_cancer_folder} does not exist.")

In [None]:
def make_special_bed_detail(every_task_features):
    """Create a BED file with the details of the feature sets."""

    # Create initial reference features
    cancer_feature_sets = every_task_features[CANCER]
    reference_features = (
        cancer_feature_sets["merge_samplings_cancer"]
        & cancer_feature_sets["merge_samplings_non-cancer"]
    )
    relevant_names = [name for name in cancer_feature_sets if "& cancer" in name]

    # Create a dict with the origin of each feature
    reference_features_origin = defaultdict(set)
    for feature in reference_features:
        for set_name in relevant_names:
            if feature in cancer_feature_sets[set_name]:
                set_name = set_name.replace(" & cancer", "")
                reference_features_origin[feature].add(set_name)

    assert len(reference_features_origin) == len(reference_features)

    for feature, origin in reference_features_origin.items():
        assert len(origin) != 0
        # print(feature, origin)

    # Create a dict with the bed values of each feature
    feature_bed_vals = {}
    for feature in reference_features_origin:
        bed_vals = bins_to_bed_ranges(
            bin_indexes=[feature],
            chroms=chroms,
            resolution=RESOLUTION,
        )
        assert len(bed_vals) == 1
        feature_bed_vals[feature] = bed_vals[0]

    # Write to bed the feature positions with origin detail
    with open(
        logs_dir / "cancer_intersection_merge_samplings_bed-details.tsv",
        "w",
        encoding="utf8",
    ) as f:
        for feature, bed_val in sorted(feature_bed_vals.items(), key=lambda x: x[0]):
            print(feature)
            f.write(
                f"{bed_val[0]}\t{bed_val[1]}\t{bed_val[2]}\t{feature}\t{sorted(reference_features_origin[feature])}\n"
            )

In [None]:
make_special_bed_detail(every_task_features)

In [None]:
def upsetplot_global_union(global_union_features):
    """Create an upsetplot of the global union of features."""
    # Plot global union of features upsetplot
    correct_dict = {
        name: features
        for name, features in global_union_features.items()
        if name in categories
    }
    upset_features = upsetplot.from_contents(correct_dict)
    fig = upsetplot.UpSet(
        upset_features,
        subset_size="count",
        show_counts=True,  # type: ignore
        show_percentages=True,
        sort_by="cardinality",
        sort_categories_by="cardinality",
    )
    fig.plot()

In [None]:
# upsetplot_global_union(global_union_features)

In [None]:
def write_global_task_features(
    global_union_features: Dict[str, Set[int]]
) -> Dict[str, List[int]]:
    """Write the global union of features (union of all subsampling sets) for each task to json.

    Returns:
        Dict: Augmented global_union_features with total union and intersections over all tasks.
    """
    feature_sets = list(global_union_features.values())

    global_task_union = set().union(*feature_sets)

    global_task_intersection = set(feature_sets[0])
    for feature_set in global_union_features.values():
        global_task_intersection.intersection_update(feature_set)

    # print(len(global_task_union), len(global_task_intersection))

    with open(logs_dir / "global_task_features.json", "w", encoding="utf8") as f:
        content = {
            category: list(features)
            for category, features in global_union_features.items()
        }
        content["global_tasks_union"] = list(global_task_union)
        content["global_tasks_intersection"] = list(global_task_intersection)
        # json.dump(content, f, indent=4)

    return content

In [None]:
def create_new_global_beds(global_union_features):
    """Create new BED files for the global union and intersection of features for all tasks
    Also create correspond beds of random features of same size.
    """
    N_BINS_HDF5_100KB = 30321
    json_content = write_global_task_features(global_union_features)
    for category in ["global_tasks_union", "global_tasks_intersection"]:
        # Actual features
        features = json_content[category]
        nb_features = len(features)
        bed_vals = bins_to_bed_ranges(
            bin_indexes=features,
            chroms=chroms,
            resolution=RESOLUTION,
        )

        bed_filename = get_valid_filename(f"{category}.bed")
        write_to_bed(bed_vals, logs_dir / bed_filename, verbose=False)

        # Random features of the same size
        random_idxs = np.random.default_rng(42).choice(
            a=np.arange(0, N_BINS_HDF5_100KB), size=nb_features, replace=False
        )
        assert len(random_idxs) == len(set(random_idxs))  # sanity check: no replace
        bed_vals_random = bins_to_bed_ranges(
            bin_indexes=list(random_idxs),
            chroms=chroms,
            resolution=RESOLUTION,
        )

        bed_filename_random = get_valid_filename(f"random_n{nb_features}.bed")
        write_to_bed(bed_vals_random, logs_dir / bed_filename_random, verbose=False)

In [None]:
raise ValueError("Stop here")

### Venn Diagramm over multiple classification tasks

In [None]:
# # pylint: disable=unreachable
# union_important_features = defaultdict(set)
# for category in categories:
#     print(f"Category: {category}")
#     path = list(
#         (logs_dir / f"{category}_1l_3000n").glob(
#             "10fold-oversampl*/global_shap_analysis/top100*"
#         )
#     )
#     if not path:
#         path = list(
#             (logs_dir / f"{category}_1l_3000n").glob(
#                 "*/10fold-oversampl*/global_shap_analysis/top100*"
#             )
#         )
#         if not path:
#             raise ValueError(f"Path {path} does not exist.")
#     path = path[0]
#     print(f"Path: {path}")

#     for folder in path.iterdir():
#         if not folder.is_dir():
#             continue

#         json_path = folder / "feature_count.json"

#         with open(json_path, "r", encoding="utf8") as json_file:
#             feature_count = json.load(json_file)

#         for class_label, class_features in feature_count.items():
#             features = [feature for feature, count in class_features if count >= 8]
#             union_important_features[category].update(features)

In [None]:
# v = venn3(
#     list(union_important_features.values()), set_labels=union_important_features.keys()
# )

In [None]:
# logdir1 = logs_dir / f"{category}_1l_3000n/10fold-oversampling/split1/shap"
# if not logdir1.exists():
#     raise ValueError(f"{logdir1} does not exist.")

In [None]:
# my_meta.display_labels(ASSAY)
# my_meta.display_labels(CELL_TYPE)
# if category not in [ASSAY, CELL_TYPE]:
#     my_meta.display_labels(category)

In [None]:
# This is intended as a stop to not run code further needlessly, but still be able to run the beginning in one shot
raise ValueError("Stop here")  # pylint: disable=unreachable

## Feature selection/filtering from full jsons

In [None]:
base_path = Path.home() / "scratch/epiclass/join_important_features"
feature_count_path = (
    base_path
    / "hg38_regulatory_regions_n30321_100kb_coord/harmonized_sample_ontology_intermediate_1l_3000n/10fold-oversampling/global_shap_analysis/top303/h3k27ac/feature_count.json"
)

feature_count_general_dir = feature_count_path.parent.parent

for folder in feature_count_general_dir.iterdir():
    if not folder.is_dir():
        continue

    feature_count_path = folder / "feature_count.json"
    if not feature_count_path.exists():
        continue

    features = collect_features_from_feature_count_file(feature_count_path, n=8)
    with open(folder / "features_n8.json", "w", encoding="utf8") as f:
        json.dump(features, f)

    features_all = collect_all_features_from_feature_count_file(feature_count_path, n=8)
    with open(folder / "features_n8_all.json", "w", encoding="utf8") as f:
        json.dump(features_all, f)

### Union of frequent features

In [None]:
global_important_features_dir = (
    Path.home() / "Projects/epiclass/output/models/SHAP/global_task_features/global_info"
)
global_important_features_path = (
    global_important_features_dir / "global_task_features.json"
)
with open(global_important_features_path, "r", encoding="utf8") as f:
    global_important_features = json.load(f)

In [None]:
cancer_special_merge_path = (
    global_important_features_dir
    / "cancer/cancer_intersection_merge_samplings_bed-details_2.tsv"
)
df = pd.read_csv(cancer_special_merge_path, sep="\t", header=None)

In [None]:
cancer_special_merge_features = df[3].tolist()

In [None]:
global_important_features[
    "cancer_intersection_merge_samplings"
] = cancer_special_merge_features

In [None]:
with open(global_important_features_path, "w", encoding="utf8") as f:
    json.dump(global_important_features, f, indent=4)

In [None]:
bed_dir = global_important_features_dir / "global_task_features_beds"
for name, features in global_important_features.items():
    bed_ranges = bins_to_bed_ranges(
        bin_indexes=features, chroms=chroms, resolution=RESOLUTION
    )
    bed_path = bed_dir / f"{name}_features.bed"
    write_to_bed(bed_ranges, bed_path)