In [None]:
"""Workbook to analyse classifier predictions on recount3 data."""

# pylint: disable=duplicate-code

In [None]:
%load_ext autoreload
%autoreload 2

## SETUP

In [None]:
from __future__ import annotations

from pathlib import Path

import pandas as pd
from IPython.display import display  # pylint: disable=unused-import

from epi_ml.utils.notebooks.paper.metrics_per_assay import MetricsPerAssay
from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    BIOMATERIAL_TYPE,
    CANCER,
    LIFE_STAGE,
    SEX,
    check_label_coherence,
    format_labels,
    merge_life_stages,
    rename_columns,
)

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
paper_dir = base_dir

base_fig_dir = base_dir / "figures"

table_dir = base_dir / "tables"

base_data_dir = base_dir / "data"
metadata_dir = base_data_dir / "metadata" / "recount3"

preds_dir = table_dir / "dfreeze_v2" / "predictions"

In [None]:
full_preds_path = preds_dir / "recount3_merged_preds_metadata_freeze1.csv.xz"

full_df = pd.read_csv(
    full_preds_path,
    sep=",",
    low_memory=False,
    compression="xz",
)
full_df.fillna("unknown", inplace=True)
full_df.replace("indeterminate", "unknown", inplace=True)

In [None]:
cell_line_vals = ["cell_line", "cell line", "unknown", "other"]

### Reformat labels/columns

In [None]:
# Drop non-relevant columns
to_drop = [
    col
    for col in full_df.columns
    if any(l in col for l in ["extracted_term", "combined_"])
]
full_df.drop(columns=to_drop, inplace=True)

In [None]:
# Rename columns
categories = ["assay", "sex", "cancer", "lifestage", "biomat"]
proper_categories = [ASSAY, SEX, CANCER, LIFE_STAGE, BIOMATERIAL_TYPE]
cat_remapping = dict(zip(categories, proper_categories))

to_rename = {
    f"expected_{name}": f"Expected class ({cat_remapping[name]})" for name in categories
}
full_df = rename_columns(full_df, to_rename, exact_match=True, verbose=True)

In [None]:
all_categories = proper_categories

column_templates = {
    "True": "Expected class ({})",
    "Predicted": "Predicted class ({})",
}

all_columns = []
for cat in all_categories:
    all_columns.append(column_templates["True"].format(cat))
    all_columns.append(column_templates["Predicted"].format(cat))

full_df = format_labels(full_df, all_columns)

In [None]:
# Post-modification check
all_categories.remove(ASSAY)
check_label_coherence(full_df, all_categories, column_templates)

## Computing metrics

### Assay predictions details

In [None]:
true_col_label = column_templates["True"].format(ASSAY)
pred_col_label = column_templates["Predicted"].format(ASSAY)
max_pred_label = f"Max pred ({ASSAY})"

orderby_cols = [true_col_label, pred_col_label]

mask = full_df[true_col_label].isin(["unknown"])
assay_df = full_df[~mask].copy()

N = assay_df.shape[0]

for max_pred in [0, 0.6, 0.8]:
    # continue
    subset = assay_df[assay_df[max_pred_label] >= max_pred]
    counts = subset[pred_col_label].value_counts(dropna=False)

    N_subset = counts.sum()
    counts_perc = counts / N_subset
    correct_perc = counts_perc["rna_seq"] + counts_perc["mrna_seq"]
    print(f"min_PredScore >= {max_pred} ({N_subset/N:.2%} left): {correct_perc:.2%}\n")

    print("Predictions grouped, assay types left as is")
    groupby = (
        subset.groupby(orderby_cols)
        .size()
        .reset_index()
        .rename(columns={0: "Count"})
        .sort_values(by=[true_col_label, "Count"], ascending=[True, False])
    )
    print(groupby, "\n")

    print("Predictions grouped, all rna types = rna")
    tmp_df = subset.copy()
    tmp_df.loc[:, true_col_label] = "rna_seq"
    tmp_df.loc[:, pred_col_label].replace("mrna_seq", "rna_seq", inplace=True)
    groupby = (
        tmp_df.groupby(orderby_cols)
        .size()
        .reset_index()
        .rename(columns={0: "Count"})
        .sort_values(by=[true_col_label, "Count"], ascending=[True, False])
    )
    print(groupby, "\n")

    print("Breakdown by assay type")
    assay_breakdown = subset[true_col_label].value_counts(dropna=False)
    print(assay_breakdown / assay_breakdown.sum(), "\n")
    for assay_type in assay_breakdown.index:
        assay_type_subset = subset[subset[true_col_label] == assay_type].copy()

        counts = assay_type_subset[pred_col_label].value_counts()
        N_subset = counts.sum()
        counts_perc = counts / N_subset
        correct_perc = counts_perc["rna_seq"] + counts_perc["mrna_seq"]
        print(f"{assay_type} acc: {correct_perc:.2%}\n")
        print(f"{assay_type} preds:\n{counts_perc}\n")
    print()

### Accuracy and F1-score summary.

In [None]:
df = full_df.copy(deep=True)
print(df.shape)

In [None]:
metrics_handler = MetricsPerAssay()

In [None]:
output_dir = table_dir / "dfreeze_v2" / "predictions" / "metrics"

All files

In [None]:
categories = [CANCER, SEX, BIOMATERIAL_TYPE]

column_templates["Max pred"] = "Max pred ({})"

# ASSAY needs to exist in full_df
full_df[ASSAY] = full_df[true_col_label]

compute_fct_kwargs = {
    "no_epiatlas": False,
    "merge_assays": False,
    "categories": categories,
    "column_templates": column_templates,
    "core_assays": df[true_col_label].unique().tolist(),
    "non_core_assays": [],  # no "non-core" assays
}

In [None]:
base_filename = "recount3_metrics_per_assay"

metrics_handler.compute_multiple_metric_formats(
    preds=full_df.copy(),
    folders_to_save=[output_dir],
    general_filename=base_filename,
    verbose=False,
    return_df=False,
    compute_fct_kwargs=compute_fct_kwargs,
)

Only files where Assay predictions are (m)rna-seq and predScore >= 0.6

In [None]:
base_filename = "recount3_metrics_per_assay_assay11c-filtered"

print(full_df.shape)
filtered_df = full_df[
    (full_df[max_pred_label] >= 0.6)
    & (full_df[pred_col_label].isin(["rna_seq", "mrna_seq"]))
].copy()
print(filtered_df.shape)

metrics_handler.compute_multiple_metric_formats(
    preds=filtered_df.copy(),  # type: ignore
    folders_to_save=[output_dir],
    general_filename=base_filename,
    verbose=False,
    return_df=False,
    compute_fct_kwargs=compute_fct_kwargs,
)

No cell line (for life stage)

In [None]:
new_compute_fct_kwargs = compute_fct_kwargs.copy()
new_compute_fct_kwargs["categories"] = [f"{LIFE_STAGE}_merged"]

biomat_col = column_templates["True"].format(BIOMATERIAL_TYPE)

for df, filename in zip(
    [filtered_df.copy(), full_df.copy()],
    [
        "recount3_metrics_per_assay_assay11c-filtered_no_cell_line",
        "recount3_metrics_per_assay_no_cell_line",
    ],
):
    print(filename)
    df = df[~df[biomat_col].isin(cell_line_vals)]
    print(df.shape)

    df = merge_life_stages(
        df=df,
        lifestage_column_name=LIFE_STAGE,
        column_name_templates=list(column_templates.values()),
        verbose=True,
    )

    metrics_handler.compute_multiple_metric_formats(
        preds=df,  # type: ignore
        folders_to_save=[output_dir],
        general_filename=filename,
        verbose=False,
        return_df=False,
        compute_fct_kwargs=new_compute_fct_kwargs,
    )