In [None]:
"""Plot PCA representation for various datasets."""

# pylint: disable=redefined-outer-name, use-dict-literal, import-error, too-many-lines, duplicate-code

## SETUP

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import annotations

from pathlib import Path
from typing import List

import colorcet as cc
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import skops.io as skio
from IPython.display import display  # pylint: disable=unused-import
from plotly.subplots import make_subplots

from epiclass.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    ASSAY_ORDER,
    CELL_TYPE,
    MetadataHandler,
    set_file_id,
)

In [None]:
def rgb_to_plotly(rgb_list: List[List[float]]) -> List[List[str]]:
    """Transform [r, g, b] float list to a plotly format string 'rgb(r,g,b)'"""
    new_vals = []
    for rgb_seq in rgb_list:
        r, g, b = rgb_seq
        new_val = f"rgb({r},{g},{b})"
        new_vals.append(new_val)
    return new_vals

In [None]:
CORE_ASSAYS = ASSAY_ORDER[0:7]

In [None]:
ALL_CORE_ASSAYS = ASSAY_ORDER + ["mrna_seq", "wgbs-standard", "wgbs-pbat", "core7"]

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"

metadata_dir = base_data_dir / "metadata"
pca_data_dir = base_data_dir / "pca"

base_fig_dir = base_dir / "figures"
general_logdir = base_fig_dir / "pca"

paper_dir = base_dir

if not base_fig_dir.exists():
    raise FileNotFoundError(f"Directory {base_fig_dir} does not exist.")

In [None]:
metadata_handler = MetadataHandler(paper_dir)
metadata_v2 = metadata_handler.load_metadata("v2")

### Metadata setup

Also create file/ids lists for PCA computation.

In [None]:
file_list_dir = pca_data_dir / "file_lists"

In [None]:
metadata_v2_df = metadata_v2.to_df()

metadata_v2_df["md5sum"].to_csv(
    file_list_dir / "epiatlas_files_dfreeze-v2.md5",
    sep=",",
    index=False,
    header=False,
)

metadata_v2_df[metadata_v2_df[ASSAY].isin(CORE_ASSAYS)]["md5sum"].to_csv(
    file_list_dir / "epiatlas_chip_dfreeze-v2.md5",
    sep=",",
    index=False,
    header=False,
)

metadata_v2_df[metadata_v2_df[ASSAY].str.contains("rna")]["md5sum"].to_csv(
    file_list_dir / "epiatlas_rna_dfreeze-v2.md5",
    sep=",",
    index=False,
    header=False,
)

Load ChIP-Atlas predictions, exclude potential non-core assays files, or EpiATLAS overlap

In [None]:
ca_pred_path = (
    base_data_dir
    / "training_results"
    / "predictions"
    / "C-A"
    / ASSAY
    / "CA_metadata_4DB+all_pred.20240606_mod3.0.tsv"
)
ca_pred_df = pd.read_csv(ca_pred_path, sep="\t", low_memory=False)
print(ca_pred_df.shape)

In [None]:
ca_pred_df["in_epiatlas"] = ca_pred_df["is_EpiAtlas_EpiRR"].astype(str) != "0"

in_epiatlas_ca = ca_pred_df[ca_pred_df["in_epiatlas"]]["Experimental-id"].unique()
in_epiatlas_ca = set(in_epiatlas_ca)
print(f"C-A overlap w EpiATLAS: {len(in_epiatlas_ca)}")

ca_pred_df = ca_pred_df[~ca_pred_df["in_epiatlas"]]
print(f"After removing overlap w EpiATLAS, {len(ca_pred_df)} rows remain.")

In [None]:
# ca_pred_df["core7_DBs_consensus"].value_counts(dropna=False)
N = len(ca_pred_df)
ca_pred_df = ca_pred_df[
    ~ca_pred_df["core7_DBs_consensus"].str.contains(pat="Ignored|non-core", regex=True)
]
print(f"Removed {N - len(ca_pred_df)} rows that are potentially non-core/CTCF.")
print(f"After this, {len(ca_pred_df)} rows remain.")

In [None]:
ca_pred_df["Experimental-id"].to_csv(
    file_list_dir / "ChIP-Atlas_confirmed_chip_ids_no_epiatlas.list",
    sep=",",
    index=False,
    header=False,
)

In [None]:
# no_consensus files still only have core7 assays potential values (except maybe an input file from a non-core assay)
ca_pred_df["manual_target_consensus"].replace("no_consensus", "core7", inplace=True)
print(ca_pred_df["manual_target_consensus"].value_counts(dropna=False))

Load ENCODE predictions, exclude EpiATLAS EpiRR overlap.

In [None]:
enc_merged_preds_path = (
    base_data_dir
    / "training_results"
    / "predictions"
    / "encode"
    / "complete_encode_predictions_augmented_2025-02_metadata.csv.gz"
)
enc_pred_df = pd.read_csv(
    enc_merged_preds_path, sep=",", low_memory=False, compression="gzip"
)
print(enc_pred_df.shape)

In [None]:
# enc_pred_df[ASSAY].value_counts(dropna=False)

In [None]:
mask_in_epiatlas = enc_pred_df["in_epiatlas"]

in_epiatlas_df = enc_pred_df[mask_in_epiatlas]

In [None]:
# print(in_epiatlas_df[ASSAY].value_counts(dropna=False))

In [None]:
print(enc_pred_df["in_epiatlas"].value_counts(dropna=False))

enc_pred_df = enc_pred_df[~mask_in_epiatlas]

In [None]:
# print(enc_pred_df[ASSAY].value_counts(dropna=False))

In [None]:
enc_pred_df["FILE_accession"].to_csv(
    file_list_dir / "ENCODE_no_epiatlas_ids.list",
    sep=",",
    index=False,
    header=False,
)

enc_rna_overlap = enc_pred_df[enc_pred_df[ASSAY].str.contains("rna")]
enc_rna_overlap["FILE_accession"].to_csv(
    file_list_dir / "ENCODE_rna_no_epiatlas_overlap_ids.list",
    sep=",",
    index=False,
    header=False,
)

In [None]:
enc_chip_df = enc_pred_df[enc_pred_df[ASSAY].isin(CORE_ASSAYS)]

enc_chip_df["FILE_accession"].to_csv(
    file_list_dir / "ENCODE_chip_core_ids_no_epiatlas.list",
    sep=",",
    index=False,
    header=False,
)

In [None]:
recount3_metadata = pd.read_csv(
    metadata_dir / "recount3" / "recount_harmonized_metadata_20250122_leuk2.tsv",
    sep="\t",
    low_memory=False,
)

recount3_metadata: pd.DataFrame = recount3_metadata[["ID", "harmonized_assay"]]  # type: ignore
recount3_metadata.fillna("unknown", inplace=True)
print(recount3_metadata.shape)

In [None]:
graph_metadata = MetadataHandler.uniformize_metadata_for_plotting(
    metadata_v2, ca_pred_df, enc_pred_df, recount3_metadata
)
display(graph_metadata["source"].value_counts(dropna=False))

if graph_metadata.shape[0] != len(graph_metadata["id"].unique()):
    raise ValueError("IDs are not unique.")

In [None]:
PLOT_LABEL = "plot_label"

graph_metadata.loc[:, PLOT_LABEL] = (
    graph_metadata.loc[:, "source"]
    + "_"
    + graph_metadata.loc[:, "assay_epiclass"].str.lower()
)

# print(graph_metadata[PLOT_LABEL].value_counts(dropna=False))

In [None]:
graph_metadata[ASSAY] = graph_metadata[ASSAY].str.lower()
print(graph_metadata[ASSAY].value_counts(dropna=False))

### PCA results loading

In [None]:
# this object is created so you can refer to a different easily
# by just modifying the number, since files are supposed to have format 'X_IPCA_n[N].skops'
pca_name_to_N = {
    "chip_3projects": 65960,
    # "recount3_epiatlas": 71738,
    "rna_enc_epi_recount3": 54531,
    "epiatlas_encode": 29699,
}

In [None]:
pca_name = "rna_enc_epi_recount3"
pca_dir = pca_data_dir / pca_name

expected_n_files = pca_name_to_N[pca_name]

In [None]:
pca_fit = skio.load(pca_dir / f"IPCA_fit_n{expected_n_files}.skops")
pca_results = skio.load(pca_dir / f"X_IPCA_n{expected_n_files}.skops")

In [None]:
pca_data = pca_results["X_ipca"]
if len(pca_data) != expected_n_files:
    raise ValueError("PCA data length does not match filename.")

ipca_fit = pca_fit["ipca_fit"]
pca_filenames = pca_fit["file_names"]
explained_variance = ipca_fit.explained_variance_ratio_
assert len(pca_filenames) == expected_n_files

In [None]:
global_pca_df = pd.DataFrame(pca_data)
global_pca_df.columns = [f"PC{i+1}" for i in range(global_pca_df.shape[1])]
global_pca_df["id"] = pca_filenames
global_pca_df = set_file_id(global_pca_df, "id", "id")

In [None]:
if pca_name in ["3projects", "chip_3projects"]:
    final_pca_df = global_pca_df.merge(
        graph_metadata, left_on="id", right_on="id", how="inner"
    )
    display(final_pca_df["source"].value_counts(dropna=False))


elif pca_name == "epiatlas_encode":
    final_pca_df = global_pca_df.merge(
        graph_metadata, left_on="id", right_on="id", how="inner"
    )
    display(final_pca_df["source"].value_counts(dropna=False))

elif pca_name == "recount3":
    final_pca_df = global_pca_df.merge(
        graph_metadata, left_on="id", right_on="id", how="left"
    )
    final_pca_df["source"].fillna("recount3", inplace=True)
    display(final_pca_df["source"].value_counts(dropna=False))

elif pca_name == "rna_enc_epi_recount3":
    final_pca_df = global_pca_df.merge(
        graph_metadata, left_on="id", right_on="id", how="left"
    )
    final_pca_df["source"].fillna("recount3", inplace=True)
    display(final_pca_df["source"].value_counts(dropna=False))

else:
    raise ValueError(f"Unknown PCA name: {pca_name}")

## Plotting

### Initial setup (functions)

In [None]:
non_core_labels = [
    "non-core",
    "ctcf",
    "wgbs-standard",
    "wgbs-pbat",
    "rna_seq",
    "mrna_seq",
    "wgbs",
    "other",
]
core7_labels = [
    "h3k27ac",
    "h3k4me1",
    "h3k4me3",
    "h3k36me3",
    "h3k27me3",
    "h3k9me3",
    "input",
    "core7",
]

# print(final_pca_df["assay_epiclass"].value_counts(dropna=False))

In [None]:
map_assay_type = {assay: "ChIP, core7" for assay in core7_labels}
map_assay_type.update({assay: "ChIP, non-core" for assay in ["non-core", "ctcf"]})
map_assay_type.update({assay: "ChIP, unknown" for assay in ["no_consensus"]})
map_assay_type.update({assay: "WGBS" for assay in ["wgbs-pbat", "wgbs-standard", "wgbs"]})
map_assay_type.update({assay: "RNA_seq" for assay in ["rna_seq", "mrna_seq"]})

In [None]:
map_general_assay_type = {assay: "core9" for assay in ALL_CORE_ASSAYS}
map_general_assay_type.update({assay: "non-core" for assay in ["non-core", "ctcf"]})

In [None]:
core_assay_df = final_pca_df[final_pca_df[ASSAY].isin(core7_labels)]

In [None]:
def convert_to_filename(name: str) -> str:
    """Converts a string to a filename-safe string."""
    name = (
        name.replace("&", "and")
        .replace("-", "_")
        .replace(" ", "_")
        .replace("__", "_")
        .replace("(", "")
        .replace(")", "")
    )
    return name

In [None]:
def adjust_axes(fig: go.Figure, scale: float, full_df: pd.DataFrame) -> None:
    """Attempt to adjust the axis scaling to 1:1 (x:y).

    Args:
        fig (go.Figure): The figure to adjust.
        scale (float): The general figure size scaling factor.
        full_df (pd.DataFrame): The full DataFrame containing the data.
    """
    y_max = full_df["PC2"].max()
    y_min = full_df["PC2"].min()
    x_max = full_df["PC1"].max()
    x_min = full_df["PC1"].min()

    # because of legend, need to have width bigger (~15%)
    height = (y_max - y_min) * scale * 0.85
    width = (x_max - x_min) * scale

    fig.update_layout(height=height, width=width)

In [None]:
def create_graph_from_groups(
    df: pd.DataFrame, name: str, logdir: Path | None = None, scale: float = 1
) -> None:
    """Creates a graph from predefined groups.

    Groups use the `plot_label` (or whatever PLOT_LABEL is set to) column in the DataFrame.
    """
    df = df.copy(deep=True)
    if PLOT_LABEL not in df.columns:
        raise ValueError(f"Column {PLOT_LABEL} not found in the DataFrame.")

    color_dict = {
        db_label: px.colors.qualitative.Dark24[i]
        for i, db_label in enumerate(sorted(df[PLOT_LABEL].unique()))
    }

    fig = go.Figure()
    for label, color in color_dict.items():
        filtered_df = df[df[PLOT_LABEL] == label]
        print(label, filtered_df.shape)
        fig.add_trace(
            go.Scatter(
                x=filtered_df["PC1"],
                y=filtered_df["PC2"],
                mode="markers",
                marker=dict(
                    size=1.5,
                    color=color,
                    opacity=0.8,
                ),
                hovertemplate="%{text}",
                text=[
                    f"{id_label}: {assay} ({label})"
                    for id_label, assay, label in zip(
                        filtered_df["id"],
                        filtered_df["assay_epiclass"],
                        filtered_df["source"],
                    )
                ],
                name=f"{label} (N={filtered_df.shape[0]})",
                showlegend=True,
            )
        )

    axis_titles = [f"<b>PC {i+1}</b> ({explained_variance[i]:.2%})" for i in range(2)]

    title_base = f"PCA - {name}"

    fig.update_layout(
        title=title_base,
        xaxis_title=axis_titles[0],
        yaxis_title=axis_titles[1],
        legend={"itemsizing": "constant"},
    )

    adjust_axes(fig, scale, df)

    if logdir:
        name = convert_to_filename(name)
        name = name + "_2D"
        fig.write_html(logdir / f"{name}.html")
        fig.write_image(logdir / f"{name}.png")
        fig.write_image(logdir / f"{name}.svg")
    fig.show()
    del fig

In [None]:
def graph_all_samples_2D_per_source(
    df: pd.DataFrame, name: str, logdir: Path | None = None, scale: float = 1
) -> None:
    """miaw"""
    df = df.copy(deep=True)

    color_dict = {
        db_label: px.colors.qualitative.Dark24[i]
        for i, db_label in enumerate(sorted(df["source"].unique()))
    }

    fig = go.Figure()
    for db_label, color in color_dict.items():
        filtered_df = df[df["source"] == db_label]
        print(db_label, filtered_df.shape)
        fig.add_trace(
            go.Scatter(
                x=filtered_df["PC1"],
                y=filtered_df["PC2"],
                mode="markers",
                marker=dict(
                    size=1.5,
                    color=color,
                    opacity=0.8,
                ),
                hovertemplate="%{text}",
                text=[
                    f"{id_label}: {assay} ({db_label})"
                    for id_label, assay, db_label in zip(
                        filtered_df["id"],
                        filtered_df["assay_epiclass"],
                        filtered_df["source"],
                    )
                ],
                name=f"{db_label} (N={filtered_df.shape[0]})",
                showlegend=True,
            )
        )

    axis_titles = [f"<b>PC {i+1}</b> ({explained_variance[i]:.2%})" for i in range(2)]

    title_base = f"PCA - {name}"
    if pca_name == "chip_3projects":
        title = title_base + " - ChIP core samples"
    elif pca_name == "rna_enc_epi_recount3":
        title = title_base + " - RNA samples"
    else:
        title = title_base + " - all samples"

    fig.update_layout(
        title=title,
        xaxis_title=axis_titles[0],
        yaxis_title=axis_titles[1],
        legend={"itemsizing": "constant"},
    )

    adjust_axes(fig, scale, df)

    if logdir:
        name = convert_to_filename(name)
        name = name + "_2D"
        fig.write_html(logdir / f"{name}.html")
        fig.write_image(logdir / f"{name}.png")
        fig.write_image(logdir / f"{name}.svg")
    fig.show()
    del fig

In [None]:
def graph_all_samples_2D_per_source_and_assay_type(
    df: pd.DataFrame, name: str, logdir: Path | None = None, scale: float = 5
) -> None:
    """miaw"""
    df = df.copy(deep=True)

    # Setup color groups
    try:
        df["assay_type"] = df[ASSAY].map(map_assay_type)
    except KeyError as err:
        assays = set(df[ASSAY].unique())
        mapped_assays = set(map_assay_type.keys())
        missing = assays - mapped_assays
        raise ValueError(
            f"An assay is not present in the assay mapper: {missing}"
        ) from err

    color_list = rgb_to_plotly(cc.glasbey_bw_minc_20_maxl_70)

    unique_pairs = df[["source", "assay_type"]].drop_duplicates().values
    unique_pairs = sorted(unique_pairs, key=lambda x: (x[0], x[1]))
    color_dict = {
        (db_label, assay_type): color_list[i]
        for i, (db_label, assay_type) in enumerate(unique_pairs)
    }

    # plot
    fig = go.Figure()
    for (db_label, assay_type), color in color_dict.items():
        filtered_df = df[(df["source"] == db_label) & (df["assay_type"] == assay_type)]
        print(db_label, filtered_df.shape)
        fig.add_trace(
            go.Scatter(
                x=filtered_df["PC1"],
                y=filtered_df["PC2"],
                mode="markers",
                marker=dict(
                    size=1.5,
                    color=color,
                    opacity=0.8,
                ),
                hovertemplate="%{text}",
                text=[
                    f"{id_label}: {assay} ({db_label})"
                    for id_label, assay, db_label in zip(
                        filtered_df["id"],
                        filtered_df["assay_epiclass"],
                        filtered_df["source"],
                    )
                ],
                name=f"{db_label}, {assay_type} (N={filtered_df.shape[0]})",
                showlegend=True,
            )
        )

    axis_titles = [f"<b>PC {i+1}</b> ({explained_variance[i]:.2%})" for i in range(2)]

    title_base = f"PCA - {name}"
    if pca_name == "chip_3projects":
        title = title_base + " - ChIP core samples"
    elif pca_name == "rna_enc_epi_recount3":
        title = title_base + " - RNA samples"
    else:
        title = title_base + " - all samples"

    fig.update_layout(
        title=title,
        xaxis_title=axis_titles[0],
        yaxis_title=axis_titles[1],
        legend={"itemsizing": "constant"},
    )

    adjust_axes(fig, scale, df)

    if logdir:
        figname = convert_to_filename(name)
        figname = figname + "_2D"
        fig.write_html(logdir / f"{figname}.html")
        fig.write_image(logdir / f"{figname}.png")
        fig.write_image(logdir / f"{figname}.svg")
    fig.show()
    del fig

### Samples by database source

In [None]:
# Configuration dictionary for supported PCA names
pca_config = {
    "epiatlas_encode": {
        "logdir": general_logdir / "EpiATLAS_ENCODE" / "no_epiatlas_overlap",
        "name": "EpiATLAS & ENCODE",
    },
    "chip_3projects": {
        "logdir": general_logdir / "chip_C-A_epiatlas_ENC" / "no_epiatlas_overlap",
        "name": "EpiATLAS & ENCODE & ChIP-Atlas",
    },
    "rna_enc_epi_recount3": {
        "logdir": general_logdir / "RNA_EpiATLAS_ENCODE_recount3" / "no_epiatlas_overlap",
        "name": "EpiATLAS & ENCODE & recount3",
    },
}

# Handle known PCA configurations
if pca_name in pca_config:
    config = pca_config[pca_name]
    logdir = config["logdir"]
    name = config["name"]
    scale = 4  # Constant across both cases, can be moved into config if needed

    logdir.mkdir(exist_ok=True, parents=True)

    graph_all_samples_2D_per_source(
        df=final_pca_df,
        name=name,
        logdir=logdir,
        scale=scale,
    )
else:
    print(f"Unknown pca_name: {pca_name}")

### Samples by source and assay type

In [None]:
# display(final_pca_df["source"].value_counts(dropna=False))
# display(final_pca_df[ASSAY].value_counts(dropna=False))

In [None]:
# Configuration dictionary for supported PCA names
pca_config = {
    "epiatlas_encode": {
        "logdir": general_logdir / "EpiATLAS_ENCODE" / "no_epiatlas_overlap",
        "name": "EpiATLAS & ENCODE - by assay type",
    },
}

# Handle known PCA configurations
if pca_name in pca_config:
    config = pca_config[pca_name]
    logdir = config["logdir"]
    name = config["name"]
    scale = 4  # Constant across both cases, can be moved into config if needed

    logdir.mkdir(exist_ok=True)

    graph_all_samples_2D_per_source_and_assay_type(
        df=final_pca_df,
        name=name,
        logdir=logdir,
        scale=scale,
    )
else:
    print(f"Unknown pca_name: {pca_name}")

In [None]:
# fig = px.density_contour(
#     core_assay_df,
#     x="PC1",
#     y="PC2",
#     color="source",
#     height=800,
#     width=800,
#     )

# fig.update_traces(line=dict(width=1))

# fig.update_layout(
#     title="2D PCA - epiATLAS+ChiP-Atlas+ENC - core7 samples",
#     xaxis_title=axis_titles[0],
#     yaxis_title=axis_titles[1],
#     legend={"itemsizing": "constant"},
#     )
# fig.show()


# fig = px.density_contour(
#     core_assay_df,
#     x="PC1",
#     y="PC2",
#     marginal_x="histogram",
#     marginal_y="histogram",
#     color="source",
#     height=800,
#     width=800,
#     )

# fig.show()

### PCA with custom groups

In [None]:
if pca_name == "epiatlas_encode":
    logdir = general_logdir / "EpiATLAS_ENCODE" / "no_epiatlas_overlap"

    final_pca_df[PLOT_LABEL] = (
        final_pca_df["source"] + "_" + final_pca_df[ASSAY].map(map_general_assay_type)
    )
    create_graph_from_groups(
        df=final_pca_df,
        name="EpiATLAS & ENCODE - All assays",
        logdir=logdir,
        scale=scale,
    )

### recount3 vs EpiATLAS

In [None]:
def recount3_vs_epiatlas_pca(
    final_pca_df: pd.DataFrame,
    explained_variance: List[float],
    logdir: Path | None = None,
) -> None:
    """
    Plot all EpiATLAS files vs recount3 files.
    """
    # Create a new color dictionary
    color_dict = {
        "EpiATLAS_ChIP": px.colors.qualitative.Dark24[0],
        "EpiATLAS_RNA": px.colors.qualitative.Dark24[6],
        "EpiATLAS_WGB": px.colors.qualitative.Dark24[2],
        "recount3_RNA": px.colors.qualitative.Dark24[3],
    }

    axis_titles = [f"<b>PC{i+1}</b> ({explained_variance[i]:.2%})" for i in range(2)]

    fig = make_subplots(
        rows=3,
        cols=1,
        shared_xaxes=True,
        vertical_spacing=0.1,
        row_heights=[0.2, 0.2, 0.6],
        subplot_titles=(
            "EpiATLAS RNA PC1 distribution (%)",
            "recount3 RNA PC1 distribution (%)",
            "PCA",
        ),
        x_title=axis_titles[0],
    )

    for fig_label, color in color_dict.items():
        # filter/create correct groups
        if fig_label == "EpiATLAS_ChIP":
            filtered_df = final_pca_df[
                (final_pca_df["source"] == "epiatlas")
                & (final_pca_df["assay_epiclass"].isin(CORE_ASSAYS))
            ]
            display_label = "EpiATLAS ChIP"
        elif fig_label == "EpiATLAS_RNA":
            filtered_df = final_pca_df[
                (final_pca_df["source"] == "epiatlas")
                & (final_pca_df["assay_epiclass"].isin(["mrna_seq", "rna_seq"]))
            ]
            display_label = "EpiATLAS RNA"
        elif fig_label == "EpiATLAS_WGB":
            filtered_df = final_pca_df[
                (final_pca_df["source"] == "epiatlas")
                & (final_pca_df["assay_epiclass"].isin(["wgbs-standard", "wgbs-pbat"]))
            ]
            display_label = "EpiATLAS WGB"
        elif fig_label == "recount3_RNA":
            filtered_df = final_pca_df[(final_pca_df["source"] == "recount3")]
            display_label = "recount3 RNA"
        else:
            raise ValueError(f"Unknown fig_label: {fig_label}")

        # plot
        fig.add_trace(
            go.Scatter(
                x=filtered_df["PC1"],
                y=filtered_df["PC2"],
                mode="markers",
                marker=dict(
                    size=1.5,
                    color=color,
                    opacity=0.8,
                ),
                hovertemplate="%{text}",
                text=[
                    f"{id_label}: {assay} ({display_label})"
                    for id_label, assay in zip(
                        filtered_df["id"],
                        filtered_df["assay_epiclass"],
                    )
                ],
                name=f"{display_label} (N={filtered_df.shape[0]})",
                showlegend=True,
            ),
            row=3,
            col=1,
        )

        if fig_label == "EpiATLAS_RNA":
            fig.add_trace(
                go.Histogram(
                    x=filtered_df["PC1"],
                    histnorm="percent",
                    name=f"{display_label} (N={filtered_df.shape[0]})",
                    showlegend=False,
                    marker=dict(color=color),
                ),
                row=1,
                col=1,
            )

        if fig_label == "recount3_RNA":
            fig.add_trace(
                go.Histogram(
                    x=filtered_df["PC1"],
                    histnorm="percent",
                    name=f"{display_label} (N={filtered_df.shape[0]})",
                    showlegend=False,
                    marker=dict(color=color),
                ),
                row=2,
                col=1,
            )

    title = "PCA - EpiATLAS & recount3"
    fig.update_layout(
        title=title,
        legend={"itemsizing": "constant"},
    )

    # add y-axis title to last row
    fig.update_yaxes(title_text=axis_titles[1], row=3, col=1)

    # set histograms to same yrange
    fig.update_yaxes(range=[-0.001, 10], row=1, col=1, nticks=4)
    fig.update_yaxes(range=[-0.001, 10], row=2, col=1, nticks=4)

    if logdir:
        name = "pca_epiatlas_recount3_2D"
        fig.write_html(logdir / f"{name}.html")
        fig.write_image(logdir / f"{name}.png", scale=1.5)
        fig.write_image(logdir / f"{name}.svg")
    fig.show()

In [None]:
# recount3_vs_epiatlas_pca(
#     df=global_pca_df,
#     explained_variance=explained_variance,
# )

### RNA: recount3, ENCODE, EpiATLAS

In [None]:
print(f"nb epiatlas files: {(global_pca_df['id'].str.len() == 32).sum()}")
print(f"nb ENCODE files: {(global_pca_df['id'].str.slice(0, 3) == 'ENC').sum()}")

final_pca_df = global_pca_df.merge(graph_metadata, how="left", on="id")

In [None]:
final_pca_df["source"].fillna("encode", inplace=True)
display(final_pca_df["source"].value_counts(dropna=False))
display(final_pca_df[ASSAY].value_counts(dropna=False))

#### No histogram

In [None]:
def recount3_encode_RNA_vs_epiatlas_pca(
    df: pd.DataFrame,
    explained_variance: List[float],
    logdir: Path | None = None,
    scale: float = 4,
) -> None:
    """
    Plot EpiATLAS vs ENCODE RNA vs recount3 RNA files
    """
    df = df.copy(deep=True)

    # Create a new color dictionary
    color_dict = {
        "EpiATLAS": px.colors.qualitative.Dark24[0],
        "ENCODE": px.colors.qualitative.Dark24[2],
        "recount3": px.colors.qualitative.Dark24[3],
    }

    fig = go.Figure()
    for fig_label, color in color_dict.items():
        # filter/create correct groups
        if fig_label == "EpiATLAS":
            filtered_df = df[(df["source"] == "epiatlas")]
        elif fig_label == "ENCODE":
            filtered_df = df[(df["source"] == "encode")]
        elif fig_label == "recount3":
            filtered_df = df[(df["source"] == "recount3")]
        else:
            raise ValueError(f"fig_label must be one of {color_dict.keys()}")

        display_label = fig_label

        # plot
        fig.add_trace(
            go.Scatter(
                x=filtered_df["PC1"],
                y=filtered_df["PC2"],
                mode="markers",
                marker=dict(
                    size=1.5,
                    color=color,
                    opacity=0.8,
                ),
                hovertemplate="%{text}",
                text=[
                    f"{id_label}: {assay} ({display_label})"
                    for id_label, assay in zip(
                        filtered_df["id"],
                        filtered_df["assay_epiclass"],
                    )
                ],
                name=f"{display_label} (N={filtered_df.shape[0]})",
                showlegend=True,
            ),
        )

    axis_titles = [f"<b>PC {i+1}</b> ({explained_variance[i]:.2%})" for i in range(2)]

    title = "PCA - RNA files - EpiATLAS & ENCODE & recount3"
    fig.update_layout(
        title=title,
        xaxis_title=axis_titles[0],
        yaxis_title=axis_titles[1],
        legend={"itemsizing": "constant"},
    )

    adjust_axes(fig, scale, df)

    if logdir:
        name = "all_rna_samples_2D"
        fig.write_html(logdir / f"{name}.html")
        fig.write_image(logdir / f"{name}.png")
        fig.write_image(logdir / f"{name}.svg")
    fig.show()
    del fig

In [None]:
logdir = general_logdir / "rna_enc_epi_recount3"
logdir.mkdir(exist_ok=True)

recount3_encode_RNA_vs_epiatlas_pca(
    df=final_pca_df,
    explained_variance=explained_variance,
    logdir=logdir,
    scale=4,
)

#### With PC1 histograms

In [None]:
def recount3_encode_RNA_vs_epiatlas_pca_w_histogram(
    df: pd.DataFrame,
    explained_variance: List[float],
    logdir: Path | None = None,
    scale: float = 4,
) -> None:
    """
    Plot EpiATLAS vs ENCODE RNA vs recount3 RNA files
    includes PC1 histograms
    """
    df = df.copy(deep=True)

    # Create a new color dictionary
    color_dict = {
        "EpiATLAS": px.colors.qualitative.Dark24[0],
        "ENCODE": px.colors.qualitative.Dark24[2],
        "recount3": px.colors.qualitative.Dark24[3],
    }

    axis_titles = [f"<b>PC{i+1}</b> ({explained_variance[i]:.2%})" for i in range(2)]
    subplot_titles = [
        label + " PC1 file distribution (%)" for label in sorted(color_dict.keys())
    ] + ["PCA"]

    fig = make_subplots(
        rows=4,
        cols=1,
        shared_xaxes=True,
        vertical_spacing=0.05,
        subplot_titles=subplot_titles,
        x_title=axis_titles[0],
        row_heights=[0.1, 0.1, 0.1, 0.7],
    )

    histogram_row = 1
    for fig_label, color in sorted(color_dict.items()):
        # filter/create correct groups
        if fig_label == "EpiATLAS":
            filtered_df = df[(df["source"] == "epiatlas")]
        elif fig_label == "ENCODE":
            filtered_df = df[(df["source"] == "encode")]
        elif fig_label == "recount3":
            filtered_df = df[(df["source"] == "recount3")]
        else:
            raise ValueError(f"fig_label must be one of {color_dict.keys()}")

        display_label = fig_label

        # PCA 2D
        fig.add_trace(
            go.Scatter(
                x=filtered_df["PC1"],
                y=filtered_df["PC2"],
                mode="markers",
                marker=dict(
                    size=1.5,
                    color=color,
                    opacity=0.8,
                ),
                hovertemplate="%{text}",
                text=[
                    f"{id_label}: {assay} ({display_label})"
                    for id_label, assay in zip(
                        filtered_df["id"],
                        filtered_df["assay_epiclass"],
                    )
                ],
                name=f"{display_label} (N={filtered_df.shape[0]})",
                showlegend=True,
            ),
            row=4,
            col=1,
        )

        # PC1 histogram
        fig.add_trace(
            go.Histogram(
                x=filtered_df["PC1"],
                histnorm="percent",
                name=f"{display_label} (N={filtered_df.shape[0]})",
                marker=dict(color=color),
                showlegend=False,
            ),
            row=histogram_row,
            col=1,
        )
        histogram_row += 1

    title = "PCA - RNA files - EpiATLAS & ENCODE & recount3"
    fig.update_layout(
        title=title,
        legend={"itemsizing": "constant"},
        height=1200,
    )

    # add y-axis title to last row
    fig.update_yaxes(title_text=axis_titles[1], row=4, col=1)

    # set histograms to same yrange
    for row in [1, 2, 3]:
        fig.update_yaxes(range=[-0.001, 10], row=row, col=1, nticks=4)

    y_max = df["PC2"].max()
    y_min = df["PC2"].min()
    x_max = df["PC1"].max()
    x_min = df["PC1"].min()

    mul = scale
    height = (y_max - y_min) * mul
    width = (x_max - x_min) * mul
    fig.update_layout(height=height, width=width)

    if logdir:
        name = "all_rna_samples_2D_with_histogram"
        fig.write_html(logdir / f"{name}.html")
        fig.write_image(logdir / f"{name}.png")
        fig.write_image(logdir / f"{name}.svg")
    fig.show()
    del fig

In [None]:
recount3_encode_RNA_vs_epiatlas_pca_w_histogram(
    df=final_pca_df,
    explained_variance=explained_variance,
    logdir=logdir,
    scale=4,
)

#### EpiAtlas cell type coloring

In [None]:
def rna_cell_type_epiatlas_pca(
    df: pd.DataFrame,
    explained_variance: List[float],
    logdir: Path | None = None,
    scale: float = 4,
) -> None:
    """
    Plot EpiATLAS RNA files by cell type
    """
    df = df.copy(deep=True)

    # Create a new color dictionary
    epiatlas_df = df[(df["source"] == "epiatlas")]

    epiatlas_metadata = metadata_handler.load_metadata_df("v2")

    epiatlas_df = pd.merge(
        epiatlas_df,
        epiatlas_metadata,
        how="left",
        left_on="id",
        right_on="md5sum",
        suffixes=("", "_DROP"),
    )
    epiatlas_df = epiatlas_df.drop(
        columns=[col for col in epiatlas_df.columns if "_DROP" in col]
    )
    count_ct = epiatlas_df[CELL_TYPE].value_counts()
    color_dict = {
        cell_type: px.colors.qualitative.Dark24[i % 24]
        for i, cell_type in enumerate(epiatlas_df[CELL_TYPE].unique())
        if count_ct[cell_type] >= 10
    }

    axis_titles = [f"<b>PC{i+1}</b> ({explained_variance[i]:.2%})" for i in range(2)]

    fig = go.Figure()
    for fig_label, color in color_dict.items():
        filtered_df = epiatlas_df[(epiatlas_df[CELL_TYPE] == fig_label)]
        display_label = fig_label

        fig.add_trace(
            go.Scatter(
                x=filtered_df["PC1"],
                y=filtered_df["PC2"],
                mode="markers",
                marker=dict(
                    size=1.5,
                    color=color,
                    opacity=0.8,
                ),
                hovertemplate="%{text}",
                text=[
                    f"{id_label}: {assay},{track_type} ({display_label})"
                    for id_label, assay, track_type in zip(
                        filtered_df["id"],
                        filtered_df["assay_epiclass"],
                        filtered_df["track_type"],
                    )
                ],
                name=f"{display_label} (N={filtered_df.shape[0]})",
                showlegend=True,
            ),
        )

    title = "PCA - RNA files - EpiATLAS (files>=10 per cell type)"
    fig.update_layout(
        title=title,
        legend={"itemsizing": "constant"},
        height=1200,
        xaxis_title=axis_titles[0],
        yaxis_title=axis_titles[1],
    )

    # fig.update_yaxes(range=[-40, 60])
    # fig.update_xaxes(range=[-80, -10])

    y_max = df["PC2"].max()
    y_min = df["PC2"].min()
    x_max = df["PC1"].max()
    x_min = df["PC1"].min()

    mul = scale
    height = (y_max - y_min) * mul
    width = (x_max - x_min) * mul
    fig.update_layout(height=height, width=width)

    if logdir:
        name = "pca_RNA_epiatlas_cell_type_2D"
        fig.write_html(logdir / f"{name}.html")
        fig.write_image(logdir / f"{name}.png")
        fig.write_image(logdir / f"{name}.svg")
    fig.show()
    del fig

In [None]:
logdir = base_fig_dir / "pca"
rna_cell_type_epiatlas_pca(
    df=final_pca_df,
    explained_variance=explained_variance,
    logdir=logdir,
    scale=4,
)

#### recount3 assay coloring

In [None]:
def rna_assay_recount3_pca(
    final_pca_df: pd.DataFrame,
    explained_variance: List[float],
) -> None:
    """
    Plot recount3 RNA files by assay type
    """
    source_df = final_pca_df[(final_pca_df["source"] == "recount3")]

    color_dict = {
        cell_type: px.colors.qualitative.Dark24[i % 24]
        for i, cell_type in enumerate(source_df[ASSAY].unique())
    }

    axis_titles = [f"<b>PC{i+1}</b> ({explained_variance[i]:.2%})" for i in range(2)]

    fig = go.Figure()
    for fig_label, color in color_dict.items():
        filtered_df = source_df[(source_df[ASSAY] == fig_label)]
        display_label = fig_label

        fig.add_trace(
            go.Scatter(
                x=filtered_df["PC1"],
                y=filtered_df["PC2"],
                mode="markers",
                marker=dict(
                    size=1.5,
                    color=color,
                    opacity=0.8,
                ),
                hovertemplate="%{text}",
                text=[
                    f"{id_label}: {assay} ({display_label})"
                    for id_label, assay in zip(
                        filtered_df["id"],
                        filtered_df["assay_epiclass"],
                    )
                ],
                name=f"{display_label} (N={filtered_df.shape[0]})",
                showlegend=True,
            ),
        )

    title = "PCA - RNA files - recount3"
    fig.update_layout(
        title=title,
        legend={"itemsizing": "constant"},
        height=1200,
        xaxis_title=axis_titles[0],
        yaxis_title=axis_titles[1],
    )

    if logdir:
        figname = "pca_RNA_assay_recount3_2D"
        fig.write_html(logdir / f"{figname}.html")
        # fig.write_image(logdir / f"{figname}.svg")

        # fig.update_yaxes(range=[-40, 60])
        # fig.update_xaxes(range=[-80, -10])
        # fig.write_image(logdir / f"{figname}.png", scale=1.5)

    fig.show()

In [None]:
# rna_assay_recount3_pca(
#     final_pca_df,
#     explained_variance,
# )