In [None]:
"""Figure core creation: Fig1

Formatting of the figures may not be identical to the paper, but they contain the same data points.
"""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines, too-many-branches

'Figure core creation: Fig1\n\nFormatting of the figures may not be identical to the paper, but they contain the same data points.\n'

In [2]:
%load_ext autoreload
%autoreload 2

## Setup

In [None]:
from __future__ import annotations

import copy
from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from sklearn.metrics import confusion_matrix as sk_cm

from epi_ml.core.confusion_matrix import ConfusionMatrixWriter
from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    ASSAY_MERGE_DICT,
    ASSAY_ORDER,
    CELL_TYPE,
    IHECColorMap,
    MetadataHandler,
    SplitResultsHandler,
    merge_similar_assays,
)

In [4]:
# TODO: Have a slimed down data directory that only uses the necessary files for figure creation, compress it, and use it here.

In [5]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
paper_dir = base_dir
if not paper_dir.exists():
    raise FileNotFoundError(f"Directory {paper_dir} does not exist.")

base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"

In [6]:
IHECColorMap = IHECColorMap(base_fig_dir)
assay_colors = IHECColorMap.assay_color_map
cell_type_colors = IHECColorMap.cell_type_color_map

assay_colors["mrna_seq"] = assay_colors["rna_seq"]
assay_colors["wgbs-pbat"] = assay_colors["wgbs"]
assay_colors["wgbs-standard"] = assay_colors["wgbs"]

In [7]:
split_results_handler = SplitResultsHandler()

metadata_handler = MetadataHandler(paper_dir)
metadata_v2 = metadata_handler.load_metadata("v2")

## Supp Fig 1

### Supp Fig 1A,1B - All classifiers metrics on EpiAtlas - Assay and Sample Ontology - 100kb resolution

Fig 1A,1B: data points are included here.

In [8]:
def plot_split_metrics(
    split_metrics: Dict[str, Dict[str, Dict[str, float]]],
    label_category: str,
    logdir: Path | None = None,
    filename: str = "fig1_all_classifiers_metrics",
) -> None:
    """Render to box plots the metrics per classifier and split, each in its own subplot.

    Args:
        split_metrics: A dictionary containing metric scores for each classifier and split.
        label_category: The label category for the classification task.
        name: The name of the figure.
        logdir: The directory to save the figure to. If None, the figure is only displayed.

    Returns:
        None: Displays the figure and saves it to the logdir if provided.
    """
    metrics = ["Accuracy", "F1_macro", "AUC_micro", "AUC_macro"]
    classifier_names = list(next(iter(split_metrics.values())).keys())
    classifier_names = ["NN", "LR", "LGBM", "LinearSVC", "RF"]

    # Create subplots, one row for each metric
    fig = make_subplots(
        rows=1,
        cols=len(metrics),
        subplot_titles=metrics,
        horizontal_spacing=0.075,
    )

    for i, metric in enumerate(metrics):
        for classifier in classifier_names:
            values = [split_metrics[split][classifier][metric] for split in split_metrics]
            if classifier == "NN":
                classifier = "MLP"
            fig.add_trace(
                go.Box(
                    y=values,
                    name=classifier,
                    line=dict(color="black", width=1.5),
                    marker=dict(size=3, color="black"),
                    boxmean=True,
                    boxpoints="all",  # or "outliers" to show only outliers
                    pointpos=-1.4,
                    showlegend=False,
                    width=0.5,
                    hovertemplate="%{text}",
                    text=[
                        f"{split}: {value:.4f}"
                        for split, value in zip(split_metrics, values)
                    ],
                ),
                row=1,
                col=i + 1,
            )

    fig.update_layout(
        title_text=f"{label_category} classification - Metric distribution for 10fold cross-validation",
        yaxis_title="Value",
        boxmode="group",
    )

    # Adjust y-axis
    if label_category == ASSAY:
        range_acc = [0.95, 1.001]
        range_AUC = [0.992, 1.0001]
    elif label_category == CELL_TYPE:
        range_acc = [0.81, 1]
        range_AUC = [0.96, 1]
    else:
        range_acc = [0.6, 1.001]
        range_AUC = [0.9, 1.0001]

    fig.update_layout(yaxis=dict(range=range_acc))
    fig.update_layout(yaxis2=dict(range=range_acc))
    fig.update_layout(yaxis3=dict(range=range_AUC))
    fig.update_layout(yaxis4=dict(range=range_AUC))

    # Save figure
    if logdir:
        fig.write_image(logdir / f"{filename}.svg")
        fig.write_image(logdir / f"{filename}.png")
        fig.write_html(logdir / f"{filename}.html")

    fig.show()

In [9]:
data_dir_100kb = base_data_dir / "training_results" / "dfreeze_v2" / "hg38_100kb_all_none"
if not data_dir_100kb.exists():
    raise FileNotFoundError(f"Directory {data_dir_100kb} does not exist.")

merge_assays = False

for label_category in [ASSAY, CELL_TYPE]:
    all_split_dfs = split_results_handler.gather_split_results_across_methods(
        results_dir=data_dir_100kb, label_category=label_category
    )

    if merge_assays and label_category == ASSAY:
        for split_name, split_dfs in all_split_dfs.items():
            for classifier_type, df in split_dfs.items():
                all_split_dfs[split_name][classifier_type] = merge_similar_assays(df)

    split_metrics = split_results_handler.compute_split_metrics(all_split_dfs)

    plot_split_metrics(
        split_metrics,
        label_category=label_category,
    )

**Going forward, all results are for MLP classifiers.**

### Supp Fig 1C - Metrics for zeroed blacklist values and winsorized files - 100kb resolution

In [10]:
def create_blklst_graphs(
    feature_set_metrics_dict: Dict[str, Dict[str, Dict[str, Dict[str, float]]]],
    logdir: Path | None = None,
) -> None:
    """Create boxplots for blacklisted related feature sets.

    Args:
        feature_set_metrics_dict (Dict[str, Dict[str, Dict[str, Dict[str, float]]]]): The dictionary containing all metrics for all blklst related feature sets.
            format: {feature_set: {task_name: {split_name: metric_dict}}}
        logdir (Path, Optional): The directory to save the figure to. If None, the figure is only displayed.
    """
    # Assume names exist in all feature sets
    task_names = list(feature_set_metrics_dict.values())[0].keys()

    traces_names_dict = {
        "hg38_100kb_all_none": "observed",
        "hg38_100kb_all_none_0blklst": "0blklst",
        "hg38_100kb_all_none_0blklst_winsorized": "0blklst_winsorized",
    }

    for task_name in task_names:
        category_fig = make_subplots(
            rows=1,
            cols=2,
            shared_yaxes=False,
            subplot_titles=["Accuracy", "F1-score (macro)"],
            x_title="Feature set",
            y_title="Metric value",
            horizontal_spacing=0.1,
        )
        for feature_set_name, tasks_dicts in feature_set_metrics_dict.items():
            task_dict = tasks_dicts[task_name]
            trace_name = traces_names_dict[feature_set_name]

            # Accuracy
            metric = "Accuracy"
            y_vals = [task_dict[split][metric] for split in task_dict]  # type: ignore
            hovertext = [
                f"{split}: {metrics_dict[metric]:.4f}"  # type: ignore
                for split, metrics_dict in task_dict.items()
            ]

            category_fig.add_trace(
                go.Box(
                    y=y_vals,
                    name=trace_name,
                    boxmean=True,
                    boxpoints="all",
                    showlegend=False,
                    marker=dict(size=3, color="black"),
                    line=dict(width=1, color="black"),
                    hovertemplate="%{text}",
                    text=hovertext,
                ),
                row=1,
                col=1,
            )

            metric = "F1_macro"
            y_vals = [task_dict[split][metric] for split in task_dict]  # type: ignore
            hovertext = [
                f"{split}: {metrics_dict[metric]:.4f}"  # type: ignore
                for split, metrics_dict in task_dict.items()
            ]
            category_fig.add_trace(
                go.Box(
                    y=y_vals,
                    name=trace_name,
                    boxmean=True,
                    boxpoints="all",
                    showlegend=False,
                    marker=dict(size=3, color="black"),
                    line=dict(width=1, color="black"),
                    hovertemplate="%{text}",
                    text=hovertext,
                ),
                row=1,
                col=2,
            )

        category_fig.update_xaxes(
            categoryorder="array",
            categoryarray=list(traces_names_dict.values()),
        )
        category_fig.update_yaxes(range=[0.85, 1.001])

        task_name = task_name.replace("_1l_3000n-10fold", "")
        category_fig.update_layout(
            title=f"MLP performance<br>{task_name}",
        )

        width = 500
        height = width * 1.5
        category_fig.update_layout(
            autosize=False,
            width=width,
            height=height,
        )
        # Save figure
        if logdir:
            base_name = f"metrics_{task_name}"
            category_fig.write_html(logdir / f"{base_name}.html")
            category_fig.write_image(logdir / f"{base_name}.svg")
            category_fig.write_image(logdir / f"{base_name}.png")

        category_fig.show()

In [11]:
include_sets = [
    "hg38_100kb_all_none",
    "hg38_100kb_all_none_0blklst",
    "hg38_100kb_all_none_0blklst_winsorized",
]

results_folder_blklst = base_data_dir / "training_results" / "2023-01-epiatlas-freeze"
if not results_folder_blklst.exists():
    raise FileNotFoundError(f"Folder '{results_folder_blklst}' not found")

In [12]:
# Select 10-fold oversampling runs
# expected result shape: {feature_set: {task_name: {split_name: metrics_dict}}}
all_metrics: Dict[
    str, Dict[str, Dict[str, Dict[str, float]]]
] = split_results_handler.obtain_all_feature_set_data(
    return_type="metrics",
    parent_folder=results_folder_blklst,
    merge_assays=False,
    include_categories=[ASSAY, CELL_TYPE],
    include_sets=include_sets,
    oversampled_only=False,
    verbose=False,
)  # type: ignore

In [13]:
create_blklst_graphs(all_metrics)

### Supp Fig 1D - Accuracy per assay + confusion matrix

In [14]:
def NN_performance_per_assay_across_categories(
    all_split_results: Dict[str, Dict[str, pd.DataFrame]],
    title_end: str = "",
    exclude_categories: List[str] | None = None,
    y_range: None | List[float] = None,
    logdir: Path | None = None,
    verbose: bool = False,
):
    """Create a box plot of assay accuracy for each classifier.

    all_split_results (Dict[str, Dict[str, pd.DataFrame]]): The dictionary containing all split results for each classifier.
    title_end (str, optional): The title to append to the figure title.
    exclude_categories (List[str], optional): The categories to exclude from the figure.
    y_range (None | List[float], optional): The y-axis range for the figure.
    logdir (Path, optional): The directory to save the figure to. If None, the figure is only displayed.
    verbose (bool, optional): Whether to print progress information.
    """
    all_split_results = copy.deepcopy(all_split_results)

    # Exclude some categories
    classifier_names = list(all_split_results.keys())
    if exclude_categories is not None:
        for category in exclude_categories:
            classifier_names = [c for c in classifier_names if category not in c]

    metadata_df = MetadataHandler(paper_dir).load_metadata_df("v2-encode")

    # One graph per metadata category
    for task_name in classifier_names:
        if verbose:
            print(f"Processing {task_name}")
        split_results = all_split_results[task_name]
        if ASSAY in task_name:
            for split_name in split_results:
                try:
                    split_results[split_name] = merge_similar_assays(
                        split_results[split_name]
                    )
                except ValueError as e:
                    print(f"Skipping {task_name} assay merging: {e}")
                    break

        assay_acc_df = split_results_handler.compute_acc_per_assay(
            split_results, metadata_df
        )

        fig = go.Figure()
        for assay in ASSAY_ORDER:
            try:
                assay_accuracies = assay_acc_df[assay]
            except KeyError:
                continue

            fig.add_trace(
                go.Box(
                    y=assay_accuracies.values,
                    name=assay,
                    boxmean=True,
                    boxpoints="all",
                    showlegend=True,
                    marker=dict(size=3, color="black"),
                    line=dict(width=1, color="black"),
                    fillcolor=assay_colors[assay],
                    hovertemplate="%{text}",
                    text=[
                        f"{split}: {value:.4f}"
                        for split, value in assay_accuracies.items()
                    ],
                )
            )

        yrange = [assay_acc_df.min(), 1.001]  # type: ignore
        if y_range is not None:
            yrange = y_range

        fig.update_yaxes(range=yrange)

        title_text = f"NN classification - {task_name}"
        if title_end:
            title_text += f" - {title_end}"
        fig.update_layout(
            title_text=title_text,
            yaxis_title="Accuracy",
            xaxis_title="Assay",
            width=1000,
            height=700,
        )

        # Save figure
        if logdir:
            filename = "NN_assay_performance_" + task_name
            fig.write_image(logdir / f"{filename}.svg")
            fig.write_image(logdir / f"{filename}.png")
            fig.write_html(logdir / f"{filename}.html")

        fig.show()

In [15]:
def create_confusion_matrix(
    df: pd.DataFrame,
    logdir: Path,
    name: str,
    min_pred_score: float = 0,
    majority: bool = False,
) -> None:
    """Create a confusion matrix for the given DataFrame and save it to the logdir.

    Args:
        df (pd.DataFrame): The DataFrame containing the results.
        logdir (Path): The directory path for saving the figures.
        name (str): The name for the saved figures.
        min_pred_score (float): The minimum prediction score to consider.
        majority (bool): Whether to use majority vote (uuid-wise) for the predicted class.
    """
    # Compute confusion matrix
    classes = sorted(df["True class"].unique())
    if "Max pred" not in df.columns:
        df["Max pred"] = df[classes].max(axis=1)  # type: ignore
    filtered_df = df[df["Max pred"] > min_pred_score]

    if majority:
        # Majority vote for predicted class
        groupby_uuid = filtered_df.groupby(["uuid", "True class", "Predicted class"])[
            "Max pred"
        ].aggregate(["size", "mean"])

        if groupby_uuid["size"].max() > 3:
            raise ValueError("More than three predictions for the same uuid.")

        groupby_uuid = groupby_uuid.reset_index().sort_values(
            ["uuid", "True class", "size"], ascending=[True, True, False]
        )
        groupby_uuid = groupby_uuid.drop_duplicates(
            subset=["uuid", "True class"], keep="first"
        )
        filtered_df = groupby_uuid

    confusion_mat = sk_cm(
        filtered_df["True class"], filtered_df["Predicted class"], labels=classes
    )

    mat_writer = ConfusionMatrixWriter(labels=classes, confusion_matrix=confusion_mat)
    files = mat_writer.to_all_formats(logdir, name=f"{name}_n{len(filtered_df)}")
    print(f"Saved confusion matrix to {logdir}:")
    for file in files:
        print(Path(file).name)

In [16]:
assay_split_dfs = split_results_handler.gather_split_results_across_methods(
    results_dir=data_dir_100kb, label_category=ASSAY, only_NN=True
)
concat_assay_df = split_results_handler.concatenate_split_results(assay_split_dfs)["NN"]

df_with_meta = metadata_handler.join_metadata(concat_assay_df, metadata_v2)  # type: ignore
if "Predicted class" not in df_with_meta.columns:
    raise ValueError("`Predicted class` not in DataFrame")

classifier_name = "MLP"
min_pred_score = 0
majority = False

name = f"{classifier_name}_pred>{min_pred_score}"

logdir = base_fig_dir / "fig1_EpiAtlas_assay" / "fig1_supp_D-assay_c11_confusion_matrices"
if majority:
    logdir = logdir / "per_uuid"
else:
    logdir = logdir / "per_file"
logdir.mkdir(parents=True, exist_ok=True)

In [17]:
create_confusion_matrix(
    df=df_with_meta,
    min_pred_score=min_pred_score,
    logdir=logdir,
    name=name,
    majority=majority,
)

Saved confusion matrix to /home/local/USHERBROOKE/rabj2301/Projects/epiclass/output/paper/figures/fig1_EpiAtlas_assay/fig1_supp_D-assay_c11_confusion_matrices/per_file:
MLP_pred>0_n20922.csv
MLP_pred>0_n20922_relative.csv
MLP_pred>0_n20922.png


In [18]:
results_per_task: Dict[str, Dict[str, pd.DataFrame]] = {ASSAY: split_results_handler.invert_metrics_dict(assay_split_dfs)["NN"]}  # type: ignore

In [19]:
NN_performance_per_assay_across_categories(all_split_results=results_per_task)

### Supp Fig 1E,1F,1G - Distribution of average prediction scores per assay

- E: Assay training 10-fold validation
- F: Assay complete training, predictions on imputed data
- G: Sample ontology 10-fold validation

In [20]:
def plot_prediction_scores_distribution(
    results_df: pd.DataFrame,
    merge_assay_pairs: bool = True,
    logdir: Path | None = None,
    name: str = "prediction_score_distribution",
    group_by_column: str = "True class",
    min_y: float = 0.7,
    use_aggregate_vote: bool = True,
    title: str | None = None,
) -> None:
    """
    Creates a Plotly figure with violin plots and associated scatter plots for each group.
    Supports both aggregated and non-aggregated data visualization with enhanced styling.

    Args:
        results_df (pd.DataFrame): DataFrame containing prediction results and metadata
        merge_assay_pairs (bool): Whether to merge similar assays (mrna/rna, wgbs-pbat/wgbs-standard)
        logdir (Path | None): Directory to save figures. If None, only displays the figure
        name (str): Base name for saved files
        group_by_column (str): Column name to use for grouping traces
        merge_similar_assays (bool): Whether to merge similar assays (mrna/rna, wgbs-pbat/wgbs-standard)
        min_y (float): Minimum y-axis value
        use_aggregate_vote (bool): If True, aggregate by EpiRR. If False, use individual predictions
        title (str | None): Additional title text to append
    """
    fig = go.Figure()

    jitter_amplitude = 0.05  # Scatter plot jittering
    scatter_offset = 0.2  # Scatter plots offset

    if merge_assay_pairs:
        try:
            results_df = merge_similar_assays(results_df)
        except ValueError as e:
            print(f"Skipping assay merging: {e}")

    # Group ordering
    if group_by_column == ASSAY and merge_assay_pairs:
        group_labels = ASSAY_ORDER
    else:
        group_labels = sorted(set(results_df[group_by_column].unique()))
    group_index = {label: i for i, label in enumerate(group_labels)}

    # Colors for each group
    if group_by_column == ASSAY:
        colors = assay_colors
    else:
        grey = "rgba(237, 231, 225, 1)"
        colors = {label: grey for label in group_labels}

    for label in group_labels:
        sub_df = results_df[results_df[group_by_column] == label]

        if use_aggregate_vote:
            # Aggregate by EpiRR with majority voting
            groupby = sub_df.groupby(["EpiRR", "Predicted class", "True class"])[
                "Max pred"
            ].aggregate(["size", "mean"])
            groupby = groupby.reset_index().sort_values(
                ["EpiRR", "size"], ascending=[True, False]
            )
            groupby = groupby.drop_duplicates(subset="EpiRR", keep="first")
            assert groupby["EpiRR"].is_unique
            mean_pred = groupby["mean"]

            # Compare predicted class against true class for matches
            matches = groupby["Predicted class"] == groupby["True class"]
            match_pred = mean_pred[matches]
            mismatch_pred = mean_pred[~matches]

            hover_template = [
                f"EpiRR: {row[1]['EpiRR']}, Expected: {row[1]['True class']}, Pred: {row[1]['Predicted class']}, "
                f"Mean pred: {row[1]['mean']:.3f}, n={row[1]['size']}"
                for row in groupby.iterrows()
            ]
        else:
            # Use individual predictions
            mean_pred = sub_df["Max pred"]
            matches = sub_df["Predicted class"] == sub_df["True class"]
            match_pred = mean_pred[matches]
            mismatch_pred = mean_pred[~matches]

            hover_template = [
                f"ID: {row['md5sum']}, Expected: {row['True class']}, Pred: {row['Predicted class']}, "
                f"Pred: {row['Max pred']:.3f}"
                for _, row in sub_df.iterrows()
            ]

        # Add violin plot
        fig.add_trace(
            go.Violin(
                x=[group_index[label]] * len(mean_pred),
                y=mean_pred,
                name=label,
                spanmode="hard",
                box_visible=True,
                meanline_visible=True,
                points=False,
                fillcolor=colors[label],
                line_color="white",
                line=dict(width=0.5),
                showlegend=False,
            )
        )

        # Add jittered scatter plots for matches and mismatches
        np.random.seed(42)
        jitter = np.random.uniform(-jitter_amplitude, jitter_amplitude, len(mean_pred))
        x_positions = (
            np.array([group_index[label]] * len(mean_pred)) + jitter - scatter_offset
        )

        # Plot matches (black points)
        if len(match_pred) > 0:
            fig.add_trace(
                go.Scatter(
                    x=x_positions[matches],
                    y=match_pred,
                    mode="markers",
                    name=f"Match {label}",
                    marker=dict(color="black", size=1),
                    hovertemplate="%{text}",
                    text=[hover_template[i] for i, m in enumerate(matches) if m],
                    showlegend=False,
                    legendgroup="match",
                )
            )

        # Plot mismatches (red points)
        if len(mismatch_pred) > 0:
            fig.add_trace(
                go.Scatter(
                    x=x_positions[~matches],
                    y=mismatch_pred,
                    mode="markers",
                    name=f"Mismatch {label}",
                    marker=dict(color="red", size=3),
                    hovertemplate="%{text}",
                    text=[hover_template[i] for i, m in enumerate(matches) if not m],
                    showlegend=False,
                    legendgroup="mismatch",
                )
            )

    # Add legend entries
    for legend_entry in [("Match", "black", 10), ("Mismatch", "red", 10)]:
        fig.add_trace(
            go.Scatter(
                x=[None],
                y=[None],
                mode="markers",
                name=legend_entry[0],
                marker=dict(color=legend_entry[1], size=legend_entry[2]),
                showlegend=True,
                legendgroup=legend_entry[0].lower(),
            )
        )

    # Update layout
    title_text = "Prediction Score Distribution"
    if use_aggregate_vote:
        title_text += " (EpiRR majority vote)"
    if title:
        title_text += f" - {title}"

    fig.update_layout(
        title_text=title_text,
        yaxis_title="Prediction Score"
        if not use_aggregate_vote
        else "Avg. Prediction Score (majority class)",
        xaxis_title=group_by_column,
        yaxis_range=[min_y, 1.001],
        xaxis=dict(
            tickvals=list(group_index.values()),
            ticktext=list(group_index.keys()),
        ),
        legend=dict(
            title_text="Legend",
            itemsizing="constant",
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1,
        ),
    )

    # Save figure if logdir is provided
    if logdir:
        filename = f"{name}_epirr" if use_aggregate_vote else name
        fig.write_html(logdir / f"{filename}.html")
        fig.write_image(logdir / f"{filename}.svg")
        fig.write_image(logdir / f"{filename}.png")

    fig.show()

#### Supp fig 1E: ASSAY - Prediction scores for 10-fold cross-validation

In [21]:
fig_e_data_dir = data_dir_100kb / f"{ASSAY}_1l_3000n" / "11c" / "10fold-oversampling"
if not fig_e_data_dir.exists():
    raise FileNotFoundError(f"Directory {fig_e_data_dir} does not exist.")

dfs = split_results_handler.read_split_results(fig_e_data_dir)
concat_df: pd.DataFrame = split_results_handler.concatenate_split_results(dfs, depth=1)  # type: ignore
concat_df = split_results_handler.add_max_pred(concat_df)
concat_df_w_meta = metadata_handler.join_metadata(concat_df, metadata_v2)

In [22]:
plot_prediction_scores_distribution(
    results_df=concat_df_w_meta,
    group_by_column=ASSAY,
    merge_assay_pairs=True,
    min_y=0.4,
    title="11 classes assay training - Prediction scores for 10-fold cross-validation",
)

#### Supp fig 1F: ASSAY -  Classifer=split0 from previous training. Predicting on imputed data (all pval)

In [23]:
imputation_dir = base_data_dir / "training_results" / "imputation"
if not imputation_dir.exists():
    raise FileNotFoundError(f"Directory {imputation_dir} does not exist.")

fig_f_pred_dir = (
    imputation_dir
    / "hg38_100kb_all_none"
    / f"{ASSAY}_1l_3000n"
    / "chip-seq-only"
    / "10fold-oversampling"
    / "split0"
    / "predict_imputed"
)
if not fig_f_pred_dir.exists():
    raise FileNotFoundError(f"Directory {fig_f_pred_dir} does not exist.")

df_pred = pd.read_csv(
    fig_f_pred_dir / "split0_test_prediction_100kb_all_none_chip-seq_imputed.csv",
    index_col=0,
)
df_pred = split_results_handler.add_max_pred(df_pred)
df_pred["EpiRR"] = df_pred.index
df_pred[ASSAY] = df_pred["True class"]

In [24]:
plot_prediction_scores_distribution(
    results_df=df_pred,
    group_by_column=ASSAY,
    merge_assay_pairs=True,
    min_y=0.9,
    title="split0 assay classifier, predicting on imputed data",
)

Skipping assay merging: Wrong results dataframe, rna or wgbs columns missing.


#### Supp fig 1G: Sample Ontology - Prediction scores for 10-fold cross-validation

In [25]:
fig_g_data_dir = data_dir_100kb / f"{CELL_TYPE}_1l_3000n" / "10fold-oversampling"
if not fig_g_data_dir.exists():
    raise FileNotFoundError(f"Directory {fig_g_data_dir} does not exist.")

dfs = split_results_handler.read_split_results(fig_g_data_dir)
concat_df: pd.DataFrame = split_results_handler.concatenate_split_results(dfs, depth=1)  # type: ignore
concat_df = split_results_handler.add_max_pred(concat_df)
concat_df_w_meta = metadata_handler.join_metadata(concat_df, metadata_v2)
concat_df_w_meta[ASSAY].replace(ASSAY_MERGE_DICT, inplace=True)

In [26]:
plot_prediction_scores_distribution(
    results_df=concat_df_w_meta,
    group_by_column=ASSAY,
    min_y=0,
    title="Sample Ontology training - Prediction scores for 10-fold cross-validation",
)

Skipping assay merging: Wrong results dataframe, rna or wgbs columns missing.
