# Setup imports

In [1]:
import logging
import os
import random
import sys

import numpy as np
import pandas as pd
import torch
torch.cuda.empty_cache()

from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import DataLoader, ImageDataset
from monai.transforms import AddChannel, Compose, RandRotate90
from monai.transforms import Resize, ScaleIntensity, EnsureType

from datetime import datetime
from tqdm import tqdm

from constants import Constants

pin_memory = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logging.basicConfig(stream=sys.stdout, level=logging.INFO)

constants = Constants()

In [2]:
LABELS = constants.labels_pkl

LOGS = constants.logs
if not os.path.exists(LOGS): os.mkdir(LOGS)
    
BEST_METRICS = constants.best_metrics
if not os.path.exists(BEST_METRICS): os.mkdir(BEST_METRICS)

## Initialize log saving 

In [3]:
def log_write(command, file):
    file.write(f'{command}\n')
    print(command)

# Get existing data

In [4]:
def get_patient_data(patient_id, labels):
    idx = get_patient_idx(patient_id)
    patient_data = {
        'data' : {
            'ProxID'  : labels.ProxID.iloc[idx],
            'ClinSig' : labels.ClinSig.iloc[idx],
            'fid'     : labels.fid.iloc[idx],
            'pos'     : labels.pos.iloc[idx],
            'zone'    : labels.zone.iloc[idx],
            'spacing' : labels.spacing.iloc[idx],
            'slices'  : labels.slices.iloc[idx] 
        },
        'images' : {
            'T2'     : labels.T2.iloc[idx],
            'ADC'    : labels.ADC.iloc[idx],
            'DWI'    : labels.DWI.iloc[idx],
            'KTrans' : labels.KTrans.iloc[idx] 
        }
    }
    return patient_data


def get_patientID(filename):
    return filename[:14]


def get_patient_idx(patient_id):
    return int(patient_id[10:])

In [5]:
labels_df = pd.read_pickle(LABELS)

# Split the dataset randomly

In [6]:
patients = list(labels_df.ProxID)

TRAIN_TEST_RATIO = constants.split_ratio
train_num = int(len(patients) * TRAIN_TEST_RATIO)

random.shuffle(patients)

train_data, test_data = patients[:train_num], patients[train_num:] 

len(train_data), len(test_data)

(163, 41)

# Prepare training data 

In [7]:
def get_lesion_summary(ClinSig):
    summary = False
    for cs in ClinSig:
        summary = summary or cs
    return summary


def get_images_and_labels(data, labels):
    images_arr = []
    labels_arr = []
    num_files_idx = []

    for patient_id in data:
        patient_data = get_patient_data(patient_id, labels)

        label = 0
        clin_sig = get_lesion_summary(patient_data['data']['ClinSig'])
        if clin_sig: label = 1

        for t2 in patient_data['images']['T2']: 
            images_arr.append(t2)
            labels_arr.append(label)
        for adc in patient_data['images']['ADC']: 
            images_arr.append(adc)
            labels_arr.append(label)
        for dwi in patient_data['images']['DWI']: 
            images_arr.append(dwi)
            labels_arr.append(label)
        images_arr.append(patient_data['images']['KTrans'])
        labels_arr.append(label)
        num_files_idx.append(len(images_arr))
    return images_arr, labels_arr, num_files_idx

In [8]:
train_images_arr, labels_arr, num_files_idx = get_images_and_labels(train_data, labels_df)

train_labels_arr = np.array(labels_arr)
train_labels_arr = torch.nn.functional.one_hot(torch.as_tensor(labels_arr)).float()

len(train_images_arr), len(train_labels_arr), train_labels_arr.shape

(1020, 1020, torch.Size([1020, 2]))

## Define transforms

In [9]:
RESIZE_SIZE = constants.image_resize

train_transforms = Compose([
    ScaleIntensity(), 
    AddChannel(), 
    Resize((RESIZE_SIZE, RESIZE_SIZE, RESIZE_SIZE)), 
    RandRotate90(), 
    EnsureType()
])

val_transforms = Compose([
    ScaleIntensity(), 
    AddChannel(), 
    Resize((RESIZE_SIZE, RESIZE_SIZE, RESIZE_SIZE)), 
    EnsureType()
])

## Check loaders

In [10]:
BATCH_SIZE = constants.batch_size
NUM_WORKERS = constants.num_workers

# Define nifti dataset, data loader
check_ds = ImageDataset(
    image_files=train_images_arr, 
    labels=train_labels_arr, 
    transform=train_transforms
)
check_loader = DataLoader(
    check_ds, 
    batch_size=BATCH_SIZE, 
    num_workers=NUM_WORKERS, 
    pin_memory=pin_memory
)

im, label = monai.utils.misc.first(check_loader)
print(type(im), im.shape, label, label.shape)

<class 'torch.Tensor'> torch.Size([2, 1, 168, 168, 168]) tensor([[1., 0.],
        [1., 0.]]) torch.Size([2, 2])


## Create data loaders 

In [11]:
TRAIN_VAL_RATIO = constants.split_ratio
train_val_num = num_files_idx[int(len(train_data) * TRAIN_VAL_RATIO)] - 1 

train_images, val_images = train_images_arr[:train_val_num], train_images_arr[train_val_num:]
train_labels, val_labels = train_labels_arr[:train_val_num], train_labels_arr[train_val_num:]

# create a training data loader
train_ds = ImageDataset(
    image_files=train_images, 
    labels=train_labels, 
    transform=train_transforms
)
train_loader = DataLoader(
    train_ds, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=NUM_WORKERS, 
    pin_memory=pin_memory
)

# create a validation data loader
val_ds = ImageDataset(
    image_files=val_images, 
    labels=val_labels, 
    transform=val_transforms
)
val_loader = DataLoader(
    val_ds, 
    batch_size=BATCH_SIZE, 
    num_workers=NUM_WORKERS, 
    pin_memory=pin_memory
)

len(train_loader), len(val_loader)

(409, 101)

# Train the data

## Define network, loss function, and optimizer

In [12]:
today = datetime.today()
date_format_metric = today.strftime("%Y%m%d_%H%M%S")
print(date_format_metric)

model = constants.model.to(device)
loss_function = constants.loss_function
optimizer = constants.optimizer

20220718_163533


## Training parameters

In [13]:
EPOCHS = constants.epochs
MODEL_TYPE = constants.model_type
LOSS_TYPE = constants.loss_type
OPTIMIZER_TYPE = constants.optimizer_type
LR = constants.lr

val_interval = 1
best_metric = -1
best_metric_epoch = -1

epoch_loss_values = []
metric_values = []
writer = SummaryWriter()

## Start training 

In [14]:
log_filename_train = LOGS / f'{date_format_metric}_E{EPOCHS}_train.txt'
log_filename_test = LOGS / f'{date_format_metric}_E{EPOCHS}_test.txt'

In [15]:
log_file = open(log_filename_train, 'w')

log_write(f'Model         :: {MODEL_TYPE}', log_file)
log_write(f'Loss function :: {LOSS_TYPE}', log_file)
log_write(f'Optimizer     :: {OPTIMIZER_TYPE}', log_file)
log_write(f'Learning rate :: {LR}\n', log_file)

for epoch in range(EPOCHS):
    log_write("-" * 10, log_file)
    log_write(f"Epoch {epoch + 1}/{EPOCHS}", log_file)
    model.train()
    epoch_loss = 0
    step = 0
    
    MODEL_STATE_DICT = BEST_METRICS / f'{date_format_metric}.pth'
    
    for batch_data in tqdm(train_loader):
        step += 1
        inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(train_ds) // train_loader.batch_size
        writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)

    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    log_write(f"Epoch {epoch + 1} average loss: {epoch_loss:.4f}", log_file)

    if (epoch + 1) % val_interval == 0:
        model.eval()

        num_correct, metric_count = 0.0, 0
        ones, zeros = 0, 0
        
        for val_data in val_loader:
            val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
            with torch.no_grad():
                val_outputs = model(val_images).argmax(dim=1)
                val_labels = val_labels.argmax(dim=1)
                value = torch.eq(val_outputs, val_labels)
                metric_count += len(value)
                num_correct += value.sum().item()

        metric = num_correct / metric_count
        metric_values.append(metric)

        if metric > best_metric:
            best_metric = metric
            best_metric_epoch = epoch + 1
            torch.save(model.state_dict(), MODEL_STATE_DICT)
            log_write("Saved new best metric model", log_file)

        log_write(f"Current epoch: {epoch+1} current accuracy: {metric:.4f}", log_file)
        log_write(f"Best accuracy: {best_metric:.4f} at epoch {best_metric_epoch}", log_file)
        writer.add_scalar("val_accuracy", metric, epoch + 1)

log_write(f"Training completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}", log_file)
writer.close()
log_file.close()

Model         :: DensNet121(3, 1, 2)
Loss function :: Binary Cross Entropy
Optimizer     :: Adam
Learning rate :: 0.0001

----------
Epoch 1/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:52<00:00,  1.40it/s]


Epoch 1 average loss: 0.6419
Saved new best metric model
Current epoch: 1 current accuracy: 0.4851
Best accuracy: 0.4851 at epoch 1
----------
Epoch 2/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:47<00:00,  1.42it/s]


Epoch 2 average loss: 0.6312
Current epoch: 2 current accuracy: 0.4851
Best accuracy: 0.4851 at epoch 1
----------
Epoch 3/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:47<00:00,  1.42it/s]


Epoch 3 average loss: 0.6313
Current epoch: 3 current accuracy: 0.4851
Best accuracy: 0.4851 at epoch 1
----------
Epoch 4/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:47<00:00,  1.42it/s]


Epoch 4 average loss: 0.6292
Current epoch: 4 current accuracy: 0.4851
Best accuracy: 0.4851 at epoch 1
----------
Epoch 5/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:46<00:00,  1.43it/s]


Epoch 5 average loss: 0.6293
Current epoch: 5 current accuracy: 0.4851
Best accuracy: 0.4851 at epoch 1
----------
Epoch 6/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:45<00:00,  1.43it/s]


Epoch 6 average loss: 0.6247
Current epoch: 6 current accuracy: 0.4851
Best accuracy: 0.4851 at epoch 1
----------
Epoch 7/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:45<00:00,  1.43it/s]


Epoch 7 average loss: 0.6240
Current epoch: 7 current accuracy: 0.4851
Best accuracy: 0.4851 at epoch 1
----------
Epoch 8/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:46<00:00,  1.43it/s]


Epoch 8 average loss: 0.6276
Current epoch: 8 current accuracy: 0.4851
Best accuracy: 0.4851 at epoch 1
----------
Epoch 9/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:47<00:00,  1.42it/s]


Epoch 9 average loss: 0.6265
Current epoch: 9 current accuracy: 0.4851
Best accuracy: 0.4851 at epoch 1
----------
Epoch 10/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:48<00:00,  1.42it/s]


Epoch 10 average loss: 0.6242
Current epoch: 10 current accuracy: 0.4851
Best accuracy: 0.4851 at epoch 1
----------
Epoch 11/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:45<00:00,  1.43it/s]


Epoch 11 average loss: 0.6257
Current epoch: 11 current accuracy: 0.4851
Best accuracy: 0.4851 at epoch 1
----------
Epoch 12/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:46<00:00,  1.43it/s]


Epoch 12 average loss: 0.6270
Current epoch: 12 current accuracy: 0.4851
Best accuracy: 0.4851 at epoch 1
----------
Epoch 13/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:46<00:00,  1.43it/s]


Epoch 13 average loss: 0.6252
Current epoch: 13 current accuracy: 0.4851
Best accuracy: 0.4851 at epoch 1
----------
Epoch 14/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:46<00:00,  1.43it/s]


Epoch 14 average loss: 0.6251
Current epoch: 14 current accuracy: 0.4851
Best accuracy: 0.4851 at epoch 1
----------
Epoch 15/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:46<00:00,  1.43it/s]


Epoch 15 average loss: 0.6251
Current epoch: 15 current accuracy: 0.4851
Best accuracy: 0.4851 at epoch 1
----------
Epoch 16/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:46<00:00,  1.43it/s]


Epoch 16 average loss: 0.6227
Saved new best metric model
Current epoch: 16 current accuracy: 0.4950
Best accuracy: 0.4950 at epoch 16
----------
Epoch 17/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:46<00:00,  1.43it/s]


Epoch 17 average loss: 0.6271
Current epoch: 17 current accuracy: 0.4851
Best accuracy: 0.4950 at epoch 16
----------
Epoch 18/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:46<00:00,  1.43it/s]


Epoch 18 average loss: 0.6218
Saved new best metric model
Current epoch: 18 current accuracy: 0.5050
Best accuracy: 0.5050 at epoch 18
----------
Epoch 19/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:45<00:00,  1.43it/s]


Epoch 19 average loss: 0.6197
Current epoch: 19 current accuracy: 0.4851
Best accuracy: 0.5050 at epoch 18
----------
Epoch 20/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:45<00:00,  1.43it/s]


Epoch 20 average loss: 0.6181
Current epoch: 20 current accuracy: 0.4851
Best accuracy: 0.5050 at epoch 18
----------
Epoch 21/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:45<00:00,  1.43it/s]


Epoch 21 average loss: 0.6196
Current epoch: 21 current accuracy: 0.4851
Best accuracy: 0.5050 at epoch 18
----------
Epoch 22/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:46<00:00,  1.43it/s]


Epoch 22 average loss: 0.6165
Current epoch: 22 current accuracy: 0.4851
Best accuracy: 0.5050 at epoch 18
----------
Epoch 23/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:45<00:00,  1.43it/s]


Epoch 23 average loss: 0.6162
Current epoch: 23 current accuracy: 0.4851
Best accuracy: 0.5050 at epoch 18
----------
Epoch 24/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:46<00:00,  1.43it/s]


Epoch 24 average loss: 0.6188
Current epoch: 24 current accuracy: 0.4851
Best accuracy: 0.5050 at epoch 18
----------
Epoch 25/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:45<00:00,  1.43it/s]


Epoch 25 average loss: 0.6155
Saved new best metric model
Current epoch: 25 current accuracy: 0.5099
Best accuracy: 0.5099 at epoch 25
----------
Epoch 26/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:45<00:00,  1.43it/s]


Epoch 26 average loss: 0.6177
Current epoch: 26 current accuracy: 0.4851
Best accuracy: 0.5099 at epoch 25
----------
Epoch 27/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:44<00:00,  1.44it/s]


Epoch 27 average loss: 0.6195
Current epoch: 27 current accuracy: 0.4851
Best accuracy: 0.5099 at epoch 25
----------
Epoch 28/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:44<00:00,  1.44it/s]


Epoch 28 average loss: 0.6215
Current epoch: 28 current accuracy: 0.4851
Best accuracy: 0.5099 at epoch 25
----------
Epoch 29/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:46<00:00,  1.43it/s]


Epoch 29 average loss: 0.6172
Current epoch: 29 current accuracy: 0.4851
Best accuracy: 0.5099 at epoch 25
----------
Epoch 30/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:44<00:00,  1.44it/s]


Epoch 30 average loss: 0.6205
Current epoch: 30 current accuracy: 0.4950
Best accuracy: 0.5099 at epoch 25
----------
Epoch 31/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:45<00:00,  1.43it/s]


Epoch 31 average loss: 0.6162
Current epoch: 31 current accuracy: 0.4851
Best accuracy: 0.5099 at epoch 25
----------
Epoch 32/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:45<00:00,  1.43it/s]


Epoch 32 average loss: 0.6051
Current epoch: 32 current accuracy: 0.4851
Best accuracy: 0.5099 at epoch 25
----------
Epoch 33/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:45<00:00,  1.43it/s]


Epoch 33 average loss: 0.6009
Current epoch: 33 current accuracy: 0.4901
Best accuracy: 0.5099 at epoch 25
----------
Epoch 34/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:47<00:00,  1.42it/s]


Epoch 34 average loss: 0.6032
Current epoch: 34 current accuracy: 0.4703
Best accuracy: 0.5099 at epoch 25
----------
Epoch 35/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:45<00:00,  1.43it/s]


Epoch 35 average loss: 0.6047
Current epoch: 35 current accuracy: 0.4950
Best accuracy: 0.5099 at epoch 25
----------
Epoch 36/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:45<00:00,  1.43it/s]


Epoch 36 average loss: 0.5946
Saved new best metric model
Current epoch: 36 current accuracy: 0.5149
Best accuracy: 0.5149 at epoch 36
----------
Epoch 37/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:45<00:00,  1.43it/s]


Epoch 37 average loss: 0.6006
Current epoch: 37 current accuracy: 0.4851
Best accuracy: 0.5149 at epoch 36
----------
Epoch 38/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:46<00:00,  1.43it/s]


Epoch 38 average loss: 0.5962
Current epoch: 38 current accuracy: 0.4851
Best accuracy: 0.5149 at epoch 36
----------
Epoch 39/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:46<00:00,  1.43it/s]


Epoch 39 average loss: 0.5837
Current epoch: 39 current accuracy: 0.5050
Best accuracy: 0.5149 at epoch 36
----------
Epoch 40/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:47<00:00,  1.42it/s]


Epoch 40 average loss: 0.5753
Current epoch: 40 current accuracy: 0.4604
Best accuracy: 0.5149 at epoch 36
----------
Epoch 41/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:46<00:00,  1.43it/s]


Epoch 41 average loss: 0.5892
Current epoch: 41 current accuracy: 0.4802
Best accuracy: 0.5149 at epoch 36
----------
Epoch 42/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:46<00:00,  1.43it/s]


Epoch 42 average loss: 0.5765
Current epoch: 42 current accuracy: 0.5050
Best accuracy: 0.5149 at epoch 36
----------
Epoch 43/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:44<00:00,  1.44it/s]


Epoch 43 average loss: 0.5751
Current epoch: 43 current accuracy: 0.5099
Best accuracy: 0.5149 at epoch 36
----------
Epoch 44/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:45<00:00,  1.43it/s]


Epoch 44 average loss: 0.5564
Current epoch: 44 current accuracy: 0.5000
Best accuracy: 0.5149 at epoch 36
----------
Epoch 45/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:45<00:00,  1.43it/s]


Epoch 45 average loss: 0.5716
Current epoch: 45 current accuracy: 0.4851
Best accuracy: 0.5149 at epoch 36
----------
Epoch 46/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:45<00:00,  1.43it/s]


Epoch 46 average loss: 0.5513
Current epoch: 46 current accuracy: 0.4851
Best accuracy: 0.5149 at epoch 36
----------
Epoch 47/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:44<00:00,  1.44it/s]


Epoch 47 average loss: 0.5425
Current epoch: 47 current accuracy: 0.4950
Best accuracy: 0.5149 at epoch 36
----------
Epoch 48/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:46<00:00,  1.43it/s]


Epoch 48 average loss: 0.5390
Current epoch: 48 current accuracy: 0.4950
Best accuracy: 0.5149 at epoch 36
----------
Epoch 49/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:45<00:00,  1.43it/s]


Epoch 49 average loss: 0.5302
Current epoch: 49 current accuracy: 0.4901
Best accuracy: 0.5149 at epoch 36
----------
Epoch 50/50


100%|████████████████████████████████████████████████████████████████████████████████| 409/409 [04:45<00:00,  1.43it/s]


Epoch 50 average loss: 0.5191
Current epoch: 50 current accuracy: 0.4851
Best accuracy: 0.5149 at epoch 36
Training completed, best_metric: 0.5149 at epoch: 36


# Test the model

## Prepare test data 

In [16]:
test_images_arr, labels_arr, _ = get_images_and_labels(test_data, labels_df)

test_labels_arr = np.array(labels_arr)
test_labels_arr = torch.nn.functional.one_hot(torch.as_tensor(labels_arr)).float()

len(test_images_arr), len(test_labels_arr), test_labels_arr.shape

(269, 269, torch.Size([269, 2]))

## Create test loader

In [17]:
BATCH_SIZE_TEST = constants.batch_size_test
NUM_WORKERS_TEST = constants.num_workers

test_ds = ImageDataset(
    image_files=test_images_arr, 
    labels=test_labels_arr, 
    transform=val_transforms
)

test_loader = DataLoader(
    test_ds, 
    batch_size=BATCH_SIZE_TEST, 
    num_workers=NUM_WORKERS_TEST, 
    pin_memory=torch.cuda.is_available()
)

In [18]:
def calculate_KPI(kpi, output, label):
    tp, tn, fp, fn = kpi
    
    type_KPI = ''
    
    if label[0]:
        if output[0]: 
            tp += 1
            type_KPI = 'TP'
        else: 
            fn +=1
            type_KPI = 'FN'
    else:
        if output[0]: 
            fp += 1
            type_KPI = 'FP'
        else: 
            tn +=1
            type_KPI = 'TN'
    return (tp, tn, fp, fn), type_KPI


def get_metrics(kpi):
    tp, tn, fp, fn = kpi
    
    recall, precision, f1 = -1, -1, -1
    accuracy = (tp+tn) / (tp+tn+fp+fn)
    
    if tp + fn != 0: 
        recall = tp / (tp+fn)
        
    if tp + fp != 0: 
        precision = tp / (tp+fp)
        
    if precision > 0 and recall > 0: 
        f1 = (2*precision*recall) / (precision+recall)
    
    return accuracy, recall, precision, f1

In [19]:
log_file = open(log_filename_test, 'w')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_test = constants.model.to(device)

MODEL = BEST_METRICS / f'{date_format_metric}.pth'

log_write(f'Training log :: {log_filename_train}', log_file)
log_write(f'Metrics      :: {MODEL}\n', log_file)
log_write("-" * 10, log_file)

model_test.load_state_dict(torch.load(MODEL))
model_test.eval()

fps, fns = [], []

with torch.no_grad():
    KPI = 0, 0, 0, 0
    
    for i, data in enumerate(tqdm(test_loader)):
        test_images, test_labels = data[0].to(device), data[1].to(device)
        
        test_outputs = model_test(test_images).argmax(dim=1)
        test_labels = test_labels.argmax(dim=1)
        
        KPI, type_KPI = calculate_KPI(KPI, test_outputs, test_labels)
        
        if type_KPI == 'FP':
            fps.append((i, test_labels, test_outputs))
        elif type_KPI == 'FN':
            fns.append((i, test_labels, test_outputs))
        
    log_write(f'(TP, TN, FP, FN) = {KPI}', log_file)
    
    metrics = get_metrics(KPI)
    log_write(
        f'Evaluation metrics:\n'
        f'Accuracy  : {metrics[0]}\n'
        f'Recall    : {metrics[1]}\n'
        f'Precision : {metrics[2]}\n'
        f'F1-score  : {metrics[3]}',
        log_file
    )
    
log_file.close()

Training log :: logs\20220718_163533_E50_train.txt
Metrics      :: best_metrics\20220718_163533.pth

----------


100%|████████████████████████████████████████████████████████████████████████████████| 269/269 [00:57<00:00,  4.66it/s]

(TP, TN, FP, FN) = (2, 177, 4, 86)
Evaluation metrics:
Accuracy  : 0.6654275092936803
Recall    : 0.022727272727272728
Precision : 0.3333333333333333
F1-score  : 0.04255319148936171





# Understanding the KPIs

In [20]:
def get_type(file):
    path = str(file)
    if 't2' in path: return 'T2'
    if 'ADC' in path: return 'ADC'
    if 'Ktrans' in path: return 'Ktrans'
    return 'DWI'


def count_images(images, type_img):
    t2, adc, dwi, ktrans = images
    
    if type_img == 'T2': t2 += 1
    elif type_img == 'ADC': adc += 1
    elif type_img == 'DWI': dwi += 1
    elif type_img == 'Ktrans': ktrans += 1
        
    return t2, adc, dwi, ktrans

In [21]:
images = 0, 0, 0, 0
patients = []

for data in fps:
    idx, lab, pred = data
    
    test_img = test_images_arr[idx]
    test_lab = test_labels_arr[idx]
    
    filename = str(test_img).split('\\')[2]
    patient_id = get_patientID(filename)
    images = count_images(images, get_type(filename))
    
    if patient_id not in patients: patients.append(patient_id)
    print(test_img)

print(f'\n{len(patients)} patients:\n{patients}')
print(f'(T2, ADC, DWI, KTrans) = {images}')

data\processed\ProstateX-0134_ep2d_diff_tra_ep2d_diff_tra_DYNDIST_ADC.nii.gz
data\processed\ProstateX-0172_ep2d_diff_tra_ep2d_diff_tra_DYNDIST_ADC.nii.gz
data\processed\ProstateX-0152_ep2d_diff_tra_ep2d_diff_tra_DYNDIST_ADC.nii.gz
data\processed\ProstateX-0168_ep2d_diff_tra_ep2d_diff_tra_DYNDIST_ADC.nii.gz

4 patients:
['ProstateX-0134', 'ProstateX-0172', 'ProstateX-0152', 'ProstateX-0168']
(T2, ADC, DWI, KTrans) = (0, 4, 0, 0)


In [22]:
images = 0, 0, 0, 0
patients = []

for data in fns:
    idx, lab, pred = data
    
    test_img = test_images_arr[idx]
    test_lab = test_labels_arr[idx]
    
    filename = str(test_img).split('\\')[2]
    patient_id = get_patientID(filename)
    images = count_images(images, get_type(filename))
    
    if patient_id not in patients: patients.append(patient_id)
    print(test_img)

print(f'\n{len(patients)} patients:\n{patients}')
print(f'(T2, ADC, DWI, KTrans) = {images}')

data\processed\ProstateX-0197_t2_tse_tra_t2_tse_tra.nii.gz
data\processed\ProstateX-0197_t2_tse_tra_t2_tse_traa.nii.gz
data\processed\ProstateX-0197_ep2d_diff_tra_ep2d_diff_tra_DYNDIST_ADC.nii.gz
data\processed\ProstateX-0197_ep2d_diff_tra_ep2d_diff_tra_DYNDIST_0.nii.gz
data\processed\ProstateX-0197_ep2d_diff_tra_ep2d_diff_tra_DYNDIST_1.nii.gz
data\processed\ProstateX-0197_ep2d_diff_tra_ep2d_diff_tra_DYNDIST_2.nii.gz
data\processed\ProstateX-0197-Ktrans.nii.gz
data\processed\ProstateX-0018_t2_tse_tra_t2_tse_tra.nii.gz
data\processed\ProstateX-0018_ep2d_diff_tra_ep2d_diff_tra_DYNDIST_ADC.nii.gz
data\processed\ProstateX-0018_ep2d_diff_tra_ep2d_diff_tra_DYNDIST_0.nii.gz
data\processed\ProstateX-0018_ep2d_diff_tra_ep2d_diff_tra_DYNDIST_1.nii.gz
data\processed\ProstateX-0018_ep2d_diff_tra_ep2d_diff_tra_DYNDIST_2.nii.gz
data\processed\ProstateX-0018-Ktrans.nii.gz
data\processed\ProstateX-0201_t2_tse_tra_t2_tse_tra.nii.gz
data\processed\ProstateX-0201_diffusie-3Scan-4bval_fs_diffusie-3Scan-4b