In [None]:
import pandas as pd
import numpy as np 

In [2]:
timestamp = '04-22_02-08-47'
part = '7'
ws = f'/data/geraugi/plural/dataset_files/xplainer_inference_2025-{timestamp}'
confusion_path = f'{ws}/confusion_matrix_part{part}.csv'
descriptor_probs_path = f'{ws}/descriptor_probs_long_part{part}.csv'
probs_path = f'{ws}/xp_inf_probabilities_{part}.csv'
neg_probs_path = f'{ws}/xp_inf_neg_probabilities_{part}.csv'
gt_path = f'{ws}/xp_inf_labels_{part}.csv'
preds_path = f'{ws}/xp_inf_predictions_{part}.csv'

confusion_matrix_df = pd.read_csv(confusion_path)
desc_probs_df = pd.read_csv(descriptor_probs_path)
dis_probs_df = pd.read_csv(probs_path)
dis_neg_probs_df = pd.read_csv(neg_probs_path)
gt_df = pd.read_csv(gt_path)
preds_df = pd.read_csv(preds_path)


In [3]:
desc_probs_df.head(5)

Unnamed: 0,dicom_id,descriptor,prob,neg_prob
0,424ec6e6-d4f6483d-72a716a2-3f9e6778-b9b5c5f5,Absence of lung markings indicating Pneumothorax,0.541488,0.458512
1,424ec6e6-d4f6483d-72a716a2-3f9e6778-b9b5c5f5,Air bronchograms indicating Pneumonia,0.773128,0.226872
2,424ec6e6-d4f6483d-72a716a2-3f9e6778-b9b5c5f5,Air bronchograms within the opacity indicating...,0.492451,0.507548
3,424ec6e6-d4f6483d-72a716a2-3f9e6778-b9b5c5f5,Barrel Chest indicating Emphysema,0.586808,0.413192
4,424ec6e6-d4f6483d-72a716a2-3f9e6778-b9b5c5f5,Blunting of costophrenic angles indicating Ple...,0.541042,0.458958


In [5]:
confusion_matrix_df = confusion_matrix_df.rename(columns={'Unnamed: 0': 'dicom_id'}) \
                                         .set_index('dicom_id')
dis_probs_df        = dis_probs_df       .rename(columns={'Unnamed: 0': 'dicom_id'}) \
                                         .set_index('dicom_id')
dis_neg_probs_df    = dis_neg_probs_df   .rename(columns={'Unnamed: 0': 'dicom_id'}) \
                                         .set_index('dicom_id')

In [14]:
from IPython.display import display
conf_long = (
    confusion_matrix_df
      .reset_index()                       # bring dicom_id back as a column if it was the index
      .melt(
         id_vars='dicom_id',              # keep dicom_id
         var_name='disease',              # new column for what used to be each disease column name
         value_name='res'                 # new column for TP/TN/FP/FN values
      )
)

# — 2) Parse disease out of each descriptor
# desc_probs_df has index=dicom_id, columns: descriptor, prob, neg_prob
# and descriptor strings end in "indicating <disease>"
desc = desc_probs_df.copy()
desc['disease'] = desc['descriptor'].str.extract(r'indicating\s+(.+)$')[0].str.strip()


# — 3) Merge and aggregate
merged = (
    desc
      .reset_index()        # make dicom_id a column
      .merge(
         conf_long,
         on=['dicom_id', 'disease'],
         how='inner'
      )
)

merged['dif_prob'] = merged['prob'] - merged['neg_prob']
# Now each row has: dicom_id, descriptor, prob, neg_prob, disease, res

# Group by disease / descriptor / res and compute:
stats = (
    merged
      .groupby(['disease','descriptor','res'])
      .agg(
         #n               = ('res',      'count'),
         mean_prob       = ('prob',     'mean'),
         std_prob        = ('prob',     'std'),
         mean_neg_prob   = ('neg_prob', 'mean'),
         std_neg_prob    = ('neg_prob', 'std'),
         mean_dif_prob   = ('dif_prob', 'mean'),
         std_dif_prob    = ('dif_prob', 'std'),
      )
      .reset_index()
)
#stats_fp = stats.query("res == 'FP'")
merged_fp = merged.query("res == 'FP'")
#display(stats_fp)
display(merged_fp)

Unnamed: 0,index,dicom_id,descriptor,prob,neg_prob,disease,res,dif_prob
0,0,424ec6e6-d4f6483d-72a716a2-3f9e6778-b9b5c5f5,Absence of lung markings indicating Pneumothorax,0.541488,0.458512,Pneumothorax,FP,0.082975
1,11,424ec6e6-d4f6483d-72a716a2-3f9e6778-b9b5c5f5,Deep sulcus sign indicating Pneumothorax,0.533381,0.466619,Pneumothorax,FP,0.066762
2,17,424ec6e6-d4f6483d-72a716a2-3f9e6778-b9b5c5f5,Flattening of the hemidiaphragm indicating Pne...,0.624963,0.375037,Pneumothorax,FP,0.249927
3,28,424ec6e6-d4f6483d-72a716a2-3f9e6778-b9b5c5f5,Increased radiolucency indicating Pneumothorax,0.497868,0.502132,Pneumothorax,FP,-0.004264
4,51,424ec6e6-d4f6483d-72a716a2-3f9e6778-b9b5c5f5,Shifting of the mediastinum indicating Pneumot...,0.462463,0.537537,Pneumothorax,FP,-0.075075
...,...,...,...,...,...,...,...,...
1252927,1273437,fa62fc78-9b66c0fd-aa7ee648-8b82e0fc-b0e5c0d4,Fracture lines that are jagged or irregular in...,0.545534,0.454466,Fracture,FP,0.091067
1252928,1273457,fa62fc78-9b66c0fd-aa7ee648-8b82e0fc-b0e5c0d4,Misalignments of bone fragments indicating Fra...,0.629041,0.370959,Fracture,FP,0.258082
1252929,1273458,fa62fc78-9b66c0fd-aa7ee648-8b82e0fc-b0e5c0d4,Multiple fracture lines that intersect at diff...,0.502643,0.497357,Fracture,FP,0.005285
1252930,1273475,fa62fc78-9b66c0fd-aa7ee648-8b82e0fc-b0e5c0d4,Visible breaks in the continuity of the bone i...,0.570649,0.429351,Fracture,FP,0.141297


In [15]:
# 1) filter FP & dif_prob < 0
fp_neg = merged[(merged['res'] == 'FP') & (merged['dif_prob'] < 0)]

# 2) count how many times each descriptor appears
desc_counts = (
    fp_neg
      .groupby('descriptor')
      .size()
      .reset_index(name='count')
      .sort_values('count', ascending=False)
)

print(desc_counts)

                                           descriptor  count
57  Visible callus or healing tissue indicating Fr...  10796
28     Increased radiolucency indicating Pneumothorax   3542
50  Shifting of the mediastinum indicating Pneumot...   3532
14  Disruptions of the cortex or outer layer of th...   2441
24  Increased density of lung tissue indicating Co...   1889
..                                                ...    ...
35  Loss of normal lung markings within the opacit...    158
21             Hyperlucent lungs indicating Emphysema    155
51  Silhouette sign loss with adjacent structures ...    144
12  Displacement of the diaphragm indicating Atele...    143
26  Increased interstitial markings in the lungs i...     70

[61 rows x 2 columns]


In [16]:
# 1) filter FP & dif_prob < 0
tp_neg = merged[(merged['res'] == 'TP') & (merged['dif_prob'] < 0)]

# 2) count how many times each descriptor appears
desc_counts = (
    tp_neg
      .groupby('descriptor')
      .size()
      .reset_index(name='count')
      .sort_values('count', ascending=False)
)

print(desc_counts)

                                           descriptor  count
37      Mediastinal shift indicating Pleural Effusion   1379
49  Shifting of the mediastinum indicating Atelect...   1310
47    Reduced lung volume indicating Pleural Effusion    882
18  Fluid levels within the opacity indicating Lun...    706
23  Increased density in the lung field indicating...    525
..                                                ...    ...
32  Lobulated peripheral shadowing indicating Pleu...      3
46              Pulmonary bullae indicating Emphysema      3
8   Calcifications on the pleura indicating Pleura...      2
54   Thickening of intestinal folds indicating Hernia      2
21             Hyperlucent lungs indicating Emphysema      1

[61 rows x 2 columns]


In [4]:
# 1. Suffix the columns
res_df      = confusion_matrix_df.add_suffix('_res')
prob_df     = dis_probs_df.add_suffix('_prob')
neg_prob_df = dis_neg_probs_df.add_suffix('_neg_prob')

# 2. Concatenate on the index
merged_df = pd.concat([res_df, prob_df, neg_prob_df], axis=1, join='inner')

merged_df.head(5)

Unnamed: 0_level_0,Atelectasis_res,Cardiomegaly_res,Consolidation_res,Edema_res,Emphysema_res,Fracture_res,Hernia_res,Lung Opacity_res,Pleural Effusion_res,Pleural Thickening_res,...,Edema_neg_prob,Emphysema_neg_prob,Fracture_neg_prob,Hernia_neg_prob,Lung Opacity_neg_prob,Pleural Effusion_neg_prob,Pleural Thickening_neg_prob,Pneumonia_neg_prob,Pneumothorax_neg_prob,No Findings_neg_prob
dicom_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
424ec6e6-d4f6483d-72a716a2-3f9e6778-b9b5c5f5,FP,FP,FN,FP,FP,FP,FP,TP,TN,FP,...,0.441639,0.393692,0.307714,0.334435,0.452264,0.607101,0.312133,0.291387,0.474995,0.700465
db9df957-41ff2879-0874458d-cc8460c2-f0925266,FP,TP,FP,TP,TN,TN,TN,TP,TP,TN,...,0.280656,0.576938,0.562461,0.535271,0.275968,0.168234,0.528432,0.365828,0.448424,0.840053
3bf44e0c-e62502e1-f189ad54-9ca3bb7f-a17ae8f0,FP,FP,FP,TP,TN,FP,FP,TP,TP,FP,...,0.42711,0.556135,0.491582,0.461899,0.325164,0.18844,0.473784,0.427636,0.440037,0.793909
bdcdc2e7-65506aad-b3a1a137-c65690b8-40b0abaa,FP,FP,FP,FP,TN,FP,FP,TP,TN,FP,...,0.476697,0.505984,0.361549,0.45447,0.425833,0.573824,0.435382,0.374213,0.477932,0.637559
aeb30959-32075af6-296e6124-fca3e68c-ebbd23d7,FP,TN,TN,TN,FP,TN,FP,TN,FP,TN,...,0.571906,0.481342,0.542168,0.434497,0.561001,0.464954,0.540161,0.594268,0.506347,0.564732


In [12]:
diseases = [
            "Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Emphysema",
            "Fracture", "Hernia", "Lung Opacity", "Pleural Effusion", "Pleural Thickening",
            "Pneumonia", "Pneumothorax", "No Findings"
        ]
for disease in diseases:
    prob_col = f"{disease}_prob"
    neg_col  = f"{disease}_neg_prob"
    diff_col = f"{disease}_dif_prob"
    
    # only if both source columns exist
    if prob_col in merged_df.columns and neg_col in merged_df.columns:
        merged_df[diff_col] = merged_df[prob_col] - merged_df[neg_col]

In [None]:

for disease in diseases:
    res_col      = f"{disease}_res"
    prob_col     = f"{disease}_prob"
    neg_prob_col = f"{disease}_neg_prob"
    dif_prob_col = f"{disease}_dif_prob"
    
    # group by TP/TN/FP/FN
    grp = merged_df.groupby(res_col)
    grp.head()
    # aggregate counts + stats for both prob and neg_prob
    agg = grp[res_col, prob_col, neg_prob_col, dif_prob_col].agg(
        count       =  (res_col, 'count'),      # count per group
        mean_prob   =  (prob_col, 'mean'),
        std_prob    =  (prob_col, 'std'),
        mean_neg    =  (neg_prob_col, 'mean'),
        std_neg     =  (neg_prob_col, 'std'),
        mean_dif    =  (dif_prob_col, 'mean'),
        std_dif     =  (dif_prob_col, 'std')
    )
    
    # flatten MultiIndex columns
    agg.columns = ['n', 'mean_prob', 'std_prob', 'mean_neg_prob', 'std_neg_prob','mean_dif_prob', 'std_dif_prob'] 
    
    agg = agg.reset_index()  # bring res (TP/TN/…) back as a column
    print(agg)


  Atelectasis_res     n  mean_prob  std_prob  mean_neg_prob  std_neg_prob  \
0              FN  1549   0.405920  0.077168       0.588352      0.077345   
1              FP  6340   0.604492  0.067911       0.389262      0.068177   
2              TN  6316   0.377763  0.086164       0.616627      0.085401   
3              TP  6335   0.630796  0.066948       0.362254      0.067308   

   mean_dif_prob  std_dif_prob  
0      -0.182432      0.154467  
1       0.215229      0.135995  
2      -0.238864      0.171528  
3       0.268542      0.134132  
  Cardiomegaly_res     n  mean_prob  std_prob  mean_neg_prob  std_neg_prob  \
0               FN   604   0.429092  0.072148       0.569921      0.071754   
1               FP  7753   0.609795  0.090864       0.389373      0.090956   
2               TN  9226   0.373553  0.098021       0.625300      0.097503   
3               TP  2957   0.664401  0.106896       0.334581      0.106925   

   mean_dif_prob  std_dif_prob  
0      -0.140829      0.1

  # This is added back by InteractiveShellApp.init_path()
  # This is added back by InteractiveShellApp.init_path()
  # This is added back by InteractiveShellApp.init_path()
  # This is added back by InteractiveShellApp.init_path()
  # This is added back by InteractiveShellApp.init_path()
  # This is added back by InteractiveShellApp.init_path()
  # This is added back by InteractiveShellApp.init_path()
  # This is added back by InteractiveShellApp.init_path()
  # This is added back by InteractiveShellApp.init_path()
  # This is added back by InteractiveShellApp.init_path()
  # This is added back by InteractiveShellApp.init_path()
  # This is added back by InteractiveShellApp.init_path()
  # This is added back by InteractiveShellApp.init_path()
