In [None]:
"""Plot UMAP embeddings for various datasets."""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines, unused-import, unused-argument, too-many-branches

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import annotations

import copy
import pickle
from pathlib import Path
from typing import Dict

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import umap
from IPython.display import display
from plotly.subplots import make_subplots
from sklearn.preprocessing import StandardScaler
from umap.umap_ import nearest_neighbors

from epi_ml.core.hdf5_loader import Hdf5Loader
from epi_ml.core.metadata import Metadata
from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    ASSAY_MERGE_DICT,
    ASSAY_ORDER,
    MetadataHandler,
)

In [None]:
CORE7_ASSAYS = ASSAY_ORDER[0:7]

In [None]:
UMAP = "plot_label"

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
paper_dir = base_dir

if not base_fig_dir.exists():
    raise FileNotFoundError(f"Directory {base_fig_dir} does not exist.")

In [None]:
metadata_handler = MetadataHandler(paper_dir)
metadata_v2 = metadata_handler.load_metadata("v2")

In [None]:
chromsize_path = base_data_dir / "chromsizes" / "hg38.noy.chrom.sizes"
hdf5_loader = Hdf5Loader(chrom_file=chromsize_path, normalization=True)

In [None]:
ca_pred_path = (
    base_data_dir
    / "training_results"
    / "predictions"
    / "C-A"
    / "assay_epiclass"
    / "CA_metadata_4DB+all_pred_subset.20240606_mod2.tsv"
)
ca_pred_df = pd.read_csv(ca_pred_path, sep="\t", low_memory=False)

## C-A input peak metrics

In [None]:
# peak_metrics_path = Path.home() / "Downloads" / "temp" / "peak_metrics" / "metadata.peak_chip_atlas.tsv"
# peak_metrics_df = pd.read_csv(peak_metrics_path, sep="\t", index_col=0)

In [None]:
# scaler = StandardScaler()
# scaled_peak_metrics = scaler.fit_transform(peak_metrics_df.values)
# ids = peak_metrics_df.index.tolist()

In [None]:
# precomputed_knn = nearest_neighbors(
#     X=scaled_peak_metrics,
#     n_neighbors=100,
#     metric="correlation",
#     random_state=42,
#     low_memory=False,
#     metric_kwds=None,
#     angular=None,
# )

# with open(peak_metrics_path.parent / "peak_metrics_precomputed_knn.pkl", "wb") as f:
#     pickle.dump(precomputed_knn, f)

In [None]:
# # UMAP parameters
# nn_default = 15
# nn_bigger = 30
# nn_biggest = 100
# embedding_params = {}
# for nn_size in [nn_default, nn_bigger, nn_biggest]:
#     embedding_params[f"standard_3D_nn{nn_size}"] = {
#         "n_neighbors": nn_size,
#         "min_dist": 0.1,
#         "n_components": 3,
#         "low_memory": False,
#     }
#     embedding_params[f"densmap_3D_nn{nn_size}"] = {
#         "n_neighbors": nn_size,
#         "min_dist": 0.1,
#         "n_components": 3,
#         "low_memory": False,
#         "densmap": True,
#     }

In [None]:
# for name, params in embedding_params.items():
#     embedding = umap.UMAP(
#         **params, random_state=42, precomputed_knn=precomputed_knn
#     ).fit_transform(X=scaled_peak_metrics)

#     with open(peak_metrics_path.parent / f"embedding_{name}.pkl", "wb") as f:
#         pickle.dump(
#             obj={"ids": ids, "embedding": embedding, "params": params}, file=f
#         )
#         print(f"Saved embedding_{name}.pkl")

## C-A hdf5 100kb UMAP

### Function definitions: metadata and plotting

In [None]:
def create_umap_graphs(
    embeddings_dir: Path,
    output_dir: Path,
    metadata: pd.DataFrame,
    include_3D: bool = True,
    force_color: Dict[str, str] = None,
) -> None:
    """Create UMAP graphs for all embeddings 2D/3D pairs in the given directory.

    Only uses samples present in metadata.

    metadata must contain the following columns:
        - ids: Sample IDs
        - assay_epiclass: experiment target
        - track_type: Track type
        - Value of 'UMAP' constant: UMAP group for color and legend.


    Args:
        embeddings_dir (Path): Directory containing the embedding files.
        output_dir (Path): Directory where the output files will be saved.
        metadata (pd.Dataframe): metadata dataframe.
        include_3D (bool): Whether to include 3D embeddings. Defaults to True.
        force_color (Dict[str, str]): Dictionary mapping UMAP groups to colors. Defaults to None.
    """
    metadata = copy.deepcopy(metadata)
    if isinstance(metadata, Metadata):
        metadata = MetadataHandler.metadata_to_df(metadata)

    # Create a global figure for 2D UMAPS
    names = [
        file.stem.split("_", maxsplit=1)[1]
        for file in embeddings_dir.glob("embedding_*.pkl")
    ]
    names = sorted(names, key=lambda x: (x.split("_")[0], int(x.split("nn")[1])))
    names_2D = [name for name in names if "2D" in name]
    if include_3D:
        names_3D = [name for name in names if "3D" in name]
        if len(names_2D) != len(names_3D):
            raise ValueError("Number of 2D and 3D embeddings do not match.")
        if len(names_2D) + len(names_3D) != len(names):
            raise ValueError("Some embeddings don't specify 2D or 3D.")
    else:
        names_3D = [None] * len(names_2D)

    fignames = [f"{name.replace('_2D_', '_')}" for name in names_2D]
    global_fig = make_subplots(
        rows=2,
        cols=3,
        subplot_titles=fignames,
    )

    # Create each 2D/3D plot.
    for i, name_2D, name_3D in zip(range(len(names_2D)), names_2D, names_3D):
        # Searching
        file_2D = embeddings_dir / f"embedding_{name_2D}.pkl"
        if not file_2D.exists():
            raise FileNotFoundError(f"Could not find {file_2D}")

        if include_3D:
            file_3D = embeddings_dir / f"embedding_{name_3D}.pkl"
            if not file_3D.exists():
                raise FileNotFoundError(f"Could not find {file_3D}")

        # Open embeddings
        with open(file_2D, "rb") as f:
            data_2D = pickle.load(f)

        if include_3D:
            with open(file_3D, "rb") as f:
                data_3D = pickle.load(f)

        # Create dataframes
        df_2D = pd.DataFrame(data_2D["embedding"])
        df_2D["ids"] = data_2D["ids"]

        if include_3D:
            df_3D = pd.DataFrame(data_3D["embedding"])
            df_3D["ids"] = data_3D["ids"]

            if not set(df_2D["ids"]) == set(df_3D["ids"]):
                raise ValueError("IDs do not match between 2D and 3D embeddings.")

        # Explicitely filter out samples not in metadata
        initial_size = len(df_2D)
        df_2D = df_2D[df_2D["ids"].isin(metadata["id"])]

        if include_3D:
            df_3D = df_3D[df_3D["ids"].isin(metadata["id"])]

        # Info on run 1
        if i == 1:
            print(f"{len(df_2D['ids'])} samples found in 2D embedding")
            print(f"Filtered out {initial_size - len(df_2D)} samples not in metadata.")

        # Merge metadata
        df_2D = df_2D.merge(metadata, left_on="ids", right_on="id", how="inner")
        if include_3D:
            df_3D = df_3D.merge(metadata, left_on="ids", right_on="id", how="inner")

        if force_color:
            color_dict = force_color
        else:
            color_dict = {
                group: px.colors.qualitative.Dark24[i]
                for i, group in enumerate(sorted(df_2D[UMAP].unique()))
            }

        # 2D plotly
        print(f"Processing {name_2D}...")
        fig = go.Figure()

        for group, color in color_dict.items():
            filtered_df = df_2D[df_2D[UMAP] == group]

            hovertext = [
                f"{id_label}: {assay} ({track_type})"
                for id_label, assay, track_type in zip(
                    filtered_df["ids"],
                    filtered_df[ASSAY],
                    filtered_df["track_type"],
                )
            ]
            trace = go.Scatter(
                x=filtered_df[0],
                y=filtered_df[1],
                mode="markers",
                marker=dict(
                    size=2,
                    color=color,
                    opacity=0.8,
                ),
                hovertemplate="%{text}",
                text=hovertext,
                name=group,
                showlegend=True,
            )

            trace_global = go.Scatter(
                x=filtered_df[0],
                y=filtered_df[1],
                mode="markers",
                marker=dict(
                    size=2,
                    color=color,
                    opacity=0.8,
                ),
                hovertemplate="%{text}",
                text=hovertext,
                legendgroup=group,
                showlegend=False,
            )

            fig.add_trace(trace)
            global_fig.add_trace(
                trace_global,
                row=int(i / 3) + 1,
                col=i % 3 + 1,
            )

        fig.update_layout(
            title=f"2D UMAP Embeddings - {name_2D.replace('_2D', '')}",
            xaxis_title="UMAP 1",
            yaxis_title="UMAP 2",
            legend={"itemsizing": "constant"},
        )

        fig.write_html(output_dir / f"embedding_{name_2D}.html")
        del fig

        # 3D plotly
        if not include_3D:
            continue

        print(f"Processing {name_3D}...")
        fig = go.Figure()

        for group, color in color_dict.items():
            filtered_df = df_3D[df_3D[UMAP] == group]

            hovertext = [
                f"{id_label}: {assay} ({track_type})"
                for id_label, assay, track_type in zip(
                    filtered_df["ids"],
                    filtered_df[ASSAY],
                    filtered_df["track_type"],
                )
            ]
            fig.add_trace(
                go.Scatter3d(
                    x=filtered_df[0],
                    y=filtered_df[1],
                    z=filtered_df[2],
                    mode="markers",
                    marker=dict(
                        size=1,
                        color=color,
                        opacity=0.8,
                    ),
                    hovertemplate="%{text}",
                    text=hovertext,
                    name=group,
                    showlegend=True,
                )
            )

        fig.update_layout(
            title=f"3D UMAP Embeddings - {name_3D.replace('_3D', '')}",
            scene=dict(
                xaxis_title="UMAP 1",
                yaxis_title="UMAP 2",
                zaxis_title="UMAP 3",
            ),
            legend={"itemsizing": "constant"},
        )

        fig.write_html(output_dir / f"embedding_{name_3D}.html")
        del fig

    for group, color in color_dict.items():
        global_fig.add_trace(
            go.Scatter(
                x=[None],
                y=[None],
                mode="markers",
                marker=dict(
                    size=10,
                    color=color,
                    opacity=0.8,
                ),
                legendgroup=group,
                showlegend=True,
            ),
            row=1,
            col=1,
        )

    # Update the layout for the global figure and save it
    global_fig.update_layout(
        title_text="2D UMAP Embeddings",
        legend={"itemsizing": "constant"},
    )

    output_file = output_dir / "all_embeddings_2D"
    print(f"Writing global fig {output_file}.ext")
    global_fig.write_html(f"{output_file}.html")
    # global_fig.write_image(f"{output_file}.png", width=3 * 800, height=2 * 800)
    # global_fig.write_image(f"{output_file}.svg", width=3*800, height=2*800)
    del global_fig

In [None]:
def examine_umap_vals(embeddings_dir: Path, metadata: Metadata) -> None:
    """Filter UMAP embeddings based on metadata and UMAP coordinates.

    Args:
        embeddings_dir (Path): Directory containing the embedding files.
        metadata (Metadata): Metadata object.
    """
    names = [
        file.stem.split("_", maxsplit=1)[1]
        for file in embeddings_dir.glob("embedding_*.pkl")
    ]
    names = sorted(names, key=lambda x: (x.split("_")[0], int(x.split("nn")[1])))

    # Create each 2D/3D plot.
    for name in names:
        print(f"Processing {name}...")

        if not ("standard" in name and "nn15" in name):
            continue

        file = embeddings_dir / f"embedding_{name}.pkl"
        if not file.exists():
            raise FileNotFoundError(f"Could not find {file}")

        with open(file, "rb") as f:
            data = pickle.load(f)

        df = pd.DataFrame(data["embedding"])
        df["ids"] = data["ids"]

        # Add custom metadata to the dataframe
        assays = []
        track_types = []
        for id_label in df["ids"]:
            if id_label in metadata:
                assays.append(metadata[id_label][ASSAY])
                track_types.append(metadata[id_label]["track_type"])
            else:
                is_epiatlas = len(id_label) == 32
                assays.append("epiatlas_NA" if is_epiatlas else "C-A_no-pred")
                track_types.append("NA" if is_epiatlas else "ctl_raw")

        df[ASSAY] = assays
        df["track_type"] = track_types

        # Save metadata of some clusters
        sub_df = df[df[ASSAY].str.contains("wgb")]
        sub_df = sub_df[~sub_df[ASSAY].str.contains("input")]
        display(sub_df[ASSAY].value_counts())

        clus_df = sub_df[(sub_df[0] > 5) & (sub_df[0] < 7.5)]
        clus_df = clus_df[(clus_df[1] > 0) & (clus_df[1] < 2)]
        with open(
            embeddings_dir / f"embedding_{name}_sus_wgbs.md5", "w", encoding="utf8"
        ) as f:
            for id_label in clus_df["ids"]:
                f.write(f"{id_label}\n")

        clus_df = sub_df[sub_df[1] > 4.1]
        with open(
            embeddings_dir / f"embedding_{name}_sus_wgbs_clus2.md5", "w", encoding="utf8"
        ) as f:
            for id_label in clus_df["ids"]:
                f.write(f"{id_label}\n")
        print(clus_df.shape)

### Plotting

In [None]:
metadata = metadata_handler.load_metadata("v2")

#### C-A + EpiAtlas

In [None]:
input_dir = Path.home() / "mounts/narval-mount/scratch/umap/C-A_epiatlas"
if not input_dir.exists():
    raise FileNotFoundError(f"Could not find {input_dir}")

output_dir = input_dir / "graphs" / "C-A_epiatlas"
output_dir.mkdir(parents=False, exist_ok=True)

In [None]:
for md5, dset in metadata.items:
    assay = dset[ASSAY]
    dset[UMAP] = "epiatlas_" + assay

ca_pred_df[UMAP] = ca_pred_df["manual_target_consensus"].apply(lambda x: f"C-A_{x}")
ca_df = ca_pred_df[ca_pred_df["manual_target_consensus"].isin(CORE7_ASSAYS)]

graph_metadata = MetadataHandler.uniformize_metadata_for_plotting(metadata, ca_df)

   JobID                                                                                              JobName  Timelimit    Elapsed NCPUS  ReqMem ExitCode    State Priority       Submit    Account 
-------- ---------------------------------------------------------------------------------------------------- ---------- ---------- ----- ------- -------- -------- -------- ------------ ---------- 
34407394                                                                                     CA_epiatlas_umap   02:00:00   00:00:00     0     55G      0:0  PENDING  1520919  04 20:05:02 rrg-jacqu+ 

In [None]:
create_umap_graphs(input_dir, output_dir, graph_metadata)

#### Input predictions

In [None]:
input_dir = Path.home() / "mounts/narval-mount/scratch/umap/all_enc_CA_epiatlas"
if not input_dir.exists():
    raise FileNotFoundError(f"Could not find {input_dir}")

min_pred_score = 0.0
output_dir = (
    input_dir / "graphs" / "c-a_input_predictions" / f"graph_minPred_{min_pred_score}"
)
output_dir.mkdir(parents=False, exist_ok=True)

In [None]:
# Filter the input DataFrame
ca_input_mask = ca_pred_df["manual_target_consensus"] == "input"
input_N = ca_input_mask.sum()
print(f"Number of input samples: {input_N}")
pred_score_mask = ca_pred_df["Max_pred_assay11"].astype(float) >= min_pred_score
filtered_N = (ca_input_mask & pred_score_mask).sum()
print(
    f"Number of input samples with score >= {min_pred_score}: {filtered_N}/{input_N} ({filtered_N/input_N:.2%})"
)

ca_input_df = ca_pred_df[ca_input_mask & pred_score_mask].copy()
ca_input_df[UMAP] = "C-A_input_pred_" + ca_input_df["Predicted_class_assay11"]

In [None]:
for md5, dset in metadata.items:
    assay = dset[ASSAY]
    dset[UMAP] = "epiatlas_" + assay

graph_metadata = MetadataHandler.uniformize_metadata_for_plotting(metadata, ca_input_df)

In [None]:
# print(len(graph_metadata))
# display(graph_metadata[UMAP].value_counts(dropna=False))

In [None]:
# create_umap_graphs(input_dir, output_dir, graph_metadata)

#### C-A + EpiAtlas + imputed 

In [None]:
metadata_path = (
    Path.home()
    / "Projects/epiclass/input/metadata/dfreeze-v2/hg38_2023-epiatlas-dfreeze_v2.1_w_imputed.json"
)
metadata = Metadata(metadata_path)

In [None]:
for md5, dset in list(metadata.items):
    if "upload_date" not in dset:
        assay = dset[ASSAY]
        dset[UMAP] = f"epiatlas_{assay}_imputed"
    else:
        assay = dset[ASSAY]
        dset[UMAP] = f"epiatlas_{assay}"

ca_pred_df[UMAP] = "C-A_" + ca_pred_df["manual_target_consensus"]

In [None]:
graph_metadata = MetadataHandler.uniformize_metadata_for_plotting(metadata, ca_pred_df)

In [None]:
colors = px.colors.qualitative.Dark24
force_color = {}
for group in graph_metadata[UMAP].unique():
    if "epiatlas" in group:
        force_color[group] = colors[0]
    elif "C-A" in group:
        force_color[group] = colors[1]
    if "imputed" in group:
        force_color[group] = colors[2]

In [None]:
input_dir = Path.home() / "mounts/narval-mount/scratch/umap/imputed"
if not input_dir.exists():
    raise FileNotFoundError(f"Could not find {input_dir}")

output_dir = Path.home() / "downloads" / "temp" / "umap" / "imputed"
if not output_dir.exists():
    raise FileNotFoundError(f"Could not find {output_dir}")

In [None]:
# create_umap_graphs(input_dir, output_dir, graph_metadata, include_3D=False, force_color=force_color)

## ENCODE + C-A mislabels position within EpiAtlas UMAP

Uses UMAP coordinates from global knn computations (ENC + C-A + epiAtlas)

In [None]:
mislabels_dir = base_data_dir / "training_results" / "predictions"
mislabels_path = mislabels_dir / "mislabels_C-A&ENCODE_assay7.csv"
mislabels_df = pd.read_csv(mislabels_path, sep=",", low_memory=False)
print(mislabels_df.shape)

In [None]:
min_pred = 0.6
mislabels_df = mislabels_df[mislabels_df["Max_pred_assay7"] >= min_pred]

In [None]:
def plot_UMAP_mislabels(
    embeddings_dir: Path,
    output_dir: Path,
    epiatlas_metadata: Metadata,
    mislabels_df: pd.DataFrame,
) -> None:
    """Create UMAP graphs for all embeddings in the given directory.

    Args:
        embeddings_dir (Path): Directory containing the embedding files.
        output_dir (Path): Directory where the output files will be saved.
        epiatlas_metadata (Metadata): Metadata object for the EpiAtlas data.
        mislabels_df (pd.DataFrame): DataFrame containing the mislabels info.
    """
    # Create a global figure for 2D UMAPS
    names = [
        file.stem.split("_", maxsplit=1)[1]
        for file in embeddings_dir.glob("embedding_*.pkl")
    ]
    names = sorted(names, key=lambda x: (x.split("_")[0], int(x.split("nn")[1])))
    fignames = [f"{name.replace('_3D_', '_')}" for name in names]
    global_fig = make_subplots(
        rows=2,
        cols=3,
        subplot_titles=fignames,
    )

    color_dict = {
        assay: px.colors.qualitative.Dark24[i]
        for i, assay in enumerate(sorted(ASSAY_ORDER))
    }

    # Create each 2D/3D plot.
    for i, name in enumerate(names):
        print(f"Processing {name}...")

        file = embeddings_dir / f"embedding_{name}.pkl"
        if not file.exists():
            raise FileNotFoundError(f"Could not find {file}")

        with open(file, "rb") as f:
            data = pickle.load(f)

        global_umap_df = pd.DataFrame(data["embedding"])
        global_umap_df.loc[:, "ids"] = data["ids"]

        global_umap_df.loc[:, "is_epiatlas"] = global_umap_df["ids"].apply(
            lambda x: x in epiatlas_metadata
        )
        epiatlas_mask = global_umap_df["is_epiatlas"]
        epiatlas_umap = global_umap_df[epiatlas_mask].copy()
        not_epiatlas_umap = global_umap_df[~epiatlas_mask].copy()

        epiatlas_umap.loc[:, ASSAY] = epiatlas_umap["ids"].apply(
            lambda x: epiatlas_metadata[x][ASSAY]
        )
        epiatlas_umap.loc[:, "track_type"] = epiatlas_umap["ids"].apply(
            lambda x: epiatlas_metadata[x]["track_type"]
        )

        # 2D plotly
        fig = go.Figure()
        for assay_type, color in color_dict.items():
            filtered_df = epiatlas_umap[epiatlas_umap[ASSAY] == assay_type]
            if filtered_df.shape[0] == 0:
                continue
            fig.add_trace(
                go.Scatter(
                    x=filtered_df[0],
                    y=filtered_df[1],
                    mode="markers",
                    marker=dict(
                        size=2,
                        color=color,
                        opacity=0.8,
                    ),
                    hovertemplate="%{text}",
                    text=[
                        f"{id_label}: {assay} ({track_type})"
                        for id_label, assay, track_type in zip(
                            filtered_df["ids"],
                            filtered_df[ASSAY],
                            filtered_df["track_type"],
                        )
                    ],
                    name=assay_type,
                    showlegend=True,
                )
            )

            # Add mislabels to the plot
            pred_df = mislabels_df[mislabels_df["Predicted_class_assay7"] == assay_type]
            if pred_df.shape[0] == 0:
                continue

            umap_df = not_epiatlas_umap[
                not_epiatlas_umap["ids"].isin(pred_df["Experimental-id"])
            ]

            text = [
                f"{id}: Label={true_label}. Pred={pred_label} ({score:.2f})"
                for id, true_label, pred_label, score in zip(
                    pred_df["Experimental-id"],
                    pred_df["manual_target_consensus"],
                    pred_df["Predicted_class_assay7"],
                    pred_df["Max_pred_assay7"],
                )
            ]

            fig.add_trace(
                go.Scatter(
                    x=umap_df[0],
                    y=umap_df[1],
                    mode="markers",
                    marker=dict(
                        size=3,
                        color="black",
                        opacity=0.9,
                    ),
                    hovertemplate="%{text}",
                    text=text,
                    name=f"assay7 pred={assay_type}",
                    showlegend=True,
                )
            )

        fig.update_layout(
            title=f"2D UMAP Embeddings - {name}",
            xaxis_title="UMAP 1",
            yaxis_title="UMAP 2",
            legend={"itemsizing": "constant"},
        )

        fig.write_html(output_dir / f"embedding_{name}_2D.html")

        # Add subplot to the global figure
        for assay_type, color in color_dict.items():
            filtered_df = epiatlas_umap[epiatlas_umap[ASSAY] == assay_type]
            if filtered_df.shape[0] == 0:
                continue

            global_fig.add_trace(
                go.Scatter(
                    x=filtered_df[0],
                    y=filtered_df[1],
                    mode="markers",
                    marker=dict(
                        size=2,
                        color=color,
                        opacity=0.8,
                    ),
                    hovertemplate="%{text}",
                    text=[
                        f"{id_label}: {assay} ({track_type})"
                        for id_label, assay, track_type in zip(
                            filtered_df["ids"],
                            filtered_df[ASSAY],
                            filtered_df["track_type"],
                        )
                    ],
                    name=f"{assay_type}",
                    showlegend=(i == 0),
                    legendgroup=assay_type,
                ),
                row=int(i / 3) + 1,
                col=i % 3 + 1,
            )

            # Add mislabels to the plot
            pred_df = mislabels_df[mislabels_df["Predicted_class_assay7"] == assay_type]
            if pred_df.shape[0] == 0:
                continue

            umap_df = not_epiatlas_umap[
                not_epiatlas_umap["ids"].isin(pred_df["Experimental-id"])
            ]

            text = [
                f"{id}: Label={true_label}. Pred={pred_label} ({score:.2f})"
                for id, true_label, pred_label, score in zip(
                    pred_df["Experimental-id"],
                    pred_df["manual_target_consensus"],
                    pred_df["Predicted_class_assay7"],
                    pred_df["Max_pred_assay7"],
                )
            ]

            global_fig.add_trace(
                go.Scatter(
                    x=umap_df[0],
                    y=umap_df[1],
                    mode="markers",
                    marker=dict(
                        size=3,
                        color="black",
                        opacity=0.9,
                    ),
                    hovertemplate="%{text}",
                    text=text,
                    name=f"CA&ENC: assay7 pred={assay_type}",
                    showlegend=(i == 0),
                    legendgroup=assay_type,
                ),
                row=int(i / 3) + 1,
                col=i % 3 + 1,
            )

        # 3D plotly
        fig = go.Figure()

        for assay_type, color in color_dict.items():
            filtered_df = epiatlas_umap[epiatlas_umap[ASSAY] == assay_type]
            if filtered_df.shape[0] == 0:
                continue

            fig.add_trace(
                go.Scatter3d(
                    x=filtered_df[0],
                    y=filtered_df[1],
                    z=filtered_df[2],
                    mode="markers",
                    marker=dict(
                        size=1,
                        color=color,
                        opacity=0.8,
                    ),
                    hovertemplate="%{text}",
                    text=[
                        f"{id_label}: {assay} ({track_type})"
                        for id_label, assay, track_type in zip(
                            filtered_df["ids"],
                            filtered_df[ASSAY],
                            filtered_df["track_type"],
                        )
                    ],
                    name=assay_type,
                    showlegend=True,
                )
            )

        # Add mislabels to the plot
        for assay_type, color in color_dict.items():
            pred_df = mislabels_df[mislabels_df["Predicted_class_assay7"] == assay_type]
            if pred_df.shape[0] == 0:
                continue

            umap_df = not_epiatlas_umap[
                not_epiatlas_umap["ids"].isin(pred_df["Experimental-id"])
            ]

            text = [
                f"{id}: Label={true_label}. Pred={pred_label} ({score:.2f})"
                for id, true_label, pred_label, score in zip(
                    pred_df["Experimental-id"],
                    pred_df["manual_target_consensus"],
                    pred_df["Predicted_class_assay7"],
                    pred_df["Max_pred_assay7"],
                )
            ]

            fig.add_trace(
                go.Scatter3d(
                    x=umap_df[0],
                    y=umap_df[1],
                    z=umap_df[2],
                    mode="markers",
                    marker=dict(
                        size=3,
                        color="black",
                        opacity=0.9,
                    ),
                    hovertemplate="%{text}",
                    text=text,
                    name=f"CA&ENC: assay7 pred={assay_type}",
                    showlegend=True,
                )
            )

        fig.update_layout(
            title=f"3D UMAP Embeddings - {name}",
            scene=dict(
                xaxis_title="UMAP 1",
                yaxis_title="UMAP 2",
                zaxis_title="UMAP 3",
            ),
            legend={"itemsizing": "constant"},
        )

        fig.write_html(output_dir / f"embedding_{name}_3D.html")

    # Update the layout for the global figure and save it
    global_fig.update_layout(
        title_text="2D UMAP Embeddings",
        legend={"itemsizing": "constant"},
    )
    output_file = output_dir / "all_embeddings_2D"
    print(f"Writing {output_file}")
    global_fig.write_html(f"{output_file}.html")
    # global_fig.write_image(f"{output_file}.svg", width=3*800, height=2*800)
    # global_fig.write_image(f"{output_file}.png", width=3*800, height=2*800)

In [None]:
mislabel_cols = [
    "Experimental-id",
    "manual_target_consensus",
    "Predicted_class_assay7",
    "Max_pred_assay7",
]
assert mislabels_df[mislabel_cols].isna().sum().sum() == 0

In [None]:
input_dir = Path.home() / "mounts/narval-mount/scratch/umap/all_enc_CA_epiatlas"

output_dir = input_dir / "graphs" / "mislabels"
output_dir.mkdir(parents=False, exist_ok=True)
metadata = metadata_handler.load_metadata("v2")
metadata.convert_classes(ASSAY, ASSAY_MERGE_DICT)

In [None]:
plot_UMAP_mislabels(input_dir, output_dir, metadata, mislabels_df)