In [None]:
"""Explore metadata distribution of correct/wrong predictions"""
# pylint: disable=line-too-long, redefined-outer-name, import-error, pointless-statement, use-dict-literal, expression-not-assigned, unused-import, too-many-lines, unreachable
from __future__ import annotations

from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import display

In [None]:
ASSAY = "assay_epiclass"
TRACK = "track_type"

BIOMATERIAL_TYPE = "harmonized_biomaterial_type"
CANCER = "harmonized_sample_cancer_high"
CELL_TYPE = "harmonized_sample_ontology_intermediate"
DISEASE = "harmonized_sample_disease_high"
LIFE_STAGE = "harmonized_donor_life_stage"
SEX = "harmonized_donor_sex"

In [None]:
results_path = (
    Path.home() / "downloads" / "temp" / "merged_pred_results_all_2.1_chrY_zscores.csv"
)
results_df = pd.read_csv(results_path, index_col="md5sum", header=0, low_memory=False)

In [None]:
results_df.head()

In [None]:
analysis_target = SEX

print([column for column in results_df.columns if SEX in column])

In [None]:
classifier_name = "harmonized_donor_sex_1l_3000n_w-mixed_10fold-oversample"

classifier_preds_colname = f"Predicted class {classifier_name}"
classifer_correct_colname = f"True class {classifier_name}"
classifer_same_colname = f"Same? {classifier_name}"

classifier_df = results_df[results_df[classifier_preds_colname].notnull()]
print(results_df.shape, classifier_df.shape)

In [None]:
display(classifier_df[classifier_preds_colname].value_counts())
display(classifier_df[classifer_correct_colname].value_counts())
display(classifier_df[classifer_same_colname].value_counts())

In [None]:
groupby_cols = [
    ASSAY,
    TRACK,
    CELL_TYPE,
    BIOMATERIAL_TYPE,
] + [analysis_target]

global_metadata_distribution = classifier_df.groupby(by=groupby_cols).size()

# pylint: disable=singleton-comparison
error_metadata_distribution = (
    classifier_df[classifier_df[classifer_same_colname] == True]
    .groupby(by=groupby_cols)
    .size()
)

In [None]:
print(global_metadata_distribution.sum(), error_metadata_distribution.sum())

In [None]:
error_count_df = []
for labels, global_count in global_metadata_distribution.items():
    error_count = error_metadata_distribution.get(labels, default=0)  # type: ignore
    error_rate = error_count / global_count  # type: ignore
    error_count_df.append(list(labels) + [error_rate, error_count, global_count])  # type: ignore

error_count_df = pd.DataFrame(
    error_count_df, columns=groupby_cols + ["error rate", "n error", "n total"]
)
assert error_count_df["n total"].sum() == global_metadata_distribution.sum()

error_count_df.to_csv(
    results_path.parent / f"{classifier_name}_error_rate.csv", index=False
)

In [None]:
px.scatter(
    data_frame=error_count_df,
    x="n total",
    y="error rate",
    hover_data=error_count_df.columns.values[0:5],
    color=TRACK,
)