# Debugging the analysis module

In [3]:
## imports
import polars as pl
from sklearn.metrics import (
    average_precision_score,
    balanced_accuracy_score,
    confusion_matrix,
    f1_score,
    roc_auc_score,
)
from tqdm import tqdm


# Define a function to compute metrics for each group
def compute_aubprc(auprc, prior):
    ## calculate the balanced AUPRC based on a prior (incorporating the class imbalance)
    return (auprc * (1 - prior)) / ((auprc * (1 - prior)) + ((1 - auprc) * prior))


def compute_metrics(group):
    ## Calculate all the metrics for evaluating a classifier
    y_true = group["Label"].to_numpy()
    y_prob = group["Prediction"].to_numpy()
    y_pred = (y_prob > 0.5).astype(int)
    prior = sum(y_true == 1) / len(y_true)

    class_ID = group["Classifier_ID"].unique()[0]

    # Compute AUROC
    auroc = roc_auc_score(y_true, y_prob)

    # Compute AUPRC and balanced AUPRC
    auprc = average_precision_score(y_true, y_prob)
    aubprc = compute_aubprc(auprc, prior)

    # Compute macro-averaged F1 score
    macro_f1 = f1_score(y_true, y_pred, average="macro")

    # Compute sensitivity and specificity
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    sensitivity = tp / (tp + fn)
    specificity = tn / (tn + fp)

    # Compute balanced accuracy
    balanced_acc = balanced_accuracy_score(y_true, y_pred)

    return {
        "AUROC": auroc,
        "AUPRC": auprc,
        "AUBPRC": aubprc,
        "Macro_F1": macro_f1,
        "Sensitivity": sensitivity,
        "Specificity": specificity,
        "Balanced_Accuracy": balanced_acc,
        "Classifier_ID": class_ID,
    }


In [8]:
def calculate_class_metrics(classifier_info: str, predictions: str, metrics_file: str):
    batch_id = [subdir for subdir in classifier_info.split("/") if "Batch" in subdir][0]
    batch_id = f"B{batch_id.split('Batch_')[-1]}"

    ## read in classifier info
    class_info = pl.read_csv(classifier_info)

    ## calculate the class imbalance in training and testing per each classifier
    class_info = class_info.with_columns(
        (pl.col("trainsize_1") / (pl.col("trainsize_0") + pl.col("trainsize_1"))).alias(
            "train_prob_1"
        ),
        (pl.col("testsize_1") / (pl.col("testsize_0") + pl.col("testsize_1"))).alias(
            "test_prob_1"
        ),
    )

    ## read in predictions
    preds = pl.scan_parquet(predictions)
    preds = preds.with_columns(pl.lit(batch_id).alias("Batch")).collect()
    preds = preds.with_columns(
        pl.concat_str(
            [pl.col("Classifier_ID"), pl.col("Metadata_Protein"), pl.col("Batch")],
            separator="_",
        ).alias("Full_Classifier_ID")
    )

    ## Initialize an empty list to store the results
    results = []
    classIDs = preds.select("Full_Classifier_ID").to_series().unique().to_list()

    # Group by Classifier_ID and compute metrics for each group
    for id in tqdm(classIDs):
        metrics = compute_metrics(preds.filter(pl.col("Full_Classifier_ID") == id))
        metrics["Full_Classifier_ID"] = id
        results.append(metrics)

    # Convert the results to a Polars DataFrame
    metrics_df = pl.DataFrame(results)

    # Add classifier info and save
    metrics_df = metrics_df.join(class_info, on="Classifier_ID")
    metrics_df = metrics_df.with_columns(
        (
            pl.max_horizontal(["trainsize_0", "trainsize_1"])
            / pl.min_horizontal(["trainsize_0", "trainsize_1"])
        ).alias("Training_imbalance"),
        (
            pl.max_horizontal(["testsize_0", "testsize_1"])
            / pl.min_horizontal(["testsize_0", "testsize_1"])
        ).alias("Testing_imbalance"),
    )
    metrics_df.write_csv(metrics_file)
    return metrics_df


def compute_hits(metrics_file: str, metrics_summary_file: str, trn_imbal_thres: int, min_num_classifier: int):
    metrics_df = pl.read_csv(metrics_file)
    
    batch_id = [subdir for subdir in metrics_file.split("/") if "Batch" in subdir][0]
    batch_id = f"B{batch_id.split('Batch_')[-1]}"

    # Add useful columns (type, batch)
    metrics_df = metrics_df.with_columns(
        #  Group the same alleles from platemap together, by extracting the substring that:
        #  1. Has a digit (\d) immediately before it (anchors the match at a number)
        #  2. Starts with 'A' and then as few characters as needed (A.*?), captured as group 1
        #  3. Stops right before the literal 'T'
        pl.col("Plate").str.extract(r"\d(A.*?)T", 1).alias("Allele_set"),
    )

    metrics_df = metrics_df.with_columns(
        pl.when(pl.col("Full_Classifier_ID").str.contains("true"))
        .then(pl.lit("localization"))
        .otherwise(pl.lit("morphology"))
        .alias("Classifier_type"),
        pl.col("Full_Classifier_ID").str.split("_").list.last().alias("Batch"),
    )

    # Filter based on class imbalance
    metrics_ctrl = (
        metrics_df.filter(
            (pl.col("Training_imbalance") < trn_imbal_thres) & (pl.col("Metadata_Control"))
        )
        .select(["Classifier_type", "Batch", "AUROC"])
        .group_by(["Classifier_type", "Batch"])
        .quantile(0.99)
    ).rename({"AUROC": "AUROC_thresh"})

    # Merge with metrics_df and decide whether it passed the threshold
    metrics_df = metrics_df.join(metrics_ctrl, on=["Classifier_type", "Batch"])

    # Must be at least min_class_num classifiers per batch
    # Number of classifiers is the same for localization and morph, so just use morph
    classifier_count = (
        metrics_df.filter(
            (~pl.col("Metadata_Control"))
            & (pl.col("Classifier_type") == "localization")
        )
        .group_by(["allele_0", "Allele_set", "Batch", "allele_1"])
        .agg([pl.len().alias("Number_classifiers")])
    )
    classifier_count = classifier_count.pivot(
        index=["allele_0", "allele_1", "Allele_set"],
        on="Batch",
        values="Number_classifiers",
    )
    print("Total number of unique classifiers:", classifier_count.shape[0])
    print("Total number of unique variant alleles:", len(classifier_count.select("allele_0").to_series().unique().to_list()))
    print("Total number of unique WT genes:", len(classifier_count.select("allele_1").to_series().unique().to_list()))
    print("==========================================================================")

    # Must be at least min_class_num classifiers per batch
    # Number of classifiers is the same for localization and morph, so just use morph
    classifier_count = (
        metrics_df.filter(
            (pl.col("Training_imbalance") < trn_imbal_thres)
            & (~pl.col("Metadata_Control"))
            & (pl.col("Classifier_type") == "localization")
        )
        .group_by(["allele_0", "Allele_set", "Batch", "allele_1"])
        .agg([pl.len().alias("Number_classifiers")])
    )
    classifier_count = classifier_count.pivot(
        index=["allele_0", "allele_1", "Allele_set"],
        on="Batch",
        values="Number_classifiers",
    )
    print("After filtering out classifiers with training imbalance > 3:")
    print("Total number of unique classifiers:", classifier_count.shape)
    print("Total number of unique variant alleles:", len(classifier_count.select("allele_0").to_series().unique().to_list()))
    print("Total number of unique WT genes:", len(classifier_count.select("allele_1").to_series().unique().to_list()))
    print("==========================================================================")

    classifier_count = classifier_count.filter(
        (pl.col(batch_id) >= min_num_classifier)
    )
    print("After filtering out alleles with available number of classifiers < 2:")
    print("Total number of unique classifiers:", classifier_count.shape)
    print("Total number of unique variant alleles:", len(classifier_count.select("allele_0").to_series().unique().to_list()))
    print("Total number of unique WT genes:", len(classifier_count.select("allele_1").to_series().unique().to_list()))

    # filter based on this
    keep_alleles = classifier_count.select("allele_0").to_series().unique().to_list()
    metrics_df = metrics_df.filter(
        ~((~pl.col("Metadata_Control")) & ~pl.col("allele_0").is_in(keep_alleles))
    )

    # Filter by imbalance and calculate mean AUROC for each batch
    metrics_wtvar = (
        (
            metrics_df.filter(
                (~pl.col("Metadata_Control")) ## pl.col("Training_imbalance") < trn_imbal_thres) & 
            )
        )
        .select([
            "AUROC",
            "Classifier_type",
            "Batch",
            "AUROC_thresh",
            "allele_0",
            "trainsize_0",
            "testsize_0",
            "trainsize_1",
            "testsize_1",
            "Allele_set",
            "Training_imbalance",
        ])
        .group_by(["Classifier_type", "allele_0", "Allele_set", "Batch", "AUROC_thresh"])
        .agg([
            pl.all()
            .exclude(["Classifier_type", "allele_0", "Allele_set", "Batch", "AUROC_thresh"])
            .mean()
            .name.suffix("_mean")
        ])
    )
    return metrics_wtvar
    
    # Write out results
    # metrics_wtvar.write_csv(metrics_summary_file)

In [9]:
metrics_file = "../outputs/classification_analyses/2025_01_27_Batch_13/profiles_tcdropped_filtered_var_mad_outlier_featselect/metrics.csv"
compute_hits(metrics_file, "", trn_imbal_thres=3, min_num_classifier=2)

Total number of unique classifiers: 439
Total number of unique variant alleles: 439
Total number of unique WT genes: 23
After filtering out classifiers with training imbalance > 3:
Total number of unique classifiers: (385, 4)
Total number of unique variant alleles: 385
Total number of unique WT genes: 23
After filtering out alleles with available number of classifiers < 2:
Total number of unique classifiers: (360, 4)
Total number of unique variant alleles: 360
Total number of unique WT genes: 23


Classifier_type,allele_0,Allele_set,Batch,AUROC_thresh,AUROC_mean,trainsize_0_mean,testsize_0_mean,trainsize_1_mean,testsize_1_mean,Training_imbalance_mean
str,str,str,str,f64,f64,f64,f64,f64,f64,f64
"""localization""","""KRAS_Gly12Asp""","""A7A8P1_""","""B13""",0.923142,0.75201,1269.0,423.0,1374.0,458.0,1.147595
"""localization""","""CCM2_Gly407Asp""","""A7A8P2_""","""B13""",0.923142,0.783035,1341.75,447.25,3903.0,1301.0,2.915331
"""localization""","""F9_Gly442Arg""","""A7A8P1_""","""B13""",0.923142,0.549513,1122.0,374.0,757.5,252.5,1.484793
"""morphology""","""RET_Val262Ala""","""A7A8P2_""","""B13""",0.937407,0.835261,870.75,290.25,1044.0,348.0,1.206279
"""localization""","""BRAF_Gly596Val""","""A7A8P1_""","""B13""",0.923142,0.893082,1570.5,523.5,1419.0,473.0,1.101321
…,…,…,…,…,…,…,…,…,…,…
"""morphology""","""BRCA1_Ala102Gly""","""A7A8P1_""","""B13""",0.937407,0.877495,1113.75,371.25,1239.0,413.0,1.170862
"""morphology""","""BRCA1_Thr587Arg""","""A7A8P1_""","""B13""",0.937407,0.905352,1232.25,410.75,1239.0,413.0,1.145495
"""morphology""","""RPS19_Arg62Trp""","""A7A8P2_""","""B13""",0.937407,0.725234,420.75,140.25,442.5,147.5,1.075846
"""localization""","""SDHD_His102Pro""","""A7A8P2_""","""B13""",0.923142,0.623399,1325.25,441.75,1876.5,625.5,1.418962


In [None]:
metrics_file = "../outputs/classification_analyses/2025_01_27_Batch_13/profiles_tcdropped_filtered_var_mad_outlier_featselect/metrics_summary.csv"
metrics_summ = pl.read_csv(metrics_file)
metrics_summ

Classifier_type,allele_0,Allele_set,Batch,AUROC_thresh,AUROC_mean,trainsize_0_mean,testsize_0_mean,trainsize_1_mean,testsize_1_mean,Training_imbalance_mean
str,str,str,str,f64,f64,f64,f64,f64,f64,f64
"""morphology""","""F9_Arg3His""",,"""B13""",0.937407,0.619266,554.25,184.75,757.5,252.5,1.366262
"""morphology""","""CCM2_Gly5Ser""",,"""B13""",0.937407,0.76576,1458.75,486.25,3903.0,1301.0,2.684427
"""morphology""","""CTCF_Asp46Asn""",,"""B13""",0.937407,0.79788,363.75,121.25,720.75,240.25,2.013674
"""localization""","""BRAF_Gln257Arg""",,"""B13""",0.923142,0.709575,810.0,270.0,1419.0,473.0,1.843124
"""localization""","""RAD51D_Val56Gly""",,"""B13""",0.923142,0.696832,1219.5,406.5,984.0,328.0,1.237085
…,…,…,…,…,…,…,…,…,…,…
"""morphology""","""CCM2_Arg412Gln""",,"""B13""",0.937407,0.653131,2172.75,724.25,3903.0,1301.0,1.822314
"""morphology""","""F9_Ile7Phe""",,"""B13""",0.937407,0.626554,598.5,199.5,757.5,252.5,1.268307
"""morphology""","""KRAS_Gly12Asp""",,"""B13""",0.937407,0.786289,1269.0,423.0,1374.0,458.0,1.147595
"""localization""","""F9_Glu79Asp""",,"""B13""",0.923142,0.583162,494.25,164.75,757.5,252.5,1.528803
