In [None]:
"""Code related to SHAP analyses for flagship paper + supplementary figures."""

# pylint: disable=line-too-long, redefined-outer-name, import-error, duplicate-code, unreachable, unused-argument, use-dict-literal, unused-import

How relevant files were downloaded

~~~bash
base_path="/home/rabyj/scratch/epilap-logs/epiatlas-dfreeze-v2.1/hg38_100kb_all_none/harmonized_sample_ontology_intermediate_1l_3000n/10fold-oversampling"
rsync --info=progress2 -aR narval:${base_path}/./split*/*.md5 .

# rsync --info=progress2 -aR --exclude "*.npz" --exclude "analysis_n*_f80.00/" narval:${base_path}/./split*/shap .
rsync --info=progress2 -aR narval:${base_path}/./split*/shap/*background*.npz .
~~~

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import annotations

import gzip
import json
import tarfile
from collections import Counter, defaultdict
from math import floor
from pathlib import Path
from typing import Collection, Dict, List, Tuple

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import display
from plotly.subplots import make_subplots
from scipy.special import softmax

from epi_ml.core import metadata
from epi_ml.core.data_source import EpiDataSource
from epi_ml.core.epiatlas_treatment import EpiAtlasFoldFactory, EpiAtlasMetadata
from epi_ml.utils import modify_metadata
from epi_ml.utils.bed_utils import bed_ranges_to_bins, read_bed_to_ranges
from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    CANCER,
    CELL_TYPE,
    SEX,
    MetadataHandler,
)

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
paper_dir = base_dir

base_fig_dir = paper_dir / "figures"
base_data_dir = (
    base_dir / "data" / "training_results" / "dfreeze_v2" / "hg38_100kb_all_none"
)
if not base_data_dir.exists():
    raise FileNotFoundError(f"Directory {base_data_dir} does not exist.")

In [None]:
chromsize_path = paper_dir / "data" / "chromsizes" / "hg38.noy.chrom.sizes"
chroms = EpiDataSource.load_external_chrom_file(chromsize_path)

In [None]:
metadata_handler = MetadataHandler(paper_dir)
metadata_df = metadata_handler.load_metadata_df("v2")
print(metadata_df.shape)
print(metadata_df.index.name)

### Background selection details

In [None]:
def compare_all_training_vs_background(
    parent_dir: Path,
    label_category: str,
    dataset_handler: EpiAtlasFoldFactory,
    output_dir: Path,
):
    """Create a table that compares the training data to the SHAP background data for each split.

    Args:
        parent_dir: The directory containing each split/fold directory.
        label_category: The category to compare the training and background data on.
        datasource: The data source to use for the comparison.
    """
    my_metadata = dataset_handler.epiatlas_dataset.metadata
    # Get md5s from dataset handler (with oversampling info)
    oversample_md5_training_composition = defaultdict(list)
    for i, my_data in enumerate(dataset_handler.yield_split(oversample=True)):
        split = f"split{i}"
        full_training_md5s = list(my_data.train.ids)
        oversample_md5_training_composition[split] = full_training_md5s

    # Get md5s saved to file (no oversampling info)
    split_md5_composition = defaultdict(dict)
    metadata_split_dfs = defaultdict(dict)
    for split_dir in parent_dir.glob("split*"):
        if not split_dir.is_dir():
            continue
        split = split_dir.name

        training_data_path = list(split_dir.glob("*training*.md5"))
        if len(training_data_path) != 1:
            raise ValueError(f"Multiple training data files found in {split_dir}")
        training_data_path = training_data_path[0]

        background_data_path = list(split_dir.glob("shap/*background*.npz"))
        if len(background_data_path) != 1:
            raise ValueError(f"Multiple background data files found in {split_dir}")
        background_data_path = background_data_path[0]

        training_md5s = set(
            pd.read_csv(training_data_path, index_col=False, header=None)[0].values
        )
        background_data = np.load(background_data_path)
        background_md5s = set(background_data["background_md5s"])

        # Sanity check: background data is a subset of training data
        diff = background_md5s - training_md5s
        if diff:
            raise ValueError(
                f"Background data is not a subset of the training data in {split_dir}: {len(diff)} md5s unique to background data."
            )

        # Sanity check: training data set == oversampled training data set
        full_training_md5s = oversample_md5_training_composition[split]
        if training_md5s != set(full_training_md5s):
            raise ValueError(
                f"Training data (unique={len(training_md5s)}) in {split_dir} does not match oversampled training data (unique={len(full_training_md5s)})."
            )

        split_md5_composition[split] = {
            "training": list(full_training_md5s),
            "background": list(background_md5s),
        }
        print(
            f"Split {split} has {len(full_training_md5s)} training md5s (unique={len(set(full_training_md5s))}) and {len(background_md5s)} background md5s."
        )

        # for diff ratio, do ((ratio background - ratio training))*100 -> positive if background has more
        # do assay / cell type / output class ratios
        # table should have full number + ratios + diff ratio for each metadata category
        # do for each split, save to a file, then average accross splits
        background_md5s = split_md5_composition[split]["background"]
        training_md5s = split_md5_composition[split]["training"]

        for category in set([ASSAY, CELL_TYPE, label_category]):
            background_composition = Counter(
                my_metadata[md5][category] for md5 in background_md5s
            )
            training_composition = Counter(
                my_metadata[md5][category] for md5 in training_md5s
            )
            pre_oversampling_composition = Counter(
                my_metadata[md5][category] for md5 in set(training_md5s)
            )

            background_ratios = {
                key: value / len(background_md5s)
                for key, value in background_composition.items()
            }
            training_ratios = {
                key: value / len(training_md5s)
                for key, value in training_composition.items()
            }

            diff_ratios = {
                key: (background_ratios[key] - training_ratios[key]) * 100
                for key in set(background_ratios.keys()) | set(training_ratios.keys())
            }
            metadata_split_dfs[label_category][split] = pd.DataFrame(
                {
                    "Background size": background_composition,
                    "Pre-oversampling training size": pre_oversampling_composition,
                    "Training size": training_composition,
                    "Background Ratio": background_ratios,
                    "Training Ratio": training_ratios,
                    "(Bg ratio - Tr ratio)*100": diff_ratios,
                }
            )

    full_split_sizes = pd.DataFrame(
        {
            "Training size": {
                split: len(set(split_md5_composition[split]["training"]))
                for split in oversample_md5_training_composition.keys()
            },
            "Background size": {
                split: len(set(split_md5_composition[split]["background"]))
                for split in oversample_md5_training_composition.keys()
            },
        }
    )
    full_split_sizes.loc["all", :] = full_split_sizes.mean(axis=0)
    # full_split_sizes.loc[:, "Ratio"] = (
    #     full_split_sizes["Background size"] / full_split_sizes["Training size"]
    # )

    # average par class label
    for category, split_dfs in metadata_split_dfs.items():
        mean_df = pd.concat(split_dfs.values()).groupby(level=0).mean()
        mean_df = mean_df.sort_values("Background Ratio", ascending=False)
        mean_df.rename(
            columns={
                "Background size": "Avg. Background size",
                "Pre-oversampling training size": "Avg. Pre-oversampling training size",
                "Training size": "Avg. Training size",
            },
            inplace=True,
        )

        mean_df.loc["all", "Avg. Background size"] = full_split_sizes.loc[
            "all", "Background size"
        ]
        mean_df.loc["all", "Avg. Pre-oversampling training size"] = full_split_sizes.loc[
            "all", "Training size"
        ]

        mean_df.to_csv(
            output_dir / f"shap_metadata_composition_{category}.csv", index_label=category
        )

In [None]:
# version v2-encode has sample_cancer_high as a label category
# attempt to compare epiatlas_training.py metadata pre-processing.
# for folder, metadata_version in zip([cell_type_dir, cancer_dir], ["v2", "v2-encode"]):
for task, metadata_version in zip([SEX, CELL_TYPE, CANCER], ["v2", "v2", "v2-encode"]):
    print(f"Processing {task} with metadata version {metadata_version}")
    label_category = task

    folder = base_data_dir / f"{task}_1l_3000n" / "10fold-oversampling"
    if not folder.exists():
        raise ValueError(f"Folder {folder} does not exist.")

    # special hdf5 filepath with no real path
    metadata_filename = Path(metadata_handler.version_names[metadata_version])
    metadata_path = paper_dir / "data" / "metadata" / "epiatlas" / metadata_filename
    md5_filepath = metadata_path.parent / f"{metadata_filename.stem}.md5"

    # epiatlas training treatment
    my_datasource = EpiDataSource(
        hdf5=md5_filepath,
        chromsize=chromsize_path,
        metadata=metadata_path,
    )

    my_metadata = metadata.UUIDMetadata(my_datasource.metadata_file)

    my_metadata.remove_category_subsets(
        label_category=label_category, labels=["", "unknown"]
    )
    my_metadata.remove_category_subsets(
        label_category="track_type", labels=["Unique.raw"]
    )
    my_metadata.remove_missing_labels(label_category)

    if label_category in set(
        [
            "harmonized_sample_ontology_intermediate",
            "harm_sample_ontology_intermediate",
            "cell_type",
        ]
    ):
        categories = set(my_metadata.get_categories())
        if "assay_epiclass" in categories:
            assay_cat = "assay_epiclass"
        elif "assay" in categories:
            assay_cat = "assay"
        else:
            raise ValueError("Cannot find assay category for class pairs.")
        my_metadata = modify_metadata.filter_by_pairs(
            my_metadata,
            assay_cat=assay_cat,
            cat2=label_category,
            nb_pairs=9,
            min_per_pair=10,
        )

    label_list = metadata.env_filtering(my_metadata, label_category)

    # ratio comparison
    print(f"Creating dataset for {label_category}.")
    full_dataset = EpiAtlasMetadata(
        datasource=my_datasource,
        metadata=my_metadata,
        label_category=label_category,
        label_list=label_list,
        min_class_size=10,
        force_filter=True,
    )
    print("Creating dataset handler.")
    dataset_handler = EpiAtlasFoldFactory(
        epiatlas_dataset=full_dataset,
        n_fold=10,
        test_ratio=0,
    )

    output_dir = folder / "shap_analysis"
    output_dir.mkdir(exist_ok=True, parents=True)

    print("Entering compare_all_training_vs_background")
    compare_all_training_vs_background(
        parent_dir=folder,
        label_category=label_category,
        dataset_handler=dataset_handler,
        output_dir=output_dir,
    )

### SHAP values ranking

SHAP rank significance analysis (sample ontology)

- Gather rank of all bins for sample output class (20k x 30k matrix)
- For each % of highest values starting from top to 10%, compute mean SHAP value (20k x 10 matrix)
- Create a graph that represents the ratio between 0-1% and 1-2% chunks, etc up until 10%, for each sample. 1 violin or boxplot per chunk (10 plots), 20k points per plot.

Then, with flagship cell GO, get rank distribution of features "unique" in group1 VS group2 (t-cell, neutrophil: k27ac vs k27me3)

#### Compute SHAP rankings / 1% chunks means

In [None]:
def load_data(shap_dir: Path) -> Tuple[Dict, Dict]:
    """Load evaluation and background data from specified directory.

    Args:
        shap_dir (Path): Directory containing the SHAP data files.

    Returns:
        tuple: Tuple containing evaluation results and background data.
    """
    # Find all npz files once
    npz_files = list(shap_dir.glob("*.npz"))
    eval_files = [f for f in npz_files if "evaluation" in f.name]
    background_files = [f for f in npz_files if "background" in f.name]

    if len(eval_files) != 1:
        raise ValueError(
            f"Expected one evaluation file, found {len(eval_files)} in {shap_dir}"
        )
    if len(background_files) != 1:
        raise ValueError(
            f"Expected one background file, found {len(background_files)} in {shap_dir}"
        )

    eval_results = dict(np.load(eval_files[0], allow_pickle=True))
    background_data = dict(np.load(background_files[0], allow_pickle=True))

    return eval_results, background_data


def compute_mean_shap_values(
    shap_matrices: List[np.ndarray],
    md5_indices: List[str],
    classes: Dict[str, int],
    metadata_df: pd.DataFrame,
) -> Dict[str, Dict]:
    """Rank amplitude of shap values and computes the mean SHAP values for specific segments of feature importance rankings
    for each sample, and aggregates this information along with their softmax transformations.

    This function processes SHAP values for given classes, extracts the relevant SHAP values
    for each sample using its MD5 index, ranks these values, and calculates the mean of these values
    in the top 10% segments. Additionally, it applies a softmax transformation to the SHAP values
    and computes the mean for these as well, allowing for comparison between raw and transformed importance.

    Args:
        shap_matrices (List[np.ndarray]): SHAP values for all classes.
        md5s (List[str]): Indices corresponding to MD5 hashes.
        classes (dict): Mapping of class names to indices.
        metadata_df (pd.DataFrame): DataFrame containing metadata.

    Returns:
        Dict[str, Dict]: Dictionary of SHAP details for each sample ({md5: details}).
    """
    input_size = len(shap_matrices[0][0])
    chunks_10perc_idx = [
        (floor(input_size / 100) * i, floor(input_size / 100) * (i + 1))
        for i in range(10)
    ]
    shap_details = {}

    for md5_idx, md5 in enumerate(md5_indices):
        cell_type: str = metadata_df.loc[md5][CELL_TYPE]  # type: ignore
        class_idx = classes[cell_type]
        sample_shaps = shap_matrices[class_idx][md5_idx]

        sample_shaps = np.abs(sample_shaps)  # magnitude only
        softmax_shaps = softmax(sample_shaps)
        ranks = np.argsort(sample_shaps)[::-1]  # descending order, greatest to smallest

        sorted_shap_vals = sample_shaps[ranks]
        sorted_softmax_shap_vals = softmax_shaps[ranks]

        mean_10perc_vals = [
            np.mean(sorted_shap_vals[idx1:idx2], dtype=np.float64)
            for idx1, idx2 in chunks_10perc_idx
        ]
        mean_10perc_vals_softmax = [
            np.mean(sorted_softmax_shap_vals[idx1:idx2], dtype=np.float64)
            for idx1, idx2 in chunks_10perc_idx
        ]

        shap_details[md5] = {
            "ranks": ranks,
            "mean_10perc_vals": mean_10perc_vals,
            "mean_10perc_vals_softmax": mean_10perc_vals_softmax,
        }

    return shap_details

In [None]:
def create_shap_tables(shap_details: Dict[str, Dict]) -> Dict[str, pd.DataFrame]:
    """Create tables containing SHAP details for each sample.

    Args:
        shap_details (Dict[str, Dict]): Dictionary of SHAP details for each sample
            format: {md5: {"ranks": vals, "mean_10perc_vals": vals, "mean_10perc_vals_softmax": vals}}

    Returns:
       Dict[str, pd.DataFrame]: List of DataFrames containing each shap details category.
    """
    ranks = pd.DataFrame.from_dict(
        data={md5: details["ranks"] for md5, details in shap_details.items()},
        orient="index",
        dtype="int32",
    )

    mean_10perc_vals = pd.DataFrame.from_dict(
        data={md5: details["mean_10perc_vals"] for md5, details in shap_details.items()},
        orient="index",
        dtype="float64",
    )
    mean_10perc_vals.columns = [f"mean(top {i}% to {i+1}%)" for i in range(10)]

    mean_10perc_vals_softmax = pd.DataFrame.from_dict(
        data={
            md5: details["mean_10perc_vals_softmax"]
            for md5, details in shap_details.items()
        },
        orient="index",
        dtype="float64",
    )
    mean_10perc_vals_softmax.columns = mean_10perc_vals.columns

    return {
        "ranks": ranks,  # actually sorting indices, not ranks
        "mean_10perc_vals": mean_10perc_vals,
        "mean_10perc_vals_softmax": mean_10perc_vals_softmax,
    }

In [None]:
def save_shap_tables(
    tables: Dict[str, pd.DataFrame], output_dir: Path, verbose: bool = True
):
    """Save SHAP tables to specified directory.

    Args:
        tables (Dict[str, pd.DataFrame]): Dictionary of tables to save.
        output_dir (Path): Directory to save the tables.
    """
    for name, table in tables.items():
        if name == "ranks":
            output_path = output_dir / "shap_abs_ranks.npz"
            np.savez_compressed(
                output_path,
                **{
                    "index": table.index.values,
                    "columns": table.columns.values,
                    "values": table.values,
                },
            )
            if verbose:
                print(f"Saved SHAP ranks to {output_path}")
        else:
            output_path = output_dir / f"shap_table_{name}.csv"
            table.to_csv(output_path)
            if verbose:
                print(f"Saved SHAP details to {output_path}")

In [None]:
split_name = "split0"

In [None]:
# shap_details = defaultdict(dict)
# for shap_dir in cell_type_dir.glob(f"{split_name}/shap"):
#     if not shap_dir.is_dir():
#         continue

#     try:
#         eval_results, background_data = load_data(shap_dir)
#     except ValueError as e:
#         print(e)
#         continue

#     classes = {class_name: int(idx) for idx, class_name in background_data["classes"]}

#     shap_matrices = eval_results["shap_values"]

#     shap_details.update(
#         compute_mean_shap_values(
#             shap_matrices, eval_results["evaluation_md5s"], classes, metadata_df
#         )
#     )

# del shap_matrices, eval_results, background_data


# shap_tables = create_shap_tables(shap_details)


# output_dir = cell_type_dir / "global_shap_analysis"
# if not output_dir.exists():
#     raise ValueError(f"Output directory {output_dir} does not exist.")

# output_dir = output_dir / f"{split_name}_details"
# output_dir.mkdir(exist_ok=True)
# save_shap_tables(shap_tables, output_dir)

# print(f"Processed {len(shap_details)} samples.")

#### Graph computed means ratios

In [None]:
def graph_shap_10perc(shap_table_mean_path: Path, output_dir: Path):
    """Plot the variation of mean SHAP values accross 1% segments (top 10)"""
    df = pd.read_csv(shap_table_mean_path, index_col=0)
    fig = go.Figure()
    for idx, col in enumerate(df.columns):
        try:
            ratio = df.iloc[:, idx] / df.iloc[:, idx + 1]
        except IndexError:
            continue
        fig.add_trace(
            go.Violin(
                y=ratio,
                name=f"Ratio {col} / {df.columns[idx+1]}",
                points="all",
                marker=dict(size=1),
                hovertemplate="%{text}",
                text=[f"{md5}: {diff}" for md5, diff in zip(df.index, ratio)],
                box_visible=True,
                meanline_visible=True,
                spanmode="hard",
            )
        )
    fig.update_layout(
        title="Ratio of mean SHAP values between 1% segments",
        xaxis_title="1% segment",
        yaxis_title="Ratio",
    )

    # Save figure
    output = output_dir / shap_table_mean_path.stem

    fig.write_html(f"{output}.html")
    fig.write_image(f"{output}.png")
    fig.write_image(f"{output}.svg")

    fig.show()

In [None]:
# fig_output_dir = base_fig_dir / "flagship" / "features_SHAP_ranks"

# mean_vals_path = (
#     cell_type_dir / "global_shap_analysis" / "shap_details_mean_10perc_vals.csv"
# )
# if not mean_vals_path.exists():
#     raise FileNotFoundError(mean_vals_path)
# graph_shap_10perc(mean_vals_path, fig_output_dir)

In [None]:
def beds_to_bins(
    feature_beds_tar: Path, chroms: List[Tuple[str, int]]
) -> Dict[str, List[int]]:
    """Extracts and processes .bed files from a tar.gz archive, targeting specific cells and markers.

    Args:
        feature_beds_tar (Path): The path to the tar.gz file containing .bed files.
        chroms (List[Tuple[str, int]]): List of tuples containing chromosome names and their lengths.

    Returns:
        Dict[str, List[int]]: A dictionary with keys as feature names and values as lists of bins derived from the .bed ranges.
    """
    feature_bins = {}

    def is_target_cell(name: str) -> bool:
        """Check if the file name indicates a target cell type."""
        return "T_cell" in name or "neutrophil" in name

    def has_target_marker(name: str) -> bool:
        """Check if the file name includes target histone markers."""
        return any(marker in name for marker in ["h3k27ac", "h3k27me3", "h3k9me3"])

    with tarfile.open(feature_beds_tar, "r:gz") as tar:
        for member in tar.getmembers():
            if (
                member.name.endswith(".bed")
                and is_target_cell(member.name)
                and has_target_marker(member.name)
            ):
                with tar.extractfile(member) as f:  # type: ignore
                    bed_ranges = read_bed_to_ranges(f)
                    bed_bins = bed_ranges_to_bins(
                        bed_ranges, chroms=chroms, resolution=100 * 1000
                    )
                    name = Path(member.name).stem.replace("_feature", "")
                    feature_bins[name] = bed_bins

    return feature_bins

#### Graph rank of important relevant features from previous analyses

In [None]:
fig_dir = base_fig_dir / "flagship" / "features_SHAP_ranks"

In [None]:
# feature_beds_tar = cell_type_dir / "global_shap_analysis" / "select_beds_top303.tar.gz"

# feature_bins_dict = beds_to_bins(
#     feature_beds_tar=feature_beds_tar,
#     chroms=chroms,
# )

# feature_bins_dict.keys()

In [None]:
# print("Loading rank values.")
# output_dir = cell_type_dir / "global_shap_analysis"
# rank_file = output_dir / "shap_abs_ranks.npz"
# print(rank_file)

# rank_data = {}
# with np.load(rank_file, allow_pickle=True) as f:
#     rank_data["md5s"] = f["index"]
#     rank_data["feature_index"] = f["columns"]
#     rank_data["values"] = f["values"]

In [None]:
def graph_ranks_sample_ontology(
    rank_data: Dict[str, np.ndarray],
    feature_bins: Collection[int],
    metadata_df: pd.DataFrame,
    output_dir: Path,
    title_name: str,
):
    """Graph the rank of specific sample ontology features, from perspective of different classes."""
    # metadata groups: t-cell/neutrophil with h3k27ac/h3k27me3
    cell_types = ["T cell", "neutrophil"]
    assays = ["h3k27ac", "h3k27me3"]
    md5s_dict = {
        f"{cell}_{assay}_md5s": list(
            metadata_df[(metadata_df[CELL_TYPE] == cell) & (metadata_df[ASSAY] == assay)][
                "md5sum"
            ].values
        )
        for cell in cell_types
        for assay in assays
    }

    # get md5s for each of the four metadata groups, and then get the ranks for each of the shared bins.
    # fig1: one boxplot group per metadata group, with the sample ranks for each shared bin
    # fig2: 1 violin plot with all ranks for each metadata group
    fig1 = go.Figure()
    fig2 = go.Figure()
    for name, md5s in md5s_dict.items():
        avail_md5s = np.isin(rank_data["md5s"], md5s)
        shared_ranks = rank_data["values"][avail_md5s]

        nb_samples = shared_ranks.shape[0]

        # Each bin is a different boxplot within the group
        all_ranks = []
        for bin_index in sorted(feature_bins):
            # need to do this since the ranks are actually sorting indices
            feature_ranks = [np.where(ranks == bin_index)[0][0] for ranks in shared_ranks]
            all_ranks.extend(feature_ranks)

            fig1.add_trace(
                go.Box(
                    x=[f"{name.replace('_md5s', '')} (f={len(feature_ranks)})"]
                    * len(feature_ranks),
                    y=feature_ranks,
                    name=f"{name}_bin_{bin_index}",
                    line=dict(color="black", width=0.1),
                    showlegend=False,
                    boxpoints=False,
                    hovertext=[f"bin={bin_index},Rank={rank}" for rank in feature_ranks],
                ),
            )

        fig2.add_trace(
            go.Violin(
                y=all_ranks,
                name=f"{name.replace('_md5s', '')} (f={nb_samples})",
                spanmode="hard",
                box_visible=True,
                meanline_visible=True,
                showlegend=True,
                fillcolor="grey",
                line=dict(color="black", width=1),
                points=False,
            ),
        )

    fig1.update_layout(yaxis=dict(range=[-5, 20000], autorange=False))
    fig2.update_layout(yaxis=dict(range=[-5, 20000], autorange=False))

    for y_val in [303, 909]:
        fig1.add_hline(
            y=y_val,
            line_width=1,
            line_dash="dash",
            line_color="black",
            annotation_text=f"Top {y_val}",
        )
        fig2.add_hline(
            y=y_val,
            line_width=1,
            line_dash="dash",
            line_color="black",
            annotation_text=f"Top {y_val}",
        )

    fig1.update_layout(
        title=f"{title_name} features ranks (n={len(feature_bins)} bins) - top303 abs(SHAP),80% samples,8/10 fold",
        xaxis_title=f"Sample group ({len(feature_bins)} features, f files)",
        yaxis_title="Rank",
        boxmode="group",
        width=1000,
        height=1000,
    )
    fig2.update_layout(
        title=f"{title_name} features ranks (n={len(feature_bins)} bins mixed) - top303 abs(SHAP),80% samples,8/10 fold",
        xaxis_title="Sample group (f files)",
        yaxis_title="Rank",
        width=1000,
        height=1000,
    )

    # Save figure
    output = output_dir / f"feature_ranks_{title_name}"
    fig1.write_html(f"{output}.html")
    fig1.write_image(f"{output}.png")
    fig1.write_image(f"{output}.svg")

    fig1.show()

    output = output_dir / f"all_feature_ranks_{title_name}"
    fig2.write_html(f"{output}.html")
    fig2.write_image(f"{output}.png")
    fig2.write_image(f"{output}.svg")

    fig2.show()

In [None]:
# shared_k27ac_k27me3_bins = set(feature_bins_dict["h3k27ac_T_cells"]) & set(
#     feature_bins_dict["h3k27me3_T_cells"]
# )
# shared_k27ac_k9me3_bins = set(feature_bins_dict["h3k27ac_T_cells"]) & set(
#     feature_bins_dict["h3k9me3_T_cells"]
# )

# unique_k27ac_v_k27me3_bins = set(feature_bins_dict["h3k27ac_T_cells"]) - set(
#     feature_bins_dict["h3k27me3_T_cells"]
# )
# unique_k27ac_v_k9me3_bins = set(feature_bins_dict["h3k27ac_T_cells"]) - set(
#     feature_bins_dict["h3k9me3_T_cells"]
# )

# unique_k27me3_v_k27ac_bins = set(feature_bins_dict["h3k27me3_T_cells"]) - set(
#     feature_bins_dict["h3k27ac_T_cells"]
# )
# unique_k9me3_v_k27ac_bins = set(feature_bins_dict["h3k9me3_T_cells"]) - set(
#     feature_bins_dict["h3k27ac_T_cells"]
# )

In [None]:
# fig_output_dir = fig_dir / "k27ac_k27me3"

# for var_name in [
#     "shared_k27ac_k27me3_bins",
#     "unique_k27ac_v_k27me3_bins",
#     "unique_k27me3_v_k27ac_bins",
# ]:
#     graph_ranks_sample_ontology(
#         rank_data=rank_data,
#         feature_bins=globals()[var_name],
#         metadata_df=metadata_df,
#         output_dir=fig_output_dir,
#         title_name=var_name,
#     )

In [None]:
# fig_output_dir = fig_dir / "k27ac_k9me3"

# for var_name in [
#     "shared_k27ac_k9me3_bins",
#     "unique_k27ac_v_k9me3_bins",
#     "unique_k9me3_v_k27ac_bins",
# ]:
#     graph_ranks_sample_ontology(
#         rank_data=rank_data,
#         feature_bins=globals()[var_name],
#         metadata_df=metadata_df,
#         output_dir=fig_output_dir,
#         title_name=var_name,
#     )

## ChromHMM regulatory features

In [None]:
base_shap_folder = (
    Path.home()
    / "scratch/epiclass/join_important_features/hg38_regulatory_regions_n30321_100kb_coord/harmonized_sample_ontology_intermediate_1l_3000n/10fold-oversampling/global_shap_analysis/top303"
)

chrom_hmm_folder = Path.home() / "Projects/epiclass/output/paper/data/ChromHMM"

for path in [base_shap_folder, chrom_hmm_folder]:
    if not path.exists():
        raise FileNotFoundError(path)

corr_path = chrom_hmm_folder / "StackedChromHMM_hg38_EnhancerMaxK27acCorrelations.txt.gz"
other_info_path = (
    chrom_hmm_folder
    / "StackedChromHMM_hg38_Xie_AllInteractionsBackgroundCorr_02PearAdj_JointCorrelations_gABCOnly_RemPromPCorr_REMLabels.txt.gz"
)

# cross-reference both ChromHMM files to find negative corr values that are part of 30k.

# Do a basic most important features analysis, take core6 pval + other marks, do an upset plot of most important features if we ignore cell types

# compare important features with negative correlation values from ChromHMM

### Merge ChromHMM files

In [None]:
with gzip.open(chrom_hmm_folder / corr_path, "rt") as f:
    corr_df = pd.read_csv(f, sep="\t", header=0)

with gzip.open(chrom_hmm_folder / other_info_path, "rt") as f:
    other_info_df = pd.read_csv(f, sep="\t", header=None)

In [None]:
print(corr_df.shape, other_info_df.shape)

In [None]:
bed_regions = list(corr_df.columns[0:3])

In [None]:
full_corr_df = corr_df.merge(other_info_df, left_on=bed_regions, right_on=[0, 1, 2])
full_corr_df = full_corr_df.drop(columns=[0, 1, 2])
full_corr_df["tag"] = full_corr_df[3].astype(str)
full_corr_df = full_corr_df.drop(columns=[3])

In [None]:
full_corr_df["max(abs(spear_r))"].describe()

In [None]:
used_corr_df = full_corr_df.loc[range(0, 30321), :]

In [None]:
baseline_tag_counts = full_corr_df["tag"].value_counts()
baseline_tag_dist = baseline_tag_counts / baseline_tag_counts.sum() * 100
display(baseline_tag_dist)

In [None]:
used_tag_counts = used_corr_df["tag"].value_counts()
used_tag_dist = used_tag_counts / used_tag_counts.sum() * 100
display(used_tag_dist)
display(used_tag_counts)

### Collect important reg NN features and count tags

In [None]:
data: List[Dict] = []
for folder in sorted(base_shap_folder.iterdir()):
    if not folder.is_dir():
        continue

    important_features_path = folder / "features_n8_all.json"
    if not important_features_path.exists():
        continue

    with open(important_features_path, "r", encoding="utf8") as f:
        important_features = json.load(f)

    features_details = full_corr_df.iloc[important_features, :]

    tag_count = features_details["tag"].value_counts()
    tag_dist = tag_count / tag_count.sum() * 100
    dist_diff = tag_dist - used_tag_dist
    fold_change = tag_dist / used_tag_dist

    # Store results
    folder_data = {"subsampling": folder.name}
    for tag in tag_count.index:
        folder_data[f"{tag}_count"] = tag_count[tag]
        folder_data[f"{tag}_perc"] = tag_dist[tag]
        folder_data[f"{tag}_perc_diff"] = dist_diff[tag]
        folder_data[f"{tag}_fc"] = fold_change[tag]

    data.append(folder_data)

# Convert the list of dictionaries to a DataFrame
df = pd.DataFrame(data).fillna(0)
df = df.set_index("subsampling")
df.to_csv(
    base_shap_folder / "important_features_vs_chromHMM_tags.csv", header=True, sep=","
)