In [None]:
# pylint: disable=import-error, redefined-outer-name
"""Workbooks to analyze metadata."""
from __future__ import annotations

import itertools
import random
from collections import Counter, defaultdict
from pathlib import Path
from typing import DefaultDict, List

import pandas as pd
from IPython.display import display

from epi_ml.core.metadata import Metadata
from epi_ml.utils.modify_metadata import TRACKS_MAPPING, epiatlas_assays, filter_by_pairs

CELL_TYPE = "harmonized_sample_ontology_intermediate"
ASSAY = "assay_epiclass"

In [None]:
base = Path("/home/local/USHERBROOKE/rabj2301/Projects/epilap/input/metadata")
# path = base / "hg38_2023_epiatlas_dfreeze_plus_encode_noncore_formatted_JR.json"
path = base / "hg38_2023_epiatlas_dfreeze_formatted_JR.json"
my_metadata = Metadata(path)

In [None]:
# my_metadata.get_categories()

In [None]:
def display_gen_info(metadata: Metadata):
    """Display track type, assay and cell type class counts."""
    metadata.display_labels("track_type")
    metadata.display_labels(ASSAY)
    metadata.display_labels(CELL_TYPE)

In [None]:
print(my_metadata.get_categories())
display_gen_info(my_metadata)

In [None]:
my_metadata.select_category_subsets("track_type", TRACKS_MAPPING.keys())
my_metadata.select_category_subsets(ASSAY, epiatlas_assays)

In [None]:
my_metadata.remove_category_subsets(CELL_TYPE, [""])
my_metadata.remove_small_classes(10, CELL_TYPE)

In [None]:
display_gen_info(my_metadata)

In [None]:
def count_trios(metadata: Metadata) -> Counter:
    """
    Count the occurrences of unique (track_type, assay, cell_type) trios in the metadata.

    Returns:
        Counter: A Counter object of the unique trios.
    """
    trios = Counter(
        [(dset["track_type"], dset[ASSAY], dset[CELL_TYPE]) for dset in metadata.datasets]
    )
    return trios

In [None]:
my_metadata = filter_by_pairs(
    my_metadata=my_metadata, assay_cat=ASSAY, cat2=CELL_TYPE, nb_pairs=1, min_per_pair=10
)

In [None]:
display_gen_info(my_metadata)

In [None]:
trios = count_trios(my_metadata)
print(len(trios), len(trios) * 5)

In [None]:
# display(trios.most_common())

In [None]:
display_gen_info(my_metadata)
my_metadata.display_labels("harmonized_donor_sex")

In [None]:
def select_datasets(metadata: Metadata, n=5) -> List[str]:
    """
    Select a random subset of n datasets for each unique (track_type, assay, cell_type) trio.

    Returns:
        list: A list of sampled md5sums of the selected datasets.
    """
    trio_files = defaultdict(list)
    for md5sum, dset in metadata.items:
        trio = (dset["track_type"], dset[ASSAY], dset[CELL_TYPE])
        trio_files[trio].append(md5sum)
    print(len(trio_files))

    sampled_md5s = list(
        itertools.chain.from_iterable(
            [random.sample(md5_list, n) for md5_list in trio_files.values()]
        )
    )
    return sampled_md5s

In [None]:
md5s = select_datasets(my_metadata)
print(len(md5s))

In [None]:
pd.DataFrame(md5s).to_csv("md5_shap_assay_background.list", index=False, header=False)

In [None]:
# import copy
# meta2 = copy.deepcopy(my_metadata)
meta2 = Metadata(path)

meta2.remove_category_subsets("track_type", ["Unique_raw"])
meta2.select_category_subsets(ASSAY, epiatlas_assays)

meta2.remove_category_subsets(CELL_TYPE, [""])
meta2.remove_small_classes(10, CELL_TYPE)

# meta2.select_category_subsets(CELL_TYPE, ["T cell"])
# meta2.display_labels(ASSAY)

# meta2.select_category_subsets(ASSAY, ["wgbs-pbat"])
# meta2.display_labels(CELL_TYPE)

display_gen_info(meta2)

In [None]:
def count_cell_types(metadata: Metadata) -> DefaultDict[str, Counter]:
    """
    Count the occurrences of each cell type for each assay in the dataset.

    Returns:
        defaultdict(Counter): A defaultdict of Counter objects with the count of cell types per assay.
    """
    cell_type_count = defaultdict(Counter)
    for dset in metadata.datasets:
        assay, cell_type = dset[ASSAY], dset[CELL_TYPE]
        cell_type_count[assay].update([cell_type])
    return cell_type_count


def select_cell_types(metadata: Metadata, n=70) -> DefaultDict[str, List]:
    """
    Determines which cell types are needed to attain n datasets, for a given assay.
    Starts with T cell and then selects the most common cell types.

    Args:
        metadata (Metadata): A Metadata object containing dataset metadata.
        n (int, optional): Maximum number of cell types to select for each assay. Defaults to 70.

    Returns:
        defaultdict(list): A defaultdict with selected cell types for each assay.
    """
    cell_count = count_cell_types(metadata)

    selected_ct = defaultdict(list)
    for assay, counter in cell_count.items():
        selected_ct[assay].append("T cell")
        i = min(counter["T cell"], n)
        del counter["T cell"]
        while i < n and counter:
            for cell_type, count in counter.most_common():
                i += min(count, n - i)
                selected_ct[assay].append(cell_type)
                del counter[cell_type]
                break
        if i < n:
            print(f"There is not at least {n} files for {assay}. Final number={i}")

    return selected_ct

In [None]:
# count_cell_types(my_metadata)
selected_ct = select_cell_types(my_metadata, n=70)
display(selected_ct)

In [None]:
def select_explain_files(
    metadata: Metadata, selected_cell_types: DefaultDict[str, List], n: int
) -> List[str]:
    """
    Sample 'n' random datasets for each assay from the given cell types.

    Args:
        metadata (Metadata): A Metadata object containing dataset metadata.
        selected_cell_types (defaultdict(list)): A dictionary with selected cell types for each assay.
        n (int): Number of files to select for each assay.

    Returns:
        list: A list of sampled md5sums of the selected files.
    """
    md5sums = defaultdict(list)
    for md5sum, dset in metadata.items:
        assay, ct = dset[ASSAY], dset[CELL_TYPE]
        if ct in selected_cell_types[assay]:
            md5sums[assay].append(md5sum)

    sampled_md5s = list(
        itertools.chain.from_iterable(
            [random.sample(md5_list, n) for md5_list in md5sums.values()]
        )
    )

    return sampled_md5s

In [None]:
files_to_explain = select_explain_files(meta2, selected_ct, 70)
len(files_to_explain)

In [None]:
pd.DataFrame(files_to_explain).to_csv(
    "md5_shap_assay_explain.list", index=False, header=False
)