In [None]:
"""Workbook to create figures destined for the paper."""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import annotations

import itertools
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from sklearn.metrics import confusion_matrix as sk_cm

from epi_ml.core.confusion_matrix import ConfusionMatrixWriter
from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    ASSAY_ORDER,
    CELL_TYPE,
    IHECColorMap,
    MetadataHandler,
    SplitResultsHandler,
    merge_similar_assays,
)

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
paper_dir = base_dir

In [None]:
IHECColorMap = IHECColorMap(base_fig_dir)
assay_colors = IHECColorMap.assay_color_map
cell_type_colors = IHECColorMap.cell_type_color_map

In [None]:
base_fig_dir = base_fig_dir / "fig1_EpiAtlas_assay"

In [None]:
split_results_handler = SplitResultsHandler()

In [None]:
metadata_handler = MetadataHandler(paper_dir)
metadata_v2 = metadata_handler.load_metadata("v2")

### Prepare assay predictions data

In [None]:
data_dir_100kb = base_data_dir / "training_results" / "dfreeze_v2" / "hg38_100kb_all_none"
all_split_dfs = split_results_handler.gather_split_results_across_methods(
    results_dir=data_dir_100kb,
    label_category=ASSAY,
    only_NN=False,
)

In [None]:
full_dfs = split_results_handler.concatenate_split_results(all_split_dfs)
merged_dfs = {classifier: merge_similar_assays(df) for classifier, df in full_dfs.items()}
assays = merged_dfs[next(iter(merged_dfs))]["True class"].unique()

# Add Max pred
for classifier, df in merged_dfs.items():
    df["Max pred"] = df[assays].max(axis=1)

# Join metadata
for classifier, df in merged_dfs.items():
    merged_dfs[classifier] = metadata_handler.join_metadata(df, metadata_v2)

### Prediction score per assay

Average distribution of prediction scores per assay 
violin plot. One point per UUID, track types averaged (combine 2xRNA and 2xWGBS)
points with 3 colors: 
- black for pred same class
- red for pred different class/mislabel
- orange bad qual (IHEC flag, was removed in later stages)

In [None]:
def fig1_a(
    NN_results: pd.DataFrame,
    logdir: Path,
    name: str,
    merge_assay_pairs: bool,
    min_y: float = 0.7,
    title: str | None = None,
) -> None:
    """
    Creates a Plotly figure with violin plots and associated scatter plots for each class.
    Red scatter points, indicating a mismatch, appear on top and have a larger size.

    Args:
        NN_results (pd.DataFrame): The DataFrame containing the neural network results, including metadata.
        logdir (Path): The directory where the figure will be saved.
        name (str): The name of the figure.
        merge_assay_pairs (bool): Whether to merge similar assays (mrna/rna, wgbs-pbat/wgbs-standard)
    Returns:
        None: Displays the plotly figure.
    """
    fig = go.Figure()
    # Combine similar assays
    if merge_assay_pairs:
        NN_results = merge_similar_assays(NN_results)

    # Adjustments for replacement and class ordering
    class_index = {label: i for i, label in enumerate(ASSAY_ORDER)}

    scatter_offset = 0.05  # Scatter plot jittering

    for label in ASSAY_ORDER:
        df = NN_results[NN_results["True class"] == label]

        # Majority vote, mean prediction score
        groupby_epirr = df.groupby(["EpiRR", "Predicted class"])["Max pred"].aggregate(
            ["size", "mean"]
        )

        groupby_epirr = groupby_epirr.reset_index().sort_values(
            ["EpiRR", "size"], ascending=[True, False]
        )
        groupby_epirr = groupby_epirr.drop_duplicates(subset="EpiRR", keep="first")
        assert groupby_epirr["EpiRR"].is_unique

        mean_pred = groupby_epirr["mean"]

        line_color = "white"
        # fig.add_trace(
        #     go.Violin(
        #         y=mean_pred,
        #         name=label,
        #         spanmode="hard",
        #         box_visible=True,
        #         meanline_visible=True,
        #         points="all",
        #         pointpos=0,
        #         fillcolor=assay_colors[label],
        #         line_color=line_color,
        #         line=dict(width=0.5),
        #         marker=dict(size=2, color="black"),
        #         showlegend=False,
        #     )
        # )

        fig.add_trace(
            go.Violin(
                x=[class_index[label]] * len(mean_pred),
                y=mean_pred,
                name=label,
                spanmode="hard",
                box_visible=True,
                meanline_visible=True,
                points=False,
                fillcolor=assay_colors[label],
                line_color=line_color,
                line=dict(width=0.5),
                marker=dict(size=2, color="black"),
                showlegend=False,
            )
        )

        # Prepare data for scatter plots
        jittered_x_positions = np.random.uniform(-scatter_offset, scatter_offset, size=len(mean_pred)) + class_index[label] - 0.25  # type: ignore

        match_pred = [
            mean_pred.iloc[i]
            for i, row in enumerate(groupby_epirr.iterrows())
            if row[1]["Predicted class"] == label
        ]
        mismatch_pred = [
            mean_pred.iloc[i]
            for i, row in enumerate(groupby_epirr.iterrows())
            if row[1]["Predicted class"] != label
        ]

        match_x_positions = [
            jittered_x_positions[i]
            for i, row in enumerate(groupby_epirr.iterrows())
            if row[1]["Predicted class"] == label
        ]
        mismatch_x_positions = [
            jittered_x_positions[i]
            for i, row in enumerate(groupby_epirr.iterrows())
            if row[1]["Predicted class"] != label
        ]

        # Add scatter plots for matches in black
        fig.add_trace(
            go.Scatter(
                x=match_x_positions,
                y=match_pred,
                mode="markers",
                name=f"Match {label}",
                marker=dict(
                    color="black",
                    size=1,  # Standard size for matches
                ),
                hovertemplate="%{text}",
                text=[
                    f"EpiRR: {row[1]['EpiRR']}, Pred class: {row[1]['Predicted class']}, Mean pred: {row[1]['mean']:.2f}"
                    for row in groupby_epirr.iterrows()
                    if row[1]["Predicted class"] == label
                ],
                showlegend=False,
                legendgroup="match",
            )
        )

        # Add scatter plots for mismatches in red, with larger size
        fig.add_trace(
            go.Scatter(
                x=mismatch_x_positions,
                y=mismatch_pred,
                mode="markers",
                name=f"Mismatch {label}",
                marker=dict(
                    color="red",
                    size=3,  # Larger size for mismatches
                ),
                hovertemplate="%{text}",
                text=[
                    f"EpiRR: {row[1]['EpiRR']}, Pred class: {row[1]['Predicted class']}, Mean pred: {row[1]['mean']:.3f}"
                    for row in groupby_epirr.iterrows()
                    if row[1]["Predicted class"] != label
                ],
                showlegend=False,
                legendgroup="mismatch",
            )
        )

    # Update layout to improve visualization
    fig.update_yaxes(range=[min_y, 1.001])
    fig.update_xaxes(tickvals=np.arange(len(ASSAY_ORDER)), ticktext=ASSAY_ORDER)

    title_text = "Prediction score distribution per assay class"
    if title is not None:
        title_text += f" - {title}"
    fig.update_layout(
        title_text=title_text,
        yaxis_title="Avg. prediction score (majority class)",
        xaxis_title="Expected class label",
    )

    # Add a dummy scatter plot for legend - black points
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            name="Match",
            marker=dict(color="black", size=10),
            showlegend=True,
            legendgroup="match",
        )
    )

    # Add a dummy scatter plot for legend - red points
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            name="Mismatch",
            marker=dict(color="red", size=10),
            showlegend=True,
            legendgroup="mismatch",
        )
    )

    # Update the layout to adjust the legend
    fig.update_layout(
        legend=dict(
            title_text="Legend",
            itemsizing="constant",
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1,
        )
    )

    # Save figure
    fig.write_html(logdir / f"{name}.html")
    fig.write_image(logdir / f"{name}.svg")
    fig.write_image(logdir / f"{name}.png")

    fig.show()

In [None]:
NN_results = merged_dfs["NN"]

In [None]:
logdir = base_fig_dir / "fig1--pred_score_per_assay"
# fig1_a(NN_results, logdir=logdir, name="pred_score_per_assay_internal", merge_assay_pairs=False)

### Performance of classification algorithm (boxplot)

Violin plot (10 folds) of overall accuracy for each model (NN, LR, RF, LGBM, SV).  
For each split, 4 box plot per model:
  - Acc
  - F1
  - AUROC (OvR, both micro/macro)

Source files:  
epiatlas-dfreeze-v2.1/hg38_100kb_all_none (oversampling)

In [None]:
ALL_CLASSIFIERS = ["NN", "LR", "LGBM", "LinearSVC", "RF"]

In [None]:
def plot_split_metrics_one_algo(
    split_metrics: Dict[str, Dict[str, Dict[str, float]]],
    label_category: str,
    logdir: Path,
    name: str,
    classifier_name: str = "NN",
) -> None:
    """Render to box plots the metrics per classifier and split, each in its own subplot.

    Args:
        split_metrics: A dictionary containing metric scores for each classifier and split.
    """
    if classifier_name not in ALL_CLASSIFIERS:
        raise ValueError(f"Invalid classifier name: {classifier_name}")
    metrics = ["Accuracy", "F1_macro", "AUC_micro", "AUC_macro"]

    # Create subplots, one row for each metric
    fig = make_subplots(rows=1, cols=2, subplot_titles=["Accuracy/F1", "AUC micro/macro"])

    for i, metric in enumerate(metrics):
        values = [split_metrics[split][classifier][metric] for split in split_metrics]

        fig.add_trace(
            go.Box(
                y=values,
                name=metric,
                line=dict(color="black", width=1.5),
                marker=dict(size=3, color="black"),
                boxmean=True,
                boxpoints="all",  # or "outliers" to show only outliers
                pointpos=-1.4,
                showlegend=False,
                hovertemplate="%{text}",
                text=[
                    f"{split}: {value:.4f}" for split, value in zip(split_metrics, values)
                ],
            ),
            row=1,
            col=1 if i < 2 else 2,
        )

    fig.update_layout(
        title_text=f"{label_category} {classifier_name} classification - Metric distribution for 10fold cross-validation",
        yaxis_title="Value",
    )

    # Adjust y-axis
    if label_category == ASSAY:
        # range_acc = [0.95, 1]
        # range_AUC = [0.998, 1]
        range_acc = [0.98, 1.001]
        range_AUC = [0.996, 1.0001]
    elif label_category == CELL_TYPE:
        range_acc = [0.93, 1]
        range_AUC = [0.996, 1]
    else:
        range_acc = [0.6, 1.001]
        range_AUC = [0.9, 1.0001]

    fig.update_layout(yaxis=dict(range=range_acc))
    fig.update_layout(yaxis2=dict(range=range_AUC))

    # Save figure
    fig.write_image(logdir / f"{name}.svg")
    fig.write_image(logdir / f"{name}.png")
    fig.write_html(logdir / f"{name}.html")

    fig.show()

In [None]:
def plot_split_metrics(
    split_metrics: Dict[str, Dict[str, Dict[str, float]]],
    label_category: str,
    logdir: Path,
    name: str,
) -> None:
    """Render to box plots the metrics per classifier and split, each in its own subplot.

    Args:
        split_metrics: A dictionary containing metric scores for each classifier and split.
    """
    metrics = ["Accuracy", "F1_macro", "AUC_micro", "AUC_macro"]
    classifier_names = list(next(iter(split_metrics.values())).keys())
    classifier_names = ["NN", "LR", "LGBM", "LinearSVC", "RF"]

    # # Sort classifiers by accuracy
    # mean_acc = {}
    # for classifier in classifier_names:
    #     mean_acc[classifier] = np.mean(
    #         [split_metrics[split][classifier]["Accuracy"] for split in split_metrics]
    #     )
    # classifier_names = sorted(classifier_names, key=lambda x: mean_acc[x], reverse=True)

    # Create subplots, one row for each metric
    fig = make_subplots(
        rows=1,
        cols=len(metrics),
        subplot_titles=metrics,
        horizontal_spacing=0.075,
    )

    for i, metric in enumerate(metrics):
        for classifier in classifier_names:
            values = [split_metrics[split][classifier][metric] for split in split_metrics]

            fig.add_trace(
                go.Box(
                    y=values,
                    name=classifier,
                    line=dict(color="black", width=1.5),
                    marker=dict(size=3, color="black"),
                    boxmean=True,
                    boxpoints="all",  # or "outliers" to show only outliers
                    pointpos=-1.4,
                    showlegend=False,
                    width=0.5,
                    hovertemplate="%{text}",
                    text=[
                        f"{split}: {value:.4f}"
                        for split, value in zip(split_metrics, values)
                    ],
                ),
                row=1,
                col=i + 1,
            )

    fig.update_layout(
        title_text=f"{label_category} classification - Metric distribution for 10fold cross-validation",
        yaxis_title="Value",
        boxmode="group",
    )

    # Adjust y-axis
    if label_category == ASSAY:
        # range_acc = [0.955, 1.001]
        # range_AUC = [0.992, 1.0001]
        range_acc = [0.98, 1.001]
        range_AUC = [0.996, 1.0001]
    elif label_category == CELL_TYPE:
        range_acc = [0.81, 1]
        range_AUC = [0.96, 1]
    else:
        range_acc = [0.6, 1.001]
        range_AUC = [0.9, 1.0001]

    fig.update_layout(yaxis=dict(range=range_acc))
    fig.update_layout(yaxis2=dict(range=range_acc))
    fig.update_layout(yaxis3=dict(range=range_AUC))
    fig.update_layout(yaxis4=dict(range=range_AUC))

    # Save figure
    fig.write_image(logdir / f"{name}.svg")
    fig.write_image(logdir / f"{name}.png")
    fig.write_html(logdir / f"{name}.html")

    fig.show()

In [None]:
data_dir_100kb = base_data_dir / "training_results" / "dfreeze_v2" / "hg38_100kb_all_none"
fig_dir = base_fig_dir / "fig1--boxplot_perf_all_algorithms"
merge_assays = True

# fig dir path
assay_fig_dir = fig_dir / "assay_computed"
if merge_assays:
    assay_fig_dir = Path(f"{assay_fig_dir}_9c")
else:
    assay_fig_dir = Path(f"{assay_fig_dir}_11c")

for label_category in [ASSAY, CELL_TYPE]:
    all_split_dfs = split_results_handler.gather_split_results_across_methods(
        results_dir=data_dir_100kb, label_category=label_category
    )

    this_fig_dir = assay_fig_dir if label_category == ASSAY else fig_dir

    name_base = f"{label_category}_10fold_metrics"
    if merge_assays and label_category == ASSAY:
        name_base += "_merged_assays"

        for split_name, split_dfs in all_split_dfs.items():
            for classifier_type, df in split_dfs.items():
                all_split_dfs[split_name][classifier_type] = merge_similar_assays(df)

    # split_metrics = split_results_handler.compute_split_metrics(all_split_dfs)

    # plot_split_metrics(
    #     split_metrics,
    #     label_category=label_category,
    #     logdir=this_fig_dir,
    #     name=f"{name_base}_all_classifiers_y0.98",
    # )

    # plot_split_metrics_one_algo(
    #     split_metrics,
    #     label_category=label_category,
    #     logdir=this_fig_dir,
    #     name=f"{name_base}_NN_y0.98",
    # )
    # break

### Prediction score per assay across classifiers

Per model, compute score distribution per assay (1 violin per assay). No SVM. Agree black, red disagree.

In [None]:
def fig1_supp_B(df_dict: Dict[str, pd.DataFrame], logdir: Path, name: str) -> None:
    """
    Creates a Plotly figure with subplots for each assay, each containing violin plots for different classifiers
    and associated scatter plots for matches (in black) and mismatches (in red).

    Args:
        df_dict (Dict[str, pd.DataFrame]): Dictionary with the DataFrame containing the results for each classifier.
        logdir (Path): The directory path for saving the figures.
        name (str): The name for the saved figures.

    Returns:
        None: Displays the plotly figure.
    """
    class_labels_sorted = ASSAY_ORDER
    num_assays = len(class_labels_sorted)

    classifiers = list(df_dict.keys())
    try:
        classifiers.remove("LinearSVC")
        classifiers.remove("RF")
    except ValueError:
        pass
    classifier_index = {name: i for i, name in enumerate(classifiers)}
    num_classifiers = len(classifiers)

    scatter_offset = 0.05  # Scatter plot jittering

    # Calculate the size of the grid
    grid_size = int(np.ceil(np.sqrt(num_assays)))
    rows, cols = grid_size, grid_size

    # Create subplots with a square grid
    fig = make_subplots(
        rows=rows,
        cols=cols,
        subplot_titles=class_labels_sorted,
        shared_yaxes="all",  # type: ignore
        horizontal_spacing=0.05,
        vertical_spacing=0.05,
        y_title="Average prediction score",
    )
    for idx, label in enumerate(class_labels_sorted):
        row, col = divmod(idx, grid_size)
        for classifier_name in classifiers:
            df = df_dict[classifier_name]
            df = df[df["True class"] == label]

            # Majority vote, mean prediction score
            groupby_epirr = df.groupby(["EpiRR", "Predicted class"])[
                "Max pred"
            ].aggregate(["size", "mean"])
            groupby_epirr = groupby_epirr.reset_index().sort_values(
                ["EpiRR", "size"], ascending=[True, False]
            )
            groupby_epirr = groupby_epirr.drop_duplicates(subset="EpiRR", keep="first")
            assert groupby_epirr["EpiRR"].is_unique

            mean_pred = groupby_epirr["mean"]
            classifier_pos = classifier_index[classifier_name]

            # Add violin plot with integer x positions
            fig.add_trace(
                go.Violin(
                    x=classifier_pos * np.ones(len(mean_pred)),
                    y=mean_pred,
                    name=label,
                    spanmode="hard",
                    box_visible=True,
                    meanline_visible=True,
                    points=False,
                    fillcolor="grey",
                    line_color="black",
                    line=dict(width=0.8),
                    showlegend=False,
                ),
                row=row + 1,  # Plotly rows are 1-indexed
                col=col + 1,
            )

            # Prepare data for scatter plots
            jittered_x_positions = np.random.uniform(-scatter_offset, scatter_offset, size=len(mean_pred)) + classifier_pos - 0.3  # type: ignore

            match_pred = [
                mean_pred.iloc[i]
                for i, row in enumerate(groupby_epirr.iterrows())
                if row[1]["Predicted class"] == label
            ]
            mismatch_pred = [
                mean_pred.iloc[i]
                for i, row in enumerate(groupby_epirr.iterrows())
                if row[1]["Predicted class"] != label
            ]

            match_x_positions = [
                jittered_x_positions[i]
                for i, row in enumerate(groupby_epirr.iterrows())
                if row[1]["Predicted class"] == label
            ]
            mismatch_x_positions = [
                jittered_x_positions[i]
                for i, row in enumerate(groupby_epirr.iterrows())
                if row[1]["Predicted class"] != label
            ]

            # Add scatter plots for matches in black
            fig.add_trace(
                go.Scatter(
                    x=match_x_positions,
                    y=match_pred,
                    mode="markers",
                    marker=dict(color="black", size=1),
                    showlegend=False,
                    name=f"Match {classifier_name}",
                    legendgroup="match",
                ),
                row=row + 1,  # Plotly rows are 1-indexed
                col=col + 1,
            )

            # Add scatter plots for mismatches in red
            fig.add_trace(
                go.Scatter(
                    x=mismatch_x_positions,
                    y=mismatch_pred,
                    mode="markers",
                    marker=dict(color="red", size=3),
                    showlegend=False,
                    name=f"Mismatch {classifier_name}",
                    legendgroup="mismatch",
                ),
                row=row + 1,  # Plotly rows are 1-indexed
                col=col + 1,
            )

    # Add a dummy scatter plot for legend - black points
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            name="Match",
            marker=dict(color="black", size=10),
            showlegend=True,
            legendgroup="match",
        )
    )

    # Add a dummy scatter plot for legend - red points
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            name="Mismatch",
            marker=dict(color="red", size=10),
            showlegend=True,
            legendgroup="mismatch",
        )
    )

    # Update the layout to adjust the legend
    fig.update_layout(
        legend=dict(
            title_text="Legend",
            itemsizing="constant",
            orientation="h",
            yanchor="bottom",
            y=1.025,
            xanchor="right",
            x=1,
        )
    )

    # Update layout to improve visualization, adjust if needed for better appearance with multiple classifiers
    fig.update_layout(
        title_text="Prediction score distribution per assay across classifiers",
        height=1500,  # Adjust the height as necessary
        width=1500,  # Adjust the width based on the number of assays
    )

    # fig.update_layout(yaxis2=dict(range=[0.9, 1.01]))

    # Adjust tick names
    # Assuming equal spacing between each classifier on the x-axis
    tickvals = list(
        range(0, num_classifiers + 1)
    )  # Generate tick values (1-indexed for Plotly)
    ticktext = classifiers  # Use classifier names as tick labels
    for i, j in itertools.product(range(rows), range(cols)):
        fig.update_xaxes(tickvals=tickvals, ticktext=ticktext, row=i + 1, col=j + 1)

    # Save figure
    fig.write_image(logdir / f"{name}.svg")
    fig.write_image(logdir / f"{name}.png")
    fig.write_html(logdir / f"{name}.html")

    fig.show()

In [None]:
# print(len(merged_dfs['NN']))
# for metadata_category in ["EpiRR", "uuid"]:
#     print(f"{metadata_category}: {len(merged_dfs['NN'][metadata_category].unique())}")

# display(merged_dfs["NN"][ASSAY].value_counts())
# display(merged_dfs["NN"]["track_type"].value_counts())

In [None]:
logdir = base_fig_dir / "fig1--pred_score_per_assay_across_classifiers"
# fig1_supp_B(merged_dfs, logdir=logdir, name="fig1_supp_B_3classifiers")

In [None]:
def fig1_supp_B_2(df_dict: Dict[str, pd.DataFrame], logdir: Path, name: str) -> None:
    """
    pred_score_per_assay_across_classifiers: 1 graph per classifier

    Args:
        df_dict (Dict[str, pd.DataFrame]): Dictionary with the DataFrame containing the results for each classifier.
        logdir (Path): The directory path for saving the figures.
        name (str): The name for the saved figures.

    Returns:
        None: Displays the plotly figure.
    """
    for classifier_name, classifier_df in df_dict.items():
        fig = go.Figure()
        for assay in ASSAY_ORDER:
            assay_df = classifier_df[classifier_df["True class"] == assay]

            # Majority vote, mean prediction score
            groupby_epirr = assay_df.groupby(["EpiRR", "Predicted class"])[
                "Max pred"
            ].aggregate(["size", "mean"])
            groupby_epirr = groupby_epirr.reset_index().sort_values(
                ["EpiRR", "size"], ascending=[True, False]
            )
            groupby_epirr = groupby_epirr.drop_duplicates(subset="EpiRR", keep="first")
            assert groupby_epirr["EpiRR"].is_unique

            fig.add_trace(
                go.Violin(
                    y=groupby_epirr["mean"],
                    name=assay,
                    spanmode="hard",
                    box_visible=True,
                    meanline_visible=True,
                    points="all",
                    pointpos=0,
                    fillcolor=assay_colors[assay],
                    line_color="white",
                    line=dict(width=0.5),
                    marker=dict(size=2, color="black"),
                    showlegend=False,
                )
            )

        fig.update_yaxes(range=[0.7, 1.001])
        fig.update_layout(
            title_text=f"{classifier_name} - Prediction score distribution per assay",
        )

        # Save figure
        this_name = f"{name}_{classifier_name}"
        fig.write_html(logdir / f"{this_name}.html")
        fig.write_image(logdir / f"{this_name}.svg")
        fig.write_image(logdir / f"{this_name}.png")

        fig.show()

In [None]:
logdir = base_fig_dir / "fig1--pred_score_per_assay_across_classifiers"
# fig1_supp_B_2(merged_dfs, logdir=logdir, name="fig1_supp_B_2")

### Confusions matrices across classification algorithms

For each classifier type

Confusion matrix (1point=1 uuid) for observed datasets with average scores>0.9
- Goal: Represent global predictions/mislabels. 11c

In [None]:
def create_confusion_matrix(
    df: pd.DataFrame,
    min_pred_score: float,
    logdir: Path,
    name: str,
    majority: bool = False,
) -> None:
    """Create a confusion matrix for the given DataFrame.

    Args:
        df (pd.DataFrame): The DataFrame containing the neural network results.
        min_pred_score (float): The minimum prediction score to consider.
        logdir (Path): The directory path for saving the figures.
        name (str): The name for the saved figures.
        majority (bool): Whether to use majority vote (uuid-wise) for the predicted class.
    """
    # Compute confusion matrix
    classes = sorted(df["True class"].unique())
    if "Max pred" not in df.columns:
        df["Max pred"] = df[classes].max(axis=1)  # type: ignore
    filtered_df = df[df["Max pred"] > min_pred_score]

    if majority:
        # Majority vote for predicted class
        groupby_uuid = filtered_df.groupby(["uuid", "True class", "Predicted class"])[
            "Max pred"
        ].aggregate(["size", "mean"])

        if groupby_uuid["size"].max() > 3:
            raise ValueError("More than three predictions for the same uuid")

        groupby_uuid = groupby_uuid.reset_index().sort_values(
            ["uuid", "True class", "size"], ascending=[True, True, False]
        )
        groupby_uuid = groupby_uuid.drop_duplicates(
            subset=["uuid", "True class"], keep="first"
        )
        filtered_df = groupby_uuid

    confusion_mat = sk_cm(
        filtered_df["True class"], filtered_df["Predicted class"], labels=classes
    )

    mat_writer = ConfusionMatrixWriter(labels=classes, confusion_matrix=confusion_mat)
    mat_writer.to_all_formats(logdir, name=f"{name}_n{len(filtered_df)}")

In [None]:
for min_pred_score, majority in itertools.product([0, 0.9], [True, False]):
    for classifier_name, df in full_dfs.items():
        # df_with_meta = metadata_handler.join_metadata(df, metadata_v2)
        # assert "Predicted class" in df_with_meta.columns
        # # if classifier_name != "NN":
        # #     continue

        # name = f"{classifier_name}_pred>{min_pred_score}"
        # if classifier_name == "LinearSVC":
        #     name = f"{classifier_name}"

        # logdir = base_fig_dir / "fig1_supp_F-assay_c11_confusion_matrices"
        # if majority:
        #     logdir = logdir / "per_uuid"
        # else:
        #     logdir = logdir / "per_file"

        # create_confusion_matrix(
        #     df=df_with_meta,
        #     min_pred_score=min_pred_score,
        #     logdir=logdir,
        #     name=name,
        #     majority=majority,
        # )

        pass

### Imputed prediction score & accuracy

Inference on imputed data: Violin plot with pred score per assay (like Fig1A)  
VS  
Inference on observed data, from imputed   

In [None]:
section_folder = base_fig_dir / "fig1--imputation_impact" / "complete_no_valid_oversample"
if not section_folder.exists():
    raise FileNotFoundError(f"Folder {section_folder} does not exist")

In [None]:
ca_metadata_path = (
    base_data_dir
    / "training_results"
    / "predictions"
    / "C-A"
    / "assay_epiclass"
    / "CA_metadata_4DB+all_pred_subset.20240606_mod2.tsv"
)
ca_metadata = pd.read_csv(ca_metadata_path, sep="\t")
ca_id_col = ca_metadata.columns[0]
ca_target = ca_metadata[[ca_id_col, "manual_target_consensus"]]

In [None]:
data_dir = base_data_dir / "training_results" / "imputation"

# Load data
observed_dir = (
    data_dir
    / "hg38_100kb_all_none"
    / "assay_epiclass_1l_3000n"
    / "chip-seq-only"
    / "complete_no_valid_oversample"
)
observed_inf_imputed_path = next((observed_dir / "predict_imputed").glob("*.csv"))
observed_inf_CA_path = next((observed_dir / "predict_C-A").glob("*.csv"))

observed_inf_imputed_df = pd.read_csv(
    observed_inf_imputed_path, header=0, index_col=0, low_memory=False
)
observed_inf_CA_df = pd.read_csv(
    observed_inf_CA_path, header=0, index_col=0, low_memory=False
)

imputed_dir = (
    data_dir
    / "hg38_100kb_all_none_imputed"
    / "assay_epiclass_1l_3000n"
    / "chip-seq-only"
    / "complete_no_valid_oversample"
)
imputed_inf_observed_path = next(
    (imputed_dir / "predict_epiatlas_pval_chip-seq").glob("*.csv")
)
imputed_inf_CA_path = next((imputed_dir / "predict_C-A").glob("*.csv"))

imputed_inf_observed_df = pd.read_csv(
    imputed_inf_observed_path, header=0, index_col=0, low_memory=False
)
imputed_inf_CA_df = pd.read_csv(
    imputed_inf_CA_path, header=0, index_col=0, low_memory=False
)

imputed_inf_CA_df = imputed_inf_CA_df.merge(
    ca_target, left_index=True, right_on=ca_id_col
)
observed_inf_CA_df = observed_inf_CA_df.merge(
    ca_target, left_index=True, right_on=ca_id_col
)
for df in [imputed_inf_CA_df, observed_inf_CA_df]:
    df["True class"] = df["manual_target_consensus"]
    df.drop("manual_target_consensus", axis=1, inplace=True)

In [None]:
assay_labels = observed_inf_imputed_df["True class"].unique()
var_names = [
    "observed_inf_imputed",
    "observed_inf_C-A",
    "imputed_inf_observed",
    "imputed_inf_C-A",
]
var_list = [
    observed_inf_imputed_df,
    observed_inf_CA_df,
    imputed_inf_observed_df,
    imputed_inf_CA_df,
]

In [None]:
# Get C-A core7 all track types preds
wanted_cols = ["True class", "Predicted class", "Max pred"]
for old_name, new_name in zip(
    ["manual_target_consensus", "Predicted_class_assay7", "Max_pred_assay7"], wanted_cols
):
    ca_metadata.rename(columns={old_name: new_name}, inplace=True)

ca_df = ca_metadata[ca_metadata["True class"].isin(assay_labels)][
    ["Experimental-id"] + wanted_cols
]

# Exclude predictions as 'input'
ca_df = ca_df[ca_df["Predicted class"] != "input"]
name = "obs_core7_inf_C-A"

var_names.append(name)
var_list.append(ca_df)

#### Graph prediction scores

In [None]:
this_fig_dir = section_folder / "pred_scores"
if not this_fig_dir.exists():
    raise FileNotFoundError(f"Folder {this_fig_dir} does not exist")

for name, df in zip(var_names, var_list):
    print(df.shape)
    df["EpiRR"] = list(df.index)
    df[ASSAY] = df["True class"]
    if "Max pred" not in df.columns:
        df["Max pred"] = df[assay_labels].max(axis=1)

    name = f"train_{name}"
    # print(f"Graphing {name}")

    # min_y = 0.7
    # fig1_a(
    #     df, logdir=this_fig_dir, name=f"imputation_impact-{name}_n{len(df)}_minY{min_y:.1f}", merge_assay_pairs=False, min_y=min_y, title=name
    # )

    # min_y = 0.9
    # fig1_a(
    #     df, logdir=this_fig_dir, name=f"imputation_impact-{name}_n{len(df)}_minY{min_y:.1f}", merge_assay_pairs=False, min_y=min_y, title=name
    # )

In [None]:
this_fig_dir = section_folder / "acc_per_assay"

In [None]:
all_acc_per_assay = {}
for name, df in zip(var_names, var_list):
    if "Max pred" not in df.columns:
        df["Max pred"] = df[assay_labels].max(axis=1)

    name = f"train_{name}"
    # {assay: [(min_pred, acc, nb_samples), ...], ...}
    acc_per_assay: Dict[str, List[Tuple[str, float, int]]] = {}
    for label in assay_labels:
        acc_per_assay[label] = []
        assay_df = df[df["True class"] == label]
        for min_pred in ["0.0", "0.6", "0.9"]:
            sub_df = assay_df[assay_df["Max pred"] > float(min_pred)]
            acc = (sub_df["True class"] == sub_df["Predicted class"]).mean()
            acc_per_assay[label].append((min_pred, acc, len(sub_df)))

    all_acc_per_assay[name] = acc_per_assay

In [None]:
# acc per assay to table
# cols = [classifier+task, assay, min_pred, acc, nb_samples]
rows = []
for name, acc_per_assay in all_acc_per_assay.items():
    for assay, values in acc_per_assay.items():
        for min_pred, acc, nb_samples in values:
            rows.append([name, assay, min_pred, acc, nb_samples])
df_acc_per_assay = pd.DataFrame(
    rows, columns=["task_name", "assay", "min_predScore", "acc", "nb_samples"]
)
df_acc_per_assay.to_csv(
    this_fig_dir / "imputation_impact_acc_per_assay.tsv", sep="\t", index=False
)

In [None]:
min_predScore_color_map = {"0.0": "blue", "0.6": "orange", "0.9": "red"}

df_acc_per_assay["scatter_name"] = (
    df_acc_per_assay["task_name"]
    .replace("train_", "", regex=True)
    .replace("imputed", "imp", regex=True)
    .replace("observed", "obs", regex=True)
)
df_acc_per_assay["inf_target"] = df_acc_per_assay["scatter_name"].str.split("_").str[-1]

In [None]:
df_acc_per_assay = df_acc_per_assay.sort_values(
    ["assay", "min_predScore", "scatter_name"]
)

In [None]:
# graph_type = "no_C-A"
# graph_type = "all"
# graph_type = "only_C-A"
graph_type = "only_C-A+core7"

if graph_type == "no_C-A":
    minY = 0.97
    maxY = 1.001
    trace_per_assay = 2
if graph_type == "all":
    minY = 0.7
    maxY = 1.005
    trace_per_assay = 4
if graph_type == "only_C-A":
    minY = 0.7
    maxY = 1.005
    trace_per_assay = 2
if graph_type == "only_C-A+core7":
    minY = 0.7
    maxY = 1.005
    trace_per_assay = 3


graph_df = df_acc_per_assay.copy()

if graph_type == "no_C-A":
    graph_df = graph_df[graph_df["inf_target"] != "C-A"]
elif "only_C-A" in graph_type:
    graph_df = graph_df[graph_df["inf_target"] == "C-A"]
elif graph_type == "all":
    pass

if graph_type != "only_C-A+core7":
    graph_df = graph_df[~graph_df["scatter_name"].str.contains("core7")]

# Calculate average over assays
avg_df = graph_df.groupby(["min_predScore", "scatter_name"])["acc"].mean().reset_index()
avg_df["assay"] = "Average"

fig = go.Figure()

for min_pred in ["0.0", "0.6", "0.9"]:
    df_subset = graph_df[graph_df["min_predScore"] == min_pred]
    avg_subset = avg_df[avg_df["min_predScore"] == min_pred]

    # Add average over assay trace
    fig.add_trace(
        go.Scatter(
            x=["Average - " + name for name in avg_subset["scatter_name"]],
            y=avg_subset["acc"],
            mode="markers",
            name=f"Avg Min Pred Score: {min_pred}",
            marker=dict(
                color=min_predScore_color_map[min_pred],
                size=9,
                symbol="star",
            ),
            hoverinfo="y+x",
            showlegend=False,
        )
    )

    # Add individual assay traces
    hovertext = list(
        zip(df_subset["assay"], df_subset["nb_samples"].apply(lambda x: f"Samples: {x}"))
    )
    fig.add_trace(
        go.Scatter(
            x=df_subset["assay"] + " - " + df_subset["scatter_name"],
            y=df_subset["acc"],
            mode="markers",
            name=f"Min Pred Score: {min_pred}",
            marker=dict(
                color=min_predScore_color_map[min_pred],
                size=9,
            ),
            text=hovertext,
            hoverinfo="text+y+x",
        )
    )

# Modify x-axis tick labels
ticktext = []
tick_group = list(df_subset["scatter_name"].unique())
for i, tick in enumerate(tick_group):
    train, inf = tick.split("_inf_")
    tick_group[i] = f"<b>{train}</b> \u2192 <b>{inf}</b>"

for i in range(len(assay_labels)):
    ticktext.extend(tick_group)

fig.update_xaxes(tickmode="array", ticktext=ticktext, tickvals=list(range(len(ticktext))))

# Add assay labels on top + vertical lines between assay groups
fig.add_annotation(
    x=len(tick_group) / 2 - 0.5,
    y=1.05,
    yref="paper",
    text="Average",
    showarrow=False,
    font=dict(size=14),
)

fig.add_vline(
    x=len(tick_group) - 0.5, line_width=2, line_dash="solid", line_color="black"
)

for i, label in enumerate(sorted(assay_labels)):
    fig.add_annotation(
        x=(i + 1) * len(tick_group) + len(tick_group) / 2 - 0.5,
        y=1.05,
        yref="paper",
        text=label,
        showarrow=False,
        font=dict(size=14),
    )
    fig.add_vline(
        x=(i + 1) * len(tick_group) - 0.5,
        line_width=1,
        line_dash="dash",
        line_color="black",
    )

fig.add_annotation(
    x=1.15,
    y=0.6,
    yref="paper",
    xref="paper",
    text="obs = observed<br>imp = imputed",
    showarrow=False,
    font=dict(size=14),
)

# titles + yaxis range
fig.update_layout(
    title="Accuracy per Assay and Task",
    xaxis_title="Assay - Task (training data \u2192 inference data)",
    yaxis_title="Accuracy",
    xaxis_tickangle=-45,
    showlegend=True,
    height=600,
    width=1200,
    yaxis=dict(tickformat=".2%", range=[minY, maxY]),
)

fig.add_hline(y=1, line_width=1, line_color="black")

# Show/Write the plot
print(f"Graphing {graph_type}")
figname = f"imputation_impact_{graph_type}_acc_per_assay_minY{minY:.2f}"
fig.write_html(this_fig_dir / f"{figname}.html")
fig.write_image(this_fig_dir / f"{figname}.png")
fig.write_image(this_fig_dir / f"{figname}.svg")
fig.show()

del fig

#### Graph accuracy boxplot

In [None]:
this_fig_dir = section_folder / "acc_per_assay" / "boxplot" / "assay_colored"
if not this_fig_dir.exists():
    raise FileNotFoundError(f"Folder {this_fig_dir} does not exist")

In [None]:
# graph_type = "no_C-A"
# graph_type = "all"
# graph_type = "only_C-A"
graph_type = "only_C-A+core7"

if graph_type == "no_C-A":
    minY = 0.97
    maxY = 1.001
if graph_type == "all":
    minY = 0.7
    maxY = 1.005
if graph_type == "only_C-A":
    minY = 0.7
    maxY = 1.005

graph_df = df_acc_per_assay.copy()
graph_df = graph_df.sort_values(["inf_target", "scatter_name"])
if graph_type == "no_C-A":
    graph_df = graph_df[graph_df["inf_target"] != "C-A"]
elif "only_C-A" in graph_type:
    graph_df = graph_df[graph_df["inf_target"] == "C-A"]

if graph_type != "only_C-A+core7":
    graph_df = graph_df[~graph_df["scatter_name"].str.contains("core7")]


# Prepare boxplot data
tick_group = graph_df["scatter_name"].unique()
scatter_name_to_position = {name: i for i, name in enumerate(tick_group)}

min_pred_values = ["0.0", "0.6", "0.9"]
offset = [-0.25, 0, 0.25]  # Offset for each min_pred within a tick group

fig = go.Figure()
for name in tick_group:
    group = graph_df[graph_df["scatter_name"] == name]

    for i, min_pred in enumerate(min_pred_values):
        df_subset = group[group["min_predScore"] == min_pred]

        x_position = scatter_name_to_position[name] + offset[i]
        x_positions = [x_position] * len(df_subset)
        y_values = df_subset["acc"]
        hover_texts = [
            f"{row['assay']}<br>Samples: {row['nb_samples']}"
            for _, row in df_subset.iterrows()
        ]
        colors = [assay_colors[assay] for assay in df_subset["assay"]]

        # Add box plot without points
        fig.add_trace(
            go.Box(
                x=x_positions,
                y=y_values,
                name=f"{name} - Min Pred Score: {min_pred}",
                line=dict(
                    color=min_predScore_color_map[min_pred],
                ),
                boxpoints=False,
                boxmean=True,
                showlegend=False,
            )
        )
        # Add scatter plot for individual points
        fig.add_trace(
            go.Scatter(
                x=[x + np.random.uniform(-0.01, 0.01) for x in x_positions],
                y=y_values,
                mode="markers",
                marker=dict(color=colors, size=8, line=dict(color="Black", width=1)),
                name=f"{name} - Min Pred Score: {min_pred}",
                showlegend=False,
                text=hover_texts,
                hoverinfo="text+y",
            )
        )

# Update x-axis tick labels
ticktext = []
for tick in tick_group:
    train, inf = tick.split("_inf_")
    ticktext.append(f"<b>{train}</b> \u2192 <b>{inf}</b>")

fig.update_xaxes(tickmode="array", ticktext=ticktext, tickvals=list(range(len(ticktext))))

# Update layout
fig.update_layout(
    title="Accuracy per Task (6 core assays)",
    xaxis_title="Task (training data \u2192 inference data)",
    yaxis_title="Accuracy",
    showlegend=True,
    height=600,
    width=1000,
    yaxis=dict(tickformat=".2%", range=[minY, maxY]),
)

# Add a legend for minPred colors
for val, color in min_predScore_color_map.items():
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            marker=dict(size=10, color=color, symbol="square"),
            name=f"Min Pred Score: {val}",
            showlegend=True,
        )
    )

# Add a legend for assay colors
for assay in sorted(assay_labels):
    color = assay_colors[assay]
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            marker=dict(size=10, color=color),
            name=assay,
            legendgroup="assays",
            showlegend=True,
        )
    )

# Add legend for obs and imp
fig.add_annotation(
    x=1.2,
    y=0.3,
    yref="paper",
    xref="paper",
    text="obs = observed<br>imp = imputed",
    showarrow=False,
    font=dict(size=14),
)

# Show/Write the plot
figname = f"imputation_impact_{graph_type}_boxplot_minY{minY:.2f}"
fig.write_html(this_fig_dir / f"{figname}.html")
fig.write_image(this_fig_dir / f"{figname}.png")
fig.write_image(this_fig_dir / f"{figname}.svg")
fig.show()

del fig

#### Confidence threshold VS samples conserved

In [None]:
this_fig_dir = section_folder / "samples_conserved"
if not this_fig_dir.exists():
    raise ValueError(f"Directory {this_fig_dir} does not exist")

In [None]:
samples_left_dict = {}

for name, df in zip(var_names, var_list):
    df[ASSAY] = df["True class"]
    pred_scores = sorted(df["Max pred"])

    nb_samples = len(df)
    x_vals = np.linspace(0, 1, 200)
    y_vals = np.zeros_like(x_vals)
    y_vals_acc = np.zeros_like(x_vals)
    for i, min_pred in enumerate(x_vals):
        sub_df = df[df["Max pred"] >= min_pred]

        samples_left_ratio = len(sub_df) / nb_samples
        y_vals[i] = samples_left_ratio

        acc = (sub_df["True class"] == sub_df["Predicted class"]).mean()
        y_vals_acc[i] = acc

    samples_left_dict[name] = (x_vals, y_vals, y_vals_acc)

In [None]:
fixed_min_pred_score_metrics = {}

for name, df in zip(var_names, var_list):
    fixed_min_pred_score_metrics[name] = {}
    for assay in assay_labels:
        sub_df = df[df["True class"] == assay]
        nb_samples = len(sub_df)

        x_vals = np.linspace(0, 1, 200)
        y_vals = np.zeros_like(x_vals)
        y_vals_acc = np.zeros_like(x_vals)
        for i, min_pred in enumerate(x_vals):
            sub_df = sub_df[sub_df["Max pred"] >= min_pred]

            samples_left_ratio = len(sub_df) / nb_samples
            y_vals[i] = samples_left_ratio

            acc = (sub_df["True class"] == sub_df["Predicted class"]).mean()
            y_vals_acc[i] = acc

        fixed_min_pred_score_metrics[name][assay] = (x_vals, y_vals, y_vals_acc)

In [None]:
fixed_sample_nb_metrics = {}
for name, df in zip(var_names, var_list):
    fixed_sample_nb_metrics[name] = {}
    for assay in assay_labels:
        sub_df = df[df["True class"] == assay]
        nb_samples = len(sub_df)

        # Define the fixed ratios of samples to keep
        fixed_ratios = np.linspace(0, 1, 200)

        min_pred_scores = []
        accuracies = []

        for ratio in fixed_ratios:
            samples_to_keep = int(nb_samples * ratio)
            if samples_to_keep == 0:
                samples_to_keep = 1

            # Sort by prediction score in descending order
            sorted_df = sub_df.sort_values("Max pred", ascending=False)

            # Keep top samples
            kept_df = sorted_df.head(samples_to_keep)

            # Compute min prediction score for kept samples
            min_pred_score = kept_df["Max pred"].min()
            min_pred_scores.append(min_pred_score)

            # Compute accuracy for kept samples
            acc = (kept_df["True class"] == kept_df["Predicted class"]).mean()
            accuracies.append(acc)

        fixed_sample_nb_metrics[name][assay] = (fixed_ratios, min_pred_scores, accuracies)

In [None]:
fig = go.Figure()

colors = px.colors.qualitative.Plotly

for i, (name, (x_vals, y_vals, y_vals_acc)) in enumerate(samples_left_dict.items()):
    if not name.endswith("C-A"):
        continue

    name = (
        name.replace("_inf_C-A", "")
        .replace("observed", "train_observed")
        .replace("imputed", "train_imputed")
    )

    fig.add_trace(
        go.Scatter(
            x=x_vals,
            y=y_vals_acc,
            name=f"{name}: Acc",
            mode="lines",
            line=dict(color=colors[i]),
            showlegend=True,
        )
    )

    fig.add_trace(
        go.Scatter(
            x=x_vals,
            y=y_vals,
            name=f"{name}: % left",
            mode="lines",
            line=dict(color=colors[i], dash="dash"),
            yaxis="y2",
            showlegend=True,
        )
    )


# Update layout
fig.update_layout(
    title="Accuracy and samples conserved for C-A predictions.",
    xaxis_title="Minimum prediction score",
    yaxis=dict(title="Accuracy (%)", tickformat=".2%"),
    yaxis2=dict(title="Subset Size (%)", overlaying="y", side="right", tickformat=".2%"),
    legend=dict(orientation="v", x=1.08, y=1),
    showlegend=True,
    height=600,
    width=1000,
)

fig.update_yaxes(range=[0, 1])
fig.update_xaxes(range=[0, 1])

# Show/Write the plot
figname = "C-A_predScore_samples_left_w_core7"
fig.write_html(this_fig_dir / f"{figname}.html")
fig.write_image(this_fig_dir / f"{figname}.png")
fig.write_image(this_fig_dir / f"{figname}.svg")
fig.show()

del fig

In [None]:
fig = go.Figure()

colors = px.colors.qualitative.Plotly


for assay in assay_labels:
    for name in var_names:
        if not name.endswith("C-A"):
            continue

        if "obs" in name:
            marker = "circle"
        elif "imp" in name:
            marker = "x"
        else:
            raise ValueError(f"Invalid name: {name}")

        x_vals, y_vals, y_vals_acc = fixed_min_pred_score_metrics[name][assay]

        color = assay_colors[assay]
        name = (
            name.replace("_inf_C-A", "")
            .replace("observed", "train_observed")
            .replace("imputed", "train_imputed")
        )

        fig.add_trace(
            go.Scatter(
                x=x_vals,
                y=y_vals_acc,
                mode="lines",
                line=dict(color=color),
                showlegend=False,
                name=f"{assay},{name},Acc",
                legendgroup=f"{assay}",
            )
        )

        fig.add_trace(
            go.Scatter(
                x=x_vals[::10],
                y=y_vals_acc[::10],
                mode="markers",
                marker=dict(
                    color=color, symbol=marker, line_color="black", line_width=1.5, size=8
                ),
                showlegend=False,
                name=f"{assay},{name},Acc",
                legendgroup=f"{assay}",
            )
        )

        fig.add_trace(
            go.Scatter(
                x=x_vals,
                y=y_vals,
                mode="lines",
                line=dict(color=color, dash="dash"),
                yaxis="y2",
                showlegend=False,
                name=f"{assay},{name},% samples",
                legendgroup=f"{assay}",
            )
        )

        fig.add_trace(
            go.Scatter(
                x=x_vals[::10],
                y=y_vals[::10],
                mode="markers",
                marker=dict(
                    color=color, symbol=marker, line_color="black", line_width=1.5, size=8
                ),
                yaxis="y2",
                showlegend=False,
                name=f"{assay},{name},% samples",
                legendgroup=f"{assay}",
            )
        )

        fig.add_trace(
            go.Scatter(
                x=[None],
                y=[None],
                mode="lines+markers",
                marker=dict(
                    color=color, symbol=marker, line_color="black", line_width=1.5, size=8
                ),
                name=f"{name} - {assay}",
                showlegend=True,
                legendgroup=f"{assay}",
            )
        )

    # Update layout
    fig.update_layout(
        title="Accuracy and samples conserved for C-A predictions.",
        xaxis_title="Minimum prediction score",
        yaxis=dict(title="Accuracy (%)", tickformat=".2%"),
        yaxis2=dict(
            title="Subset Size (%)", overlaying="y", side="right", tickformat=".2%"
        ),
        legend=dict(orientation="v", x=1.15, y=1),
        height=600,
        width=1000,
        showlegend=True,
    )

fig.update_yaxes(range=[0, 1])
fig.update_xaxes(range=[0, 1])

# fig.add_vline(x=0.603, line_dash="dash", line_color="black")
# fig.add_hline(y=87.96/100, line_dash="dash", line_color="black", yref="y2")

# fig.add_vline(x=0.7186, line_dash="dash", line_color="black")

# Show/Write the plot
figname = "C-A_predScore_samples_left_per_assay"
fig.write_html(this_fig_dir / f"{figname}.html")
fig.write_image(this_fig_dir / f"{figname}.png")
fig.write_image(this_fig_dir / f"{figname}.svg")
fig.show()

del fig

In [None]:
fig = go.Figure()
colors = px.colors.qualitative.Plotly

for assay in assay_labels:
    for name in var_names:
        if not name.endswith("C-A"):
            continue
        if "obs" in name:
            marker = "circle"
        elif "imp" in name:
            marker = "x"
        else:
            raise ValueError(f"Invalid name: {name}")

        fixed_ratios, min_pred_scores, accuracies = fixed_sample_nb_metrics[name][assay]
        color = assay_colors[assay]
        name = (
            name.replace("_inf_C-A", "")
            .replace("observed", "train_observed")
            .replace("imputed", "train_imputed")
        )

        # Accuracy trace
        fig.add_trace(
            go.Scatter(
                x=fixed_ratios,
                y=accuracies,
                mode="lines",
                line=dict(color=color),
                showlegend=False,
                name=f"{assay},{name},Acc",
                legendgroup=f"{assay}",
            )
        )
        fig.add_trace(
            go.Scatter(
                x=fixed_ratios[::10],
                y=accuracies[::10],
                mode="markers",
                marker=dict(
                    color=color, symbol=marker, line_color="black", line_width=1.5, size=8
                ),
                showlegend=False,
                name=f"{assay},{name},Acc",
                legendgroup=f"{assay}",
            )
        )

        # Minimum prediction score trace
        fig.add_trace(
            go.Scatter(
                x=fixed_ratios,
                y=min_pred_scores,
                mode="lines",
                line=dict(color=color, dash="dash"),
                yaxis="y2",
                showlegend=False,
                name=f"{assay},{name},Min pred score",
                legendgroup=f"{assay}",
            )
        )
        fig.add_trace(
            go.Scatter(
                x=fixed_ratios[::10],
                y=min_pred_scores[::10],
                mode="markers",
                marker=dict(
                    color=color, symbol=marker, line_color="black", line_width=1.5, size=8
                ),
                yaxis="y2",
                showlegend=False,
                name=f"{assay},{name},Min pred score",
                legendgroup=f"{assay}",
            )
        )

        # Legend trace
        fig.add_trace(
            go.Scatter(
                x=[None],
                y=[None],
                mode="lines+markers",
                marker=dict(
                    color=color, symbol=marker, line_color="black", line_width=1.5, size=8
                ),
                name=f"{name} - {assay}",
                showlegend=True,
                legendgroup=f"{assay}",
            )
        )

# Update layout
fig.update_layout(
    title="Accuracy and minimum prediction score for C-A predictions.",
    xaxis_title="Ratio of samples kept",
    yaxis=dict(title="Accuracy (%)", tickformat=".2%"),
    yaxis2=dict(
        title="Minimum prediction score", overlaying="y", side="right", tickformat=".3f"
    ),
    legend=dict(orientation="v", x=1.15, y=1),
    height=600,
    width=1000,
    showlegend=True,
)
fig.update_yaxes(range=[0, 1])
fig.update_xaxes(range=[0, 1])

# Show/Write the plot
# figname = "C-A_sampleRatio_minPredScore_per_assay"
# fig.write_html(this_fig_dir / f"{figname}.html")
# fig.write_image(this_fig_dir / f"{figname}.png")
# fig.write_image(this_fig_dir / f"{figname}.svg")
fig.show()
del fig

### NN - Accuracy per assay (boxplot 10fold)

In [None]:
def fig1_acc_per_assay(
    all_split_dfs: Dict[str, Dict[str, pd.DataFrame]], logdir: Path, name: str
) -> None:
    """
    Creates a Plotly figure with a boxplot for the accuracy of each assay over all splits.

    Args:
        NN_results (pd.DataFrame): The DataFrame containing the neural network results.
        logdir (Path): The directory where the figure will be saved.
        name (str): The name of the figure.
    Returns:
        None: Displays the plotly figure.
    """
    fig = go.Figure()

    # Compute accuracy per assay per split
    assay_accuracies = defaultdict(dict)
    for split_name, classifier_dict in all_split_dfs.items():
        df = classifier_dict["NN"]
        df = merge_similar_assays(df)
        pred_groupby = df.groupby(["True class"])["Predicted class"].value_counts(
            normalize=True
        )
        assay_acc = pred_groupby.unstack().fillna(0).max(axis=1)
        assay_accuracies[split_name] = assay_acc.to_dict()

    # invert the dictionary
    assay_accuracies_inv = defaultdict(dict)
    for split_name, assay_acc in assay_accuracies.items():
        for assay, acc in assay_acc.items():
            assay_accuracies_inv[assay][split_name] = acc

    class_labels_sorted = ASSAY_ORDER

    for label in class_labels_sorted:
        hovertext = [
            f"{split_name}: {acc:.4f}"
            for split_name, acc in assay_accuracies_inv[label].items()
        ]
        line_color = "black"
        fig.add_trace(
            go.Box(
                name=label,
                y=list(assay_accuracies_inv[label].values()),
                boxmean=True,
                boxpoints="all",
                fillcolor=assay_colors[label],
                line_color=line_color,
                marker=dict(size=3, color=assay_colors[label]),
                showlegend=True,
                hovertemplate="%{text}",
                text=hovertext,
            )
        )

    fig.update_yaxes(range=[0.985, 1.001])

    fig.update_layout(
        title_text="Neural network assay classification 10-fold accuracy",
        yaxis_title="Accuracy",
        xaxis_title="Assay experiment",
    )

    # Save figure
    fig.write_image(logdir / f"{name}.svg")
    fig.write_image(logdir / f"{name}.png")
    fig.write_html(logdir / f"{name}.html")

    fig.show()

In [None]:
all_split_dfs = split_results_handler.gather_split_results_across_methods(
    results_dir=base_data_dir / "dfreeze_v2" / "hg38_100kb_all_none",
    label_category=ASSAY,
    only_NN=True,
)

logdir = base_fig_dir / "fig1--acc_per_assay"
name = "fig1--acc_per_assay"
fig1_acc_per_assay(all_split_dfs, logdir=logdir, name=name)

In [None]:
all_split_dfs = split_results_handler.gather_split_results_across_methods(
    results_dir=base_data_dir / "dfreeze_v2" / "hg38_10kb_all_none",
    label_category=ASSAY,
    only_NN=True,
)

logdir = paper_dir / "figures" / "fig2_EpiAtlas_other" / "fig2--10kb"
if not logdir.exists():
    raise FileNotFoundError(f"Directory {logdir} does not exist")

name = f"fig2--{ASSAY}_acc_per_assay_10kb"
fig1_acc_per_assay(all_split_dfs, logdir=logdir, name=name)