# Setup imports

In [1]:
import logging
import os
import random
import sys

import numpy as np
import pandas as pd
import torch
torch.cuda.empty_cache()

from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import DataLoader, ImageDataset
from monai.transforms import AddChannel, Compose, RandRotate90
from monai.transforms import Resize, ScaleIntensity, EnsureType

from datetime import datetime
from tqdm import tqdm

from constants import Constants
from patient import Patient

pin_memory = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logging.basicConfig(stream=sys.stdout, level=logging.INFO)

constants = Constants()

In [2]:
LABELS = constants.labels_pkl

LOGS = constants.logs
if not os.path.exists(LOGS): os.mkdir(LOGS)
    
BEST_METRICS = constants.best_metrics
if not os.path.exists(BEST_METRICS): os.mkdir(BEST_METRICS)

## Initialize log saving 

In [3]:
def log_write(command, file):
    file.write(f'{command}\n')
    print(command)

# Get existing data

In [4]:
def get_patients(labels):
    patient_data = {}
    for c, prox_id in enumerate(labels.ProxID):
        labels_row = labels.iloc[c]
        
        new_patient = Patient()
        new_patient.add_labels(labels_row)
        
        patient_data[prox_id] = new_patient
        
    return patient_data

In [5]:
labels_df = pd.read_pickle(LABELS)
patient_data = get_patients(labels_df)

# Split the dataset randomly

In [6]:
patients = list(labels_df.ProxID)

SPLIT_RATIO = constants.split_ratio
train_val_num = int(len(patients) * SPLIT_RATIO)
train_num = int(train_val_num * SPLIT_RATIO)

random.shuffle(patients)

train_data = patients[:train_num]
val_data = patients[train_num:train_val_num] 
test_data = patients[train_val_num:]

len(train_data), len(val_data), len(test_data)

(130, 33, 41)

# Prepare training data 

## Use Oversampling

In [7]:
def get_lesion_summary(ClinSig):
    summary = False
    for cs in ClinSig:
        summary = summary or cs
    return summary


OVERSAMPLE_TIMES = 5
def get_images_and_labels(data, labels, oversample=False):
    images_arr = []
    labels_arr = []

    for patient_id in data:
        patient = patient_data[patient_id]

        label = 0
        clin_sig = get_lesion_summary(patient.ClinSig)
        if clin_sig: label = 1

        for t2 in patient.T2:
            if clin_sig and oversample:
                for i in range(OVERSAMPLE_TIMES):
                    images_arr.append(t2)
                    labels_arr.append(label)
            images_arr.append(t2)
            labels_arr.append(label)
        for adc in patient.ADC:
            if clin_sig and oversample:
                for i in range(OVERSAMPLE_TIMES):
                    images_arr.append(adc)
                    labels_arr.append(label)
            images_arr.append(adc)
            labels_arr.append(label)
        for dwi in patient.DWI:
            if clin_sig and oversample:
                for i in range(OVERSAMPLE_TIMES):
                    images_arr.append(dwi)
                    labels_arr.append(label)
            images_arr.append(dwi)
            labels_arr.append(label)
        
        if clin_sig and oversample:
            for i in range(OVERSAMPLE_TIMES):
                images_arr.append(patient.KTrans)
                labels_arr.append(label)
        images_arr.append(patient.KTrans)
        labels_arr.append(label)
    return images_arr, labels_arr

In [8]:
def get_dataset_arr(data, labels_df=labels_df, oversample=False):
    images, labels_arr = get_images_and_labels(data, labels_df, oversample=oversample)
    
    labels = np.array(labels_arr)
    labels = torch.nn.functional.one_hot(torch.as_tensor(labels_arr)).float()
    
    return images, labels

In [9]:
oversample=True
train_images, train_labels = get_dataset_arr(data=train_data, oversample=oversample)
len(train_images), len(train_labels), train_labels.shape

(2382, 2382, torch.Size([2382, 2]))

In [10]:
val_images, val_labels = get_dataset_arr(data=val_data)
len(val_images), len(val_labels), val_labels.shape

(209, 209, torch.Size([209, 2]))

## Define transforms

In [11]:
RESIZE_SIZE = constants.image_resize

train_transforms = Compose([
    ScaleIntensity(), 
    AddChannel(), 
    Resize((RESIZE_SIZE, RESIZE_SIZE, RESIZE_SIZE)), 
    RandRotate90(), 
    EnsureType()
])

val_transforms = Compose([
    ScaleIntensity(), 
    AddChannel(), 
    Resize((RESIZE_SIZE, RESIZE_SIZE, RESIZE_SIZE)), 
    EnsureType()
])

## Check loaders

In [12]:
BATCH_SIZE = constants.batch_size
NUM_WORKERS = constants.num_workers

# Define nifti dataset, data loader
check_ds = ImageDataset(
    image_files=train_images, 
    labels=train_labels, 
    transform=train_transforms
)
check_loader = DataLoader(
    check_ds, 
    batch_size=BATCH_SIZE, 
    num_workers=NUM_WORKERS, 
    pin_memory=pin_memory
)

im, label = monai.utils.misc.first(check_loader)
print(type(im), im.shape, label, label.shape)

<class 'torch.Tensor'> torch.Size([2, 1, 168, 168, 168]) tensor([[0., 1.],
        [0., 1.]]) torch.Size([2, 2])


## Create data loaders 

In [13]:
# create a training data loader
train_ds = ImageDataset(
    image_files=train_images, 
    labels=train_labels, 
    transform=train_transforms
)
train_loader = DataLoader(
    train_ds, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=NUM_WORKERS, 
    pin_memory=pin_memory
)

# create a validation data loader
val_ds = ImageDataset(
    image_files=val_images, 
    labels=val_labels, 
    transform=val_transforms
)
val_loader = DataLoader(
    val_ds, 
    batch_size=BATCH_SIZE, 
    num_workers=NUM_WORKERS, 
    pin_memory=pin_memory
)

len(train_loader), len(val_loader)

(1191, 105)

# Train the data

## Define network, loss function, and optimizer

In [14]:
today = datetime.today()
date_format_metric = today.strftime("%Y%m%d_%H%M%S")
print(date_format_metric)

model = constants.model.to(device)
loss_function = constants.loss_function
optimizer = constants.optimizer

20220727_141401


## Training parameters

In [15]:
EPOCHS = constants.epochs
MODEL_TYPE = constants.model_type
LOSS_TYPE = constants.loss_type
OPTIMIZER_TYPE = constants.optimizer_type
LR = constants.lr

val_interval = 1
best_metric = -1
best_metric_epoch = -1

epoch_loss_values = []
metric_values = []
writer = SummaryWriter()

## Start training 

In [16]:
log_filename_train = LOGS / f'{date_format_metric}_E{EPOCHS}_train.txt'
log_filename_test = LOGS / f'{date_format_metric}_E{EPOCHS}_test.txt'

In [17]:
log_file = open(log_filename_train, 'w')

log_write(f'Model         :: {MODEL_TYPE}', log_file)
if oversample: 
    log_write(f'Oversample    :: {OVERSAMPLE_TIMES}', log_file)
log_write(f'Loss function :: {LOSS_TYPE}', log_file)
log_write(f'Optimizer     :: {OPTIMIZER_TYPE}', log_file)
log_write(f'Learning rate :: {LR}\n', log_file)

for epoch in range(EPOCHS):
    log_write("-" * 10, log_file)
    log_write(f"Epoch {epoch + 1}/{EPOCHS}", log_file)
    model.train()
    epoch_loss = 0
    step = 0
    
    MODEL_STATE_DICT = BEST_METRICS / f'{date_format_metric}.pth'
    
    for batch_data in tqdm(train_loader):
        step += 1
        inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(train_ds) // train_loader.batch_size
        writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)

    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    log_write(f"Epoch {epoch + 1} average loss: {epoch_loss:.4f}", log_file)

    if (epoch + 1) % val_interval == 0:
        model.eval()
        num_correct, metric_count = 0.0, 0
        
        for val_data in val_loader:
            val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
            with torch.no_grad():
                val_outputs = model(val_images).argmax(dim=1)
                val_labels = val_labels.argmax(dim=1)
                value = torch.eq(val_outputs, val_labels)
                metric_count += len(value)
                num_correct += value.sum().item()

        metric = num_correct / metric_count
        metric_values.append(metric)

        if metric > best_metric:
            best_metric = metric
            best_metric_epoch = epoch + 1
            torch.save(model.state_dict(), MODEL_STATE_DICT)
            log_write("Saved new best metric model", log_file)

        log_write(f"Current epoch: {epoch+1} current accuracy: {metric:.4f}", log_file)
        log_write(f"Best accuracy: {best_metric:.4f} at epoch {best_metric_epoch}", log_file)
        writer.add_scalar("val_accuracy", metric, epoch + 1)

log_write(f"Training completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}", log_file)
writer.close()
log_file.close()

Model         :: DensNet121(3, 1, 2)
Oversample    :: 5
Loss function :: Binary Cross Entropy
Optimizer     :: Adam
Learning rate :: 0.0001

----------
Epoch 1/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:59<00:00,  1.42it/s]


Epoch 1 average loss: 0.5334
Saved new best metric model
Current epoch: 1 current accuracy: 0.3876
Best accuracy: 0.3876 at epoch 1
----------
Epoch 2/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:49<00:00,  1.44it/s]


Epoch 2 average loss: 0.5253
Current epoch: 2 current accuracy: 0.3206
Best accuracy: 0.3876 at epoch 1
----------
Epoch 3/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:46<00:00,  1.44it/s]


Epoch 3 average loss: 0.5226
Current epoch: 3 current accuracy: 0.3206
Best accuracy: 0.3876 at epoch 1
----------
Epoch 4/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:58<00:00,  1.42it/s]


Epoch 4 average loss: 0.5225
Current epoch: 4 current accuracy: 0.3206
Best accuracy: 0.3876 at epoch 1
----------
Epoch 5/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [14:02<00:00,  1.41it/s]


Epoch 5 average loss: 0.5188
Current epoch: 5 current accuracy: 0.3206
Best accuracy: 0.3876 at epoch 1
----------
Epoch 6/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:58<00:00,  1.42it/s]


Epoch 6 average loss: 0.5212
Current epoch: 6 current accuracy: 0.3254
Best accuracy: 0.3876 at epoch 1
----------
Epoch 7/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [14:00<00:00,  1.42it/s]


Epoch 7 average loss: 0.5190
Current epoch: 7 current accuracy: 0.3206
Best accuracy: 0.3876 at epoch 1
----------
Epoch 8/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:55<00:00,  1.43it/s]


Epoch 8 average loss: 0.5195
Current epoch: 8 current accuracy: 0.3349
Best accuracy: 0.3876 at epoch 1
----------
Epoch 9/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:50<00:00,  1.43it/s]


Epoch 9 average loss: 0.5149
Current epoch: 9 current accuracy: 0.3206
Best accuracy: 0.3876 at epoch 1
----------
Epoch 10/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:53<00:00,  1.43it/s]


Epoch 10 average loss: 0.5137
Current epoch: 10 current accuracy: 0.3206
Best accuracy: 0.3876 at epoch 1
----------
Epoch 11/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:48<00:00,  1.44it/s]


Epoch 11 average loss: 0.5121
Current epoch: 11 current accuracy: 0.3206
Best accuracy: 0.3876 at epoch 1
----------
Epoch 12/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:51<00:00,  1.43it/s]


Epoch 12 average loss: 0.4997
Current epoch: 12 current accuracy: 0.3589
Best accuracy: 0.3876 at epoch 1
----------
Epoch 13/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:48<00:00,  1.44it/s]


Epoch 13 average loss: 0.4906
Current epoch: 13 current accuracy: 0.3206
Best accuracy: 0.3876 at epoch 1
----------
Epoch 14/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:44<00:00,  1.44it/s]


Epoch 14 average loss: 0.4848
Current epoch: 14 current accuracy: 0.3301
Best accuracy: 0.3876 at epoch 1
----------
Epoch 15/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:46<00:00,  1.44it/s]


Epoch 15 average loss: 0.4721
Saved new best metric model
Current epoch: 15 current accuracy: 0.3971
Best accuracy: 0.3971 at epoch 15
----------
Epoch 16/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:46<00:00,  1.44it/s]


Epoch 16 average loss: 0.4380
Saved new best metric model
Current epoch: 16 current accuracy: 0.4641
Best accuracy: 0.4641 at epoch 16
----------
Epoch 17/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:45<00:00,  1.44it/s]


Epoch 17 average loss: 0.4316
Saved new best metric model
Current epoch: 17 current accuracy: 0.5981
Best accuracy: 0.5981 at epoch 17
----------
Epoch 18/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:44<00:00,  1.44it/s]


Epoch 18 average loss: 0.3982
Current epoch: 18 current accuracy: 0.5072
Best accuracy: 0.5981 at epoch 17
----------
Epoch 19/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:45<00:00,  1.44it/s]


Epoch 19 average loss: 0.3633
Saved new best metric model
Current epoch: 19 current accuracy: 0.6268
Best accuracy: 0.6268 at epoch 19
----------
Epoch 20/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:43<00:00,  1.45it/s]


Epoch 20 average loss: 0.3234
Current epoch: 20 current accuracy: 0.5933
Best accuracy: 0.6268 at epoch 19
----------
Epoch 21/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:43<00:00,  1.45it/s]


Epoch 21 average loss: 0.2803
Current epoch: 21 current accuracy: 0.6029
Best accuracy: 0.6268 at epoch 19
----------
Epoch 22/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:45<00:00,  1.44it/s]


Epoch 22 average loss: 0.2496
Saved new best metric model
Current epoch: 22 current accuracy: 0.6555
Best accuracy: 0.6555 at epoch 22
----------
Epoch 23/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:44<00:00,  1.44it/s]


Epoch 23 average loss: 0.2248
Current epoch: 23 current accuracy: 0.5789
Best accuracy: 0.6555 at epoch 22
----------
Epoch 24/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:44<00:00,  1.44it/s]


Epoch 24 average loss: 0.2038
Current epoch: 24 current accuracy: 0.6364
Best accuracy: 0.6555 at epoch 22
----------
Epoch 25/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:45<00:00,  1.44it/s]


Epoch 25 average loss: 0.1712
Current epoch: 25 current accuracy: 0.6220
Best accuracy: 0.6555 at epoch 22
----------
Epoch 26/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:49<00:00,  1.44it/s]


Epoch 26 average loss: 0.1625
Current epoch: 26 current accuracy: 0.6364
Best accuracy: 0.6555 at epoch 22
----------
Epoch 27/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:45<00:00,  1.44it/s]


Epoch 27 average loss: 0.1544
Current epoch: 27 current accuracy: 0.6364
Best accuracy: 0.6555 at epoch 22
----------
Epoch 28/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:49<00:00,  1.44it/s]


Epoch 28 average loss: 0.1426
Current epoch: 28 current accuracy: 0.6411
Best accuracy: 0.6555 at epoch 22
----------
Epoch 29/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:46<00:00,  1.44it/s]


Epoch 29 average loss: 0.1487
Current epoch: 29 current accuracy: 0.6316
Best accuracy: 0.6555 at epoch 22
----------
Epoch 30/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:45<00:00,  1.44it/s]


Epoch 30 average loss: 0.1242
Current epoch: 30 current accuracy: 0.6268
Best accuracy: 0.6555 at epoch 22
----------
Epoch 31/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:46<00:00,  1.44it/s]


Epoch 31 average loss: 0.1095
Saved new best metric model
Current epoch: 31 current accuracy: 0.6746
Best accuracy: 0.6746 at epoch 31
----------
Epoch 32/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:44<00:00,  1.44it/s]


Epoch 32 average loss: 0.1040
Current epoch: 32 current accuracy: 0.6459
Best accuracy: 0.6746 at epoch 31
----------
Epoch 33/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:45<00:00,  1.44it/s]


Epoch 33 average loss: 0.1031
Saved new best metric model
Current epoch: 33 current accuracy: 0.7081
Best accuracy: 0.7081 at epoch 33
----------
Epoch 34/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:46<00:00,  1.44it/s]


Epoch 34 average loss: 0.0931
Current epoch: 34 current accuracy: 0.6603
Best accuracy: 0.7081 at epoch 33
----------
Epoch 35/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:44<00:00,  1.45it/s]


Epoch 35 average loss: 0.1088
Current epoch: 35 current accuracy: 0.6268
Best accuracy: 0.7081 at epoch 33
----------
Epoch 36/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:44<00:00,  1.44it/s]


Epoch 36 average loss: 0.0829
Current epoch: 36 current accuracy: 0.6555
Best accuracy: 0.7081 at epoch 33
----------
Epoch 37/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:46<00:00,  1.44it/s]


Epoch 37 average loss: 0.0876
Current epoch: 37 current accuracy: 0.6842
Best accuracy: 0.7081 at epoch 33
----------
Epoch 38/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:45<00:00,  1.44it/s]


Epoch 38 average loss: 0.0905
Current epoch: 38 current accuracy: 0.6890
Best accuracy: 0.7081 at epoch 33
----------
Epoch 39/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:44<00:00,  1.44it/s]


Epoch 39 average loss: 0.0732
Current epoch: 39 current accuracy: 0.6699
Best accuracy: 0.7081 at epoch 33
----------
Epoch 40/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:46<00:00,  1.44it/s]


Epoch 40 average loss: 0.0799
Current epoch: 40 current accuracy: 0.6794
Best accuracy: 0.7081 at epoch 33
----------
Epoch 41/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:44<00:00,  1.44it/s]


Epoch 41 average loss: 0.0688
Current epoch: 41 current accuracy: 0.6077
Best accuracy: 0.7081 at epoch 33
----------
Epoch 42/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:44<00:00,  1.44it/s]


Epoch 42 average loss: 0.0784
Current epoch: 42 current accuracy: 0.6986
Best accuracy: 0.7081 at epoch 33
----------
Epoch 43/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:46<00:00,  1.44it/s]


Epoch 43 average loss: 0.0775
Current epoch: 43 current accuracy: 0.6842
Best accuracy: 0.7081 at epoch 33
----------
Epoch 44/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:44<00:00,  1.44it/s]


Epoch 44 average loss: 0.0689
Current epoch: 44 current accuracy: 0.6746
Best accuracy: 0.7081 at epoch 33
----------
Epoch 45/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:45<00:00,  1.44it/s]


Epoch 45 average loss: 0.0588
Current epoch: 45 current accuracy: 0.6651
Best accuracy: 0.7081 at epoch 33
----------
Epoch 46/50


100%|████████████████████████████████████████████████████████████████████████████| 1191/1191 [9:59:27<00:00, 30.20s/it]


Epoch 46 average loss: 0.0728
Current epoch: 46 current accuracy: 0.6746
Best accuracy: 0.7081 at epoch 33
----------
Epoch 47/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:49<00:00,  1.44it/s]


Epoch 47 average loss: 0.0550
Current epoch: 47 current accuracy: 0.6651
Best accuracy: 0.7081 at epoch 33
----------
Epoch 48/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [13:46<00:00,  1.44it/s]


Epoch 48 average loss: 0.0643
Current epoch: 48 current accuracy: 0.6172
Best accuracy: 0.7081 at epoch 33
----------
Epoch 49/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [30:52<00:00,  1.56s/it]


Epoch 49 average loss: 0.0703
Saved new best metric model
Current epoch: 49 current accuracy: 0.7129
Best accuracy: 0.7129 at epoch 49
----------
Epoch 50/50


100%|██████████████████████████████████████████████████████████████████████████████| 1191/1191 [21:45<00:00,  1.10s/it]


Epoch 50 average loss: 0.0617
Current epoch: 50 current accuracy: 0.6746
Best accuracy: 0.7129 at epoch 49
Training completed, best_metric: 0.7129 at epoch: 49


# Test the model

## Prepare test data 

In [18]:
test_images, test_labels = get_dataset_arr(test_data)
len(test_images), len(test_labels), test_labels.shape

(258, 258, torch.Size([258, 2]))

## Create test loader

In [20]:
BATCH_SIZE_TEST = constants.batch_size_test
NUM_WORKERS_TEST = constants.num_workers

test_ds = ImageDataset(
    image_files=test_images, 
    labels=test_labels, 
    transform=val_transforms
)

test_loader = DataLoader(
    test_ds, 
    batch_size=BATCH_SIZE_TEST, 
    num_workers=NUM_WORKERS_TEST, 
    pin_memory=pin_memory
)

In [21]:
def calculate_KPI(kpi, output, label):
    tp, tn, fp, fn = kpi
    
    if label[0]:
        if output[0]: tp += 1
        else: fn +=1
    else:
        if output[0]: fp += 1
        else: tn +=1
    return tp, tn, fp, fn


def get_metrics(kpi):
    tp, tn, fp, fn = kpi
    
    recall, precision, f1 = -1, -1, -1
    accuracy = (tp+tn) / (tp+tn+fp+fn)
    
    if tp + fn != 0: 
        recall = tp / (tp+fn)
        
    if tp + fp != 0: 
        precision = tp / (tp+fp)
        
    if precision > 0 and recall > 0: 
        f1 = (2*precision*recall) / (precision+recall)
    
    return accuracy, recall, precision, f1

In [22]:
log_file = open(log_filename_test, 'w')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_test = constants.model.to(device)

MODEL = BEST_METRICS / f'{date_format_metric}.pth'

log_write(f'Training log :: {log_filename_train}', log_file)
log_write(f'Metrics      :: {MODEL}\n', log_file)
log_write("-" * 10, log_file)

model_test.load_state_dict(torch.load(MODEL))
model_test.eval()

with torch.no_grad():
    KPI = 0, 0, 0, 0
    
    for i, data in enumerate(tqdm(test_loader)):
        test_images, test_labels = data[0].to(device), data[1].to(device)
        
        test_outputs = model_test(test_images).argmax(dim=1)
        test_labels = test_labels.argmax(dim=1)
        
        KPI = calculate_KPI(KPI, test_outputs, test_labels)
        
    log_write(f'(TP, TN, FP, FN) = {KPI}', log_file)
    
    metrics = get_metrics(KPI)
    log_write(
        f'Evaluation metrics:\n'
        f'Accuracy  : {metrics[0]}\n'
        f'Recall    : {metrics[1]}\n'
        f'Precision : {metrics[2]}\n'
        f'F1-score  : {metrics[3]}',
        log_file
    )
    
log_file.close()

Training log :: logs\20220727_141401_E50_train.txt
Metrics      :: best_metrics\20220727_141401.pth

----------


100%|████████████████████████████████████████████████████████████████████████████████| 258/258 [00:50<00:00,  5.12it/s]

(TP, TN, FP, FN) = (38, 147, 43, 30)
Evaluation metrics:
Accuracy  : 0.7170542635658915
Recall    : 0.5588235294117647
Precision : 0.4691358024691358
F1-score  : 0.5100671140939597



