In [None]:
"""Workbooks to analyze metadata."""
# pylint: disable=import-error, redefined-outer-name

In [None]:
from __future__ import annotations

import copy
from collections import Counter, defaultdict
from pathlib import Path
from typing import DefaultDict, Dict, List

import pandas as pd
from IPython.display import display

from epi_ml.core.metadata import Metadata, UUIDMetadata
from epi_ml.utils.general_utility import write_hdf5_paths_to_file, write_md5s_to_file
from epi_ml.utils.modify_metadata import filter_by_pairs

BIOMATERIAL_TYPE = "harmonized_biomaterial_type"
CELL_TYPE = "harmonized_sample_ontology_intermediate"
ASSAY = "assay_epiclass"
SEX = "harmonized_donor_sex"
CANCER = "harmonized_sample_cancer_high"
DISEASE = "harmonized_sample_disease_high"
LIFE_STAGE = "harmonized_donor_life_stage"
TRACK = "track_type"

In [None]:
ASSAY_MERGE_DICT: Dict[str, str] = {
    "rna_seq": "rna",
    "mrna_seq": "rna",
    "wgbs-pbat": "wgbs",
    "wgbs-standard": "wgbs",
}

In [None]:
base = Path().home() / "Projects/epiclass/input/metadata"
# path = base / "hg38_2023_epiatlas_dfreeze_plus_encode_noncore_formatted_JR.json"
# path = base / "hg38_2023_epiatlas_dfreeze_formatted_JR.json"
path = base / "dfreeze-v2" / "hg38_2023-epiatlas-dfreeze-pospurge-nodup_filterCtl.json"
# path = base / "dfreeze-v1.0" / "hg38_2023-epiatlas_dfreeze_formatted_JR.json"
# path = base / "dfreeze-v2" / "hg38_2023-epiatlas-dfreeze_v2.1_w_encode_noncore_2.json"
my_metadata = Metadata(path)

In [None]:
def display_gen_info(metadata: Metadata):
    """Display track type, assay and cell type class counts."""
    metadata.display_labels("track_type")
    metadata.display_labels(ASSAY)
    metadata.display_labels(CELL_TYPE)
    metadata.display_labels(SEX)
    # metadata.display_labels(CANCER)
    # metadata.display_labels(DISEASE)
    # metadata.display_labels(LIFE_STAGE)

In [None]:
def count_trios(metadata: Metadata) -> Counter:
    """
    Count the occurrences of unique (track_type, assay, cell_type) trios in the metadata.

    Returns:
        Counter: A Counter object of the unique trios.
    """
    trios = Counter(
        [(dset["track_type"], dset[ASSAY], dset[CELL_TYPE]) for dset in metadata.datasets]
    )
    return trios

In [None]:
def count_pairs_w_assay(metadata: Metadata, category: str) -> DefaultDict[str, Counter]:
    """
    Count the occurrences of each cell type for each assay in the dataset.

    Returns:
        defaultdict(Counter): A defaultdict of Counter objects with the count of cell types per assay.
    """
    pair_count = defaultdict(Counter)
    for dset in metadata.datasets:
        assay, other_label = dset[ASSAY], dset[category]
        pair_count[assay].update([other_label])
    return pair_count


def select_cell_types(metadata: Metadata, n=70) -> DefaultDict[str, List]:
    """
    Determines which cell types are needed to attain n datasets, for a given assay.
    Starts with T cell and then selects the most common cell types.

    Args:
        metadata (Metadata): A Metadata object containing dataset metadata.
        n (int, optional): Maximum number of cell types to select for each assay. Defaults to 70.

    Returns:
        defaultdict(list): A defaultdict with selected cell types for each assay.
    """
    cell_count = count_pairs_w_assay(metadata, CELL_TYPE)

    selected_ct = defaultdict(list)
    for assay, counter in cell_count.items():
        selected_ct[assay].append("T cell")
        i = min(counter["T cell"], n)
        del counter["T cell"]
        while i < n and counter:
            for cell_type, count in counter.most_common():
                i += min(count, n - i)
                selected_ct[assay].append(cell_type)
                del counter[cell_type]
                break
        if i < n:
            print(f"There is not at least {n} files for {assay}. Final number={i}")

    return selected_ct

In [None]:
# display_gen_info(my_metadata)
# my_metadata.get_categories()

In [None]:
my_metadata.remove_missing_labels(CELL_TYPE)
my_metadata.select_category_subsets(ASSAY, ["rna_seq"])
my_metadata.display_labels(CELL_TYPE)

In [None]:
# my_metadata.remove_missing_labels(CELL_TYPE)
# my_metadata.convert_classes(ASSAY, ASSAY_MERGE_DICT)
# for md5 in list(my_metadata.md5s):
#     assay = my_metadata[md5][ASSAY]
#     ct = my_metadata[md5][CELL_TYPE]
#     new_label = f"{assay}_{ct}"
#     my_metadata[md5][f"{ASSAY}_{CELL_TYPE}"] = new_label

# my_metadata.display_labels(f"{ASSAY}_{CELL_TYPE}")
# my_metadata.save(base / "dfreeze-v2" / "hg38_2023-epiatlas-dfreeze-pospurge-nodup_filterCtl_w_merged_assay-ct.json")

In [None]:
# my_metadata = filter_by_pairs(my_metadata, assay_cat=ASSAY, cat2=CELL_TYPE, nb_pairs=9, min_per_pair=10)
# my_metadata.display_labels(CELL_TYPE)
# my_metadata.display_labels(ASSAY)
# my_metadata.display_labels(f"{ASSAY}_{CELL_TYPE}")

# filepath = Path.home() / "downloads" / "100kb_all_none_16ct_pair9.list"
# write_hdf5_paths_to_file(my_metadata.md5s, parent="", suffix="100kb_all_none_value", filepath=filepath)

In [None]:
sus_md5_path = "/home/local/USHERBROOKE/rabj2301/mounts/narval-mount/scratch/other_data/C-A/hdf5/umap-input/epiatlas_all/nn100/embedding_standard_3D_nn15_sus_wgbs.md5"
with open(sus_md5_path, "r", encoding="utf8") as f:
    sus_md5s = f.read().splitlines()

In [None]:
meta = copy.deepcopy(my_metadata)
for md5 in list(meta.md5s):
    if md5 not in sus_md5s:
        del meta[md5]

In [None]:
my_metadata.select_category_subsets(ASSAY, ["wgbs-standard"])
for cat in [
    CELL_TYPE,
    "project",
    "data_generating_centre",
    DISEASE,
    LIFE_STAGE,
    "harmonized_sample_disease_ontology_curie",
]:
    if cat not in ["uuid", "md5sum"]:
        my_metadata.display_labels(cat)

In [None]:
for cat in [
    ASSAY,
    CELL_TYPE,
    TRACK,
    "project",
    "data_generating_centre",
    DISEASE,
    LIFE_STAGE,
    "harmonized_sample_disease_ontology_curie",
]:
    if cat not in ["uuid", "md5sum"]:
        meta.display_labels(cat)

In [None]:
# my_metadata.select_category_subsets(TRACK, ["pval"])
# my_metadata.select_category_subsets(ASSAY, ["h3k27ac", "h3k27me3", "h3k36me3", "h3k4me1", "h3k4me3", "h3k9me3"])
# display_gen_info(my_metadata)
# write_hdf5_paths_to_file(
#     md5s=my_metadata.md5s,
#     parent="/lustre07/scratch/rabyj/local_ihec_data/epiatlas/hg38/hdf5",
#     suffix="100kb_all_none",
#     filepath=Path.home() / "downloads/temp" / "epiatlas_pval_chip-seq_100kb_all_none.list",
# )

In [None]:
# # my_metadata.display_labels("project")

# my_metadata.select_category_subsets(ASSAY, ["input"])
# my_metadata.select_category_subsets(CELL_TYPE, ["T cell", "lymphocyte of B lineage", "neutrophil", "muscle organ"])
# my_metadata.display_labels(CELL_TYPE)
# write_md5s_to_file(
#     md5s=my_metadata.md5s,
#     logdir=Path.home() / "downloads/temp",
#     name="input_4ct",
# )
# write_hdf5_paths_to_file(
#     md5s=my_metadata.md5s,
#     parent="",
#     suffix="100kb_all_none",
#     filepath=Path.home() / "downloads/temp" / "100kb_all_none_input_4ct.list",
# )

### Create new metadata (for imputed files)

In [None]:
my_metadata.select_category_subsets(TRACK, ["pval"])
my_metadata.select_category_subsets(
    ASSAY, ["h3k27ac", "h3k27me3", "h3k36me3", "h3k4me1", "h3k4me3", "h3k9me3"]
)

df = pd.DataFrame.from_records(list(my_metadata.datasets), index=["epirr_id"])

print(df.shape, len(my_metadata))

In [None]:
df.head()

In [None]:
# remove all assay specific columns, only want epirr metadata
df.drop(
    columns=[
        "uuid",
        "md5sum",
        "assay_type",
        "assay_epiclass",
        "experiment_type",
        "antibody",
        "inputs",
        "inputs_ctl",
        "data_file_path",
        "upload_date",
        "paired_end",
        "analyzed_as_stranded",
        "status",
    ],
    inplace=True,
    errors="ignore",
)
problematics_columns = df.filter(like="read_len").columns.to_list()
df.drop(columns=problematics_columns, inplace=True, errors="ignore")
df.drop_duplicates(inplace=True)
df.dropna(axis=0, how="all", inplace=True)

In [None]:
print(df.shape, len(set(df.index)))
df.head()

In [None]:
# df[CELL_TYPE].value_counts(dropna=False)

In [None]:
imputed_ids_path = (
    Path.home()
    / "mounts/narval-mount"
    / "scratch/local_ihec_data/epiatlas/hg38/bw/chip-seq_imputed/all_md5sums.list"
)

imputed_ids_df = pd.read_csv(
    imputed_ids_path, sep="  ", header=None, names=["md5sum", "filename"]
)

In [None]:
imputed_ids_df.head()
imputed_ids_df["epirr_id"] = imputed_ids_df["filename"].str.extract(
    r"impute_(.+)_H3.+.pval.bw"
)
imputed_ids_df["assay_epiclass"] = imputed_ids_df["filename"].str.extract(
    r"impute_.+_(H3.+).pval.bw"
)
imputed_ids_df["assay_epiclass"] = imputed_ids_df["assay_epiclass"].str.lower()
imputed_ids_df["uuid"] = imputed_ids_df["md5sum"]

In [None]:
print(imputed_ids_df.shape)
imputed_ids_df.head()

In [None]:
print(imputed_ids_df["epirr_id"].unique().shape)

In [None]:
set_og = set(df.index)
set_imputed = set(imputed_ids_df["epirr_id"])

union = set(df.index) | set(imputed_ids_df["epirr_id"])
print(len(union), len(set_og), len(set_imputed))
print(set_imputed - set_og)

for item in sorted(set_imputed - set_og):
    print(item)

In [None]:
merged_imputed_df = df.merge(
    imputed_ids_df, left_index=True, right_on="epirr_id", how="right"
)

In [None]:
print(merged_imputed_df.shape)

In [None]:
# merged_imputed_df[CELL_TYPE].value_counts(dropna=False)

In [None]:
merged_imputed_df.fillna("", inplace=True)  # necessary to not end up with "float" types.

In [None]:
# merged_imputed_df.to_csv(Path.home() / "downloads" / "temp"/ "hg38_epiatlas_imputed_pval_chip_2024-02.csv")

In [None]:
new_dict = merged_imputed_df.to_dict(orient="records")
meta_dict = {dset["md5sum"]: dset for dset in new_dict}
new_metadata = Metadata.from_dict(meta_dict)
new_metadata.save(
    Path.home() / "downloads" / "temp" / "hg38_epiatlas_imputed_pval_chip_2024-02.json"
)