In [None]:
"""Workbook to quantify bias present in metadata
Q: Can you identify certain labels by using other metadata
e.g. find cell type using project+assay+other
"""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines, unused-import, unused-argument, too-many-branches

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from __future__ import annotations

from pathlib import Path
from typing import Any, Dict, List, Tuple

import numpy as np
import pandas as pd
from IPython.display import display  # pylint: disable=unused-import
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.svm import SVC

from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    ASSAY_MERGE_DICT,
    CELL_TYPE,
    LIFE_STAGE,
    SEX,
    MetadataHandler,
    SplitResultsHandler,
    create_mislabel_corrector,
)

In [4]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
paper_dir = base_dir

if not base_fig_dir.exists():
    raise FileNotFoundError(f"Directory {base_fig_dir} does not exist.")

In [5]:
metadata_handler = MetadataHandler(paper_dir)
metadata_df = metadata_handler.load_metadata_df("v2")
metadata = metadata_handler.load_metadata("v2")

split_results_handler = SplitResultsHandler()

## Evaluate bias in input samples classification

### Collect observed average accuracy

In [6]:
results_dir = base_data_dir / "training_results" / "dfreeze_v2" / "hg38_100kb_all_none"

exclusion = ["cancer", "random", "track", "disease", "second", "end"]
exclude_names = ["chip", "no-mixed", "ct", "7c"]

all_split_results = split_results_handler.general_split_metrics(
    results_dir=results_dir,
    exclude_categories=exclusion,
    exclude_names=exclude_names,
    merge_assays=True,
    mislabel_corrections=create_mislabel_corrector(paper_dir),
    return_type="split_results",
)

In [7]:
concat_split_results: Dict[str, pd.DataFrame] = split_results_handler.concatenate_split_results(all_split_results, concat_first_level=True)  # type: ignore

In [8]:
for cat_name, df in list(concat_split_results.items()):
    new_df = metadata_handler.join_metadata(df, metadata)
    concat_split_results[cat_name] = new_df

In [9]:
avg_input_acc = {}
for cat_name, df in list(concat_split_results.items()):
    # filtered_df = df[df[ASSAY] == "input"]
    filtered_df = df
    acc = (filtered_df["True class"] == filtered_df["Predicted class"]).sum() / len(
        filtered_df
    )
    avg_input_acc[cat_name] = acc

In [None]:
display(avg_input_acc)

In [11]:
avg_input_acc[SEX] = avg_input_acc["harmonized_donor_sex_w-mixed"]
concat_split_results[SEX] = concat_split_results["harmonized_donor_sex_w-mixed"]

avg_input_acc[ASSAY] = avg_input_acc["assay_epiclass_11c"]
concat_split_results[ASSAY] = concat_split_results["assay_epiclass_11c"]

### Compute max bias accuracy using metadata as input

In [13]:
def define_input_bias_categories(target_category: str) -> List[List[str]]:
    """Define bias categories used for bias analysis.

    Args:
        target_category (str): Classification target category. Is excluded from input lists.

    Returns:
        List[List[str]]: List of bias categories.
    """
    bias_categories_1 = [ASSAY, "project", "harmonized_biomaterial_type", CELL_TYPE]
    bias_categories_2 = [
        ASSAY,
        "project",
        "harmonized_biomaterial_type",
        CELL_TYPE,
        LIFE_STAGE,
    ]
    bias_categories_3 = [ASSAY, "project", "harmonized_biomaterial_type", CELL_TYPE, SEX]
    bias_categories_4 = [
        ASSAY,
        "project",
        "harmonized_biomaterial_type",
        CELL_TYPE,
        SEX,
        LIFE_STAGE,
    ]

    all_bias_categories = [
        bias_categories_1,
        bias_categories_2,
        bias_categories_3,
        bias_categories_4,
    ]
    for bias_categories in all_bias_categories:
        try:
            bias_categories.remove(target_category)
        except ValueError:
            pass
    return all_bias_categories

In [14]:
def create_models() -> List:
    """Create models for bias analysis."""
    lr_model_1 = LogisticRegression(
        solver="lbfgs", max_iter=1000, multi_class="multinomial", random_state=42
    )
    lr_model_2 = LogisticRegression(
        solver="lbfgs", max_iter=1000, multi_class="ovr", random_state=42
    )
    rf_model = RandomForestClassifier(n_estimators=100, random_state=42)
    svm_model = SVC(kernel="linear", random_state=42)
    svm_model_rbf = SVC(kernel="rbf", random_state=42)
    return [lr_model_1, lr_model_2, rf_model, svm_model, svm_model_rbf]

In [19]:
def filter_samples(
    metadata_df: pd.DataFrame, target_category: str, verbose: bool = True
) -> pd.DataFrame:
    """Filter samples based on the output category to match the original training set."""
    df = metadata_df.copy(deep=True)

    if "md5sum" not in df.columns:
        df["md5sum"] = df.index

    df = df[df["md5sum"].isin(concat_split_results[target_category]["md5sum"])]

    if verbose:
        print("Metadata shape:", metadata_df.shape)
        print("Filtered shape:", df.shape)
        display(df[target_category].value_counts())

    return df  # type: ignore


def find_max_bias(
    metadata_df: pd.DataFrame, target_category: str, verbose: bool = True
) -> Dict[Tuple[str, ...], float]:
    """Find the bias categories that provide the highest accuracy for the target category."""

    filtered_df = filter_samples(metadata_df, target_category)

    max_bias_dict = {}
    for bias_categories in define_input_bias_categories(target_category):
        print(f"Using bias categories: {bias_categories}")
        X = filtered_df[bias_categories]
        y = filtered_df[target_category]

        # one-hot encode the data
        X_encoded = OneHotEncoder().fit_transform(X).toarray()  # type: ignore
        y_encoded = LabelEncoder().fit_transform(y)

        max_acc = 0
        for model in create_models():
            scores = cross_val_score(
                model, X_encoded, y_encoded, cv=10, scoring="accuracy", n_jobs=-1
            )
            if verbose:
                print(f"Model: {model}")
                print(f"Accuracy: {np.mean(scores):.2f} (+/- {np.std(scores):.2f})")
            if np.mean(scores) > max_acc:
                max_acc = np.mean(scores)
                max_bias_dict[tuple(bias_categories)] = max_acc

    return max_bias_dict

In [None]:
def compute_all_max_bias(
    metadata_df: pd.DataFrame, target_categories: List[str], verbose: bool = True
) -> Dict[str, Any]:
    """Compute the max metadata bias for all target categories."""
    final_results: Dict[str, Any] = {}
    for target_category in target_categories:
        if verbose:
            print(f"Target category: {target_category}")

        max_bias_dict = find_max_bias(metadata_df, target_category)
        max_bias_cats, max_bias_acc = max(max_bias_dict.items(), key=lambda x: x[1])
        if verbose:
            print(f"Max bias categories: {max_bias_cats}")
            print(f"Max bias acc: {max_bias_acc}")

        MLP_acc = avg_input_acc[target_category]

        acc_to_compare = [
            acc for cat, acc in avg_input_acc.items() if cat in max_bias_cats
        ]
        avg_MLP_acc = np.mean(acc_to_compare)
        max_acc_with_bias = max_bias_acc * avg_MLP_acc

        if verbose:
            print("CLASSIFICATION ACCURACY")
            print(f"Average {target_category} observed acc: {MLP_acc:.1%}")
            print(f"Average MLP acc on bias categories: {avg_MLP_acc:.1%}")
            print(
                f"Max avg acc with bias from ({max_bias_cats}): {max_acc_with_bias:.1%}"
            )
            print(f"Not accounted for: {MLP_acc - max_acc_with_bias:.1%}\n")

        final_results[target_category] = {
            "max_bias_cats": max_bias_cats,
            "max_bias_acc": max_bias_acc,
            "MLP_acc": MLP_acc,
            "bias_avg_MLP_acc": avg_MLP_acc,
            "max_bias_acc_corrected": max_acc_with_bias,
            "acc_diff": MLP_acc - max_acc_with_bias,
        }

    return final_results

In [None]:
target_categories = ["project", "harmonized_biomaterial_type", CELL_TYPE, SEX, LIFE_STAGE]
final_results = compute_all_max_bias(metadata_df, target_categories)

final_results_df = pd.DataFrame.from_dict(final_results, orient="index")
final_results_df.to_csv("metadata_bias_analysis_results.csv")

## Correlation between results and metadata labels

cell type classification: check (input, ct) pairs for enrichment in any metadata category  
&emsp;Use:  
  - harmonized_donor_life_stage  
  - harmonized_donor_sex  
  - harmonized_sample_cancer_high  
  - harmonized_biomaterial  
  - paired_end  
  - project  

&emsp;find a 3rd factor metric, e.g. if any pair (assay, ct) subclass is very different from global dist, it can use that info as 3rd factor, and we're looking at assay specifically  
  - one score per pair, pearson w accuracy vector (one vector per assay)  

In [None]:
path_results_cell_type = (
    base_data_dir / "training_results" / "dfreeze_v2" / "hg38_10kb_all_none"
)
if not path_results_cell_type.exists():
    raise FileNotFoundError(f"{path_results_cell_type} does not exist.")

# Load split results into one combined dataframe
ct_split_dfs = split_results_handler.gather_split_results_across_categories(
    path_results_cell_type
)["harmonized_sample_ontology_intermediate_1l_3000n_10fold-oversampling"]
ct_full_df = pd.concat(ct_split_dfs.values(), axis=0)

# Load metadata and join with split results
ct_full_df = metadata_handler.join_metadata(ct_full_df, metadata)
ct_full_df[ASSAY].replace(ASSAY_MERGE_DICT, inplace=True)

In [None]:
def calculate_metadata_distribution(
    df: pd.DataFrame, columns: List[str]
) -> Dict[str, pd.Series]:
    """
    Calculates the percentage of metadata labels within specified columns of a DataFrame.

    Args:
        df: A pandas DataFrame containing the data.
        columns: A list of column names to analyze.

    Returns:
        A dictionary where keys are column names and values are Series objects containing
        the percentage of each unique label in the respective column.
    """
    distribution = {}
    nb_samples = len(df)
    for column in columns:
        # Count the occurrences of each unique value in the column
        value_counts = df[column].value_counts(dropna=False)
        # Calculate the percentages
        percentages = (value_counts / nb_samples) * 100
        # Store the results in the dictionary
        distribution[column] = percentages

    return distribution

In [None]:
def compare_label_ratios(
    target_distribution: Dict[str, pd.Series],
    comparison_distributions: List[Dict[str, pd.Series]],
    labels: List[str],
) -> Dict[str, pd.DataFrame]:
    """
    Compares label ratios of a target distribution against multiple comparison distributions,
    calculating the difference in percentage points for each label within each metadata category.

    Args:
        target_distribution: A dictionary of Series representing the target distribution for comparison.
        comparison_distributions: A list of dictionaries of Series, where each dictionary
                                  represents a distribution (e.g., assay, cell type, global) for comparison.
        labels: A list of labels corresponding to each distribution in `comparison_distributions`,
                used for labeling the columns in the result.

    Returns:
        A dictionary of DataFrames, where each DataFrame shows the difference in percentage points
        for each label in a metadata category between the target distribution and each of the
        comparison distributions.
    """
    comparison_results = {}
    for category, target_series in target_distribution.items():
        # Initialize a DataFrame to store comparison results for this category
        comparison_df = pd.DataFrame()

        for label, comparison_distribution in zip(labels, comparison_distributions):
            # Ensure the comparison distribution series for this category exists and align target with comparison
            comparison_series = comparison_distribution.get(
                category, pd.Series(dtype="float64")
            )
            aligned_target, aligned_comparison = target_series.align(
                comparison_series, fill_value=0
            )

            # Calculate difference in percentage points
            difference = aligned_target - aligned_comparison

            # Store the results in the comparison DataFrame
            comparison_df[f"Difference_vs_{label}"] = difference

        comparison_results[category] = comparison_df

    return comparison_results

In [None]:
metadata_categories = [
    "harmonized_donor_life_stage",
    "harmonized_donor_sex",
    "harmonized_sample_disease_high",
    "harmonized_biomaterial_type",
    "paired_end",
    "project",
]

In [None]:
def compute_third_factor_correlation(
    ct_full_df: pd.DataFrame,
    metadata_categories: List[str],
    save_full_details: bool = False,
    metric: str = "min",
):
    """
    Calculates the correlation between third factor influence and cell type classification accuracy for each assay.

    This function operates on classification results to evaluate how a third factor, represented by metadata category distributions,
    correlates with the accuracy of cell type classifications across assays. It involves comparing metadata distributions
    within assay and cell type groups to a global distribution, and then correlating these comparisons with classification
    accuracies.

    Args:
        df (pd.DataFrame): DataFrame with epigenomics data, including assays, cell types, and metadata for classification.
        metadata_categories (List[str]): A list of metadata categories to analyze.
        save_full_details (bool): Whether to save the full details of the comparison results to a CSV file.
        metric (str): The metric to apply on "Difference vs Global" col for correlation calculation. (max, min, abs_max, abs_min, abs_sum)
    """
    global_dist = calculate_metadata_distribution(ct_full_df, metadata_categories)
    subclass_distributions = {}
    comparison_results = (
        {}
    )  # Initialize a dict to hold comparison results for each subgroup

    for group in ct_full_df.groupby(ASSAY):
        label = group[0]
        sub_df = group[1]
        subclass_distributions[label] = calculate_metadata_distribution(
            sub_df, metadata_categories
        )

    for group in ct_full_df.groupby(CELL_TYPE):
        label = group[0]
        sub_df = group[1]
        subclass_distributions[label] = calculate_metadata_distribution(
            sub_df, metadata_categories
        )

    # Loop through each group and compare to global
    for group in ct_full_df.groupby([ASSAY, CELL_TYPE]):
        assay, cell_type = group[0]  # type: ignore
        sub_df = group[1]
        pair_subclass_dist = calculate_metadata_distribution(sub_df, metadata_categories)
        subclass_distributions[(assay, cell_type)] = pair_subclass_dist

        assay_dist = subclass_distributions[assay]
        cell_type_dist = subclass_distributions[cell_type]

        comparisons_dists = [assay_dist, cell_type_dist, global_dist]
        comparison_labels = [assay, cell_type, "global"]

        comparison_results[(assay, cell_type)] = compare_label_ratios(
            target_distribution=pair_subclass_dist,
            comparison_distributions=comparisons_dists,
            labels=comparison_labels,
        )

    pair_dfs = {}
    pairs_3rd_factor = {}
    for (assay, cell_type), comparisons in comparison_results.items():
        # Initialize an empty list to collect DataFrames for concatenation
        dfs_to_concat = []

        for category, df_comparison in comparisons.items():
            df_comparison.columns = [
                "Difference vs Assay",
                "Difference vs Cell Type",
                "Difference vs Global",
            ]
            # Add identifiers for the assay, cell type, and category
            df_comparison["Assay"] = assay
            df_comparison["Cell Type"] = cell_type
            df_comparison["Category"] = category

            subclass_dist = subclass_distributions[(assay, cell_type)][category]
            df_comparison["(assay, ct) subclass %"] = subclass_dist

            # Collect the DataFrame
            dfs_to_concat.append(df_comparison.reset_index())

        # Concatenate all DataFrames along rows
        final_df = pd.concat(dfs_to_concat, ignore_index=True)
        final_df.fillna(0, inplace=True)

        new_columns = final_df.columns.tolist()
        new_first = ["index", "Category", "(assay, ct) subclass %"]
        for label in new_first:
            new_columns.remove(label)
        new_columns = new_first + new_columns
        final_df = final_df[new_columns]

        pair_dfs[(assay, cell_type)] = final_df

        if metric == "max":
            val_3rd_factor = final_df["Difference vs Global"].max()
        elif metric == "min":
            val_3rd_factor = final_df["Difference vs Global"].min()
        elif metric == "abs_max":
            val_3rd_factor = final_df["Difference vs Global"].abs().max()
        elif metric == "abs_min":
            val_3rd_factor = final_df["Difference vs Global"].abs().min()
        elif metric == "abs_sum":
            val_3rd_factor = final_df["Difference vs Global"].abs().sum()

        pairs_3rd_factor[(assay, cell_type)] = val_3rd_factor

    # Subclass accuracy per assay
    assay_labels = sorted(ct_full_df[ASSAY].unique())
    ct_labels = sorted(ct_full_df[CELL_TYPE].unique())
    assay_accuracies = {}
    for assay_label in assay_labels:
        assay_df = ct_full_df[ct_full_df[ASSAY] == assay_label]

        # cell type subclass accuracy
        subclass_size = assay_df.groupby(["True class"]).agg("size")
        pred_confusion_matrix = assay_df.groupby(["True class", "Predicted class"]).agg(
            "size"
        )
        ct_accuracies = {
            ct_label: pred_confusion_matrix[ct_label][ct_label]
            / float(subclass_size[ct_label])
            for ct_label in sorted(ct_labels)
        }

        assay_accuracies[assay_label] = ct_accuracies

    # Concatenate all DataFrames along rows
    if save_full_details:
        all_pairs_df = pd.concat(pair_dfs, axis=0, ignore_index=True)
        for assay in assay_labels:
            for ct in ct_labels:
                all_pairs_df.loc[
                    (all_pairs_df["Assay"] == assay) & (all_pairs_df["Cell Type"] == ct),
                    "Accuracy",
                ] = assay_accuracies[assay][ct]

        all_pairs_df.columns = [
            "Label" if x == "index" else x for x in all_pairs_df.columns
        ]
        file_path = base_fig_dir / "flagship" / "metadata_comparison_all.csv"
        all_pairs_df.to_csv(file_path, index=False)

    pearson_3rd_factor = {}
    for assay, acc_dict in assay_accuracies.items():
        acc_vector = {ct: acc_dict[ct] for ct in ct_labels}
        acc_vector = pd.Series(acc_vector)

        diff_metric = {ct: pairs_3rd_factor[(assay, ct)] for ct in ct_labels}
        diff_metric = pd.Series(diff_metric)
        pearson = acc_vector.corr(diff_metric, method="pearson")
        pearson_3rd_factor[assay] = pearson

    return pearson_3rd_factor

In [None]:
# _ = compute_third_factor_correlation(ct_full_df, metadata_categories, save_full_details=True)

In [None]:
def compute_all_third_factor_correlations(ct_full_df, metadata_categories):
    """Compute third factor correlations for all supported metrics."""
    for metric in ["max", "min", "abs_max", "abs_min", "abs_sum"]:
        pearson_series = []
        for metadata_category in metadata_categories:
            pearson_dict = compute_third_factor_correlation(
                ct_full_df, [metadata_category], save_full_details=False, metric=metric
            )
            pearson_df = pd.DataFrame(pearson_dict, index=[metadata_category])
            pearson_series.append(pearson_df)

        full_pearson_df = pd.concat(pearson_series, axis=0)

        # Add max for each row and column
        full_pearson_df["Max"] = full_pearson_df.abs().max(axis=1)
        max_row = full_pearson_df.abs().max(axis=0)
        max_row.name = "Max"
        full_pearson_df = pd.concat([full_pearson_df, max_row.to_frame().T], axis=0)

        output_path = (
            base_fig_dir / "flagship" / f"3rd_factor_{metric}_pearson_correlation.csv"
        )
        full_pearson_df.to_csv(output_path, float_format="%.3f")

In [None]:
# compute_all_third_factor_correlations(ct_full_df, metadata_categories)

#### Various metadata proportions counts

In [None]:
# input_df = ct_full_df[ct_full_df[ASSAY] == "input"]
# ct_df = input_df[input_df[CELL_TYPE] == "neutrophil"]
# for cat in metadata_categories:
#     print(ct_df.value_counts(cat))

In [None]:
# metadata = metadata_handler.load_metadata("v2")
# metadata_2_uuid = UUIDMetadata.from_metadata(metadata)

# for md5 in list(metadata_2_uuid.md5s):
#     if md5 not in ct_full_df.index:
#         del metadata_2_uuid[md5]

# metadata_2_uuid.select_category_subsets("track_type", ["pval", "ctl_raw", "Unique_minusRaw", "gembs_neg", "Unique_raw"])
# uuid_df = pd.DataFrame.from_records(list(metadata_2_uuid.datasets))
# uuid_df.set_index("md5sum", inplace=True)
# uuid_df[ASSAY].replace(ASSAY_MERGE_DICT, inplace=True)
# uuid_df.groupby([CELL_TYPE, ASSAY]).aggregate("size").unstack().fillna(0).to_csv(base_fig_dir/"flagship"/"ct_assay_counts_unique_uuid.csv")

In [None]:
print("harmonized_biomaterial_type == primary cell")
df = ct_full_df[ct_full_df["harmonized_biomaterial_type"] == "primary cell"].value_counts(
    CELL_TYPE
)
df = df / df.sum() * 100
display(df.shape)
display(df)


# print("project == BLUEPRINT")
# df = ct_full_df[ct_full_df["project"] == "BLUEPRINT"].value_counts(CELL_TYPE)
print("project == NIH Roadmap Epigenomics")
df = ct_full_df[ct_full_df["project"] == "NIH Roadmap Epigenomics"].value_counts(
    CELL_TYPE
)
# print("project == CEEHRC")
# df = ct_full_df[ct_full_df["project"] == "CEEHRC"].value_counts(CELL_TYPE)
df = df / df.sum() * 100
display(df.shape)
display(df)

print("paired_end == FALSE")
df = ct_full_df[ct_full_df["paired_end"] == "FALSE"].value_counts(CELL_TYPE)
df = df / df.sum() * 100
display(df.shape)
display(df)

print("harmonized_sample_disease_high == Healthy/None")
df = ct_full_df[
    ct_full_df["harmonized_sample_disease_high"] == "Healthy/None"
].value_counts(CELL_TYPE)
df = df / df.sum() * 100
display(df.shape)
display(df)

print("harmonized_biomaterial_type distribution")
df = ct_full_df["harmonized_biomaterial_type"].value_counts()
df = df / df.sum() * 100
display(df.shape)
display(df)

In [None]:
intersection = ct_full_df[
    (ct_full_df[ASSAY] == "input")
    & (ct_full_df["harmonized_biomaterial_type"] == "primary cell")
    & (ct_full_df["project"] == "BLUEPRINT")
    & (ct_full_df["paired_end"] == "FALSE")
    & (ct_full_df["harmonized_sample_disease_high"] == "Healthy/None")
].value_counts(CELL_TYPE)
intersection = intersection / intersection.sum() * 100
display(intersection)

In [None]:
intersection = ct_full_df[
    (ct_full_df["harmonized_biomaterial_type"] == "primary cell")
    & (ct_full_df["project"] == "BLUEPRINT")
    & (ct_full_df["paired_end"] == "FALSE")
    & (ct_full_df["harmonized_sample_disease_high"] == "Healthy/None")
].value_counts(CELL_TYPE)
intersection = intersection / intersection.sum() * 100
display(intersection)

In [None]:
intersection = ct_full_df[
    (ct_full_df[ASSAY] == "input")
    & (ct_full_df["harmonized_biomaterial_type"] == "primary cell")
    & (ct_full_df["harmonized_sample_disease_high"] == "Healthy/None")
    & (ct_full_df["project"].isin(["NIH Roadmap Epigenomics", "CEEHRC"]))
].value_counts(CELL_TYPE)
intersection = intersection / intersection.sum() * 100
display(intersection)

In [None]:
intersection = ct_full_df[
    (ct_full_df["harmonized_biomaterial_type"] == "primary cell")
    & (ct_full_df["harmonized_sample_disease_high"] == "Healthy/None")
    & (ct_full_df["project"].isin(["NIH Roadmap Epigenomics", "CEEHRC"]))
].value_counts(CELL_TYPE)
intersection = intersection / intersection.sum() * 100
display(intersection)

In [None]:
intersection = ct_full_df[
    (ct_full_df["harmonized_biomaterial_type"] == "primary cell")
    & (ct_full_df["harmonized_sample_disease_high"] == "Healthy/None")
].value_counts(CELL_TYPE)
intersection = intersection / intersection.sum() * 100
display(intersection)

In [None]:
intersection = ct_full_df[
    (ct_full_df[ASSAY] == "input")
    & (ct_full_df["harmonized_biomaterial_type"] == "primary tissue")
    & (ct_full_df["project"].isin(["NIH Roadmap Epigenomics", "CEEHRC"]))
].value_counts(CELL_TYPE)
intersection = intersection / intersection.sum() * 100
display(intersection)

In [None]:
intersection = ct_full_df[
    (ct_full_df[ASSAY] == "input")
    & (ct_full_df["harmonized_biomaterial_type"] == "primary cell")
    & (ct_full_df["harmonized_sample_disease_high"] == "Healthy/None")
].value_counts(CELL_TYPE)
intersection = intersection / intersection.sum() * 100
display(intersection)

intersection = ct_full_df[
    (ct_full_df[ASSAY] == "input")
    & (ct_full_df["harmonized_biomaterial_type"] == "primary cell")
    & (ct_full_df["harmonized_sample_disease_high"] == "Healthy/None")
].value_counts(CELL_TYPE)
intersection = intersection / intersection.sum() * 100
display(intersection)

In [None]:
intersection = ct_full_df[
    (ct_full_df["harmonized_biomaterial_type"] == "primary tissue")
].value_counts(CELL_TYPE)
intersection = intersection / intersection.sum() * 100
display(intersection)

### Subclass (assay, ct, life_stage) accuracy

In [None]:
assay_labels = sorted(ct_full_df[ASSAY].unique())
ct_labels = sorted(ct_full_df[CELL_TYPE].unique())
life_stages = ct_full_df["harmonized_donor_life_stage"].unique().tolist()
life_stages.remove("unknown")

acc_list = []
for assay_label in assay_labels:
    assay_df = ct_full_df[ct_full_df[ASSAY] == assay_label]
    for ct in ct_labels:
        ct_df = assay_df[assay_df[CELL_TYPE] == ct]
        for life_stage in life_stages:
            life_stage_df = ct_df[ct_df["harmonized_donor_life_stage"] == life_stage]
            acc = life_stage_df["Predicted class"].eq(life_stage_df["True class"]).mean()
            size = len(life_stage_df)
            acc_list.append((assay_label, ct, life_stage, size, acc))

acc_df = pd.DataFrame(
    acc_list, columns=["Assay", "Cell Type", "Life Stage", "Size", "Accuracy"]
)

In [None]:
# acc_df.to_csv(base_fig_dir / "flagship" / "assay_ct_life_stage_accuracy.csv")