In [None]:
"""Plot accuracy, precision, and subset size for different probability thresholds."""

# pylint: disable=line-too-long, redefined-outer-name, import-error, pointless-statement, use-dict-literal, expression-not-assigned, unused-import, too-many-lines, too-many-branches

## SETUP

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import annotations

from pathlib import Path
from typing import Dict, List, Sequence, Tuple

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import display
from PIL import ImageColor
from plotly.subplots import make_subplots
from sklearn.metrics import accuracy_score, f1_score

from epi_ml.utils.general_utility import get_valid_filename
from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    ASSAY_MERGE_DICT,
    ASSAY_ORDER,
    BIOMATERIAL_TYPE,
    CANCER,
    CELL_TYPE,
    LIFE_STAGE,
    SEX,
    SplitResultsHandler,
    format_labels,
    merge_life_stages,
    rename_columns,
)

In [None]:
float_seq = np.typing.NDArray[np.floating] | Sequence[float | np.floating]

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
paper_dir = base_dir
if not paper_dir.exists():
    raise FileNotFoundError(f"Directory {paper_dir} does not exist.")

base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
table_dir = paper_dir / "tables"

In [None]:
core7_assays = ASSAY_ORDER[:7]
core9_assays = ASSAY_ORDER

## Confidence threshold impact on accuracy

In [None]:
# DB: {"results": dict, "other_info": dict}
all_threshold_results = {}

### Graphing and co. functions

In [None]:
def compute_metrics(
    df: pd.DataFrame,
    threshold: float,
    true_col: str,
    pred_col: str,
    pred_prob_cols: List[str],
    target_class: str | None,
) -> Tuple[float, float, float, float]:
    """
    Compute accuracy, precision, and subset size for a given probability threshold and class.

    Parameters:
    df (pd.DataFrame): The input DataFrame containing the true labels, predicted labels, and predicted probabilities.
    threshold (float): The probability threshold for filtering the DataFrame.
    true_col (str): The column name containing the true labels.
    pred_col (str): The column name containing the predicted labels.
    pred_prob_cols (List[str]): List of column names containing the predicted probabilities.
    target_class (str|None): The class for which precision is to be calculated. Return np.nan if None.

    Considers target class for computations if given, otherwise considers all samples.

    Returns:
    Tuple[float, float, float, float]: A tuple containing the threshold, the calculated accuracy (%), the calculated precision (%),
                                       and the subset size (%) respectively.
    """
    # Targeting a class or not
    if target_class in [None, "all"]:
        total_size = len(df)
    else:
        total_size = len(df[true_col] == target_class)

    # Filter rows where the max predicted probability is above the threshold
    try:
        subset_df = df[df[pred_prob_cols].max(axis=1) >= threshold]
    except TypeError as e:
        print(
            f"Error: Could not filter rows.\npred_cols: {pred_prob_cols}\nthreshold: {threshold}"
        )
        raise e

    if len(subset_df) == 0:
        return np.nan, np.nan, np.nan, np.nan

    # Calculate the accuracy for this subset
    if target_class in [None, "all"]:
        correct_preds = np.sum(subset_df[true_col] == subset_df[pred_col])
        subset_size = len(subset_df)
    else:
        correct_preds = np.sum(
            (subset_df[true_col] == subset_df[pred_col])
            & (subset_df[true_col] == target_class)
        )
        subset_size = np.sum(subset_df[true_col] == target_class)
    accuracy = (correct_preds / subset_size) * 100
    subset_size_percent = (subset_size / total_size) * 100

    # Calculate precision for the target class
    if target_class in [None, "all"]:
        precision = np.nan
        return threshold, accuracy, precision, subset_size_percent

    true_positives = np.sum(
        (subset_df[true_col] == target_class) & (subset_df[pred_col] == target_class)
    )
    false_positives = np.sum(
        (subset_df[true_col] != target_class) & (subset_df[pred_col] == target_class)
    )

    if true_positives + false_positives == 0:
        precision = np.nan
    else:
        precision = (true_positives / (true_positives + false_positives)) * 100

    return threshold, accuracy, precision, subset_size_percent

In [None]:
def compute_metrics_global(
    df: pd.DataFrame,
    threshold: float | np.floating,
    true_col: str,
    pred_col: str,
    pred_prob_cols: List[str],
) -> Tuple[float, float, float, float]:
    """
    Compute accuracy, precision, and subset size for a given probability threshold.

    Parameters:
    df (pd.DataFrame): The input DataFrame containing the true labels, predicted labels, and predicted probabilities.
    threshold (float): The probability threshold for filtering the DataFrame.
    true_col (str): The column name containing the true labels.
    pred_col (str): The column name containing the predicted labels.
    pred_prob_cols (List[str]|str): List of column names containing the predicted probabilities.
                                OR a Max PredScore column

    Returns:
    Tuple[float, float, float, float, int]: A tuple containing the threshold, the accuracy (%), the macro f1-score (%) and the subset size (%) respectively.
    """
    total_size = len(df)

    # Filter rows where the max predicted probability is above the threshold
    # Normally expecting a matrix of probabilities
    # But can deal with a Max PredScore column
    if isinstance(pred_prob_cols, str):
        pred_prob_cols = [pred_prob_cols]
    try:
        subset_df = df[df[pred_prob_cols].max(axis=1) >= threshold]
    except TypeError as e:
        print(
            f"Error: Could not filter rows.\npred_cols: {pred_prob_cols}\nthreshold: {threshold}"
        )
        raise e

    N = len(subset_df)
    if N == 0:
        return np.nan, np.nan, np.nan, np.nan

    # Metrics
    existing_labels = sorted(df[true_col].unique())
    acc: float = accuracy_score(subset_df[true_col], subset_df[pred_col])  # type: ignore
    f1: float = f1_score(subset_df[true_col], subset_df[pred_col], average="macro", labels=existing_labels)  # type: ignore
    relative_size = N / total_size

    return float(threshold), acc, f1, relative_size

In [None]:
ACCURACY_NAME = "rec"
PRECISION_NAME = "prec"
SUBSET_SIZE_NAME = "sz"


def find_columns(df: pd.DataFrame, verbose: bool = False) -> Dict[str, List[str] | str]:
    """
    Find the columns containing true labels, predicted labels, and predicted probabilities in a DataFrame.
    """
    df_cols = df.columns
    df_cols = [col for col in df_cols if str(col) not in ["TRUE", "FALSE"]]

    likely_true_class_cols = [
        col for col in df_cols if "true" in col.lower() or "expected" in col.lower()
    ]
    likely_pred_class_cols = [col for col in df_cols if "pred" in col.lower()]

    if not likely_true_class_cols or not likely_pred_class_cols:
        raise ValueError(
            "Could not automatically detect 'True class' or 'Predicted class' columns."
        )

    true_col = likely_true_class_cols[0]
    pred_col = likely_pred_class_cols[0]
    if df[true_col].dtype != object or df[pred_col].dtype != object:
        print(f"{true_col} and {pred_col} are not string columns. Could cause issues.")

    if verbose:
        print(f"True class: {true_col}")
        print(f"Predicted class: {pred_col}")

    classes = df[true_col].unique().tolist() + ["all"]
    pred_prob_cols = classes[0:-1]

    if verbose:
        print(f"Classes: {classes}")
        print(f"Predicted probability columns: {pred_prob_cols}")

    for col in pred_prob_cols:
        if df[col].dtype != float:
            print(f"{col} is not a float column ({df[col].dtype}). Could cause issues.")

    return {
        "true_col": true_col,
        "pred_col": pred_col,
        "classes": classes,
        "pred_prob_cols": pred_prob_cols,
    }


def evaluate_thresholds(
    df: pd.DataFrame, thresholds: List[float], verbose: bool = False
) -> Dict[str, pd.DataFrame]:
    """
    Evaluate the accuracy and subset size for different probability thresholds with improved automatic column detection.

    Parameters:
    df (pd.DataFrame): The dataframe containing true labels and predicted probabilities.
    thresholds (list): List of probability thresholds to evaluate.

    Returns:
    pd.DataFrame: A dataframe containing the accuracy and subset size for each threshold.
    """
    columns = find_columns(df, verbose=verbose)
    true_col: str = columns["true_col"]  # type: ignore
    pred_col: str = columns["pred_col"]  # type: ignore
    classes: List[str] = columns["classes"]  # type: ignore
    pred_prob_cols: List[str] = columns["pred_prob_cols"]  # type: ignore

    # Evaluate each threshold over each class
    results_dfs = {}
    for class_label in classes:
        results = []
        filtered_df = (
            df
            if class_label == "all"
            else df[(df[true_col] == class_label) | (df[pred_col] == class_label)]
        )

        for thresh in thresholds:
            try:
                result = compute_metrics(
                    filtered_df,
                    thresh,
                    true_col,
                    pred_col,
                    pred_prob_cols,
                    target_class=class_label,
                )
            except Exception as e:
                print(
                    f"Error. Could not compute metric with class {class_label}.\ntrue_col: {true_col}\npred_col: {pred_col}\npred_prob_cols: {pred_prob_cols}\n"
                )
                raise e

            results.append(result)

        # Convert to DataFrame for easier manipulation
        short_class_label = class_label[0:10]
        results_df = pd.DataFrame(
            results,
            columns=[
                "Threshold",
                f"{ACCURACY_NAME}_{short_class_label} (%)",
                f"{PRECISION_NAME}_{short_class_label} (%)",
                f"{SUBSET_SIZE_NAME}_{short_class_label} (%) ({filtered_df.shape[0]})",
            ],
        )

        results_dfs[class_label] = results_df

    return results_dfs

In [None]:
def evaluate_thresholds_global(
    df: pd.DataFrame,
    thresholds: float_seq,
    verbose: bool = False,
    columns: Dict[str, List[str] | str] | None = None,
) -> pd.DataFrame:
    """
    Evaluate the accuracy and subset size for different probability thresholds with improved automatic column detection.

    Parameters:
        df (pd.DataFrame): The dataframe containing true labels and predicted probabilities.
        thresholds (list): List of probability thresholds to evaluate.
        verbose (bool): Whether to print verbose information.
        columns (dict): A dictionary containing the column names for true labels, predicted labels, and predicted probabilities.
                        Expecting entries: "true_col", "pred_col", "pred_prob_cols"|"max_pred".
    Returns:
        pd.DataFrame: A dataframe containing the accuracy and subset size for each threshold.
    """
    if columns is None:
        columns = find_columns(df, verbose=verbose)
        true_col: str = columns["true_col"]  # type: ignore
        pred_col: str = columns["pred_col"]  # type: ignore
        pred_prob_cols: List[str] = columns["pred_prob_cols"]  # type: ignore
    else:
        true_col: str = columns["true_col"]  # type: ignore
        pred_col: str = columns["pred_col"]  # type: ignore
        try:
            pred_prob_cols: List[str] = columns["pred_prob_cols"]  # type: ignore
        except KeyError:
            pred_prob_cols: List[str] = [columns["max_pred"]]  # type: ignore

    # Evaluate each threshold over each class
    results = []
    for tresh in thresholds:
        try:
            result = compute_metrics_global(df, tresh, true_col, pred_col, pred_prob_cols)
        except Exception as e:
            print(
                f"Error. Could not compute metrics.\ntrue_col: {true_col}\npred_col: {pred_col}\npred_prob_cols: {pred_prob_cols}\n"
            )
            raise e

        results.append(result)

    # Convert to DataFrame for easier manipulation
    results_df = pd.DataFrame(
        results,
        columns=[
            "Threshold",
            "Accuracy (%)",
            "F1-score",
            f"Subset size (%) ({df.shape[0]})",
        ],
    )

    return results_df

In [None]:
def create_thresholds_graph_global_plotly(
    metrics_df: pd.DataFrame, name: str, xrange: Tuple[float, float] | None = None
):
    """
    Return graph of the accuracy and subset size at different probability thresholds for global results.

    Parameters:
    metrics_df (pd.DataFrame): DataFrame with metrics at different probability thresholds.
    name (str): Graph title.

    Returns:
    go.Figure: Plotly figure object with the plotted graph.
    """
    # color-blind friendly
    # black, blue, red
    colors = ["#000000", "#005AB5", "#DC3220"]
    marker1 = "square-open"
    marker2 = "cross-open"
    marker3 = "circle"

    fig = go.Figure()

    acc_label = metrics_df.filter(like="Acc").columns[0]
    f1_score_label = metrics_df.filter(like="F1").columns[0]
    subset_size_label = metrics_df.filter(like="Subset").columns[0]

    # Plot accuracy
    vals = metrics_df[acc_label]
    fig.add_trace(
        go.Scatter(
            x=metrics_df["Threshold"],
            y=vals,
            name=acc_label,
            line=dict(color=colors[2]),
            marker_symbol=marker1,
            mode="lines+markers",
        )
    )

    # Plot f1_score
    vals = metrics_df[f1_score_label]
    fig.add_trace(
        go.Scatter(
            x=metrics_df["Threshold"],
            y=vals,
            name=f1_score_label,
            line=dict(color=colors[1], dash="dot"),
            marker_symbol=marker2,
            mode="lines+markers",
        )
    )

    # Plot subset size on secondary Y-axis
    vals = metrics_df[subset_size_label]
    min_y2 = vals.min()
    fig.add_trace(
        go.Scatter(
            x=metrics_df["Threshold"],
            y=vals,
            name=subset_size_label.split("(")[0].strip(),
            line=dict(color=colors[0], dash="dash"),
            marker_symbol=marker3,
            yaxis="y2",
            mode="lines+markers",
        )
    )

    # Adjusting the layout
    fig.update_layout(
        title=f"Metrics at Different Pred. Score Thresholds<br>{name}",
        xaxis_title="Prediction Score Threshold",
        xaxis=dict(
            tickvals=np.linspace(0, 1, 11),
            ticktext=[f"{x:.1f}" for x in np.linspace(0, 1, 11)],
        ),
        yaxis_title="Accuracy / F1-score (%)",
        yaxis2=dict(title="Subset Size (%)", overlaying="y", side="right"),
        legend=dict(orientation="v", x=1.1, y=1),
        height=500,
        width=500,
        yaxis2_range=[min_y2 - 0.001, 1.001],
    )

    if not xrange:
        xrange = (-0.001, 1.001)
    fig.update_xaxes(range=xrange)

    fig.update_traces(line={"width": 1})

    return fig

In [None]:
def create_thresholds_graph_plotly(threshold_dfs: Dict[str, pd.DataFrame], name: str):
    """
    Return graph of the accuracy and subset size at different probability thresholds for all classes.

    Parameters:
    threshold_metrics_df (Dict[str, pd.DataFrame]): A dictionary containing dfs with metrics for each class label and the general case.
    name (str): Graph title.

    Returns:
    go.Figure: Plotly figure object with the plotted graph.
    """
    colors = px.colors.qualitative.Dark24
    marker1 = "circle"
    marker2 = "cross-open"
    marker3 = "circle-open"

    fig = go.Figure()
    for idx, (_, threshold_metrics) in enumerate(threshold_dfs.items()):
        color = colors[idx % len(colors)]

        acc_label = threshold_metrics.filter(like=f"{ACCURACY_NAME}").columns[0]
        acc_subset = threshold_metrics.filter(like=f"{SUBSET_SIZE_NAME}").columns[0]
        prec_label = threshold_metrics.filter(like=f"{PRECISION_NAME}").columns[0]

        # Plot accuracy
        fig.add_trace(
            go.Scatter(
                x=threshold_metrics["Threshold"],
                y=threshold_metrics[acc_label],
                name=acc_label,
                line=dict(color=color),
                marker_symbol=marker1,
                mode="lines+markers",
            )
        )

        # Plot precision
        prec_vals = threshold_metrics[prec_label]
        if not prec_vals.isna().all():
            fig.add_trace(
                go.Scatter(
                    x=threshold_metrics["Threshold"],
                    y=prec_vals,
                    name=prec_label,
                    line=dict(color=color, dash="dot"),
                    marker_symbol=marker2,
                    mode="lines+markers",
                )
            )

        # Plot subset size on secondary Y-axis
        fig.add_trace(
            go.Scatter(
                x=threshold_metrics["Threshold"],
                y=threshold_metrics[acc_subset],
                name=acc_subset,
                line=dict(color=color, dash="dash"),
                marker_symbol=marker3,
                yaxis="y2",
                mode="lines+markers",
            )
        )

    # Adjusting the layout
    fig.update_layout(
        title=f"Accuracy and Subset Size at Different Probability Thresholds<br>{name}",
        xaxis_title="Probability Threshold",
        xaxis=dict(
            tickvals=np.linspace(0, 1, 11),
            ticktext=[f"{x:.1f}" for x in np.linspace(0, 1, 11)],
        ),
        yaxis_title="Accuracy (%)",
        yaxis2=dict(title="Subset Size (%)", overlaying="y", side="right"),
        legend=dict(orientation="v", x=1.05, y=1),
        height=1000,
        width=1600,
    )
    fig.update_xaxes(range=[-0.001, 1.001])
    fig.update_traces(line={"width": 1})

    return fig

In [None]:
thresholds: List[float] = [float(x) for x in np.arange(0, 1, 1 / 20)] + [0.99]

### MLP EpiAtlas cross-validation results

In [None]:
category_remapper = {
    "assay": ASSAY,
    "assay7": ASSAY,
    f"{ASSAY}_11c": ASSAY,
    ASSAY: ASSAY,
    "sex": SEX,
    "sex3": SEX,
    SEX: SEX,
    "harmonized_donor_sex_w-mixed": SEX,
    "cancer": CANCER,
    CANCER: CANCER,
    "biomat": BIOMATERIAL_TYPE,
    BIOMATERIAL_TYPE: BIOMATERIAL_TYPE,
}

for l in ["donorlife", "lifestage", LIFE_STAGE]:
    category_remapper[l] = LIFE_STAGE
    category_remapper[f"{l}_merged"] = LIFE_STAGE

In [None]:
categories = [
    ASSAY,
    CELL_TYPE,
    SEX,
    LIFE_STAGE,
    BIOMATERIAL_TYPE,
    CANCER,
    "paired_end",
    "project",
]
split_results_handler = SplitResultsHandler()

data_dir_100kb = base_data_dir / "training_results" / "dfreeze_v2" / "hg38_100kb_all_none"

In [None]:
# # Select 10-fold oversampling runs
# all_split_dfs = split_results_handler.general_split_metrics(
#     results_dir=data_dir_100kb,
#     merge_assays=False,
#     include_categories=categories,
#     exclude_names=["reg", "no-mixed", "chip", "16ct", "27ct"],
#     return_type="split_results",
#     oversampled_only=True,
#     verbose=False,
# )
# all_split_dfs_concat: Dict = split_results_handler.concatenate_split_results(all_split_dfs, concat_first_level=True)  # type: ignore

Fixing special case "paired_end" which has bool values that aren't treated as strings.

In [None]:
# cols = ["True class", "Predicted class"]
# df = all_split_dfs_concat["paired_end"].copy()

# # labels: bool -> str
# df[cols] = df[cols].astype(str)
# for col in cols:
#     df[col] = df[col].str.lower()

# # make sure column names = class names
# df = df.rename(columns={"TRUE": "true", "FALSE": "false"})

# all_split_dfs_concat["paired_end"] = df

Computing all values separately from graphing

In [None]:
# threshold_dfs = {}
# for task_name, df in all_split_dfs_concat.items():
#     print("TASK:",task_name)
#     threshold_dfs[task_name] = evaluate_thresholds(df, thresholds)

In [None]:
# output_dir = base_fig_dir / "threshold_graphs" / "100kb_all_none"
# if not output_dir.exists():
#     output_dir.mkdir(parents=True, exist_ok=True)

# for task_name, df in all_split_dfs_concat.items():
#     print("TASK:", task_name)
#     nb_samples = len(df)
#     nb_classes = df["True class"].nunique()

#     df = threshold_dfs[task_name]

#     # create figure
#     name = f"{task_name} - {nb_classes} classes"
#     fig = create_thresholds_graph_plotly(df, f"{name} - n={nb_samples}")
#     fig.show()

#     # # save
#     filename = f"threshold_impact_graph_full_{get_valid_filename(name)}".replace(
#         "_-_", "-"
#     )
#     fig.write_image(output_dir / f"{filename}.png")
#     fig.write_image(output_dir / f"{filename}.svg")
#     fig.write_html(output_dir / f"{filename}.html")

In [None]:
# threshold_dfs = {}
# other_info = {}
# for task_name, df in all_split_dfs_concat.items():
#     print("TASK:", task_name)
#     nb_samples = len(df)
#     nb_classes = df["True class"].nunique()

#     other_info[task_name] = {"nb_samples": nb_samples, "nb_classes": nb_classes}

#     threshold_dfs[task_name] = evaluate_thresholds_global(df, thresholds)

In [None]:
# output_dir = base_fig_dir / "threshold_graphs" / "100kb_all_none" / "EpiATLAS"
# if not output_dir.exists():
#     output_dir.mkdir(parents=True, exist_ok=True)

# for task_name, df in all_split_dfs_concat.items():
#     print("TASK:", task_name)
#     nb_samples = len(df)
#     nb_classes = df["True class"].nunique()

#     df = threshold_dfs[task_name]

#     # create figure
#     name = f"{task_name} - {nb_classes} classes"
#     fig = create_thresholds_graph_global_plotly(df, f"{name} - n={nb_samples}", xrange=(max(0, 1.0/nb_classes-0.05), 1.001))
#     # fig.show()

#     # # save
#     filename = f"threshold_impact_graph_global_{get_valid_filename(name)}".replace(
#         "_-_", "-"
#     )
#     fig.write_image(output_dir / f"{filename}.png")
#     fig.write_image(output_dir / f"{filename}.svg")
#     fig.write_html(output_dir / f"{filename}.html")

Rename / drop classifier metrics for future graphing

In [None]:
# for label in [f"{ASSAY}_7c", "project", "paired_end"]:
#     threshold_dfs.pop(label, None)

# for name in list(threshold_dfs.keys()):
#     try:
#         new_name = category_remapper[name]
#     except KeyError:
#         # Undesired category for rest
#         del threshold_dfs[name]
#         continue

#     threshold_dfs[new_name] = threshold_dfs.pop(name)

# all_threshold_results["EpiATLAS"] = {"results": threshold_dfs, "other_info": other_info}

### ENCODE, ChIP-Atlas and recount3 inference results

In [None]:
# 'other'/'unknown' are too undefined, we exclude from life stage predictions
cell_line_vals = ["cell_line", "cell line", "unknown", "other"]

unmerged_life_stages = [
    "embryonic",
    "fetal",
    "newborn",
    "embryo",
]

unknown_values = ["unknown", "other", "indeterminate"]

In [None]:
predictions_dir = table_dir / "dfreeze_v2" / "predictions"

We do not apply `life stage classifier` on `cell line` samples because it was not part of the training data,
and the notion of life stage for a cell line is dubious. 

Also, we merge `perinatal stages` public DB inference (embryonic, fetal, newborn).

In [None]:
def format_category_labels(
    df: pd.DataFrame, categories: List[str], verbose: bool = False
) -> pd.DataFrame:
    """Uniformize class labels for each category labels."""
    # Uniformize class labels
    to_format = []
    for col in df.columns:
        cond1 = any(category in col.lower() for category in categories)
        cond2 = any(l in col.lower() for l in ["true", "expected", "predicted"])
        if cond1 and (cond2 or col in categories):
            if verbose:
                print(f"Formatting {col}")
            to_format.append(col)

    if verbose:
        print(f"Formatting {len(to_format)} columns: {to_format}")

    df = format_labels(
        df=df,
        columns=to_format,
    )

    return df

#### ChIP-Atlas

In [None]:
preds_path = (
    predictions_dir / "ChIP-Atlas_predictions_20240606_merge_metadata_freeze1.csv.xz"
)
pred_df = pd.read_csv(preds_path, sep=",", low_memory=False, compression="xz")
print(pred_df.shape)

In [None]:
to_drop = [
    col
    for col in pred_df.columns
    if any(l in col.lower() for l in ["disease", "assay11", "assay13"])
]
pred_df = pred_df.drop(columns=to_drop)
print(pred_df.shape)

In [None]:
pred_df = pred_df[pred_df["is_EpiAtlas_EpiRR"].astype(str) == "0"]
print(pred_df.shape)

In [None]:
pred_df = pred_df[
    ~pred_df["core7_DBs_consensus"].isin(
        ["Ignored - Potential non-core", "non-core/CTCF"]
    )
]
print(pred_df.shape)

In [None]:
pred_df = pred_df.fillna("unknown")

In [None]:
pred_df[BIOMATERIAL_TYPE] = pred_df["expected_biomat"]

to_replace = {
    "sex3": "sex",
    "assay7": "assay",
    "donorlife": "lifestage",
}
pred_df = rename_columns(
    df=pred_df,
    remapper=to_replace,
    exact_match=False,
    verbose=True,
)

In [None]:
col_mapper_template = {
    "true_col": "expected_{}",
    "pred_col": "Predicted_class_{}",
    "max_pred": "Max_pred_{}",
}

In [None]:
verbose = True

categories = ["assay", "sex", "cancer", "lifestage", "biomat"]

pred_df = format_category_labels(pred_df, categories, verbose=False)

In [None]:
threshold_dfs = {}
other_info = {}
for category in categories:
    print("TASK:", category)
    col_mapper = {k: v.format(category) for k, v in col_mapper_template.items()}

    df = pred_df.copy()

    # Filter unknown/NA
    df = df[~df[col_mapper["true_col"]].isin(unknown_values)]

    if category == "assay":
        df = pred_df[pred_df[col_mapper["true_col"]].isin(ASSAY_ORDER[0:7])]

    elif category == "lifestage":
        df = df[~df[BIOMATERIAL_TYPE].isin(cell_line_vals)]
        life_stages = set(df[col_mapper["true_col"]].unique()) | set(
            df[col_mapper["pred_col"]].unique()
        )
        if any(label in life_stages for label in unmerged_life_stages):
            df = merge_life_stages(
                df=df,
                lifestage_column_name=category,
                column_name_templates=list(col_mapper.values()),
            )
            category = f"{category}_merged"
            col_mapper = {k: v.format(category) for k, v in col_mapper_template.items()}

    cat_name = category_remapper[category]

    nb_samples = df.shape[0]
    N_true_classes = len(set(df[col_mapper["true_col"]]))
    total_N_classes = len(
        set(df[col_mapper["pred_col"]]) | set(df[col_mapper["true_col"]])
    )
    other_info[cat_name] = {
        "nb_samples": nb_samples,
        "nb_classes": N_true_classes,
        "total_possible_classes": total_N_classes,
    }

    if verbose:
        for col in [col_mapper["true_col"], col_mapper["pred_col"]]:
            print(df[col].value_counts(dropna=False), "\n")

    threshold_dfs[cat_name] = evaluate_thresholds_global(
        df, thresholds, verbose=verbose, columns=col_mapper  # type: ignore
    )

all_threshold_results["ChIP-Atlas"] = {"results": threshold_dfs, "other_info": other_info}

In [None]:
# for task_name in categories:
#     df = threshold_dfs[task_name]
#     print("TASK:", task_name)
#     nb_samples = other_info[task_name]["nb_samples"]
#     nb_classes = other_info[task_name]["nb_classes"]

#     # create figure
#     name = f"{task_name} - {nb_classes} classes"
#     fig = create_thresholds_graph_global_plotly(df, f"{name} - n={nb_samples}", xrange=(max(0, 1.0/nb_classes-0.05), 1.001))
#     fig.show()

#### ENCODE

In [None]:
preds_path = predictions_dir / "encode_predictions_merge_metadata_2025-02_freeze1.csv.xz"

pred_df = pd.read_csv(preds_path, sep=",", low_memory=False, compression="xz")
print(pred_df.shape)

In [None]:
for col in pred_df.columns:
    print(col)

In [None]:
to_drop = [
    col
    for col in pred_df.columns
    if any(
        l in col.lower()
        for l in ["disease", "assay_epiclass_7c", "assay13", "biospecimen"]
    )
]

pred_df = pred_df.drop(columns=to_drop)
print(pred_df.shape)

In [None]:
for col in list(pred_df.columns):
    if "11c" in col:
        new_col = col.replace("assay_epiclass_11c", "assay_epiclass")
        pred_df = pred_df.rename(columns={col: new_col})

In [None]:
pred_df = pred_df[~pred_df["in_epiatlas"]]
print(pred_df.shape)

In [None]:
col_mapper_template = {
    "true_col": "{}",
    "pred_col": "Predicted class ({})",
    "max_pred": "Max pred ({})",
}

In [None]:
relevant_columns = []
categories = [ASSAY, SEX, CANCER, LIFE_STAGE, BIOMATERIAL_TYPE]

for category in categories:
    relevant_columns.extend(
        [
            col_mapper_template["true_col"].format(category),
            col_mapper_template["pred_col"].format(category),
        ]
    )

pred_df = format_category_labels(pred_df, relevant_columns, verbose=False)

In [None]:
verbose = True

threshold_dfs_core = {}
threshold_dfs_noncore = {}
other_info_core = {}
other_info_noncore = {}

for category in categories:
    print("TASK:", category)
    col_mapper = {k: v.format(category) for k, v in col_mapper_template.items()}

    df: pd.DataFrame = pred_df.copy()  # type: ignore
    df.fillna("unknown", inplace=True)

    # Filter unknown/NA
    df = df[~(df[col_mapper["true_col"]].isin(unknown_values))]

    # Merge rna / wgbs pairs
    if category == ASSAY:
        true, pred = col_mapper["true_col"], col_mapper["pred_col"]
        df.loc[:, [true, pred]] = df.loc[:, [true, pred]].replace(
            ASSAY_MERGE_DICT, inplace=False
        )
    elif category == LIFE_STAGE:
        df = df[~df[BIOMATERIAL_TYPE].isin(cell_line_vals)]

        life_stages = set(df[col_mapper["true_col"]].unique()) | set(
            df[col_mapper["pred_col"]].unique()
        )
        if any(label in life_stages for label in unmerged_life_stages):
            df = merge_life_stages(
                df=df,
                lifestage_column_name=category,
                column_name_templates=list(col_mapper.values()),
            )
            category = f"{category}_merged"
            col_mapper = {k: v.format(category) for k, v in col_mapper_template.items()}
            if verbose:
                print("Biomaterial type and assay post cell line filter:")
                print(df[BIOMATERIAL_TYPE].value_counts(dropna=False), "\n")
                print(df[ASSAY].value_counts(dropna=False), "\n")

    # split core/non-core
    df.loc[:, ASSAY] = df.loc[:, ASSAY].replace(ASSAY_MERGE_DICT, inplace=False)
    mask = df[ASSAY].isin(core9_assays)

    df_core = df[mask]
    df_noncore = df[~mask]

    # Compute all thresholds
    cat_name = category_remapper[category]
    for name, container_results, container_other_info, set_df in zip(
        ["core", "noncore"],
        [threshold_dfs_core, threshold_dfs_noncore],
        [other_info_core, other_info_noncore],
        [df_core, df_noncore],
    ):
        if cat_name == ASSAY and "ctcf" in set_df[ASSAY].unique():
            if verbose:
                print("\nSkipping assay non-core\n")
            continue

        nb_samples = set_df.shape[0]
        N_true_classes = len(set(set_df[col_mapper["true_col"]]))
        total_N_classes = len(
            set(set_df[col_mapper["pred_col"]]) | set(set_df[col_mapper["true_col"]])
        )
        container_other_info[cat_name] = {
            "nb_samples": nb_samples,
            "nb_classes": N_true_classes,
            "total_possible_classes": total_N_classes,
        }

        if verbose:
            print(f"Set: {name}")
            for col in [col_mapper["true_col"], col_mapper["pred_col"]]:
                print(set_df[col].value_counts(dropna=False), "\n")

        container_results[cat_name] = evaluate_thresholds_global(
            set_df, thresholds, verbose=False, columns=col_mapper  # type: ignore
        )

all_threshold_results["ENCODE_core"] = {
    "results": threshold_dfs_core,
    "other_info": other_info_core,
}
all_threshold_results["ENCODE_non-core"] = {
    "results": threshold_dfs_noncore,
    "other_info": other_info_noncore,
}

In [None]:
# for task_name in categories:
#     df = threshold_dfs[task_name]
#     print("TASK:", task_name)
#     nb_samples = other_info[task_name]["nb_samples"]
#     nb_classes = other_info[task_name]["nb_classes"]

#     # create figure
#     name = f"{task_name} - {nb_classes} classes"
#     fig = create_thresholds_graph_global_plotly(df, f"{name} - n={nb_samples}", xrange=(max(0, 1.0/nb_classes-0.05), 1.001))
#     fig.show()

#### recount3

In [None]:
preds_path = predictions_dir / "recount3_merged_preds_metadata_freeze1.csv.xz"
pred_df = pd.read_csv(preds_path, sep=",", low_memory=False, compression="xz")
print(pred_df.shape)

In [None]:
# for col in pred_df.columns:
#     print(col)

In [None]:
col_mapper_template = {
    "true_col": "{}",
    "pred_col": "Predicted class ({})",
    "max_pred": "Max pred ({})",
}

In [None]:
categories = [ASSAY, SEX, CANCER, f"{LIFE_STAGE}_merged", BIOMATERIAL_TYPE]
pred_df = format_category_labels(pred_df, categories + [LIFE_STAGE])

In [None]:
for col in pred_df.columns:
    if "donor_life" in col.lower():
        if pred_df[col].dtype == "object":
            print(col)
            print(pred_df[col].unique())

In [None]:
pred_df = merge_life_stages(
    df=pred_df,
    lifestage_column_name=LIFE_STAGE,
    column_name_templates=["Predicted class ({})", "Max pred ({})"],
)

In [None]:
# categories = [f"{LIFE_STAGE}_merged"]

assay_pred_col = f"Predicted class ({ASSAY})"
assay_max_pred_col = f"Max pred ({ASSAY})"

verbose = True

threshold_dfs = {}
other_info = {}

for category in categories:
    print("TASK:", category)
    col_mapper = {k: v.format(category) for k, v in col_mapper_template.items()}

    df = pred_df.copy()
    df.fillna("unknown", inplace=True)

    # Filter unknown/NA
    df = df[~df[col_mapper["true_col"]].isin(unknown_values)]

    if verbose:
        print("Know labels distribution:")
        print(df[col_mapper["true_col"]].value_counts(dropna=False), "\n")

    if category == ASSAY:
        pred_col = col_mapper["pred_col"]
        df.loc[:, pred_col] = df.loc[:, pred_col].replace(ASSAY_MERGE_DICT, inplace=False)

        # All supposed to be rna-seq-like assays
        true_col = col_mapper["true_col"]
        df.loc[:, true_col] = "rna_seq"
    else:
        # Only keep "similar to training" dsets
        # Predicted as m/rna-seq by assay classifier with high-pred (>0.6)
        cond1 = df[assay_pred_col].isin(["rna_seq", "mrna_seq"])
        cond2 = df[assay_max_pred_col] > 0.6
        df = df[cond1 & cond2]

    if verbose:
        print("All labels distribution after 11c filter:")
        print(df[col_mapper["true_col"]].value_counts(dropna=False), "\n")

    if LIFE_STAGE in category:
        if verbose:
            print(f"Filtering out cell lines for `{category}`...")

        df = df[~df[BIOMATERIAL_TYPE].isin(cell_line_vals)]

        if verbose:
            print("Life stage labels distribution after cell line filter:")
            for col in [col_mapper["true_col"], col_mapper["pred_col"]]:
                print(df[col].value_counts(dropna=False), "\n")
            print(df[BIOMATERIAL_TYPE].value_counts(dropna=False), "\n")

    cat_name = category_remapper[category]

    nb_samples = df.shape[0]
    N_true_classes = len(set(df[col_mapper["true_col"]]))
    total_N_classes = len(
        set(df[col_mapper["pred_col"]]) | set(df[col_mapper["true_col"]])
    )
    other_info[cat_name] = {
        "nb_samples": nb_samples,
        "nb_classes": N_true_classes,
        "total_possible_classes": total_N_classes,
    }

    threshold_dfs[cat_name] = evaluate_thresholds_global(
        df, thresholds, verbose=verbose, columns=col_mapper  # type: ignore
    )

all_threshold_results["recount3"] = {"results": threshold_dfs, "other_info": other_info}

In [None]:
# for task_name in categories:
#     df = threshold_dfs[task_name]
#     print("TASK:", task_name)
#     nb_samples = other_info[task_name]["nb_samples"]
#     nb_classes = other_info[task_name]["nb_classes"]

#     # create figure
#     name = f"{task_name} - {nb_classes} classes"
#     fig = create_thresholds_graph_global_plotly(df, f"{name} - n={nb_samples}", xrange=(max(0, 1.0/nb_classes-0.05), 1.001))
#     fig.show()

#### Graph results for training and inference per database

In [None]:
def rgb2hex(r, g, b):
    """Convert rgb to hex."""
    return f"#{r:02x}{g:02x}{b:02x}"


def hex2rgb(hex_str):
    """Convert hex to rgb."""
    return ImageColor.getrgb(hex_str)


def add_acc_f1(
    fig: go.Figure,
    df: pd.DataFrame,
    row: int,
    col: int,
    colors: List[str],
    show_legend: bool = True,
    label_modifier: str = "",
    color_mod: int = 0,
) -> None:
    """Add accuracy and F1 to the figure.

    Args:
        fig: The figure to add the traces to.
        df: The dataframe containing the data.
        row: The row of the subplot. (1 indexed)
        col: The column of the subplot. (1 indexd)
        colors: The colors to use for the traces (1 for accuracy, 2 for F1).
        show_legend: Whether to show the legend.
        label_modifier: A string to add to the legend.
        color_mod: The RGB amount to modify the color by.
    """
    acc_label = df.filter(like="Acc").columns[0]
    f1_label = df.filter(like="F1").columns[0]

    color_acc = colors[1]
    color_f1 = colors[2]

    name_acc = acc_label
    name_f1 = f1_label

    if label_modifier:
        # Names
        name_acc = f"{name_acc} {label_modifier}"
        name_f1 = f"{name_f1} {label_modifier}"

        N = color_mod
        # Acc
        rgb_color = hex2rgb(color_acc)
        rgb_vals = [max(color_val - N, 0) for color_val in rgb_color]
        color_acc = rgb2hex(*rgb_vals)

        # F1
        rgb_color = hex2rgb(color_f1)
        rgb_vals = [max(color_val - N, 0) for color_val in rgb_color]
        color_f1 = rgb2hex(*rgb_vals)

    # Plot accuracy
    acc_vals = df[acc_label]
    fig.add_trace(
        go.Scatter(
            x=df["Threshold"],
            y=acc_vals,
            name=name_acc,
            line=dict(color=color_acc, dash="solid"),
            mode="lines",
            showlegend=show_legend,
            legendgroup="Accuracy",
        ),
        row=row,
        col=col,
    )

    # Plot F1
    prec_vals = df[f1_label]
    if not prec_vals.isna().all():
        fig.add_trace(
            go.Scatter(
                x=df["Threshold"],
                y=prec_vals,
                name=name_f1,
                line=dict(color=color_f1, dash="dot"),
                mode="lines",
                showlegend=show_legend,
                legendgroup="F1-score",
            ),
            row=row,
            col=col,
        )


def add_subset_size(
    fig: go.Figure,
    df: pd.DataFrame,
    row: int,
    col: int,
    colors: List[str],
    show_legend: bool = True,
    label_modifier: str = "",
    color_mod: int = 1,
) -> None:
    """Add file count relative size to the figure."""
    # Plot subset size on secondary Y-axis
    subset_label = df.filter(like="Subset").columns[0]
    trace_name = subset_label.split("(")[0].strip() + " (%)"

    trace_color = colors[0]

    if label_modifier:
        trace_name = f"{trace_name} {label_modifier}"

        N = color_mod
        rgb_color = hex2rgb(trace_color)
        rgb_vals = [min(color_val + N, 255) for color_val in rgb_color]
        trace_color = rgb2hex(*rgb_vals)

    fig.add_trace(
        go.Scatter(
            x=df["Threshold"],
            y=df[subset_label],
            name=trace_name,
            line=dict(color=trace_color, dash="dash"),
            yaxis="y2",
            mode="lines",
            showlegend=show_legend,
            legendgroup="Subset Size",
        ),
        row=row,
        col=col,
    )


def graph_all_DB_threshold_graphs(
    results_dict: Dict[str, Dict],
    output_dir: Path | None = None,
    filename: str | None = None,
):
    """
    Create a threshold graph for mutiple DBs and classifiers.

    Args:
        results_dict: A dictionary containing the results for each DB and classifier.
        output_dir: The directory to save the graph to.
        name: The name of the graph.

    """
    category_order = [ASSAY, SEX, CANCER, LIFE_STAGE, BIOMATERIAL_TYPE]
    DBs_order = ["EpiATLAS", "ENCODE_core", "ENCODE_non-core", "ChIP-Atlas", "recount3"]
    graph_renamer = {
        ASSAY: "Assay",
        SEX: "Sex",
        CANCER: "Cancer status",
        LIFE_STAGE: "Life stage",
        BIOMATERIAL_TYPE: "Biomaterial type",
    }

    # color-blind friendly
    # black, blue, red
    colors = ["#000000", "#005AB5", "#DC3220"]

    fig = make_subplots(
        rows=5,
        cols=5,
        row_titles=DBs_order,
        column_titles=[graph_renamer[category] for category in category_order],
        shared_xaxes=True,
        vertical_spacing=0.025,
        horizontal_spacing=0.04,
        x_title="Prediction Score Threshold",
        y_title="Metric value",
    )

    y_ranges = {
        "EpiATLAS": [0.7, 1.01],
        "ChIP-Atlas": [0.1, 1.01],
        "ENCODE_core": [0.45, 1.01],
        "ENCODE_non-core": [0.30, 1.01],
        "recount3": [0, 1.01],
    }

    for i, DB in enumerate(DBs_order):
        # Add empty subplot row, temporary
        if DB == "EpiATLAS":
            for j, _ in enumerate(category_order):
                fig.add_trace(
                    go.Scatter(
                        x=[],
                        y=[],
                        name="",
                    ),
                    row=i + 1,
                    col=j + 1,
                )
            continue

        data = results_dict[DB]
        for j, category in enumerate(category_order):
            show_legend = bool(j == 0 and i == 0)

            try:
                threshold_df = data["results"][category]
            except KeyError as e:
                print(f"Could not find results for {DB} {category}: {e}")
                continue

            add_acc_f1(fig, threshold_df, i + 1, j + 1, colors, show_legend)
            add_subset_size(fig, threshold_df, i + 1, j + 1, colors, show_legend)

            # Nb files + classes
            try:
                other_info = data["other_info"][category]
            except KeyError as e:
                print(f"Could not find other info for {DB} {category}: {e}")
                continue

            N = other_info["nb_samples"]
            c_true = other_info["nb_classes"]
            c_all = other_info["total_possible_classes"]
            annotation_text = f"N = {N}<br>C = {c_true}/{c_all}"
            # print(DB, category, annotation_text)
            fig.add_annotation(
                text=annotation_text,
                showarrow=False,
                font=dict(size=10, color="black"),
                # y=0.1,
                # xref=f"x{i+1} domain",
                # yref=f"y{j+1} domain",
                row=i + 1,
                col=j + 1,
            )

    # Set y-axis ranges
    for i, DB in enumerate(DBs_order):
        y_range = y_ranges[DB]
        for j in range(1, 6):
            dtick = 0.2
            if DB in ["ENCODE_core", "ENCODE_non-core", "EpiATLAS"]:
                dtick = 0.1

            fig.update_yaxes(range=y_range, row=i + 1, col=j, dtick=dtick)

    fig.update_xaxes(range=[0.1, 1.01], dtick=0.2)

    fig.update_layout(
        width=800,
        height=800,
        title="All Databases - 5 classifiers - Metrics at Different Pred. Score Thresholds",
    )

    fig.update_layout(hovermode="x unified", hoverlabel_namelength=-1)

    fig.show()

    if output_dir:
        if not filename:
            filename = "all_DBs_5_classifiers_thresholds"
        fig.write_image(output_dir / f"{filename}.svg")
        fig.write_image(output_dir / f"{filename}.png")
        fig.write_html(output_dir / f"{filename}.html")

In [None]:
output_dir = base_fig_dir / "threshold_graphs" / "100kb_all_none"

graph_all_DB_threshold_graphs(
    all_threshold_results,
    output_dir=output_dir,
    filename="3DBs_5_classifiers_thresholds_w_ENCODE_split",
)