In [None]:
"""Workbook destined to analyze mis-predictions from various cell type metadata groupings.
"""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines, unused-import, unused-argument, too-many-branches

## SETUP

In [None]:
%load_ext autoreload
%autoreload 2

In [39]:
from __future__ import annotations

import functools
from pathlib import Path
from typing import Dict

import pandas as pd
from sklearn.metrics import confusion_matrix as sk_cm

from epi_ml.core.confusion_matrix import ConfusionMatrixWriter
from epi_ml.utils.classification_merging_utils import merge_dataframes
from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    CELL_TYPE,
    MetadataHandler,
    SplitResultsHandler,
)

In [40]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
paper_dir = base_dir

if not base_fig_dir.exists():
    raise FileNotFoundError(f"Directory {base_fig_dir} does not exist.")

In [41]:
split_results_handler = SplitResultsHandler()
metadata_handler = MetadataHandler(paper_dir)

In [42]:
file_info = metadata_handler.load_metadata_df("v2", merge_assays=False)
file_info.reset_index(drop=False, inplace=True)
file_info = file_info[["epirr_id_without_version", "uuid", "md5sum", "track_type", ASSAY]]

In [43]:
official_metadata_dir = base_data_dir / "metadata" / "official"

metadata_v1_2_path = (
    official_metadata_dir / "IHEC_metadata_harmonization.v1.2.extended.csv"
)
metadata_v1_2 = pd.read_csv(metadata_v1_2_path, index_col=False)

In [44]:
file_metadata = file_info.merge(metadata_v1_2, how="left", on="epirr_id_without_version")

In [None]:
data_dir_100kb = base_data_dir / "training_results" / "dfreeze_v2" / "hg38_100kb_all_none"

all_split_dfs = split_results_handler.general_split_metrics(
    results_dir=data_dir_100kb,
    merge_assays=False,
    include_categories=["cell_type_PE", "cell_type_martin", CELL_TYPE],
    exclude_names=["27", "16"],
    return_type="split_results",
    oversampled_only=True,
    verbose=False,
)

In [53]:
all_split_dfs_concat: Dict[
    str, pd.DataFrame
] = split_results_handler.concatenate_split_results(
    all_split_dfs, concat_first_level=True  # type: ignore
)

for name, df in all_split_dfs_concat.items():
    split = df.pop("split")
    df.insert(0, "split", split)
    df = split_results_handler.add_max_pred(df)
    max_pred = df.pop("Max pred")
    df.insert(3, "Max pred", max_pred)
    all_split_dfs_concat[name] = df
    df.rename(columns={"True class": "Expected class"}, inplace=True)

## For each category: pivot per predicted class

In [54]:
groupby_cols = ["EpiRR", "Expected class", "Predicted class"]

for name, df in all_split_dfs_concat.items():
    new_df = df.reset_index(drop=True)
    new_df = new_df.merge(file_metadata, how="left", on="md5sum")
    new_df.rename(columns={"Max pred": "pred_score"}, inplace=True)

    # files per epirr
    epirr_counts = new_df.groupby("EpiRR").size().reset_index(name="total_epirr_files")

    # Do your original groupby
    groupby = new_df.groupby(groupby_cols).agg(
        {"pred_score": ["mean", "median", "std", "count", "min", "max"]}, axis=1
    )
    # Reset index and flatten column names
    groupby = groupby.reset_index()
    groupby.columns = [
        col[0] if col[1] == "" else f"{col[0]},{col[1]}" for col in groupby.columns
    ]
    groupby = groupby.merge(epirr_counts, on="EpiRR")

    # Calculate the percentage
    groupby["count_ratio"] = (
        groupby[("pred_score,count")] / groupby["total_epirr_files"] * 100
    ).round(2)

    groupby = groupby.sort_values(["EpiRR", "count_ratio"], ascending=[True, False])
    groupby_w_metadata = groupby.merge(metadata_v1_2, how="left", on="EpiRR")

    groupby_w_metadata.to_csv(
        data_dir_100kb / f"{name}_pivot_predicted_class.csv", index=False
    )

## Merge all dfs

In [48]:
def merge_all_dfs(pred_dfs: Dict[str, pd.DataFrame]) -> pd.DataFrame:
    """Merge all different cell type predictions into a single DataFrame."""
    # Make all different columns have unique relevant names except for the pred vector
    same_col_len = 4
    new_dfs = {}
    for i, (cat, df) in enumerate(pred_dfs.items()):
        df = df.copy()

        old_names = df.columns[0:same_col_len]
        new_names = [f"{old_name} ({cat})" for old_name in old_names]
        df.rename(columns=dict(zip(old_names, new_names)), inplace=True)

        pred_vector_cols = df.columns[same_col_len:-1]
        new_names = [f"{pred_vector}_df{i}" for pred_vector in pred_vector_cols]
        df.rename(columns=dict(zip(pred_vector_cols, new_names)), inplace=True)

        df.reset_index(drop=True, inplace=True)

        new_dfs[cat] = df

    merge_dataframes_func = functools.partial(merge_dataframes)
    full_merged_df = functools.reduce(merge_dataframes_func, new_dfs.values())
    md5sum = full_merged_df.pop("md5sum")
    full_merged_df.insert(0, "md5sum", md5sum)
    return full_merged_df

In [49]:
full_merged_df = merge_all_dfs(all_split_dfs_concat)

In [50]:
final_df = full_merged_df.merge(file_metadata, how="left", on="md5sum")

In [15]:
# final_df.to_csv(
#     data_dir_100kb / "all_custom_cell_type_predictions_augmented.csv",
# )

### Confusion matrices

In [55]:
logdir = (
    base_fig_dir
    / "flagship"
    / "ct_assay_accuracy"
    / "other_cell_type_groupings"
    / "confusion_matrices"
)
if not logdir.exists():
    logdir.mkdir(parents=True, exist_ok=True)

for name, df in all_split_dfs_concat.items():
    y_pred = df["Predicted class"]
    y_true = df["Expected class"]
    labels = sorted(set(y_true) | set(y_pred))

    this_logdir = logdir / name
    this_logdir.mkdir(parents=False, exist_ok=True)

    for minPredScore in [0, 0.6, 0.8]:
        df = df[df["Max pred"] >= minPredScore]
        y_pred = df["Predicted class"]
        y_true = df["Expected class"]

        cm = sk_cm(y_true, y_pred, normalize=None, labels=labels)
        cm_writer = ConfusionMatrixWriter(
            labels=labels,
            confusion_matrix=cm,
        )

        filename = f"full-10fold-validation_prediction-confusion-matrix-threshold-{minPredScore:.2f}_{name}"
        cm_writer.to_all_formats(
            logdir=this_logdir,
            name=filename,
        )