In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import sys

from sklearn.metrics import roc_curve
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import auc
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score
from sklearn.metrics import balanced_accuracy_score
from sklearn.metrics import recall_score

: 

# Prediction performance evaluation

In [None]:
df = pd.read_csv('/rds/general/user/sea22/ephemeral/datafiles/chexpert/CheXpert-v1.0/chexpert.resample.test.csv')

white = 'White'
asian = 'Asian'
black = 'Black'

: 

## Race classification

In [None]:
data_dir = '/homes/sea22/MSC_PROJECT/main/chex-aIchemy/train/results/densenet-imagenet-None-0.0001-test/'
cnn_pred_race = pd.read_csv(data_dir + 'predictions_test_race.csv')

: 

In [None]:
preds_race = np.stack([cnn_pred_race['class_0'],cnn_pred_race['class_1'],cnn_pred_race['class_2']]).transpose()
targets_race = np.array(cnn_pred_race['target'])

: 

In [None]:
# WHITE
pos_label = 0
y = np.array(targets_race)
y[targets_race != pos_label] = 0
y[targets_race == pos_label] = 1
fpr_w, tpr_w, _ = roc_curve(y, preds_race[:,pos_label])
roc_auc_w = auc(fpr_w, tpr_w)

# ASIAN
pos_label = 1
y = np.array(targets_race)
y[targets_race != pos_label] = 0
y[targets_race == pos_label] = 1
fpr_a, tpr_a, _ = roc_curve(y, preds_race[:,pos_label])
roc_auc_a = auc(fpr_a, tpr_a)

# BLACK
pos_label = 2
y = np.array(targets_race)
y[targets_race != pos_label] = 0
y[targets_race == pos_label] = 1
fpr_b, tpr_b, _ = roc_curve(y, preds_race[:,pos_label])
roc_auc_b = auc(fpr_b, tpr_b)

fig, ax = plt.subplots(figsize=(7,4))
plt.plot(fpr_w, tpr_w, lw=1.5, alpha=.8, label='White AUC=%0.2f' % roc_auc_w)
plt.plot(fpr_a, tpr_a, lw=1.5, alpha=.8, label='Asian AUC=%0.2f' % roc_auc_a)
plt.plot(fpr_b, tpr_b, lw=1.5, alpha=.8, label='Black AUC=%0.2f' % roc_auc_b)
ax.plot([0, 1], [0, 1], linestyle='--', lw=1.5, color='k', label='Chance', alpha=.8)
plt.xlabel('False Positive Rate', fontsize=14)
plt.ylabel('True Positive Rate', fontsize=14)
plt.legend(loc="lower right", fontsize=12)
plt.title('Race Classification', fontsize=14)
ax.set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05])

plt.show()

# fig.savefig("CNN-race.png", bbox_inches='tight', dpi=300)

: 

## Sex classification

In [None]:
data_dir = '/homes/sea22/MSC_PROJECT/main/chex-aIchemy/train/results/densenet-imagenet-None-0.0001-test/'
cnn_pred_sex = pd.read_csv(data_dir + 'predictions_test_sex.csv')

# for subgroup analysis
cnn_pred_sex['race'] = df['race']
cnn_pred_sex_w = cnn_pred_sex[cnn_pred_sex['race'] == white]
cnn_pred_sex_a = cnn_pred_sex[cnn_pred_sex['race'] == asian]
cnn_pred_sex_b = cnn_pred_sex[cnn_pred_sex['race'] == black]

: 

In [None]:
preds_sex = np.stack([cnn_pred_sex_w['class_0'],cnn_pred_sex_w['class_1']]).transpose()
targets_sex = np.array(cnn_pred_sex_w['target'])
fpr_w, tpr_w, _ = roc_curve(targets_sex, preds_sex[:,1])
roc_auc_w = auc(fpr_w, tpr_w)

preds_sex = np.stack([cnn_pred_sex_a['class_0'],cnn_pred_sex_a['class_1']]).transpose()
targets_sex = np.array(cnn_pred_sex_a['target'])
fpr_a, tpr_a, _ = roc_curve(targets_sex, preds_sex[:,1])
roc_auc_a = auc(fpr_a, tpr_a)

preds_sex = np.stack([cnn_pred_sex_b['class_0'],cnn_pred_sex_b['class_1']]).transpose()
targets_sex = np.array(cnn_pred_sex_b['target'])
fpr_b, tpr_b, _ = roc_curve(targets_sex, preds_sex[:,1])
roc_auc_b = auc(fpr_b, tpr_b)

: 

In [None]:
fig, ax = plt.subplots()
plt.plot(fpr_w, tpr_w, lw=1.5, alpha=.8, label='White AUC=%0.2f' % roc_auc_w)
plt.plot(fpr_a, tpr_a, lw=1.5, alpha=.8, label='Asian AUC=%0.2f' % roc_auc_a)
plt.plot(fpr_b, tpr_b, lw=1.5, alpha=.8, label='Black AUC=%0.2f' % roc_auc_b)
ax.plot([0, 1], [0, 1], linestyle='--', lw=1.5, color='k', label='Chance', alpha=.8)
plt.xlabel('False Positive Rate', fontsize=14)
plt.ylabel('True Positive Rate', fontsize=14)
plt.legend(loc="lower right", fontsize=12)
plt.title('Sex Classification', fontsize=14)
ax.set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05])

# fig.savefig("CNN-sex.png", bbox_inches='tight', dpi=300)

: 

## Disease classification (multi-label)

In [None]:
data_dir = '/homes/sea22/MSC_PROJECT/main/chex-aIchemy/train/results/densenet-imagenet-None-0.0001-test/'

cnn_pred_disease = pd.read_csv(data_dir + 'predictions_test_disease.csv')

: 

In [None]:
labels = [
    'No Finding',
    'Enlarged Cardiomediastinum',
    'Cardiomegaly',
    'Lung Opacity',
    'Lung Lesion',
    'Edema',
    'Consolidation',
    'Pneumonia',
    'Atelectasis',
    'Pneumothorax',
    'Pleural Effusion',
    'Pleural Other',
    'Fracture',
    'Support Devices']

: 

In [None]:
label = 10    # No finding
#label = 10 # Pleural effusion
print(labels[label])

: 

In [None]:
cnn_pred_disease['race'] = df['race']
cnn_pred_disease['sex'] = df['sex']
cnn_pred_disease['age'] = df['age']
cnn_pred_disease[labels[label]] = df[labels[label]]

# subgroups
cnn_pred_m = cnn_pred_disease[cnn_pred_disease['sex'] == 'Male'].copy()
cnn_pred_f = cnn_pred_disease[cnn_pred_disease['sex'] == 'Female'].copy()
cnn_pred_w = cnn_pred_disease[cnn_pred_disease['race'] == white].copy()
cnn_pred_a = cnn_pred_disease[cnn_pred_disease['race'] == asian].copy()
cnn_pred_b = cnn_pred_disease[cnn_pred_disease['race'] == black].copy()

: 

In [None]:
preds = cnn_pred_disease['class_' + str(label)]
targets = np.array(cnn_pred_disease['target_' + str(label)])

preds_m = cnn_pred_m['class_' + str(label)]
targets_m = np.array(cnn_pred_m['target_' + str(label)])

preds_f = cnn_pred_f['class_' + str(label)]
targets_f = np.array(cnn_pred_f['target_' + str(label)])

preds_w = cnn_pred_w['class_' + str(label)]
targets_w = np.array(cnn_pred_w['target_' + str(label)])

preds_a = cnn_pred_a['class_' + str(label)]
targets_a = np.array(cnn_pred_a['target_' + str(label)])

preds_b = cnn_pred_b['class_' + str(label)]
targets_b = np.array(cnn_pred_b['target_' + str(label)])

: 

In [None]:
target_fpr = 0.2

fpr, tpr, thres = roc_curve(targets, preds)
fpr_w, tpr_w, thres_w = roc_curve(targets_w, preds_w)
fpr_a, tpr_a, thres_a = roc_curve(targets_a, preds_a)
fpr_b, tpr_b, thres_b = roc_curve(targets_b, preds_b)
fpr_m, tpr_m, thres_m = roc_curve(targets_m, preds_m)
fpr_f, tpr_f, thres_f = roc_curve(targets_f, preds_f)

roc_auc = auc(fpr, tpr)
roc_auc_w = auc(fpr_w, tpr_w)
roc_auc_a = auc(fpr_a, tpr_a)
roc_auc_b = auc(fpr_b, tpr_b)
roc_auc_m = auc(fpr_m, tpr_m)
roc_auc_f = auc(fpr_f, tpr_f)

# global threshold
op = thres[np.argmin(np.abs(fpr-target_fpr))]

print('All \t Threshold %0.4f' % op)

: 

In [None]:
# APPLYING GLOBAL THRESHOLD
fpr_t = 1 - recall_score(targets, preds>=op, pos_label=0)
tpr_t = recall_score(targets, preds>=op, pos_label=1)

fpr_t_w = 1 - recall_score(targets_w, preds_w>=op, pos_label=0)
tpr_t_w = recall_score(targets_w, preds_w>=op, pos_label=1)

fpr_t_a = 1 - recall_score(targets_a, preds_a>=op, pos_label=0)
tpr_t_a = recall_score(targets_a, preds_a>=op, pos_label=1)

fpr_t_b = 1 - recall_score(targets_b, preds_b>=op, pos_label=0)
tpr_t_b = recall_score(targets_b, preds_b>=op, pos_label=1)

fpr_t_f = 1 - recall_score(targets_f, preds_f>=op, pos_label=0)
tpr_t_f = recall_score(targets_f, preds_f>=op, pos_label=1)

fpr_t_m = 1 - recall_score(targets_m, preds_m>=op, pos_label=0)
tpr_t_m = recall_score(targets_m, preds_m>=op, pos_label=1)

: 

In [None]:
fig, ax = plt.subplots(figsize=(7,4))
plt.plot(fpr_w, tpr_w, lw=1.5, alpha=.8, label='White AUC=%0.2f' % roc_auc_w)
plt.plot(fpr_a, tpr_a, lw=1.5, alpha=.8, label='Asian AUC=%0.2f' % roc_auc_a)
plt.plot(fpr_b, tpr_b, lw=1.5, alpha=.8, label='Black AUC=%0.2f' % roc_auc_b)
plt.plot(fpr_f, tpr_f, lw=1.5, alpha=.8, label='Female AUC=%0.2f' % roc_auc_f)
plt.plot(fpr_m, tpr_m, lw=1.5, alpha=.8, label='Male AUC=%0.2f' % roc_auc_m)

plt.gca().set_prop_cycle(None)

plt.plot(fpr_t_w, tpr_t_w, 'X', alpha=.8, markersize=10, label='TPR=%0.2f FPR=%0.2f' % (tpr_t_w,fpr_t_w))
plt.plot(fpr_t_a, tpr_t_a, 'X', alpha=.8, markersize=10, label='TPR=%0.2f FPR=%0.2f' % (tpr_t_a,fpr_t_a))
plt.plot(fpr_t_b, tpr_t_b, 'X', alpha=.8, markersize=10, label='TPR=%0.2f FPR=%0.2f' % (tpr_t_b,fpr_t_b))
plt.plot(fpr_t_f, tpr_t_f, 'X', alpha=.8, markersize=10, label='TPR=%0.2f FPR=%0.2f' % (tpr_t_f,fpr_t_f))
plt.plot(fpr_t_m, tpr_t_m, 'X', alpha=.8, markersize=10, label='TPR=%0.2f FPR=%0.2f' % (tpr_t_m,fpr_t_m))

plt.legend(loc="lower right", fontsize=12, ncol=2)
title = labels[label] + ' - Original Test-set'
ax.plot([0, 1], [0, 1], linestyle='--', lw=1.5, color='k', label='Chance', alpha=.8)
plt.xlabel('False Positive Rate', fontsize=14)
plt.ylabel('True Positive Rate', fontsize=14)
plt.title(title, fontsize=14)
ax.set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05])

plt.show()
# fig.savefig(title + '.png', bbox_inches='tight', dpi=300)

print('All\tTPR %0.2f | FPR %0.2f | AUC %0.2f' % (tpr_t,fpr_t,roc_auc))
print('White\tTPR %0.2f | FPR %0.2f | AUC %0.2f' % (tpr_t_w,fpr_t_w,roc_auc_w))
print('Asian\tTPR %0.2f | FPR %0.2f | AUC %0.2f' % (tpr_t_a,fpr_t_a,roc_auc_a))
print('Black\tTPR %0.2f | FPR %0.2f | AUC %0.2f' % (tpr_t_b,fpr_t_b,roc_auc_b))
print('Female\tTPR %0.2f | FPR %0.2f | AUC %0.2f' % (tpr_t_f,fpr_t_f,roc_auc_f))
print('Male\tTPR %0.2f | FPR %0.2f | AUC %0.2f' % (tpr_t_m,fpr_t_m,roc_auc_m))

: 

: 