In [None]:
"""Plot PCA representation for various datasets."""
# pylint: disable=redefined-outer-name,use-dict-literal,import-error

## SETUP

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from __future__ import annotations

import copy
from pathlib import Path

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import skops.io as skio
from IPython.display import display  # pylint: disable=unused-import
from plotly.subplots import make_subplots

from epi_ml.core.hdf5_loader import Hdf5Loader
from epi_ml.utils.notebooks.paper.paper_utilities import ASSAY_ORDER, MetadataHandler

In [4]:
CORE_ASSAYS = ASSAY_ORDER[0:7]

In [5]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
paper_dir = base_dir

if not base_fig_dir.exists():
    raise FileNotFoundError(f"Directory {base_fig_dir} does not exist.")

In [6]:
metadata_handler = MetadataHandler(paper_dir)
metadata_v2 = metadata_handler.load_metadata("v2")

In [7]:
chromsize_path = base_data_dir / "chromsizes" / "hg38.noy.chrom.sizes"
hdf5_loader = Hdf5Loader(chrom_file=chromsize_path, normalization=True)

### Metadata setup

In [8]:
ca_pred_path = (
    base_data_dir
    / "training_results"
    / "predictions"
    / "C-A"
    / "assay_epiclass"
    / "CA_metadata_4DB+all_pred.20240606_mod2.2.tsv"
)
ca_pred_df = pd.read_csv(ca_pred_path, sep="\t", low_memory=False)

In [9]:
enc_merged_preds_path = (
    base_data_dir
    / "training_results"
    / "predictions"
    / "encode"
    / "encode_predictions_augmented_merged.csv"
)
enc_pred_df = pd.read_csv(enc_merged_preds_path, sep=",", low_memory=False)

In [10]:
PLOT_LABEL = "plot_label"

In [11]:
in_epiatlas_ca = ca_pred_df[ca_pred_df["is_EpiAtlas_EpiRR"] != "0"][
    "Experimental-id"
].unique()
in_epiatlas_ca = set(in_epiatlas_ca)

In [None]:
in_epiatlas_enc = enc_pred_df[enc_pred_df["in_EpiAtlas"]]["ENC_ID"].unique()
in_epiatlas_enc = set(in_epiatlas_enc)
print(len(in_epiatlas_ca), len(in_epiatlas_enc))

In [None]:
graph_metadata = MetadataHandler.uniformize_metadata_for_plotting(
    metadata_v2, ca_pred_df, enc_pred_df
)

display(graph_metadata["source"].value_counts(dropna=False))

if len(graph_metadata["id"]) != len(graph_metadata["id"].unique()):
    raise ValueError("IDs are not unique.")

In [14]:
graph_metadata.loc[:, PLOT_LABEL] = (
    graph_metadata.loc[:, "source"] + "_" + graph_metadata.loc[:, "assay_epiclass"]
)

# graph_metadata[PLOT_LABEL].value_counts(dropna=False)

In [None]:
graph_metadata_no_overlap = copy.deepcopy(graph_metadata)

in_epiatlas = in_epiatlas_ca.union(in_epiatlas_enc)

graph_metadata_no_overlap = graph_metadata_no_overlap[
    ~graph_metadata_no_overlap["id"].isin(in_epiatlas)
]

if graph_metadata.shape[0] != graph_metadata_no_overlap.shape[0] + len(in_epiatlas):
    print(graph_metadata.shape, graph_metadata_no_overlap.shape)
    diff = graph_metadata.shape[0] - graph_metadata_no_overlap.shape[0]
    raise ValueError(
        f"Wrong number of samples: {diff} samples removed, {len(in_epiatlas)} samples in EpiAtlas overlap."
    )

display(graph_metadata_no_overlap["source"].value_counts(dropna=False))

### PCA results loading

In [16]:
pca_dir = base_data_dir / "pca" / "recount3"
# n_3projects = 88777
n_recount3 = 71738
pca_fit = skio.load(pca_dir / f"IPCA_fit_n{n_recount3}.skops")
pca_results = skio.load(pca_dir / f"X_IPCA_n{n_recount3}.skops")

In [17]:
pca_data = pca_results["X_ipca"]
if len(pca_data) != n_recount3:
    raise ValueError("PCA data length does not match filename.")

In [18]:
ipca_fit = pca_fit["ipca_fit"]
pca_filenames = pca_fit["file_names"]
explained_variance = ipca_fit.explained_variance_ratio_
assert len(pca_filenames) == n_recount3

In [19]:
global_pca_df = pd.DataFrame(pca_data)
global_pca_df.columns = [f"PC{i+1}" for i in range(global_pca_df.shape[1])]
global_pca_df["id"] = pca_filenames

In [None]:
pca_name = "recount3"
# pca_name = "3projects"

if pca_name == "3projects":
    final_pca_df = global_pca_df.merge(
        graph_metadata, left_on="id", right_on="id", how="inner"
    )

    final_pca_df_no_overlap = global_pca_df.merge(
        graph_metadata_no_overlap, left_on="id", right_on="id", how="inner"
    )

    display(final_pca_df["source"].value_counts(dropna=False))
    display(final_pca_df_no_overlap["source"].value_counts(dropna=False))

elif pca_name == "recount3":
    final_pca_df = global_pca_df.merge(
        graph_metadata, left_on="id", right_on="id", how="left"
    )
    final_pca_df["source"].fillna("recount3", inplace=True)
    display(final_pca_df["source"].value_counts(dropna=False))

else:
    raise ValueError(f"Unknown PCA name: {pca_name}")

## Plotting

In [21]:
output_dir = base_fig_dir / "pca"

In [22]:
# non_core_labels = ["non-core", "CTCF", "wgbs-standard", "wgbs-pbat", "rna_seq", "mrna_seq"]
# core_assay_df = final_pca_df[~final_pca_df["assay_epiclass"].isin(non_core_labels)]
# core_assay_df_no_overlap = final_pca_df_no_overlap[~final_pca_df_no_overlap["assay_epiclass"].isin(non_core_labels)]

# core_assay_df = final_pca_df[final_pca_df["assay_epiclass"].isin(CORE_ASSAYS)]
# core_assay_df_no_overlap = final_pca_df_no_overlap[
#     final_pca_df_no_overlap["assay_epiclass"].isin(CORE_ASSAYS)
# ]

In [23]:
# color_dict = {
#     "C-A": px.colors.qualitative.Dark24[0],
#     "epiatlas": px.colors.qualitative.Dark24[1],
#     "encode": px.colors.qualitative.Dark24[2],
# }

# fig = go.Figure()
# for db_label, color in color_dict.items():
#     filtered_df = core_assay_df_no_overlap[core_assay_df_no_overlap["source"] == db_label]
#     fig.add_trace(
#         go.Scatter3d(
#             x=filtered_df["PC1"],
#             y=filtered_df["PC2"],
#             z=filtered_df["PC3"],
#             mode="markers",
#             marker=dict(
#                 size=1,
#                 color=color,
#                 opacity=0.5,
#             ),
#             hovertemplate="%{text}",
#             text=[
#                 f"{id_label}: {assay} ({db_label})"
#                 for id_label, assay, db_label in zip(
#                     filtered_df["id"],
#                     filtered_df["assay_epiclass"],
#                     filtered_df["source"],
#                 )
#             ],
#             name=f"{db_label} (N={filtered_df.shape[0]})",
#             showlegend=True,
#         )
#     )

# axis_titles = [f"PC {i+1} ({explained_variance[i]:.2%})" for i in range(3)]

# fig.update_layout(
#     title="3D PCA - epiATLAS and ChiP-Atlas - all samples",
#     scene=dict(
#         xaxis_title=axis_titles[0],
#         yaxis_title=axis_titles[1],
#         zaxis_title=axis_titles[2],
#     ),
#     legend={"itemsizing": "constant"},
# )

# fig.write_html(output_dir / "pca_3projects_no_epiatlas_overlap_3D.html")
# del fig

In [24]:
# color_dict = {
#     plot_label: px.colors.qualitative.Dark24[i]
#     for i, plot_label in enumerate(final_pca_df[PLOT_LABEL].unique())
# }

# fig = go.Figure()
# for plot_label, color in color_dict.items():
#     filtered_df = core_assay_df_no_overlap[
#         core_assay_df_no_overlap[PLOT_LABEL] == plot_label
#     ]
#     fig.add_trace(
#         go.Scatter(
#             x=filtered_df["PC1"],
#             y=filtered_df["PC2"],
#             mode="markers",
#             marker=dict(
#                 size=1,
#                 color=color,
#                 opacity=0.5,
#             ),
#             hovertemplate="%{text}",
#             text=[
#                 f"{id_label}: {assay} ({db_label})"
#                 for id_label, assay, db_label in zip(
#                     filtered_df["id"],
#                     filtered_df["assay_epiclass"],
#                     filtered_df["source"],
#                 )
#             ],
#             name=f"{plot_label} (N={filtered_df.shape[0]})",
#             showlegend=True,
#         )
#     )

# axis_titles = [f"PC {i+1} ({explained_variance[i]:.2%})" for i in range(2)]

# fig.update_layout(
#     title="2D PCA - epiATLAS and ChiP-Atlas - core7 samples",
#     xaxis_title=axis_titles[0],
#     yaxis_title=axis_titles[1],
#     legend={"itemsizing": "constant"},
# )

# name = "pca_core7_per_assay_C-A_epiatlas_2D"
# fig.write_html(output_dir / f"{name}.html")
# fig.write_image(output_dir / f"{name}.png")
# fig.write_image(output_dir / f"{name}.svg")
# del fig

In [25]:
# color_dict = {
#     # "C-A": px.colors.qualitative.Dark24[0],
#     "epiatlas": px.colors.qualitative.Dark24[1],
#     "encode": px.colors.qualitative.Dark24[2],
# }


# fig = go.Figure()
# for db_label, color in color_dict.items():
#     filtered_df = core_assay_df_no_overlap[core_assay_df_no_overlap["source"] == db_label]
#     fig.add_trace(
#         go.Scatter(
#             x=filtered_df["PC1"],
#             y=filtered_df["PC2"],
#             mode="markers",
#             marker=dict(
#                 size=1.5,
#                 color=color,
#                 opacity=0.8,
#             ),
#             hovertemplate="%{text}",
#             text=[
#                 f"{id_label}: {assay} ({db_label})"
#                 for id_label, assay, db_label in zip(
#                     filtered_df["id"],
#                     filtered_df["assay_epiclass"],
#                     filtered_df["source"],
#                 )
#             ],
#             name=f"{db_label} (N={filtered_df.shape[0]})",
#             showlegend=True,
#         )
#     )

# axis_titles = [f"<b>PC {i+1}</b> ({explained_variance[i]:.2%})" for i in range(2)]

# # title = "PCA - epiATLAS & ENCODE & ChIP-Atlas - core7 samples"
# title = "PCA - epiATLAS & ENCODE - core7 samples"
# fig.update_layout(
#     title=title,
#     xaxis_title=axis_titles[0],
#     yaxis_title=axis_titles[1],
#     legend={"itemsizing": "constant"},
# )

# # name = "pca_3projects_no_epiatlas_overlap_2D"
# name = "pca_epiatlas_ENC_no_epiatlas_overlap_2D"
# fig.write_html(output_dir / f"{name}.html")
# fig.write_image(output_dir / f"{name}.png", scale=1.5)
# fig.write_image(output_dir / f"{name}.svg")
# fig.show()
# del fig

In [26]:
# color_dict = {
#     "C-A": px.colors.qualitative.Dark24[0],
#     "epiatlas": px.colors.qualitative.Dark24[1],
#     "encode": px.colors.qualitative.Dark24[2],
# }


# fig = go.Figure()
# for db_label, color in color_dict.items():
#     filtered_df = final_pca_df_no_overlap[final_pca_df_no_overlap["source"] == db_label]
#     fig.add_trace(
#         go.Scatter(
#             x=filtered_df["PC1"],
#             y=filtered_df["PC2"],
#             mode="markers",
#             marker=dict(
#                 size=1.5,
#                 color=color,
#                 opacity=0.8,
#             ),
#             hovertemplate="%{text}",
#             text=[
#                 f"{id_label}: {assay} ({db_label})"
#                 for id_label, assay, db_label in zip(
#                     filtered_df["id"],
#                     filtered_df["assay_epiclass"],
#                     filtered_df["source"],
#                 )
#             ],
#             name=f"{db_label} (N={filtered_df.shape[0]})",
#             showlegend=True,
#         )
#     )

# axis_titles = [f"<b>PC {i+1}</b> ({explained_variance[i]:.2%})" for i in range(2)]

# title = "PCA - EpiATLAS & ENCODE & ChIP-Atlas - all samples (no overlap)"
# fig.update_layout(
#     title=title,
#     xaxis_title=axis_titles[0],
#     yaxis_title=axis_titles[1],
#     legend={"itemsizing": "constant"},
# )

# name = "pca_3projects_no-epiatlas-overlap_all-samples_2D"
# fig.write_html(output_dir / f"{name}.html")
# fig.write_image(output_dir / f"{name}.png", scale=1.5)
# fig.write_image(output_dir / f"{name}.svg")
# fig.show()
# del fig

In [27]:
# # Create a new color dictionary
# color_dict = {
#     "C-A": px.colors.qualitative.Dark24[0],
#     "epiatlas": px.colors.qualitative.Dark24[1],
#     "encode_core": px.colors.qualitative.Dark24[2],
#     "encode_non_core": px.colors.qualitative.Dark24[5],
# }

# fig = go.Figure()

# for db_label, color in color_dict.items():
#     if db_label == "encode_core":
#         filtered_df = final_pca_df_no_overlap[
#             (final_pca_df_no_overlap["source"] == "encode")
#             & (final_pca_df_no_overlap["assay_epiclass"].isin(CORE_ASSAYS))
#         ]
#         display_label = "ENCODE (core)"
#     elif db_label == "encode_non_core":
#         filtered_df = final_pca_df_no_overlap[
#             (final_pca_df_no_overlap["source"] == "encode")
#             & (~final_pca_df_no_overlap["assay_epiclass"].isin(CORE_ASSAYS))
#         ]
#         display_label = "ENCODE (non-core)"
#     else:
#         filtered_df = final_pca_df_no_overlap[
#             final_pca_df_no_overlap["source"] == db_label
#         ]
#         display_label = db_label

#     fig.add_trace(
#         go.Scatter(
#             x=filtered_df["PC1"],
#             y=filtered_df["PC2"],
#             mode="markers",
#             marker=dict(
#                 size=1.5,
#                 color=color,
#                 opacity=0.8,
#             ),
#             hovertemplate="%{text}",
#             text=[
#                 f"{id_label}: {assay} ({display_label})"
#                 for id_label, assay in zip(
#                     filtered_df["id"],
#                     filtered_df["assay_epiclass"],
#                 )
#             ],
#             name=f"{display_label} (N={filtered_df.shape[0]})",
#             showlegend=True,
#         )
#     )

# axis_titles = [f"<b>PC {i+1}</b> ({explained_variance[i]:.2%})" for i in range(2)]

# title = "PCA - EpiATLAS & ENCODE & ChIP-Atlas - all samples (no overlap)"
# fig.update_layout(
#     title=title,
#     xaxis_title=axis_titles[0],
#     yaxis_title=axis_titles[1],
#     legend={"itemsizing": "constant"},
# )

# name = "pca_4colors_no-epiatlas-overlap_all-samples_2D"
# fig.write_html(output_dir / f"{name}.html")
# fig.write_image(output_dir / f"{name}.png", scale=1.5)
# fig.write_image(output_dir / f"{name}.svg")
# fig.show()

In [28]:
# fig = px.density_contour(
#     core_assay_df_no_overlap,
#     x="PC1",
#     y="PC2",
#     color="source",
#     height=800,
#     width=800,
#     )

# fig.update_traces(line=dict(width=1))

# fig.update_layout(
#     title="2D PCA - epiATLAS+ChiP-Atlas+ENC - core7 samples",
#     xaxis_title=axis_titles[0],
#     yaxis_title=axis_titles[1],
#     legend={"itemsizing": "constant"},
#     )
# fig.show()


# fig = px.density_contour(
#     core_assay_df_no_overlap,
#     x="PC1",
#     y="PC2",
#     marginal_x="histogram",
#     marginal_y="histogram",
#     color="source",
#     height=800,
#     width=800,
#     )

# fig.show()

### recount3

In [None]:
# Create a new color dictionary
color_dict = {
    "EpiATLAS_ChIP": px.colors.qualitative.Dark24[0],
    "EpiATLAS_RNA": px.colors.qualitative.Dark24[6],
    "EpiATLAS_WGB": px.colors.qualitative.Dark24[2],
    "recount3_RNA": px.colors.qualitative.Dark24[3],
}

axis_titles = [f"<b>PC{i+1}</b> ({explained_variance[i]:.2%})" for i in range(2)]

fig = make_subplots(
    rows=3,
    cols=1,
    shared_xaxes=True,
    vertical_spacing=0.1,
    row_heights=[0.2, 0.2, 0.6],
    subplot_titles=(
        "EpiATLAS RNA PC1 distribution (%)",
        "recount3 RNA PC1 distribution (%)",
        "PCA",
    ),
    x_title=axis_titles[0],
)

for fig_label, color in color_dict.items():
    # filter/create correct groups
    if fig_label == "EpiATLAS_ChIP":
        filtered_df = final_pca_df[
            (final_pca_df["source"] == "epiatlas")
            & (final_pca_df["assay_epiclass"].isin(CORE_ASSAYS))
        ]
        display_label = "EpiATLAS ChIP"
    elif fig_label == "EpiATLAS_RNA":
        filtered_df = final_pca_df[
            (final_pca_df["source"] == "epiatlas")
            & (final_pca_df["assay_epiclass"].isin(["mrna_seq", "rna_seq"]))
        ]
        display_label = "EpiATLAS RNA"
    elif fig_label == "EpiATLAS_WGB":
        filtered_df = final_pca_df[
            (final_pca_df["source"] == "epiatlas")
            & (final_pca_df["assay_epiclass"].isin(["wgbs-standard", "wgbs-pbat"]))
        ]
        display_label = "EpiATLAS WGB"
    elif fig_label == "recount3_RNA":
        filtered_df = final_pca_df[(final_pca_df["source"] == "recount3")]
        display_label = "recount3 RNA"

    # plot
    fig.add_trace(
        go.Scatter(
            x=filtered_df["PC1"],
            y=filtered_df["PC2"],
            mode="markers",
            marker=dict(
                size=1.5,
                color=color,
                opacity=0.8,
            ),
            hovertemplate="%{text}",
            text=[
                f"{id_label}: {assay} ({display_label})"
                for id_label, assay in zip(
                    filtered_df["id"],
                    filtered_df["assay_epiclass"],
                )
            ],
            name=f"{display_label} (N={filtered_df.shape[0]})",
            showlegend=True,
        ),
        row=3,
        col=1,
    )

    if fig_label == "EpiATLAS_RNA":
        fig.add_trace(
            go.Histogram(
                x=filtered_df["PC1"],
                histnorm="percent",
                name=f"{display_label} (N={filtered_df.shape[0]})",
                showlegend=False,
                marker=dict(color=color),
            ),
            row=1,
            col=1,
        )

    if fig_label == "recount3_RNA":
        fig.add_trace(
            go.Histogram(
                x=filtered_df["PC1"],
                histnorm="percent",
                name=f"{display_label} (N={filtered_df.shape[0]})",
                showlegend=False,
                marker=dict(color=color),
            ),
            row=2,
            col=1,
        )

title = "PCA - EpiATLAS & recount3"
fig.update_layout(
    title=title,
    legend={"itemsizing": "constant"},
)

# add y-axis title to last row
fig.update_yaxes(title_text=axis_titles[1], row=3, col=1)

# set histograms to same yrange
fig.update_yaxes(range=[-0.001, 10], row=1, col=1, nticks=4)
fig.update_yaxes(range=[-0.001, 10], row=2, col=1, nticks=4)

name = "pca_epiatlas_recount3_2D"
output_dir = base_fig_dir / "pca"
fig.write_html(output_dir / f"{name}.html")
fig.write_image(output_dir / f"{name}.png", scale=1.5)
fig.write_image(output_dir / f"{name}.svg")
fig.show()