In [None]:
"""Workbooks to analyze metadata."""

# pylint: disable=import-error, redefined-outer-name, unused-import

In [None]:
%load_ext autoreload
%autoreload 2

### SETUP

In [None]:
from __future__ import annotations

from collections import Counter, defaultdict
from pathlib import Path
from typing import DefaultDict, Dict, List

import pandas as pd
from IPython.display import display

from epi_ml.core.metadata import Metadata, UUIDMetadata
from epi_ml.utils.general_utility import write_hdf5_paths_to_file, write_md5s_to_file
from epi_ml.utils.modify_metadata import filter_by_pairs
from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    ASSAY_ORDER,
    BIOMATERIAL_TYPE,
    CANCER,
    CELL_TYPE,
    DISEASE,
    EPIATLAS_16_CT,
    LIFE_STAGE,
    SEX,
    TRACK,
    MetadataHandler,
)

In [None]:
CORE7_ASSAYS = ASSAY_ORDER[:7]

In [None]:
ASSAY_MERGE_DICT: Dict[str, str] = {
    "rna_seq": "rna",
    "mrna_seq": "rna",
    "wgbs-pbat": "wgbs",
    "wgbs-standard": "wgbs",
}

In [None]:
paper_dir = Path.home() / "Projects/epiclass/output/paper"
paper_meta_dir = paper_dir / "data" / "metadata"

In [None]:
base = Path().home() / "Projects/epiclass/input/metadata"
path = base / "dfreeze-v2" / "hg38_2023-epiatlas-dfreeze-pospurge-nodup_filterCtl.json"
# path = base / "dfreeze-v1.0" / "hg38_2023-epiatlas_dfreeze_formatted_JR.json"
# path = base / "dfreeze-v1.0" / "hg38_2023-epiatlas_dfreeze_plus_encode_noncore_formatted_JR.json"
# path = base / "dfreeze-v2" / "hg38_2023-epiatlas-dfreeze_v2.1_w_encode_noncore_2.json"
my_metadata = Metadata(path)
meta_df = my_metadata.to_df()

In [None]:
def display_gen_info(metadata: Metadata):
    """Display track type, assay and cell type class counts."""
    metadata.display_labels("track_type")
    metadata.display_labels(ASSAY)
    metadata.display_labels(CELL_TYPE)
    metadata.display_labels(SEX)
    # metadata.display_labels(CANCER)
    # metadata.display_labels(DISEASE)
    # metadata.display_labels(LIFE_STAGE)
    metadata.display_labels(TRACK)

In [None]:
def count_trios(metadata: Metadata) -> Counter:
    """
    Count the occurrences of unique (track_type, assay, cell_type) trios in the metadata.

    Returns:
        Counter: A Counter object of the unique trios.
    """
    trios = Counter(
        [(dset["track_type"], dset[ASSAY], dset[CELL_TYPE]) for dset in metadata.datasets]
    )
    return trios

In [None]:
def count_pairs_w_assay(metadata: Metadata, category: str) -> DefaultDict[str, Counter]:
    """
    Count the occurrences of each cell type for each assay in the dataset.

    Returns:
        defaultdict(Counter): A defaultdict of Counter objects with the count of cell types per assay.
    """
    pair_count = defaultdict(Counter)
    for dset in metadata.datasets:
        assay, other_label = dset[ASSAY], dset[category]
        pair_count[assay].update([other_label])
    return pair_count


def select_cell_types(metadata: Metadata, n=70) -> DefaultDict[str, List]:
    """
    Determines which cell types are needed to attain n datasets, for a given assay.
    Starts with T cell and then selects the most common cell types.

    Args:
        metadata (Metadata): A Metadata object containing dataset metadata.
        n (int, optional): Maximum number of cell types to select for each assay. Defaults to 70.

    Returns:
        defaultdict(list): A defaultdict with selected cell types for each assay.
    """
    cell_count = count_pairs_w_assay(metadata, CELL_TYPE)

    selected_ct = defaultdict(list)
    for assay, counter in cell_count.items():
        selected_ct[assay].append("T cell")
        i = min(counter["T cell"], n)
        del counter["T cell"]
        while i < n and counter:
            for cell_type, count in counter.most_common():
                i += min(count, n - i)
                selected_ct[assay].append(cell_type)
                del counter[cell_type]
                break
        if i < n:
            print(f"There is not at least {n} files for {assay}. Final number={i}")

    return selected_ct

In [None]:
# my_metadata.select_category_subsets(ASSAY, CORE7_ASSAYS)
# logdir = Path.home() / "scratch/pca"
# write_md5s_to_file(
#     md5s=my_metadata.md5s,
#     logdir=logdir,
#     name="epiatlas_chip",
# )

In [None]:
# my_metadata.select_category_subsets(BIOMATERIAL_TYPE, ["cell line"])
# df = my_metadata.to_df()
# print(df["epirr_id_without_version"].unique().shape)

### Create new metadata (for imputed files)

In [None]:
paper_dir = Path.home() / "Projects/epiclass/output/paper"
paper_meta_dir = paper_dir / "data" / "metadata"

In [None]:
path = paper_meta_dir / "hg38_2023-epiatlas-dfreeze-pospurge-nodup_filterCtl.json"
my_metadata = Metadata(path)

In [None]:
my_metadata.select_category_subsets(TRACK, ["pval"])
my_metadata.select_category_subsets(
    ASSAY, ["h3k27ac", "h3k27me3", "h3k36me3", "h3k4me1", "h3k4me3", "h3k9me3"]
)

df = pd.DataFrame.from_records(list(my_metadata.datasets), index=["epirr_id"])

print(df.shape, len(my_metadata))

In [None]:
# remove all assay specific columns, only want epirr metadata
df.drop(
    columns=[
        "uuid",
        "md5sum",
        "assay_type",
        "assay_epiclass",
        "experiment_type",
        "antibody",
        "inputs",
        "inputs_ctl",
        "data_file_path",
        "upload_date",
        "paired_end",
        "analyzed_as_stranded",
        "status",
    ],
    inplace=True,
    errors="ignore",
)
problematics_columns = df.filter(like="read_len").columns.to_list()
df.drop(columns=problematics_columns, inplace=True, errors="ignore")
df.drop_duplicates(inplace=True)
df.dropna(axis=0, how="all", inplace=True)

In [None]:
print(df.shape, len(set(df.index)))
df.head()

In [None]:
imputed_ids_path = paper_meta_dir / "all_imputed_files_md5.list"

imputed_ids_df = pd.read_csv(
    imputed_ids_path, sep="  ", header=None, names=["md5sum", "filename"], engine="python"
)

In [None]:
imputed_ids_df.head()
imputed_ids_df["epirr_id"] = imputed_ids_df["filename"].str.extract(
    r"impute_(.+)_H3.+.pval.bw"
)
imputed_ids_df["assay_epiclass"] = imputed_ids_df["filename"].str.extract(
    r"impute_.+_(H3.+).pval.bw"
)
imputed_ids_df["assay_epiclass"] = imputed_ids_df["assay_epiclass"].str.lower()
imputed_ids_df["uuid"] = imputed_ids_df["md5sum"]

In [None]:
print(imputed_ids_df.shape)
display(imputed_ids_df.head())

In [None]:
print(imputed_ids_df["epirr_id"].unique().shape)

In [None]:
set_og = set(df.index)
set_imputed = set(imputed_ids_df["epirr_id"])

union = set(df.index) | set(imputed_ids_df["epirr_id"])
print(len(union), len(set_og), len(set_imputed))
print(set_imputed - set_og)

for item in sorted(set_imputed - set_og):
    print(item)

In [None]:
merged_imputed_df = df.merge(
    imputed_ids_df, left_index=True, right_on="epirr_id", how="right"
)
print(merged_imputed_df.shape)

In [None]:
merged_imputed_df.fillna("", inplace=True)  # necessary to not end up with "float" types.

In [None]:
new_dict = merged_imputed_df.to_dict(orient="records")
meta_dict = {dset["md5sum"]: dset for dset in new_dict}
new_metadata = Metadata.from_dict(meta_dict)
new_metadata.save(paper_meta_dir / "hg38_epiatlas_imputed_pval_chip_2024-02.json")

### Sanity check: imputed vs obs pval datasets are similar

In [None]:
paper_data_dir = paper_dir / "data"

In [None]:
path_metadata_observed = (
    paper_meta_dir / "hg38_2023-epiatlas-dfreeze_v2.1_w_encode_noncore_2.json"
)
obs_metadata = Metadata(path_metadata_observed)

path_obs_md5 = paper_data_dir / "hdf5_list" / "100kb_all_none_pval_chip-seq.list"
with open(path_obs_md5, "r", encoding="utf8") as f:
    obs_md5 = f.read().splitlines()
    obs_md5 = set(md5.split("/")[-1].split("_")[0] for md5 in obs_md5)

path_metadata_imputed = paper_meta_dir / "hg38_epiatlas_imputed_pval_chip_2024-02.json"
imp_metadata = Metadata(path_metadata_imputed)

path_imputed_md5 = paper_data_dir / "hdf5_list" / "100kb_all_none_chip-seq_imputed.list"
with open(path_imputed_md5, "r", encoding="utf8") as f:
    imp_md5 = f.read().splitlines()
    imp_md5 = set(md5.split("/")[-1].split("_")[0] for md5 in imp_md5)

In [None]:
for md5 in list(obs_metadata.md5s):
    if md5 not in obs_md5:
        del obs_metadata[md5]

for md5 in list(imp_metadata.md5s):
    if md5 not in imp_md5:
        del imp_metadata[md5]

In [None]:
meta_dfs = {}
for name, metadata in zip(["observed", "imputed"], [obs_metadata, imp_metadata]):
    print(name)
    metadata.display_labels(ASSAY)
    meta_dfs[name] = metadata.to_df()

In [None]:
obs_df = meta_dfs["observed"]
imp_df = meta_dfs["imputed"]

obs_df_cell_type = obs_df[CELL_TYPE].value_counts(dropna=False)
relative_obs_df_cell_type = obs_df_cell_type / obs_df_cell_type.sum()

imp_df_cell_type = imp_df[CELL_TYPE].value_counts(dropna=False)
relative_imp_df_cell_type = imp_df_cell_type / imp_df_cell_type.sum()

for cell_type, perc in sorted(
    relative_obs_df_cell_type.items(), key=lambda x: x[1], reverse=True
)[0:20]:
    print(cell_type)
    print(f"obs: {obs_df_cell_type[cell_type]}, imp: {imp_df_cell_type[cell_type]}")
    print(
        f"obs: {relative_obs_df_cell_type[cell_type]:.2%}, imp: {relative_imp_df_cell_type[cell_type]:.2%}"
    )
    diff = relative_obs_df_cell_type[cell_type] - relative_imp_df_cell_type[cell_type]
    print(f"diff: {diff:.2%}")
    print()

## New cell type

In [None]:
metadata_handler = MetadataHandler(paper_dir)

metadata_v2_df = metadata_handler.load_metadata_df("v2")
metadata_v2_df.reset_index(drop=False, inplace=True)
print(metadata_v2_df.shape)

new_cell_type_path = paper_meta_dir / "Martin_class_v3_041224.tsv"
new_cell_type_df = pd.read_csv(
    new_cell_type_path,
    sep="\t",
    names=["epirr_id_without_version", "cell_type_martin", "cell_type_PE"],
)
print(new_cell_type_df.shape)

merged_metadata = metadata_v2_df.merge(
    new_cell_type_df, on="epirr_id_without_version", how="left"
)
print(merged_metadata.shape)

In [None]:
new_meta = {dset["md5sum"]: dset for dset in merged_metadata.to_dict(orient="records")}
new_meta_dict = Metadata.from_dict(new_meta)
new_meta_dict.save(
    paper_meta_dir / "hg38_2023-epiatlas-dfreeze-pospurge-nodup_filterCtl_newCT.json"
)

In [None]:
# for col in ["md5sum", "uuid", "epirr_id_without_version"]:
#     print(col, merged_metadata[col].nunique())

# merged_metadata = merged_metadata.drop_duplicates("uuid")
# print(merged_metadata.shape)

In [None]:
# for pivot_col in ["cell_type_martin", "cell_grouping_PE"]:
#     print(pivot_col)
#     pair_count_df = merged_metadata.groupby([pivot_col, ASSAY]).agg({"uuid": "count"}).reset_index()
#     assay_count_df = pair_count_df[pair_count_df["uuid"] >= 10].groupby(pivot_col).agg({ASSAY: "count"}).reset_index().sort_values(ASSAY, ascending=False)
#     print(assay_count_df.reset_index(drop=True))

## Merge pre purge predictions with official BadQual metadata

In [None]:
preds_path = (
    paper_dir
    / "data/training_results/pre-purge_n21606/10fold"
    / "full-10fold-validation_prediction_augmented-all.csv"
)
preds_df = pd.read_csv(preds_path, sep=",", low_memory=False)
preds_df.head()

In [None]:
classifier = "assay_epiclass_1l_3000n_11c_10fold-oversampling"
cols = ["uuid", ASSAY, "track_type", "Predicted class", "Max pred"]
preds_df[cols].head()
preds_df = preds_df[cols]

In [None]:
bad_qual_path = (
    paper_meta_dir
    / "epiatlas"
    / "official"
    / "BadQual-mislabels"
    / "official_BadQual.csv"
)
bad_qual_df = pd.read_csv(bad_qual_path)
display(bad_qual_df.head())

bad_uuid = set(bad_qual_df["uuid"])

In [None]:
merged_df = pd.merge(bad_qual_df, preds_df, how="right", on=["uuid"])

for df in [bad_qual_df, preds_df, merged_df]:
    print(df.shape)

merged_df = merged_df[merged_df["uuid"].isin(bad_uuid)]
print(merged_df.shape)

In [None]:
# Select relevant columns
to_pivot = merged_df[["uuid", "track_type", "Max pred", "Predicted class"]]

# Pivot longer to wider format using two value columns
wide_df = to_pivot.pivot(
    index="uuid", columns="track_type", values=["Max pred", "Predicted class"]
)

# Flatten MultiIndex columns
wide_df.columns = [
    f"{val.lower().replace(' ', '_')}_{track}" for val, track in wide_df.columns
]

# Reset index so uuid becomes a column again
wide_df = wide_df.reset_index()

print(wide_df.shape)  # Should be roughly (134, 7-9)
display(wide_df.head())

# remerge with bad_qual_df
merged_df = pd.merge(bad_qual_df, wide_df, how="left", on=["uuid"])
print(merged_df.shape)
display(merged_df.head())

In [None]:
# save
new_path = bad_qual_path.parent / "official_BadQual_augmented.csv"
merged_df.to_csv(new_path, index=False)