In [None]:
"""Workbook to create supplementary prediction files destined for the paper.

Includes most data predictions used to create paper figures.
"""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines

In [2]:
%load_ext autoreload
%autoreload 2

## SETUP

In [3]:
from __future__ import annotations

import functools
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Set, Tuple

import numpy as np
import pandas as pd
from IPython.display import display
from sklearn.metrics import classification_report, confusion_matrix as sk_cm

from epi_ml.utils.classification_merging_utils import merge_dataframes
from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    CELL_TYPE,
    LIFE_STAGE,
    SEX,
    MetadataHandler,
    SplitResultsHandler,
)

In [4]:
CANCER = "harmonized_sample_cancer_high"

In [5]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
metadata_dir = base_data_dir / "metadata"
paper_dir = base_dir
table_dir = paper_dir / "tables"

In [6]:
split_results_handler = SplitResultsHandler()
metadata_handler = MetadataHandler(paper_dir)

### Official metadata

In [7]:
meta_df = metadata_handler.load_metadata_df(version="v2", merge_assays=False)
md5sum_to_epirr = meta_df["epirr_id_without_version"].to_dict()
md5sum_to_uuid = meta_df["uuid"].to_dict()
del meta_df

In [8]:
official_metadata_dir = base_data_dir / "metadata" / "official"

metadata_v1_1_path = (
    official_metadata_dir / "IHEC_metadata_harmonization.v1.1.extended.csv"
)
metadata_v1_1 = pd.read_csv(metadata_v1_1_path, index_col=False)
metadata_v1_1.set_index("epirr_id_without_version", inplace=True)

metadata_v1_2_path = (
    official_metadata_dir / "IHEC_metadata_harmonization.v1.2.extended.csv"
)
metadata_v1_2 = pd.read_csv(metadata_v1_2_path, index_col=False)
metadata_v1_2.set_index("epirr_id_without_version", inplace=True)

## Collect experiment keys for all trained classifiers

In [None]:
def extract_experiment_info(line: str) -> Tuple[str, str] | None:
    """Extract split and experiment key from a line containing checkpoint information.

    Line should have format: .../splitX/EpiLaP/[exp_key]/checkpoints/...
    """
    if "EpiLaP" not in line:
        return None

    parts = line.strip().split("/")
    for i, part in enumerate(parts):
        if part == "EpiLaP" and i > 0:
            return (parts[i - 1], parts[i + 1])
    return None


def process_log_file(file_path: Path) -> Set[Tuple[str, str]]:
    """Process a single log file and extract experiment information."""
    experiment_info = set()
    try:
        with open(file_path, "r", encoding="utf8") as f:
            for line in f:
                if result := extract_experiment_info(line):
                    experiment_info.add(result)
    except Exception as e:  # pylint: disable=broad-exception-caught
        print(f"Error processing {file_path}: {e}")

    return experiment_info


def collect_exp_keys(folder: Path) -> Dict[str, Set[Tuple[str, str]]]:
    """Collect experiment keys from log files (.o files), recursively from a given folder."""
    experiments_keys = defaultdict(set)

    log_files = folder.glob("*1l_3000n/**/*.o")
    for file in log_files:
        if experiment_info := process_log_file(file):
            experiments_keys[file.parent].update(experiment_info)

    return experiments_keys

In [None]:
def format_exp_key_context(
    experiments_keys_dict: Dict[str, Set[Tuple[str, str]]]
) -> pd.DataFrame:
    """Format experiment keys context for saving to a DataFrame."""
    data = []
    for exp_folder, exp_keys in experiments_keys_dict.items():
        for split, exp_key in exp_keys:
            data.append((str(exp_folder), split, exp_key))

    df = pd.DataFrame(data, columns=["exp_folder", "split", "exp_key"])

    # comet-ml experiment url
    df["comet-url"] = "https://www.comet.com/rabyj/epiclass/" + df["exp_key"].astype(str)

    # Remove useless part of paths
    to_remove_path = (
        str(Path.home() / "Projects/epiclass/output/paper/data/training_results") + "/"
    )
    df["complete_experiment_context"] = df["exp_folder"].str.replace(to_remove_path, "")
    df.drop(columns="exp_folder", inplace=True)

    # Split path into named parts
    df[["release", "feature_set_name", "metadata_category"]] = (
        df["complete_experiment_context"].str.split("/", expand=True).loc[:, [0, 1, 2]]
    )

    df["experiment_specification"] = (
        df["complete_experiment_context"]
        .str.split("/", n=3, expand=True)[3]
        .str.replace("/", ",")
    )

    # Remove redundant info (all MLP exp are 1 hidden layer 3000 nodes)
    df["metadata_category"] = df["metadata_category"].str.replace("_1l_3000n", "")

    # Reorder columns
    df_new_col_order = df.columns.to_list()[-4:] + df.columns.to_list()[:-4]
    df = df[df_new_col_order]

    return df

In [None]:
all_exp_keys_dfs = []
for folder in ["dfreeze_v2", "2023-01-epiatlas-freeze", "imputation"]:
    data_dir = base_data_dir / "training_results" / folder
    for subfolder in data_dir.glob("*"):
        if subfolder.is_file():
            continue
        # print(subfolder)
        exp_key_dict = collect_exp_keys(subfolder)
        df = format_exp_key_context(exp_key_dict)
        all_exp_keys_dfs.append(df)

exp_keys_df = pd.concat(all_exp_keys_dfs, ignore_index=True)

In [None]:
# for col in ["release", "feature_set_name", "metadata_category", "split"]:
#     display(exp_keys_df[col].value_counts(dropna=False))

exp_keys_df.to_csv(table_dir / "training_experiment_keys.csv", index=False)

## assay_epiclass + sample ontology for all 5 model types - 100kb_all_none

In [None]:
def prepare_df_for_save(
    df: pd.DataFrame, md5sum_to_epirr: Dict[str, str], md5sum_to_uuid: Dict[str, str]
) -> pd.DataFrame:
    """Prepare DataFrame for saving to CSV. Return a modified DataFrame, with:
    - Index set to md5sum
    - Expected class as the first column
    - epirr_without_version column added
    - uuid column added
    - Sorted by epirr_without_version and uuid
    """
    df.insert(0, "Expected class", df.pop("True class"))
    df.set_index("md5sum", inplace=True)

    df["epirr_without_version"] = df.index.map(md5sum_to_epirr)
    df["uuid"] = df.index.map(md5sum_to_uuid)
    df.sort_values(["epirr_without_version", "uuid"], inplace=True)
    return df

In [None]:
data_dir_100kb = base_data_dir / "training_results" / "dfreeze_v2" / "hg38_100kb_all_none"
logdir = table_dir / "dfreeze_v2" / "100kb_all_none"
if not logdir.exists():
    logdir.mkdir(parents=True)

split_md5sums = []
for category in [ASSAY, CELL_TYPE]:
    all_split_dfs = split_results_handler.gather_split_results_across_methods(
        results_dir=data_dir_100kb,
        label_category=category,
        only_NN=False,
    )

    # Sanity check, same shape, same input files for each method
    for split_dict in all_split_dfs.values():
        ref_dict = split_dict["NN"]
        ref_md5sums = sorted(ref_dict.index.values.tolist())
        ref_shape = ref_dict.shape
        for method, df in split_dict.items():
            if not ref_md5sums == sorted(df.index.values.tolist()):
                raise ValueError("MD5sums do not match")
            if ref_shape != df.shape:
                raise ValueError("Shapes do not match")

    all_split_dfs_concat: Dict = split_results_handler.concatenate_split_results(all_split_dfs)  # type: ignore

    # Save to file
    for method, df in all_split_dfs_concat.items():
        df = prepare_df_for_save(df, md5sum_to_epirr, md5sum_to_uuid)

        if method == "NN":
            method = "MLP"

        filename = f"10fold_predictions_{category}_{method}.csv"
        df.to_csv(logdir / filename, index=True, sep=",", float_format="%.4f")

## Other MLP results - 100kb_all_none

In [None]:
categories = [
    "paired_end",
    "harmonized_sample_cancer_high",
    LIFE_STAGE,
    SEX,
    "harmonized_biomaterial_type",
    "project",
]

# Select 10-fold oversampling runs
all_split_dfs = split_results_handler.general_split_metrics(
    results_dir=data_dir_100kb,
    merge_assays=False,
    include_categories=categories,
    exclude_names=["reg", "no-mixed", "chip"],
    return_type="split_results",
    oversampled_only=True,
    verbose=False,
)
all_split_dfs_concat: Dict = split_results_handler.concatenate_split_results(all_split_dfs, concat_first_level=True)  # type: ignore

# Save to file
for category, df in all_split_dfs_concat.items():
    df = prepare_df_for_save(df, md5sum_to_epirr, md5sum_to_uuid)
    if category in [LIFE_STAGE, SEX]:
        for version, metadata in [("v1.2", metadata_v1_2), ("v1.1", metadata_v1_1)]:
            idx = df.index.map(md5sum_to_epirr).values
            values = metadata.loc[idx, category].to_list()  # type: ignore
            df.insert(loc=0, column=f"Expected class {version}", value=values)

        df.drop(columns="Expected class", inplace=True)

    filename = f"10fold_predictions_{category}_MLP.csv"
    df.to_csv(logdir / filename, index=True, sep=",", float_format="%.4f")

## Results for other feature sets (MLP)

In [None]:
def verify_splits_identity(
    all_results: Dict[str, Dict[str, Dict[str, pd.DataFrame]]],
    task_names: List[str],
    verbose: bool | None = None,
) -> None:
    """Verify that the splits are identical between feature sets for each task.

    all_results: {feature_set: {task_name: {split_name: results_dataframe}}}
    task_names: list of task names to verify
    verbose: print additional information
    """
    # Sanity check : MD5sums and shapes should match between reference and other feature sets, for each split
    for task_name in task_names:
        if verbose:
            print(f"Verifying task '{task_name}'")
        # Select a reference feature set and use its splits as the baseline for comparison
        reference_feature_set = "hg38_100kb_all_none"
        reference_splits = all_results[reference_feature_set][task_name]

        # Create reference MD5sums and shapes for each split in the reference feature set
        reference_md5sums = {
            split_name: sorted(df.index.tolist())
            for split_name, df in reference_splits.items()
        }
        reference_shapes = {
            split_name: df.shape for split_name, df in reference_splits.items()
        }

        # Iterate over each feature set and compare its splits against the reference
        for feature_set_name, tasks_dict in all_results.items():
            if verbose:
                print(
                    f"Verifying feature set '{feature_set_name}' against reference feature set '{reference_feature_set}'"
                )
            for split_name, df in tasks_dict[task_name].items():
                if reference_shapes[split_name] != df.shape:
                    print(
                        f"WARNING: Shape mismatch in task '{task_name}', split '{split_name}', "
                        f"between reference feature set '{reference_feature_set}' and feature set '{feature_set_name}'",
                    )
                if reference_md5sums[split_name] != sorted(df.index.tolist()):
                    print(
                        f"WARNING: MD5sums mismatch in task '{task_name}', split '{split_name}', "
                        f"between reference feature set '{reference_feature_set}' and feature set '{feature_set_name}'",
                    )

In [None]:
categories = [ASSAY, CELL_TYPE]
include_sets = [
    "hg38_10mb_all_none_1mb_coord",
    "hg38_100kb_random_n316_none",
    "hg38_1mb_all_none",
    "hg38_100kb_random_n3044_none",
    "hg38_100kb_all_none",
    "hg38_gene_regions_100kb_coord_n19864",
    "hg38_10kb_random_n30321_none",
    "hg38_regulatory_regions_n30321",
    "hg38_1kb_random_n30321_none",
    "hg38_cpg_topvar_200bp_10kb_coord_n30k",
    "hg38_10kb_all_none",
    "hg38_regulatory_regions_n303114",
    "hg38_1kb_random_n303114_none",
    "hg38_cpg_topvar_200bp_10kb_coord_n300k",
]
exclude_names = ["7c", "chip-seq-only", "27ct", "16ct"]

# Select 10-fold oversampling runs
# expected result shape: {feature_set: {task_name: {split_name: results_dataframe}}}
all_results: Dict[
    str, Dict[str, Dict[str, pd.DataFrame]]
] = split_results_handler.obtain_all_feature_set_data(
    return_type="split_results",
    parent_folder=data_dir_100kb.parent,
    merge_assays=False,
    include_categories=categories,
    include_sets=include_sets,
    exclude_names=exclude_names,
    verbose=False,
)  # type: ignore

In [None]:
for feature_set_name in all_results.keys():
    try:
        all_results[feature_set_name][ASSAY] = all_results[feature_set_name]["assay_epiclass_11c"]  # type: ignore
        del all_results[feature_set_name]["assay_epiclass_11c"]
    except KeyError:
        pass

In [None]:
verify_splits_identity(all_results, categories)

In [None]:
logdir = table_dir / "dfreeze_v2" / "other_feature_sets"
logdir.mkdir(parents=True, exist_ok=True)

for feature_set_name, tasks_dict in all_results.items():
    if feature_set_name == "hg38_100kb_all_none":
        continue
    all_split_dfs_concat: Dict = split_results_handler.concatenate_split_results(
        tasks_dict, concat_first_level=True
    )  # type: ignore
    for task_name, df in all_split_dfs_concat.items():
        df = prepare_df_for_save(df, md5sum_to_epirr, md5sum_to_uuid)

        filename = f"{feature_set_name}_10fold_predictions_{task_name}.csv"
        df.to_csv(logdir / filename, index=True, sep=",", float_format="%.4f")

## Winsorized files and/or blacklist zeroed

In [None]:
categories = [ASSAY, CELL_TYPE, SEX, "harmonized_biomaterial_type"]
include_sets = [
    "hg38_100kb_all_none",
    "hg38_100kb_all_none_0blklst",
    "hg38_100kb_all_none_0blklst_winsorized",
]

results_folder = base_data_dir / "training_results" / "2023-01-epiatlas-freeze"
if not results_folder.exists():
    raise FileNotFoundError(f"Folder '{results_folder}' not found")

logdir = table_dir / "2023-01-epiatlas-freeze"
if not logdir.exists():
    logdir.mkdir()

In [None]:
# Select 10-fold oversampling runs
# expected result shape: {feature_set: {task_name: {split_name: results_dataframe}}}
all_results: Dict[
    str, Dict[str, Dict[str, pd.DataFrame]]
] = split_results_handler.obtain_all_feature_set_data(
    return_type="split_results",
    parent_folder=results_folder,
    merge_assays=False,
    include_categories=categories,
    include_sets=include_sets,
    oversampled_only=False,
    verbose=False,
)  # type: ignore

display(all_results.keys())

In [None]:
tasks_collected = list(all_results["hg38_100kb_all_none"].keys())
verify_splits_identity(all_results, tasks_collected, verbose=True)

In [None]:
# save concatenated result
for feature_set_name, tasks_dict in all_results.items():
    concatenated_dfs = split_results_handler.concatenate_split_results(
        tasks_dict, concat_first_level=True
    )
    for task_name, concatenated_df in concatenated_dfs.items():
        concatenated_df = prepare_df_for_save(concatenated_df, md5sum_to_epirr, md5sum_to_uuid)  # type: ignore
        filename = f"{feature_set_name}_10fold_predictions_{task_name}.csv"
        print(f"Saving {filename}")
        concatenated_df.to_csv(
            logdir / filename, index=True, sep=",", float_format="%.4f"
        )

### Evaluate input dataset discrepancy in assay_epiclass

In [None]:
metadata_handler = MetadataHandler(paper_dir)
metadata_df = metadata_handler.load_metadata_df("v2", merge_assays=False)

In [None]:
values_counts = {}
for feature_set_name, tasks_dict in all_results.items():
    concatenated_dfs = split_results_handler.concatenate_split_results(
        tasks_dict, concat_first_level=True
    )
    md5sums = concatenated_dfs[ASSAY].index.tolist()
    print(f"{feature_set_name}: {len(md5sums)}")

    metadata_subset = metadata_df[metadata_df.index.isin(md5sums)]
    values_counts[feature_set_name] = metadata_subset[ASSAY].value_counts()

In [None]:
display(
    values_counts["hg38_100kb_all_none_0blklst"] - values_counts["hg38_100kb_all_none"]
)

## ENCODE predictions

### Metadata cleanup/merging

In [None]:
encode_metadata_dir = base_data_dir / "metadata" / "encode"

In [None]:
full_metadata_path = encode_metadata_dir / "encode_metadata_2023-10-25.csv"
full_metadata_df = pd.read_csv(full_metadata_path)
full_metadata_df["filename"] = full_metadata_df["md5sum"]
full_metadata_df.drop(columns="md5sum", inplace=True)
full_metadata_df.reset_index(drop=True, inplace=True)
print(full_metadata_df.shape)

In [None]:
# this file does not contain ctcf/non-core, so gotta fill out the missing values when merging
encode_epiatlas_mapping_path = encode_metadata_dir / "ENCODE_IHEC_keys.tsv"
encode_epiatlas_mapping_df = pd.read_csv(encode_epiatlas_mapping_path, sep="\t")
print(encode_epiatlas_mapping_df.shape)

In [None]:
if "is_EpiAtlas_EpiRR" not in full_metadata_df.columns:
    full_metadata_df = full_metadata_df.merge(
        encode_epiatlas_mapping_df, left_on="filename", right_on="ENC_ID", how="left"
    )
    full_metadata_df.replace(np.nan, "unknown", inplace=True)

# merge duplicate column names from previous merging when possible
for col in full_metadata_df.columns:
    if col.endswith("_x"):
        name1 = col
        name2 = col.replace("_x", "_y")
        if name2 not in full_metadata_df.columns:
            full_metadata_df.rename(
                columns={name1: name1.replace("_x", "")}, inplace=True
            )
            continue

        col1 = full_metadata_df[name1]
        col2 = full_metadata_df[name2]
        if not (col1 == col2).all():
            diff_mask = col1 != col2
            if (col1[diff_mask] == "unknown").all():
                full_metadata_df.drop(columns=name1, inplace=True)
                full_metadata_df.rename(
                    columns={name2: name2.replace("_y", "")}, inplace=True
                )
            elif (col2[diff_mask] == "unknown").all():
                full_metadata_df.drop(columns=name2, inplace=True)
                full_metadata_df.rename(
                    columns={name1: name1.replace("_x", "")}, inplace=True
                )
        else:
            full_metadata_df.drop(columns=name2, inplace=True)
            full_metadata_df.rename(
                columns={name1: name1.replace("_x", "")}, inplace=True
            )

In [None]:
# for col in full_metadata_df.columns:
#     if "lab" in col.lower():
#         print(full_metadata_df[col].value_counts(dropna=False), "\n")
try:
    full_metadata_df.drop(columns=["lab_x", "lab_y"], inplace=True)
except KeyError:
    pass

In [None]:
curie_def_df = pd.read_csv(
    encode_metadata_dir / "EpiAtlas_list-curie_term_HSOI.tsv",
    sep="\t",
    names=["code", "term", CELL_TYPE],
)
encode_ontology_df = pd.read_csv(encode_metadata_dir / "encode_ontol+assay.tsv", sep="\t")
partial_meta = encode_ontology_df.merge(
    curie_def_df, left_on="Biosample term id", right_on="code", how="left"
)
partial_meta.drop(columns=["code", "term"], inplace=True)

In [None]:
if CELL_TYPE not in full_metadata_df.columns:
    full_metadata_df = full_metadata_df.merge(
        partial_meta,
        left_on="filename",
        right_on="ENC_ID",
        suffixes=("_DROP", ""),
        how="left",
    )
    for col in full_metadata_df.columns:
        if col.endswith("_DROP"):
            full_metadata_df.drop(columns=col, inplace=True)

### Prediction merging

In [None]:
encode_predictions_dir = base_data_dir / "training_results" / "predictions" / "encode"
pred_dfs = {}
for folder in encode_predictions_dir.glob("*1l_3000n"):
    if not folder.is_dir():
        continue
    # match categories with dir names of format: [cat_name]_1l_3000n
    cat = folder.name.split("_1l_3000n")[0]
    pred_file = list(folder.rglob("complete_no_valid_oversample_*.csv"))[0]
    encode_df = pd.read_csv(pred_file)
    pred_dfs[cat] = encode_df

In [None]:
def merged_all_encode_preds(
    pred_dfs: Dict[str, pd.DataFrame], full_metadata_df: pd.DataFrame
) -> pd.DataFrame:
    """Merge all ENCODE predictions into a single DataFrame."""
    # define correct column names for prediction tasks
    metadata_mapping = {
        ASSAY: ASSAY,
        CELL_TYPE: CELL_TYPE,
        CANCER: "cancer_status",
        SEX: "donor_sex",
        LIFE_STAGE: "life_stage",
    }

    same_col_len = 8
    # Make all different columns have unique relevant names except for the pred vector
    new_dfs = {}
    for cat, df in pred_dfs.items():
        metadata_colname = metadata_mapping[cat]
        df = df.copy()
        df = df.drop(columns=["Same?"])
        df = df.merge(
            full_metadata_df[["filename", metadata_colname]],
            left_on="md5sum",
            right_on="filename",
            how="inner",
        )
        df["True class"] = df[metadata_mapping[cat]]
        df = df.rename(columns={"True class": "Expected class"})
        df = df.drop(columns=["filename", metadata_colname])

        old_names = df.columns[1 : same_col_len - 1]
        new_names = [f"{old_name} ({cat})" for old_name in old_names]
        df.rename(columns=dict(zip(old_names, new_names)), inplace=True)
        new_dfs[cat] = df

    df_order = [ASSAY, CELL_TYPE, SEX, LIFE_STAGE, CANCER]
    df_list = [new_dfs[cat] for cat in df_order]
    full_merged_df = functools.reduce(merge_dataframes, df_list)
    full_merged_df.reset_index(drop=True, inplace=True)
    if "md5sum" in full_merged_df.columns and "ENC" in full_merged_df.loc[0, "md5sum"]:
        full_merged_df.rename(columns={"md5sum": "filename"}, inplace=True)

    full_merged_df = full_merged_df.merge(
        full_metadata_df,
        left_on="filename",
        right_on="filename",
        how="inner",
        suffixes=("", "_DROP"),
    )
    for col in full_merged_df.columns:
        if col.endswith("_DROP"):
            full_merged_df.drop(columns=col, inplace=True)
    return full_merged_df

In [None]:
merged_all_encode_preds_df = merged_all_encode_preds(pred_dfs, full_metadata_df)

In [None]:
output_dir = table_dir / "dfreeze_v2" / "encode"
output_dir.mkdir(parents=True, exist_ok=True)
merged_all_encode_preds_df.to_csv(
    output_dir / "encode_predictions_5task.csv", index=False
)

## ChIP-Atlas predictions

## ChIP-Seq_imputed_with_RNA-Seq_only predictions

In [None]:
output_dir = table_dir / "dfreeze_v2" / "epiatlas_imputed"

Predictions are from epiclass_11c complete training (with oversampling) MLP classifer  
Training details at 0f8e5eb996114868a17057bebe64f87c (comet-ml id)

In [None]:
pred_folder = base_data_dir / "training_results" / "predictions" / "epiatlas_imputed"
pred_file = "complete_no_valid_oversample_test_prediction_100kb_all_none_ChIP-Seq_imputed_with_RNA-Seq_only.csv"
pred_df = pd.read_csv(pred_folder / pred_file)
print(pred_df.shape)

In [None]:
pred_df.rename(columns={"Unnamed: 0": "filename"}, inplace=True)

# filename of format 'impute_[ihec-id]_[expected-class]_[resolution]_[filter_in]_[filter_out].csv'
pred_df["True class"] = pred_df["filename"].str.split("_", expand=True)[2].str.lower()
pred_df.rename(columns={"True class": "Expected class"}, inplace=True)

In [None]:
idx_pred_col = np.where(pred_df.columns == "Predicted class")[0][0]
pred_df.insert(
    loc=int(idx_pred_col + 1),
    column="Same?",
    value=pred_df["Expected class"] == pred_df["Predicted class"],
)

In [None]:
print(f"Accuracy: {pred_df['Same?'].sum() / pred_df.shape[0]:.2%}")

In [None]:
non_pred_vector_cols = 4
nb_classes = 11
pred_df.insert(
    loc=non_pred_vector_cols,
    column="Max pred",
    value=pred_df.iloc[:, non_pred_vector_cols : non_pred_vector_cols + nb_classes].max(
        axis=1
    ),
)

In [None]:
pred_df.to_csv(output_dir / "epiatlas_imputed_w_rna_only_predictions.csv", index=False)

## recount3

In [105]:
split_results_handler = SplitResultsHandler()

In [106]:
recount3_folder = (
    base_data_dir
    / "training_results"
    / "predictions"
    / "recount3"
    / "hg38_100kb_all_none"
)
if not recount3_folder.exists():
    raise FileNotFoundError()

In [107]:
split_pred_files = {}
for cat in [ASSAY, SEX, LIFE_STAGE, CANCER]:
    pred_files = list(recount3_folder.rglob(f"{cat}*/**/recount3/complete_*.csv"))
    split_pred_files[cat] = pred_files

assert len(split_pred_files) == 4

In [108]:
pred_dfs = {}
for cat, pred_files in split_pred_files.items():
    dfs = []
    for pred_file in pred_files:
        df = pd.read_csv(pred_file, low_memory=False)
        dfs.append(df)
    concat_df = pd.concat(dfs, ignore_index=True)
    pred_dfs[cat] = concat_df

In [None]:
for cat, pred_df in list(pred_dfs.items()):
    try:
        pred_df = pred_df.drop("True class", axis=1)
    except KeyError:
        pass
    pred_df = pred_df.rename(columns={"Unnamed: 0": "filename"})

    # Add max pred + move it to front
    pred_df = split_results_handler.add_max_pred(pred_df, target_label="Predicted class")
    pred_df.insert(2, "Max pred", pred_df.pop("Max pred"))
    pred_df = pred_df[pred_df["Max pred"] >= 0]

    # Get id columns
    id_cols = (
        pred_df["filename"].str.split(".", expand=True)[2].str.split("_", expand=True)
    )

    pred_df.insert(1, "id1", id_cols.loc[:, 0])
    pred_df.insert(2, "id2", id_cols.loc[:, 1])

    pred_dfs[cat] = pred_df

In [111]:
# display(pred_dfs[ASSAY]["filename"].str.split(".", expand=True)[2].str.split("_",expand=True).head())
# display(pred_dfs[ASSAY]["id1"].nunique(), pred_dfs[ASSAY]["id2"].nunique())
# display(pred_dfs[ASSAY]["id2"].str.slice(0,3).value_counts())
assert pred_dfs[ASSAY]["id2"].nunique() == pred_dfs[ASSAY].shape[0]

In [None]:
meta_name = "harmonized_metadata_20250110"
metadata_file = metadata_dir / f"recount_{meta_name}.tsv"
recount_metadata_df = pd.read_csv(metadata_file, sep="\t")

In [114]:
recount_metadata_df.rename(
    mapper={
        "harmonized_assay": ASSAY,
        "harmonized_lifestage": LIFE_STAGE,
        "harmonized_sex": SEX,
        "harmonized_cancer": CANCER,
    },
    axis=1,
    inplace=True,
)
recount_metadata_df.fillna("unknown", inplace=True)

In [116]:
def merge_all_recount3_preds(
    pred_dfs: Dict[str, pd.DataFrame], full_metadata_df: pd.DataFrame
) -> pd.DataFrame:
    """Merge all recount3 predictions into a single DataFrame."""
    same_col_len = 5
    # Make all different columns have unique relevant names except for the pred vector
    new_dfs = {}
    for cat, df in pred_dfs.items():
        df = df.copy()
        df["ID"] = df["id2"]
        df = df.drop(["id1", "id2"], axis=1)
        try:
            df = df.drop(columns=["Same?"])
        except KeyError:
            pass
        df = df.merge(
            full_metadata_df[["ID", cat]],
            left_on="ID",
            right_on="ID",
            how="inner",
        )
        df.insert(1, "Expected class", df[cat])
        df = df.drop(columns=[cat])

        old_names = df.columns[1 : same_col_len - 1]
        new_names = [f"{old_name} ({cat})" for old_name in old_names]
        df.rename(columns=dict(zip(old_names, new_names)), inplace=True)

        new_dfs[cat] = df

    df_order = [ASSAY, SEX, CANCER, LIFE_STAGE]
    df_list = [new_dfs[cat] for cat in df_order]

    merge_dataframes_func = functools.partial(merge_dataframes, on="external_id")
    full_merged_df = functools.reduce(merge_dataframes_func, df_list)
    full_merged_df.reset_index(drop=True, inplace=True)

    full_merged_df = full_merged_df.merge(
        full_metadata_df,
        on="ID",
        how="inner",
        suffixes=("", "_DROP"),
    )
    for col in full_merged_df.columns:
        if col.endswith("_DROP"):
            full_merged_df.drop(columns=col, inplace=True)
    return full_merged_df

In [117]:
final_df = merge_all_recount3_preds(pred_dfs, recount_metadata_df)

In [118]:
final_df.insert(1, "ID", final_df.pop("ID"))

In [None]:
out_path = recount3_folder / f"recount3_merged_preds_{meta_name}.tsv.gz"
final_df.to_csv(out_path, sep="\t", index=False, compression="gzip")

### accuracy

In [None]:
preds_path = recount3_folder / f"recount3_merged_preds_{meta_name}.tsv.gz"
full_df = pd.read_csv(preds_path, sep="\t")

In [None]:
N = full_df.shape[0]
for max_pred in [0, 0.6, 0.8]:
    subset = full_df[full_df[f"Max pred ({ASSAY})"] >= max_pred]
    counts = subset[f"Predicted class ({ASSAY})"].value_counts()

    N_subset = counts.sum()
    counts_perc = counts / N_subset
    correct_perc = counts_perc["rna_seq"] + counts_perc["mrna_seq"]
    print(f"min_PredScore >= {max_pred} ({N_subset/N:.2%}% left): {correct_perc:.2%}")

In [None]:
for max_pred in [0, 0.6, 0.8]:
    subset = full_df[full_df[f"Max pred ({ASSAY})"] >= max_pred]
    print(f"min_PredScore >= {max_pred}")

    for cat in [SEX, CANCER, LIFE_STAGE]:
        pred_label = f"Predicted class ({cat})"
        true_label = f"Expected class ({cat})"

        if cat == CANCER:
            subset = subset.replace("healthy", "non-cancer")

        known_pred = subset[subset[true_label] != "unknown"]
        if cat == LIFE_STAGE:
            known_pred = known_pred[known_pred[true_label] != "children"]
        # print(known_pred[true_label].value_counts(dropna=False))

        classes = sorted(
            set(known_pred[pred_label].unique()) | set(known_pred[pred_label].unique())
        )

        N_known = known_pred.shape[0]
        N_unknown = subset.shape[0] - N_known
        # print(f"Unknown (%): {(N_unknown)/subset.shape[0]*100:.2f}")

        y_pred = known_pred[pred_label]
        y_true = known_pred[true_label]
        N_correct = (y_pred == y_true).sum()
        print(f"{cat} prediction match (%): {N_correct/N_known*100:.2f}")

        print(classification_report(y_true, y_pred, target_names=classes) + "\n")  # type: ignore

        print(f"confusion matrix classes row order: {classes}")
        cm = sk_cm(y_true, y_pred, normalize="true", labels=classes)
        print(str(cm) + "\n")

    print()