In [None]:
"""Notebook to merge various augmented (with no metadata) prediction results into a single file. Mostly working with ENCODE metadata"""
# pylint: disable=line-too-long, redefined-outer-name, import-error, unused-import, pointless-statement, unreachable, unnecessary-lambda

In [None]:
from __future__ import annotations

import collections
import functools
import json
import os
import pprint
import re
import subprocess
from pathlib import Path

import numpy as np
import pandas as pd
from IPython.display import display

from epi_ml.core.metadata import Metadata
from epi_ml.utils.classification_merging_utils import (
    merge_dataframes,
    merge_two_columns,
    remove_pred_vector,
)

### Collect relevant files

In [None]:
gen_base_dir = (
    Path.home() / "projects/epilap/output/logs/epiatlas-dfreeze-v2.1/merged_results"
).resolve()
if not gen_base_dir.exists():
    raise ValueError(f"Path {gen_base_dir} does not exist.")

prediction_results_dir = gen_base_dir / "encode/input"

In [None]:
metadata_path = prediction_results_dir / "hg38_ENCODE_total_final.json"
metadata = Metadata(metadata_path)
meta_df = pd.DataFrame.from_records(list(metadata.datasets), index="md5sum")

In [None]:
print(meta_df.shape)
for col in ["assay_epiclass", "sex", "donor_sex", "life_stage", "biosample_type"]:
    print(col)
    if col in meta_df.columns:
        display(meta_df[col].value_counts(dropna=False))
    else:
        print("Not in metadata")

In [None]:
for category in meta_df.columns.values:
    if "sex" in category:
        print(category)
        display(meta_df[category].value_counts())

In [None]:
for category in meta_df.columns.values:
    if "life" in category or "age" in category:
        print(category)
        display(meta_df[category].value_counts())

### Correct some metadata

In [None]:
for col1, col2 in zip(
    ["biosample_type", "project_x", "project_y"],
    ["Biosample_type", "Project_x", "Project_y"],
):
    if col1 in meta_df.columns and col2 in meta_df.columns:
        meta_df = merge_two_columns(meta_df, col1, col2)
    else:
        print(f"Column {col1} or {col2} not in metadata.")

In [None]:
display(meta_df["biosample_type"].value_counts(dropna=False))
display(meta_df["project_x"].value_counts(dropna=False))
display(meta_df["project_y"].value_counts(dropna=False))

In [None]:
meta_df.drop(columns=["project_y"], inplace=True)

### Add more metadata

In [None]:
extra_metadata_path = Path.home() / "downloads" / "encodeproject.json"
with open(extra_metadata_path, "r", encoding="utf8") as f:
    extra_metadata_dict = json.load(f)

#### Explore extra metadata

In [None]:
for k, v in extra_metadata_dict.items():
    try:
        print(f"{k}: {len(v)}")
    except TypeError:
        print(f"{k}: {v}")

In [None]:
key_counter = collections.Counter()
for val in extra_metadata_dict["@graph"]:
    key_counter.update(val.keys())

for key in key_counter.most_common():
    print(key)

In [None]:
i = 0
for graph in extra_metadata_dict["@graph"]:
    print(graph.keys())
    for k in graph.keys():
        print(f"{k}: {graph[k]}\n")
    if i == 0:
        break
    i += 1

In [None]:
files_all = collections.Counter()
for graph in extra_metadata_dict["@graph"]:
    files_all.update([file_id.split("/")[2] for file_id in graph["original_files"]])

In [None]:
assert files_all.most_common()[0][1] == 1

In [None]:
len(set(meta_df.index.values) & set(files_all.keys()))

In [None]:
# reduced_json = [elem for idx, elem in enumerate(extra_metadata_dict["@graph"]) if idx < 5000]
# for graph in reduced_json:
#     try:
#         del graph["revoked_files"]
#     except KeyError:
#         continue

# with open(extra_metadata_path.parent, "w", encoding="utf8") as f:
#     json.dump(reduced_json, f)

In [None]:
sex_all = collections.Counter()
ct_all = collections.Counter()
for graph in extra_metadata_dict["@graph"]:
    biosample_info = graph["replicates"][0]["library"]["biosample"]
    sex_label = biosample_info["sex"]

    ct_info = graph["biosample_ontology"]["cell_slims"]

    sex_all.update([sex_label])
    ct_all.update(ct_info)

In [None]:
display(sex_all.most_common())
display(ct_all.most_common())

#### Get new metadata values and integrate

In [None]:
# Integrate values with metadata
# verifying sex and life_stage values
new_extra_metadata = []
for graph in extra_metadata_dict["@graph"]:
    biosample_info = graph["replicates"][0]["library"]["biosample"]
    sex_label = biosample_info["sex"]

    cancer_info = graph["biosample_ontology"]["cell_slims"]
    if "cancer cell" in cancer_info:
        cancer_info = "cancer"
    else:
        cancer_info = "non-cancer"

    files = [file_id.split("/")[2] for file_id in graph["original_files"]]

    for file in files:
        new_extra_metadata.append(
            {"md5sum": file, "donor_sex": sex_label, "cancer": cancer_info}
        )

In [None]:
new_extra_metadata_df = pd.DataFrame.from_records(new_extra_metadata, index="md5sum")

In [None]:
print(new_extra_metadata_df.shape)
# display(new_extra_metadata_df.donor_life_stage.value_counts())
# display(new_extra_metadata_df.donor_sex.value_counts())

In [None]:
meta_df = merge_dataframes(meta_df, new_extra_metadata_df)

In [None]:
meta_df = meta_df[meta_df["assay_epiclass"].notnull()]

In [None]:
meta_df.fillna("unknown", inplace=True)
meta_df.replace(to_replace="", value="unknown", inplace=True)

In [None]:
print(meta_df.shape)
display(meta_df.donor_sex.value_counts(dropna=False))
display(meta_df.assay_epiclass.value_counts(dropna=False))
display(meta_df.cancer.value_counts(dropna=False))

In [None]:
# display(meta_df["life_stage"].value_counts())
# sum(meta_df["life_stage"].value_counts())

In [None]:
# display(meta_df["donor_life_stage"].value_counts())
# sum(meta_df["donor_life_stage"].value_counts())

### Update prediction files as needed

In [None]:
dfs = {}
for pred_file in prediction_results_dir.glob("*.csv"):
    df = pd.read_csv(pred_file, sep=",", index_col="md5sum", dtype=str)
    df_name = pred_file.stem.replace("_prediction_100kb_all_none_augmented", "")

    # # Add true class if inexistent
    # if "True class" not in df.columns:
    #     print(f"Adding 'True class' to {df_name}")
    #     df.insert(0, "True class", "unknown")
    #     df.to_csv(pred_file, sep=",", index=True)

    # # Add true class if inexistent
    # if "Same?" not in df.columns:
    #     print(f"Adding 'Same?' to {df_name}")
    #     df.insert(2, "Same?", "False")
    #     df.to_csv(pred_file, sep=",", index=True)

    # Augment if not already done
    if "Max pred" not in df.columns:
        print(f"Augmenting {df_name}")
        current_dir = Path(os.path.abspath(""))
        output = subprocess.check_output(
            args=[
                "python",
                str(current_dir.parent / "augment_predict_file.py"),
                str(pred_file),
                str(metadata_path),
            ]
        ).decode("utf-8")
        new_name = str(pred_file).replace(".csv", "_augmented.csv")
        df = pd.read_csv(new_name, sep=",", index_col=0)

    dfs[df_name] = df

In [None]:
sorted(dfs.keys())

#### Add 'True class' values

In [None]:
true_class_dict = {
    "predict_assay7_oversample_test": "assay_epiclass",
    "predict_assay7_test": "assay_epiclass",
    "predict_assay11_test": "assay_epiclass",
    "predict_assay13_test": "assay_epiclass",
    "predict_biomat_test": "biosample_type",
    "predict_donorlife_oversample_test": "life_stage",
    "predict_project_oversample_test": "project_x",
    "predict_sex2_test": "donor_sex",
    "predict_sex3_oversample_test": "donor_sex",
    "predict_cancer_oversample_test": "cancer",
    "predict_disease_oversample_test": "cancer",
    "predict_disease_test": "cancer",
}

In [None]:
samples = list(
    set.union(*[set(dfs[df_name].index.values) for df_name in true_class_dict])
)

In [None]:
meta_df = meta_df.loc[samples]

In [None]:
for df_name, class_label in true_class_dict.items():
    df = dfs[df_name]
    df["True class"] = meta_df[class_label]
    try:
        df["Same?"] = df["True class"].str.lower() == df["Predicted class"].str.lower()
    except KeyError as err:
        print(err)
        print(df.columns.values)

### Merge dataframes

In [None]:
for df_name, df in list(dfs.items()):
    try:
        df = remove_pred_vector(df)
    except KeyError:
        print(f"Could not remove pred vector from {df_name}")

    dfs[df_name] = df

In [None]:
# Drop useless columns
for name, df in dfs.items():
    df.replace(to_replace=["--empty--", "", "NA", None], value=np.nan, inplace=True)
    df = df.dropna(axis=1, how="all")
    dfs[name] = df

In [None]:
# for df_name, df in dfs.items():
#     print(df.columns.values, df_name)

In [None]:
# Make result column names unique (not metadata columns)
old_column_names = list(dfs.values())[0].columns.values
for df_name, df in dfs.items():
    if df.shape[1] != 7:
        raise ValueError(f"Wrong number of columns in {df_name}. {df.columns.values}")
    new_column_names = [old_name + f" {df_name}" for old_name in old_column_names]
    df.rename(columns=dict(zip(old_column_names, new_column_names)), inplace=True)
    df.name = df_name
    dfs[df_name] = df

In [None]:
df_list = [meta_df] + [df for _, df in sorted(dfs.items())]
df_final = functools.reduce(merge_dataframes, df_list)

In [None]:
# Remove duplicate metadata columns (those that end by _delete)
df_final = df_final.filter(regex=r"^(?:(?!_delete).)+$")

In [None]:
# Re-arrange columns
all_columns = df_final.columns.tolist()

# Separate metadata and result columns
result_columns = [col for col in all_columns if col.rsplit(" ", 1)[0] in old_column_names]
meta_columns = [col for col in all_columns if col not in result_columns]

new_order = meta_columns + result_columns
df_final = df_final[new_order]

In [None]:
for column in list(df_final.columns):
    if all(df_final[column] == "unknown"):
        df_final.drop(columns=[column], inplace=True)

In [None]:
df_final.to_csv(
    prediction_results_dir.parent / "encode_predictions_merged_results_V2.csv"
)