# Compute hits

Apply filters based on class imbalance and minimum number of classifiers per allele pair, compute the mean AUROC per batch, and determine whether each batch passed its respective threshold.

In [1]:
# imports
import polars as pl

In [11]:
# Read in data and set parameters
metrics_dir = "/dgx1nas1/storage/data/jess/varchamp/sc_data/classification_results/B7B8_1percent_updatedmeta"
metrics_df = pl.read_csv(f"{metrics_dir}/metrics.csv")
thresh = 3  # previously 10
min_class_num = 2

In [12]:
# Add useful columns (type, batch)
metrics_df = metrics_df.with_columns(
    pl.col("Plate").str.slice(13, 7).str.replace("R.*_", "").alias("Allele_set")
)

metrics_df = metrics_df.with_columns(
    pl.when(pl.col("Full_Classifier_ID").str.contains("true"))
    .then(pl.lit("localization"))
    .otherwise(pl.lit("morphology"))
    .alias("Classifier_type"),
    pl.when(pl.col("Full_Classifier_ID").str.contains("B7A"))
    .then(pl.lit("batch7"))
    .otherwise(pl.lit("batch8"))
    .alias("Batch"),
)

# Filter based on class imbalance
metrics_ctrl = (
    metrics_df.filter(
        (pl.col("Training_imbalance") < thresh) & (pl.col("Metadata_Control"))
    )
    .select(["Classifier_type", "Batch", "AUROC"])
    .group_by(["Classifier_type", "Batch"])
    .quantile(0.99)
).rename({"AUROC": "AUROC_thresh"})

print(metrics_ctrl)

shape: (4, 3)
┌─────────────────┬────────┬──────────────┐
│ Classifier_type ┆ Batch  ┆ AUROC_thresh │
│ ---             ┆ ---    ┆ ---          │
│ str             ┆ str    ┆ f64          │
╞═════════════════╪════════╪══════════════╡
│ morphology      ┆ batch8 ┆ 0.991648     │
│ localization    ┆ batch8 ┆ 0.819394     │
│ morphology      ┆ batch7 ┆ 0.971686     │
│ localization    ┆ batch7 ┆ 0.719226     │
└─────────────────┴────────┴──────────────┘


In [13]:
# Merge with metrics_df and decide whether it passed the threshold
metrics_df = metrics_df.join(metrics_ctrl, on=["Classifier_type", "Batch"])

In [14]:
# Must be at least min_class_num classifiers per batch
# Number of classifiers is the same for localization and morph, so just use morph
classifier_count = (
    metrics_df.filter(
        (pl.col("Training_imbalance") < thresh)
        & (~pl.col("Metadata_Control"))
        & (pl.col("Classifier_type") == "localization")
    )
    .group_by(["allele_0", "Allele_set", "Batch", "allele_1"])
    .agg([pl.len().alias("Number_classifiers")])
)
classifier_count = classifier_count.pivot(
    index=["allele_0", "allele_1", "Allele_set"],
    columns="Batch",
    values="Number_classifiers",
)

print(classifier_count.shape)
print(len(classifier_count.select("allele_0").to_series().unique().to_list()))
print(len(classifier_count.select("allele_1").to_series().unique().to_list()))

classifier_count = classifier_count.filter(
    (pl.col("batch7") >= min_class_num) & (pl.col("batch8") >= min_class_num)
)

print(classifier_count.shape)
print(len(classifier_count.select("allele_0").to_series().unique().to_list()))
print(len(classifier_count.select("allele_1").to_series().unique().to_list()))

# filter based on this
keep_alleles = classifier_count.select("allele_0").to_series().unique().to_list()
metrics_df = metrics_df.filter(
    ~((~pl.col("Metadata_Control")) & ~pl.col("allele_0").is_in(keep_alleles))
)

(573, 5)
573
113
(471, 5)
471
102


In [15]:
# Filter by imbalance and calculate mean AUROC for each batch
metrics_wtvar = (
    (
        metrics_df.filter(
            (pl.col("Training_imbalance") < thresh) & (~pl.col("Metadata_Control"))
        )
    )
    .select([
        "AUROC",
        "Classifier_type",
        "Batch",
        "AUROC_thresh",
        "allele_0",
        "trainsize_0",
        "testsize_0",
        "trainsize_1",
        "testsize_1",
        "Allele_set",
        "Training_imbalance",
    ])
    .group_by(["Classifier_type", "allele_0", "Allele_set", "Batch", "AUROC_thresh"])
    .agg([
        pl.all()
        .exclude(["Classifier_type", "allele_0", "Allele_set", "Batch", "AUROC_thresh"])
        .mean()
        .suffix("_mean")
    ])
)

# Write out results
metrics_wtvar.write_csv(f"{metrics_dir}/metrics_summary.csv")

