In [None]:
"""Workbooks to analyze metadata."""
# pylint: disable=import-error, redefined-outer-name, unused-import

In [None]:
%load_ext autoreload
%autoreload 2

### SETUP

In [None]:
from __future__ import annotations

import copy
import io
import json
from collections import Counter, defaultdict
from pathlib import Path
from typing import DefaultDict, Dict, List

import numpy as np
import pandas as pd
import requests
from IPython.display import display

from epi_ml.core.metadata import Metadata, UUIDMetadata
from epi_ml.utils.general_utility import write_hdf5_paths_to_file, write_md5s_to_file
from epi_ml.utils.modify_metadata import filter_by_pairs

BIOMATERIAL_TYPE = "harmonized_biomaterial_type"
CELL_TYPE = "harmonized_sample_ontology_intermediate"
ASSAY = "assay_epiclass"
SEX = "harmonized_donor_sex"
CANCER = "harmonized_sample_cancer_high"
DISEASE = "harmonized_sample_disease_high"
LIFE_STAGE = "harmonized_donor_life_stage"
TRACK = "track_type"

In [None]:
ASSAY_MERGE_DICT: Dict[str, str] = {
    "rna_seq": "rna",
    "mrna_seq": "rna",
    "wgbs-pbat": "wgbs",
    "wgbs-standard": "wgbs",
}

In [None]:
accepted_cts = [
    "T cell",
    "neutrophil",
    "brain",
    "monocyte",
    "lymphocyte of B lineage",
    "myeloid cell",
    "venous blood",
    "macrophage",
    "mesoderm-derived structure",
    "endoderm-derived structure",
    "colon",
    "connective tissue cell",
    "hepatocyte",
    "mammary gland epithelial cell",
    "muscle organ",
    "extraembryonic cell",
]
accepted_cts = [ct.lower() for ct in accepted_cts]

In [None]:
base = Path().home() / "Projects/epiclass/input/metadata"
path = base / "dfreeze-v2" / "hg38_2023-epiatlas-dfreeze-pospurge-nodup_filterCtl.json"
# path = base / "dfreeze-v1.0" / "hg38_2023-epiatlas_dfreeze_formatted_JR.json"
# path = base / "dfreeze-v1.0" / "hg38_2023-epiatlas_dfreeze_plus_encode_noncore_formatted_JR.json"
# path = base / "dfreeze-v2" / "hg38_2023-epiatlas-dfreeze_v2.1_w_encode_noncore_2.json"
my_metadata = Metadata(path)
meta_df = my_metadata.to_df()

In [None]:
def display_gen_info(metadata: Metadata):
    """Display track type, assay and cell type class counts."""
    metadata.display_labels("track_type")
    metadata.display_labels(ASSAY)
    metadata.display_labels(CELL_TYPE)
    metadata.display_labels(SEX)
    # metadata.display_labels(CANCER)
    # metadata.display_labels(DISEASE)
    # metadata.display_labels(LIFE_STAGE)
    metadata.display_labels(TRACK)

In [None]:
def count_trios(metadata: Metadata) -> Counter:
    """
    Count the occurrences of unique (track_type, assay, cell_type) trios in the metadata.

    Returns:
        Counter: A Counter object of the unique trios.
    """
    trios = Counter(
        [(dset["track_type"], dset[ASSAY], dset[CELL_TYPE]) for dset in metadata.datasets]
    )
    return trios

In [None]:
def count_pairs_w_assay(metadata: Metadata, category: str) -> DefaultDict[str, Counter]:
    """
    Count the occurrences of each cell type for each assay in the dataset.

    Returns:
        defaultdict(Counter): A defaultdict of Counter objects with the count of cell types per assay.
    """
    pair_count = defaultdict(Counter)
    for dset in metadata.datasets:
        assay, other_label = dset[ASSAY], dset[category]
        pair_count[assay].update([other_label])
    return pair_count


def select_cell_types(metadata: Metadata, n=70) -> DefaultDict[str, List]:
    """
    Determines which cell types are needed to attain n datasets, for a given assay.
    Starts with T cell and then selects the most common cell types.

    Args:
        metadata (Metadata): A Metadata object containing dataset metadata.
        n (int, optional): Maximum number of cell types to select for each assay. Defaults to 70.

    Returns:
        defaultdict(list): A defaultdict with selected cell types for each assay.
    """
    cell_count = count_pairs_w_assay(metadata, CELL_TYPE)

    selected_ct = defaultdict(list)
    for assay, counter in cell_count.items():
        selected_ct[assay].append("T cell")
        i = min(counter["T cell"], n)
        del counter["T cell"]
        while i < n and counter:
            for cell_type, count in counter.most_common():
                i += min(count, n - i)
                selected_ct[assay].append(cell_type)
                del counter[cell_type]
                break
        if i < n:
            print(f"There is not at least {n} files for {assay}. Final number={i}")

    return selected_ct

### Our metadata VS official metadata

#### Metadata we use for training

In [None]:
path = base / "dfreeze-v2" / "hg38_2023-epiatlas-dfreeze-pospurge-nodup_filterCtl.json"
my_metadata = Metadata(path)
my_meta_df = my_metadata.to_df()
my_meta_df = my_meta_df.drop_duplicates(subset=["epirr_id_without_version"])

In [None]:
relevants_cols = [CELL_TYPE, BIOMATERIAL_TYPE, SEX, DISEASE, LIFE_STAGE]
my_meta_df = my_meta_df[["epirr_id_without_version"] + relevants_cols]
my_epirrs = set(my_meta_df["epirr_id_without_version"].unique())

my_meta_df = my_meta_df.set_index("epirr_id_without_version")

#### Official metadata

In [None]:
dfs = {}

url_template = "https://raw.githubusercontent.com/IHEC/epiATLAS-metadata-harmonization/refs/heads/main/openrefine/{version}/IHEC_metadata_harmonization.{version}.extended.csv"
for version in ["v1.0", "v1.1", "v1.2"]:
    myurl = url_template.format(version=version)
    print(f"Downloading version {version}: {myurl}")

    try:
        # Download the file
        response = requests.get(myurl, stream=True)
        response.raise_for_status()  # Raise an error for bad responses (4xx, 5xx)

        # Load file as a DataFrame
        content = response.content
        df = pd.read_csv(io.StringIO(content.decode("utf-8")))

    except requests.exceptions.RequestException as e:
        print(f"Error downloading {myurl}: {e}")

    dfs[version] = df

Modify dataframes to fit with our metadata.

In [None]:
for v, df in dfs.items():
    df["epirr_id_without_version"] = df["EpiRR"].str.split(".").str[0]
    df = df.set_index("epirr_id_without_version")
    df.fillna("unknown", inplace=True)
    dfs[v] = df

#### Creating json of differences

In [None]:
problematic_idxs = defaultdict(set)
for cat in relevants_cols:
    for version in ["v1.0", "v1.1"]:
        meta = dfs[version]
        meta = meta[meta.index.isin(my_epirrs)]

        # sort same way
        meta = meta.loc[my_meta_df.index]

        # find idx where value is different
        diff = meta[cat] != my_meta_df[cat]
        diff_idxs = diff[diff].index

        if not diff_idxs.empty:
            problematic_idxs[cat].update(diff_idxs)

In [None]:
all_changes = {col: {} for col in relevants_cols if col in problematic_idxs}
for col in relevants_cols:
    cat_idxs = problematic_idxs[col]
    for idx in cat_idxs:
        values = {
            "training": my_meta_df.loc[idx, col],
            "v1.0-official": dfs["v1.0"].loc[idx, col],
            "v1.1-official": dfs["v1.1"].loc[idx, col],
        }
        all_changes[col][idx] = values

In [None]:
for col in relevants_cols:
    if col in problematic_idxs:
        print(f"Changes in {col}: {len(problematic_idxs[col])}")
    else:
        print(f"No changes in {col}")

In [None]:
filename = "training_metadata_vs_official.json"
path = base / filename

with open(path, "w", encoding="utf8") as f:
    json.dump(all_changes, f, indent=4, allow_nan=False)

### Sanity check: SEX v1.2 = SEX v1.3

In [None]:
official_metadata_dir = (
    Path.home() / "Projects/epiclass/output/paper/data/metadata/official"
)

official_metadata_dfs = {}
for version in ["v1.1", "v1.2", "v1.3"]:
    path = official_metadata_dir / f"IHEC_metadata_harmonization.{version}.extended.csv"
    df = pd.read_csv(path, sep=",")
    official_metadata_dfs[version] = df

In [None]:
SEX = "harmonized_donor_sex"
sex_mislabels_path = (
    official_metadata_dir / "BadQual-mislabels" / "official_Sex_mislabeled.csv"
)
sex_mislabels_df = pd.read_csv(sex_mislabels_path, sep=",")

In [None]:
sex_epirrs = {}
subset_df = sex_mislabels_df
for version, df in official_metadata_dfs.items():
    relevant_df = df.loc[:, ["epirr_id_without_version", SEX]]
    subset_df = relevant_df.merge(
        subset_df,
        left_on="epirr_id_without_version",
        right_on="EpiRR_no-v",
        how="right",
        suffixes=(f"_{version}", ""),
    )

In [None]:
subset_df = subset_df.drop(
    columns=[col for col in subset_df.columns if col.startswith("epirr_id")]
)
subset_df = subset_df.drop(columns=[SEX])

In [None]:
assert (subset_df[f"{SEX}_v1.3"] != subset_df[f"{SEX}_v1.2"]).sum() == 0

In [None]:
merged_df = official_metadata_dfs["v1.2"].merge(
    official_metadata_dfs["v1.3"],
    on="epirr_id_without_version",
    how="inner",
    suffixes=("_v1.2", "_v1.3"),
)

In [None]:
assert (merged_df[f"{SEX}_v1.3"] != merged_df[f"{SEX}_v1.2"]).sum() == 0

### Create new metadata (for imputed files)

In [None]:
path = base / "dfreeze-v2" / "hg38_2023-epiatlas-dfreeze-pospurge-nodup_filterCtl.json"
my_metadata = Metadata(path)

In [None]:
my_metadata.select_category_subsets(TRACK, ["pval"])
my_metadata.select_category_subsets(
    ASSAY, ["h3k27ac", "h3k27me3", "h3k36me3", "h3k4me1", "h3k4me3", "h3k9me3"]
)

df = pd.DataFrame.from_records(list(my_metadata.datasets), index=["epirr_id"])

print(df.shape, len(my_metadata))

In [None]:
# remove all assay specific columns, only want epirr metadata
df.drop(
    columns=[
        "uuid",
        "md5sum",
        "assay_type",
        "assay_epiclass",
        "experiment_type",
        "antibody",
        "inputs",
        "inputs_ctl",
        "data_file_path",
        "upload_date",
        "paired_end",
        "analyzed_as_stranded",
        "status",
    ],
    inplace=True,
    errors="ignore",
)
problematics_columns = df.filter(like="read_len").columns.to_list()
df.drop(columns=problematics_columns, inplace=True, errors="ignore")
df.drop_duplicates(inplace=True)
df.dropna(axis=0, how="all", inplace=True)

In [None]:
print(df.shape, len(set(df.index)))
df.head()

In [None]:
imputed_ids_path = (
    Path.home()
    / "mounts/narval-mount"
    / "rrg-ihec-share/local_ihec_data/epiatlas/hg38/bw/chip-seq_imputed"
    / "all_md5sums.list"
)

imputed_ids_df = pd.read_csv(
    imputed_ids_path, sep="  ", header=None, names=["md5sum", "filename"], engine="python"
)

In [None]:
imputed_ids_df.head()
imputed_ids_df["epirr_id"] = imputed_ids_df["filename"].str.extract(
    r"impute_(.+)_H3.+.pval.bw"
)
imputed_ids_df["assay_epiclass"] = imputed_ids_df["filename"].str.extract(
    r"impute_.+_(H3.+).pval.bw"
)
imputed_ids_df["assay_epiclass"] = imputed_ids_df["assay_epiclass"].str.lower()
imputed_ids_df["uuid"] = imputed_ids_df["md5sum"]

In [None]:
print(imputed_ids_df.shape)
imputed_ids_df.head()

In [None]:
print(imputed_ids_df["epirr_id"].unique().shape)

In [None]:
set_og = set(df.index)
set_imputed = set(imputed_ids_df["epirr_id"])

union = set(df.index) | set(imputed_ids_df["epirr_id"])
print(len(union), len(set_og), len(set_imputed))
print(set_imputed - set_og)

for item in sorted(set_imputed - set_og):
    print(item)

In [None]:
merged_imputed_df = df.merge(
    imputed_ids_df, left_index=True, right_on="epirr_id", how="right"
)
print(merged_imputed_df.shape)

In [None]:
merged_imputed_df.fillna("", inplace=True)  # necessary to not end up with "float" types.

In [None]:
# merged_imputed_df.to_csv(Path.home() / "downloads" / "temp"/ "hg38_epiatlas_imputed_pval_chip_2024-02.csv")

In [None]:
# new_dict = merged_imputed_df.to_dict(orient="records")
# meta_dict = {dset["md5sum"]: dset for dset in new_dict}
# new_metadata = Metadata.from_dict(meta_dict)
# new_metadata.save(
#     Path.home() / "downloads" / "temp" / "hg38_epiatlas_imputed_pval_chip_2024-02.json"
# )

### Sanity check: imputed vs obs pval datasets are similar

In [None]:
base_paper_dir = Path.home() / "Projects/epiclass/output/paper"
base_metadata_dir = base_paper_dir / "data/metadata"
path_metadata_observed = (
    base_metadata_dir / "hg38_2023-epiatlas-dfreeze_v2.1_w_encode_noncore_2.json"
)
obs_metadata = Metadata(path_metadata_observed)

path_obs_md5 = base_paper_dir / "data/hdf5_list" / "100kb_all_none_pval_chip-seq.list"
with open(path_obs_md5, "r", encoding="utf8") as f:
    obs_md5 = f.read().splitlines()
    obs_md5 = set(md5.split("/")[-1].split("_")[0] for md5 in obs_md5)

path_metadata_imputed = base_metadata_dir / "hg38_epiatlas_imputed_pval_chip_2024-02.json"
imp_metadata = Metadata(path_metadata_imputed)

path_imputed_md5 = (
    base_paper_dir / "data/hdf5_list" / "100kb_all_none_chip-seq_imputed.list"
)
with open(path_imputed_md5, "r", encoding="utf8") as f:
    imp_md5 = f.read().splitlines()
    imp_md5 = set(md5.split("/")[-1].split("_")[0] for md5 in imp_md5)

In [None]:
for md5 in list(obs_metadata.md5s):
    if md5 not in obs_md5:
        del obs_metadata[md5]

for md5 in list(imp_metadata.md5s):
    if md5 not in imp_md5:
        del imp_metadata[md5]

In [None]:
meta_dfs = {}
for name, metadata in zip(["observed", "imputed"], [obs_metadata, imp_metadata]):
    print(name)
    metadata.display_labels(ASSAY)
    meta_dfs[name] = metadata.to_df()

In [None]:
obs_df = meta_dfs["observed"]
imp_df = meta_dfs["imputed"]

obs_df_cell_type = obs_df[CELL_TYPE].value_counts(dropna=False)
relative_obs_df_cell_type = obs_df_cell_type / obs_df_cell_type.sum()

imp_df_cell_type = imp_df[CELL_TYPE].value_counts(dropna=False)
relative_imp_df_cell_type = imp_df_cell_type / imp_df_cell_type.sum()

for cell_type, perc in sorted(
    relative_obs_df_cell_type.items(), key=lambda x: x[1], reverse=True
)[0:20]:
    print(cell_type)
    print(f"obs: {obs_df_cell_type[cell_type]}, imp: {imp_df_cell_type[cell_type]}")
    print(
        f"obs: {relative_obs_df_cell_type[cell_type]:.2%}, imp: {relative_imp_df_cell_type[cell_type]:.2%}"
    )
    diff = relative_obs_df_cell_type[cell_type] - relative_imp_df_cell_type[cell_type]
    print(f"diff: {diff:.2%}")
    print()

## New cell type

In [None]:
from epi_ml.utils.notebooks.paper.paper_utilities import MetadataHandler

paper_dir = Path.home() / "Projects/epiclass/output/paper"
metadata_dir = paper_dir / "data/metadata"

metadata_handler = MetadataHandler(paper_dir)

metadata_v2_df = metadata_handler.load_metadata_df("v2")
metadata_v2_df.reset_index(drop=False, inplace=True)
print(metadata_v2_df.shape)

new_cell_type_path = metadata_dir / "Martin_class_v3_041224.tsv"
new_cell_type_df = pd.read_csv(
    new_cell_type_path,
    sep="\t",
    names=["epirr_id_without_version", "cell_type_martin", "cell_type_PE"],
)
print(new_cell_type_df.shape)

merged_metadata = metadata_v2_df.merge(
    new_cell_type_df, on="epirr_id_without_version", how="left"
)
print(merged_metadata.shape)

In [None]:
new_meta = {dset["md5sum"]: dset for dset in merged_metadata.to_dict(orient="records")}
new_meta_dict = Metadata.from_dict(new_meta)
# new_meta_dict.save(
#     metadata_dir / "hg38_2023-epiatlas-dfreeze-pospurge-nodup_filterCtl_newCT.json"
# )

In [None]:
# for col in ["md5sum", "uuid", "epirr_id_without_version"]:
#     print(col, merged_metadata[col].nunique())

# merged_metadata = merged_metadata.drop_duplicates("uuid")
# print(merged_metadata.shape)

In [None]:
# for pivot_col in ["cell_type_martin", "cell_grouping_PE"]:
#     print(pivot_col)
#     pair_count_df = merged_metadata.groupby([pivot_col, ASSAY]).agg({"uuid": "count"}).reset_index()
#     assay_count_df = pair_count_df[pair_count_df["uuid"] >= 10].groupby(pivot_col).agg({ASSAY: "count"}).reset_index().sort_values(ASSAY, ascending=False)
#     print(assay_count_df.reset_index(drop=True))