In [None]:
"""Workbook to analyse encode non-core predictions.
"""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines, unused-import, unused-argument, too-many-branches, pointless-statement

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import annotations

from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from IPython.display import display
from sklearn.metrics import confusion_matrix

from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    ASSAY_MERGE_DICT,
    ASSAY_ORDER,
    CELL_TYPE,
    IHECColorMap,
    MetadataHandler,
    SplitResultsHandler,
)

# import plotly.express as px
# import plotly.graph_objects as go
# from plotly.subplots import make_subplots

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
paper_dir = base_dir

if not base_fig_dir.exists():
    raise FileNotFoundError(f"Directory {base_fig_dir} does not exist.")

In [None]:
metadata_handler = MetadataHandler(paper_dir)
split_results_handler = SplitResultsHandler()

#### Getting GO info

In [None]:
encode_metadata_dir = base_data_dir / "metadata" / "encode"
curie_def_df = pd.read_csv(
    encode_metadata_dir / "EpiAtlas_list-curie_term_HSOI.tsv",
    sep="\t",
    names=["code", "term", CELL_TYPE],
)
encode_ontology_df = pd.read_csv(encode_metadata_dir / "encode_ontol+assay.tsv", sep="\t")

In [None]:
merged_df = encode_ontology_df.merge(
    curie_def_df, left_on="Biosample term id", right_on="code", how="left"
)
merged_df = merged_df.drop(columns=["code", "term"])

In [None]:
merged_df.head()

In [None]:
counts = merged_df[CELL_TYPE].value_counts(dropna=False)
display(counts / counts.sum())

#### Missing harmonized_sample_ontology_intermediate details

In [None]:
# check term on missing CELL_TYPE
missing_cell_type = merged_df[merged_df[CELL_TYPE].isna()]
print(missing_cell_type.shape)

biosample_cols = ["Biosample term id", "Biosample term name"]

missing_count = missing_cell_type[biosample_cols].value_counts()
display(missing_count.shape)
with pd.option_context(
    "display.float_format",
    "{:.2f}".format,  # pylint: disable=consider-using-f-string
    "display.max_rows",
    None,
):
    display(missing_count / missing_count.sum() * 100)

In [None]:
t_cell_types = [
    name for name in missing_cell_type["Biosample term name"].unique() if "T cell" in name
]
b_cell_types = [
    name for name in missing_cell_type["Biosample term name"].unique() if "B cell" in name
]

In [None]:
t_cell_count = missing_cell_type[
    missing_cell_type["Biosample term name"].isin(t_cell_types)
][biosample_cols].value_counts()
display(t_cell_count, t_cell_count.sum())

In [None]:
b_cell_count = missing_cell_type[
    missing_cell_type["Biosample term name"].isin(b_cell_types)
][biosample_cols].value_counts()
display(b_cell_count, b_cell_count.sum())

perc_missing = (
    (t_cell_count.sum() + b_cell_count.sum()) / missing_cell_type.shape[0] * 100
)
print(f"t+b cells, percentage of missing cell types: {perc_missing:.2f}%")

#### Match predictions from various trainings with GO info

In [None]:
pred_folder = (
    base_data_dir
    / "training_results/dfreeze_v2/hg38_100kb_all_none/harmonized_sample_ontology_intermediate_1l_3000n/complete-no_valid-oversampling"
)

In [None]:
# Only keep the predictions for the 16 cell types
accepted_ct = [
    "T cell",
    "neutrophil",
    "brain",
    "monocyte",
    "lymphocyte of B lineage",
    "myeloid cell",
    "venous blood",
    "macrophage",
    "mesoderm-derived structure",
    "endoderm-derived structure",
    "colon",
    "connective tissue cell",
    "hepatocyte",
    "mammary gland epithelial cell",
    "muscle organ",
    "extraembryonic cell",
]
print(merged_df.shape)
merged_df = merged_df[merged_df[CELL_TYPE].isin(accepted_ct)]
print(merged_df.shape)

In [None]:
pred_dfs_dict = {}
for folder in pred_folder.glob("*"):
    if not folder.is_dir():
        print(f"Skipping {folder}")
        continue
    pred_file = list(folder.glob("predictions/*.csv"))

    if len(pred_file) > 1:
        print(f"More than one prediction file found in {folder}")
        continue

    if len(pred_file) == 0:
        print(f"No prediction file found in {folder}")
        continue

    pred_file = pred_file[0]

    pred_df = pd.read_csv(pred_file)
    name = folder.name.replace("complete_no_valid_oversample_", "")
    pred_dfs_dict[name] = pred_df

In [None]:
for name, pred_df in sorted(pred_dfs_dict.items()):
    print(name)
    pred_w_ct = pred_df.merge(merged_df, left_on="md5sum", right_on="ENC_ID", how="left")
    pred_w_ct["Assay"] = pred_w_ct["Assay"].str.lower()
    pred_w_ct = pred_w_ct.dropna(subset=[CELL_TYPE])  # drop rows with missing cell type
    pred_w_ct = pred_w_ct[~pred_w_ct["Assay"].isin(ASSAY_ORDER)]

    pred_w_ct["correct_pred"] = pred_w_ct["Predicted class"] == pred_w_ct[CELL_TYPE]
    counts = (
        pred_w_ct.groupby(["Assay", CELL_TYPE, "Predicted class", "correct_pred"])
        .size()
        .sort_values(ascending=False)
    )
    total_correct = counts.loc[:, :, :, True].sum()

    perc = total_correct / pred_w_ct.shape[0]
    print(f"Total correct: {total_correct}/{pred_w_ct.shape[0]} ({perc:.2%})\n")

In [None]:
groupby_cols = ["Assay", CELL_TYPE, "Predicted class", "correct_pred"]

for name, pred_df in sorted(pred_dfs_dict.items()):
    print(name)
    pred_w_ct = pred_df.merge(merged_df, left_on="md5sum", right_on="ENC_ID", how="left")
    pred_w_ct["Assay"] = pred_w_ct["Assay"].str.lower()
    pred_w_ct = pred_w_ct.dropna(subset=[CELL_TYPE])  # drop rows with missing cell type
    pred_w_ct = pred_w_ct[~pred_w_ct["Assay"].isin(ASSAY_ORDER)]
    N = pred_w_ct.shape[0]

    # Calculate results for all predictions
    pred_w_ct["correct_pred"] = pred_w_ct["Predicted class"] == pred_w_ct[CELL_TYPE]
    counts = pred_w_ct.groupby(groupby_cols).size().sort_values(ascending=False)
    total_correct = counts.loc[:, :, :, True].sum()
    perc = total_correct / N
    print(f"Acc (pred>0.0) {total_correct}/{N} ({perc:.2%})")

    # Calculate results for predictions with max_pred > 0.8
    pred_w_ct_filtered = pred_w_ct[pred_w_ct["Max pred"] > 0.8]
    counts_filtered = (
        pred_w_ct_filtered.groupby(groupby_cols).size().sort_values(ascending=False)
    )
    total_correct_filtered = counts_filtered.loc[:, :, :, True].sum()
    perc_filtered = total_correct_filtered / pred_w_ct_filtered.shape[0]
    print(
        f"Acc (pred>0.8): {total_correct_filtered}/{pred_w_ct_filtered.shape[0]} ({perc_filtered:.2%})"
    )
    diff = N - pred_w_ct_filtered.shape[0]
    print(f"Samples ignored at 0.8: {diff} ({diff/N:.2%})\n")

    # Uncomment the following lines if you want to display additional information
    # if "assay" in name.lower():
    #     with pd.option_context(
    #         "display.float_format",
    #         "{:.3f}".format,
    #         "display.max_rows",
    #         None,
    #     ):
    #         values_count = pred_w_ct["Assay"].value_counts()
    #         # display(values_count)
    #         display(values_count / values_count.sum())
    #         display(counts)

In [None]:
for name, pred_df in sorted(pred_dfs_dict.items()):
    print(name)
    pred_w_ct = pred_df.merge(merged_df, left_on="md5sum", right_on="ENC_ID", how="left")
    pred_w_ct["Assay"] = pred_w_ct["Assay"].str.lower()
    pred_w_ct = pred_w_ct.dropna(subset=[CELL_TYPE])  # drop rows with missing cell type
    pred_w_ct = pred_w_ct[~pred_w_ct["Assay"].isin(ASSAY_ORDER)]
    pred_w_ct = pred_w_ct[pred_w_ct["Max pred"] > 0.8]

    # Count real samples for each cell type
    real_samples_count = pred_w_ct[CELL_TYPE].value_counts()

    # Create confusion matrix
    cm = confusion_matrix(
        pred_w_ct[CELL_TYPE], pred_w_ct["Predicted class"], labels=accepted_ct
    )

    # Convert to percentages (each row sums to 1)
    cm_percentage = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]

    # Create x-axis labels with sample counts
    ticklabels_w_count = [
        f"{ct}\n(n={real_samples_count.get(ct, 0)})" for ct in accepted_ct
    ]

    # Create a heatmap of the percentage-based confusion matrix
    plt.figure(figsize=(24, 20))  # Increased figure size
    sns.heatmap(
        cm_percentage,
        annot=True,
        fmt=".2%",
        cmap="Blues",
        xticklabels=accepted_ct,
        yticklabels=ticklabels_w_count,
        vmin=0,
        vmax=1,
        annot_kws={"size": 10},  # Increased annotation font size
        cbar_kws={"shrink": 0.8},
    )  # Adjust colorbar size

    plt.title(f"Confusion Matrix (%) for {name}", fontsize=20)
    plt.xlabel("Predicted", fontsize=16)
    plt.ylabel("Actual", fontsize=16)
    plt.xticks(fontsize=10, rotation=90, ha="center")
    plt.yticks(fontsize=12, rotation=0)

    # Adjust bottom margin to accommodate longer x-axis labels
    plt.gcf().subplots_adjust(bottom=0.2)

    plt.tight_layout()

    accuracy = np.trace(cm) / np.sum(cm)
    print(f"Accuracy: {accuracy:.2%} ({np.trace(cm)} / {np.sum(cm)})")

    plt.show()