In [None]:
"""Workbook to create supplementary prediction files destined for the paper.

Must include all data predictions used to create paper figures.
"""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import annotations

import sys
from pathlib import Path
from typing import Dict, List

import pandas as pd

from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    CELL_TYPE,
    LIFE_STAGE,
    SEX,
    SplitResultsHandler,
)

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
paper_dir = base_dir
table_dir = paper_dir / "tables"

In [None]:
split_results_handler = SplitResultsHandler()

### assay_epiclass + sample ontology for all 5 model types - 100kb_all_none

In [None]:
def prepare_df_for_save(df: pd.DataFrame) -> pd.DataFrame:
    """Prepare DataFrame for saving to CSV."""
    df.insert(0, "Expected class", df.pop("True class"))
    df.set_index("md5sum", inplace=True)
    return df

In [None]:
data_dir_100kb = base_data_dir / "training_results" / "dfreeze_v2" / "hg38_100kb_all_none"
logdir = table_dir / "dfreeze_v2" / "100kb_all_none"
if not logdir.exists():
    logdir.mkdir(parents=True)

split_md5sums = []
for category in [ASSAY, CELL_TYPE]:
    all_split_dfs = split_results_handler.gather_split_results_across_methods(
        results_dir=data_dir_100kb,
        label_category=category,
        only_NN=False,
    )

    # Sanity check, same shape, same input files for each method
    for split_dict in all_split_dfs.values():
        ref_dict = split_dict["NN"]
        ref_md5sums = sorted(ref_dict.index.values.tolist())
        ref_shape = ref_dict.shape
        for method, df in split_dict.items():
            if not ref_md5sums == sorted(df.index.values.tolist()):
                raise ValueError("MD5sums do not match")
            if ref_shape != df.shape:
                raise ValueError("Shapes do not match")

    all_split_dfs_concat = split_results_handler.concatenate_split_results(all_split_dfs)

    # Save to file
    for method, df in all_split_dfs_concat.items():
        df = prepare_df_for_save(df)

        if method == "NN":
            method = "MLP"

        filename = f"10fold_predictions_{category}_{method}.csv"
        # df.to_csv(logdir / filename, index=True, sep=",", float_format="%.4f")

### Other MLP results - 100kb_all_none

In [None]:
categories = [
    "paired_end",
    "harmonized_sample_cancer_high",
    LIFE_STAGE,
    SEX,
    "harmonized_biomaterial_type",
    "project",
]

# Select 10-fold oversampling runs
all_split_dfs = split_results_handler.general_split_metrics(
    results_dir=data_dir_100kb,
    merge_assays=False,
    include_categories=categories,
    exclude_names=["reg", "no-mixed", "chip"],
    return_type="split_results",
    oversampled_only=True,
    verbose=False,
)
all_split_dfs_concat = split_results_handler.concatenate_split_results(all_split_dfs, concat_first_level=True)  # type: ignore

# Save to file
for category, df in all_split_dfs_concat.items():
    df = prepare_df_for_save(df)

    filename = f"10fold_predictions_{category}_MLP.csv"
    # df.to_csv(logdir / filename, index=True, sep=",", float_format="%.4f")

### Results for other feature sets (MLP)

In [None]:
def verify_splits_identity(
    all_results: Dict[str, Dict[str, Dict[str, pd.DataFrame]]],
    task_names: List[str],
    verbose: bool | None = None,
) -> None:
    """Verify that the splits are identical between feature sets for each task.

    all_results: {feature_set: {task_name: {split_name: results_dataframe}}}
    task_names: list of task names to verify
    verbose: print additional information
    """
    # Sanity check : MD5sums and shapes should match between reference and other feature sets, for each split
    for task_name in task_names:
        if verbose:
            print(f"Verifying task '{task_name}'")
        # Select a reference feature set and use its splits as the baseline for comparison
        reference_feature_set = "hg38_100kb_all_none"
        reference_splits = all_results[reference_feature_set][task_name]

        # Create reference MD5sums and shapes for each split in the reference feature set
        reference_md5sums = {
            split_name: sorted(df.index.tolist())
            for split_name, df in reference_splits.items()
        }
        reference_shapes = {
            split_name: df.shape for split_name, df in reference_splits.items()
        }

        # Iterate over each feature set and compare its splits against the reference
        for feature_set_name, tasks_dict in all_results.items():
            if verbose:
                print(
                    f"Verifying feature set '{feature_set_name}' against reference feature set '{reference_feature_set}'"
                )
            for split_name, df in tasks_dict[task_name].items():
                if reference_shapes[split_name] != df.shape:
                    print(
                        f"WARNING: Shape mismatch in task '{task_name}', split '{split_name}', "
                        f"between reference feature set '{reference_feature_set}' and feature set '{feature_set_name}'",
                        file=sys.stderr,
                    )
                if reference_md5sums[split_name] != sorted(df.index.tolist()):
                    print(
                        f"WARNING: MD5sums mismatch in task '{task_name}', split '{split_name}', "
                        f"between reference feature set '{reference_feature_set}' and feature set '{feature_set_name}'",
                        file=sys.stderr,
                    )

In [None]:
categories = [ASSAY, CELL_TYPE]
include_sets = [
    "hg38_10mb_all_none_1mb_coord",
    "hg38_100kb_random_n316_none",
    "hg38_1mb_all_none",
    "hg38_100kb_random_n3044_none",
    "hg38_100kb_all_none",
    "hg38_gene_regions_100kb_coord_n19864",
    "hg38_10kb_random_n30321_none",
    "hg38_regulatory_regions_n30321",
    "hg38_1kb_random_n30321_none",
    "hg38_cpg_topvar_200bp_10kb_coord_n30k",
    "hg38_10kb_all_none",
    "hg38_regulatory_regions_n303114",
    "hg38_1kb_random_n303114_none",
    "hg38_cpg_topvar_200bp_10kb_coord_n300k",
]
exclude_names = ["7c", "chip-seq-only", "27ct", "16ct"]

# Select 10-fold oversampling runs
# expected result shape: {feature_set: {task_name: {split_name: results_dataframe}}}
all_results: Dict[
    str, Dict[str, Dict[str, pd.DataFrame]]
] = split_results_handler.obtain_all_feature_set_data(
    return_type="split_results",
    parent_folder=data_dir_100kb.parent,
    merge_assays=False,
    include_categories=categories,
    include_sets=include_sets,
    exclude_names=exclude_names,
    verbose=False,
)  # type: ignore

In [None]:
for feature_set_name in all_results.keys():
    try:
        all_results[feature_set_name][ASSAY] = all_results[feature_set_name]["assay_epiclass_11c"]  # type: ignore
        del all_results[feature_set_name]["assay_epiclass_11c"]
    except KeyError:
        pass

In [None]:
verify_splits_identity(all_results, categories)

In [None]:
logdir = table_dir / "dfreeze_v2" / "other_feature_sets"
logdir.mkdir(parents=True, exist_ok=True)

for feature_set_name, tasks_dict in all_results.items():
    if feature_set_name == "hg38_100kb_all_none":
        continue
    all_split_dfs_concat = split_results_handler.concatenate_split_results(
        tasks_dict, concat_first_level=True
    )
    for task_name, df in all_split_dfs_concat.items():
        df = prepare_df_for_save(df)

        filename = f"{feature_set_name}_10fold_predictions_{task_name}.csv"
        # df.to_csv(logdir / filename, index=True, sep=",", float_format="%.4f")

### Winsorized files and/or blacklist zeroed

In [None]:
categories = [ASSAY, CELL_TYPE, SEX, "harmonized_biomaterial_type"]
include_sets = [
    "hg38_100kb_all_none",
    "hg38_100kb_all_none_0blklst",
    "hg38_100kb_all_none_0blklst_winsorized",
]

results_folder = base_data_dir / "training_results" / "2023-01-epiatlas-freeze"
if not results_folder.exists():
    raise FileNotFoundError(f"Folder '{results_folder}' not found")

In [None]:
# Select 10-fold oversampling runs
# expected result shape: {feature_set: {task_name: {split_name: results_dataframe}}}
all_results: Dict[
    str, Dict[str, Dict[str, pd.DataFrame]]
] = split_results_handler.obtain_all_feature_set_data(
    return_type="split_results",
    parent_folder=results_folder,
    merge_assays=False,
    include_categories=categories,
    include_sets=include_sets,
    oversampled_only=False,
    verbose=False,
)  # type: ignore

all_results.keys()

In [None]:
tasks_collected = list(all_results["hg38_100kb_all_none"].keys())

In [None]:
verify_splits_identity(all_results, tasks_collected, verbose=True)