In [None]:
"""Notebook to work on proper way to merge a lot of already augmented output files."""
# pylint: disable=line-too-long, redefined-outer-name, import-error

In [None]:
import functools
import re
from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd
from IPython.display import display

In [None]:
gen_base_dir = (
    Path.home() / "projects/epilap/output/logs/epiatlas-dfreeze-v2.1"
).resolve()
base_dir_1 = gen_base_dir / "hg38_100kb_all_none/general_results"
base_dir_2 = gen_base_dir / "hg38_100kb_all_none_w_encode_noncore"
valid_pred_files = list(
    base_dir_1.rglob("full-10fold-validation_prediction_augmented-all.csv")
)
valid_pred_files += list(
    base_dir_2.rglob("full-10fold-validation_prediction_augmented-all.csv")
)

In [None]:
invalid_dirs = ["oversampling", "noFC", "raw", "pval", "l1", "11c", "w-mix"]
valid_pred_files = [
    file
    for file in valid_pred_files
    if all(name not in str(file) for name in invalid_dirs)
]
print(len(valid_pred_files))

In [None]:
for path in valid_pred_files:
    print(path.parent)

In [None]:
OUTPUT_PATH = Path(gen_base_dir) / "merged_pred_results.csv"

In [None]:
valid_pred_files = sorted_paths

In [None]:
def create_filename(path: Path) -> str:
    """Create filename from path."""
    category = path.parents[1].name

    pattern = r"10fold-?(.+)?/"
    extra_info = re.search(pattern, str(path)).group(1)
    if extra_info is None:
        extra_info = ""

    if "encode" in str(path):
        extra_info += "encode"

    if extra_info:
        name = f"{category}_{extra_info}"
    else:
        name = category
    return name

In [None]:
dfs = {}
for input_file in valid_pred_files:
    input_file = Path(input_file)

    df_name = create_filename(input_file)

    df = pd.read_csv(input_file, index_col="md5sum", low_memory=False)
    df.name = df_name

    df.dropna(axis=1, how="all")
    if df_name in dfs:
        raise ValueError(
            f"Conflicting names from {input_file}: {df_name} file already exists."
        )

    dfs[df_name] = df

In [None]:
def print_cols(dfs: Dict):
    """Print all columns of first df in dict."""
    a_df = list(dfs.values())[0]
    print(a_df.shape)
    for column in a_df.columns:
        print(column)

In [None]:
# Remove detail of prediction probabilities
col1 = "1rst/2nd prob ratio"
col2 = "files/epiRR"
for cat, df in dfs.items():
    column_names = df.columns
    try:
        cut_pos_1 = column_names.get_loc(col1)
        cut_pos_2 = column_names.get_loc(col2)
        df = df.drop(df.columns[cut_pos_1 + 1 : cut_pos_2], axis=1)
        df = df.drop(columns=["EpiRR", "md5sum.1"])
    except KeyError:
        print("df seems already reduced")

    dfs[cat] = df
    print(df.shape)
    # if df.shape[1] != 72:
    #     print(df.name)
    #     raise ValueError("wtf")

In [None]:
# Drop useless columns
for cat, df in dfs.items():
    df.replace(to_replace=["--empty--", "", "NA", None], value=np.nan, inplace=True)
    df = df.dropna(axis=1, how="all")
    dfs[cat] = df

In [None]:
# Make all different columns have unique relevant names
# https://stackoverflow.com/questions/38101009/changing-multiple-column-names-but-not-all-of-them-pandas-python
nb_diff_columns = 13
for cat, df in dfs.items():
    old_names = df.columns[-nb_diff_columns:]
    new_names = [name + f" {cat}" for name in old_names if name[-1] != "n"]
    df.rename(columns=dict(zip(old_names, new_names)), inplace=True)
    df.name = cat
    dfs[cat] = df
    # print(df.columns)

In [None]:
# Merge all dataframes
df_list = list(dfs.values())
df_final = functools.reduce(
    lambda left, right: pd.merge(
        left, right, on="md5sum", how="outer", suffixes=(None, "_x")
    ),
    df_list,
)

In [None]:
# Remove duplicate metadata columns (those that end by _x)
filtered_final = df_final.filter(regex=r"^(?:(?!_x).)+$")

In [None]:
cols = filtered_final.columns.values
display(len(cols), cols)

In [None]:
filtered_final.to_csv(OUTPUT_PATH)

## Extra path ordering stuff

In [None]:
from fuzzywuzzy import process


def parse_instructions(instructions: str) -> Dict[str, int]:
    """
    Parse the instructions from A and return a dictionary with keys and their orders.

    Args:
        instructions (str): The instructions from A.

    Returns:
        Dict[str, int]: Dictionary containing the keys and their orders.
    """
    order_dict = {}
    for line in instructions.strip().split("\n"):
        if line.startswith("#"):
            match = re.match(r"#(\d+)", line)
            if match:
                order = int(match.group(1))
                key = re.search(r"[* ]([a-zA-Z_]+)", line[match.end() :]).group(1)
                order_dict[key] = order
    return order_dict


def fuzzy_sort_paths(paths: List[Path], order_dict: Dict[str, int]) -> List[str]:
    """
    Sort a list of paths based on the fuzzy matching with keys from an order dictionary.

    Args:
        paths (List[str]): The list of paths to sort.
        order_dict (Dict[str, int]): The dictionary containing keys and their orders.

    Returns:
        List[str]: List of paths sorted according to their best fuzzy-matched keys.
    """

    def get_order(path: Path) -> int:
        key = path.parents[1].name
        best_match, _ = process.extractOne(key, order_dict.keys())
        return order_dict.get(best_match, 9999)

    return sorted(paths, key=get_order)

In [None]:
instructions = """
#1 assay_epiclass
#2 assay_epiclass_encode
#9 harmonized_biomaterial_type
#3 harmonized_donor_sex (binary)
#6 harmonized_sample_disease_high
#10 paired_end
#5 groups_second_level_name, no “mixed.mixed”
#4 harmonized_sample_ontology_intermediate
#12 random_16c
#8 project
#11 track_type
#7 harmonized_donor_life_stage
"""

In [None]:
order_dict = parse_instructions(instructions)
sorted_paths = fuzzy_sort_paths(valid_pred_files, order_dict)

In [None]:
for elem in sorted(order_dict.items(), key=lambda x: x[1]):
    print(elem)

In [None]:
for i, path in enumerate(sorted_paths):
    print(i, path.parents[1].name, path.parent.name)