In [1]:
# ============================================================
# ResFinder Evaluation for Salmonella enterica and Staphylococcus aureus
# ============================================================

from google.colab import drive
import shutil, os, glob
import pandas as pd
import numpy as np
from sklearn.metrics import (
    balanced_accuracy_score, precision_score, recall_score,
    f1_score, fbeta_score, roc_auc_score, jaccard_score,
    matthews_corrcoef, confusion_matrix
)

# ============================================================
# STEP 0: Mount Google Drive and copy files to /content/
# ============================================================

drive.mount('/content/drive')

FILES_TO_COPY = [
    # ABRicate raw predictions (same file for both species)
    ('/content/drive/MyDrive/MRSA datasets/Metadata Tables/Predicted Labels/ABRicate Labels/Master_Table_with_abricate_labels_resfinder.csv',
     'Master_Table_with_abricate_labels_resfinder.csv'),
    # Master Table
    ('/content/drive/MyDrive/MRSA datasets/Master_Table.csv',
     'Master_Table.csv'),
    # Species-specific model tables
    ('/content/drive/MyDrive/MRSA datasets/Tables/Salmonella_Model_Table_clusters.csv',
     'Salmonella_Model_Table_clusters.csv'),
    ('/content/drive/MyDrive/MRSA datasets/Tables/Staphylococcus_Model_Table_clusters.csv',
     'Staphylococcus_Model_Table_clusters.csv'),
]

for src_path, dst_name in FILES_TO_COPY:
    dst = os.path.join('/content', dst_name)
    if os.path.exists(src_path):
        shutil.copy(src_path, dst)
        print(f"Copied: {dst_name}")
    else:
        print(f"MISSING: {src_path}")
        search = glob.glob('/content/drive/MyDrive/MRSA datasets/**/*abricate*resfinder*', recursive=True)
        if search:
            print(f"  FOUND similar files:")
            for s in search: print(f"    {s}")
print()

# ============================================================
# STEP 1: Configuration
# ============================================================

FILTER_TEST_SET = False
SPLIT_COLUMN = 'split'
TEST_SPLIT_VALUE = 'test'

SPECIES_CONFIG = {
    'Salmonella enterica': {
        'species_filter': 'Salmonella',
        'drugs': ['Tetracycline', 'Ampicillin', 'Amoxicillin-Clavulanic acid',
                  'Cefoxitin', 'Ceftiofur', 'Gentamicin', 'Ceftriaxone'],
        'model_table': '/content/Salmonella_Model_Table_clusters.csv',
        'output_metrics': '/content/salmonella_resfinder_results.csv',
        'output_samples': '/content/salmonella_resfinder_sample_predictions.csv'
    },
    'Staphylococcus aureus': {
        'species_filter': 'Staphylococcus',
        'drugs': ['Erythromycin', 'Ciprofloxacin', 'Clindamycin', 'Penicillin'],
        'model_table': '/content/Staphylococcus_Model_Table_clusters.csv',
        'output_metrics': '/content/staph_resfinder_results.csv',
        'output_samples': '/content/staph_resfinder_sample_predictions.csv'
    }
}

# Load common files
ABRICATE_RAW_PATH = '/content/Master_Table_with_abricate_labels_resfinder.csv'
MASTER_TABLE_PATH = '/content/Master_Table.csv'

df_resfinder = pd.read_csv(ABRICATE_RAW_PATH, low_memory=False)
df_master = pd.read_csv(MASTER_TABLE_PATH, low_memory=False)

# Drug name mapping
drug_mapping = {
    "minocycline": "Minocycline",
    "tetracycline": "Tetracycline",
    "meropenem": "Meropenem",
    "sulfamethoxazole": "Trimethoprim-Sulfamethoxazole",
    "amikacin": "Amikacin",
    "azithromycin": "Azithromycin",
    "florfenicol": "Chloramphenicol",
    "colistin": "Colistin",
    "cefoxitin": "Cefoxitin",
    "piperacillin+tazobactam": "Piperacillin-Tazobactam",
    "amoxicillin+clavulanic_acid": "Amoxicillin-Clavulanic acid",
    "ticarcillin+clavulanic_acid": "Beta-lactam",
    "chloramphenicol": "Chloramphenicol",
    "imipenem": "Imipenem",
    "tobramycin": "Tobramycin",
    "doxycycline": "Tetracycline",
    "aztreonam": "Aztreonam",
    "ceftazidime": "Ceftazidime",
    "piperacillin": "Beta-lactam",
    "fosfomycin": "Fosfomycin",
    "cephalothin": "Beta-lactam",
    "linezolid": "Linezolid",
    "cefixime": "Beta-lactam",
    "lincomycin": "Lincosamide",
    "trimethoprim": "Trimethoprim",
    "streptomycin": "Streptomycin",
    "ticarcillin": "Beta-lactam",
    "cefepime": "Cefepime",
    "ampicillin": "Ampicillin",
    "erythromycin": "Erythromycin",
    "telithromycin": "Macrolides",
    "spiramycin": "Macrolides",
    "amoxicillin": "Amoxicillin",
    "cefotaxime": "Cefotaxime",
    "ertapenem": "Meropenem",
    "gentamicin": "Gentamicin",
    "penicillin": "Penicillin",
    "ciprofloxacin": "Ciprofloxacin",
    "quinolone": "Ciprofloxacin",
    "fluoroquinolone": "Ciprofloxacin",
    "ceftriaxone": "Ceftriaxone",
    "ampicillin+clavulanic_acid": "Amoxicillin-Clavulanic acid",
    "tigecycline": "Tigecycline",
    "fusidic_acid": "Fusidic acid",
    "nalidixic_acid": "Nalidixic acid",
    "clindamycin": "Clindamycin",
    "ceftiofur": "Ceftiofur"
}

# ============================================================
# STEP 2: Generate R/S predictions for both species
# ============================================================

drug_cols = [col for col in df_master.columns[3:] if col not in ['File Name', 'Species', 'Dataset']]

# Initialize predictions table
df_predictions = pd.DataFrame(index=df_resfinder.index)
df_predictions[['File Name', 'Species', 'Dataset']] = df_resfinder[['File Name', 'Species', 'Dataset']]
df_predictions[drug_cols] = 'S'

# Build resistance dictionary
drug_cols_set = set(drug_cols)
filtered_drug_mapping = {k: v for k, v in drug_mapping.items() if v is not None and v in drug_cols_set}

resistance_dict = {}
for row in df_resfinder.itertuples():
    raw_classes = row.raw_classes
    if isinstance(raw_classes, str) and not pd.isna(raw_classes):
        antibiotics = [ab.strip() for ab in raw_classes.split(';')]
        mapped_drugs = [filtered_drug_mapping.get(ab) for ab in antibiotics]
        mapped_drugs = [d for d in mapped_drugs if d is not None]
        mapped_drugs = list(set(mapped_drugs))
        if mapped_drugs:
            resistance_dict[row.Index] = mapped_drugs

for idx, drugs_list in resistance_dict.items():
    unique_drugs = list(set(drugs_list))
    df_predictions.loc[idx, unique_drugs] = 'R'

df_predictions = df_predictions.fillna('S')

# ============================================================
# STEP 3: Loop over species and compute metrics
# ============================================================

label_map = {'S': 1, 'R': 0, 'I': 1}

for species_name, config in SPECIES_CONFIG.items():
    print("\n" + "="*60)
    print(f"Processing {species_name}")
    print("="*60)

    species_filter = config['species_filter']
    drugs_to_evaluate = config['drugs']
    model_table_path = config['model_table']
    output_metrics = config['output_metrics']
    output_samples = config['output_samples']

    # Filter predictions to species
    df_species_preds = df_predictions[df_predictions['Species'].str.contains(species_filter, case=False, na=False)].copy()

    # Load ground truth
    df_ground_truth = pd.read_csv(model_table_path, low_memory=False)
    if FILTER_TEST_SET:
        df_ground_truth = df_ground_truth[df_ground_truth[SPLIT_COLUMN]==TEST_SPLIT_VALUE].copy()

    # Merge
    merge_keys = ['File Name']
    valid_drugs = [d for d in drugs_to_evaluate if d in df_species_preds.columns and d in df_ground_truth.columns]
    df_merged = pd.merge(df_ground_truth[merge_keys + valid_drugs],
                         df_species_preds[merge_keys + valid_drugs],
                         on=merge_keys,
                         how='inner',
                         suffixes=('_gt', '_pred'))

    # Compute metrics
    results = []
    for drug in valid_drugs:
        gt_col = f'{drug}_gt'
        pred_col = f'{drug}_pred'
        y_true = df_merged[gt_col].copy()
        y_pred = df_merged[pred_col].copy()
        valid_mask = y_true.isin(['S','R','I'])
        y_true = y_true[valid_mask].map(label_map).astype(int)
        y_pred = y_pred[valid_mask].map(label_map).fillna(1).astype(int)
        if len(y_true) < 2:
            continue
        ba = balanced_accuracy_score(y_true, y_pred)
        prec = precision_score(y_true, y_pred, pos_label=0, zero_division=0)
        rec = recall_score(y_true, y_pred, pos_label=0, zero_division=0)
        f1_val = f1_score(y_true, y_pred, pos_label=0, zero_division=0)
        fb = fbeta_score(y_true, y_pred, beta=0.5, pos_label=0, zero_division=0)
        roc = roc_auc_score(y_true, y_pred)
        jac = jaccard_score(y_true, y_pred, pos_label=0, zero_division=0)
        mcc = matthews_corrcoef(y_true, y_pred)
        results.append({'Drug': drug, 'Bal. Acc': round(ba,2), 'Precision': round(prec,2),
                        'Recall': round(rec,2), 'F1': round(f1_val,2), 'F_beta': round(fb,2),
                        'ROC AUC': round(roc,2), 'Jaccard': round(jac,2), 'MCC': round(mcc,2)})

    # Save per-drug metrics
    df_results = pd.DataFrame(results)
    df_results.to_csv(output_metrics, index=False)
    print(f"Metrics saved to: {output_metrics}")

    # Save sample-level predictions
    cols_to_save = ['File Name'] + [f'{d}_gt' for d in valid_drugs] + [f'{d}_pred' for d in valid_drugs]
    df_sample_preds = df_merged[cols_to_save].copy()
    for drug in valid_drugs:
        gt_col = f'{drug}_gt'
        df_sample_preds[gt_col] = df_sample_preds[gt_col].replace('I','S')
        valid_mask = df_sample_preds[gt_col].isin(['S','R'])
        df_sample_preds.loc[~valid_mask, gt_col] = np.nan
    df_sample_preds.to_csv(output_samples, index=False)
    print(f"Sample predictions saved to: {output_samples}")

Mounted at /content/drive
Copied: Master_Table_with_abricate_labels_resfinder.csv
Copied: Master_Table.csv
Copied: Salmonella_Model_Table_clusters.csv
Copied: Staphylococcus_Model_Table_clusters.csv


Processing Salmonella enterica
Metrics saved to: /content/salmonella_resfinder_results.csv
Sample predictions saved to: /content/salmonella_resfinder_sample_predictions.csv

Processing Staphylococcus aureus
Metrics saved to: /content/staph_resfinder_results.csv
Sample predictions saved to: /content/staph_resfinder_sample_predictions.csv


In [None]:
import pandas as pd
from statsmodels.stats.contingency_tables import mcnemar

# ====================================================================
# CONFIGURATION - FILE PATHS
# ====================================================================
# Salmonella Files
SALMONELLA_RES_PATH = '/content/salmonella_resfinder_sample_predictions.csv'
SALMONELLA_CNN_PATH = '/content/cnn_salmonella_sample_predictions.csv'

# Staphylococcus Files
STAPH_RES_PATH = '/content/staph_resfinder_sample_predictions.csv'
STAPH_CNN_PATH = '/content/cnn_staph_sample_predictions.csv'

# Drugs Evaluated
SALMONELLA_DRUGS = ['Tetracycline', 'Ampicillin', 'Amoxicillin-Clavulanic acid',
                    'Cefoxitin', 'Ceftiofur', 'Gentamicin', 'Ceftriaxone']
STAPH_DRUGS = ['Erythromycin', 'Ciprofloxacin', 'Clindamycin', 'Penicillin']


# ====================================================================
# MCNEMAR'S TEST FUNCTION
# ====================================================================
def run_mcnemars_test(res_path, cnn_path, drugs, species_name):
    print("=" * 60)
    print(f"MCNEMAR'S TEST: {species_name.upper()} (CNN vs ResFinder)")
    print("=" * 60)

    try:
        df_res = pd.read_csv(res_path)
        df_cnn = pd.read_csv(cnn_path)
        cnn_rename_map = {drug: f"{drug}_pred_cnn" for drug in drugs}
        df_cnn = df_cnn.rename(columns=cnn_rename_map)
    except FileNotFoundError as e:
        print(f"ERROR: Could not find files for {species_name}. Make sure you generated them!")
        print(e, "\n")
        return

    # Merge on File Name
    df_merged = pd.merge(df_res, df_cnn, on='File Name', suffixes=('', '_cnn_df'))

    print(f"{'Drug Name'.ljust(30)} | {'p-value'.ljust(10)} | {'Significant?'}")
    print("-" * 60)

    for drug in drugs:
        gt_col = f"{drug}_gt"
        res_col = f"{drug}_pred"
        cnn_col = f"{drug}_pred_cnn"

        # Skip if columns are missing
        if not all(c in df_merged.columns for c in [gt_col, res_col, cnn_col]):
            print(f"{drug.ljust(30)} | {'MISSING DATA'.ljust(10)} | -")
            continue

        # Drop rows where ground truth or either prediction is missing/NaN
        df_clean = df_merged.dropna(subset=[gt_col, res_col, cnn_col]).copy()

        if len(df_clean) == 0:
            print(f"{drug.ljust(30)} | {'NO DATA'.ljust(10)} | -")
            continue

        both_correct = 0
        cnn_right_res_wrong = 0
        res_right_cnn_wrong = 0
        both_wrong = 0

        for _, row in df_clean.iterrows():
            gt = row[gt_col]
            res = row[res_col]
            cnn = row[cnn_col]

            cnn_correct = (cnn == gt)
            res_correct = (res == gt)

            if cnn_correct and res_correct:
                both_correct += 1
            elif cnn_correct and not res_correct:
                cnn_right_res_wrong += 1
            elif not cnn_correct and res_correct:
                res_right_cnn_wrong += 1
            else:
                both_wrong += 1

        # Build 2x2 table for McNemar
        # [[Both correct,             ResFinder right / CNN wrong],
        #  [CNN right / Res wrong,    Both wrong]]
        table = [[both_correct, res_right_cnn_wrong],
                 [cnn_right_res_wrong, both_wrong]]

        # Run the test (exact=False uses chi-squared, correction=True applies continuity correction)
        # If the number of disagreements is very small, statsmodels handles it safely.
        result = mcnemar(table, exact=False, correction=True)

        # Significance threshold
        is_sig = "YES (*)" if result.pvalue < 0.05 else "NO"

        print(f"{drug.ljust(30)} | {result.pvalue:.4f}     | {is_sig}")

    print("\n* YES indicates a statistically significant difference in performance (p < 0.05).\n")


# ====================================================================
# EXECUTE TESTS
# ====================================================================
run_mcnemars_test(SALMONELLA_RES_PATH, SALMONELLA_CNN_PATH, SALMONELLA_DRUGS, "Salmonella enterica")
run_mcnemars_test(STAPH_RES_PATH, STAPH_CNN_PATH, STAPH_DRUGS, "Staphylococcus aureus")

MCNEMAR'S TEST: SALMONELLA ENTERICA (CNN vs ResFinder)
Drug Name                      | p-value    | Significant?
------------------------------------------------------------
Tetracycline                   | 0.0000     | YES (*)
Ampicillin                     | 0.0000     | YES (*)
Amoxicillin-Clavulanic acid    | 0.0000     | YES (*)
Cefoxitin                      | 0.0000     | YES (*)
Ceftiofur                      | 0.0001     | YES (*)
Gentamicin                     | 0.0098     | YES (*)
Ceftriaxone                    | 0.0184     | YES (*)

* YES indicates a statistically significant difference in performance (p < 0.05).

MCNEMAR'S TEST: STAPHYLOCOCCUS AUREUS (CNN vs ResFinder)
Drug Name                      | p-value    | Significant?
------------------------------------------------------------
Erythromycin                   | 0.0000     | YES (*)
Ciprofloxacin                  | 0.0004     | YES (*)
Clindamycin                    | 0.0000     | YES (*)
Penicillin              