In [None]:
"""Analysis of relation between UMAP and QC metrics, following WGBS grouped with input."""

# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines, duplicate-code

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import annotations

import copy
from pathlib import Path
from typing import List

import pandas as pd
import plotly.graph_objects as go
from IPython.display import display

from epiclass.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    ASSAY_MERGE_DICT,
    CELL_TYPE,
    IHECColorMap,
    MetadataHandler,
)

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
paper_dir = base_dir

if not base_fig_dir.exists():
    raise FileNotFoundError(f"Directory {base_fig_dir} does not exist.")

In [None]:
metadata_handler = MetadataHandler(paper_dir)
metadata = metadata_handler.load_metadata("v2")
metadata.convert_classes(ASSAY, ASSAY_MERGE_DICT)
metadata.select_category_subsets(ASSAY, ["wgbs"])

In [None]:
meta_og = metadata_handler.load_metadata("v2")
meta_og.select_category_subsets(ASSAY, ["wgbs-pbat", "wgbs-standard"])

In [None]:
IHECColorMap = IHECColorMap(base_fig_dir)
assay_colors = IHECColorMap.assay_color_map

In [None]:
def display_labels(meta, categories: List[str]):
    """Display metadata labels for categories."""
    for category in categories:
        meta.display_labels(category)

In [None]:
cats = [ASSAY, "track_type", CELL_TYPE]

## Read relevant files

In [None]:
# WGBS QC
qc_dir = base_data_dir / "experiment_metadata"
wgbs_qc_path = qc_dir / "EpiATLAS_wgbs_qc_summary.csv"
wgbs_qc = pd.read_csv(wgbs_qc_path)
display(wgbs_qc.head())
display(wgbs_qc.shape)

In [None]:
metric_cols = wgbs_qc.columns[3:]

In [None]:
if not wgbs_qc.shape[0] == wgbs_qc["uuid"].nunique():
    raise ValueError("UUIDs are not unique.")

In [None]:
# UMAP WGBS SUSPICIOUS CLUSTERS
umap_dir = base_data_dir / "umap"
path_template = "embedding_standard_3D_nn100_sus_wgbs_clus{i}.md5"
clusters_md5s = []
for i in [1, 2]:
    path = umap_dir / path_template.format(i=i)
    clusters_md5s.append(pd.read_csv(path, index_col=0).index.to_list())

In [None]:
sus_md5s = set(clusters_md5s[0] + clusters_md5s[1])

In [None]:
md5_to_uuid = {
    dset["md5sum"]: dset["uuid"]
    for dset in metadata.datasets
    if dset["md5sum"] in sus_md5s
}
all_bad_uuids = set(md5_to_uuid.values())

In [None]:
uuid_1 = set(md5_to_uuid[md5] for md5 in clusters_md5s[0])
uuid_2 = set(md5_to_uuid[md5] for md5 in clusters_md5s[1])

In [None]:
print(len(uuid_1), len(uuid_2))

In [None]:
no_conversion = wgbs_qc[wgbs_qc["BS_conversion_Rate"].isna()]
display(no_conversion.shape)

In [None]:
meta_no_conversion = copy.deepcopy(meta_og)
for md5, dset in list(meta_no_conversion.items):
    if dset["uuid"] not in no_conversion["uuid"].to_list():
        del meta_no_conversion[md5]

In [None]:
display_labels(meta_no_conversion, cats)

In [None]:
for metric in metric_cols:
    fig = go.Figure()
    for i, uuids in enumerate([uuid_1, uuid_2]):
        fig.add_trace(
            go.Violin(
                y=wgbs_qc.loc[wgbs_qc["uuid"].isin(uuids), metric],
                box_visible=True,
                line_color="black",
                meanline_visible=True,
                points="all",
                spanmode="hard",
                fillcolor=assay_colors["WGBS"],
                marker=dict(size=5),
                opacity=0.6,
                name=f"Cluster {i+1}",
            )
        )
    fig.add_trace(
        go.Violin(
            y=wgbs_qc.loc[~wgbs_qc["uuid"].isin(all_bad_uuids), metric],
            box_visible=True,
            line_color="black",
            meanline_visible=True,
            points="all",
            spanmode="hard",
            fillcolor=assay_colors["WGBS"],
            marker=dict(size=3),
            opacity=0.6,
            name="Other",
        )
    )
    fig.update_layout(
        title=f"{metric} - WGBS grouped with input (UMAP)",
        yaxis_title=metric,
        xaxis_title="Cluster",
    )
    fig.show()

In [None]:
sub_df = wgbs_qc.loc[wgbs_qc["BS_conversion_Rate"].isna()]
for metric in metric_cols:
    fig = go.Figure()

    mask_1 = sub_df["uuid"].isin(uuid_1)

    fig.add_trace(
        go.Violin(
            y=sub_df[mask_1][metric],
            box_visible=True,
            line_color="black",
            meanline_visible=True,
            points="all",
            spanmode="hard",
            fillcolor="red",
            marker=dict(size=3),
            opacity=0.9,
            name="Cluster 1",
        )
    )
    fig.add_trace(
        go.Violin(
            y=sub_df[~mask_1][metric],
            box_visible=True,
            line_color="black",
            meanline_visible=True,
            points="all",
            spanmode="hard",
            fillcolor=assay_colors["WGBS"],
            marker=dict(size=3),
            opacity=0.9,
            name="Other",
        )
    )

    fig.update_layout(
        title=f"{metric} - WGBS no conversion rate",
        yaxis_title=metric,
    )
    fig.show()

In [None]:
meta = copy.deepcopy(meta_og)
for md5 in list(meta.md5s):
    if md5 not in clusters_md5s[0]:
        del meta[md5]

display_labels(meta, cats)