In [1]:
"""Figure core creation: Fig2

Formatting of the figures may not be identical to the paper, but they contain the same data points.
"""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines, too-many-branches, consider-using-f-string

'Figure core creation: Fig2\n\nFormatting of the figures may not be identical to the paper, but they contain the same data points.\n'

In [2]:
%load_ext autoreload
%autoreload 2

## Setup

In [3]:
from __future__ import annotations

import itertools
import re
import tarfile
from pathlib import Path
from typing import Dict, List, Sequence

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import display
from plotly.subplots import make_subplots
from scipy.stats import zscore

from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    ASSAY_MERGE_DICT,
    CELL_TYPE,
    LIFE_STAGE,
    SEX,
    IHECColorMap,
    MetadataHandler,
    SplitResultsHandler,
    create_mislabel_corrector,
)

In [4]:
# TODO: Have a slimed down data directory that only uses the necessary files for figure creation, compress it, and use it here.

In [5]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
paper_dir = base_dir
if not paper_dir.exists():
    raise FileNotFoundError(f"Directory {paper_dir} does not exist.")

base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"

In [6]:
IHECColorMap = IHECColorMap(base_fig_dir)
assay_colors = IHECColorMap.assay_color_map
cell_type_colors = IHECColorMap.cell_type_color_map
sex_colors = IHECColorMap.sex_color_map

In [7]:
split_results_handler = SplitResultsHandler()

metadata_handler = MetadataHandler(paper_dir)
metadata_v2 = metadata_handler.load_metadata("v2")
metadata_v2_df = metadata_handler.load_metadata_df("v2")

## Fig 2

### Fig 2A, 2B (Supp Fig 4A) - MLP performance on various classification tasks - 100kb resolution files

- A: Accuracy
- B: F1-Score

In [8]:
def NN_performance_across_classification_tasks(
    split_metrics: Dict[str, Dict[str, Dict[str, float]]],
    name: str | None = None,
    logdir: Path | None = None,
    exclude_categories: List[str] | None = None,
    y_range: List[float] | None = None,
    sort_by_acc: bool = False,
    metric_names: Sequence[str] = ("Accuracy", "F1_macro"),
    title: str | None = None,
) -> List[str]:
    """Render box plots of metrics per classifier and split, each in its own subplot.

    This function generates a figure with subplots, each representing a different
    metric. Each subplot contains box plots for each classifier, ordered by accuracy.

    Args:
        split_metrics: A nested dictionary with structure {split: {classifier: {metric: score}}}.
        logdir: The directory path to save the output plots. If None, only display the plot.
        name: The base name for the output plot files.
        exclude_categories: Task categories to exclude from the plot.
        y_range: The y-axis range for the plots.
        sort_by_acc: Whether to sort the classifiers by accuracy.
        metric_names: The metrics to include in the plot.

    Returns:
        The list of classifier names in the order they appear in the plot.
    """
    # Exclude some categories
    classifier_names = list(split_metrics["split0"].keys())
    if exclude_categories is not None:
        for category in exclude_categories:
            classifier_names = [c for c in classifier_names if category not in c]

    available_metrics = list(split_metrics["split0"][classifier_names[0]].keys())
    if any(metric not in available_metrics for metric in metric_names):
        raise ValueError(f"Invalid metric. Metrics need to be in {available_metrics}")

    # Sort classifiers by accuracy
    if sort_by_acc:
        mean_acc = {}
        for classifier in classifier_names:
            mean_acc[classifier] = np.mean(
                [split_metrics[split][classifier]["Accuracy"] for split in split_metrics]
            )
        classifier_names = sorted(
            classifier_names, key=lambda x: mean_acc[x], reverse=True
        )

    # Create subplots, one column for each metric
    fig = make_subplots(
        rows=1,
        cols=len(metric_names),
        subplot_titles=metric_names,
        horizontal_spacing=0.03,
    )

    color_group = px.colors.qualitative.Plotly
    colors = {
        classifier: color_group[i % len(color_group)]
        for i, classifier in enumerate(classifier_names)
    }

    point_pos = 0
    for i, metric in enumerate(metric_names):
        for classifier_name in classifier_names:
            values = [
                split_metrics[split][classifier_name][metric] for split in split_metrics
            ]

            fig.add_trace(
                go.Box(
                    y=values,
                    name=classifier_name,
                    fillcolor=colors[classifier_name],
                    line=dict(color="black", width=1.5),
                    marker=dict(size=3, color="white", line_width=1),
                    boxmean=True,
                    boxpoints="all",
                    pointpos=point_pos,
                    showlegend=i == 0,  # Only show legend in the first subplot
                    hovertemplate="%{text}",
                    text=[
                        f"{split}: {value:.4f}"
                        for split, value in zip(split_metrics, values)
                    ],
                    legendgroup=classifier_name,
                    width=0.5,
                ),
                row=1,
                col=i + 1,
            )

    title_text = (
        "Neural network classification - Metric distribution for 10-fold cross-validation"
    )
    if title:
        title_text = title

    fig.update_layout(
        title_text=title_text,
        yaxis_title="Value",
        boxmode="group",
        height=1200 * 0.8,
        width=1750 * 0.8,
    )

    # Acc, F1
    fig.update_layout(yaxis=dict(range=[0.88, 1.001]))
    fig.update_layout(yaxis2=dict(range=[0.80, 1.001]))

    # AUC
    range_auc = [0.986, 1.0001]
    fig.update_layout(yaxis3=dict(range=range_auc))
    fig.update_layout(yaxis4=dict(range=range_auc))

    if y_range is not None:
        fig.update_yaxes(range=y_range)

    # Save figure
    if logdir:
        if name is None:
            name = "MLP_metrics_various_tasks"
        fig.write_image(logdir / f"{name}.svg")
        fig.write_image(logdir / f"{name}.png")
        fig.write_html(logdir / f"{name}.html")

    fig.show()

    return classifier_names

In [9]:
# exclude_categories = ["track_type", "group", "disease", "PE", "martin"]
exclude_categories = ["track_type", "group", "disease"]
exclude_names = ["chip-seq", "7c", "16ct", "no-mixed"]

hdf5_type = "hg38_100kb_all_none"
results_dir = base_data_dir / "training_results" / "dfreeze_v2" / hdf5_type
if not results_dir.exists():
    raise FileNotFoundError(f"Directory {results_dir} does not exist.")

mislabel_correction = True
if mislabel_correction:
    mislabel_corrector = create_mislabel_corrector(paper_dir)
else:
    mislabel_corrector = None

split_results_metrics, all_split_results = split_results_handler.general_split_metrics(
    results_dir,
    merge_assays=True,
    exclude_categories=exclude_categories,
    exclude_names=exclude_names,
    return_type="both",
    oversampled_only=True,
    mislabel_corrections=mislabel_corrector,
    verbose=False,
)

In [12]:
fig_logdir = base_fig_dir / "fig2_EpiAtlas_other" / "fig2--NN_perf_across_categories"
fig_logdir.mkdir(parents=False, exist_ok=True)

if mislabel_correction:
    this_fig_logdir = fig_logdir / "post_mislabel_correction"
    this_fig_logdir.mkdir(parents=False, exist_ok=True)
else:
    this_fig_logdir = fig_logdir / "no_mislabel_correction"
    this_fig_logdir.mkdir(parents=False, exist_ok=True)

metrics_full = ["Accuracy", "F1_macro", "AUC_micro", "AUC_macro"]
fig_name = f"{hdf5_type}_perf_across_categories_full"
sorted_task_names = NN_performance_across_classification_tasks(
    split_results_metrics,  # type: ignore
    sort_by_acc=True,
    metric_names=metrics_full,
)

In [18]:
assert mislabel_corrector

for min_pred_score in [0, 0.6, 0.8]:
    split_results_metrics = split_results_handler.general_split_metrics(
        results_dir,
        merge_assays=True,
        exclude_categories=exclude_categories,
        exclude_names=exclude_names,
        return_type="metrics",
        oversampled_only=True,
        mislabel_corrections=mislabel_corrector,
        verbose=False,
        min_pred_score=min_pred_score,
    )

    metrics_full = ["Accuracy", "F1_macro", "AUC_micro", "AUC_macro"]
    fig_name = f"{hdf5_type}_perf_across_categories_full_minPredScore{min_pred_score}"
    fig_title = f"MLP classification - Metric distribution for 10-fold cross-validation (minPredScore={min_pred_score})"
    sorted_task_names = NN_performance_across_classification_tasks(
        split_results_metrics,  # type: ignore
        sort_by_acc=True,
        metric_names=metrics_full,
        name=fig_name,
        logdir=this_fig_logdir,
        title=fig_title,
    )

TRUE or FALSE in pred vector. Changing column names.
TRUE or FALSE in pred vector. Changing column names.
TRUE or FALSE in pred vector. Changing column names.
TRUE or FALSE in pred vector. Changing column names.
TRUE or FALSE in pred vector. Changing column names.
TRUE or FALSE in pred vector. Changing column names.
TRUE or FALSE in pred vector. Changing column names.
TRUE or FALSE in pred vector. Changing column names.
TRUE or FALSE in pred vector. Changing column names.
TRUE or FALSE in pred vector. Changing column names.


TRUE or FALSE in pred vector. Changing column names.
TRUE or FALSE in pred vector. Changing column names.
TRUE or FALSE in pred vector. Changing column names.
TRUE or FALSE in pred vector. Changing column names.
TRUE or FALSE in pred vector. Changing column names.
TRUE or FALSE in pred vector. Changing column names.
TRUE or FALSE in pred vector. Changing column names.
TRUE or FALSE in pred vector. Changing column names.
TRUE or FALSE in pred vector. Changing column names.
TRUE or FALSE in pred vector. Changing column names.


### Fig 2C: Normalized Shannon Entropy on various metadata categories

In [11]:
def compute_class_imbalance(
    all_split_results: Dict[str, Dict[str, pd.DataFrame]]
) -> pd.DataFrame:
    """Compute class imbalance for each task and split.

    Args:
        all_split_results: A dictionary with structure {task_name: {split_name: split_results_df}}.

    Returns:
        pd.DataFrame: A DataFrame with the following columns:
            - avg(balance_ratio): The average balance ratio for each task.
            - n: The number of classes for each task (used for the average).
    """
    # combine md5 lists
    task_md5s = {
        classifier_task: [split_df.index for split_df in split_results.values()]
        for classifier_task, split_results in all_split_results.items()
    }
    task_md5s = {
        classifier_task: [list(split_md5s) for split_md5s in md5s]
        for classifier_task, md5s in task_md5s.items()
    }
    task_md5s = {
        classifier_task: list(itertools.chain(*md5s))
        for classifier_task, md5s in task_md5s.items()
    }

    # get metadata
    metadata_df = metadata_handler.load_metadata_df("v2-encode")

    label_counts = {}
    for classifier_task, md5s in task_md5s.items():
        try:
            label_counts[classifier_task] = metadata_df.loc[md5s][
                classifier_task
            ].value_counts()
        except KeyError as e:
            category_name = classifier_task.rsplit("_", maxsplit=1)[0]
            try:
                label_counts[classifier_task] = metadata_df.loc[md5s][
                    category_name
                ].value_counts()
            except KeyError as e:
                raise e

    # Compute Shannon Entropy
    class_balance = {}
    for classifier_task, counts in label_counts.items():
        total_count = counts.sum()
        k = len(counts)
        p_x = counts / total_count  # class proportions
        p_x = p_x.values
        shannon_entropy = -np.sum(p_x * np.log2(p_x))
        balance = shannon_entropy / np.log2(k)
        class_balance[classifier_task] = (balance, k)

    df_class_balance = pd.DataFrame.from_dict(
        class_balance, orient="index", columns=["Normalized Shannon Entropy", "k"]
    ).sort_index()

    return df_class_balance

In [12]:
df_class_balance = compute_class_imbalance(all_split_results)  # type: ignore

In [13]:
df_class_balance = df_class_balance.loc[sorted_task_names]
fig = px.bar(
    df_class_balance,
    x=df_class_balance.index,
    y="Normalized Shannon Entropy",
    labels={
        "k": "Number of classes",
        "Normalized Shannon Entropy": "Normalized Shannon Entropy",
    },
    title="Class imbalance across tasks",
)
fig.update_layout(
    yaxis=dict(range=[0, 1]),
)

### Fig 2D: Donor Sex, before and after label correction

IHEC_metadata_harmonization.v1.1.extended.csv contains 314 EpiRRs with unknown sex. We applied a fully trained sex classifier on those.

Need to reproduce: https://docs.google.com/spreadsheets/d/1nMCEOZd16pWHfGY63vXcI2JzeOqGBR-tCoBuRZLZgQM/edit?gid=608314669#gid=608314669

In [14]:
sex_results_dir = (
    base_data_dir
    / "training_results"
    / "dfreeze_v2"
    / "hg38_100kb_all_none"
    / f"{SEX}_1l_3000n"
    / "w-mixed"
)
if not sex_results_dir.exists():
    raise FileNotFoundError(f"Directory {sex_results_dir} does not exist.")

#### Unknown Predictions

In [15]:
official_metadata_dir = base_data_dir / "metadata" / "official"

metadata_v1_1_path = (
    official_metadata_dir / "IHEC_metadata_harmonization.v1.1.extended.csv"
)
metadata_v1_1 = pd.read_csv(metadata_v1_1_path, index_col=0)

metadata_v1_2_path = (
    official_metadata_dir / "IHEC_metadata_harmonization.v1.2.extended.csv"
)
metadata_v1_2 = pd.read_csv(metadata_v1_2_path, index_col=0)

In [16]:
full_metadata_df = metadata_v2_df
full_metadata_df["md5sum"] = full_metadata_df.index
assert (
    metadata_v2_df[metadata_v2_df[SEX].isin(["unknown"])]["EpiRR"].nunique()
    == metadata_v1_1[metadata_v1_1[SEX] == "unknown"].index.nunique()
    == 314
)

In [17]:
sex_full_model_dir = sex_results_dir / "complete_no_valid_oversample"
if not sex_full_model_dir.exists():
    raise FileNotFoundError(f"Directory {sex_full_model_dir} does not exist")

pred_unknown_file_path = (
    sex_full_model_dir
    / "predictions"
    / "complete_no_valid_oversample_test_prediction_100kb_all_none_dfreeze_v2.1_sex_mixed_unknown.csv"
)
pred_unknown_df = pd.read_csv(pred_unknown_file_path, index_col=0, header=0)

In [25]:
pred_unknown_df = pred_unknown_df[pred_unknown_df["True class"] == "unknown"]
pred_unknown_df = split_results_handler.add_max_pred(pred_unknown_df)  # type: ignore
pred_unknown_df = metadata_handler.join_metadata(pred_unknown_df, metadata_v2)
pred_unknown_df["md5sum"] = pred_unknown_df.index

#### 10fold cross-validation results

In [19]:
sex_10fold_dir = sex_results_dir / "10fold-oversampling"
if not sex_10fold_dir.exists():
    raise FileNotFoundError(f"Directory {sex_10fold_dir} does not exist")

split_results: Dict[str, pd.DataFrame] = split_results_handler.read_split_results(
    sex_10fold_dir
)
concat_results_10fold: pd.DataFrame = split_results_handler.concatenate_split_results(split_results, depth=1)  # type: ignore
concat_results_10fold = split_results_handler.add_max_pred(concat_results_10fold)
concat_results_10fold = metadata_handler.join_metadata(concat_results_10fold, metadata_v2)

#### Average chrY values z-score distributions

1. For each bigwig file, the chrY average value is computed. (with pyBigWig module, in chrY_bigwig_mean.py)  
2. For each assay, the z-score distribution (of the mean chrY value) of the file group is computed.  
3. Graph E is made by averaging for each EpiRR the z-score value in each assay distribution.  
---
1. Outputs `chrXY_coverage_all.csv`
2. Outputs `chrY_coverage_zscores.csv`

In [20]:
chrY_coverage_dir = base_data_dir / "chrY_coverage"
if not chrY_coverage_dir.exists():
    raise FileNotFoundError(f"Directory {chrY_coverage_dir} does not exist")

In [21]:
def compute_chrY_zscores(
    chrY_coverage_dir: Path, version: str, save: bool = False
) -> pd.DataFrame:
    """Compute z-scores for chrY coverage data.

    Computes two distributions of z-scores:
    1) Per assay group, excluding raw, pval, and Unique_raw tracks.
    2) Per assay+track group.

    In both cases, rna-seq/mrna-seq and wgbs-standard/wgbs-pbat are put as one assay.

    Args:
        chrY_coverage_dir: The directory containing the chrY coverage data.
        version: The metadata version to use.
        save: Whether to save the results.

    Returns:
        pd.DataFrame: The chrY coverage data with z-scores appended.
    """
    output_dir = Path()
    if save:
        output_dir = chrY_coverage_dir / f"dfreeze_{version}_stats"
        output_dir.mkdir(parents=False, exist_ok=True)

    # Get chrY coverage data
    chrY_coverage_dir = base_data_dir / "chrY_coverage"
    if not chrY_coverage_dir.exists():
        raise FileNotFoundError(f"Directory {chrY_coverage_dir} does not exist.")
    chrY_coverage_df = pd.read_csv(chrY_coverage_dir / "chrXY_coverage_all.csv", header=0)

    # Filter out md5s not in metadata version
    metadata = MetadataHandler(paper_dir).load_metadata(version)
    md5s = set(metadata.md5s)
    chrY_coverage_df = chrY_coverage_df[chrY_coverage_df["filename"].isin(md5s)]

    # Make sure all values are non-zero
    if not (chrY_coverage_df["chrY"] != 0).all():
        raise ValueError("Some chrY values are zero.")

    # Merge metadata
    metadata_df = pd.DataFrame.from_records(list(metadata.datasets))
    metadata_df[ASSAY].replace(ASSAY_MERGE_DICT, inplace=True)
    chrY_coverage_df = chrY_coverage_df.merge(
        metadata_df, left_on="filename", right_on="md5sum"
    )

    # Compute stats for distributions
    metric_name_1 = "chrY_zscore_vs_assay_w_track_exclusion"
    metric_name_2 = "chrY_zscore_vs_assay_track"
    files1 = chrY_coverage_df[
        ~chrY_coverage_df["track_type"].isin(["raw", "pval", "Unique_raw"])
    ]
    files2 = chrY_coverage_df
    dist1 = files1.groupby(ASSAY).agg({"chrY": ["mean", "std", "count"]})
    dist2 = files2.groupby([ASSAY, "track_type"]).agg({"chrY": ["mean", "std", "count"]})
    if save:
        output_dir: Path
        dist1.to_csv(output_dir / "chrY_coverage_stats_assay_w_track_exclusion.csv")
        dist2.to_csv(output_dir / "chrY_coverage_stats_assay_and_track.csv")

    # Compute full z-score distributions
    for groups in files1.groupby(ASSAY):
        _, group_df = groups
        group_df["zscore"] = zscore(group_df["chrY"])
        chrY_coverage_df.loc[group_df.index, metric_name_1] = group_df["zscore"]
        chrY_coverage_df.loc[group_df.index, f"N_{metric_name_1}"] = groups[1].shape[0]
    for groups in files2.groupby([ASSAY, "track_type"]):
        _, group_df = groups
        group_df["zscore"] = zscore(group_df["chrY"])
        chrY_coverage_df.loc[group_df.index, metric_name_2] = group_df["zscore"]
        chrY_coverage_df.loc[group_df.index, f"N_{metric_name_2}"] = groups[1].shape[0]

    # Fill in missing values
    for N_name in [f"N_{metric_name_1}", f"N_{metric_name_2}"]:
        chrY_coverage_df[N_name] = chrY_coverage_df[N_name].fillna(0).astype(int)
    chrY_coverage_df.fillna("NA", inplace=True)

    if save:
        output_cols = [
            "filename",
            ASSAY,
            "track_type",
            "chrY",
            metric_name_1,
            f"N_{metric_name_1}",
            metric_name_2,
            f"N_{metric_name_2}",
        ]
        chrY_coverage_df[output_cols].to_csv(
            output_dir / "chrY_coverage_zscores.csv", index=False  # type: ignore
        )
    return chrY_coverage_df

In [22]:
chrY_coverage_df = compute_chrY_zscores(chrY_coverage_dir, "v2", save=True)

#### Fig 2D - Inner portion

In [23]:
# Proportion of unknown, excluding mixed. same as v1.1 ihec metadata
no_mixed = full_metadata_df[full_metadata_df[SEX] != "mixed"]

with pd.option_context("display.float_format", "{:.4%}".format):
    print("file-wise:")
    display(no_mixed[SEX].value_counts() / no_mixed.shape[0])

    print("EpiRR-wise:")
    epirr_no_mixed = no_mixed.drop_duplicates(subset=["EpiRR"])
    display(epirr_no_mixed[SEX].value_counts() / epirr_no_mixed.shape[0])

file-wise:


female    44.8769%
male      42.4098%
unknown   12.7133%
Name: harmonized_donor_sex, dtype: float64

EpiRR-wise:


male      45.1408%
female    40.5995%
unknown   14.2598%
Name: harmonized_donor_sex, dtype: float64

#### Fig 2D - Outer portion

Outer ring represents SEX metadata labnels v1.2 (without `mixed` labels), which had those modifications:
- Some unknown SEX files were labelled, using (assay,track type) z-score in conjunction with fully trained model predictions.
- Correction of some mislabels, using 10fold cross-validation results

In [24]:
meta_v1_2_no_mixed = metadata_v1_2[metadata_v1_2[SEX] != "mixed"]
with pd.option_context("display.float_format", "{:.4%}".format):
    print("EpiRR-wise:")
    display(meta_v1_2_no_mixed.value_counts(SEX) / meta_v1_2_no_mixed.shape[0])

EpiRR-wise:


harmonized_donor_sex
male      51.4184%
female    46.5426%
unknown    2.0390%
dtype: float64

#### Unknown predictions analysis file

In [25]:
index_cols = [
    "EpiRR",
    "project",
    "harmonized_donor_type",
    CELL_TYPE,
    SEX,
    "Predicted class",
]
val_cols = ["Max pred", "chrY_zscore_vs_assay_track"]
pred_unknown_analysis = pred_unknown_df.merge(
    chrY_coverage_df, on="md5sum", suffixes=("", "_DROP")
)
pred_unknown_analysis.drop(
    columns=[c for c in pred_unknown_analysis.columns if c.endswith("_DROP")],
    inplace=True,
)

# equivalent to [insert analysis file here]
# TODO: Add analysis file here
pivot_table = pred_unknown_analysis.pivot_table(
    index=index_cols,
    values=val_cols,
    aggfunc=["mean", "median", "std", "count"],
)

In [26]:
# display(pivot_table)

#### 10fold cross-validation predictions analysis file

In [27]:
index_cols = [
    "EpiRR",
    "project",
    "harmonized_donor_type",
    CELL_TYPE,
    SEX,
    "Predicted class",
]
val_cols = ["Max pred", "chrY_zscore_vs_assay_w_track_exclusion"]
cross_val_analysis = concat_results_10fold.merge(chrY_coverage_df, left_index=True, right_on="md5sum", suffixes=("", "_DROP"))  # type: ignore
cross_val_analysis.drop(
    columns=[c for c in cross_val_analysis.columns if c.endswith("_DROP")], inplace=True
)

# equivalent to [insert analysis file here]
# TODO: Add analysis file here
pivot_table = cross_val_analysis.pivot_table(
    index=index_cols,
    values=val_cols,
    aggfunc=["mean", "median", "std", "count"],
)

In [28]:
# display(pivot_table)

### Fig 2E - Average EpiRR z-score for 10fold on v1.1 IHEC harmonized metadata

In [29]:
def zscore_merged_assays(
    zscore_df: pd.DataFrame,
    sex_mislabels: Dict[str, str],
    name: str | None = None,
    logdir: Path | None = None,
    min_pred: float | None = None,
) -> None:
    """Male vs Female z-score distribution for merged assays, excluding wgbs.

    Does not include pval and raw tracks.

    Highlights mislabels in the plot.

    Args:
        zscore_df (pd.DataFrame): The dataframe with z-score data.
        sex_mislabels (Dict[str, str]): {EpiRR_no-v: corrected_sex_label}
        logdir (Path): The directory path to save the output plots. If None, only display the plot.
        name (str): The base name for the output plot files.
        min_pred (float|None): Minimum prediction value to include in the plot. Used on average EpiRR 'Max pred' values.
    """
    zscore_df = zscore_df.copy(deep=True)
    zscore_df[ASSAY].replace(ASSAY_MERGE_DICT, inplace=True)

    # wgbs reverses male/female chrY tendency, so removed here
    zscore_df = zscore_df[zscore_df[ASSAY] != "wgbs"]

    # Average chrY z-score values
    metric_label = "chrY_zscore_vs_assay_w_track_exclusion"
    zscore_df = zscore_df[zscore_df[metric_label] != "NA"]
    mean_chrY_values_df = zscore_df.groupby(["EpiRR", SEX]).agg(
        {metric_label: "mean", "Max pred": "mean"}
    )
    mean_chrY_values_df.reset_index(inplace=True)
    if not mean_chrY_values_df["EpiRR"].is_unique:
        raise ValueError("EpiRR is not unique.")

    # Filter out low prediction values
    if min_pred is not None:
        mean_chrY_values_df = mean_chrY_values_df[
            mean_chrY_values_df["Max pred"] > min_pred
        ]

    mean_chrY_values_df.reset_index(drop=True, inplace=True)
    mean_chrY_values_df["EpiRR_no_v"] = mean_chrY_values_df["EpiRR"].str.extract(
        pat=r"(\w+\d+).\d+"
    )[0]

    chrY_values = mean_chrY_values_df[metric_label]
    female_idx = np.argwhere((mean_chrY_values_df[SEX] == "female").values).flatten()  # type: ignore
    male_idx = np.argwhere((mean_chrY_values_df[SEX] == "male").values).flatten()  # type: ignore

    # Mislabels
    predicted_as_female = set(
        epirr_no_v for epirr_no_v, label in sex_mislabels.items() if label == "female"
    )
    predicted_as_male = set(
        epirr_no_v for epirr_no_v, label in sex_mislabels.items() if label == "male"
    )
    predicted_as_female_idx = np.argwhere(mean_chrY_values_df["EpiRR_no_v"].isin(predicted_as_female).values).flatten()  # type: ignore
    predicted_as_male_idx = np.argwhere(mean_chrY_values_df["EpiRR_no_v"].isin(predicted_as_male).values).flatten()  # type: ignore

    # Hovertext
    hovertext = [
        f"{epirr}: <z-score>={z_score:.3f}"
        for epirr, z_score in zip(
            mean_chrY_values_df["EpiRR"],
            mean_chrY_values_df[metric_label],
        )
    ]
    hovertext = np.array(hovertext)

    # Create figure
    fig = go.Figure()
    fig.add_trace(
        go.Box(
            name="Female",
            x=[0] * len(female_idx),
            y=chrY_values[female_idx],  # type: ignore
            boxmean=True,
            boxpoints="all",
            pointpos=-2,
            hovertemplate="%{text}",
            text=hovertext[female_idx],
            marker=dict(size=2),
            line=dict(width=1, color="black"),
            fillcolor=sex_colors["female"],
            showlegend=True,
        ),
    )

    fig.add_trace(
        go.Box(
            name="Male",
            x=[1] * len(female_idx),
            y=chrY_values[male_idx],  # type: ignore
            boxmean=True,
            boxpoints="all",
            pointpos=-2,
            hovertemplate="%{text}",
            text=hovertext[male_idx],
            marker=dict(size=2),
            line=dict(width=1, color="black"),
            fillcolor=sex_colors["male"],
            showlegend=True,
        ),
    )

    fig.add_trace(
        go.Scatter(
            name="Male",
            x=[-0.5] * len(predicted_as_male_idx),
            y=chrY_values[predicted_as_male_idx],  # type: ignore
            mode="markers",
            marker=dict(
                size=10, color=sex_colors["male"], line=dict(width=1, color="black")
            ),
            hovertemplate="%{text}",
            text=hovertext[predicted_as_male_idx],
            showlegend=False,
        ),
    )

    fig.add_trace(
        go.Scatter(
            name="Female",
            x=[0.5] * len(predicted_as_female_idx),
            y=chrY_values[predicted_as_female_idx],  # type: ignore
            mode="markers",
            marker=dict(
                size=10, color=sex_colors["female"], line=dict(width=1, color="black")
            ),
            hovertemplate="%{text}",
            text=hovertext[predicted_as_female_idx],
            showlegend=False,
        ),
    )

    fig.update_xaxes(showticklabels=False)

    fig.update_yaxes(range=[-1.5, 3])
    title = "z-score(mean chrY coverage per file) distribution - z-scores averaged over assays"
    if min_pred is not None:
        title += f"<br>avg_maxPred>{min_pred}"

    fig.update_layout(
        title=dict(text=title, x=0.5),
        xaxis_title=SEX,
        yaxis_title="Average z-score",
        width=750,
        height=750,
    )

    # Save figure
    if logdir:
        this_name = f"{metric_label}_n{mean_chrY_values_df.shape[0]}"
        if name:
            this_name = f"{name}_n{this_name}"
        fig.write_image(logdir / f"{this_name}.svg")
        fig.write_image(logdir / f"{this_name}.png")
        fig.write_html(logdir / f"{this_name}.html")

    fig.show()

In [30]:
zscore_merged_assays(
    zscore_df=cross_val_analysis,
    sex_mislabels=create_mislabel_corrector(paper_dir)[1][SEX],
)

### Supp Fig 4B - Average EpiRR z-score per assay (for 10fold on v1.1 IHEC harmonized metadata)

In [31]:
def zscore_per_assay(
    zscore_df: pd.DataFrame, logdir: Path | None = None, name: str | None = None
) -> None:
    """
    Plot the z-score distributions per assay.

    Does not include pval and raw tracks.

    Args:
        zscore_df: The dataframe with z-score data.
    """
    zscore_df = zscore_df.copy(deep=True)

    metric_label = "chrY_zscore_vs_assay_w_track_exclusion"
    zscore_df = zscore_df[zscore_df[metric_label] != "NA"]

    zscore_df[ASSAY].replace(ASSAY_MERGE_DICT, inplace=True)
    assay_sizes = zscore_df[ASSAY].value_counts()
    assays = sorted(assay_sizes.index)

    x_title = "Sex z-score distributions per assay"
    fig = make_subplots(
        rows=1,
        cols=len(assays),
        shared_yaxes=True,
        x_title=x_title,
        y_title="z-score",
        horizontal_spacing=0.02,
        subplot_titles=[
            f"{assay_label} ({assay_sizes[assay_label]})" for assay_label in assays
        ],
    )

    for i, assay_label in enumerate(sorted(assays)):
        sub_df = zscore_df[zscore_df[ASSAY] == assay_label]

        hovertext = [
            f"{epirr}: z-score={z_score:.3f}, pred={pred:.3f}"
            for epirr, pred, z_score in zip(
                sub_df["EpiRR"],
                sub_df["Max pred"],
                sub_df[metric_label],
            )
        ]
        hovertext = np.array(hovertext)

        sub_df.reset_index(drop=True, inplace=True)
        y_values = sub_df[metric_label].values

        female_idx = np.argwhere((sub_df[SEX] == "female").values).flatten()
        male_idx = np.argwhere((sub_df[SEX] == "male").values).flatten()

        fig.add_trace(
            go.Box(
                name=assay_label,
                y=y_values[female_idx],
                boxmean=True,
                boxpoints="all",
                hovertemplate="%{text}",
                text=hovertext[female_idx],
                marker=dict(
                    size=2,
                    color=sex_colors["female"],
                    line=dict(width=0.5, color="black"),
                ),
                fillcolor=sex_colors["female"],
                line=dict(width=1, color="black"),
                showlegend=False,
                legendgroup="Female",
            ),
            row=1,
            col=i + 1,
        )

        fig.add_trace(
            go.Box(
                name=assay_label,
                y=y_values[male_idx],
                boxmean=True,
                boxpoints="all",
                hovertemplate="%{text}",
                text=hovertext[male_idx],
                marker=dict(
                    size=2, color=sex_colors["male"], line=dict(width=0.5, color="black")
                ),
                fillcolor=sex_colors["male"],
                line=dict(width=1, color="black"),
                showlegend=False,
                legendgroup="Male",
            ),
            row=1,
            col=i + 1,
        )

    # Add a dummy scatter plot for legend
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            name="Female",
            marker=dict(color=sex_colors["female"], size=20),
            showlegend=True,
            legendgroup="Female",
        )
    )
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            name="Male",
            marker=dict(color=sex_colors["male"], size=20),
            showlegend=True,
            legendgroup="Male",
        )
    )

    fig.update_xaxes(showticklabels=False)
    fig.update_yaxes(range=[-1.5, 3])
    title = "z-score(mean chrY coverage per file) distribution per assay"
    fig.update_layout(
        title_text=title,
        width=3000,
        height=1000,
    )

    # Save figure
    if logdir:
        if name is None:
            name = "zscore_distributions_per_assay"
        fig.write_image(logdir / f"{name}.svg")
        fig.write_image(logdir / f"{name}.png")
        fig.write_html(logdir / f"{name}.html")

    fig.show()

In [32]:
zscore_per_assay(cross_val_analysis)

### Fig 2F - Donor life stage and GP-Age

#### Supp Fig 4C

In [33]:
official_metadata_dir = base_data_dir / "metadata" / "official"
meta_v1_2_df = pd.read_csv(
    official_metadata_dir / "IHEC_metadata_harmonization.v1.2.extended.csv"
)

In [34]:
gp_age_dir = base_data_dir / "GP_age"
if not gp_age_dir.exists():
    raise FileNotFoundError(f"Directory {gp_age_dir} does not exist.")

In [35]:
df_gp_age = pd.read_csv(
    gp_age_dir / "GPage_prediction_with_annotation.20240405.tsv", sep="\t"
)
df_gp_age["graph_age"] = df_gp_age["model30_wNA"]
df_gp_age["epiclass_pred"] = df_gp_age["grouped"]
print(df_gp_age.shape)

(634, 17)


In [36]:
df_gp_age.loc[:, "graph_age_categories"] = df_gp_age["epiclass_pred"].str.removesuffix(
    "_pred"
)

gp_age_categories = {
    "adult": "adult",
    "child": "pediatric",
    "embryonic": "perinatal",
    "fetal": "perinatal",
    "newborn": "perinatal",
    "unknown": "unknown",
}
df_gp_age.loc[:, "graph_age_categories"] = df_gp_age["graph_age_categories"].map(
    gp_age_categories
)

In [37]:
df_gp_age["graph_age_categories"].value_counts()

adult        479
pediatric     54
perinatal     53
unknown       48
Name: graph_age_categories, dtype: int64

In [38]:
organ_type_col = "harmonized_sample_organ_system_order_AnetaMikulasova"
df_gp_age = df_gp_age.merge(
    meta_v1_2_df[["EpiRR", organ_type_col, CELL_TYPE]],
    left_on="epirr_id",
    right_on="EpiRR",
    how="inner",
)
df_gp_age.drop_duplicates(subset=["EpiRR"], inplace=True)

In [39]:
df_gp_age[organ_type_col].value_counts(dropna=False)

Immune System                                   338
Reproductive                                     75
Digestive                                        72
Nervous                                          40
Endocrine                                        31
Muscular                                         27
Urinary (Renal) System (or Excretory System)     15
Cardiovascular System (Circulatory System)       11
Integumentary                                    11
Stem Cells and Derived Cell Lines                 5
Respiratory                                       5
Skeletal                                          4
Name: harmonized_sample_organ_system_order_AnetaMikulasova, dtype: int64

In [40]:
df_gp_age.loc[:, "tissue_group"] = df_gp_age[organ_type_col].copy()
df_gp_age.loc[:, "tissue_group"] = [
    "other" if val != "Immune System" else val for val in df_gp_age["tissue_group"]
]

In [41]:
all_preds = df_gp_age.replace("unknown", np.nan, inplace=False)
all_preds = all_preds.dropna(subset=["graph_age_categories"], inplace=False)

In [42]:
# for df in [df_gp_age, all_preds]:
#     print(df.shape)
#     display(df["tissue_group"].value_counts())
#     display(df["graph_age_categories"].value_counts())
#     display(df["epiclass_pred"].value_counts())

In [43]:
def graph_gp_age(
    df_gp_age: pd.DataFrame, logdir: Path | None = None, name: str | None = None
) -> None:
    """
    Plot the GP age predictions.

    Args:
        df_gp_age: The dataframe with GP age data.
    """
    df_gp_age = df_gp_age.copy(deep=True)

    tissue_colors = {"blood": "red", "Immune System": "red", "other": "gray"}

    age_cat_label = "graph_age_categories"

    fig = go.Figure()
    for tissue_group in df_gp_age["tissue_group"].unique():
        sub_df = df_gp_age[df_gp_age["tissue_group"] == tissue_group]
        fig.add_trace(
            go.Box(
                name=f"{tissue_group} (n={len(sub_df)})",
                x=sub_df[age_cat_label],
                y=sub_df["graph_age"],
                boxmean=True,
                boxpoints="all",
                hovertemplate="%{text}",
                text=[
                    f"{ct}: {age:.3f}"
                    for ct, age in zip(sub_df[CELL_TYPE], sub_df["graph_age"])
                ],
                marker=dict(size=2, color=tissue_colors[tissue_group]),
                showlegend=True,
            ),
        )

    fig.update_layout(
        title="GP age predictions - Using MLP predicted labels",
        xaxis_title="Life stage",
        yaxis_title="GP-Age : Predicted age",
        width=750,
        height=750,
        boxmode="group",
    )

    # Order x-axis
    label_order = ["perinatal", "pediatric", "adult"]
    axis_labels = [
        f"{age_cat} (n={(df_gp_age[age_cat_label] == age_cat).sum()})"
        for age_cat in label_order
    ]

    fig.update_xaxes(categoryorder="array", categoryarray=label_order)
    fig.update_xaxes(tickvals=[0, 1, 2], ticktext=axis_labels)

    # Save figure
    if logdir:
        if name is None:
            name = "GP_age_predictions"
        fig.write_image(logdir / f"{name}.svg")
        fig.write_image(logdir / f"{name}.png")
        fig.write_html(logdir / f"{name}.html")

    fig.show()

In [44]:
graph_gp_age(all_preds)

#### Fig 2F - GP-Age and MLP predictions for unknown harmonized_donor_life_stage

In [45]:
only_unknown_df = df_gp_age[df_gp_age[LIFE_STAGE] == "unknown"]
only_unknown_df = only_unknown_df[only_unknown_df["graph_age_categories"] != "unknown"]

In [46]:
graph_gp_age(only_unknown_df)

# TODO : get context for 'epiclass_pred' column, which still has some 'unknown' (why?)

### Fig 2G - Gene Ontology and SHAP values

See profile_bed.ipynb for creation of gene ontology files.  
The code compares important SHAP values regions of different cell types with gene gff (using the gProfiler module)  

In [47]:
# TODO: create a final paper version of profile_bed that works with paper data?

In [48]:
selected_cell_types = [
    "hepatocyte",
    "brain",
    "lymphocyte_of_B_lineage",
    "neutrophil",
    "T_cell",
]
go_terms = [
    "T cell receptor complex",
    "plasma membrane signaling receptor complex",
    "adaptive immune response",
    "receptor complex",
    "secretory granule",
    "secretory vesicle",
    "secretory granule membrane",
    "intracellular vesicle",
    "immunoglobulin complex",
    "immune response",
    "immune system process",
    "homophilic cell adhesion via plasma membrane adhesion molecules",
    "DNA binding",
    "cell-cell adhesion via plasma-membrane adhesion molecules",
    "RNA polymerase II cis-regulatory region sequence-specific DNA binding",
    "blood microparticle",
    "platelet alpha granule lumen",
    "fibrinogen complex",
    "endoplasmic reticulum lumen",
]

In [49]:
cell_type_shap_dir = (
    base_data_dir
    / "training_results"
    / "dfreeze_v2"
    / hdf5_type
    / f"{CELL_TYPE}_1l_3000n"
    / "10fold-oversampling"
    / "global_shap_analysis"
)
beds_file = cell_type_shap_dir / "select_beds_top303.tar.gz"
if not beds_file.exists():
    raise FileNotFoundError(f"File {beds_file} does not exist.")

In [50]:
go_dfs: Dict[str, pd.DataFrame] = {}
with tarfile.open(beds_file, "r:gz") as tar:
    for member in tar.getmembers():
        filename = member.name
        if (
            filename.endswith("profiler.tsv")
            and "merge_samplings" in filename
            and any(cell_type in filename for cell_type in selected_cell_types)
        ):
            with tar.extractfile(member) as f:  # type: ignore
                go_df = pd.read_csv(f, sep="\t", index_col=0)
                go_dfs[member.name] = go_df

In [51]:
all_names = set()
for go_df in go_dfs.values():
    all_names.update(go_df["name"].values)

if sum(term in all_names for term in go_terms) != len(go_terms):
    raise ValueError("Some terms not found.")

In [52]:
for name, df in go_dfs.items():
    sub_df = df[df["name"].isin(go_terms)].copy()
    sub_df.loc[:, "shap_source"] = re.match(r".*/merge_samplings_(.*)_features_intersect_gff_gprofiler.tsv", name).group(1)  # type: ignore
    sub_df.loc[:, "table_val"] = -np.log10(sub_df.loc[:, "p_value"])
    go_dfs[name] = sub_df

In [53]:
concat_df = pd.concat(go_dfs.values())

In [54]:
table = concat_df.pivot_table(
    index="name", columns="shap_source", values="table_val", aggfunc="mean"
)
table = table.loc[go_terms, selected_cell_types]
table.fillna(0, inplace=True)

In [55]:
print("GO term -log10(p-value) within SHAP values")
style = table.style.format(precision=2).background_gradient(cmap="Blues")
display(style)

GO term -log10(p-value) within SHAP values


shap_source,hepatocyte,brain,lymphocyte_of_B_lineage,neutrophil,T_cell
name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
T cell receptor complex,0.0,0.0,0.0,0.0,95.99
plasma membrane signaling receptor complex,0.0,0.0,0.0,0.0,65.06
adaptive immune response,0.0,0.0,20.41,0.0,55.2
receptor complex,0.0,0.0,0.0,0.0,41.23
secretory granule,0.0,0.0,0.0,20.96,0.0
secretory vesicle,0.0,0.0,0.0,19.11,0.0
secretory granule membrane,0.0,0.0,0.0,13.05,0.0
intracellular vesicle,0.0,0.0,0.0,11.37,0.0
immunoglobulin complex,0.0,0.0,36.45,0.0,0.0
immune response,0.0,0.0,20.96,0.0,33.95


### Fig 2G - Copy Number Variant (CNV) Signatures and SHAP values

See CNV_treatment.ipynb for the creation of the CNV stats.

In [56]:
# TODO: create a final paper version of CNV_treatment that works with paper data?

In [57]:
cnv_dir = base_data_dir / "CNV"
cnv_intersection_results = cnv_dir / "important_cancer_features_z_scores_vs_random200.tsv"
if not cnv_intersection_results.exists():
    raise FileNotFoundError(f"File {cnv_intersection_results} does not exist.")

cnv_df = pd.read_csv(cnv_intersection_results, sep="\t", index_col=0)
cnv_df.name = cnv_intersection_results.stem

In [58]:
def plot_cnv_zscores(cnv_df: pd.DataFrame, logdir: Path | None = None) -> None:
    """Plot z-scores of top SHAP features vs random feature sets.

    Args:
        cnv_df: The DataFrame with z-scores.
        logdir: The output directory to save the plot.
    """
    n_beds = int(cnv_df.name.split("random")[1])
    signature_subset_name = "EpiATLAS cancer types"

    CN_groups = [
        [f"CN{i}" for i in range(1, 4)],
        [f"CN{i}" for i in range(9, 13)],
        [f"CN{i}" for i in range(13, 17)],
        [f"CN{i}" for i in range(17, 18)],
        [f"CN{i}" for i in range(18, 22)],
        [f"CN{i}" for i in range(4, 9)],
    ]
    CN_names = [
        "CN1-CN3",
        "CN9-CN12",
        "CN13-CN16",
        "CN17",
        "CN18-CN21",
        "CN4-CN8",
    ]

    # Assign groups to the DataFrame
    cnv_df["group"] = "Other"
    for i, group in enumerate(CN_groups):
        cnv_df.loc[cnv_df.index.isin(group), "group"] = CN_names[i]

    # Sort groups
    group_medians = (
        cnv_df.groupby("group")["z_score"].median().sort_values(ascending=False)
    )
    sorted_CN_names = group_medians.index.tolist()

    # Create the figure
    fig = go.Figure()

    for group in sorted_CN_names:
        group_data = cnv_df[cnv_df["group"] == group]
        marker_size = 4 if group != "CN17" else 6

        # Add the box plot without points
        fig.add_trace(
            go.Box(
                y=group_data["z_score"],
                name=group,
                boxmean=True,
                boxpoints=False,  # Don't show points in the box plot
                line=dict(color="black"),
                fillcolor="rgba(255,255,255,0)",
                showlegend=False,
            )
        )

        # Add scatter plot for individual points
        fig.add_trace(
            go.Scatter(
                x=[group] * len(group_data),
                y=group_data["z_score"],
                mode="markers",
                marker=dict(
                    color="red",
                    size=marker_size,
                ),
                name=group,
                showlegend=False,
                text=group_data.index,  # Use CN names as hover text
                hoverinfo="text+y",  # Show CN name and y-value on hover
            )
        )
    # Update layout
    fig.update_layout(
        title={
            "text": f"Z-scores of top SHAP features (N=336) vs {n_beds} random feature sets of same size<br>on {signature_subset_name}"
        },
        xaxis_title="Cancer Type Group",
        yaxis_title="Z-score",
    )
    fig.add_hline(y=0, line_color="grey", line_width=0.8)

    # Show and save the figure
    if logdir:
        name = "important_cancer_features_z_scores_boxplot"
        fig.write_image(logdir / f"{name}.png")
        fig.write_image(logdir / f"{name}.svg")
        fig.write_html(logdir / f"{name}.html")

    fig.show()

In [59]:
plot_cnv_zscores(cnv_df)