In [22]:
import logging
import os
import sys
sys.path.append("../")
import glob
import numpy as np
import torch
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import monai
from monai.data import ImageDataset, DataLoader
import monai.transforms as transforms
from monai.transforms import EnsureChannelFirst, Compose, RandRotate90, Resize, ScaleIntensity
import nibabel as nib
import pandas as pd
from sklearn.metrics import fbeta_score, confusion_matrix, roc_auc_score, f1_score, brier_score_loss
from sklearn.metrics import fbeta_score, make_scorer
from utils.custom_transforms import ScaleIntensityFromHistogramPeak, SetBackgroundToZero, SelectChannelsd

In [6]:
#ROOT_DIR = "/home/fehrdelt/bettik/"
ROOT_DIR = "/bettik/PROJECTS/pr-gin5_aini/fehrdelt/"

In [7]:
EXPERIMENT_NAME = "experiment_0"
SUB_EXPERIMENT_NAME = "exp_0_0"

MODELS_DIR = ROOT_DIR+f"StrokeUADiag/{EXPERIMENT_NAME}/{SUB_EXPERIMENT_NAME}/models/"
os.makedirs(MODELS_DIR, exist_ok=True)

In [8]:


ddp_bool = False

rank = 0
world_size = 1
device = 0

torch.cuda.set_device(device)
print(f"Using {device}")

torch.backends.cudnn.benchmark = True
torch.set_num_threads(torch.get_num_threads())
torch.autograd.set_detect_anomaly(False)

Using 0


<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f7ff56b3390>

In [63]:


test_df = pd.read_csv(ROOT_DIR+f"StrokeUADiag/data_splits_lists/soop/test.csv")
test_files = [{"img": img, "label": label} for img, label in zip(test_df["participant_id"], test_df["high_nihss"])]
for item in test_files:
    item["img"] = ROOT_DIR+"datasets/StrokeUADiag_classification_inputs/stacked_"+item["img"]+".nii.gz"

#test_unhealthy_datalist = test_unhealthy_images_path

batch_size = 2
num_workers = 4

test_transforms = Compose([
    transforms.LoadImaged(keys=["img"]),
    transforms.EnsureChannelFirstd(keys=["img"]),
    SelectChannelsd(keys=["img"], selected_channels=[1,3]),
    transforms.ResizeWithPadOrCropd(keys=["img"], spatial_size=(128, 128, 128)),
    #SetBackgroundToZero()
    ])

test_ds = monai.data.CacheDataset(data=test_files, transform=test_transforms)




if ddp_bool:
    test_sampler = torch.utils.data.distributed.DistributedSampler(test_ds, num_replicas=world_size, rank=rank)
else:
    test_sampler = None


test_loader = DataLoader(
    test_ds, batch_size=batch_size, shuffle=(not ddp_bool), num_workers=num_workers, pin_memory=True, sampler=test_sampler
)


Loading dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53/53 [00:28<00:00,  1.86it/s]


In [64]:
model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=2, out_channels=2).to(device)

model.load_state_dict(torch.load(f"{ROOT_DIR}StrokeUADiag/experiment_0/exp_0_5/models/exp_0_5_best_model.pth"))
model.eval()


loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-5)

In [65]:

y_pred_list = []
y_true_list = []

model.eval()
with torch.no_grad():

    val_epoch_loss = 0
    for step, test_data in enumerate(tqdm(test_loader)):
        test_images, test_labels = test_data['img'].to(device), test_data['label'].to(device)
        test_outputs = model(test_images)
        #print(f"test_outputs: {test_outputs}")
        #print(f"test_outputs.argsort(dim=1): {test_outputs.argmax(dim=1)}")
        y_pred_list.extend(test_outputs.argmax(dim=1).cpu().numpy()) #TODO check
        y_true_list.extend(test_labels.cpu().numpy())


y_pred = np.array(y_pred_list)
y_true = np.array(y_true_list)



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 32.08it/s]


In [66]:
rng = np.random.RandomState(seed=12345)

In [67]:

test_roc_auc = []
test_f1 = []
test_ftwos = []
test_brier = []
test_false_neg = []
test_false_pos = []
test_sensitivity = []
test_specificity = []
test_PPV = []
test_NPV = []
test_lr_plus = []
test_lr_minus = []
test_youden_index = []
test_tn = []
test_fp = []
test_fn = []
test_tp = []

idx = np.arange(len(y_true))

for i in range(200): 
    # bootstrap with 200 rounds: random sampling with replacement of the predictions

    pred_idx = rng.choice(idx, size=len(idx), replace=True)
    
    roc_auc_test_boot = roc_auc_score(y_score=y_pred[pred_idx], y_true=y_true[pred_idx])
    f1_test_boot = f1_score(y_pred=y_pred[pred_idx], y_true=y_true[pred_idx])
    f2_test_boot = fbeta_score(y_pred=y_pred[pred_idx], y_true=y_true[pred_idx], beta=2)
    brier_test_boot = brier_score_loss(y_proba=y_pred[pred_idx], y_true=y_true[pred_idx])
    false_neg_test_boot = confusion_matrix(y_true[pred_idx], y_pred[pred_idx])[1,0]
    false_pos_test_boot = confusion_matrix(y_true[pred_idx], y_pred[pred_idx])[0,1]
    # Sensitivity (Recall) and Specificity
    tn, fp, fn, tp = confusion_matrix(y_true[pred_idx], y_pred[pred_idx]).ravel()
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else np.nan
    specificity = tn / (tn + fp) if (tn + fp) > 0 else np.nan
    # Positive Predictive Value (PPV) and Negative Predictive Value (NPV)
    ppv = tp / (tp + fp) if (tp + fp) > 0 else np.nan
    npv = tn / (tn + fn) if (tn + fn) > 0 else np.nan
    # Likelihood Ratios
    lr_plus = sensitivity / (1 - specificity) if (1 - specificity) > 0 else np.nan
    lr_minus = (1 - sensitivity) / specificity if specificity > 0 else np.nan

    # Youden index (sensitivity + specificity - 1)
    youden_index = sensitivity + specificity - 1

    test_roc_auc.append(roc_auc_test_boot)
    test_f1.append(f1_test_boot)
    test_ftwos.append(f2_test_boot)
    test_brier.append(brier_test_boot)
    test_false_neg.append(false_neg_test_boot/len(idx)*100)
    test_false_pos.append(false_pos_test_boot/len(idx)*100)
    test_sensitivity.append(sensitivity * 100)  # Convert to percentage
    test_specificity.append(specificity * 100)
    test_PPV.append(ppv * 100)  # Convert to percentage
    test_NPV.append(npv * 100)  # Convert to percentage
    test_lr_plus.append(lr_plus)
    test_lr_minus.append(lr_minus)
    test_youden_index.append(youden_index)
    test_tn.append(tn)
    test_fp.append(fp)
    test_fn.append(fn)
    test_tp.append(tp)


print("Classification performance\n")
output = {}

# Compute the mean and 95% confidence intervals for each metric
# 95% confidence intervals are computed using the 2.5th and 97.5th percentiles of the bootstrap samples
bootstrap_roc_auc_test_mean = np.mean(test_roc_auc)
ci_lower = np.percentile(test_roc_auc, 2.5)     # 2.5 percentile (alpha=0.025)
ci_upper = np.percentile(test_roc_auc, 97.5)
output["ROC AUC"] = f"{bootstrap_roc_auc_test_mean:.2f}  -  95% CI {ci_lower:.2f}-{ci_upper:.2f}"
#print(f"ROC AUC:         {bootstrap_roc_auc_test_mean:.2f}  -  95% CI {ci_lower:.2f}-{ci_upper:.2f}")

bootstrap_f1_test_mean = np.mean(test_f1)
ci_lower = np.percentile(test_f1, 2.5)
ci_upper = np.percentile(test_f1, 97.5)
output["F1"] = f"{bootstrap_f1_test_mean:.2f}  -  95% CI {ci_lower:.2f}-{ci_upper:.2f}"
#print(f"F1:              {bootstrap_f1_test_mean:.2f}  -  95% CI {ci_lower:.2f}-{ci_upper:.2f}")

bootstrap_f2_test_mean = np.mean(test_ftwos)
ci_lower = np.percentile(test_ftwos, 2.5)
ci_upper = np.percentile(test_ftwos, 97.5)
output["F2"] = f"{bootstrap_f2_test_mean:.2f}  -  95% CI {ci_lower:.2f}-{ci_upper:.2f}"
#print(f"F2:              {bootstrap_f2_test_mean:.2f}  -  95% CI {ci_lower:.2f}-{ci_upper:.2f}")

bootstrap_brier_test_mean = np.mean(test_brier)
ci_lower = np.percentile(test_brier, 2.5)
ci_upper = np.percentile(test_brier, 97.5)
output["Brier loss"] = f"{bootstrap_brier_test_mean:.2f}  -  95% CI {ci_lower:.2f}-{ci_upper:.2f}"
#print(f"Brier loss:      {bootstrap_brier_test_mean:.2f}  -  95% CI {ci_lower:.2f}-{ci_upper:.2f}")

bootstrap_false_neg_test_mean = np.mean(test_false_neg)
ci_lower = np.percentile(test_false_neg, 2.5)
ci_upper = np.percentile(test_false_neg, 97.5)
output["False negatives"] = f"{bootstrap_false_neg_test_mean:.2f}%  -  95% CI {ci_lower:.2f}-{ci_upper:.2f}"
#print(f"False negatives: {bootstrap_false_neg_test_mean:.2f}%  - 95% CI {ci_lower:.2f}-{ci_upper:.2f}")

bootstrap_false_pos_test_mean = np.mean(test_false_pos)
ci_lower = np.percentile(test_false_pos, 2.5)
ci_upper = np.percentile(test_false_pos, 97.5)
output["False positives"] = f"{bootstrap_false_pos_test_mean:.2f}%  -  95% CI {ci_lower:.2f}-{ci_upper:.2f}"
#print(f"False positives: {bootstrap_false_pos_test_mean:.2f}%  -95% CI {ci_lower:.2f}-{ci_upper:.2f}")

bootstrap_sensitivity = np.mean(test_sensitivity)
ci_lower = np.percentile(test_sensitivity, 2.5)
ci_upper = np.percentile(test_sensitivity, 97.5)
output["Sensitivity"] = f"{bootstrap_sensitivity:.2f}%  -  95% CI {ci_lower:.2f}-{ci_upper:.2f}"
#print(f"Sensitivity:     {bootstrap_sensitivity:.2f}%  -95% CI {ci_lower:.2f}-{ci_upper:.2f}")

bootstrap_specificity = np.mean(test_specificity)
ci_lower = np.percentile(test_specificity, 2.5)
ci_upper = np.percentile(test_specificity, 97.5)
output["Specificity"] = f"{bootstrap_specificity:.2f}%  -  95% CI {ci_lower:.2f}-{ci_upper:.2f}"
#print(f"Specificity:     {bootstrap_specificity:.2f}%  -95% CI {ci_lower:.2f}-{ci_upper:.2f}")

bootstrap_PPV = np.mean(test_PPV)
ci_lower = np.percentile(test_PPV, 2.5)
ci_upper = np.percentile(test_PPV, 97.5)
output["PPV"] = f"{bootstrap_PPV:.2f}%  -  95% CI {ci_lower:.2f}-{ci_upper:.2f}"
#print(f"PPV:     {bootstrap_PPV:.2f}%  -95% CI {ci_lower:.2f}-{ci_upper:.2f}")

bootstrap_NPV = np.mean(test_NPV)
ci_lower = np.percentile(test_NPV, 2.5)
ci_upper = np.percentile(test_NPV, 97.5)
output["NPV"] = f"{bootstrap_NPV:.2f}%  -  95% CI {ci_lower:.2f}-{ci_upper:.2f}"
#print(f"NPV:     {bootstrap_NPV:.2f}%  -95% CI {ci_lower:.2f}-{ci_upper:.2f}")

bootstrap_lr_plus = np.mean(test_lr_plus)
ci_lower = np.percentile(test_lr_plus, 2.5)
ci_upper = np.percentile(test_lr_plus, 97.5)
output["LR+"] = f"{bootstrap_lr_plus:.2f}  -  95% CI {ci_lower:.2f}-{ci_upper:.2f}"
#print(f"Likelihood Ratio +: {bootstrap_lr_plus:.2f}  -95% CI {ci_lower:.2f}-{ci_upper:.2f}")

bootstrap_lr_minus = np.mean(test_lr_minus)
ci_lower = np.percentile(test_lr_minus, 2.5)
ci_upper = np.percentile(test_lr_minus, 97.5)
output["LR-"] = f"{bootstrap_lr_minus:.2f}  -  95% CI {ci_lower:.2f}-{ci_upper:.2f}"
#print(f"Likelihood Ratio +: {bootstrap_lr_minus:.2f}  -95% CI {ci_lower:.2f}-{ci_upper:.2f}")

bootstrap_youden_index = np.mean(test_youden_index)
ci_lower = np.percentile(test_youden_index, 2.5)
ci_upper = np.percentile(test_youden_index, 97.5)
output["Youden Index"] = f"{bootstrap_youden_index:.2f}  -  95% CI {ci_lower:.2f}-{ci_upper:.2f}"
#print(f"Youden Index: {bootstrap_youden_index:.2f}  -95% CI {ci_lower:.2f}-{ci_upper:.2f}")

output["Average TP"] = [np.mean(test_tp)]
output["Average TN"] = [np.mean(test_tn)]
output["Average FP"] = [np.mean(test_fp)]
output["Average FN"] = [np.mean(test_fn)]

print(output)

Classification performance

{'ROC AUC': '0.63  -  95% CI 0.52-0.77', 'F1': '0.49  -  95% CI 0.26-0.69', 'F2': '0.44  -  95% CI 0.20-0.66', 'Brier loss': '0.30  -  95% CI 0.15-0.43', 'False negatives': '20.77%  -  95% CI 9.43-32.08', 'False positives': '9.54%  -  95% CI 3.77-16.98', 'Sensitivity': '41.84%  -  95% CI 18.72-65.09', 'Specificity': '85.06%  -  95% CI 72.71-95.00', 'PPV': '60.67%  -  95% CI 36.35-85.71', 'NPV': '72.48%  -  95% CI 57.77-86.70', 'LR+': '3.52  -  95% CI 1.15-10.26', 'LR-': '0.69  -  95% CI 0.40-0.96', 'Youden Index': '0.27  -  95% CI 0.04-0.54', 'Average TP': [np.float64(7.905)], 'Average TN': [np.float64(29.03)], 'Average FP': [np.float64(5.055)], 'Average FN': [np.float64(11.01)]}


In [68]:
print(y_pred)

[0 1 1 0 0 0 0 0 1 0 0 1 0 0 0 1 0 1 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 1 1 0
 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0]


In [56]:
print(y_true)

[1. 1. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 1. 0. 1. 1. 1. 0. 1. 0. 1. 1. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 1. 0. 0. 0. 0. 1. 0. 1. 0. 1.
 1. 0. 1. 0. 0.]
