In [None]:
"""Plot UMAP embeddings for various datasets."""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines, unused-import, unused-argument, too-many-branches

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import annotations

import pickle
from pathlib import Path

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import umap
from IPython.display import display
from plotly.subplots import make_subplots
from sklearn.preprocessing import StandardScaler
from umap.umap_ import nearest_neighbors

from epi_ml.core.hdf5_loader import Hdf5Loader
from epi_ml.core.metadata import Metadata
from epi_ml.utils.notebooks.paper.paper_utilities import ASSAY, MetadataHandler

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
paper_dir = base_dir

if not base_fig_dir.exists():
    raise FileNotFoundError(f"Directory {base_fig_dir} does not exist.")

In [None]:
metadata_handler = MetadataHandler(paper_dir)
metadata_v2 = metadata_handler.load_metadata("v2")

In [None]:
input_basedir = Path(
    "/home/local/USHERBROOKE/rabj2301/mounts/narval-mount/project-rabyj/epilap/input"
)

In [None]:
chromsize_path = input_basedir / "chromsizes" / "hg38.noy.chrom.sizes"
hdf5_loader = Hdf5Loader(chrom_file=chromsize_path, normalization=True)

In [None]:
ca_pred_path = (
    base_data_dir
    / "training_results"
    / "C-A"
    / "CA_metadata_4DB+all_pred_subset.20240606.tsv"
)
ca_pred_df = pd.read_csv(ca_pred_path, sep="\t", low_memory=False)

## C-A input peak metrics

In [None]:
# peak_metrics_path = Path.home() / "Downloads" / "temp" / "peak_metrics" / "metadata.peak_chip_atlas.tsv"
# peak_metrics_df = pd.read_csv(peak_metrics_path, sep="\t", index_col=0)

In [None]:
# scaler = StandardScaler()
# scaled_peak_metrics = scaler.fit_transform(peak_metrics_df.values)
# ids = peak_metrics_df.index.tolist()

In [None]:
# precomputed_knn = nearest_neighbors(
#     X=scaled_peak_metrics,
#     n_neighbors=100,
#     metric="correlation",
#     random_state=42,
#     low_memory=False,
#     metric_kwds=None,
#     angular=None,
# )

# with open(peak_metrics_path.parent / "peak_metrics_precomputed_knn.pkl", "wb") as f:
#     pickle.dump(precomputed_knn, f)

In [None]:
# # UMAP parameters
# nn_default = 15
# nn_bigger = 30
# nn_biggest = 100
# embedding_params = {}
# for nn_size in [nn_default, nn_bigger, nn_biggest]:
#     embedding_params[f"standard_3D_nn{nn_size}"] = {
#         "n_neighbors": nn_size,
#         "min_dist": 0.1,
#         "n_components": 3,
#         "low_memory": False,
#     }
#     embedding_params[f"densmap_3D_nn{nn_size}"] = {
#         "n_neighbors": nn_size,
#         "min_dist": 0.1,
#         "n_components": 3,
#         "low_memory": False,
#         "densmap": True,
#     }

In [None]:
# for name, params in embedding_params.items():
#     embedding = umap.UMAP(
#         **params, random_state=42, precomputed_knn=precomputed_knn
#     ).fit_transform(X=scaled_peak_metrics)

#     with open(peak_metrics_path.parent / f"embedding_{name}.pkl", "wb") as f:
#         pickle.dump(
#             obj={"ids": ids, "embedding": embedding, "params": params}, file=f
#         )
#         print(f"Saved embedding_{name}.pkl")

## C-A input hdf5 100kb UMAP

In [None]:
def create_umap_graphs(embeddings_dir: Path, output_dir: Path, metadata) -> None:
    """Create UMAP graphs for all embeddings in the given directory.

    Args:
        embeddings_dir (Path): Directory containing the embedding files.
        output_dir (Path): Directory where the output files will be saved.
        metadata (Metadata): Metadata object.
    """
    # Create a global figure for 2D UMAPS
    names = [
        file.stem.split("_", maxsplit=1)[1]
        for file in embeddings_dir.glob("embedding_*.pkl")
    ]
    names = sorted(names, key=lambda x: (x.split("_")[0], int(x.split("nn")[1])))
    fignames = [f"{name.replace('_3D_', '_')}" for name in names]
    global_fig = make_subplots(
        rows=2,
        cols=3,
        subplot_titles=fignames,
    )

    # Create each 2D/3D plot.
    for i, name in enumerate(names):
        print(f"Processing {name}...")

        file = embeddings_dir / f"embedding_{name}.pkl"
        if not file.exists():
            raise FileNotFoundError(f"Could not find {file}")

        with open(file, "rb") as f:
            data = pickle.load(f)

        df = pd.DataFrame(data["embedding"])
        df["ids"] = data["ids"]

        # Add custom metadata to the dataframe
        assays = []
        track_types = []
        for id_label in df["ids"]:
            if id_label in metadata:
                assays.append(metadata[id_label][ASSAY])
                track_types.append(metadata[id_label]["track_type"])
            else:
                is_epiatlas = len(id_label) == 32
                assays.append("epiatlas_NA" if is_epiatlas else "C-A_no-pred")
                track_types.append("NA" if is_epiatlas else "ctl_raw")

        df[ASSAY] = assays
        df["track_type"] = track_types

        color_dict = {
            assay: px.colors.qualitative.Dark24[i]
            for i, assay in enumerate(sorted(df[ASSAY].unique()))
        }

        # 2D plotly
        fig = go.Figure()

        for assay_type, color in color_dict.items():
            filtered_df = df[df[ASSAY] == assay_type]
            fig.add_trace(
                go.Scatter(
                    x=filtered_df[0],
                    y=filtered_df[1],
                    mode="markers",
                    marker=dict(
                        size=2,
                        color=color,
                        opacity=0.8,
                    ),
                    hovertemplate="%{text}",
                    text=[
                        f"{id_label}: {assay} ({track_type})"
                        for id_label, assay, track_type in zip(
                            filtered_df["ids"],
                            filtered_df[ASSAY],
                            filtered_df["track_type"],
                        )
                    ],
                    name=assay_type,
                    showlegend=True,
                )
            )

        fig.update_layout(
            title=f"2D UMAP Embeddings - {name}",
            xaxis_title="UMAP 1",
            yaxis_title="UMAP 2",
            legend={"itemsizing": "constant"},
        )

        fig.write_html(output_dir / f"embedding_{name}_2D.html")

        # Add subplot to the global figure
        for assay_type, color in color_dict.items():
            filtered_df = df[df[ASSAY] == assay_type]
            global_fig.add_trace(
                go.Scatter(
                    x=filtered_df[0],
                    y=filtered_df[1],
                    mode="markers",
                    marker=dict(
                        size=2,
                        color=color,
                        opacity=0.8,
                    ),
                    hovertemplate="%{text}",
                    text=[
                        f"{id_label}: {assay} ({track_type})"
                        for id_label, assay, track_type in zip(
                            filtered_df["ids"],
                            filtered_df[ASSAY],
                            filtered_df["track_type"],
                        )
                    ],
                    name=f"{assay_type}",
                    showlegend=(i == 0),
                    legendgroup=assay_type,
                ),
                row=int(i / 3) + 1,
                col=i % 3 + 1,
            )

        # 3D plotly
        fig = go.Figure()

        for assay_type, color in color_dict.items():
            filtered_df = df[df[ASSAY] == assay_type]
            fig.add_trace(
                go.Scatter3d(
                    x=filtered_df[0],
                    y=filtered_df[1],
                    z=filtered_df[2],
                    mode="markers",
                    marker=dict(
                        size=1,
                        color=color,
                        opacity=0.8,
                    ),
                    hovertemplate="%{text}",
                    text=[
                        f"{id_label}: {assay} ({track_type})"
                        for id_label, assay, track_type in zip(
                            filtered_df["ids"],
                            filtered_df[ASSAY],
                            filtered_df["track_type"],
                        )
                    ],
                    name=assay_type,
                    showlegend=True,
                )
            )

        fig.update_layout(
            title=f"3D UMAP Embeddings - {name}",
            scene=dict(
                xaxis_title="UMAP 1",
                yaxis_title="UMAP 2",
                zaxis_title="UMAP 3",
            ),
            legend={"itemsizing": "constant"},
        )

        fig.write_html(output_dir / f"embedding_{name}_3D.html")

    # Update the layout for the global figure and save it
    global_fig.update_layout(
        title_text="2D UMAP Embeddings - All Files",
        legend={"itemsizing": "constant"},
    )
    output_file = output_dir / "all_embeddings_2D"
    print(f"Writing {output_file}")
    global_fig.write_html(f"{output_file}.html")
    # global_fig.write_image(f"{output_file}.svg", width=3*800, height=2*800)
    # global_fig.write_image(f"{output_file}.png", width=3*800, height=2*800)

In [None]:
def create_custom_ca_metadata(metadata, ca_pred_df: pd.DataFrame, pred_score: float = 0):
    """Create a custom metadata dataframe for the Chip-Atlas predictions."""
    # Filter the input DataFrame
    ca_input_df = ca_pred_df[(ca_pred_df["manual_target_consensus"] == "input")]
    input_nb = ca_input_df.shape[0]
    print(f"Number of input samples: {input_nb}")
    ca_input_df = ca_input_df[ca_input_df["Max_pred_assay11"].astype(float) >= pred_score]
    print(
        f"Number of input samples with score >= {pred_score}: {ca_input_df.shape[0]}/{input_nb} ({ca_input_df.shape[0]/input_nb:.2%})"
    )

    # Create a dictionary of predictions
    pred_dict = ca_input_df.set_index("Experimental-id")[
        "Predicted_class_assay11"
    ].to_dict()

    # Update metadata
    for id_label, pred_val in pred_dict.items():
        metadata[id_label] = {ASSAY: f"input_pred_{pred_val}", "track_type": "ctl_raw"}

    return metadata

In [None]:
metadata = metadata_handler.load_metadata("v2")

In [None]:
min_pred_score = 0.8
filtered_metadata = create_custom_ca_metadata(
    metadata, ca_pred_df, pred_score=min_pred_score
)

In [None]:
input_dir = Path(
    "/home/local/USHERBROOKE/rabj2301/mounts/narval-mount/scratch/other_data/C-A/hdf5/umap-input/epiatlas_all/nn100"
)
if not input_dir.exists():
    raise FileNotFoundError(f"Could not find {input_dir}")
output_dir = input_dir / f"graph_minPred_{min_pred_score}"
output_dir.mkdir(parents=False, exist_ok=True)

In [None]:
# input_dir = Path("/home/local/USHERBROOKE/rabj2301/Downloads/temp/peak_metrics")
# if not input_dir.exists():
#     raise FileNotFoundError(f"Could not find {input_dir}")
# output_dir = input_dir

In [None]:
create_umap_graphs(input_dir, output_dir, filtered_metadata)