In [None]:
"""Plot PCA representation for various datasets."""
# pylint: disable=redefined-outer-name,use-dict-literal,import-error

## SETUP

In [None]:
%load_ext autoreload
%autoreload 2

In [153]:
from __future__ import annotations

import copy
from pathlib import Path

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import skops.io as skio
from IPython.display import display  # pylint: disable=unused-import

from epi_ml.core.hdf5_loader import Hdf5Loader
from epi_ml.utils.notebooks.paper.paper_utilities import ASSAY_ORDER, MetadataHandler

In [154]:
CORE_ASSAYS = ASSAY_ORDER[0:7]

In [155]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
paper_dir = base_dir

if not base_fig_dir.exists():
    raise FileNotFoundError(f"Directory {base_fig_dir} does not exist.")

In [156]:
metadata_handler = MetadataHandler(paper_dir)
metadata_v2 = metadata_handler.load_metadata("v2")

In [157]:
chromsize_path = base_data_dir / "chromsizes" / "hg38.noy.chrom.sizes"
hdf5_loader = Hdf5Loader(chrom_file=chromsize_path, normalization=True)

### Metadata setup

In [158]:
ca_pred_path = (
    base_data_dir
    / "training_results"
    / "predictions"
    / "C-A"
    / "assay_epiclass"
    / "CA_metadata_4DB+all_pred.20240606_mod2.2.tsv"
)
ca_pred_df = pd.read_csv(ca_pred_path, sep="\t", low_memory=False)

In [159]:
PLOT_LABEL = "plot_label"

In [160]:
in_epiatlas = ca_pred_df[ca_pred_df["is_EpiAtlas_EpiRR"] != "0"][
    "Experimental-id"
].unique()

In [161]:
graph_metadata = MetadataHandler.uniformize_metadata_for_plotting(metadata_v2, ca_pred_df)

In [162]:
graph_metadata.loc[:, PLOT_LABEL] = (
    graph_metadata.loc[:, "source"] + "_" + graph_metadata.loc[:, "assay_epiclass"]
)

In [None]:
graph_metadata[PLOT_LABEL].value_counts(dropna=False)

In [None]:
(graph_metadata["source"] == "C-A").sum()

In [None]:
graph_metadata_no_overlap = copy.deepcopy(graph_metadata)
graph_metadata_no_overlap = graph_metadata_no_overlap[
    ~graph_metadata_no_overlap["id"].isin(in_epiatlas)
]
print(graph_metadata.shape, graph_metadata_no_overlap.shape)

### PCA results loading

In [166]:
pca_dir = base_data_dir / "pca"
pca_fit = skio.load(pca_dir / "IPCA_fit_n88777.skops")
pca_results = skio.load(pca_dir / "X_IPCA_n88777.skops")

In [167]:
pca_data = pca_results["X_ipca"]

In [168]:
ipca_fit = pca_fit["ipca_fit"]
pca_filenames = pca_fit["file_names"]
explained_variance = ipca_fit.explained_variance_ratio_

In [169]:
global_pca_df = pd.DataFrame(pca_data)
global_pca_df.columns = [f"PC{i+1}" for i in range(global_pca_df.shape[1])]
global_pca_df["id"] = pca_filenames

In [170]:
final_pca_df = global_pca_df.merge(
    graph_metadata, left_on="id", right_on="id", how="inner"
)

In [None]:
final_pca_df_no_overlap = final_pca_df[~final_pca_df["id"].isin(in_epiatlas)]
print(final_pca_df.shape, final_pca_df_no_overlap.shape)

In [None]:
final_pca_df["source"].value_counts()

In [None]:
final_pca_df_no_overlap["source"].value_counts()

## Plotting

In [182]:
output_dir = base_fig_dir / "pca"

In [183]:
# non_core_labels = ["non-core", "CTCF", "wgbs-standard", "wgbs-pbat", "rna_seq", "mrna_seq"]
# core_assay_df = final_pca_df[~final_pca_df["assay_epiclass"].isin(non_core_labels)]
# core_assay_df_no_overlap = final_pca_df_no_overlap[~final_pca_df_no_overlap["assay_epiclass"].isin(non_core_labels)]

core_assay_df = final_pca_df[final_pca_df["assay_epiclass"].isin(CORE_ASSAYS)]
core_assay_df_no_overlap = final_pca_df_no_overlap[
    final_pca_df_no_overlap["assay_epiclass"].isin(CORE_ASSAYS)
]

In [184]:
color_dict = {
    "C-A": px.colors.qualitative.Dark24[0],
    "epiatlas": px.colors.qualitative.Dark24[1],
}

fig = go.Figure()
for db_label, color in color_dict.items():
    filtered_df = final_pca_df[final_pca_df["source"] == db_label]
    fig.add_trace(
        go.Scatter3d(
            x=filtered_df["PC1"],
            y=filtered_df["PC2"],
            z=filtered_df["PC3"],
            mode="markers",
            marker=dict(
                size=1,
                color=color,
                opacity=0.5,
            ),
            hovertemplate="%{text}",
            text=[
                f"{id_label}: {assay} ({db_label})"
                for id_label, assay, db_label in zip(
                    filtered_df["id"],
                    filtered_df["assay_epiclass"],
                    filtered_df["source"],
                )
            ],
            name=f"{db_label} (N={filtered_df.shape[0]})",
            showlegend=True,
        )
    )

axis_titles = [f"PC {i+1} ({explained_variance[i]:.2%})" for i in range(3)]

fig.update_layout(
    title="3D PCA - epiATLAS and ChiP-Atlas - all samples",
    scene=dict(
        xaxis_title=axis_titles[0],
        yaxis_title=axis_titles[1],
        zaxis_title=axis_titles[2],
    ),
    legend={"itemsizing": "constant"},
)

fig.write_html(output_dir / "pca_all_samples_C-A_epiatlas_3D.html")
del fig

In [185]:
color_dict = {
    "C-A": px.colors.qualitative.Dark24[0],
    "epiatlas": px.colors.qualitative.Dark24[1],
}

fig = go.Figure()
for db_label, color in color_dict.items():
    filtered_df = core_assay_df_no_overlap[core_assay_df_no_overlap["source"] == db_label]
    fig.add_trace(
        go.Scatter3d(
            x=filtered_df["PC1"],
            y=filtered_df["PC2"],
            z=filtered_df["PC3"],
            mode="markers",
            marker=dict(
                size=1,
                color=color,
                opacity=0.5,
            ),
            hovertemplate="%{text}",
            text=[
                f"{id_label}: {assay} ({db_label})"
                for id_label, assay, db_label in zip(
                    filtered_df["id"],
                    filtered_df["assay_epiclass"],
                    filtered_df["source"],
                )
            ],
            name=f"{db_label} (N={filtered_df.shape[0]})",
            showlegend=True,
        )
    )

axis_titles = [f"PC {i+1} ({explained_variance[i]:.2%})" for i in range(3)]

fig.update_layout(
    title="3D PCA - epiATLAS and ChiP-Atlas - core7 samples",
    scene=dict(
        xaxis_title=axis_titles[0],
        yaxis_title=axis_titles[1],
        zaxis_title=axis_titles[2],
    ),
    legend={"itemsizing": "constant"},
)

fig.write_html(output_dir / "pca_core7_C-A_epiatlas_3D.html")
del fig

In [186]:
color_dict = {
    plot_label: px.colors.qualitative.Dark24[i]
    for i, plot_label in enumerate(final_pca_df[PLOT_LABEL].unique())
}

fig = go.Figure()
for plot_label, color in color_dict.items():
    filtered_df = core_assay_df_no_overlap[
        core_assay_df_no_overlap[PLOT_LABEL] == plot_label
    ]
    fig.add_trace(
        go.Scatter3d(
            x=filtered_df["PC1"],
            y=filtered_df["PC2"],
            z=filtered_df["PC3"],
            mode="markers",
            marker=dict(
                size=1,
                color=color,
                opacity=0.5,
            ),
            hovertemplate="%{text}",
            text=[
                f"{id_label}: {assay} ({db_label})"
                for id_label, assay, db_label in zip(
                    filtered_df["id"],
                    filtered_df["assay_epiclass"],
                    filtered_df["source"],
                )
            ],
            name=f"{plot_label} (N={filtered_df.shape[0]})",
            showlegend=True,
        )
    )

axis_titles = [f"PC {i+1} ({explained_variance[i]:.2%})" for i in range(3)]

fig.update_layout(
    title="3D PCA - epiATLAS and ChiP-Atlas - core7 samples",
    scene=dict(
        xaxis_title=axis_titles[0],
        yaxis_title=axis_titles[1],
        zaxis_title=axis_titles[2],
    ),
    legend={"itemsizing": "constant"},
)

fig.write_html(output_dir / "pca_core7_per_assay_C-A_epiatlas_3D.html")
del fig

In [187]:
color_dict = {
    plot_label: px.colors.qualitative.Dark24[i]
    for i, plot_label in enumerate(final_pca_df[PLOT_LABEL].unique())
}

fig = go.Figure()
for plot_label, color in color_dict.items():
    filtered_df = core_assay_df_no_overlap[
        core_assay_df_no_overlap[PLOT_LABEL] == plot_label
    ]
    fig.add_trace(
        go.Scatter(
            x=filtered_df["PC1"],
            y=filtered_df["PC2"],
            mode="markers",
            marker=dict(
                size=1,
                color=color,
                opacity=0.5,
            ),
            hovertemplate="%{text}",
            text=[
                f"{id_label}: {assay} ({db_label})"
                for id_label, assay, db_label in zip(
                    filtered_df["id"],
                    filtered_df["assay_epiclass"],
                    filtered_df["source"],
                )
            ],
            name=f"{plot_label} (N={filtered_df.shape[0]})",
            showlegend=True,
        )
    )

axis_titles = [f"PC {i+1} ({explained_variance[i]:.2%})" for i in range(2)]

fig.update_layout(
    title="2D PCA - epiATLAS and ChiP-Atlas - core7 samples",
    xaxis_title=axis_titles[0],
    yaxis_title=axis_titles[1],
    legend={"itemsizing": "constant"},
)

name = "pca_core7_per_assay_C-A_epiatlas_2D"
fig.write_html(output_dir / f"{name}.html")
fig.write_image(output_dir / f"{name}.png")
fig.write_image(output_dir / f"{name}.svg")
del fig

In [188]:
color_dict = {
    "C-A": px.colors.qualitative.Dark24[0],
    "epiatlas": px.colors.qualitative.Dark24[1],
}

fig = go.Figure()
for db_label, color in color_dict.items():
    filtered_df = core_assay_df_no_overlap[core_assay_df_no_overlap["source"] == db_label]
    fig.add_trace(
        go.Scatter(
            x=filtered_df["PC1"],
            y=filtered_df["PC2"],
            mode="markers",
            marker=dict(
                size=1,
                color=color,
                opacity=0.8,
            ),
            hovertemplate="%{text}",
            text=[
                f"{id_label}: {assay} ({db_label})"
                for id_label, assay, db_label in zip(
                    filtered_df["id"],
                    filtered_df["assay_epiclass"],
                    filtered_df["source"],
                )
            ],
            name=f"{db_label} (N={filtered_df.shape[0]})",
            showlegend=True,
        )
    )

axis_titles = [f"PC {i+1} ({explained_variance[i]:.2%})" for i in range(2)]

fig.update_layout(
    title="2D PCA - epiATLAS and ChiP-Atlas - core7 samples",
    xaxis_title=axis_titles[0],
    yaxis_title=axis_titles[1],
    legend={"itemsizing": "constant"},
)

name = "pca_core7_C-A_epiatlas_2D"
fig.write_html(output_dir / f"{name}.html")
fig.write_image(output_dir / f"{name}.png")
fig.write_image(output_dir / f"{name}.svg")
del fig