### SETUP

In [None]:
"""Workbook to analyse classifier predictions on recount3 data.
"""

In [None]:
from __future__ import annotations

from pathlib import Path

import numpy as np
import pandas as pd
from IPython.display import display
from sklearn.metrics import classification_report, confusion_matrix as sk_cm

from epi_ml.utils.notebooks.paper.paper_utilities import ASSAY, LIFE_STAGE, SEX

In [None]:
DISEASE = "harmonized_sample_disease_high"
CANCER = "harmonized_sample_cancer_high"
BIOMAT = "harmonized_biomaterial_type"

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
paper_dir = base_dir

base_fig_dir = base_dir / "figures"

table_dir = base_dir / "tables"

base_data_dir = base_dir / "data"
metadata_dir = base_data_dir / "metadata"
predictions_dir = base_data_dir / "training_results" / "predictions"

recount3_folder = predictions_dir / "recount3" / "hg38_100kb_all_none"

In [None]:
meta_name = "harmonized_metadata_20250122_leuk2"
preds_path = recount3_folder / f"recount3_merged_preds_{meta_name}.tsv.gz"
full_df = pd.read_csv(preds_path, sep="\t")

### Assay predictions details

In [None]:
assay_df = full_df[full_df[ASSAY] != "unknown"]
N = assay_df.shape[0]

for max_pred in [0, 0.6, 0.8]:
    subset = assay_df[assay_df[f"Max pred ({ASSAY})"] >= max_pred]
    counts = subset[f"Predicted class ({ASSAY})"].value_counts()

    N_subset = counts.sum()
    counts_perc = counts / N_subset
    correct_perc = counts_perc["rna_seq"] + counts_perc["mrna_seq"]
    print(f"min_PredScore >= {max_pred} ({N_subset/N:.2%} left): {correct_perc:.2%}\n")

    print("Predictions grouped, assay types left as is")
    groupby = (
        subset.groupby([ASSAY, f"Predicted class ({ASSAY})"])
        .size()
        .reset_index()
        .rename(columns={0: "Count"})
        .sort_values(by=[ASSAY, "Count"], ascending=[True, False])
    )
    print(groupby, "\n")

    print("Predictions grouped, all rna types = rna")
    tmp_df = subset.copy()
    tmp_df.loc[:, ASSAY] = "rna_seq"
    tmp_df.loc[:, f"Predicted class ({ASSAY})"].replace(
        "mrna_seq", "rna_seq", inplace=True
    )
    groupby = (
        tmp_df.groupby([ASSAY, f"Predicted class ({ASSAY})"])
        .size()
        .reset_index()
        .rename(columns={0: "Count"})
        .sort_values(by=[ASSAY, "Count"], ascending=[True, False])
    )
    print(groupby, "\n")

    print("Breakdown by assay type")
    assay_breakdown = subset[ASSAY].value_counts(dropna=False)
    print(assay_breakdown / assay_breakdown.sum(), "\n")
    for assay_type in assay_breakdown.index:
        assay_type_subset = subset[subset[ASSAY] == assay_type].copy()

        counts = assay_type_subset[f"Predicted class ({ASSAY})"].value_counts()
        N_subset = counts.sum()
        counts_perc = counts / N_subset
        correct_perc = counts_perc["rna_seq"] + counts_perc["mrna_seq"]
        print(f"{assay_type} acc: {correct_perc:.2%}\n")
        print(f"{assay_type} preds:\n{counts_perc}\n")
    print()

### Other metadata categories

In [None]:
for cat in [SEX, CANCER, LIFE_STAGE, BIOMAT]:
    display(full_df[cat].value_counts(dropna=False))

In [None]:
df = full_df.copy(deep=True)
for max_pred in [0, 0.6, 0.8]:
    subset = df[df[f"Max pred ({ASSAY})"] >= max_pred]
    print(f"min_PredScore >= {max_pred}\n")

    for cat in [SEX, CANCER, LIFE_STAGE, BIOMAT]:
        pred_label = f"Predicted class ({cat})"
        true_label = f"Expected class ({cat})"

        if cat == CANCER:
            subset = subset.replace("healthy", "non-cancer")

        known_pred = subset[~subset[true_label].isin(["unknown", "other"])]
        if cat == LIFE_STAGE:
            diff = len(known_pred)
            known_pred = known_pred[known_pred[BIOMAT] != "cell line"]
            diff -= len(known_pred)
            print(f"Excluded cell lines for {cat} predictions: {diff}")

        # print(known_pred[true_label].value_counts(dropna=False))

        classes = sorted(
            set(known_pred[pred_label].unique()) | set(known_pred[pred_label].unique())
        )

        N_known = known_pred.shape[0]
        N_unknown = subset.shape[0] - N_known
        # print(f"Unknown (%): {(N_unknown)/subset.shape[0]*100:.2f}")

        y_pred = known_pred[pred_label]
        y_true = known_pred[true_label]
        N_correct = (y_pred == y_true).sum()
        print(f"{cat} prediction match (%): {N_correct/N_known*100:.2f}\n")
        print(classes)
        print(y_pred.value_counts(dropna=False), "\n")
        print(y_true.value_counts(dropna=False), "\n")

        print(classification_report(y_true, y_pred, target_names=classes, zero_division=0) + "\n")  # type: ignore

        print(f"confusion matrix classes row order: {classes}")
        cm = sk_cm(y_true, y_pred, normalize="true", labels=classes)
        with np.printoptions(precision=3):
            print(str(cm) + "\n\n")

    print("-----")