In [None]:
"""Workbook to quantify bias present in metadata
Q: Can you identify certain labels by using other metadata
e.g. find cell type using project+assay+other
"""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines, unused-import, unused-argument, too-many-branches, pointless-statement

In [None]:
%load_ext autoreload
%autoreload 2

In [33]:
from __future__ import annotations

from pathlib import Path

import numpy as np
from IPython.display import display
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.svm import SVC

from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    CELL_TYPE,
    LIFE_STAGE,
    SEX,
    MetadataHandler,
    SplitResultsHandler,
    create_mislabel_corrector,
)

In [34]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
paper_dir = base_dir

if not base_fig_dir.exists():
    raise FileNotFoundError(f"Directory {base_fig_dir} does not exist.")

In [35]:
metadata_handler = MetadataHandler(paper_dir)
metadata_df = metadata_handler.load_metadata_df("v2")
metadata = metadata_handler.load_metadata("v2")

split_results_handler = SplitResultsHandler()

## Evaluate bias in input samples classification

### Collect observed average accuracy

In [36]:
results_dir = base_data_dir / "training_results" / "dfreeze_v2" / "hg38_100kb_all_none"

exclusion = ["cancer", "random", "track", "disease", "second", "end"]
exclude_names = ["chip", "no-mixed", "ct", "7c"]

all_split_results = split_results_handler.general_split_metrics(
    results_dir=results_dir,
    exclude_categories=exclusion,
    exclude_names=exclude_names,
    merge_assays=True,
    mislabel_corrections=create_mislabel_corrector(paper_dir),
    return_type="split_results",
)

In [37]:
concat_split_results = split_results_handler.concatenate_split_results(all_split_results, concat_first_level=True)  # type: ignore

In [38]:
for cat_name, df in list(concat_split_results.items()):
    new_df = metadata_handler.join_metadata(df, metadata)
    concat_split_results[cat_name] = new_df

In [39]:
avg_input_acc = {}
for cat_name, df in list(concat_split_results.items()):
    # filtered_df = df[df[ASSAY] == "input"]
    filtered_df = df
    acc = (filtered_df["True class"] == filtered_df["Predicted class"]).sum() / len(
        filtered_df
    )
    avg_input_acc[cat_name] = acc

In [None]:
display(avg_input_acc)

In [41]:
avg_input_acc[SEX] = avg_input_acc["harmonized_donor_sex_w-mixed"]
concat_split_results[SEX] = concat_split_results["harmonized_donor_sex_w-mixed"]

avg_input_acc[ASSAY] = avg_input_acc["assay_epiclass_11c"]
concat_split_results[ASSAY] = concat_split_results["assay_epiclass_11c"]

### Compute max bias accuracy using metadata as input

In [42]:
output_category = CELL_TYPE

In [43]:
bias_categories_1 = [ASSAY, "project", "harmonized_biomaterial_type", CELL_TYPE]
bias_categories_2 = [
    ASSAY,
    "project",
    "harmonized_biomaterial_type",
    CELL_TYPE,
    LIFE_STAGE,
]
bias_categories_3 = [ASSAY, "project", "harmonized_biomaterial_type", CELL_TYPE, SEX]
bias_categories_4 = [
    ASSAY,
    "project",
    "harmonized_biomaterial_type",
    CELL_TYPE,
    SEX,
    LIFE_STAGE,
]

all_bias_categories = [
    bias_categories_1,
    bias_categories_2,
    bias_categories_3,
    bias_categories_4,
]
for bias_categories in all_bias_categories:
    try:
        bias_categories.remove(output_category)
    except ValueError:
        pass

In [None]:
# only consider input values that were actually in the training set
input_df = metadata_df
# input_df = metadata_df[metadata_df[ASSAY] == "input"]
input_df["md5sum"] = input_df.index

filtered_input_df = input_df[
    input_df["md5sum"].isin(concat_split_results[output_category]["md5sum"])
]
print(input_df.shape, filtered_input_df.shape)

In [None]:
display(filtered_input_df[output_category].value_counts())

In [46]:
lr_model_1 = LogisticRegression(
    solver="lbfgs", max_iter=1000, multi_class="multinomial", random_state=42
)
lr_model_2 = LogisticRegression(
    solver="lbfgs", max_iter=1000, multi_class="ovr", random_state=42
)
rf_model = RandomForestClassifier(n_estimators=100, random_state=42)
svm_model = SVC(kernel="linear", random_state=42)
svm_model_rbf = SVC(kernel="rbf", random_state=42)

In [None]:
max_bias_dict = {}
for bias_categories in all_bias_categories:
    print(f"Using bias categories: {bias_categories}")
    X = filtered_input_df[bias_categories]
    y = filtered_input_df[output_category]

    # one-hot encode the data
    X_encoded = OneHotEncoder().fit_transform(X).toarray()  # type: ignore
    y_encoded = LabelEncoder().fit_transform(y)

    max_acc = 0
    for model in [lr_model_1, lr_model_2, rf_model, svm_model, svm_model_rbf]:
        scores = cross_val_score(model, X_encoded, y_encoded, cv=10, scoring="accuracy")
        print(f"Model: {model}")
        print(f"Accuracy: {np.mean(scores):.2f} (+/- {np.std(scores):.2f})")
        if np.mean(scores) > max_acc:
            max_acc = np.mean(scores)
            max_bias_dict[tuple(bias_categories)] = max_acc

In [None]:
max_bias_cats, max_bias_acc = max(max_bias_dict.items(), key=lambda x: x[1])
display(max_bias_cats, max_bias_acc)

### Max acc estimation

In [49]:
acc_to_compare = [acc for cat, acc in avg_input_acc.items() if cat in max_bias_cats]
max_acc_with_bias = max_bias_acc * np.mean(acc_to_compare)

In [None]:
print("INPUT CLASSIFICATION ACCURACY")
print(f"Average {output_category} observed acc: {avg_input_acc[output_category]:.1%}")
print(f"Max avg acc with bias from ({max_bias_cats}): {max_acc_with_bias:.1%}")
print(f"Not accounted for: {avg_input_acc[output_category] - max_acc_with_bias:.1%}")