In [None]:
"""Workbooks to analyze metadata."""
# pylint: disable=import-error, redefined-outer-name

In [None]:
from __future__ import annotations

import copy
import os
from collections import Counter, defaultdict
from pathlib import Path
from typing import DefaultDict, List

from epi_ml.core.metadata import Metadata, UUIDMetadata, env_filtering
from epi_ml.utils.modify_metadata import filter_by_pairs

# import numpy as np
# import pandas as pd
# from IPython.display import display

# from epi_ml.core.epiatlas_treatment import ACCEPTED_TRACKS
# from epi_ml.core.metadata import Metadata, UUIDMetadata, env_filtering
# from epi_ml.utils.metadata_utils import EPIATLAS_ASSAYS, count_pairs, make_table
# from epi_ml.utils.modify_metadata import filter_by_pairs


BIOMATERIAL_TYPE = "harmonized_biomaterial_type"
CELL_TYPE = "harmonized_sample_ontology_intermediate"
ASSAY = "assay_epiclass"
SEX = "harmonized_donor_sex"
CANCER = "harmonized_sample_cancer_high"
DISEASE = "harmonized_sample_disease_high"
LIFE_STAGE = "harmonized_donor_life_stage"

In [None]:
base = Path().home() / "Projects/epilap/input/metadata"
# path = base / "hg38_2023_epiatlas_dfreeze_plus_encode_noncore_formatted_JR.json"
# path = base / "hg38_2023_epiatlas_dfreeze_formatted_JR.json"
# path = base / "dfreeze-v2" / "hg38_2023-epiatlas-dfreeze-pospurge-nodup_filterCtl.json"
path = base / "dfreeze-v2" / "hg38_2023-epiatlas-dfreeze_v2.1_w_encode_noncore_2.json"
my_metadata = Metadata(path)

In [None]:
def display_gen_info(metadata: Metadata):
    """Display track type, assay and cell type class counts."""
    metadata.display_labels("track_type")
    metadata.display_labels(ASSAY)
    metadata.display_labels(CELL_TYPE)
    # metadata.display_labels(SEX)
    # metadata.display_labels(CANCER)
    # metadata.display_labels(DISEASE)
    metadata.display_labels(LIFE_STAGE)

In [None]:
def count_trios(metadata: Metadata) -> Counter:
    """
    Count the occurrences of unique (track_type, assay, cell_type) trios in the metadata.

    Returns:
        Counter: A Counter object of the unique trios.
    """
    trios = Counter(
        [(dset["track_type"], dset[ASSAY], dset[CELL_TYPE]) for dset in metadata.datasets]
    )
    return trios

In [None]:
def count_pairs_w_assay(metadata: Metadata, category: str) -> DefaultDict[str, Counter]:
    """
    Count the occurrences of each cell type for each assay in the dataset.

    Returns:
        defaultdict(Counter): A defaultdict of Counter objects with the count of cell types per assay.
    """
    pair_count = defaultdict(Counter)
    for dset in metadata.datasets:
        assay, other_label = dset[ASSAY], dset[category]
        pair_count[assay].update([other_label])
    return pair_count


def select_cell_types(metadata: Metadata, n=70) -> DefaultDict[str, List]:
    """
    Determines which cell types are needed to attain n datasets, for a given assay.
    Starts with T cell and then selects the most common cell types.

    Args:
        metadata (Metadata): A Metadata object containing dataset metadata.
        n (int, optional): Maximum number of cell types to select for each assay. Defaults to 70.

    Returns:
        defaultdict(list): A defaultdict with selected cell types for each assay.
    """
    cell_count = count_pairs_w_assay(metadata, CELL_TYPE)

    selected_ct = defaultdict(list)
    for assay, counter in cell_count.items():
        selected_ct[assay].append("T cell")
        i = min(counter["T cell"], n)
        del counter["T cell"]
        while i < n and counter:
            for cell_type, count in counter.most_common():
                i += min(count, n - i)
                selected_ct[assay].append(cell_type)
                del counter[cell_type]
                break
        if i < n:
            print(f"There is not at least {n} files for {assay}. Final number={i}")

    return selected_ct

In [None]:
display_gen_info(my_metadata)

In [None]:
# for cat in my_metadata.get_categories():
#     print(cat)

# my_metadata.display_labels("groups_second_level_name")
# other_meta = UUIDMetadata.from_metadata(my_metadata)
# other_meta.display_uuid_per_class("groups_second_level_name")
# my_metadata.display_labels("data_generating_centre")
# my_metadata.select_category_subsets("project", ["EpiHK"])

In [None]:
# cat2 = "groups_second_level_name"
cat2 = "harmonized_sample_ontology_intermediate"

# os.environ["REMOVE_TRACKS"] = '["fc", "pval"]'
os.environ["REMOVE_TRACKS"] = '["fc", "raw"]'
# os.environ["REMOVE_TRACKS"] = '["fc"]'
os.environ["EXCLUDE_LIST"] = '["other", "--", "NA", ""]'
os.environ["MIN_CLASS_SIZE"] = "10"

meta = UUIDMetadata.from_metadata((copy.deepcopy(my_metadata)))

# meta.display_uuid_per_class("assay_epiclass")
# meta.display_labels("track_type")

env_filtering(meta, cat2)
# meta.remove_category_subsets("track_type", ["fc", "pval", "Unique_raw"])
# meta.display_labels("track_type")
# meta.display_labels("assay_epiclass")

new_meta = filter_by_pairs(meta, ASSAY, cat2, nb_pairs=9, min_per_pair=10, use_uuid=True)
# new_meta.display_uuid_per_class("assay_epiclass")
# new_meta.display_labels("track_type")
# UUIDMetadata.from_metadata(new_meta).display_uuid_per_class(cat2)