In [None]:
"""Workbook to analyse encode predictions.
"""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines, unused-import, unused-argument, too-many-branches, pointless-statement

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import annotations

from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import seaborn as sns
from IPython.display import display
from sklearn.metrics import confusion_matrix

from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    ASSAY_MERGE_DICT,
    ASSAY_ORDER,
    CELL_TYPE,
    IHECColorMap,
    MetadataHandler,
    SplitResultsHandler,
    add_second_highest_prediction,
    display_perc,
)

# from plotly.subplots import make_subplots

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
paper_dir = base_dir

if not base_fig_dir.exists():
    raise FileNotFoundError(f"Directory {base_fig_dir} does not exist.")

In [None]:
metadata_handler = MetadataHandler(paper_dir)
split_results_handler = SplitResultsHandler()

## CELL TYPE

#### Getting GO info

In [None]:
encode_metadata_dir = base_data_dir / "metadata" / "encode"
curie_def_df = pd.read_csv(
    encode_metadata_dir / "EpiAtlas_list-curie_term_HSOI.tsv",
    sep="\t",
    names=["code", "term", CELL_TYPE],
)
encode_ontology_df = pd.read_csv(encode_metadata_dir / "encode_ontol+assay.tsv", sep="\t")

In [None]:
encode_ontology_df.shape

In [None]:
metadata_df = encode_ontology_df.merge(
    curie_def_df, left_on="Biosample term id", right_on="code", how="left"
)
metadata_df = metadata_df.drop(columns=["code", "term"])

In [None]:
display(metadata_df.shape)

In [None]:
metadata_df.head()

In [None]:
counts = metadata_df[CELL_TYPE].value_counts(dropna=False)
display(counts / counts.sum())

#### Missing harmonized_sample_ontology_intermediate details

In [None]:
# check term on missing CELL_TYPE
missing_cell_type = metadata_df[metadata_df[CELL_TYPE].isna()]
print(missing_cell_type.shape)

biosample_cols = ["Biosample term id", "Biosample term name"]

missing_count = missing_cell_type[biosample_cols].value_counts()
display(missing_count.shape)
with pd.option_context(
    "display.float_format",
    "{:.2f}".format,  # pylint: disable=consider-using-f-string
    "display.max_rows",
    None,
):
    display(missing_count / missing_count.sum() * 100)

In [None]:
t_cell_types = [
    name for name in missing_cell_type["Biosample term name"].unique() if "T cell" in name
]
b_cell_types = [
    name for name in missing_cell_type["Biosample term name"].unique() if "B cell" in name
]

In [None]:
t_cell_count = missing_cell_type[
    missing_cell_type["Biosample term name"].isin(t_cell_types)
][biosample_cols].value_counts()
display(t_cell_count, t_cell_count.sum())

In [None]:
b_cell_count = missing_cell_type[
    missing_cell_type["Biosample term name"].isin(b_cell_types)
][biosample_cols].value_counts()
display(b_cell_count, b_cell_count.sum())

perc_missing = (
    (t_cell_count.sum() + b_cell_count.sum()) / missing_cell_type.shape[0] * 100
)
print(f"t+b cells, percentage of missing cell types: {perc_missing:.2f}%")

#### Match predictions from various trainings with GO info

In [None]:
pred_folder = (
    base_data_dir
    / "training_results/dfreeze_v2/hg38_100kb_all_none/harmonized_sample_ontology_intermediate_1l_3000n/complete-no_valid-oversampling"
)

In [None]:
metadata_df["Assay"] = metadata_df["Assay"].str.lower()
df = metadata_df.dropna(subset=[CELL_TYPE])  # drop rows with missing cell type
df = df.dropna(subset=["Assay"])  # drop rows with missing assay
non_core_metadata_df = df[~df["Assay"].isin(ASSAY_ORDER)]

In [None]:
non_core_metadata_df.columns

In [None]:
# counts = metadata_df["Assay"].value_counts(dropna=False)
# print(len(counts))
# counts.to_csv(
#     path_or_buf=Path().home() / "downloads" / "encode_assay_counts.csv",
#     sep=",",
#     header=True,
# )

In [None]:
display(non_core_metadata_df[CELL_TYPE].value_counts(dropna=False))

In [None]:
# Only keep the predictions for the 16 cell types
accepted_ct = [
    "T cell",
    "neutrophil",
    "brain",
    "monocyte",
    "lymphocyte of B lineage",
    "myeloid cell",
    "venous blood",
    "macrophage",
    "mesoderm-derived structure",
    "endoderm-derived structure",
    "colon",
    "connective tissue cell",
    "hepatocyte",
    "mammary gland epithelial cell",
    "muscle organ",
    "extraembryonic cell",
]
print(non_core_metadata_df.shape)
metadata_16ct = non_core_metadata_df[non_core_metadata_df[CELL_TYPE].isin(accepted_ct)]
print(metadata_16ct.shape)

In [None]:
display(metadata_16ct["Assay"].value_counts(dropna=False))

In [None]:
pred_dfs_dict = {}
for folder in pred_folder.glob("*"):
    if not folder.is_dir():
        print(f"Skipping {folder}")
        continue
    pred_file = list(folder.glob("predictions/*.csv"))

    if len(pred_file) > 1:
        print(f"More than one prediction file found in {folder}")
        continue

    if len(pred_file) == 0:
        print(f"No prediction file found in {folder}")
        continue

    pred_file = pred_file[0]

    pred_df = pd.read_csv(pred_file)
    name = folder.name.replace("complete_no_valid_oversample_", "")
    pred_dfs_dict[name] = pred_df

In [None]:
groupby_cols = ["Assay", CELL_TYPE, "Predicted class", "correct_pred"]

for name, pred_df in sorted(pred_dfs_dict.items()):
    print(name)
    pred_w_ct = pred_df.merge(
        metadata_16ct, left_on="md5sum", right_on="ENC_ID", how="inner"
    )
    N = pred_w_ct.shape[0]

    # Calculate results for all predictions
    pred_w_ct["correct_pred"] = pred_w_ct["Predicted class"] == pred_w_ct[CELL_TYPE]
    counts = pred_w_ct.groupby(groupby_cols).size().sort_values(ascending=False)
    total_correct = counts.loc[:, :, :, True].sum()
    perc = total_correct / N
    print(f"Acc (pred>0.0) {total_correct}/{N} ({perc:.2%})")

    # Calculate results for predictions with max_pred > 0.8
    pred_w_ct_filtered = pred_w_ct[pred_w_ct["Max pred"] > 0.8]
    counts_filtered = (
        pred_w_ct_filtered.groupby(groupby_cols).size().sort_values(ascending=False)
    )
    total_correct_filtered = counts_filtered.loc[:, :, :, True].sum()
    perc_filtered = total_correct_filtered / pred_w_ct_filtered.shape[0]
    print(
        f"Acc (pred>0.8): {total_correct_filtered}/{pred_w_ct_filtered.shape[0]} ({perc_filtered:.2%})"
    )
    diff = N - pred_w_ct_filtered.shape[0]
    print(f"Samples ignored at 0.8: {diff} ({diff/N:.2%})\n")

    # Uncomment the following lines if you want to display additional information
    # if "assay" in name.lower():
    #     with pd.option_context(
    #         "display.float_format",
    #         "{:.3f}".format,
    #         "display.max_rows",
    #         None,
    #     ):
    #         values_count = pred_w_ct["Assay"].value_counts()
    #         # display(values_count)
    #         display(values_count / values_count.sum())
    #         display(counts)

In [None]:
def sns_confusion_matrix(pred_w_ct: pd.DataFrame):
    """Create a confusion matrix plot using seaborn."""
    pred_w_ct["Assay"] = pred_w_ct["Assay"].str.lower()
    pred_w_ct = pred_w_ct.dropna(subset=[CELL_TYPE])  # drop rows with missing cell type
    pred_w_ct = pred_w_ct[~pred_w_ct["Assay"].isin(ASSAY_ORDER)]
    pred_w_ct = pred_w_ct[pred_w_ct["Max pred"] > 0.8]

    # Count real samples for each cell type
    real_samples_count = pred_w_ct[CELL_TYPE].value_counts()

    # Create confusion matrix
    cm = confusion_matrix(
        pred_w_ct[CELL_TYPE], pred_w_ct["Predicted class"], labels=accepted_ct
    )

    # Convert to percentages (each row sums to 1)
    cm_percentage = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]

    # Create x-axis labels with sample counts
    ticklabels_w_count = [
        f"{ct}\n(n={real_samples_count.get(ct, 0)})" for ct in accepted_ct
    ]

    # Create a heatmap of the percentage-based confusion matrix
    plt.figure(figsize=(24, 20))  # Increased figure size
    sns.heatmap(
        cm_percentage,
        annot=True,
        fmt=".2%",
        cmap="Blues",
        xticklabels=accepted_ct,
        yticklabels=ticklabels_w_count,
        vmin=0,
        vmax=1,
        annot_kws={"size": 10},  # Increased annotation font size
        cbar_kws={"shrink": 0.8},
    )  # Adjust colorbar size

    plt.title(f"Confusion Matrix (%) for {name}", fontsize=20)
    plt.xlabel("Predicted", fontsize=16)
    plt.ylabel("Actual", fontsize=16)
    plt.xticks(fontsize=10, rotation=90, ha="center")
    plt.yticks(fontsize=12, rotation=0)

    # Adjust bottom margin to accommodate longer x-axis labels
    plt.gcf().subplots_adjust(bottom=0.2)

    plt.tight_layout()

    accuracy = np.trace(cm) / np.sum(cm)
    print(f"Accuracy: {accuracy:.2%} ({np.trace(cm)} / {np.sum(cm)})")

    plt.show()

In [None]:
for name, pred_df in sorted(pred_dfs_dict.items()):
    print(name)
    pred_df_w_ct = pred_df.merge(
        metadata_df, left_on="md5sum", right_on="ENC_ID", how="left"
    )
    # sns_confusion_matrix(pred_df_w_ct)

## ASSAY

Download note
~~~bash
paper_dir="/home/local/USHERBROOKE/rabj2301/Projects/epiclass/output/paper/data/training_results/dfreeze_v2/hg38_100kb_all_none/assay_epiclass_1l_3000n"
cd $paper_dir
base_path="/lustre06/project/6007515/rabyj/epiclass-project/output/epiclass-logs/epiatlas-dfreeze-v2.1/hg38_100kb_all_none/assay_epiclass_1l_3000n"
rsync -avR --exclude "*/EpiLaP/" --exclude "*.png" --exclude "*confusion*" --exclude "*.md5" narval:${base_path}/./*c/complete_no_valid_oversample .

paper_dir="/home/local/USHERBROOKE/rabj2301/Projects/epiclass/output/paper/data/training_results/dfreeze_v2"
cd $paper_dir
base_path="/lustre06/project/6007515/rabyj/epiclass-project/output/epiclass-logs/epiatlas-dfreeze-v2.1"
rsync -avR --exclude "*/EpiLaP/" --exclude "*.png" --exclude "*confusion*" --exclude "*.md5" narval:${base_path}/./hg38_100kb_all_none_w_encode_noncore/assay_epiclass_1l_3000n/complete_no_valid_oversample-0 .

find -type f -name "*.list*.csv" -print0 | xargs -0 rename 's/\.list//g'
~~~

In [None]:
data_dir = base_data_dir / "training_results" / "dfreeze_v2"
assay7_folder = (
    data_dir / f"hg38_100kb_all_none/{ASSAY}_1l_3000n/7c/complete_no_valid_oversample"
)
assay11_folder = (
    data_dir / f"hg38_100kb_all_none/{ASSAY}_1l_3000n/11c/complete_no_valid_oversample"
)
assay13_folder = (
    data_dir
    / f"hg38_100kb_all_none_w_encode_noncore/{ASSAY}_1l_3000n/complete_no_valid_oversample-0"
)

In [None]:
encode_metadata_path = encode_metadata_dir / "ENCODE_IHEC_keys.tsv"
core_metadata_df = pd.read_csv(encode_metadata_path, sep="\t")

In [None]:
display(core_metadata_df.head())
print(core_metadata_df.shape)

In [None]:
core_metadata_df["assay_epiclass"].value_counts(dropna=False)

In [None]:
pred_dfs_dict = {}
for name, folder in zip(
    ["7c", "11c", "13c"], [assay7_folder, assay11_folder, assay13_folder]
):
    if not folder.exists():
        print(f"Folder {folder} does not exist.")
        continue

    pred_folder = folder / "predictions" / "encode"
    if not pred_folder.exists():
        print(f"Folder {pred_folder} does not exist.")
        continue

    pred_file = list(pred_folder.glob("*.csv"))
    if len(pred_file) != 1:
        print(f"Found {len(pred_file)} files in {pred_folder}.")
        continue
    pred_file = pred_file[0]

    pred_df = pd.read_csv(pred_file, sep=",")
    try:
        pred_df.drop(columns=["Same?"], inplace=True)
    except KeyError:
        pass

    # Add assay metadata
    pred_df = pred_df.merge(
        core_metadata_df, left_on="md5sum", right_on="ENC_ID", how="left"
    )

    pred_df["True class"] = pred_df["assay_epiclass"]
    pred_dfs_dict[name] = pred_df

### Core7 preds

In [None]:
output_dir = data_dir = base_data_dir / "training_results" / "encode_predictions"
for name, df in pred_dfs_dict.items():
    print(name)
    # print(df.shape)

    # Only consider files already labeled with core7 assays
    df = df[df[ASSAY].isin(ASSAY_ORDER)]

    # Only consider non-EpiAtlas samples
    df = df[df["is_EpiAtlas_EpiRR"].isna()]

    # df.to_csv(output_dir / f"encode_only-core-{name}_predictions.csv", index=False)
    # break

    # Calculate results for all predictions
    correct_pred = df["Predicted class"] == df["True class"]
    total_correct = correct_pred.sum()
    total = df.shape[0]
    perc = total_correct / total
    print(f"Acc (pred>=0.0) {total_correct}/{total} ({perc:.2%})")

    # Calculate results for predictions with max_pred > 0.6
    df_filtered = df[df["Max pred"] >= 0.6]
    correct_pred_filtered = df_filtered["Predicted class"] == df_filtered["True class"]
    total_correct_filtered = correct_pred_filtered.sum()
    total_filtered = df_filtered.shape[0]
    perc_filtered = total_correct_filtered / total_filtered
    print(
        f"Acc (pred>=0.6): {total_correct_filtered}/{total_filtered} ({perc_filtered:.2%})"
    )

    df_filtered_wrong = df_filtered[~correct_pred_filtered]
    groupby = (
        df_filtered_wrong.groupby(["True class", "Predicted class"])
        .size()
        .sort_values(ascending=False)
    )
    display("Mislabels:", groupby)

    # df_filtered_wrong.to_csv(
    #     output_dir / f"encode_only_mislabels_minPred0.6_{name}.csv", index=False
    # )

### non-core 7c preds

In [None]:
# 7c preds on non-core assays
name = "7c"
df = pred_dfs_dict[name]
df = df.merge(metadata_df, left_on="md5sum", right_on="ENC_ID", how="left")
df = df[~df["Assay"].isin(ASSAY_ORDER)]

In [None]:
print(df.columns)
display(df["Assay"].value_counts(dropna=False))

In [None]:
output_dir = data_dir = (
    base_data_dir / "training_results" / "predictions" / "encode" / "assay_epiclass"
)
for min_pred in [0, 0.6, 0.8]:
    df_filtered = df[df["Max pred"] >= min_pred]
    groupby = (
        df_filtered.groupby(["Predicted class", "Assay"])
        .size()
        .reset_index(name="Count")
        .sort_values(["Predicted class", "Count"], ascending=[True, False])
        .set_index(["Predicted class", "Assay"])["Count"]
    )
    # groupby.to_csv(
    #     output_dir / f"encode_non-core_{name}_predictions_minPred{min_pred}.csv"
    # )

In [None]:
encode_metadata_dir = base_data_dir / "metadata/encode"
non_core_categories_path = encode_metadata_dir / "non-core_encode_assay_counts_v2.tsv"
if not non_core_categories_path.exists():
    raise FileNotFoundError(f"File {non_core_categories_path} does not exist.")

non_core_categories_df = pd.read_csv(non_core_categories_path, sep="\t")
print(non_core_categories_df.columns)

In [None]:
df_w_cats = df.merge(
    non_core_categories_df[["assay", "assay_category"]],
    left_on="Assay",
    right_on="assay",
    how="left",
)

In [None]:
df_w_cats.drop(columns=["assay"], inplace=True)

In [None]:
# print non-core assay categories for each predicted class
min_pred = 0.6
for predicted_class, group in df_w_cats.groupby("Predicted class"):
    print(predicted_class, group.shape[0])
    group = group[group["Max pred"] >= min_pred]
    print(f"min_pred={min_pred}: {group.shape[0]} samples left")
    groupby = (
        group.groupby(["assay_category", "Assay"])
        .size()
        .reset_index(name="Count")
        .sort_values(["assay_category", "Count"], ascending=[True, False])
        .set_index(["assay_category", "Assay"])["Count"]
    )
    with pd.option_context(
        "display.max_rows",
        None,
    ):
        # display(groupby)
        pass

In [None]:
def create_non_core_preds_df(df: pd.DataFrame, min_pred: float = 0.6):
    """Create a DataFrame of non-core assay predictions."""
    results = {}
    assay_categories = dict(zip(df["Assay"], df["assay_category"]))

    for assay, group in df.groupby("Assay"):
        # N = group.shape[0]
        # if N < 3:
        #     continue

        group = group[group["Max pred"] >= min_pred]
        # N_post_filter = group.shape[0]
        # if N_post_filter == 0 or N_post_filter < min_n:
        #     continue

        groupby = (
            group.groupby(["Predicted class"])
            .size()
            .reset_index(name="Count")  # type: ignore
            .sort_values(["Count"], ascending=False)
        )

        results[assay] = dict(zip(groupby["Predicted class"], groupby["Count"]))

    result_df = pd.DataFrame(results).fillna(0)
    result_df = result_df.astype(int)
    result_df = result_df.T  # assay as row/index
    result_df["Assay category"] = result_df.index.map(assay_categories)
    return result_df

In [None]:
predicted_classes_df = create_non_core_preds_df(df_w_cats, min_pred=0.6)
predicted_classes_df.to_csv(
    output_dir / f"encode_non-core_7c_predictions_per_assay_minPred{min_pred:.2f}.csv"
)

In [None]:
def create_structured_dataframe(df_w_cats):
    """Create a structured dataframe with the percentage of predictions for each assay category."""
    # Create an empty list to store our data
    data = []

    # Iterate through the grouped data
    for predicted_class, group in df_w_cats.groupby("Predicted class"):
        for min_pred in list(np.arange(0, 1, 0.05)) + [0.99]:
            df_filtered = group[group["Max pred"] >= min_pred]
            counts = df_filtered["assay_category"].value_counts(dropna=False)
            total = counts.sum()

            # Calculate percentages
            percentages = (counts / total * 100).round(2)

            # Add data for each assay category
            for assay_category, percentage in percentages.items():
                data.append(
                    {
                        "Predicted class": predicted_class,
                        "Min pred": min_pred,
                        "assay_category": assay_category,
                        "Percentage": percentage,
                        "Count": counts[assay_category],
                        "Total samples": total,
                    }
                )

    # Create the dataframe
    df_structured = pd.DataFrame(data)

    # Set the multi-index
    df_structured = df_structured.set_index(
        ["Predicted class", "Min pred", "assay_category"]
    )

    return df_structured

In [None]:
assay_category_df = create_structured_dataframe(df_w_cats)

# output_path = output_dir / "encode_non-core_7c_predictions_assay_category.csv"
# assay_category_df.to_csv(output_path)

In [None]:
def create_assay_category_graphs(df, output_dir: Path):
    """Graph assay category distribution for each predicted class."""
    # Get unique predicted classes
    predicted_classes = df.index.get_level_values("Predicted class").unique()

    graph_colors = {
        cat: px.colors.qualitative.Safe[i]
        for i, cat in enumerate(df["assay_category"].unique())
    }

    # Create a figure for each predicted class
    for predicted_class in predicted_classes:
        df_class = df.loc[predicted_class]

        # Get unique assay categories for this predicted class
        assay_categories = df_class.index.get_level_values("assay_category").unique()

        total_samples_at_zero = df_class.xs(0, level="Min pred")["Total samples"].iloc[0]

        # Create the figure
        fig = go.Figure()

        for assay_category in assay_categories:
            df_assay = df_class.xs(assay_category, level="assay_category")

            fig.add_trace(
                go.Scatter(
                    x=df_assay.index,
                    y=df_assay["Percentage"],
                    mode="lines+markers",
                    name=assay_category,
                    marker=dict(color=graph_colors[assay_category]),
                )
            )

        conserved_percentages = (
            df_class.groupby("Min pred")["Total samples"].first()
            / total_samples_at_zero
            * 100
        )
        fig.add_trace(
            go.Scatter(
                x=conserved_percentages.index,
                y=conserved_percentages.values,
                mode="lines+markers",
                name="Samples Conserved",
                line=dict(dash="dash", color="black"),
            )
        )

        # Update layout
        fig.update_layout(
            title=f"Composition for Predicted Class: {predicted_class}",
            xaxis_title="Min pred",
            yaxis_title="Percentage Composition",
            legend_title="Assay Category",
            hovermode="x unified",
        )

        fig.update_xaxes(range=[-0.01, 1.01])
        fig.update_yaxes(range=[0, 100])

        # Show the figure
        fig.show()

        fig.write_image(
            output_dir
            / f"encode_non-core_7c_predictions_assay_category_{predicted_class}.png"
        )

In [None]:
# Assuming df_structured is your dataframe from the previous step
create_assay_category_graphs(df=assay_category_df, output_dir=output_dir)