In [None]:
"""Workbook to create figures destined for the paper."""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal

In [None]:
from __future__ import annotations

import itertools
import json
from collections import defaultdict
from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from sklearn.metrics import (
    accuracy_score,
    confusion_matrix as sk_cm,
    f1_score,
    roc_auc_score,
)

from epi_ml.core.confusion_matrix import ConfusionMatrixWriter
from epi_ml.core.metadata import Metadata

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"

In [None]:
# Global variables
ASSAY = "assay_epiclass"
CELL_TYPE = "harmonized_sample_ontology_intermediate"
ASSAY_MERGE_DICT = {"mrna_seq": "rna_seq", "wgbs-pbat": "wgbs", "wgbs-standard": "wgbs"}

### Figure colors management

In [None]:
color_map_path = base_fig_dir / "IHEC_IA_colors_jan22_2024.json"
with open(color_map_path, "r", encoding="utf8") as color_map_file:
    ihec_color_map = json.load(color_map_file)

In [None]:
def create_assay_color_map(ihec_color_map: List[Dict]) -> Dict[str, str]:
    """Create a rbg color map for ihec core assays."""
    colors = dict(ihec_color_map[0]["histone"][0].items())
    for name, color in list(colors.items()):
        rbg = color.split(",")
        colors[name.lower()] = f"rgb({rbg[0]},{rbg[1]},{rbg[2]})"

    colors.update({"rna_seq": "rgb(0,204,150)", "wgbs": "rgb(171,99,250)"})
    return colors

In [None]:
def create_cell_type_color_map(ihec_color_map: List[Dict]) -> Dict[str, str]:
    """Read the rbg color map for ihec cell types."""
    colors = dict(ihec_color_map[3]["harmonized_sample_ontology_intermediate"][0].items())
    for name, color in list(colors.items()):
        rbg = color.split(",")
        colors[name] = f"rgb({rbg[0]},{rbg[1]},{rbg[2]})"

    return colors

In [None]:
assay_colors = create_assay_color_map(ihec_color_map)
cell_type_colors = create_cell_type_color_map(ihec_color_map)

In [None]:
def merge_similar_assays(df: pd.DataFrame) -> pd.DataFrame:
    """Attempt to merge rna-seq/wgbs categories, included prediction score."""
    df = df.copy(deep=True)
    try:
        df["rna_seq"] = df["rna_seq"] + df["mrna_seq"]
        df["wgbs"] = df["wgbs-standard"] + df["wgbs-pbat"]
    except KeyError as exc:
        raise ValueError(
            "Wrong results dataframe, label category is not assay specific."
        ) from exc
    df.drop(columns=["mrna_seq", "wgbs-standard", "wgbs-pbat"], inplace=True)
    df["True class"].replace(ASSAY_MERGE_DICT, inplace=True)
    df["Predicted class"].replace(ASSAY_MERGE_DICT, inplace=True)

    try:
        df[ASSAY].replace(ASSAY_MERGE_DICT, inplace=True)
    except KeyError:
        pass

    # Recompute Max pred if it exists
    classes = df["True class"].unique()
    if "Max pred" in df.columns:
        df["Max pred"] = df[classes].max(axis=1)
    return df

### Figure 1.A

Average distribution of prediction scores per assay 
violin plot. One point per UUID, track types averaged (combine 2xRNA and 2xWGBS)
points with 3 colors: 
- black for pred same class
- red for pred different class/mislabel
- orange bad qual (IHEC flag, was removed in later stages)

Graph version with color saturation gradient using max_pred/input score ratio

Using [EpiClass_EA-21606_Assay11_100kb](https://drive.google.com/drive/folders/1SzyTFCVk2Cyw7NXW08sSYB_k49y-1KoJ) : EA_NN--full-10fold-validation_prediction_augmented-all

In [None]:
# for col in NN_results.columns:
#     print(col)

In [None]:
def fig1_a(
    NN_results: pd.DataFrame, logdir: Path, name: str, merge_assay_pairs: bool
) -> None:
    """
    Creates a Plotly figure with violin plots and associated scatter plots for each class.
    Red scatter points, indicating a mismatch, appear on top and have a larger size.

    Args:
        NN_results (pd.DataFrame): The DataFrame containing the neural network results.
        logdir (Path): The directory where the figure will be saved.
        name (str): The name of the figure.
        merge_assay_pairs (bool): Whether to merge similar assays (mrna/rna, wgbs-pbat/wgbs-standard)
    Returns:
        None: Displays the plotly figure.
    """
    fig = go.Figure()

    # Combine similar assays
    if merge_assay_pairs:
        NN_results = merge_similar_assays(NN_results)

    # Adjustments for replacement and class ordering
    class_labels = NN_results["True class"].unique()
    class_labels_sorted = sorted(class_labels)
    class_index = {label: i for i, label in enumerate(class_labels_sorted)}

    scatter_offset = 0.05  # Scatter plot jittering

    for label in class_labels_sorted:
        df = NN_results[NN_results[ASSAY] == label]

        # Majority vote, mean prediction score
        groupby_epirr = df.groupby(["EpiRR", "Predicted class"])["Max pred"].aggregate(
            ["size", "mean"]
        )

        groupby_epirr = groupby_epirr.reset_index().sort_values(
            ["EpiRR", "size"], ascending=[True, False]
        )
        groupby_epirr = groupby_epirr.drop_duplicates(subset="EpiRR", keep="first")
        assert groupby_epirr["EpiRR"].is_unique

        mean_pred = groupby_epirr["mean"]

        # Add violin plot with integer x positions
        line_color = "white"
        fig.add_trace(
            go.Violin(
                x=[class_index[label]] * len(mean_pred),
                y=mean_pred,
                name=label,
                spanmode="hard",
                box_visible=True,
                meanline_visible=True,
                points=False,
                fillcolor=assay_colors[label],
                line_color=line_color,
                line=dict(width=0.8),
                showlegend=False,
            )
        )

        # Prepare data for scatter plots
        jittered_x_positions = np.random.uniform(-scatter_offset, scatter_offset, size=len(mean_pred)) + class_index[label] - 0.25  # type: ignore

        match_pred = [
            mean_pred.iloc[i]
            for i, row in enumerate(groupby_epirr.iterrows())
            if row[1]["Predicted class"] == label
        ]
        mismatch_pred = [
            mean_pred.iloc[i]
            for i, row in enumerate(groupby_epirr.iterrows())
            if row[1]["Predicted class"] != label
        ]

        match_x_positions = [
            jittered_x_positions[i]
            for i, row in enumerate(groupby_epirr.iterrows())
            if row[1]["Predicted class"] == label
        ]
        mismatch_x_positions = [
            jittered_x_positions[i]
            for i, row in enumerate(groupby_epirr.iterrows())
            if row[1]["Predicted class"] != label
        ]

        # Add scatter plots for matches in black
        fig.add_trace(
            go.Scatter(
                x=match_x_positions,
                y=match_pred,
                mode="markers",
                name=f"Match {label}",
                marker=dict(
                    color="black",
                    size=1,  # Standard size for matches
                ),
                hoverinfo="text",
                hovertext=[
                    f"EpiRR: {row[1]['EpiRR']}, Pred class: {row[1]['Predicted class']}, Mean pred: {row[1]['mean']:.2f}"
                    for row in groupby_epirr.iterrows()
                    if row[1]["Predicted class"] == label
                ],
                showlegend=False,
            )
        )

        # Add scatter plots for mismatches in red, with larger size
        fig.add_trace(
            go.Scatter(
                x=mismatch_x_positions,
                y=mismatch_pred,
                mode="markers",
                name=f"Mismatch {label}",
                marker=dict(
                    color="red",
                    size=3,  # Larger size for mismatches
                ),
                hoverinfo="text",
                hovertext=[
                    f"EpiRR: {row[1]['EpiRR']}, Pred class: {row[1]['Predicted class']}, Mean pred: {row[1]['mean']:.3f}"
                    for row in groupby_epirr.iterrows()
                    if row[1]["Predicted class"] != label
                ],
                showlegend=False,
            )
        )

    # Update layout to improve visualization
    fig.update_layout(
        title_text="Prediction score distribution per assay class",
        yaxis_title="Average prediction score (majority class)",
        xaxis_title="Expected class label",
    )
    fig.update_yaxes(range=[0.25, 1.01])
    fig.update_xaxes(tickvals=list(class_index.values()), ticktext=class_labels_sorted)

    # Add a dummy scatter plot for legend - black points
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            name="Match",
            marker=dict(color="black", size=10),
            showlegend=True,
            legendgroup="match",
        )
    )

    # Add a dummy scatter plot for legend - red points
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            name="Mismatch",
            marker=dict(color="red", size=10),
            showlegend=True,
            legendgroup="mismatch",
        )
    )

    # Update the layout to adjust the legend
    fig.update_layout(
        legend=dict(
            title_text="Legend",
            itemsizing="constant",
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1,
        )
    )

    # Save figure
    fig.write_image(logdir / f"{name}.svg")
    fig.write_image(logdir / f"{name}.png")
    fig.write_html(logdir / f"{name}.html")

    fig.show()

In [None]:
NN_results_path = (
    base_data_dir
    / "EpiClass_EA-21606_Assay11_100kb"
    / "NN"
    / "full-10fold-validation_prediction_augmented-all.csv"
)
NN_results = pd.read_csv(NN_results_path, header=0, index_col="md5sum", low_memory=False)

# fig1_a(NN_results, logdir=base_fig_dir, name="fig1_a")

### Figure 1.supp.A

Violin plot (10 folds) of overall accuracy for each model (NN, LR, RF, LGBM, SV).  
For each split, 4 box plot per model:
  - Acc
  - F1
  - AUROC (OvR, both micro/macro)


Source files:
~~~bash
cd ~/mounts/narval-mount/projects/rrg-jacquesp-ab/rabyj/epiclass-project/output/epiclass-logs/2023-01-epiatlas-freeze/  
find assay_epiclass* -type f -name *validation_prediction.csv -print0 | rsync -av --files-from=- --from0 ./ ~/Projects/epiclass/output/paper/data/EpiClass_EA-21606_Assay11_100kb/all_splits/
~~~

In [None]:
merged_results_path = (
    base_data_dir / "EpiClass_EA-21606_Assay11_100kb" / "all-predictions-merged.csv"
)
merged_results = pd.read_csv(
    NN_results_path, header=0, index_col="md5sum", low_memory=False
)

#### Verifying that results are from metadata v1.0

In [None]:
# verify that merged results and nn results have the same EpiRRs
assert (
    len(set(merged_results["epirr_id"]) & set(NN_results["epirr_id"]))
    == NN_results["epirr_id"].nunique()
)
assert len(set(merged_results.index) & set(NN_results.index)) == len(
    set(NN_results.index)
)

In [None]:
assert sum(merged_results["epirr_id"] == "IHECRE00003355.2") == 3

In [None]:
del merged_results  # using separate split results for this figure

#### Figure

In [None]:
def gather_split_results(
    results_dir: Path, label_category: str, only_NN: bool = False
) -> Dict[str, Dict[str, pd.DataFrame]]:
    """Gather split results for each classifier.

    Returns:
        Dict[str, Dict[str, pd.DataFrame]]: {split_name:{classifier_name: results_df}}
    """
    all_split_dfs = {}
    for split in [f"split{i}" for i in range(10)]:
        # Get the csv paths
        if label_category == ASSAY:
            second_dir_end = ""
        elif label_category == CELL_TYPE:
            second_dir_end = "-dfreeze-v2"

        NN_csv_path = (
            results_dir
            / f"{label_category}_1l_3000n"
            / f"10fold{second_dir_end}"
            / split
            / "validation_prediction.csv"
        )
        other_csv_root = (
            results_dir / f"{label_category}" / f"predict-10fold{second_dir_end}"
        )

        if not only_NN:
            if not other_csv_root.exists():
                raise FileNotFoundError(f"Could not find {other_csv_root}")
            other_csv_paths = other_csv_root.glob(f"*/*_{split}_validation_prediction.csv")

            other_csv_paths = list(other_csv_paths)
            if len(other_csv_paths) != 4:
                raise AssertionError(
                    f"Expected 4 other_csv_paths, got {len(other_csv_paths)}"
                )

        # Load the dataframes
        dfs = {}
        dfs["NN"] = pd.read_csv(NN_csv_path, header=0, index_col=0, low_memory=False)

        if not only_NN:
            for path in other_csv_paths:
                name = path.name.split("_", maxsplit=1)[0]
                dfs[name] = pd.read_csv(path, header=0, index_col=0, low_memory=False)

        # Verify that all dataframes have the same md5sums
        md5s = {}
        for key, df in dfs.items():
            md5s[key] = set(df.index)

        base_md5s = md5s["NN"]
        if not base_md5s.intersection(*list(md5s.values())) == base_md5s:
            raise AssertionError("Not all dataframes have the same md5sums")

        all_split_dfs[split] = dfs

    return all_split_dfs

In [None]:
def compute_split_metrics(
    all_split_dfs: Dict[str, Dict[str, pd.DataFrame]]
) -> Dict[str, Dict[str, Dict[str, float]]]:
    """Compute desired metrics for each split and classifier."""
    split_metrics = {}
    for split in [f"split{i}" for i in range(10)]:
        dfs = all_split_dfs[split]

        # Compute metrics for the split
        metrics = {}
        for key, df in dfs.items():
            # One-hot encode true and predicted classes
            classes_order = df.columns[2:]
            onehot_true = (
                pd.get_dummies(df["True class"], dtype=int)
                .reindex(columns=classes_order, fill_value=0)
                .values
            )
            pred_probs = df[
                classes_order
            ].values  # Ensure this aligns with your model's output format

            ravel_true = np.argmax(onehot_true, axis=1)
            ravel_pred = np.argmax(pred_probs, axis=1)

            metrics[key] = {
                "Accuracy": accuracy_score(ravel_true, ravel_pred),
                "F1_macro": f1_score(ravel_true, ravel_pred, average="macro"),
                "AUC_micro": roc_auc_score(
                    onehot_true, pred_probs, multi_class="ovr", average="micro"
                ),
                "AUC_macro": roc_auc_score(
                    onehot_true, pred_probs, multi_class="ovr", average="macro"
                ),
            }

            split_metrics[split] = metrics

    return split_metrics

In [None]:
def plot_split_metrics(
    split_metrics: Dict[str, Dict[str, Dict[str, float]]],
    label_category: str,
    logdir: Path,
    name: str,
) -> None:
    """Render to box plots the metrics per classifier and split, each in its own subplot.

    Args:
        split_metrics: A dictionary containing metric scores for each classifier and split.
    """
    metrics = ["Accuracy", "F1_macro", "AUC_micro", "AUC_macro"]
    classifiers = list(next(iter(split_metrics.values())).keys())

    # Create subplots, one row for each metric
    fig = make_subplots(rows=1, cols=len(metrics), subplot_titles=metrics)

    colors = {
        classifier: px.colors.qualitative.Plotly[i]
        for i, classifier in enumerate(classifiers)
    }

    for i, metric in enumerate(metrics):
        for classifier in classifiers:
            values = [split_metrics[split][classifier][metric] for split in split_metrics]

            fig.add_trace(
                go.Box(
                    y=values,
                    name=classifier,
                    marker_color=colors[classifier],
                    line=dict(color="black", width=1),
                    marker=dict(size=2),
                    boxmean=True,
                    boxpoints="all",  # or "outliers" to show only outliers
                    pointpos=-1.4,
                    showlegend=False,
                    width=0.5,
                    hoverinfo="text",
                    hovertext=[
                        f"{split}: {value:.4f}"
                        for split, value in zip(split_metrics, values)
                    ],
                ),
                row=1,
                col=i + 1,
            )

    fig.update_layout(
        title_text=f"{label_category} classification - Metric distribution for 10fold cross-validation",
        yaxis_title="Value",
        boxmode="group",
    )

    # Save figure
    fig.write_image(logdir / f"{name}.svg")
    fig.write_image(logdir / f"{name}.png")
    fig.write_html(logdir / f"{name}.html")

    fig.show()

In [None]:
path_results_assay = base_data_dir / "EpiClass_EA-21606_Assay11_100kb" / "all_splits"
path_results_cell_type = base_data_dir / CELL_TYPE / "all_splits"

for label_category, path in zip(
    [ASSAY, CELL_TYPE], [path_results_assay, path_results_cell_type]
):
    all_split_dfs = gather_split_results(results_dir=path, label_category=label_category)
    split_metrics = compute_split_metrics(all_split_dfs)
    dfreeze_version = "dfreeze-v2" if label_category == CELL_TYPE else "dfreeze-v1"
    plot_split_metrics(split_metrics, label_category=label_category, logdir=base_fig_dir, name=f"{label_category}_10fold_metrics_all_classifiers_{dfreeze_version}")

### Figure 1.supp.B

Per model, compute score distribution per assay (1 violin per assay). No SVM. Agree black, red disagree.

In [None]:
def fig1_supp_B(df_dict: Dict[str, pd.DataFrame], logdir: Path, name: str) -> None:
    """
    Creates a Plotly figure with subplots for each assay, each containing violin plots for different classifiers
    and associated scatter plots for matches (in black) and mismatches (in red).

    Args:
        df_dict (Dict[str, pd.DataFrame]): Dictionary with the DataFrame containing the results for each classifier.
        logdir (Path): The directory path for saving the figures.
        name (str): The name for the saved figures.

    Returns:
        None: Displays the plotly figure.
    """
    # Ignore LinearSVC and RandomForest for this figure
    if "LinearSVC" in df_dict:
        del df_dict["LinearSVC"]
    if "RF" in df_dict:
        del df_dict["RF"]

    # Assuming all classifiers have the same assays for simplicity
    first_key = next(iter(df_dict))
    class_labels = df_dict[first_key]["True class"].unique()
    class_labels_sorted = sorted(class_labels)
    num_assays = len(class_labels_sorted)

    classifiers = list(df_dict.keys())
    classifier_index = {name: i for i, name in enumerate(classifiers)}
    num_classifiers = len(classifiers)

    scatter_offset = 0.05  # Scatter plot jittering

    # Calculate the size of the grid
    grid_size = int(np.ceil(np.sqrt(num_assays)))
    rows, cols = grid_size, grid_size

    # Create subplots with a square grid
    fig = make_subplots(
        rows=rows,
        cols=cols,
        subplot_titles=class_labels_sorted,
        shared_yaxes="all",  # type: ignore
        horizontal_spacing=0.05,
        vertical_spacing=0.05,
        y_title="Average prediction score",
    )
    for idx, label in enumerate(class_labels_sorted):
        row, col = divmod(idx, grid_size)
        for classifier_name, classifier_df in df_dict.items():
            df = classifier_df[classifier_df["True class"] == label]

            # Majority vote, mean prediction score
            groupby_epirr = df.groupby(["EpiRR", "Predicted class"])[
                "Max pred"
            ].aggregate(["size", "mean"])
            groupby_epirr = groupby_epirr.reset_index().sort_values(
                ["EpiRR", "size"], ascending=[True, False]
            )
            groupby_epirr = groupby_epirr.drop_duplicates(subset="EpiRR", keep="first")
            assert groupby_epirr["EpiRR"].is_unique

            mean_pred = groupby_epirr["mean"]
            classifier_pos = classifier_index[classifier_name]

            # Add violin plot with integer x positions
            fig.add_trace(
                go.Violin(
                    x=classifier_pos * np.ones(len(mean_pred)),
                    y=mean_pred,
                    name=label,
                    spanmode="hard",
                    box_visible=True,
                    meanline_visible=True,
                    points=False,
                    fillcolor="grey",
                    line_color="black",
                    line=dict(width=0.8),
                    showlegend=False,
                ),
                row=row + 1,  # Plotly rows are 1-indexed
                col=col + 1,
            )

            # Prepare data for scatter plots
            jittered_x_positions = np.random.uniform(-scatter_offset, scatter_offset, size=len(mean_pred)) + classifier_pos - 0.3  # type: ignore

            match_pred = [
                mean_pred.iloc[i]
                for i, row in enumerate(groupby_epirr.iterrows())
                if row[1]["Predicted class"] == label
            ]
            mismatch_pred = [
                mean_pred.iloc[i]
                for i, row in enumerate(groupby_epirr.iterrows())
                if row[1]["Predicted class"] != label
            ]

            match_x_positions = [
                jittered_x_positions[i]
                for i, row in enumerate(groupby_epirr.iterrows())
                if row[1]["Predicted class"] == label
            ]
            mismatch_x_positions = [
                jittered_x_positions[i]
                for i, row in enumerate(groupby_epirr.iterrows())
                if row[1]["Predicted class"] != label
            ]

            # Add scatter plots for matches in black
            fig.add_trace(
                go.Scatter(
                    x=match_x_positions,
                    y=match_pred,
                    mode="markers",
                    marker=dict(color="black", size=1),
                    showlegend=False,
                    name=f"Match {classifier_name}",
                ),
                row=row + 1,  # Plotly rows are 1-indexed
                col=col + 1,
            )

            # Add scatter plots for mismatches in red
            fig.add_trace(
                go.Scatter(
                    x=mismatch_x_positions,
                    y=mismatch_pred,
                    mode="markers",
                    marker=dict(color="red", size=3),
                    showlegend=False,
                    name=f"Mismatch {classifier_name}",
                ),
                row=row + 1,  # Plotly rows are 1-indexed
                col=col + 1,
            )

    # Add a dummy scatter plot for legend - black points
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            name="Match",
            marker=dict(color="black", size=10),
            showlegend=True,
            legendgroup="match",
        )
    )

    # Add a dummy scatter plot for legend - red points
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            name="Mismatch",
            marker=dict(color="red", size=10),
            showlegend=True,
            legendgroup="mismatch",
        )
    )

    # Update the layout to adjust the legend
    fig.update_layout(
        legend=dict(
            title_text="Legend",
            itemsizing="constant",
            orientation="h",
            yanchor="bottom",
            y=1.025,
            xanchor="right",
            x=1,
        )
    )

    # Update layout to improve visualization, adjust if needed for better appearance with multiple classifiers
    fig.update_layout(
        title_text="Prediction score distribution per assay across classifiers",
        height=1500,  # Adjust the height as necessary
        width=1500,  # Adjust the width based on the number of assays
    )

    fig.update_layout(yaxis2=dict(range=[0.9, 1.01]))

    # Adjust tick names
    # Assuming equal spacing between each classifier on the x-axis
    tickvals = list(
        range(0, num_classifiers + 1)
    )  # Generate tick values (1-indexed for Plotly)
    ticktext = classifiers  # Use classifier names as tick labels
    for i, j in itertools.product(range(rows), range(cols)):
        fig.update_xaxes(tickvals=tickvals, ticktext=ticktext, row=i + 1, col=j + 1)

    # Save figure
    fig.write_image(logdir / f"{name}_min0.9.svg")
    fig.write_image(logdir / f"{name}_min0.9.png")
    fig.write_html(logdir / f"{name}_min0.9.html")

    fig.show()

In [None]:
def concatenate_split_results(
    split_dfs: Dict[str, Dict[str, pd.DataFrame]]
) -> Dict[str, pd.DataFrame]:
    """Concatenate split results for each different classifier.

    Args:
        split_dfs (Dict[str, Dict[str, pd.DataFrame]]): {split_name:{classifier_name: results_df}}

    Returns:
        Dict[str, pd.DataFrame]: {classifier_name: concatenated_df}
    """
    to_concat_dfs = defaultdict(list)
    for dfs in split_dfs.values():
        for classifier, df in dfs.items():
            to_concat_dfs[classifier].append(df)

    concatenated_dfs = {
        classifier: pd.concat(dfs, axis=0) for classifier, dfs in to_concat_dfs.items()
    }

    # Verify index is still md5sum
    for df in concatenated_dfs.values():
        if not isinstance(df.index[0], str):
            raise AssertionError("Index is not md5sum")

    return concatenated_dfs

In [None]:
def join_metadata(df: pd.DataFrame, metadata: Metadata) -> pd.DataFrame:
    """Join the metadata to the results dataframe."""
    metadata_df = pd.DataFrame(metadata.datasets)
    metadata_df.set_index("md5sum", inplace=True)

    diff_set = set(df.index) - set(metadata_df.index)
    if diff_set:
        err_df = pd.DataFrame(diff_set, columns=["md5sum"])
        err_df.to_csv(base_data_dir / "join_missing_md5sums.csv", index=False)
        raise AssertionError(f"{len(diff_set)} md5sums in the results dataframe are not present in the metadata dataframe. Saved error md5sums to join_missing_md5sums.csv.")

    merged_df = df.merge(metadata_df, how="left", left_index=True, right_index=True)
    if len(merged_df) != len(df):
        raise AssertionError("Merged dataframe has different length than original dataframe")
    return merged_df

In [None]:
# all_split_dfs = gather_split_results(path_results_assay, ASSAY)
# full_dfs = concatenate_split_results(all_split_dfs)
# merged_dfs = {classifier: merge_similar_assays(df) for classifier, df in full_dfs.items()}
# assays = merged_dfs[next(iter(merged_dfs))]["True class"].unique()

# # Add Max pred
# for classifier, df in merged_dfs.items():
#     df["Max pred"] = df[assays].max(axis=1)

# # Join metadata
# metadata_path = (
#     base_data_dir / "metadata" / "hg38_2023-epiatlas_dfreeze_formatted_JR.json"
# )
# metadata_dfreeze1 = Metadata(metadata_path)
# metadata_dfreeze1_df = pd.DataFrame(metadata_dfreeze1.datasets)

# for classifier, df in merged_dfs.items():
#     merged_dfs[classifier] = df.merge(
#         metadata_dfreeze1_df, how="left", left_index=True, right_on="md5sum"
#     )

In [None]:
# fig1_supp_B(merged_dfs, logdir=base_fig_dir, name="fig1_supp_B")

### Figure 1.supp.F

For each classifier type

Confusion matrix (1point=1 uuid) for observed datasets with average scores>0.9
- Goal: Represent global predictions/mislabels. 11c

In [None]:
def create_confusion_matrix(
    df: pd.DataFrame,
    min_pred_score: float,
    logdir: Path,
    name: str,
    majority: bool = False,
) -> None:
    """Create a confusion matrix for the given DataFrame.

    Args:
        df (pd.DataFrame): The DataFrame containing the neural network results.
        min_pred_score (float): The minimum prediction score to consider.
        logdir (Path): The directory path for saving the figures.
        name (str): The name for the saved figures.
        majority (bool): Whether to use majority vote (uuid-wise) for the predicted class.
    """
    # Compute confusion matrix
    classes = sorted(df["True class"].unique())
    if "Max pred" not in df.columns:
        df["Max pred"] = df[classes].max(axis=1)
    filtered_df = df[df["Max pred"] > min_pred_score]

    if majority:
        # Majority vote for predicted class
        groupby_uuid = df.groupby(["uuid", "True class", "Predicted class"])[
            "Max pred"
        ].aggregate(["size", "mean"])
        groupby_uuid = groupby_uuid.reset_index().sort_values(
            ["uuid", "True class", "size"], ascending=[True, True, False]
        )
        groupby_uuid = groupby_uuid.drop_duplicates(
            subset=["uuid", "True class"], keep="first"
        )
        filtered_df = groupby_uuid

    confusion_mat = sk_cm(
        filtered_df["True class"], filtered_df["Predicted class"], labels=classes
    )

    mat_writer = ConfusionMatrixWriter(labels=classes, confusion_matrix=confusion_mat)
    mat_writer.to_all_formats(logdir, name=f"{name}_n{len(filtered_df)}")

In [None]:
# min_pred_score = 0.9
# majority = True

# for classifier_name, df in full_dfs.items():
#     df_with_meta = df.merge(
#         metadata_dfreeze1_df, how="left", left_index=True, right_on="md5sum"
#     )
#     assert "Predicted class" in df_with_meta.columns

#     name = f"{classifier_name}_pred>{min_pred_score}"
#     if classifier_name == "LinearSVC":
#         name = f"{classifier_name}"

#     logdir = base_fig_dir / "fig1_supp_F-assay_c11_confusion_matrices"
#     if majority:
#         logdir = logdir / "per_uuid"
#     else:
#         logdir = logdir / "per_file"

    # create_confusion_matrix(
    #     df=df_with_meta,
    #     min_pred_score=min_pred_score,
    #     logdir=logdir,
    #     name=name,
    #     majority=majority
    #     )

### Figure 1.supp.D

Inference on imputed data: Violin plot with pred score per assay (like Fig1A)

In [None]:
fig_dir = base_fig_dir / "fig1_supp_D"
this_data_dir = base_data_dir / "imputation"

# Load data
normal_inf_imputed_path = next(
    (this_data_dir / "hg38_100kb_all_none/assay_epiclass_1l_3000n").glob("**/*.csv")
)
normal_inf_imputed_df = pd.read_csv(
    normal_inf_imputed_path, header=0, index_col=0, low_memory=False
)

imputed_inf_normal_path = next(
    (this_data_dir / "hg38_100kb_all_none_imputed/assay_epiclass_1l_3000n").rglob(
        "**/*.csv"
    )
)
imputed_inf_normal_df = pd.read_csv(
    imputed_inf_normal_path, header=0, index_col=0, low_memory=False
)

assay_labels = normal_inf_imputed_df["True class"].unique()
for name, df in zip(
    ["train_normal_inf_imputed", "train_imputed_inf_normal"],
    [normal_inf_imputed_df, imputed_inf_normal_df],
):
    df["EpiRR"] = list(df.index)
    df[ASSAY] = df["True class"]
    df["Max pred"] = df[assay_labels].max(axis=1)
    # fig1_a(
    #     df, logdir=fig_dir, name=f"fig1_supp_D-{name}_n{len(df)}", merge_assay_pairs=False
    # )

### Flagship paper figure

cell type classifier:  

  for each assay, have a violin plot for accuracy per cell type (16 points)

In [None]:
def fig_flagship_ct(cell_type_df: pd.DataFrame, logdir: Path, name: str) -> None:
    """
    [FILL HERE]

    Args:
        cell_type_df (pd.DataFrame): DataFrame containing the cell type prediction results.
        logdir (Path): The directory path for saving the figure.
        name (str): The name for the saved figure.

    Returns:
        None: Displays the plotly figure.
    """

    # Assuming all classifiers have the same assays for simplicity
    assay_labels = sorted(cell_type_df[ASSAY].unique())
    num_assays = len(assay_labels)

    ct_labels = sorted(cell_type_df["True class"].unique())
    if len(ct_labels) != 16:
        raise AssertionError(f"Expected 16 cell type labels, got {len(ct_labels)}")
    ct_colors = [cell_type_colors[ct_label] for ct_label in ct_labels]

    scatter_offset = 0.1  # Scatter plot jittering

    # Calculate the size of the grid
    grid_size = int(np.ceil(np.sqrt(num_assays)))
    rows, cols = grid_size, grid_size

    # Create subplots with a square grid
    fig = make_subplots(
        rows=rows,
        cols=cols,
        subplot_titles=assay_labels,
        shared_yaxes="all",  # type: ignore
        horizontal_spacing=0,
        vertical_spacing=0.02,
        y_title="Cell type subclass accuracy",
    )
    for idx, assay_label in enumerate(assay_labels):
        row, col = divmod(idx, grid_size)
        assay_df = cell_type_df[cell_type_df[ASSAY] == assay_label]

        # cell type subclass accuracy
        subclass_size = assay_df.groupby(["True class"]).agg("size")
        subclass_groupby_acc = assay_df.groupby(["True class", "Predicted class"]).agg("size")
        accuracies = {}
        for ct_label in sorted(ct_labels):
            acc_label = subclass_groupby_acc[ct_label][ct_label] / subclass_size[ct_label]
            accuracies[ct_label] = acc_label

        acc_values = list(accuracies.values())

        # Add violin plot with integer x positions
        fig.add_trace(
            go.Violin(
                x=[idx] * len(accuracies),
                y=acc_values,
                name=assay_label,
                spanmode="hard",
                box_visible=True,
                meanline_visible=True,
                points=False,
                fillcolor=assay_colors[assay_label],
                line_color="white",
                line=dict(width=0.8),
                showlegend=False,
            ),
            row=row + 1,  # Plotly rows are 1-indexed
            col=col + 1,
        )

        fig.update_xaxes(showticklabels=False)

        # Prepare data for scatter plots
        jittered_x_positions = np.random.uniform(-scatter_offset, scatter_offset, size=len(accuracies)) + idx - 0.4  # type: ignore


        # Add scatter plots for matches in black
        fig.add_trace(
            go.Scatter(
                x=jittered_x_positions,
                y=acc_values,
                mode="markers",
                marker=dict(
                    size=3,  # Standard size for matches
                    color=ct_colors
                ),
                hoverinfo="text",
                hovertext=[f"{ct_label} ({accuracies[ct_label]:.3f}, n={subclass_size[ct_label]})" for ct_label in accuracies],
                showlegend=False,
            ),
            row=row + 1,  # Plotly rows are 1-indexed
            col=col + 1,
        )

    # Add a dummy scatter plot for legend
    for ct_label in ct_labels:
        fig.add_trace(
            go.Scatter(
                x=[None],
                y=[None],
                mode="markers",
                name=ct_label,
                marker=dict(color=cell_type_colors[ct_label], size=3),
                showlegend=True,
            )
        )


    fig.update_layout(
        title="cell type classifier: accuracy per output class for each assay",
        height=1500,
        width=1500,
    )

    # # Save figure
    fig.write_image(logdir / f"{name}.svg")
    fig.write_image(logdir / f"{name}.png")
    fig.write_html(logdir / f"{name}.html")

    fig.show()

In [None]:
ct_split_dfs = gather_split_results(path_results_cell_type, CELL_TYPE)
ct_full_df = concatenate_split_results(ct_split_dfs)["NN"]

In [None]:
# metadata_path = (
#     base_data_dir / "metadata" / "hg38_2023-epiatlas_dfreeze_formatted_JR.json"
# )
# metadata_1 = Metadata(metadata_path)
# metadata_df_1 = pd.DataFrame(metadata_1.datasets)
# metadata_df_1.set_index("md5sum", inplace=True)

metadata_path = (
    base_data_dir / "metadata" / "hg38_2023-epiatlas-dfreeze-pospurge-nodup_filterCtl.json"
)
metadata_2 = Metadata(metadata_path)
# metadata_df_2 = pd.DataFrame(metadata_2.datasets)
# metadata_df_2.set_index("md5sum", inplace=True)
# metadata_df_2.to_csv(base_data_dir / "metadata" / "hg38_2023-epiatlas-dfreeze_v2.1_w_encode_noncore_2.csv")

In [None]:
ct_full_df = join_metadata(ct_full_df, metadata_2)
ct_full_df[ASSAY].replace(ASSAY_MERGE_DICT, inplace=True)

In [None]:
fig_dir

In [None]:
fig_dir = base_fig_dir / "flagship"
fig_flagship_ct(ct_full_df, logdir=fig_dir, name="ct_assay_accuracy")



10) cell type classification: check (input, ct) pairs for enrichment in any metadata category
    - e.g. (input, myloid) all cancer, or all from a certain data_generating_center
    - use biomaterial_type, sex, cancer, standedness-smth, data_generating_center, and other categories we have tested

In [None]:
def calculate_metadata_distribution(df: pd.DataFrame, columns: List[str]) -> Dict[str, pd.Series]:
    """
    Calculates the percentage of metadata labels within specified columns of a DataFrame.

    Args:
        df: A pandas DataFrame containing the data.
        columns: A list of column names to analyze.

    Returns:
        A dictionary where keys are column names and values are Series objects containing
        the percentage of each unique label in the respective column.
    """
    distribution = {}
    nb_samples = len(df)
    for column in columns:
        # Count the occurrences of each unique value in the column
        value_counts = df[column].value_counts(dropna=False)
        # Calculate the percentages
        percentages = (value_counts / nb_samples) * 100
        # Store the results in the dictionary
        distribution[column] = percentages

    return distribution

In [None]:
def compare_label_ratios(target_distribution: Dict[str, pd.Series],
                         comparison_distributions: List[Dict[str, pd.Series]],
                         labels: List[str]) -> Dict[str, pd.DataFrame]:
    """
    Compares label ratios of a target distribution against multiple comparison distributions,
    calculating the difference in percentage points for each label within each metadata category.

    Args:
        target_distribution: A dictionary of Series representing the target distribution for comparison.
        comparison_distributions: A list of dictionaries of Series, where each dictionary
                                  represents a distribution (e.g., assay, cell type, global) for comparison.
        labels: A list of labels corresponding to each distribution in `comparison_distributions`,
                used for labeling the columns in the result.

    Returns:
        A dictionary of DataFrames, where each DataFrame shows the difference in percentage points
        for each label in a metadata category between the target distribution and each of the
        comparison distributions.
    """
    comparison_results = {}
    for category, target_series in target_distribution.items():
        # Initialize a DataFrame to store comparison results for this category
        comparison_df = pd.DataFrame()

        for label, comparison_distribution in zip(labels, comparison_distributions):
            # Ensure the comparison distribution series for this category exists and align target with comparison
            comparison_series = comparison_distribution.get(category, pd.Series(dtype='float64'))
            aligned_target, aligned_comparison = target_series.align(comparison_series, fill_value=0)

            # Calculate difference in percentage points
            difference = aligned_target - aligned_comparison

            # Store the results in the comparison DataFrame
            comparison_df[f"Difference_vs_{label}"] = difference

        comparison_results[category] = comparison_df

    return comparison_results

In [None]:
metadata_categories = ["harmonized_donor_life_stage", "harmonized_donor_sex", "harmonized_sample_disease_high", "harmonized_biomaterial_type", "paired_end", "project"]

In [None]:
global_dist = calculate_metadata_distribution(ct_full_df, metadata_categories)
subclass_distributions = {}
comparison_results = {}  # Initialize a dict to hold comparison results for each subgroup

for group in ct_full_df.groupby(ASSAY):
    label = group[0]
    sub_df = group[1]
    subclass_distributions[label] = calculate_metadata_distribution(sub_df, metadata_categories)

for group in ct_full_df.groupby(CELL_TYPE):
    label = group[0]
    sub_df = group[1]
    subclass_distributions[label] = calculate_metadata_distribution(sub_df, metadata_categories)

# Loop through each group and compare to global
for group in ct_full_df.groupby([ASSAY, CELL_TYPE]):
    assay, cell_type = group[0]
    sub_df = group[1]
    pair_subclass_dist = calculate_metadata_distribution(sub_df, metadata_categories)
    subclass_distributions[(assay, cell_type)] = pair_subclass_dist

    assay_dist = subclass_distributions[assay]
    cell_type_dist = subclass_distributions[cell_type]

    comparisons_dists = [assay_dist, cell_type_dist, global_dist]
    comparison_labels = [assay, cell_type, "global"]

    comparison_results[(assay, cell_type)] = compare_label_ratios(
        target_distribution=pair_subclass_dist,
        comparison_distributions=comparisons_dists,
        labels=comparison_labels,
        )


pair_dfs = []
for (assay, cell_type), comparisons in comparison_results.items():
    # Initialize an empty list to collect DataFrames for concatenation
    dfs_to_concat = []

    for category, df_comparison in comparisons.items():
        df_comparison.columns = ["Difference vs Assay", "Difference vs Cell Type", "Difference vs Global"]
        # Add identifiers for the assay, cell type, and category
        df_comparison['Assay'] = assay
        df_comparison['Cell Type'] = cell_type
        df_comparison['Category'] = category

        subclass_dist = subclass_distributions[(assay, cell_type)][category]
        df_comparison["(assay, ct) subclass %"] = subclass_dist

        # Collect the DataFrame
        dfs_to_concat.append(df_comparison.reset_index())


    # Concatenate all DataFrames along rows
    final_df = pd.concat(dfs_to_concat, ignore_index=True)
    final_df.fillna(0, inplace=True)

    new_columns = final_df.columns.tolist()
    new_first = ["index", "Category", "(assay, ct) subclass %"]
    for label in new_first:
        new_columns.remove(label)
    new_columns = new_first + new_columns
    final_df = final_df[new_columns]

    pair_dfs.append(final_df)

    # # Define the file path
    # file_path = base_fig_dir / f"metadata_comparison_{assay}_{cell_type}.csv"

    # # Save the DataFrame to CSV
    # final_df.to_csv(file_path, index=False)

# Concatenate all DataFrames along rows
all_pairs_df = pd.concat(pair_dfs, axis=0, ignore_index=True)
all_pairs_df.columns = ["Label" if x=='index' else x for x in all_pairs_df.columns]
file_path = base_fig_dir / f"metadata_comparison_all.csv"
all_pairs_df.to_csv(file_path, index=False)

collect all general run parameters: fix oversampling when missing
i.e. create a new all_results_cometml_filtered_oversampling-fixed.csv
- get difference of content between different metadata groups (diff md5, create new meta obj with just diff, display labels the usual way)