This notebook takes the raw prediction outputs, performs mapping to metadata files, and produces summary metrics.

In [3]:
import pandas as pd
import os

In [4]:
# helper to load pred files
def load_predictions_df(backbone, training_strategy, seed, dataset, base_path = "/home/local/data/sophie/DGX_GridSearch/raw_predictions/"):
    """
    Helper to fetch predictions dataframe.

    Args: 
        backbone (string): Training Backbone (e.g. base, base65, grey, grey65, single, single65)
        training_strategy (string): Network training strategy (e.g. half, first, full)
        seed (int): Seed used for model training (e.g. 42, 43, 44)
        dataset (string): Name of eval dataset (e.g. cxr14, padchest, openi, jsrt)
        base_path (string, optional): Directory containing prediction file

    Returns:
        dataframe: predictions dataframe
    """
    return pd.read_csv(os.path.join(base_path, f"preds_{backbone}_{training_strategy}_seed{seed}_{dataset}.csv"))


In [5]:
meta_data = pd.read_csv("/home/local/data/sophie/processed_meta_full.csv")

In [18]:
meta_data['img_name'] = meta_data['mha'].apply(lambda x: x[:-4])
meta_data.head(1)

Unnamed: 0,original_image_name,orig_dataset,label,mha,Finding Labels,Patient ID,Patient Age,Patient Gender,original_image_width,original_image_height,...,shared_image_width,shared_image_height,Study Year,Manufacturer,Modality,subtlety,nodule_size,malignant,Partition,img_name
0,00010496_001,chestxray14,1,n0239.mha,Nodule,10496,44.0,M,2992.0,2991.0,...,1024,1024,,,,,,,Train,n0239


In [41]:
# helper to convert raw outputs to binary preds
def get_max_pred(x):
    return 1 if x['pred_0'] < x['pred_1'] else 0
def pred_based_on_pos_logit(x):
    return 1 if x['pred_1'] >= 0.5 else 0

In [43]:
test_df = load_predictions_df("base", "full", 42, "cxr14")
test_df['pred_int'] = test_df.apply(get_max_pred, axis=1)
test_df['pred1_thresh_int'] = test_df['pred_1'].apply(lambda x: 1 if x >= 0.5 else 0)
test_df['pred0_thresh_int'] = test_df['pred_0'].apply(lambda x: 0 if x >= 0.5 else 1)
test_df.head(5)

Unnamed: 0,img_name,true_label,pred_0,pred_1,pred_int,pred1_thresh_int,pred0_thresh_int
0,c0021,0,0.720273,-0.135957,0,0,0
1,c0029,0,0.162161,-0.043217,0,0,1
2,c0034,0,1.091611,-0.891011,0,0,0
3,c0035,0,0.063651,0.094966,1,0,1
4,c0056,0,1.18016,-0.596414,0,0,0


In [44]:
mapped_df = test_df.merge(meta_data, how="left", on="img_name")
mapped_df.head(1)

Unnamed: 0,img_name,true_label,pred_0,pred_1,pred_int,pred1_thresh_int,pred0_thresh_int,original_image_name,orig_dataset,label,...,orig_bbox_height,shared_image_width,shared_image_height,Study Year,Manufacturer,Modality,subtlety,nodule_size,malignant,Partition
0,c0021,0,0.720273,-0.135957,0,0,0,00000425_001,chestxray14,0,...,0,1024,1024,,,,,,,Test


In [45]:
tp = len(mapped_df.query("true_label == pred_int and true_label==1"))
tn = len(mapped_df.query("true_label == pred_int and true_label==0"))
fp = len(mapped_df.query("true_label != pred_int and true_label==0"))
fn = len(mapped_df.query("true_label != pred_int and true_label==1"))

In [46]:
def calc_prec(tp, fp):
    return tp/(tp+fp)
def calc_rec(tp, fn):
    return tp/(tp+fn)
def calc_spec(tn, fp):
    return tn/(tn+fp)

In [47]:
print(f"Accuracy: {round(len(mapped_df.query('true_label == pred_int'))/len(mapped_df),4)*100}%")
print(f"Precision: {round(calc_prec(tp, fp),4)}")
print(f"Recall/Sens: {round(calc_rec(tp, fn),4)}")
print(f"Spec: {round(calc_spec(tn, fp),4)}")

Accuracy: 84.72%
Precision: 0.8655
Recall/Sens: 0.8222
Spec: 0.8722


In [None]:
len(mapped_df.query("true_label != pred1_thresh_int and true_label==1"))

In [50]:
calc_prec(len(mapped_df.query("true_label == pred1_thresh_int and true_label==1")), len(mapped_df.query("true_label != pred1_thresh_int and true_label==0")))

0.9194630872483222

0.7611111111111111

In [55]:
print(f"Accuracy 1: {round(len(mapped_df.query('true_label == pred_int'))/len(mapped_df),4)*100}%")
print(f"Precision: {round(calc_prec(tp, fp),4)}")
print(f"Recall/Sens: {round(calc_rec(tp, fn),4)}")
print(f"Spec: {round(calc_spec(tn, fp),4)}")
print(f"Accuracy 2: {round(len(mapped_df.query('true_label == pred1_thresh_int'))/len(mapped_df),4)*100}%")
print(f'Precision: {round(calc_prec(len(mapped_df.query("true_label == pred1_thresh_int and true_label==1")), len(mapped_df.query("true_label != pred1_thresh_int and true_label==0"))),4)}')
print(f'Recall/Sens: {round(calc_rec(len(mapped_df.query("true_label == pred1_thresh_int and true_label==1")), len(mapped_df.query("true_label != pred1_thresh_int and true_label==1"))),4)}')
print(f'Spec: {round(calc_spec(len(mapped_df.query("true_label == pred1_thresh_int and true_label==0")), len(mapped_df.query("true_label != pred1_thresh_int and true_label==0"))),4)}')
print(f"Accuracy 3: {round(len(mapped_df.query('true_label == pred0_thresh_int'))/len(mapped_df),4)*100}%")


Accuracy 1: 84.72%
Precision: 0.8655
Recall/Sens: 0.8222
Spec: 0.8722
Accuracy 2: 84.72%
Precision: 0.9195
Recall/Sens: 0.7611
Spec: 0.9333
Accuracy 3: 83.89%


In [28]:
mapped_df.query("true_label != pred_int")

Unnamed: 0,img_name,true_label,pred_0,pred_1,pred_int,original_image_name,orig_dataset,label,mha,Finding Labels,...,orig_bbox_height,shared_image_width,shared_image_height,Study Year,Manufacturer,Modality,subtlety,nodule_size,malignant,Partition
3,c0035,0,0.063651,0.094966,1,00001645_000,chestxray14,0,c0035.mha,No Finding,...,0,1024,1024,,,,,,,Test
12,c0249,0,-0.897105,1.171756,1,00000701_000,chestxray14,0,c0249.mha,No Finding,...,0,1024,1024,,,,,,,Test
29,c0426,0,-1.05681,1.293607,1,00000502_000,chestxray14,0,c0426.mha,No Finding,...,0,1024,1024,,,,,,,Test
33,c0512,0,0.104742,0.138245,1,00001266_000,chestxray14,0,c0512.mha,No Finding,...,0,1024,1024,,,,,,,Test
49,c0926,0,-3.199633,3.537657,1,00001761_000,chestxray14,0,c0926.mha,No Finding,...,0,1024,1024,,,,,,,Test
50,c0931,0,-1.280314,1.771777,1,00001447_000,chestxray14,0,c0931.mha,No Finding,...,0,1024,1024,,,,,,,Test
70,c1424,0,0.066005,0.141601,1,00000417_002,chestxray14,0,c1424.mha,No Finding,...,0,1024,1024,,,,,,,Test
73,c1488,0,-0.025901,0.125242,1,00000184_000,chestxray14,0,c1488.mha,No Finding,...,0,1024,1024,,,,,,,Test
77,c1616,0,-0.90013,1.317359,1,00000083_000,chestxray14,0,c1616.mha,No Finding,...,0,1024,1024,,,,,,,Test
79,c1658,0,-2.188762,2.542366,1,00000357_000,chestxray14,0,c1658.mha,No Finding,...,0,1024,1024,,,,,,,Test
