# Debugging the classify module

In [19]:
"""Classification pipeline"""

import os
import sys
import warnings
from itertools import combinations
from typing import Union
import cupy as cp
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import pyarrow.parquet as pq
import xgboost as xgb
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
from tqdm.contrib.concurrent import thread_map

warnings.filterwarnings("ignore")
sys.path.append("..")
from utils import find_feat_cols, find_meta_cols, remove_nan_infs_columns


def classifier(df_train, df_test, log_file, target="Label", shuffle=False):
    """
    This function runs classification.
    """

    feat_col = find_feat_cols(df_train)
    feat_col.remove(target)

    x_train, y_train = cp.array(df_train[feat_col].to_numpy()), df_train[[target]]
    x_test, y_test = cp.array(df_test[feat_col].to_numpy()), df_test[[target]]

    num_pos = df_train[df_train[target] == 1].shape[0]
    num_neg = df_train[df_train[target] == 0].shape[0]

    if (num_pos == 0) or (num_neg == 0):
        log_file.write(f"Missing positive/negative labels in {df_train['Metadata_Plate'].unique()}, {df_train['Metadata_symbol'].unique()} wells: {df_train['Metadata_well_position'].unique()}")
        log_file.write(f"Size of pos: {num_pos}, Size of neg: {num_neg}\n")

        print(f"size of pos: {num_pos}, size of neg: {num_neg}")
        feat_importances = pd.Series(np.nan, index=df_train[feat_col].columns)
        return feat_importances, np.nan

    scale_pos_weight = num_neg / num_pos

    if (scale_pos_weight > 100) or (scale_pos_weight < 0.01):
        log_file.write(f"Extreme class imbalance in {df_train['Metadata_Plate'].unique()}, {df_train['Metadata_symbol'].unique()} wells: {df_train['Metadata_well_position'].unique()}")
        log_file.write(f"Scale_pos_weight: {scale_pos_weight}, Size of pos: {num_pos}, Size of neg: {num_neg}\n")
        print(
            f"scale_pos_weight: {scale_pos_weight}, size of pos: {num_pos}, size of neg: {num_neg}"
        )
        feat_importances = pd.Series(np.nan, index=df_train[feat_col].columns)
        return feat_importances, np.nan

    le = LabelEncoder()
    y_train = cp.array(le.fit_transform(y_train))
    y_test = cp.array(le.fit_transform(y_test))

    if shuffle:
        # Create shuffled train labels
        y_train_shuff = y_train.copy()
        y_train_shuff["Label"] = np.random.permutation(y_train.values)

    model = xgb.XGBClassifier(
        objective="binary:logistic",
        n_estimators=150,
        tree_method="hist",
        device="cuda",
        learning_rate=0.05,
        scale_pos_weight=scale_pos_weight,
    ).fit(x_train, y_train, verbose=False)

    # get predictions and scores
    pred_score = model.predict_proba(x_test)[:, 1]

    # Return classifier info
    info_0 = df_test[df_test["Label"] == 0].iloc[0]
    info_1 = df_test[df_test["Label"] == 1].iloc[0]
    class_ID = (
        info_0["Metadata_Plate"]
        + "_"
        + info_0["Metadata_well_position"]
        + "_"
        + info_1["Metadata_well_position"]
    )
    classifier_df = pd.DataFrame({
        "Classifier_ID": [class_ID],
        "Plate": [info_0["Metadata_Plate"]],
        "trainsize_0": [sum(y_train.get() == 0)],
        "testsize_0": [sum(y_test.get() == 0)],
        "well_0": [info_0["Metadata_well_position"]],
        "allele_0": [info_0["Metadata_gene_allele"]],
        "trainsize_1": [sum(y_train.get() == 1)],
        "testsize_1": [sum(y_test.get() == 1)],
        "well_1": [info_1["Metadata_well_position"]],
        "allele_1": [info_1["Metadata_gene_allele"]],
    })

    # Store feature importance
    feat_importances = pd.Series(
        model.feature_importances_, index=df_train[feat_col].columns
    )

    # Return cell-level predictions
    cellID = df_test.apply(
        lambda row: f"{row['Metadata_Plate']}_{row['Metadata_well_position']}_{row['Metadata_ImageNumber']}_{row['Metadata_ObjectNumber']}",
        axis=1,
    ).to_list()

    pred_df = pd.DataFrame({
        "Classifier_ID": class_ID,
        "CellID": cellID,
        "Label": y_test.get(),
        "Prediction": pred_score,
    })

    return feat_importances, classifier_df, pred_df


def get_classifier_features(dframe: pd.DataFrame, protein_feat: bool):
    """Helper function to get dframe containing protein or non-protein features"""
    feat_col = find_feat_cols(dframe)
    meta_col = find_meta_cols(dframe)

    if protein_feat:
        feat_col = [
            i
            for i in feat_col
            if ("GFP" in i)
            and ("DNA" not in i)
            and ("AGP" not in i)
            and ("Mito" not in i)
            and ("Brightfield" not in i)
        ]
    else:
        feat_col = [
            i
            for i in feat_col
            if ("GFP" not in i) and ("Brightfield" not in i) ## and ("AGP" not in i), test without AGP features
        ]

    dframe = pd.concat([dframe[meta_col], dframe[feat_col]], axis=1)
    return dframe


def stratify_by_plate(df_sampled: pd.DataFrame, plate: str):
    """Stratify dframe by plate"""
    # print(df_sampled.head())
    df_sampled_platemap = plate.split("_T")[0]
    platemaps = df_sampled[df_sampled["Metadata_Plate"].str.contains(df_sampled_platemap)]["Metadata_plate_map_name"].to_list()
    assert(len(set(platemaps))==1), "Only one platemap should be associated with plate: {plate}."
    platemap = platemaps[0]

    # Train on data from same platemap but other plates
    df_train = df_sampled[
        (df_sampled["Metadata_plate_map_name"] == platemap)
        & (df_sampled["Metadata_Plate"] != plate)
    ].reset_index(drop=True)

    df_test = df_sampled[df_sampled["Metadata_Plate"] == plate].reset_index(drop=True)
    # print(df_train.head())
    # print(df_test.head())

    return df_train, df_test


def experimental_runner(
    exp_dframe: pd.DataFrame,
    pq_writer,
    log_file,
    protein=True,
    group_key_one="Metadata_symbol",
    group_key_two="Metadata_gene_allele",
    threshold_key="Metadata_node_type",
):
    """
    Run Reference v.s. Variant experiments
    """
    exp_dframe = get_classifier_features(exp_dframe, protein)
    feat_cols = find_feat_cols(exp_dframe)
    feat_cols = [i for i in feat_cols if i != "Label"]

    group_list = []
    pair_list = []
    feat_list = []
    info_list = []

    log_file.write(f"Running XGBboost classifiers w/ protein {protein} on target variants:\n")
    groups = exp_dframe.groupby(group_key_one).groups
    for key in tqdm(groups.keys()):
        dframe_grouped = exp_dframe.loc[groups[key]].reset_index(drop=True)

        # Ensure this gene has both reference and variants
        if dframe_grouped[threshold_key].unique().size < 2:
            continue

        df_group_one = dframe_grouped[
            dframe_grouped[threshold_key] == "disease_wt"
        ].reset_index(drop=True)
        df_group_one["Label"] = 1

        subgroups = (
            dframe_grouped[dframe_grouped[threshold_key] == "allele"]
            .groupby(group_key_two)
            .groups
        )

        for subkey in subgroups.keys():
            df_group_two = dframe_grouped.loc[subgroups[subkey]].reset_index(drop=True)
            df_group_two["Label"] = 0
            plate_list = get_common_plates(df_group_one, df_group_two)

            ref_wells = df_group_one["Metadata_well_position"].unique()
            var_wells = list(df_group_two["Metadata_well_position"].unique())
            ref_var_pairs = [(ref_well, var_well) for ref_well in ref_wells for var_well in var_wells]
            df_sampled_ = pd.concat([df_group_one, df_group_two], ignore_index=True)

            for ref_var in ref_var_pairs:
                df_sampled = df_sampled_[df_sampled_["Metadata_well_position"].isin(ref_var)]
                def classify_by_plate_helper(plate):
                    df_train, df_test = stratify_by_plate(df_sampled, plate)
                    feat_importances, classifier_info, predictions = classifier(
                        df_train, df_test, log_file
                    )
                    return {plate: [feat_importances, classifier_info, predictions]}

                # try run classifier
                try:
                    result = thread_map(classify_by_plate_helper, plate_list)
                    pred_list = []
                    for res in result:
                        if len(list(res.values())[0]) == 3:
                            feat_list.append(list(res.values())[0][0])
                            group_list.append(key)
                            pair_list.append(f"{key}_{subkey}")
                            info_list.append(list(res.values())[0][1])
                            pred_list.append(list(res.values())[0][2])
                        else:
                            print("res length not 3!")
                            feat_list.append([None] * len(feat_cols))
                            group_list.append(key)
                            pair_list.append(f"{key}_{subkey}")
                            info_list.append([None] * 10)

                    cell_preds = pd.concat(pred_list, axis=0)
                    cell_preds["Metadata_Protein"] = protein
                    cell_preds["Metadata_Control"] = False
                    table = pa.Table.from_pandas(cell_preds, preserve_index=False)
                    pq_writer.write_table(table)
                except Exception as e:
                    print(e)
                    log_file.write(f"{key}, {subkey} error: {e}\n")
            break
        break

    # Store feature importance
    df_feat_one = pd.DataFrame({"Group1": group_list, "Group2": pair_list})
    df_feat_two = pd.DataFrame(feat_list)
    df_feat = pd.concat([df_feat_one, df_feat_two], axis=1)
    df_feat["Metadata_Protein"] = protein
    df_feat["Metadata_Control"] = False

    # process classifier info
    df_result = pd.concat(info_list, ignore_index=True)
    df_result["Metadata_Control"] = False

    log_file.write(f"Finished running XGBboost classifiers w/ protein {protein} on target variants.\n")
    log_file.write(f"===========================================================================\n\n")
    return df_feat, df_result


def get_common_plates(dframe1, dframe2):
    """Helper func: get common plates in two dataframes"""
    plate_list = list(
        set(list(dframe1["Metadata_Plate"].unique()))
        & set(list(dframe2["Metadata_Plate"].unique()))
    )
    return plate_list


def control_group_runner(
    ctrl_dframe: pd.DataFrame,
    pq_writer,
    log_file,
    group_key_one="Metadata_gene_allele",
    group_key_two="Metadata_plate_map_name",
    group_key_three="Metadata_well_position",
    threshold_key="Metadata_well_position",
    protein=True,
):
    """
    Run null control experiments.
    """
    ctrl_dframe = get_classifier_features(ctrl_dframe, protein)
    feat_cols = find_feat_cols(ctrl_dframe)
    feat_cols = [i for i in feat_cols if i != "Label"]

    group_list = []
    pair_list = []
    feat_list = []
    info_list = []

    log_file.write(f"Running XGBboost classifiers w/ protein {protein} on control alleles:\n")
    groups = ctrl_dframe.groupby(group_key_one).groups
    for key in tqdm(groups.keys()):
        # groupby alleles
        dframe_grouped = ctrl_dframe.loc[groups[key]].reset_index(drop=True)

        # Skip controls with no replicates
        if dframe_grouped[threshold_key].unique().size < 2:
            continue

        # group by platemap
        subgroups = dframe_grouped.groupby(group_key_two).groups
        for key_two in subgroups.keys():
            dframe_grouped_two = dframe_grouped.loc[subgroups[key_two]].reset_index(
                drop=True
            )
            # If a well is not present on all four plates, drop well
            well_count = dframe_grouped_two.groupby(["Metadata_Well"])[
                "Metadata_Plate"
            ].nunique()
            well_to_drop = well_count[well_count < 4].index
            dframe_grouped_two = dframe_grouped_two[
                ~dframe_grouped_two["Metadata_Well"].isin(well_to_drop)
            ].reset_index(drop=True)

            # group by well
            sub_sub_groups = dframe_grouped_two.groupby(group_key_three).groups
            sampled_pairs = list(combinations(list(sub_sub_groups.keys()), r=2))

            for idx1, idx2 in sampled_pairs:
                df_group_one = dframe_grouped_two.loc[sub_sub_groups[idx1]].reset_index(
                    drop=True
                )
                df_group_one["Label"] = 1
                df_group_two = dframe_grouped_two.loc[sub_sub_groups[idx2]].reset_index(
                    drop=True
                )
                df_group_two["Label"] = 0
                df_sampled = pd.concat([df_group_one, df_group_two], ignore_index=True)

                try:
                    plate_list = get_common_plates(df_group_one, df_group_two)
                    def classify_by_plate_helper(plate):
                        df_train, df_test = stratify_by_plate(df_sampled, plate)
                        feat_importances, classifier_info, predictions = classifier(
                            df_train, df_test, log_file
                        )
                        return {plate: [feat_importances, classifier_info, predictions]}

                    result = thread_map(classify_by_plate_helper, plate_list)

                    pred_list = []
                    for res in result:
                        if len(list(res.values())[0]) == 3:
                            feat_list.append(list(res.values())[0][0])
                            group_list.append(key)
                            pair_list.append(f"{idx1}_{idx2}")
                            info_list.append(list(res.values())[0][1])
                            pred_list.append(list(res.values())[0][2])
                        else:
                            print("res length does not equal three!")
                            feat_list.append([None] * len(feat_cols))
                            group_list.append(key)
                            pair_list.append(f"{idx1}_{idx2}")
                            info_list.append([None] * 10)

                    cell_preds = pd.concat(pred_list, axis=0)
                    cell_preds["Metadata_Protein"] = protein
                    cell_preds["Metadata_Control"] = True
                    table = pa.Table.from_pandas(cell_preds, preserve_index=False)
                    pq_writer.write_table(table)
                except Exception as e:
                    print(e)
                    log_file.write(f"{key}, {key_two} error: {e}, wells per ctrl: {sub_sub_groups}\n")
                break
            break

    # Store feature importance
    df_feat_one = pd.DataFrame({"Group1": group_list, "Group2": pair_list})
    df_feat_two = pd.DataFrame(feat_list)
    df_feat = pd.concat([df_feat_one, df_feat_two], axis=1)
    df_feat["Metadata_Protein"] = protein
    df_feat["Metadata_Control"] = True

    # process classifier info
    df_result = pd.concat(info_list, ignore_index=True)
    df_result["Metadata_Control"] = True

    log_file.write(f"Finished running XGBboost classifiers w/ protein {protein} on control alleles.\n")
    log_file.write(f"===========================================================================\n\n")
    return df_feat, df_result


def control_type_helper(col_annot: str):
    """helper func for annotating column "Metadata_control" """
    ## Only TC, NC, PC are used for constructing the null distribution because of multiple duplicates 
    if col_annot in ["TC", "NC", "PC"]:
        return True
    ## else labeled as not controls
    elif col_annot in ["disease_wt", "allele", "cPC", "cNC"]:
        return False
    else:
        return None


def add_control_annot(dframe):
    """annotating column "Metadata_control" """
    if "Metadata_control" not in dframe.columns:
        dframe["Metadata_control"] = dframe["Metadata_node_type"].apply(
            lambda x: control_type_helper(x)
        )
    return dframe


def drop_low_cc_wells(dframe, cc_thresh, log_file):
    # Drop wells with cell counts lower than the threshold
    dframe["Metadata_Cell_ID"] = dframe.index
    cell_count = (
        dframe.groupby(["Metadata_Plate", "Metadata_Well"])["Metadata_Cell_ID"]
        .count()
        .reset_index(name="Metadata_Cell_Count")
    )
    ## get the cell counts per well per plate
    dframe = dframe.merge(
        cell_count,
        on=["Metadata_Plate", "Metadata_Well"],
    )
    dframe_dropped = (
        dframe[dframe["Metadata_Cell_Count"] < cc_thresh]
    )
    ## keep track of the alleles in a log file
    log_file.write(f"Number of wells dropped due to cell counts < {cc_thresh}: {len((dframe_dropped['Metadata_Plate']+dframe_dropped['Metadata_Well']+dframe_dropped['Metadata_gene_allele']).unique())}\n")
    dframe_dropped = dframe_dropped.drop_duplicates(subset=["Metadata_Plate", "Metadata_Well"])
    if (dframe_dropped.shape[0] > 0):
        for idx in dframe_dropped.index:
            log_file.write(f"{dframe_dropped.loc[idx, 'Metadata_Plate']}, {dframe_dropped.loc[idx, 'Metadata_Well']}:{dframe_dropped.loc[idx, 'Metadata_gene_allele']}\n")
            # print(f"{dframe_dropped.loc[idx, 'Metadata_Plate']}, {dframe_dropped.loc[idx, 'Metadata_Well']}:{dframe_dropped.loc[idx, 'Metadata_gene_allele']}\n")
    ## keep only the wells with cc >= cc_thresh
    dframe = (
        dframe[dframe["Metadata_Cell_Count"] >= cc_thresh]
        .drop(columns=["Metadata_Cell_Count"])
        .reset_index(drop=True)
    )
    return dframe


def run_classify_workflow(
    input_path: str,
    feat_output_path: str,
    info_output_path: str,
    preds_output_path: str,
    cc_threshold: int,
    use_gpu: Union[str, None] = "0,1",
):
    """
    Run workflow for single-cell classification
    """
    if use_gpu is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = use_gpu

    # Initialize parquet for cell-level predictions
    if os.path.exists(preds_output_path):
        os.remove(preds_output_path)

    logfile_path = os.path.join(*[preds_output_path.split("/")[:-1], "classify.log"])

    schema = pa.schema([
        ("Classifier_ID", pa.string()),
        ("CellID", pa.string()),
        ("Label", pa.int64()),
        ("Prediction", pa.float32()),
        ("Metadata_Protein", pa.bool_()),
        ("Metadata_Control", pa.bool_()),
    ])
    writer = pq.ParquetWriter(preds_output_path, schema, compression="gzip")

    # Add CellID column
    dframe = (
        pl.scan_parquet(input_path)
        .with_columns(
            pl.concat_str(
                [
                    "Metadata_Plate",
                    "Metadata_well_position",
                    "Metadata_ImageNumber",
                    "Metadata_ObjectNumber",
                ],
                separator="_",
            ).alias("Metadata_CellID")
        )
        .collect()
        .to_pandas()
    )

    feat_col = find_feat_cols(dframe)

    try:
        assert (
            ~np.isnan(dframe[feat_col]).any().any()
        ), "Dataframe contains no NaN features."
        assert (
            np.isfinite(dframe[feat_col]).all().all()
        ), "Dataframe contains finite feature values."
    except AssertionError:
        dframe = remove_nan_infs_columns(dframe)

    # Filter rows with NaN Metadata
    dframe = dframe[~dframe["Metadata_well_position"].isna()]
    dframe = add_control_annot(dframe)
    dframe = dframe[~dframe["Metadata_control"].isna()]

    # Split data into controls and alleles
    df_exp = dframe[~dframe["Metadata_control"].astype("bool")].reset_index(drop=True)
    df_control = dframe[dframe["Metadata_control"].astype("bool")].reset_index(
        drop=True
    )

    # Remove any remaining TC from analysis
    df_control = df_control[df_control["Metadata_node_type"] != "TC"].reset_index(
        drop=True
    )

    with open(logfile_path, "w") as log_file:
        log_file.write(f"===============================================================================================================================================================\n")
        log_file.write("Dropping low cell count wells in control alleles:\n")
        print("Dropping low cell count wells in control alleles:\n")

        df_control = drop_low_cc_wells(df_control, cc_threshold, log_file)
        log_file.write("Dropping low cell count wells in ref. vs variant alleles:\n")
        print("Dropping low cell count wells in ref. vs variant alleles:")
        df_exp = drop_low_cc_wells(df_exp, cc_threshold, log_file)
        log_file.write(f"===============================================================================================================================================================\n\n")

        # Filter out wells with fewer than the cell count threhsold
        df_control = drop_low_cc_wells(df_control, cc_threshold, log_file)
        df_exp = drop_low_cc_wells(df_exp, cc_threshold, log_file)

        # Protein feature analysis
        df_feat_pro_con, df_result_pro_con = control_group_runner(
            df_control, pq_writer=writer, log_file=log_file, protein=True
        )
        df_feat_pro_exp, df_result_pro_exp = experimental_runner(
            df_exp, pq_writer=writer, log_file=log_file, protein=True
        )

        # Non-protein feature analysis
        df_feat_no_pro_con, df_result_no_pro_con = control_group_runner(
            df_control, pq_writer=writer, log_file=log_file, protein=False
        )
        df_feat_no_pro_exp, df_result_no_pro_exp = experimental_runner(
            df_exp, pq_writer=writer, log_file=log_file, protein=False
        )
        writer.close()

    # Concatenate results for both protein and non-protein
    df_feat = pd.concat(
        [df_feat_pro_con, df_feat_no_pro_con, df_feat_pro_exp, df_feat_no_pro_exp],
        ignore_index=True,
    )
    df_result = pd.concat(
        [
            df_result_pro_con,
            df_result_no_pro_con,
            df_result_pro_exp,
            df_result_no_pro_exp,
        ],
        ignore_index=True,
    )
    df_result = df_result.drop_duplicates()

    # Write out feature importance and classifier info
    df_feat.to_csv(feat_output_path, index=False)
    df_result.to_csv(info_output_path, index=False)

In [20]:
input_path = "../outputs/batch_profiles/2025_01_27_Batch_13/profiles_tcdropped_filtered_var_mad_outlier_featselect.parquet"
# Add CellID column
dframe = (
    pl.scan_parquet(input_path)
    .with_columns(
        pl.concat_str(
            [
                "Metadata_Plate",
                "Metadata_well_position",
                "Metadata_ImageNumber",
                "Metadata_ObjectNumber",
            ],
            separator="_",
        ).alias("Metadata_CellID")
    )
    .collect()
    .filter(pl.col("Metadata_symbol")=="CCM2")
    .to_pandas()
)
feat_col = find_feat_cols(dframe)

In [21]:
display(dframe.head())

Unnamed: 0,Metadata_plate_map_name,Metadata_well_position,Metadata_symbol,Metadata_gene_allele,Metadata_imaging_well,Metadata_imaging_plate_R1,Metadata_imaging_plate_R2,Metadata_node_type,Metadata_orf_id_wt,Metadata_ccsb_mutation_id,...,Cells_RadialDistribution_FracAtD_mito_tubeness_7of20,Nuclei_ObjectSkeleton_TotalObjectSkeletonLength_mito_skel,Cytoplasm_RadialDistribution_RadialCV_DNA_4of10,Cells_RadialDistribution_RadialCV_GFP_3of10,Nuclei_RadialDistribution_FracAtD_DNA_6of10,Nuclei_Texture_Contrast_AGP_20_00_256,Nuclei_Texture_SumAverage_GFP_5_01_256,Cytoplasm_Texture_InfoMeas1_Mito_20_02_256,Cytoplasm_Granularity_1_AGP,Metadata_CellID
0,B13A7A8P2_R1,M10,CCM2,CCM2,M10,B13A7A8P2_R1,B14A7A8P2_R2,disease_wt,3928.0,,...,2.239665,0.799854,3.747274,-0.160807,-2.421128,0.825395,-0.262621,-2.584699,-0.302897,2025_01_27_B13A7A8P2_T1_M10_2676_6
1,B13A7A8P2_R1,M10,CCM2,CCM2,M10,B13A7A8P2_R1,B14A7A8P2_R2,disease_wt,3928.0,,...,1.469021,-0.796329,1.919917,-0.198846,0.001183,-0.49864,-1.013328,-0.398445,-2.208567,2025_01_27_B13A7A8P2_T1_M10_2677_6
2,B13A7A8P2_R1,M10,CCM2,CCM2,M10,B13A7A8P2_R1,B14A7A8P2_R2,disease_wt,3928.0,,...,-1.130079,0.175046,1.994589,-0.937285,0.343746,0.12478,-0.433273,-1.471504,0.066759,2025_01_27_B13A7A8P2_T1_M10_2678_6
3,B13A7A8P2_R1,M10,CCM2,CCM2,M10,B13A7A8P2_R1,B14A7A8P2_R2,disease_wt,3928.0,,...,1.094945,2.568591,-2.04254,0.553863,-0.399926,0.771891,0.125791,1.048846,10.353196,2025_01_27_B13A7A8P2_T1_M10_2679_6
4,B13A7A8P2_R1,M10,CCM2,CCM2,M10,B13A7A8P2_R1,B14A7A8P2_R2,disease_wt,3928.0,,...,-1.429298,0.72838,4.118041,-0.258514,1.336622,2.344744,9.8e-05,1.436391,1.827263,2025_01_27_B13A7A8P2_T1_M10_2680_6


In [22]:
# Initialize parquet for cell-level predictions
batch = "2025_01_27_Batch_13"
pipeline = "profiles_tcdropped_filtered_var_mad_outlier_featselect_filtcells"
preds_output_path = f"../outputs/results/{batch}/{pipeline}/predictions.parquet"
cc_threshold = 20

if os.path.exists(preds_output_path):
    os.remove(preds_output_path)

if not os.path.exists(f"../outputs/results/{batch}/{pipeline}"):
    os.makedirs(f"../outputs/results/{batch}/{pipeline}")

logfile_path = os.path.join('/'.join(preds_output_path.split("/")[:-1]), "classify.log")
# logfile_path

schema = pa.schema([
    ("Classifier_ID", pa.string()),
    ("CellID", pa.string()),
    ("Label", pa.int64()),
    ("Prediction", pa.float32()),
    ("Metadata_Protein", pa.bool_()),
    ("Metadata_Control", pa.bool_()),
])
writer = pq.ParquetWriter(preds_output_path, schema, compression="gzip")

try:
    assert (
        ~np.isnan(dframe[feat_col]).any().any()
    ), "Dataframe contains no NaN features."
    assert (
        np.isfinite(dframe[feat_col]).all().all()
    ), "Dataframe contains finite feature values."
except AssertionError:
    dframe = remove_nan_infs_columns(dframe)

# Filter rows with NaN Metadata
dframe = dframe[~dframe["Metadata_well_position"].isna()]
dframe = add_control_annot(dframe)
dframe = dframe[~dframe["Metadata_control"].isna()]

# Split data into controls and alleles
df_exp = dframe[~dframe["Metadata_control"].astype("bool")].reset_index(drop=True)
df_control = dframe[dframe["Metadata_control"].astype("bool")].reset_index(
    drop=True
)

# Remove any remaining TC from analysis
df_control = df_control[df_control["Metadata_node_type"] != "TC"].reset_index(
    drop=True
)

with open(logfile_path, "w") as log_file:
    log_file.write(f"===============================================================================================================================================================\n")
    log_file.write("Dropping low cell count wells in control alleles:\n")
    print("Dropping low cell count wells in control alleles:\n")

    df_control = drop_low_cc_wells(df_control, cc_threshold, log_file)
    log_file.write("Dropping low cell count wells in ref. vs variant alleles:\n")
    print("Dropping low cell count wells in ref. vs variant alleles:")
    df_exp = drop_low_cc_wells(df_exp, cc_threshold, log_file)
    log_file.write(f"===============================================================================================================================================================\n\n")

    # Filter out wells with fewer than the cell count threhsold
    df_control = drop_low_cc_wells(df_control, cc_threshold, log_file)
    df_exp = drop_low_cc_wells(df_exp, cc_threshold, log_file)

    # Protein feature analysis
    # df_feat_pro_con, df_result_pro_con = control_group_runner(
    #     df_control, pq_writer=writer, log_file=log_file, protein=True
    # )
    df_feat_pro_exp, df_result_pro_exp = experimental_runner(
        df_exp, pq_writer=writer, log_file=log_file, protein=True
    )

    # # Non-protein feature analysis
    # df_feat_no_pro_con, df_result_no_pro_con = control_group_runner(
    #     df_control, pq_writer=writer, log_file=log_file, protein=False
    # )
    # df_feat_no_pro_exp, df_result_no_pro_exp = experimental_runner(
    #     df_exp, pq_writer=writer, log_file=log_file, protein=False
    # )
    # writer.close()

Dropping low cell count wells in control alleles:

Dropping low cell count wells in ref. vs variant alleles:


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/1 [00:07<?, ?it/s]


In [23]:
df_result_pro_exp

Unnamed: 0,Classifier_ID,Plate,trainsize_0,testsize_0,well_0,allele_0,trainsize_1,testsize_1,well_1,allele_1,Metadata_Control
0,2025_01_27_B13A7A8P2_T1_A24_M10,2025_01_27_B13A7A8P2_T1,1153,392,A24,CCM2_Ala123Val,1261,448,M10,CCM2,False
1,2025_01_27_B13A7A8P2_T2_A24_M10,2025_01_27_B13A7A8P2_T2,1258,287,A24,CCM2_Ala123Val,1327,382,M10,CCM2,False
2,2025_01_27_B13A7A8P2_T3_A24_M10,2025_01_27_B13A7A8P2_T3,1022,523,A24,CCM2_Ala123Val,1249,460,M10,CCM2,False
3,2025_01_27_B13A7A8P2_T4_A24_M10,2025_01_27_B13A7A8P2_T4,1202,343,A24,CCM2_Ala123Val,1290,419,M10,CCM2,False
4,2025_01_27_B13A7A8P2_T1_A24_B01,2025_01_27_B13A7A8P2_T1,1153,392,A24,CCM2_Ala123Val,1933,539,B01,CCM2,False
5,2025_01_27_B13A7A8P2_T2_A24_B01,2025_01_27_B13A7A8P2_T2,1258,287,A24,CCM2_Ala123Val,1851,621,B01,CCM2,False
6,2025_01_27_B13A7A8P2_T3_A24_B01,2025_01_27_B13A7A8P2_T3,1022,523,A24,CCM2_Ala123Val,1870,602,B01,CCM2,False
7,2025_01_27_B13A7A8P2_T4_A24_B01,2025_01_27_B13A7A8P2_T4,1202,343,A24,CCM2_Ala123Val,1762,710,B01,CCM2,False
8,2025_01_27_B13A7A8P2_T1_A24_B02,2025_01_27_B13A7A8P2_T1,1153,392,A24,CCM2_Ala123Val,771,252,B02,CCM2,False
9,2025_01_27_B13A7A8P2_T2_A24_B02,2025_01_27_B13A7A8P2_T2,1258,287,A24,CCM2_Ala123Val,742,281,B02,CCM2,False
