In [None]:
"""Compute UMAP embedding for some input+wgbs data in epiatlas and chip-atlas datasets."""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines, unused-import, unused-argument, too-many-branches

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import annotations

import pickle
from importlib import metadata
from pathlib import Path

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import display
from plotly.subplots import make_subplots

from epi_ml.core.hdf5_loader import Hdf5Loader
from epi_ml.utils.notebooks.paper.paper_utilities import ASSAY, MetadataHandler

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
paper_dir = base_dir

if not base_fig_dir.exists():
    raise FileNotFoundError(f"Directory {base_fig_dir} does not exist.")

In [None]:
metadata_handler = MetadataHandler(paper_dir)
metadata_v2 = metadata_handler.load_metadata("v2")

In [None]:
input_basedir = Path(
    "/home/local/USHERBROOKE/rabj2301/mounts/narval-mount/project-rabyj/epilap/input"
)

In [None]:
chromsize_path = input_basedir / "chromsizes" / "hg38.noy.chrom.sizes"
hdf5_loader = Hdf5Loader(chrom_file=chromsize_path, normalization=True)

In [None]:
output_dir = Path(
    "/home/local/USHERBROOKE/rabj2301/mounts/narval-mount/scratch/other_data/C-A/hdf5/umap-input/epiatlas_all/nn30"
)
if not output_dir.exists():
    raise FileNotFoundError(f"Could not find {output_dir}")

# color_dict = {
#     "input": "black",
#     "input_C-A": "forestgreen",
#     "wgbs-standard": "blue",
#     "wgbs-pbat": "purple",
# }


for name in ["standard_2", "standard_3", "densmap"]:
    try:
        with open(output_dir / f"embedding_{name}.pkl", "rb") as f:
            data = pickle.load(f)
    except FileNotFoundError:
        print(f"Could not find {output_dir / f'embedding_{name}.pkl'}")
        continue

    df = pd.DataFrame(data["embedding"])
    df["ids"] = data["ids"]
    df[ASSAY] = [
        metadata_v2[id_label][ASSAY] if id_label in metadata_v2 else "input_C-A"
        for id_label in df["ids"]
    ]

    color_dict = {
        assay: px.colors.qualitative.Dark24[i]
        for i, assay in enumerate(sorted(df[ASSAY].unique()))
    }

    # 2D plotly
    fig = go.Figure()

    for assay_type, color in color_dict.items():
        filtered_df = df[df[ASSAY] == assay_type]
        fig.add_trace(
            go.Scatter(
                x=filtered_df[0],
                y=filtered_df[1],
                mode="markers",
                marker=dict(
                    size=2,
                    color=color,
                    opacity=0.8,
                ),
                hovertemplate="%{text}",
                text=[
                    f"{id_label}: {assay}"
                    for id_label, assay in zip(filtered_df["ids"], filtered_df[ASSAY])
                ],
                name=assay_type,
                showlegend=True,
            )
        )

    fig.update_layout(
        title=f"2D UMAP Embeddings - {name}",
        xaxis_title="UMAP 1",
        yaxis_title="UMAP 2",
        legend={"itemsizing": "constant"},
    )

    fig.write_html(output_dir / f"embedding_{name}_2D.html")

    if "_2" in name:
        continue

    # 3D plotly
    fig = go.Figure()

    for assay_type, color in color_dict.items():
        filtered_df = df[df[ASSAY] == assay_type]
        fig.add_trace(
            go.Scatter3d(
                x=filtered_df[0],
                y=filtered_df[1],
                z=filtered_df[2],
                mode="markers",
                marker=dict(
                    size=1,
                    color=color,
                    opacity=0.8,
                ),
                hovertemplate="%{text}",
                text=[
                    f"{id_label}: {assay}"
                    for id_label, assay in zip(filtered_df["ids"], filtered_df[ASSAY])
                ],
                name=assay_type,
                showlegend=True,
            )
        )

    fig.update_layout(
        title=f"3D UMAP Embeddings - {name}",
        scene=dict(
            xaxis_title="UMAP 1",
            yaxis_title="UMAP 2",
            zaxis_title="UMAP 3",
        ),
        legend={"itemsizing": "constant"},
    )

    fig.write_html(output_dir / f"embedding_{name}_3D.html")