In [None]:
"""Workbook to analyse classifier predictions on recount3 data."""

# pylint: disable=duplicate-code

In [None]:
%load_ext autoreload
%autoreload 2

### SETUP

In [None]:
from __future__ import annotations

from pathlib import Path

import numpy as np
import pandas as pd
from IPython.display import display  # pylint: disable=unused-import
from sklearn.metrics import classification_report, confusion_matrix as sk_cm

from epi_ml.utils.notebooks.paper.metrics_per_assay import MetricsPerAssay
from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    BIOMATERIAL_TYPE,
    CANCER,
    LIFE_STAGE,
    SEX,
    merge_life_stages,
)

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
paper_dir = base_dir

base_fig_dir = base_dir / "figures"

table_dir = base_dir / "tables"

base_data_dir = base_dir / "data"
metadata_dir = base_data_dir / "metadata" / "recount3"

predictions_dir = table_dir / "dfreeze_v2" / "predictions"

In [None]:
meta_name = "harmonized_metadata_20250122_leuk2"
preds_path = predictions_dir / f"recount3_merged_preds_{meta_name}.tsv.gz"

full_df = pd.read_csv(preds_path, sep="\t", compression="gzip")
print(full_df.shape)

In [None]:
cell_line_vals = ["cell_line", "cell line", "unknown"]

In [None]:
# Sanity check life stage filtering

# print(full_df[BIOMATERIAL_TYPE].value_counts(dropna=False))

# assay_pred_col = "Predicted class (assay_epiclass)"
# assay_max_pred_col = "Max pred (assay_epiclass)"

# cond1 = full_df[assay_pred_col].isin(["rna_seq", "mrna_seq"])
# cond2 = full_df[assay_max_pred_col] > 0.6
# df = full_df[cond1 & cond2]
# print("After 11c filtering (m/rna > 0.6)")
# print(df.shape)
# print(df[LIFE_STAGE].value_counts(dropna=False), "\n")

# cond3 = df[LIFE_STAGE] != "unknown"
# df = df[cond3]
# print(df.shape)
# print("After unknown filtering")
# print(df[LIFE_STAGE].value_counts(dropna=False), "\n")

# cond4 = ~(df[BIOMATERIAL_TYPE].isin(cell_line_vals))
# df = df[cond4]
# print("After cell line filtering")
# print(df.shape)
# print(df[LIFE_STAGE].value_counts(dropna=False), "\n")
# print(df[BIOMATERIAL_TYPE].value_counts(dropna=False), "\n")

### Assay predictions details

In [None]:
assay_df = full_df[full_df[ASSAY] != "unknown"]
N = assay_df.shape[0]

for max_pred in [0, 0.6, 0.8]:
    # continue
    subset = assay_df[assay_df[f"Max pred ({ASSAY})"] >= max_pred]
    counts = subset[f"Predicted class ({ASSAY})"].value_counts()

    N_subset = counts.sum()
    counts_perc = counts / N_subset
    correct_perc = counts_perc["rna_seq"] + counts_perc["mrna_seq"]
    print(f"min_PredScore >= {max_pred} ({N_subset/N:.2%} left): {correct_perc:.2%}\n")

    print("Predictions grouped, assay types left as is")
    groupby = (
        subset.groupby([ASSAY, f"Predicted class ({ASSAY})"])
        .size()
        .reset_index()
        .rename(columns={0: "Count"})
        .sort_values(by=[ASSAY, "Count"], ascending=[True, False])
    )
    print(groupby, "\n")

    print("Predictions grouped, all rna types = rna")
    tmp_df = subset.copy()
    tmp_df.loc[:, ASSAY] = "rna_seq"
    tmp_df.loc[:, f"Predicted class ({ASSAY})"].replace(
        "mrna_seq", "rna_seq", inplace=True
    )
    groupby = (
        tmp_df.groupby([ASSAY, f"Predicted class ({ASSAY})"])
        .size()
        .reset_index()
        .rename(columns={0: "Count"})
        .sort_values(by=[ASSAY, "Count"], ascending=[True, False])
    )
    print(groupby, "\n")

    print("Breakdown by assay type")
    assay_breakdown = subset[ASSAY].value_counts(dropna=False)
    print(assay_breakdown / assay_breakdown.sum(), "\n")
    for assay_type in assay_breakdown.index:
        assay_type_subset = subset[subset[ASSAY] == assay_type].copy()

        counts = assay_type_subset[f"Predicted class ({ASSAY})"].value_counts()
        N_subset = counts.sum()
        counts_perc = counts / N_subset
        correct_perc = counts_perc["rna_seq"] + counts_perc["mrna_seq"]
        print(f"{assay_type} acc: {correct_perc:.2%}\n")
        print(f"{assay_type} preds:\n{counts_perc}\n")
    print()

### Other metadata categories

In [None]:
for cat in [SEX, CANCER, LIFE_STAGE, BIOMATERIAL_TYPE]:
    print(full_df[cat].value_counts(dropna=False), "\n")

In [None]:
df = full_df.copy(deep=True)

for max_pred in [0, 0.6, 0.8]:
    # for max_pred in [0]:
    # continue
    subset = df[df[f"Max pred ({ASSAY})"] >= max_pred]
    print(f"min_PredScore >= {max_pred}\n")

    for cat in [SEX, CANCER, LIFE_STAGE, BIOMATERIAL_TYPE]:
        pred_label = f"Predicted class ({cat})"
        true_label = f"Expected class ({cat})"

        if cat == CANCER:
            subset = subset.replace("healthy", "non-cancer")

        known_pred = subset[~subset[true_label].isin(["unknown", "other"])]
        if cat == LIFE_STAGE:
            diff = len(known_pred)
            known_pred = known_pred[known_pred[BIOMATERIAL_TYPE] != "cell line"]
            diff -= len(known_pred)
            print(f"Excluded cell lines for {cat} predictions: {diff}")

        # print(known_pred[true_label].value_counts(dropna=False))
        y_pred = known_pred[pred_label].str.lower().str.replace(" ", "_")
        y_true = known_pred[true_label].str.lower().str.replace(" ", "_")

        classes = sorted(set(y_pred.unique()) | set(y_true.unique()))

        N_known = known_pred.shape[0]
        N_unknown = subset.shape[0] - N_known
        # print(f"Unknown (%): {(N_unknown)/subset.shape[0]*100:.2f}")

        N_correct = (y_pred == y_true).sum()
        print(f"{cat} prediction match (%): {N_correct/N_known*100:.2f}\n")
        print(classes)
        print(y_pred.value_counts(dropna=False), "\n")
        print(y_true.value_counts(dropna=False), "\n")

        print(classification_report(y_true, y_pred, target_names=classes, zero_division=0) + "\n")  # type: ignore

        print(f"confusion matrix classes row order: {classes}")
        cm = sk_cm(y_true, y_pred, normalize="true", labels=classes)
        with np.printoptions(precision=3):
            print(str(cm) + "\n\n")

    print("-----")

### Accuracy and F1-score summary.

In [None]:
print(df.shape)

In [None]:
metrics_handler = MetricsPerAssay()

In [None]:
output_dir = table_dir / "dfreeze_v2" / "predictions" / "metrics"

All files

In [None]:
categories = [CANCER, SEX, BIOMATERIAL_TYPE]
column_templates = {
    "True": "Expected class ({})",
    "Predicted": "Predicted class ({})",
    "Max pred": "Max pred ({})",
}
compute_fct_kwargs = {
    "no_epiatlas": False,
    "merge_assays": False,
    "categories": categories,
    "column_templates": column_templates,
    "core_assays": df[ASSAY].unique().tolist(),
    "non_core_assays": [],
}

In [None]:
base_filename = "recount3_metrics_per_assay"

metrics_handler.compute_multiple_metric_formats(
    preds=df,
    folders_to_save=[output_dir],
    general_filename=base_filename,
    verbose=False,
    return_df=False,
    compute_fct_kwargs=compute_fct_kwargs,
)

Only files where Assay predictions are (m)rna-seq and predScore >= 0.6

In [None]:
base_filename = "recount3_metrics_per_assay_assay11c-filtered"

filtered_df = df[
    (df[f"Max pred ({ASSAY})"] >= 0.6)
    & (df[f"Predicted class ({ASSAY})"].isin(["rna_seq", "mrna_seq"]))
]

metrics_handler.compute_multiple_metric_formats(
    preds=filtered_df,  # type: ignore
    folders_to_save=[output_dir],
    general_filename=base_filename,
    verbose=False,
    return_df=False,
    compute_fct_kwargs=compute_fct_kwargs,
)

Merging messenger and total RNA for a new assay_epiclass label.

In [None]:
for df, filename in zip(
    [filtered_df.copy(), full_df.copy()],
    [
        "recount3_metrics_per_assay_merge_total_mrna_assay11c-filtered",
        "recount3_metrics_per_assay_merge_total_mrna",
    ],
):
    df[ASSAY].replace(
        {
            "mrna_seq": "messenger_or_total_rna",
            "rna_seq": "messenger_or_total_rna",
        },
        inplace=True,
    )
    compute_fct_kwargs.update(
        {
            "core_assays": df[ASSAY].unique().tolist(),
        }
    )

    metrics_handler.compute_multiple_metric_formats(
        preds=tmp_df,  # type: ignore
        folders_to_save=[output_dir],
        general_filename=filename,
        verbose=False,
        return_df=False,
        compute_fct_kwargs=compute_fct_kwargs,
    )

No cell line (for life stage)

In [None]:
compute_fct_kwargs.update(
    {
        "categories": [LIFE_STAGE, f"{LIFE_STAGE}_merged"],
    }
)
# Gotta exclude 'unknown' biomaterial type since it could be cell lines
no_cell_line_df = filtered_df[
    ~filtered_df[BIOMATERIAL_TYPE].isin(["cell line", "unknown"])
].copy()
print(no_cell_line_df[BIOMATERIAL_TYPE].value_counts(dropna=False), "\n")
print(no_cell_line_df[LIFE_STAGE].value_counts(dropna=False))

In [None]:
no_cell_line_df = merge_life_stages(
    no_cell_line_df, column_name_templates=list(column_templates.values()) + ["{}"]
)
print(no_cell_line_df[f"{LIFE_STAGE}_merged"].value_counts(dropna=False))

In [None]:
base_filename = "recount3_metrics_per_assay_assay11c-filtered_no_cell_line"

metrics_handler.compute_multiple_metric_formats(
    preds=no_cell_line_df,  # type: ignore
    folders_to_save=[output_dir],
    general_filename=base_filename,
    verbose=False,
    return_df=False,
    compute_fct_kwargs=compute_fct_kwargs,
)