In [None]:
"""Analysis of relation between predictions and correlation values."""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines

In [None]:
from __future__ import annotations

from pathlib import Path

import pandas as pd
import plotly.graph_objects as go
from IPython.display import display

from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    ASSAY_MERGE_DICT,
    ASSAY_ORDER,
    CELL_TYPE,
    IHECColorMap,
    MetadataHandler,
)

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
paper_dir = base_dir

if not base_fig_dir.exists():
    raise FileNotFoundError(f"Directory {base_fig_dir} does not exist.")

In [None]:
metadata_handler = MetadataHandler(paper_dir)
metadata = metadata_handler.load_metadata("v2")
metadata.convert_classes(ASSAY, ASSAY_MERGE_DICT)

In [None]:
IHECColorMap = IHECColorMap(base_fig_dir)
assay_colors = IHECColorMap.assay_color_map

## Read matrix, only keep EpiAtlas correlations + public sources mislabels

In [None]:
pred_dir = base_data_dir / "training_results" / "predictions"
pred_file = pred_dir / "mislabels_C-A&ENCODE_assay7.csv"
pred_df = pd.read_csv(pred_file)

In [None]:
# pred_df.columns[0:10]

In [None]:
min_pred = 0.6
pred_df = pred_df[pred_df["Max_pred_assay7"] >= min_pred]

In [None]:
matrix_file = pred_dir / "mislabels_C-A&ENCODE_assay7_100kb_all_none_epiatlas.mat"
matrix_df = pd.read_csv(matrix_file, sep="\t", header=0, index_col=0, low_memory=False)

In [None]:
# Rename header and rows
matrix_files = [col.split("_")[0] for col in matrix_df.columns]
matrix_df.columns = matrix_files

matrix_df.index = [row.split("_")[0] for row in matrix_df.index]  # type: ignore

In [None]:
# Drop non-epiatlas columns
epiatlas_md5s = metadata.md5s
matrix_df = matrix_df.drop(
    columns=[col for col in matrix_df.columns if col not in epiatlas_md5s]
)

In [None]:
# matrix_df.to_csv(
#     pred_dir / "mislabels_C-A&ENCODE_assay7_100kb_all_none_epiatlas.mat", sep="\t"
# )

## Plot correlations violins per assay

In [None]:
output_dir = pred_dir / "mislabel_correlation_graphs"
if not output_dir.exists():
    raise FileNotFoundError(f"Directory {output_dir} does not exist.")

In [None]:
assay_md5s = metadata.md5_per_class(ASSAY)

for sample, true_class, pred_val, pred_score in pred_df[
    [
        "Experimental-id",
        "manual_target_consensus",
        "Predicted_class_assay7",
        "Max_pred_assay7",
    ]
].values.tolist():
    if pred_val == "input" or true_class == "input":
        continue
    # if pred_score <= 0.8 or list(output_dir.glob(sample + "*")):
    #     continue
    if pred_score < 0.9 or pred_val != "h3k27ac":
        continue

    df_row = matrix_df.loc[sample]

    fig = go.Figure()
    for assay_label in ASSAY_ORDER:
        md5s = assay_md5s[assay_label]
        df = df_row.loc[md5s]

        fig.add_trace(
            go.Violin(
                y=df.values.flatten(),
                name=assay_label,
                points="all",
                box_visible=True,
                meanline_visible=True,
                fillcolor=assay_colors[assay_label],
                line_color="black",
                opacity=0.6,
                marker=dict(size=2),
                spanmode="hard",
            )
        )

    fig.update_layout(
        title=f"EpiAtlas correlations with {sample} (labeled {true_class}, predicted {pred_val}), pred_score={pred_score:.2f}",
        xaxis_title="Assay",
        yaxis_title="Correlation",
        showlegend=True,
    )

    name = f"{sample}_label-{true_class}_pred-{pred_val}_score{pred_score:.2f}_correlation_violin"

    # print(f"Saving {name}")
    # logdir = output_dir / "all"
    # fig.write_html(logdir/ f"{name}.html")
    # fig.write_image(logdir/ f"{name}.png")
    # fig.write_image(logdir / f"{name}.svg")

In [None]:
for min_pred in [0.6, 0.8, 0.9]:
    for pred_assay in ASSAY_ORDER:
        if pred_assay[0:2] != "h3":
            continue

        # Filter mislabeled samples
        mask_h3k27ac = pred_df["Predicted_class_assay7"] == pred_assay
        mask_label_input = pred_df["manual_target_consensus"] == "input"

        pred_score_mask = pred_df["Max_pred_assay7"] >= min_pred
        mislabel_samples = pred_df[mask_h3k27ac & ~mask_label_input & pred_score_mask]

        # Expected class composition
        manual_target_count = (
            mislabel_samples["manual_target_consensus"].value_counts().to_dict()
        )
        composition_text = "<br>".join(
            [f"{key}: {value}" for key, value in manual_target_count.items()]
        )

        # Loop through each mislabeled sample
        avg_correlations = {assay_label: [] for assay_label in ASSAY_ORDER}
        sample_ids = {assay_label: [] for assay_label in ASSAY_ORDER}

        for sample, true_class, pred_val, pred_score in mislabel_samples[
            [
                "Experimental-id",
                "manual_target_consensus",
                "Predicted_class_assay7",
                "Max_pred_assay7",
            ]
        ].values.tolist():
            # Get the correlation values for the sample
            df_row = matrix_df.loc[sample]

            # Loop through each assay
            for assay_label in ASSAY_ORDER:
                md5s = assay_md5s[assay_label]
                df = df_row.loc[md5s]

                # Calculate the average correlation for the current assay
                avg_correlation = df.values.flatten().mean()

                # Store the average correlation and the sample ID
                avg_correlations[assay_label].append(avg_correlation)
                sample_ids[assay_label].append(sample)

        # Plot the average correlations using violin plots
        fig = go.Figure()
        for assay_label in ASSAY_ORDER:
            fig.add_trace(
                go.Violin(
                    y=avg_correlations[assay_label],
                    name=assay_label,
                    points="all",
                    box_visible=True,
                    meanline_visible=True,
                    spanmode="hard",
                    fillcolor=assay_colors[assay_label],
                    line_color="black",
                    opacity=0.6,
                    marker=dict(size=2),
                    hovertemplate="%{text}",
                    text=[
                        f"{sample}:{corr:.2f}"
                        for corr, sample in zip(
                            avg_correlations[assay_label], sample_ids[assay_label]
                        )
                    ],
                )
            )

        fig.update_layout(
            title=f"Average Correlation for Mislabels Predicted as {pred_assay} (pred_score >= {min_pred:.2f}) - 100kb resolution",
            xaxis_title="Assay",
            yaxis_title="Average Correlation",
            showlegend=True,
        )

        fig.update_layout(
            annotations=[
                go.layout.Annotation(
                    text=f"Expected class:<br>{composition_text}",
                    showarrow=False,
                    xref="paper",
                    yref="paper",
                    x=1.15,
                    y=0.20,
                    xanchor="right",
                    yanchor="auto",
                    xshift=0,
                    yshift=0,
                    font=dict(size=10),
                )
            ]
        )

        # Save the figure
        output_name = f"average_correlation_mislabels_{pred_assay}_pred{min_pred:.2f}"
        fig.write_html(output_dir / f"{output_name}.html")
        fig.write_image(output_dir / f"{output_name}.png")
        fig.write_image(output_dir / f"{output_name}.svg")