In [None]:
"""Explore metadata distribution of correct/wrong predictions"""
# pylint: disable=line-too-long, redefined-outer-name, import-error, pointless-statement, use-dict-literal, expression-not-assigned, unused-import, too-many-lines, unreachable
from __future__ import annotations

from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import display

In [None]:
ASSAY = "assay_epiclass"
TRACK = "track_type"

BIOMATERIAL_TYPE = "harmonized_biomaterial_type"
CANCER = "harmonized_sample_cancer_high"
CELL_TYPE = "harmonized_sample_ontology_intermediate"
DISEASE = "harmonized_sample_disease_high"
LIFE_STAGE = "harmonized_donor_life_stage"
SEX = "harmonized_donor_sex"

In [None]:
results_path = (
    Path.home()
    / "projects/epilap/output/logs/epiatlas-dfreeze-v2.1/merged_results/epiatlas/with_split_nb/merged_pred_results_all_2.1_chrY_zscores.csv"
)
results_df = pd.read_csv(results_path, index_col="md5sum", header=0, low_memory=False)

In [None]:
results_df.shape

In [None]:
analysis_target = SEX

print([column for column in results_df.columns if SEX in column])

In [None]:
classifier_name = "harmonized_donor_sex_1l_3000n_w-mixed_10fold-oversample"

classifier_preds_colname = f"Predicted class {classifier_name}"
classifer_correct_colname = f"True class {classifier_name}"
classifer_same_colname = f"Same? {classifier_name}"

classifier_df = results_df[results_df[classifier_preds_colname].notnull()]
print(results_df.shape, classifier_df.shape)

In [None]:
display(classifier_df[classifier_preds_colname].value_counts())
display(classifier_df[classifer_correct_colname].value_counts())
display(classifier_df[classifer_same_colname].value_counts())

In [None]:
def create_error_count_df(
    groupby_cols: List[str], classifier_df: pd.DataFrame
) -> pd.DataFrame:
    """Count errors by groupby_cols"""
    global_metadata_distribution = classifier_df.groupby(by=groupby_cols).size()

    # pylint: disable=singleton-comparison
    error_metadata_distribution = (
        classifier_df[classifier_df[classifer_same_colname] == False]
        .groupby(by=groupby_cols)
        .size()
    )

    error_count_df = []

    for labels, global_count in global_metadata_distribution.items():
        error_count = error_metadata_distribution.get(labels, default=0)  # type: ignore
        error_rate = error_count / global_count  # type: ignore
        if isinstance(labels, str):
            labels = [labels]
        error_count_df.append(list(labels) + [error_rate, error_count, global_count])  # type: ignore

    error_count_df = pd.DataFrame(
        error_count_df, columns=groupby_cols + ["error rate", "n error", "n total"]
    )
    if error_count_df["n total"].sum() != global_metadata_distribution.sum():
        raise ValueError("Error: n total does not match global count")

    return error_count_df

In [None]:
groupby_cols = [
    ASSAY,
    TRACK,
    CELL_TYPE,
    BIOMATERIAL_TYPE,
] + [analysis_target]

groupby_selections = [groupby_cols]
groupby_selections.extend([[col] for col in groupby_cols])

for groupy_selection in groupby_selections:
    error_df = create_error_count_df(groupy_selection, classifier_df)
    display(error_df.sort_values(by=["n error", "error rate"], ascending=False).head(10))
    # error_count_df.to_csv(
    #     results_path.parent / f"{classifier_name}_error_rate.csv", index=False
    # )
    fig = px.scatter(
        title=f"Error rate by {groupy_selection}",
        data_frame=error_df,
        x="n total",
        y="error rate",
        hover_data=groupy_selection,
        range_x=[0, max(error_df["n total"]) * 1.01],
    )
    fig.add_vline(x=25)

    fig.show()

In [None]:
filter_cond = (
    (classifier_df[ASSAY] == "h3k27ac")
    & (classifier_df[TRACK] == "pval")
    & (classifier_df[CELL_TYPE] == "brain")
)
filtered_df = classifier_df[filter_cond]

fig = px.violin(
    filtered_df,
    x="Predicted class harmonized_donor_sex_1l_3000n_w-mixed_10fold-oversample",
    y="expected_assay_track_chrY_z-score",
    color="harmonized_donor_sex",
    points="all",
    box=True,
)
fig.update_traces(marker=dict(size=3))
fig.show()

# fig = px.violin(
#     filtered_df,
#     x="Predicted class harmonized_donor_sex_1l_3000n_w-mixed_10fold-oversample",
#     y="",
#     color="harmonized_donor_sex",
#     points="all",
#     box=True,
#     )
# fig.update_traces(marker=dict(size=3))
# fig.show()

### Suspicious epiRRs

In [None]:
epirr_cat = "epirr_no_version"
to_verify_epirr = [
    "IHECRE00004623",
    "IHECRE00000171",
    "IHECRE00001957",
    "IHECRE00000152",
    "IHECRE00001531",
    "IHECRE00000951",
    "IHECRE00001965",
    "IHECRE00000099",
    "IHECRE00000316",
    "IHECRE00004877",
    "IHECRE00003706",
    "IHECRE00001370",
    "IHECRE00001001",
    "IHECRE00000954",
    "IHECRE00004890",
]

In [None]:
print(len(to_verify_epirr))

In [None]:
results_df[epirr_cat] = results_df["EpiRR"].str.split(".").str[0]

In [None]:
sex_df = results_df[results_df[SEX].notnull()]

In [None]:
sex_df[SEX].value_counts()

In [None]:
sus_df = sex_df[sex_df[epirr_cat].isin(to_verify_epirr)]

In [None]:
# sus_df[epirr_cat].value_counts()

In [None]:
name_10fold = "harmonized_donor_sex_1l_3000n_w-mixed_10fold-oversample"
pred_val_label_10fold = f"Max pred {name_10fold}"
pred_class_label_10fold = f"Predicted class {name_10fold}"
split_nb_10fold = f"split_nb {name_10fold}"

sus_df[split_nb_10fold] = sus_df[split_nb_10fold].fillna(-666).astype(int)

In [None]:
for epirr in to_verify_epirr:
    epirr_df = sus_df[sus_df[epirr_cat] == epirr]
    print(epirr)
    print(
        epirr_df[[ASSAY, pred_class_label_10fold, pred_val_label_10fold, split_nb_10fold]]
        .sort_values(ASSAY)
        .values,
        "\n",
    )

In [None]:
[col for col in sus_df.columns if "complete" in col]

In [None]:
sex_new_preds_name = (
    "harmonized_donor_sex_1l_3000n_w-mixed_complete_no_valid_oversample_predictions"
)
pred_val_label_new_preds = f"Max pred {sex_new_preds_name}"
pred_class_label_new_preds = f"Predicted class {sex_new_preds_name}"

In [None]:
for epirr in to_verify_epirr:
    epirr_df = sus_df[sus_df[epirr_cat] == epirr]
    if epirr_df[pred_val_label_new_preds].notnull().sum() == 0:
        continue
    print(epirr)
    partial_df = epirr_df[
        [ASSAY, pred_class_label_new_preds, pred_val_label_new_preds]
    ].sort_values(ASSAY)
    partial_df.columns = [ASSAY, "Predicted class", "Max pred"]
    print(partial_df.to_markdown(), "\n")

In [None]:
for epirr in to_verify_epirr:
    epirr_df = sus_df[sus_df[epirr_cat] == epirr]
    total_n = epirr_df.shape[0]
    try:
        print(f"{epirr}")
        for max_pred in [0, 0.7, 0.9]:
            subset_df = epirr_df[epirr_df[pred_val_label_10fold] >= max_pred]
            pivot = (
                subset_df.pivot_table(
                    values=pred_val_label_10fold,
                    index=pred_class_label_10fold,
                    columns=split_nb_10fold,
                    aggfunc="count",
                    margins=True,
                )
                .fillna(0)
                .astype(int)
            )

            count_pred = pivot["All"]
            f_count = count_pred.get("female", default=0)
            m_count = count_pred.get("male", default=0)
            mix_count = count_pred.get("mixed", default=0)
            count = count_pred["All"]

            if mix_count != 0:
                print(
                    f"pred>{max_pred} (n={count}/{total_n}): (F={f_count}, M={m_count}, mix={mix_count})"
                )
            else:
                print(
                    f"pred>{max_pred} (n={count}/{total_n}): (F={f_count}, M={m_count})"
                )
            print(f"Splits: {pivot.shape[1] - 1}")
            print(pivot.to_string(), "\n")
        print("\n")
    except ValueError:
        continue