In [None]:
"""Top SHAP features: details the following

- Features unique to ChIP raw tracks
- Features coming from small file subsets (<10)
"""

# pylint: disable=duplicate-code

## SETUP

In [None]:
from __future__ import annotations

from collections import defaultdict
from pathlib import Path
from typing import Dict, List

import pandas as pd
from IPython.display import display

from epiclass.utils.notebooks.paper.paper_utilities import EPIATLAS_16_CT, MetadataHandler
from epiclass.utils.shap.subset_features_handling import (
    collect_features_from_feature_count_file,
)

BIOMATERIAL_TYPE = "harmonized_biomaterial_type"
CELL_TYPE = "harmonized_sample_ontology_intermediate"
ASSAY = "assay_epiclass"
SEX = "harmonized_donor_sex"
CANCER = "harmonized_sample_cancer_high"
DISEASE = "harmonized_sample_disease_high"
LIFE_STAGE = "harmonized_donor_life_stage"
TRACK = "track_type"

In [None]:
def format_label(label: str):
    """Format strings to lowercase and replaces spaces/hyphen with underscore."""
    return label.lower().replace(" ", "_").replace("-", "_")

## Cell type top SHAP features details

In [None]:
paper_dir = Path.home() / "projects/epiclass/output/paper"
if not paper_dir.exists():
    raise ValueError(f"{paper_dir} does not exist.")

metadata_handler = MetadataHandler(paper_dir)
meta_df = metadata_handler.load_metadata_df("v2", merge_assays=False)

In [None]:
meta_df[CELL_TYPE] = meta_df[CELL_TYPE].apply(format_label)

In [None]:
cts = [format_label(label) for label in EPIATLAS_16_CT]

In [None]:
assert len(set(meta_df[CELL_TYPE].unique()) & set(cts)) == 16

In [None]:
cell_type_df = meta_df[meta_df[CELL_TYPE].isin(cts)]
print(meta_df.shape, cell_type_df.shape)

### Find all possible files subsets, and seek folder names with that

In [None]:
# Smallest subsets
meta_keys = cell_type_df.groupby([ASSAY, TRACK, CELL_TYPE]).size().index

# Need to create keys for the "ALL" subsets
all_assay_celltype = cell_type_df.groupby([ASSAY, CELL_TYPE]).size().index

all_keys_with_all = [
    (assay, "ALL", cell_type) for (assay, cell_type) in all_assay_celltype
]

all_keys = list(set(meta_keys) | set(all_keys_with_all))
print(len(meta_keys))

In [None]:
# Map folder_name → all relevant (assay, track, cell_type)
folder_map = defaultdict(list)
for assay, track, cell_type in all_keys:
    if pd.isna(track) or track.upper() == "ALL":
        folder_name = assay  # No track suffix
    else:
        folder_name = f"{assay}_{track}"
    folder_map[folder_name].append((assay, track, cell_type))

folder_map["mixed_samples"] = [("ALL", "ALL", cell_type) for cell_type in cts]

### Feature count per cell type class

In [None]:
base_path = Path.home() / "scratch/epiclass/join_important_features"
feature_count_general_dir = (
    base_path
    / "hg38_100kb_all_none"
    / "harmonized_sample_ontology_intermediate_1l_3000n"
    / "10fold-oversampling"
    / "global_shap_analysis"
)

top303_dir = feature_count_general_dir / "top303"

ignore = set(["fc", "raw", "pval", "mixed_samples"])


all_features_per_class = defaultdict(set)
features_per_class = defaultdict(set)
for folder in sorted(top303_dir.iterdir()):
    if not folder.is_dir():
        continue

    feature_count_path = folder / "feature_count.json"
    if not feature_count_path.exists():
        print(f"File {feature_count_path} does not exist.")
        continue

    features: Dict[str, List[int]] = collect_features_from_feature_count_file(
        feature_count_path, n=8
    )

    if not features:
        print(f"No features passing threshold found in {folder.name}")
        continue

    for class_label, class_features in features.items():
        all_features_per_class[class_label].update(class_features)

    if any(l in folder.name for l in ignore):
        print(f"Skipping folder {folder.name}")
        continue

    print(folder.name)
    for class_label, class_features in features.items():
        features_per_class[class_label].update(class_features)

In [None]:
df_top_features_count = pd.DataFrame(dtype=int)

In [None]:
for class_label, class_features in all_features_per_class.items():
    df_top_features_count.loc[class_label, "count_all"] = len(class_features)

In [None]:
for class_label, class_features in features_per_class.items():
    df_top_features_count.loc[class_label, "count_limited"] = len(class_features)

In [None]:
df_top_features_count.fillna(0, inplace=True)

df_top_features_count["diff"] = (
    df_top_features_count["count_all"] - df_top_features_count["count_limited"]
)
display(df_top_features_count)

In [None]:
df_top_features_count.to_csv(
    feature_count_general_dir / "unique_features_count_union.csv",
    index_label="class_label",
)

### Features unique to ChIP raw tracks

In [None]:
not_chip_raw_features_per_class = defaultdict(set)
chip_raw_features_per_class = defaultdict(set)
for folder in sorted(top303_dir.iterdir()):
    if not folder.is_dir():
        continue

    feature_count_path = folder / "feature_count.json"
    if not feature_count_path.exists():
        print(f"File {feature_count_path} does not exist.")
        continue

    features: Dict[str, List[int]] = collect_features_from_feature_count_file(
        feature_count_path, n=8
    )

    if not features:
        print(f"No features passing threshold found in {folder.name}")
        continue

    if "h3" in folder.name and folder.name.endswith("_raw"):
        for class_label, class_features in features.items():
            chip_raw_features_per_class[class_label].update(class_features)
    else:
        for class_label, class_features in features.items():
            not_chip_raw_features_per_class[class_label].update(class_features)

    print(folder.name)

In [None]:
unique_raw_count = defaultdict(int)
for class_label in not_chip_raw_features_per_class:
    raw_features = set(chip_raw_features_per_class[class_label])
    not_raw_features = set(not_chip_raw_features_per_class[class_label])

    raw_specific_features = raw_features - not_raw_features
    N_unique = len(raw_specific_features)
    N_total = len(raw_features | not_raw_features)
    print(
        f"{class_label}: {len(raw_specific_features)}/{N_total} = {N_unique/N_total:.2%} features unique to raw tracks"
    )

    unique_raw_count[class_label] = N_unique

### Associate feature count and subset size for all combinations

In [None]:
feature_count_dict = {}

for folder in sorted(top303_dir.iterdir()):
    if not folder.is_dir():
        continue

    folder_name = folder.name
    if folder_name not in folder_map:
        print(f"Unknown folder: {folder_name}")
        continue

    feature_count_path = folder / "feature_count.json"
    if not feature_count_path.exists():
        continue

    print(folder_name)
    features = collect_features_from_feature_count_file(feature_count_path, n=8)
    features = {format_label(k): v for k, v in features.items()}
    if not features:
        continue

    for assay, track, cell_type in folder_map[folder_name]:
        count = len(features.get(cell_type, []))
        feature_count_dict[(assay, track, cell_type)] = count

In [None]:
# Ensure the index is a proper MultiIndex for feature_count_df
feature_count_df = pd.DataFrame.from_dict(
    feature_count_dict, orient="index", columns=["feature_count"]
)

# MultiIndex needs to have same names as other df for merge
feature_count_df.index = pd.MultiIndex.from_tuples(
    feature_count_df.index, names=[ASSAY, TRACK, CELL_TYPE]
)

In [None]:
# Build smallest subsets file count df
ct_count_df = cell_type_df.groupby([ASSAY, TRACK, CELL_TYPE]).size()

# Need to create keys for the track="ALL" subsets
all_count_df = cell_type_df.groupby([ASSAY, CELL_TYPE]).size()
all_count_df.index = pd.MultiIndex.from_tuples(
    [(assay, "ALL", cell_type) for assay, cell_type in all_count_df.index],
    names=[ASSAY, TRACK, CELL_TYPE],
)

# Add mixed_samples-style (assay="ALL", track="ALL") rows
mixed_count_df = cell_type_df.groupby([CELL_TYPE]).size()
mixed_count_df.index = pd.MultiIndex.from_tuples(
    [("ALL", "ALL", cell_type) for cell_type in mixed_count_df.index],
    names=[ASSAY, TRACK, CELL_TYPE],
)

file_count_df = pd.concat(
    [
        ct_count_df.rename("file_count").to_frame(),
        all_count_df.rename("file_count").to_frame(),
        mixed_count_df.rename("file_count").to_frame(),
    ]
)

In [None]:
# Merge
merged_df = feature_count_df.join(file_count_df, how="outer").fillna(0).astype(int)

In [None]:
merged_df.to_csv(feature_count_general_dir / "feature_count_per_subset.csv")