In [None]:
"""Analyze full prediction vector values."""
# pylint: disable=line-too-long, redefined-outer-name, import-error, pointless-statement, use-dict-literal, expression-not-assigned, unused-import, too-many-lines

## Setup

In [None]:
from __future__ import annotations

import itertools
import shutil
from pathlib import Path
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import sklearn.metrics
from IPython.display import display
from plotly.subplots import make_subplots

from epi_ml.core.confusion_matrix import ConfusionMatrixWriter
from epi_ml.utils.general_utility import get_valid_filename

In [None]:
BIOMATERIAL_TYPE = "harmonized_biomaterial_type"
CELL_TYPE = "harmonized_sample_ontology_intermediate"
ASSAY = "assay_epiclass"
SEX = "harmonized_donor_sex"
CANCER = "harmonized_sample_cancer_high"
DISEASE = "harmonized_sample_disease_high"
LIFE_STAGE = "harmonized_donor_life_stage"
TRACK = "track_type"

In [None]:
general_local_logdir = Path.home() / "downloads" / "temp"

## Pivot table on assay and cell type

In [None]:
results_path = (
    general_local_logdir
    / "sex3_oversample_full-10fold-validation_prediction_augmented-all.csv"
)
sex_df = pd.read_csv(results_path, index_col=0)

In [None]:
sex_df.head()

In [None]:
def pivot_table(df: pd.DataFrame, category_label: str) -> pd.DataFrame:
    """Create a pivot table for predictions results."""
    index_columns = [category_label, ASSAY, CELL_TYPE]
    pivot = df.pivot_table(
        index=index_columns,
        columns="Predicted class",
        values="Same?",
        aggfunc=["count", "mean"],
        margins=True,
        margins_name="Total",
        fill_value=0,
    )

    mean_columns = [
        (aggfunc, pred_class)
        for aggfunc, pred_class in pivot.columns
        if aggfunc == "mean"
    ]
    mean_columns.remove(("mean", "Total"))
    pivot.drop(columns=mean_columns, inplace=True)

    return pivot


sex_pivot = pivot_table(sex_df, SEX)
sex_pivot.to_csv(
    general_local_logdir
    / "sex3_oversample_full-10fold-validation_prediction_augmented-all_pivot.csv"
)

## Different confusion matrices

In [None]:
base_logdir = (
    Path.home()
    / "mounts/narval-mount/project-rabyj/epilap/output/logs/epiatlas-dfreeze-v2.1/hg38_100kb_all_none"
)
pred_folders = [
    base_logdir / name
    for name in [
        "harmonized_donor_life_stage_1l_3000n/no-unknown/10fold-oversampling",
        # "assay_epiclass_1l_3000n/11c/10fold-oversampling",
        # "harmonized_donor_sex_1l_3000n/w-mixed/10fold-oversample"
    ]
]
pred_files = [
    pred_folder / "full-10fold-validation_prediction_augmented-all.csv"
    for pred_folder in pred_folders
]

In [None]:
for pred_file in pred_files:
    assert pred_file.exists()

### Per sample, different thresholds

In [None]:
# for pred_file in pred_files:
#         df = pd.read_csv(
#             pred_file,
#             sep=",",
#             usecols=["True class", "Predicted class", "Max pred"],
#         )

#         for threshold in [0, 0.7, 0.9]:
#             sub_df  = df[df["Max pred"] >= threshold]

#             true, pred = sub_df.iloc[:, 0], sub_df.iloc[:, 1]
#             labels = sorted(set(true.unique().tolist() + pred.unique().tolist()))
#             confusion_mat = sklearn.metrics.confusion_matrix(true, pred, labels=labels)

#             writer = ConfusionMatrixWriter(labels=labels, confusion_matrix=confusion_mat)
#             writer.to_all_formats(
#                 logdir=pred_file.parent,
#                 name=str(pred_file.stem) + f"-confusion_matrix-{threshold*100}",
#             )
#             plt.close('all')

### Per EpiRR

In [None]:
def get_majority_class(df: pd.DataFrame) -> str:
    """
    Given a DataFrame representing a single EpiRR, determine the predicted majority class.

    Uses max prediction value to break ties.
    Args:
        df (pd.DataFrame): A DataFrame containing predictions for a single EpiRR.

    Returns:
        str: The majority class label for this EpiRR.
    """
    class_counts = df["Predicted class"].value_counts()
    max_count = class_counts.max()
    majority_classes = class_counts[class_counts == max_count].index.tolist()

    if len(majority_classes) == 1:
        return majority_classes[0]

    avg_max_pred = df.groupby("Predicted class")["Max pred"].mean()
    return avg_max_pred.loc[majority_classes].idxmax()  # type: ignore

In [None]:
for pred_file in pred_files:
    df = pd.read_csv(
        pred_file,
        sep=",",
    )

    # px.box(df, x="True class", y="Max pred", color="Predicted class", title=pred_file.stem).show()

    # only one true class per epiRR
    total_epirrs = df["EpiRR"].nunique()
    print("Total number of EpiRRs:", total_epirrs)
    assert (
        df[["EpiRR", "True class"]].value_counts().shape[0]
        == df[["EpiRR"]].value_counts().shape[0]
    )

    for threshold in [0, 0.7, 0.9]:
        threshold_df = df[df["Max pred"] >= threshold]

        # Group by EpiRR and apply the function to find the majority class for each EpiRR
        majority_class_series = threshold_df.groupby("EpiRR").apply(get_majority_class)
        majority_class_series.name = "Predicted class"

        threshold_df = threshold_df[["EpiRR", "True class"]].drop_duplicates()
        epirr_df = threshold_df.join(majority_class_series, how="inner", on="EpiRR")
        epirr_df = epirr_df.set_index("EpiRR")

        print(
            epirr_df[
                (epirr_df["True class"] == "adult")
                & (epirr_df["Predicted class"] == "embryonic")
            ]
        )
        # true, pred = epirr_df["True class"], epirr_df["Predicted class"]
        # assert len(true) == len(pred)
        # assert len(true) <= total_epirrs

        # labels = sorted(set(true.unique().tolist() + pred.unique().tolist()))
        # confusion_mat = sklearn.metrics.confusion_matrix(true, pred, labels=labels)

        # writer = ConfusionMatrixWriter(labels=labels, confusion_matrix=confusion_mat)

        # out_logdir = pred_file.parent / "conf_per_epirr"
        # out_logdir.mkdir(exist_ok=True)
        # paths = writer.to_all_formats(
        #     logdir=out_logdir,
        #     name=str(pred_file.stem) + f"-confusion_matrix-epirr-t{threshold*100}",
        # )
        # print(paths[-1])
        # plt.close('all')

In [None]:
# # Count the number of predicted classes for each EpiRR
# class_counts = threshold_df.groupby("EpiRR")["Predicted class"].nunique()

# # Find EpiRRs with more than one predicted class
# epiRRs_with_multiple_classes = class_counts[class_counts > 1].index

# # Filter the DataFrame to only include these EpiRRs
# final_df = threshold_df[threshold_df["EpiRR"].isin(epiRRs_with_multiple_classes)]

In [None]:
def get_majority_class_2(group: pd.DataFrame) -> str:
    """
    Given a DataFrame, determine the majority class per EpiRR.

    Args:
        group (pd.DataFrame): A DataFrame containing aggregated data for each EpiRR.

    Returns:
        str: The majority class label.
    """
    # Sorting by count and mean
    sorted_group = group.sort_values(
        by=[("Max pred", "count"), ("Max pred", "mean")], ascending=[False, False]
    )

    # Select the first (majority) class
    majority_class = sorted_group.index[0][
        2
    ]  # The third element in the index tuple should be "Predicted class"

    return majority_class

### Check specific EpiRR

In [None]:
# for pred_file in pred_files:

#     df = pd.read_csv(
#         pred_file,
#         sep=",",
#     )

#     px.box(df, x="True class", y="Max pred", color="Predicted class", title=pred_file.stem).show()

#     # Check that there's only one true class for each EpiRR
#     assert df.groupby("EpiRR")["True class"].nunique().eq(1).all()
#     print("Total number of EpiRRs:", df["EpiRR"].nunique())

#     threshold_df = df.groupby(["EpiRR","harmonized_donor_type", "True class","Predicted class", "assay_epiclass"]).agg({'Max pred': ['mean', 'median', 'count']})
#     display(threshold_df.loc["IHECRE00003713.7"])
#     # classes = threshold_df.groupby("EpiRR").apply(get_majority_class_2)
#     # display(classes)
#     break

## Prediction distributions (per cell of confusion matrix)

Analyze prediction values of correct vs false predictions. Can we find a good prediction score threshold that lets us eliminate important errors?

In [None]:
# logdir = Path.home() / "downloads" / "temp"

# path = logdir / "sex3_oversample_full-10fold-validation_prediction_augmented-all.csv"
# df = pd.read_csv(path, index_col=0, header=0)

In [None]:
# classes = df["True class"].unique()

In [None]:
# df["harmonized_donor_sex"].value_counts()

In [None]:
# for label in classes:
#     df_label = df[df["True class"] == label]
#     fig = go.Figure()

#     # Iterate classes each target and add a violin plot for it
#     for target in classes:
#         vals = df_label[df_label["Predicted class"] == target]["Max pred"]
#         print(df_label["assay_epiclass"].value_counts())

#         fig.add_trace(
#             go.Violin(
#                 y=vals,
#                 name=f"{target} ({len(vals)})",
#                 box_visible=True,
#                 meanline_visible=True,
#                 points="all",
#             )
#         )

#     fig.update_layout(
#         title_text=f"Predicted value distribution for {label} ({df_label.shape[0]})",
#         yaxis_title="Prediction score",
#         xaxis_title="Target",
#     )
#     fig.update_yaxes(range=[1 / len(classes), 1.01])

#     fig.show()

In [None]:
def get_assay_list(df: pd.DataFrame) -> List[List[str]]:
    """Return list of assay labels. Includes rna and wgb label pairs."""
    assay_labels = df["assay_epiclass"].unique().tolist()
    assay_labels = [[assay_label] for assay_label in assay_labels]
    assay_labels = assay_labels + [
        ["mrna_seq", "rna_seq"],
        ["wgbs-standard", "wgbs-pbat"],
    ]
    return assay_labels

## Sex chrY coverage information

In [None]:
logdir = Path.home() / "downloads" / "temp"

In [None]:
path = logdir / "sex3_oversample_full-10fold-validation_prediction_augmented-all.csv"
sex_df = pd.read_csv(path, index_col=0, header=0)

In [None]:
coverage_path = logdir / "chrXY_coverage_all.csv"
coverage_df = pd.read_csv(coverage_path, index_col=0, header=0)

In [None]:
coverage_df.head()

In [None]:
merged_df = sex_df.merge(coverage_df, left_index=True, right_index=True, how="inner")

In [None]:
# merged_df.to_csv(
#     logdir / "sex3_oversample_full-10fold-validation_prediction_augmented-all-chrY.csv"
# )

In [None]:
for df in [df, coverage_df, merged_df]:
    print(df.shape)

In [None]:
merged_df = merged_df[~merged_df[TRACK].str.contains(pat="pval|fc", case=False)]
# merged_df = merged_df[~merged_df[ASSAY].str.contains(pat="wgb|input", case=False)]

In [None]:
COLORS_DICT = {"female": "red", "male": "blue", "mixed": "purple"}

In [None]:
assay_labels = get_assay_list(merged_df)

### All samples (1 sample = 1 data point)

In [None]:
logdir_10fold_per_assay = (
    general_local_logdir
    / "chrY_coverage_results"
    / "10fold_valid"
    / "conf_matrix_per_assay"
)

In [None]:
# Initialize a 3x3 subplot layout
coverage_label = "chrY"
classes = merged_df["True class"].unique()

# Iterate over each label to populate the subplots
for assay_list in assay_labels:
    assay_df = merged_df[merged_df[ASSAY].isin(assay_list)]
    for threshold in [0, 0.7, 0.9]:
        row = 1
        col = 1
        fig = make_subplots(
            rows=3,
            cols=3,
            shared_yaxes=True,
            x_title="Predicted class (nb of predictions)",
            y_title="Mean coverage",
            row_titles=list(classes),
            column_titles=list(classes),
            vertical_spacing=0.08,
            horizontal_spacing=0.01,
        )
        threshold_df = assay_df[assay_df["Max pred"] >= threshold]
        for label in classes:
            df_label = threshold_df[threshold_df["True class"] == label]

            # Iterate over each target and add a violin plot for it
            for target in classes:
                sub_df = df_label[df_label["Predicted class"] == target]

                if len(assay_list) == 1:
                    hovertext = [
                        f"{md5sum}:(chrY={chrY_val:.3f}, pred={pred:.3f})"
                        for md5sum, pred, chrY_val in zip(
                            sub_df.index, sub_df["Max pred"], sub_df[coverage_label]
                        )
                    ]
                else:
                    hovertext = [
                        f"{md5sum},{assay}:(chrY={chrY_val:.3f}, pred={pred:.3f})"
                        for md5sum, pred, chrY_val, assay in zip(
                            sub_df.index,
                            sub_df["Max pred"],
                            sub_df[coverage_label],
                            sub_df[ASSAY],
                        )
                    ]
                fig.add_trace(
                    go.Violin(
                        y=sub_df[coverage_label],
                        name=f"{target} ({sub_df.shape[0]})",
                        box_visible=True,
                        meanline_visible=True,
                        points="all",
                        text=hovertext,
                        line_color=COLORS_DICT[target],
                        hovertemplate="%{text}",
                    ),
                    row=row,
                    col=col,
                )

                # Move to the next subplot position
                col += 1
                if col > 3:
                    col = 1
                    row += 1

        # Update global layout and traces
        fig.update_traces(marker=dict(size=1))
        fig.update_yaxes(range=[-0.001, max(assay_df[coverage_label])])

        # Directly using annotations param does not work with make_subplots
        existing_annotations = fig.layout.annotations
        new_annotation = dict(
            x=1.01,  # Position on the x-axis
            y=0.5,  # Position on the y-axis
            showarrow=False,  # Do not show arrow
            text="Reference class",  # The text you want to display
            xref="paper",  # 'x' coordinate is set in relative coordinates
            yref="paper",  # 'y' coordinate is set in relative coordinates
            xanchor="left",  # Text starts from the left of the x-coordinate
            yanchor="middle",  # Middle aligned vertically
            font=dict(size=16),
            textangle=90,
        )
        updated_annotations = list(existing_annotations) + [new_annotation]

        title = f"Mean chrY coverage per file, {','.join(assay_list)} (pred>{threshold})<br>(no fc/pval)"

        fig.update_layout(
            title_text=f"{title} (n={threshold_df.shape[0]})",
            showlegend=False,
            annotations=updated_annotations,
        )

        fig.show()

        title = get_valid_filename(title).replace("_br_", "_")
        fig.write_html(logdir_10fold_per_assay / f"{title}.html")
        fig.write_image(logdir_10fold_per_assay / f"{title}.png", scale=2)

### chrY + chrX + ratio

In [None]:
# merged_df = merged_df[merged_df["Max pred"] > 0.9]
# merged_df = merged_df[
#     ~merged_df["assay_epiclass"].str.contains(pat="input|wgb", case=False)
# ]

# for label in classes:
#     df_label = merged_df[merged_df["Predicted class"] == label]
#     fig = go.Figure()

#     # Iterate classes each target and add a violin plot for it
#     for target in classes:
#         for coverage_label in ["chrY", "chrX", "chrY/chrX"]:
#             sub_df = df_label[df_label["True class"] == target]

#             fig.add_trace(
#                 go.Violin(
#                     y=sub_df[coverage_label],
#                     name=f"{target}: {coverage_label} ({sub_df.shape[0]})",
#                     box_visible=True,
#                     meanline_visible=True,
#                     points="all",
#                     text=sub_df.index,
#                 )
#             )

#     # title = f"Coverage distribution for prediction {label}"
#     title = f"Coverage distribution for prediction {label}, max_pred > 0.9"
#     fig.update_layout(
#         title_text=f"{title} ({df_label.shape[0]})",
#         yaxis_title="Mean coverage",
#         xaxis_title="True class"
#     )
#     fig.update_traces(marker=dict(size=1))
#     fig.update_yaxes(range=[-0.001, 2])


#     fig.show()

#     title = get_valid_filename(title)
#     # fig.write_html(logdir / f"{title}.html")
#     # fig.write_image(logdir / f"{title}.png", scale=2)

In [None]:
# for label in classes:
#     df_label = merged_df[merged_df["True class"] == label]
#     fig = go.Figure()

#     # Iterate classes each target and add a violin plot for it
#     for target in classes:
#         for coverage_label in ["chrY", "chrX", "chrY/chrX"]:
#             sub_df = df_label[df_label["Predicted class"] == target]

#             fig.add_trace(
#                 go.Violin(
#                     y=sub_df[coverage_label],
#                     name=f"{target}: {coverage_label} ({sub_df.shape[0]})",
#                     box_visible=True,
#                     meanline_visible=True,
#                     points="all",
#                     text=sub_df.index,
#                 )
#             )

#     # title = f"Coverage distribution for label {label}"
#     title = f"Coverage distribution for label {label}, max_pred > 0.9"
#     fig.update_layout(
#         title_text=f"{title} ({df_label.shape[0]})",
#         yaxis_title="Mean coverage",
#         xaxis_title="Predicted class",
#     )
#     fig.update_yaxes(range=[-0.001, 1.5])
#     fig.update_traces(marker=dict(size=1))

#     fig.show()

#     title = get_valid_filename(title)
#     # fig.write_html(logdir / f"{title}.html")
#     # fig.write_image(logdir / f"{title}.png", scale=2)

### epiRR version (1 epiRR ~ 1 data point)

In [None]:
classes = merged_df["Predicted class"].unique()

epirr_df = (
    merged_df.groupby(["EpiRR", "True class", "Predicted class"])
    .agg({"Max pred": ["mean", "median"], "chrY": ["mean", "median"], "EpiRR": ["count"]})
    .reset_index()
    .set_index("EpiRR")
)

In [None]:
epirr_df

In [None]:
# test = epirr_df[
#     (epirr_df["True class"] == "mixed")
#     & (epirr_df["Predicted class"] == "female")
#     & (~epirr_df["track_type"].str.contains("fc|pval"))
# ]
# display(test)
# print(test.index.value_counts().shape)

In [None]:
# needed_columns = ["True class", "Predicted class", "EpiRR", "Max pred"]
# merged_df[
#     (merged_df["True class"] == "mixed")
#     & (merged_df["Predicted class"] == "female")
#     & (~merged_df["track_type"].str.contains("fc|pval"))
# ][needed_columns].shape

In [None]:
coverage_label = "chrY"

for metric, pred_threshold in itertools.product(["mean", "median"], [0, 0.7, 0.9]):
    fig = make_subplots(
        rows=3,
        cols=3,
        shared_yaxes=True,
        x_title="Predicted class (nb of epiRR)",
        y_title=f"{metric} coverage",
        row_titles=list(classes),
        column_titles=list(classes),
        vertical_spacing=0.08,
        horizontal_spacing=0.01,
    )

    row = 1
    col = 1
    threshold_sub_df = epirr_df[epirr_df["Max pred"][f"{metric}"] > pred_threshold]
    for label in classes:
        df_label = threshold_sub_df[threshold_sub_df["True class"] == label]
        for target in classes:
            sub_df = df_label[df_label["Predicted class"] == target]

            hovertext = [
                f"{epirr} (n={count}) pred:{pred:.02f}"
                for (epirr, count), pred in zip(
                    sub_df.index, sub_df["Max pred"][f"{metric}"]
                )
            ]
            fig.add_trace(
                go.Violin(
                    y=sub_df[coverage_label][f"{metric}"],
                    name=f"{target}: {metric}({coverage_label}) ({sub_df.shape[0]})",
                    box_visible=True,
                    meanline_visible=True,
                    points="all",
                    line_color=COLORS_DICT[target],
                    text=hovertext,
                    hovertemplate="%{text}",
                ),
                row=row,
                col=col,
            )

            # Move to the next subplot position
            col += 1
            if col > 3:
                col = 1
                row += 1

    # Update global layout and traces
    fig.update_traces(marker=dict(size=1))
    fig.update_yaxes(range=[-0.001, 1.5])

    # Directly using annotations param does not work with make_subplots
    existing_annotations = fig.layout.annotations
    new_annotation = dict(
        x=1.01,  # Position on the x-axis
        y=0.5,  # Position on the y-axis
        showarrow=False,  # Do not show arrow
        text="Reference class",  # The text you want to display
        xref="paper",  # 'x' coordinate is set in relative coordinates
        yref="paper",  # 'y' coordinate is set in relative coordinates
        xanchor="left",  # Text starts from the left of the x-coordinate
        yanchor="middle",  # Middle aligned vertically
        font=dict(size=16),
        textangle=90,
    )
    updated_annotations = list(existing_annotations) + [new_annotation]

    title = f"Coverage distribution of {metric}({coverage_label}) per epiRR<br>{metric}(max_pred) > {pred_threshold} (no fc/pval/input/wgb)"

    fig.update_layout(
        title_text=f"{title} (n={threshold_sub_df.shape[0]})",
        showlegend=False,
        annotations=updated_annotations,
    )

    fig.show()

    # title = get_valid_filename(title).replace("_br_", "_")
    # fig.write_html(logdir / f"{title}.html")
    # fig.write_image(logdir / f"{title}.png", scale=2)

## chrY - unknown samples

In [None]:
coverage_path = general_local_logdir / "chrY_coverage_results" / "chrXY_coverage_all.csv"
coverage_df = pd.read_csv(coverage_path, index_col=0, header=0)

In [None]:
unknown_predict_path = (
    general_local_logdir
    / "sex3_complete_no_valid_oversample_test_prediction_100kb_all_none_dfreeze_v2.1_sex_mixed_unknown_augmented-all.csv"
)
unknown_predict_df = pd.read_csv(unknown_predict_path, index_col=0, header=0)

In [None]:
label = "unknown"
unknown_predict_df = unknown_predict_df[unknown_predict_df["True class"] == label]

In [None]:
unknown_predict_df = unknown_predict_df[
    ~unknown_predict_df[TRACK].str.contains(pat="pval|fc", case=False)
]
unknown_predict_df = unknown_predict_df.merge(
    coverage_df, left_index=True, right_index=True, how="inner"
)

### All samples (1 sample = 1 data point)

In [None]:
unknown_logdir = general_local_logdir / "chrY_coverage_results" / "unknown_per_assay"

In [None]:
classes = unknown_predict_df["Predicted class"].unique()
coverage_label = "chrY"

assay_labels = get_assay_list(unknown_predict_df)

In [None]:
for assay_list in assay_labels:
    # Initialize subplots figure
    fig = make_subplots(
        rows=3,
        cols=1,
        subplot_titles=("pred>0", "pred>0.7", "pred>0.9"),
        vertical_spacing=0.075,
        x_title="Predicted class (nb of predictions)",
        y_title="Mean coverage",
    )

    assay_sub_df = unknown_predict_df[
        unknown_predict_df["assay_epiclass"].isin(assay_list)
    ]

    for idx, pred_threshold in enumerate([0, 0.7, 0.9]):
        threshold_sub_df = assay_sub_df[assay_sub_df["Max pred"] > pred_threshold]

        for target in classes:
            sub_df = threshold_sub_df[threshold_sub_df["Predicted class"] == target]

            if len(assay_list) == 1:
                hovertext = [
                    f"{md5sum}:(chrY={chrY_val:.3f}, pred={pred:.3f})"
                    for md5sum, pred, chrY_val in zip(
                        sub_df.index, sub_df["Max pred"], sub_df[coverage_label]
                    )
                ]
            else:
                hovertext = [
                    f"{md5sum},{assay}:(chrY={chrY_val:.3f}, pred={pred:.3f})"
                    for md5sum, pred, chrY_val, assay in zip(
                        sub_df.index,
                        sub_df["Max pred"],
                        sub_df[coverage_label],
                        sub_df[ASSAY],
                    )
                ]

            # Add traces with checks for empty subsets
            if sub_df.shape[0] == 0:
                y_values = [
                    threshold_sub_df[coverage_label].mean()
                ]  # Minimal synthetic data
                sample_count = 0
                hovertext = ["PLACEHOLDER - NO DATA"]
            else:
                y_values = sub_df[coverage_label]
                sample_count = sub_df.shape[0]

            fig.add_trace(
                go.Violin(
                    y=y_values,
                    name=f"{target}: {coverage_label} ({sample_count})",
                    box_visible=True,
                    meanline_visible=True,
                    points="all",
                    text=hovertext,
                    hovertemplate="%{text}",
                    line_color=COLORS_DICT[target],
                    legendgroup=target,
                ),
                row=idx + 1,
                col=1,
            )

    title = f"Coverage distribution for {coverage_label} in {','.join(assay_list)} (no fc/pval)"
    fig.update_layout(title_text=f"{title}", height=1200)

    # Update y-axis range
    try:
        fig.update_yaxes(range=[-0.001, max(assay_sub_df[coverage_label])])
    except ValueError:
        # Set a default y-axis range when no samples are available
        fig.update_yaxes(range=[-0.001, 1])

    fig.update_traces(marker=dict(size=1))

    fig.show()

    title = get_valid_filename(title)
    fig.write_html(unknown_logdir / f"{title}.html")
    fig.write_image(unknown_logdir / f"{title}.png", scale=2)

### epiRR version (1 epiRR ~ 1 data point)

In [None]:
unknown_predict_df = unknown_predict_df[
    ~unknown_predict_df[ASSAY].str.contains(pat="wgb", case=False)
]

In [None]:
classes = unknown_predict_df["Predicted class"].unique()

epirr_df = (
    unknown_predict_df.groupby(["EpiRR", "Predicted class"])
    .agg({"Max pred": ["mean", "median"], "chrY": ["mean", "median"], "EpiRR": ["count"]})
    .reset_index()
    .set_index("EpiRR")
)

In [None]:
print(unknown_predict_df.shape)
print(epirr_df.shape)

In [None]:
epirr_df.head(n=10)

In [None]:
logdir_unknown_epirr = (
    general_local_logdir / "chrY_coverage_results" / "unknown_per_epirr"
)

In [None]:
coverage_label = "chrY"
thresholds = [0, 0.7, 0.9]

for agg_metric in ["mean", "median"]:
    subplot_titles = [f"{agg_metric}(pred)>{threshold}" for threshold in thresholds]
    fig = make_subplots(
        rows=3,
        cols=1,
        subplot_titles=subplot_titles,
        vertical_spacing=0.075,
        x_title="Predicted class (nb of predictions)",
        y_title="agg(Mean chrY coverage) ",
    )

    for row_idx, pred_threshold in enumerate(thresholds):
        threshold_sub_df = epirr_df[
            epirr_df["Max pred"][f"{agg_metric}"] > pred_threshold
        ]

        for target in classes:
            sub_df = threshold_sub_df[threshold_sub_df["Predicted class"] == target]

            # Add traces with checks for empty subsets
            if sub_df.shape[0] == 0:
                y_values = [
                    threshold_sub_df[coverage_label][f"{agg_metric}"].mean()
                ]  # Minimal synthetic data
                sample_count = 0
                sample_text = ["PLACEHOLDER - NO DATA"]
            else:
                y_values = sub_df[coverage_label][f"{agg_metric}"]
                sample_count = sub_df.shape[0]
                sample_text = [
                    (f"{value:.3f}", epirr, f"{agg_metric}={pred:.3f}(n={count})")
                    for value, (epirr, count), pred in zip(
                        y_values, sub_df.index, sub_df["Max pred"][f"{agg_metric}"]
                    )
                ]

            fig.add_trace(
                go.Violin(
                    y=y_values,
                    name=f"{target}: {agg_metric}({coverage_label}) ({sample_count})",
                    box_visible=True,
                    meanline_visible=True,
                    points="all",
                    text=sample_text,
                    hovertemplate="%{text}",
                    line_color=COLORS_DICT[target],
                    legendgroup=target,
                ),
                row=row_idx + 1,
                col=1,
            )

    title = f"Coverage distribution of {agg_metric}({coverage_label}) for {label} per epiRR (no fc/pval/wgb)"
    fig.update_layout(
        title_text=f"{title}", height=1200  # Adjust the overall height of the figure
    )

    # Update y-axis range
    try:
        fig.update_yaxes(range=[-0.001, max(epirr_df[coverage_label][f"{agg_metric}"])])
    except ValueError as e:
        fig.update_yaxes(range=[-0.001, 1])

    fig.update_traces(marker=dict(size=1))

    fig.show()

    title = get_valid_filename(title).replace("_br_", "_")
    fig.write_html(logdir_unknown_epirr / f"{title}.html")
    fig.write_image(logdir_unknown_epirr / f"{title}.png", scale=2)