In [None]:
"""Notebook to work on proper way to merge a lot of already augmented output files."""
# pylint: disable=line-too-long, redefined-outer-name, import-error, unused-import, pointless-statement, unreachable,unnecessary-lambda

In [None]:
from __future__ import annotations

import collections
import functools
import os
import re
import subprocess
from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd
from fuzzywuzzy import process
from IPython.display import display

from epi_ml.core.metadata import Metadata
from epi_ml.utils.classification_merging_utils import merge_dataframes, remove_pred_vector
from epi_ml.utils.ssh_utils import createSCPClient, createSSHClient, run_commands_via_ssh

In [None]:
CURRENT_DIR = Path(os.path.abspath(""))

## Collect relevant files

In [None]:
gen_base_dir = (Path.home() / "mounts/narval-mount/logs-dfreeze-2.1").resolve()

In [None]:
valid_pred_files = []

In [None]:
narval_base_dir = "~/logs-dfreeze-2.1"

In [None]:
to_merge_dir = (
    Path.home()
    / "projects/epilap/output/logs/epiatlas-dfreeze-v2.1/merged_results/epiatlas"
)
# to_merge_files = to_merge_dir / "valid_pred_files_all.list"
to_merge_files = to_merge_dir / "valid_pred_files_non_augmented_all.list"

In [None]:
if not to_merge_files.exists():
    cmd1 = f"cd {narval_base_dir} && find . -mindepth 3 -maxdepth 5 -type f -name full-10fold-validation_prediction.csv"
    cmd2 = (
        f"cd {narval_base_dir} && find . -mindepth 5 -maxdepth 6 -type f -name *test*.csv"
    )
    cmd_results = run_commands_via_ssh(
        cmds=[cmd1, cmd2],
        username="rabyj",
        hostname="narval.computecanada.ca",
        port=22,
    )
    # join results of the two cmds
    valid_pred_files = [cmd_result.splitlines() for cmd_result in cmd_results]
    valid_pred_files = valid_pred_files[0] + valid_pred_files[1]

    with open(to_merge_files, "w", encoding="utf8") as f:
        f.write("\n".join(valid_pred_files))
else:
    with open(to_merge_files, "r", encoding="utf8") as f:
        valid_pred_files = f.read().splitlines()

In [None]:
# for file in valid_pred_files:
#     print(file)

In [None]:
OUTPUT_PATH = Path.home() / "downloads" / "merged_pred_results_blblbllblb.csv"

In [None]:
invalid_dirs = [
    "noFC",
    "raw",
    "pval",
    "l1",
    "harmonized_donor_sex_1l_3000n/no-mixed",
    "groups_second_level_name_1l_3000n/w-mix",
    "w-unknown",
    "10fold-2",
    "10fold-oversampling2",
    "10fold-oversample2",
    "random_1l_3000n/10fold-11c",
]
valid_pred_files = [Path(file) for file in valid_pred_files]
valid_pred_files = [
    file
    for file in valid_pred_files
    if all(name not in str(file) for name in invalid_dirs)
]

In [None]:
categories = collections.defaultdict(list)
for file in valid_pred_files:
    categories[file.parent.parent].append(file.parent.name)

In [None]:
categories

In [None]:
oversampling_dirs = []
for folder, result_list in categories.items():
    if any(
        result in ["10fold-oversampling", "10fold-oversample"] for result in result_list
    ):
        oversampling_dirs.append(folder)

In [None]:
oversampling_dirs

In [None]:
# remove non-oversampling results when w-oversampling also exits
for file in list(valid_pred_files):
    # sanity check
    if (file.parent.name == "10fold") != file.parent.stem.endswith("10fold"):
        raise ValueError(f"wat: {str(file)}")

    if file.parent.parent in oversampling_dirs and file.parent.name == "10fold":
        print(f"Removing {file}")
        valid_pred_files.remove(file)

In [None]:
print(len(valid_pred_files))
# for file in valid_pred_files:
#     print(file)

In [None]:
results_base_dir = to_merge_dir / "input_non_augmented"

In [None]:
with open(
    results_base_dir.parent / "valid_pred_files_non_augmented_filtered.list",
    "w",
    encoding="utf8",
) as f:
    f.write("\n".join([str(path) for path in valid_pred_files]))

### Extra reprocessing to add split_nb to predictions

In-between, all splits were downloaded, via commands make in valid_pred_files_non_augmented_filtered_splits.sh

In [None]:
results_base_dir = to_merge_dir / "input_non_augmented"

script = CURRENT_DIR.parent / "merge_validation_predictions.py"

for file in results_base_dir.rglob("validation_prediction.csv"):
    folds_dir = file.parent.parent
    if list(folds_dir.glob("*.csv")):
        continue

    subprocess.check_output(args=["python", str(script), str(folds_dir), "-n", "10"])

In [None]:
valid_pred_files = list(results_base_dir.rglob("complete*unknown.csv"))
valid_pred_files += list(results_base_dir.rglob("full*prediction.csv"))

In [None]:
valid_pred_files

In [None]:
print(len(valid_pred_files))

## Order paths in desired order

In [None]:
def parse_instructions(instructions: str) -> Dict[str, int]:
    """
    Parse the instructions from A and return a dictionary with keys and their orders.

    Args:
        instructions (str): The instructions from A.

    Returns:
        Dict[str, int]: Dictionary containing the keys and their orders.
    """
    order_dict = {}
    for line in instructions.strip().split("\n"):
        if line.startswith("#"):
            match = re.match(r"#(\d+)", line)
            if match:
                order = int(match.group(1))
                key = re.search(r"[* ]([a-zA-Z_]+)", line[match.end() :]).group(1)
                order_dict[key] = order
    return order_dict


def fuzzy_sort_paths(paths: List[Path], order_dict: Dict[str, int]) -> List[str]:
    """
    Sort a list of paths based on the fuzzy matching with keys from an order dictionary.

    Args:
        paths (List[str]): The list of paths to sort.
        order_dict (Dict[str, int]): The dictionary containing keys and their orders.

    Returns:
        List[str]: List of paths sorted according to their best fuzzy-matched keys.
    """

    def get_order(path: Path) -> int:
        parent_names = [parent.name for parent in path.parents]
        key = "/".join(parent_names[0:3][::-1])
        best_match, _ = process.extractOne(key, order_dict.keys())
        return order_dict.get(best_match, 9999)

    return sorted(paths, key=get_order)

In [None]:
instructions = """
#1 assay_epiclass
#2 assay_epiclass_encode
#9 harmonized_biomaterial_type
#3 harmonized_donor_sex (trinary)
#6 harmonized_sample_disease_high
#6 harmonized_sample_cancer_high
#10 paired_end
#5 groups_second_level_name, no “mixed.mixed”
#4 harmonized_sample_ontology_intermediate
#12 random_16c
#8 project
#11 track_type
#7 harmonized_donor_life_stage
#13 complete_no_valid_oversample/predictions
"""

In [None]:
order_dict = parse_instructions(instructions)
sorted_paths = fuzzy_sort_paths(valid_pred_files, order_dict)

In [None]:
# for elem in sorted(order_dict.items(), key=lambda x: x[1]):
#     print(elem)

# for i, path in enumerate(sorted_paths):
#     print(i, str(path).split("/")[-4:-1])

In [None]:
def create_filename(path: Path) -> str:
    """Create filename from important path information."""
    if "predictions" in str(path):
        important_names = [path.name for path in list(path.parents)[0:4][::-1]]
    else:
        important_names = [path.name for path in list(path.parents)[0:3][::-1]]

    for important_name in important_names:
        if "encode" in important_name:
            important_names.remove(important_name)
            important_names.insert(0, "encode")
        elif "hg38_100kb_all_none" in important_name:
            important_names.remove(important_name)

    name = "_".join(important_names)

    return name

In [None]:
for path in sorted_paths:
    print(create_filename(path))

In [None]:
# scp_client = None

# new_sorted_paths = []
# for input_file in sorted_paths:
#     input_file = Path(input_file)
# new_filename = f"{create_filename(input_file)}.csv"
#     if not (results_base_dir / new_filename).is_file():
#         files = [f"{narval_base_dir}/{input_file}", f"{results_base_dir}/{new_filename}"]
#         try:
#             scp_client.get(*files)
#         except AttributeError:
#             print("Creating new scp client")
#             scp_client = createSCPClient(
#                 createSSHClient("narval.computecanada.ca", 22, "rabyj")
#             )
#             scp_client.get(*files)
#     new_path = results_base_dir / new_filename
#     new_sorted_paths.append(new_path)

In [None]:
for input_file in sorted_paths:
    print(input_file)

In [None]:
python_script = CURRENT_DIR.parent / "augment_predict_file.py"
metadata_file = (
    Path.home()
    / "projects/epilap/input/metadata/dfreeze-v2/hg38_2023-epiatlas-dfreeze_v2.1_w_encode_noncore_2.json"
)
meta_df = pd.DataFrame.from_records(
    list(Metadata(metadata_file).datasets), index="md5sum"
)

In [None]:
new_sorted_paths = []
for input_file in list(sorted_paths):
    new_file = input_file.parent / Path(input_file.stem + "_augmented.csv")  # type: ignore
    if new_file.is_file():
        new_sorted_paths.append(new_file)
        continue

    print(f"Creating {new_file}")
    args = [
        "python",
        str(python_script),
        str(input_file),
        str(metadata_file),
        "--compute-coherence",
    ]
    subprocess.check_output(args=args)

    if not new_file.is_file():
        raise FileNotFoundError(f"Did not create {new_file}.")
    new_sorted_paths.append(new_file)

### Merge files

In [None]:
dfs = {}
for input_file in new_sorted_paths:
    df_name = create_filename(input_file)
    try:
        df = pd.read_csv(input_file, index_col="md5sum", low_memory=False)
    except ValueError as err:
        print(f"Error reading {input_file}: {err}")
        continue

    df.dropna(axis=1, how="all")
    if df_name in dfs:
        raise ValueError(
            f"Conflicting names from {input_file}: {df_name} file already exists."
        )

    dfs[df_name] = df

In [None]:
for name, df in list(dfs.items()):
    df = remove_pred_vector(df)
    dfs[name] = df

In [None]:
# Drop useless columns
for name, df in dfs.items():
    df.replace(to_replace=["--empty--", "", "NA", None], value=np.nan, inplace=True)
    df = df.dropna(axis=1, how="all")
    dfs[name] = df

In [None]:
# Make all different columns have unique relevant names, hardcoded 13 work only on output of augmented files with added coherence columns
# https://stackoverflow.com/questions/38101009/changing-multiple-column-names-but-not-all-of-them-pandas-python
# 13 without split_nb col, 14 with.
nb_diff_columns = 14
old_names = list(dfs.values())[0].columns[-nb_diff_columns:]
for cat, df in dfs.items():
    new_names = [old_name + f" {cat}" for old_name in old_names if name[-1] != "n"]
    df.rename(columns=dict(zip(old_names, new_names)), inplace=True)
    dfs[cat] = df
    # print(df.columns)

Merge encode and epiatlas df, encode metadata is not redundant with epiatlas

In [None]:
df_key1 = "assay_epiclass_1l_3000n_11c_10fold-oversampling"
df_key2 = "encode_assay_epiclass_1l_3000n_10fold-oversampling"
df_key3 = "partial_merge"

partial_merge = merge_dataframes(dfs[df_key1], dfs[df_key2])

dfs[df_key3] = partial_merge

In [None]:
# for name in [df_key1, df_key2, df_key3]:
#     df = dfs[name]
#     print(name, df.shape)
#     # print(df.index.name)
#     display(df["assay_epiclass"].value_counts(dropna=False))

In [None]:
raise ValueError("stop here")

In [None]:
for df_name in [df_key1, df_key2]:
    try:
        del dfs[df_name]
    except KeyError:
        continue

In [None]:
for df_name, df in dfs.items():
    if any(df["assay_epiclass"].isnull()):
        print(f"assay_epiclass is null in {df_name}")

Merge all the rest of dataframes

starting with biggest dataframes first

In [None]:
df_list = sorted(list(dfs.values()), key=lambda x: len(x), reverse=True)
df_final = functools.reduce(merge_dataframes, df_list)

In [None]:
# for column in df_final.columns:
#     print(column)

In [None]:
df_final = df_final.merge(
    meta_df, left_index=True, right_index=True, how="inner", suffixes=("", "_delete")
)

In [None]:
# Remove duplicate metadata columns (those that end by _delete)
df_final = df_final.filter(regex=r"^(?:(?!_delete).)+$")

In [None]:
df_final.head()

In [None]:
# Re-arrange columns
all_columns = df_final.columns.tolist()

# Separate metadata and result columns
result_columns = [col for col in all_columns if col.rsplit(" ", 1)[0] in old_names]
meta_columns = [col for col in all_columns if col not in result_columns]

new_order = meta_columns + result_columns
df_final = df_final[new_order]

In [None]:
df_final.to_csv(OUTPUT_PATH)

In [None]:
raise ValueError("STOP HERE")

### Add ChrY/X coverage

In [None]:
OUTPUT_PATH

In [None]:
df_final = pd.read_csv(OUTPUT_PATH, index_col="md5sum", low_memory=False)

chrY_path = (
    Path.home()
    / "Projects/epilap/output/logs/epiatlas-dfreeze-v2.1/chrY_coverage_results/chrY_coverage_zscores.csv"
)
df_chrY = pd.read_csv(chrY_path, index_col="filename")

In [None]:
print(df_final.shape, df_chrY.shape)

In [None]:
new_final = df_final.join(df_chrY, how="left")
assert new_final.shape == (
    df_final.shape[0],
    df_final.shape[1] + df_chrY.shape[1],
)  # same number as og samples, but more columns
new_final.to_csv(OUTPUT_PATH.parent / "merged_pred_results_all_2.1_chrY_zscores.csv")