In [None]:
"""See markdown"""
# pylint: disable=line-too-long, redefined-outer-name, import-error, duplicate-code

# Prepare background and evaluation data for SHAP analysis.

In [7]:
from __future__ import annotations

import copy
import itertools
import random
from collections import defaultdict
from pathlib import Path
from typing import List

from epi_ml.core.metadata import Metadata
from epi_ml.utils.general_utility import write_hdf5_paths_to_file

BIOMATERIAL_TYPE = "harmonized_biomaterial_type"
CELL_TYPE = "harmonized_sample_ontology_intermediate"
ASSAY = "assay_epiclass"
SEX = "harmonized_donor_sex"
CANCER = "harmonized_sample_cancer_high"
DISEASE = "harmonized_sample_disease_high"
LIFE_STAGE = "harmonized_donor_life_stage"

In [5]:
def display_gen_info(metadata: Metadata, extra_categories: List[str] | None = None):
    """Display track type, assay and cell type class counts."""
    metadata.display_labels("track_type")
    metadata.display_labels(ASSAY)
    metadata.display_labels(CELL_TYPE)
    if extra_categories:
        for category in extra_categories:
            metadata.display_labels(category)

In [None]:
def select_datasets(metadata: Metadata, n=5) -> List[str]:
    """
    Select a random subset of n datasets for each unique (track_type, assay, cell_type) trio.

    Returns:
        list: A list of sampled md5sums of the selected datasets.
    """
    trio_files = defaultdict(list)
    for md5sum, dset in metadata.items:
        trio = (dset["track_type"], dset[ASSAY], dset[CELL_TYPE])
        trio_files[trio].append(md5sum)
    print(len(trio_files))

    sampled_md5s = list(
        itertools.chain.from_iterable(
            [random.sample(md5_list, n) for md5_list in trio_files.values()]
        )
    )
    return sampled_md5s

In [None]:
base = Path().home() / "Projects/epilap/input/metadata"
path = base / "dfreeze-v2" / "hg38_2023-epiatlas-dfreeze_v2.1_w_encode_noncore_2.json"
base_metadata = Metadata(path)

In [None]:
model_path = (
    Path.home()
    / "mounts/narval-mount/project-rabyj/epilap/output/logs/epiatlas-dfreeze-v2.1/hg38_100kb_all_none"
)
model_path = (
    model_path
    / "harmonized_donor_life_stage_1l_3000n/no-unknown/10fold-oversampling/split0"
)

In [None]:
training_md5_path = list(model_path.glob("split0_training_*.md5"))[0]
valid_md5_path = list(model_path.glob("split0_validation_*.md5"))[0]
training_mapping_path = model_path / "training_mapping.tsv"

with open(training_md5_path, "r", encoding="utf8") as f:
    training_md5 = set(f.read().splitlines())
with open(valid_md5_path, "r", encoding="utf8") as f:
    valid_md5 = set(f.read().splitlines())
with open(training_mapping_path, "r", encoding="utf8") as f:
    training_mapping = dict(line.split("\t") for line in f.read().splitlines())

In [None]:
training_metadata = copy.deepcopy(base_metadata)
for md5 in list(training_metadata.md5s):
    if md5 not in training_md5:
        del training_metadata[md5]

valid_metadata = copy.deepcopy(base_metadata)
for md5 in list(valid_metadata.md5s):
    if md5 not in valid_md5:
        del valid_metadata[md5]

### Background list

In [None]:
n_per_trio = 3

In [None]:
display_gen_info(training_metadata)

trios_md5_dict = defaultdict(list)
for dset in training_metadata.datasets:
    trios_md5_dict[(dset[CELL_TYPE], dset[ASSAY], dset["track_type"])].append(
        dset["md5sum"]
    )

print(f"{len(trios_md5_dict)} entries/trios")

background_md5s = set()
for trio, md5s in trios_md5_dict.items():
    background_md5s.update(md5s[0:n_per_trio])

for md5 in list(training_metadata.md5s):
    if md5 not in background_md5s:
        del training_metadata[md5]

display_gen_info(training_metadata)

In [None]:
name = "{n_per_trio}pertrio"
write_hdf5_paths_to_file(
    md5s=training_metadata.md5s,
    parent=".",
    suffix="100kb_all_none",
    filepath=model_path / "shap" / f"shap_background_{name}.list",
)

### Evaluation list

In [None]:
display_gen_info(valid_metadata)

In [None]:
valid_metadata.remove_category_subsets(ASSAY, ["input", "wgbs-standard", "wgbs-pbat"])
# valid_metadata.select_category_subsets(CELL_TYPE, ["T cell", "lymphocyte of B lineage", "muscle organ", "monocyte", "neutrophil", "myeloid cell"])

In [None]:
name = "6hist_w_rna"
write_hdf5_paths_to_file(
    md5s=valid_metadata.md5s,
    parent=".",
    suffix="100kb_all_none",
    filepath=model_path / "shap" / f"shap_eval_{name}.list",
)