# Setup imports

In [1]:
import logging
import ntpath
import os
import random
import sys
import shutil
import tempfile

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
torch.cuda.empty_cache()

from torch.utils.tensorboard import SummaryWriter

import monai
from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import DataLoader, ImageDataset
from monai.transforms import AddChannel, Compose, RandRotate90
from monai.transforms import Resize, ScaleIntensity, EnsureType

from datetime import datetime
from pathlib import Path
from tqdm import tqdm

from constants import Constants

pin_memory = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logging.basicConfig(stream=sys.stdout, level=logging.INFO)

constants = Constants()

In [2]:
# Set data directory
ROOT = constants.dataset_path / 'processed'

LABELS = ROOT / 'lesion_findings.pickle'

LOGS = Path('logs')
if not os.path.exists(LOGS): os.mkdir(LOGS)
    
BEST_METRICS = Path('best_metrics')
if not os.path.exists(BEST_METRICS): os.mkdir(BEST_METRICS)

# Get existing data

In [3]:
def get_patient_data(patient_id, labels):
    idx = get_patient_idx(patient_id)
    patient_data = {
        'data' : {
            'ProxID' : labels.ProxID.iloc[idx],
            'ClinSig' : labels.ClinSig.iloc[idx],
            'fid' : labels.fid.iloc[idx],
            'pos' : labels.pos.iloc[idx],
            'zone' : labels.zone.iloc[idx],
            'spacing' : labels.spacing.iloc[idx],
            'slices' : labels.slices.iloc[idx] 
        },
        'images' : {
            'T2' : labels.T2.iloc[idx],
            'ADC' : labels.ADC.iloc[idx],
            'KTrans' : labels.KTrans.iloc[idx] 
        }
    }
            
    return patient_data


def get_patientID(filename):
    return filename[:14]

def get_patient_idx(patient_id):
    return int(patient_id[10:])

In [15]:
labels = pd.read_pickle(LABELS)
labels

Unnamed: 0,ProxID,ClinSig,fid,pos,zone,spacing,slices,T2,ADC,KTrans
0,ProstateX-0000,[True],[1],[25.7457 31.8707 -38.511],[PZ],"(0.5, 0.5, 3.0)","[7, 8, 9, 10, 11, 12, 13]",data\processed\ProstateX-0000_t2_tse_tra_t2_ts...,data\processed\ProstateX-0000_ep2d_diff_tra_ep...,data\processed\ProstateX-0000-Ktrans.nii.gz
1,ProstateX-0001,[False],[1],[-40.5367071921656 29.320722668457 -16.7076690...,[AS],"(0.5, 0.5, 3.0)","[7, 8, 9, 10, 11, 12, 13]",data\processed\ProstateX-0001_t2_tse_tra_t2_ts...,data\processed\ProstateX-0001_ep2d_diff_tra_ep...,data\processed\ProstateX-0001-Ktrans.nii.gz
2,ProstateX-0002,"[True, False]","[1, 2]","[-27.0102 41.5467 -26.0469, -2.058 38.6752 -34...","[PZ, PZ]","(0.5, 0.5, 3.0)","[11, 12, 13, 14, 15, 16, 17, 18]",data\processed\ProstateX-0002_t2_tse_tra_t2_ts...,data\processed\ProstateX-0002_ep2d_diff_tra_ep...,data\processed\ProstateX-0002-Ktrans.nii.gz
3,ProstateX-0003,"[False, False]","[1, 2]","[22.1495 31.2717 -2.45933, -21.2871 19.3995 19...","[TZ, TZ]","(0.5, 0.5, 3.0)","[11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]",data\processed\ProstateX-0003_t2_tse_tra_t2_ts...,data\processed\ProstateX-0003_ep2d_diff_tra_ep...,data\processed\ProstateX-0003-Ktrans.nii.gz
4,ProstateX-0004,[False],[1],[-7.69665 3.64226 23.1659],[AS],"(0.5, 0.5, 3.0)","[6, 7, 8, 9, 10, 11, 12, 13]",data\processed\ProstateX-0004_t2_tse_tra_t2_ts...,data\processed\ProstateX-0004_ep2d_diff_tra_ep...,data\processed\ProstateX-0004-Ktrans.nii.gz
...,...,...,...,...,...,...,...,...,...,...
199,ProstateX-0199,"[True, True]","[1, 2]","[-4.267512 -51.1958 4.3458, -20.3406 -48.9915...","[AS, AS]","(0.5, 0.5, 3.0)","[8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]",data\processed\ProstateX-0199_t2_tse_tra_t2_ts...,data\processed\ProstateX-0199_diffusie-3Scan-4...,data\processed\ProstateX-0199-Ktrans.nii.gz
200,ProstateX-0200,"[True, False]","[1, 2]","[21.8727 -28.9887 -64.2121, -19.0211715698242...","[AS, PZ]","(0.5, 0.5, 3.0)","[10, 11, 12, 13, 14, 15]",data\processed\ProstateX-0200_t2_tse_tra_t2_ts...,data\processed\ProstateX-0200_diffusie-3Scan-4...,data\processed\ProstateX-0200-Ktrans.nii.gz
201,ProstateX-0201,[True],[1],[10.1826 -10.0427 20.9151],[AS],"(0.5, 0.5, 3.0)","[13, 14, 15, 16, 17, 18, 19, 20]",data\processed\ProstateX-0201_t2_tse_tra_t2_ts...,data\processed\ProstateX-0201_diffusie-3Scan-4...,data\processed\ProstateX-0201-Ktrans.nii.gz
202,ProstateX-0202,"[True, False]","[1, 2]","[-0.02085 -44.5506 16.7349, -12.8649 -21.7307 ...","[AS, PZ]","(0.5, 0.5, 3.0)","[7, 8, 9, 10, 11, 12, 13, 14, 15]",data\processed\ProstateX-0202_t2_tse_tra_t2_ts...,data\processed\ProstateX-0202_diffusie-3Scan-4...,data\processed\ProstateX-0202-Ktrans.nii.gz


# Split the dataset randomly

In [5]:
patients = list(labels.ProxID)

TRAIN_TEST_RATIO = 0.8
train_num = int(len(patients) * TRAIN_TEST_RATIO)

random.shuffle(patients)

train_data, test_data = patients[:train_num], patients[train_num:] 

len(train_data), len(test_data)

(163, 41)

# Prepare training data 

In [6]:
def get_lesion_summary(ClinSig):
    summary = False
    for cs in ClinSig:
        summary = summary or cs
    return summary


def get_images_and_labels(data, labels):
    images_arr = []
    labels_arr = []

    for patient_id in data:
        patient_data = get_patient_data(patient_id, labels)

        label = 0
        clin_sig = get_lesion_summary(patient_data['data']['ClinSig'])
        if clin_sig: label = 1

        images_arr.append(patient_data['images']['T2'])
        images_arr.append(patient_data['images']['ADC'])
        images_arr.append(patient_data['images']['KTrans'])

        for i in range(3):
            labels_arr.append(label)
            
    return images_arr, labels_arr

In [7]:
train_images_arr, labels_arr = get_images_and_labels(train_data, labels)

train_labels_arr = np.array(labels_arr)
train_labels_arr = torch.nn.functional.one_hot(torch.as_tensor(labels_arr)).float()

len(train_images_arr), len(train_labels_arr), train_labels_arr.shape

(489, 489, torch.Size([489, 2]))

## Define transforms

In [8]:
RESIZE_SIZE = 168

train_transforms = Compose([
    ScaleIntensity(), 
    AddChannel(), 
    Resize((RESIZE_SIZE, RESIZE_SIZE, RESIZE_SIZE)), 
    RandRotate90(), 
    EnsureType()
])

val_transforms = Compose([
    ScaleIntensity(), 
    AddChannel(), 
    Resize((RESIZE_SIZE, RESIZE_SIZE, RESIZE_SIZE)), 
    EnsureType()
])

## Check loaders

In [9]:
BATCH_SIZE = 2
NUM_WORKERS = 0

# Define nifti dataset, data loader
check_ds = ImageDataset(
    image_files=train_images_arr, 
    labels=train_labels_arr, 
    transform=train_transforms
)
check_loader = DataLoader(
    check_ds, 
    batch_size=BATCH_SIZE, 
    num_workers=NUM_WORKERS, 
    pin_memory=pin_memory
)

im, label = monai.utils.misc.first(check_loader)
print(type(im), im.shape, label, label.shape)

<class 'torch.Tensor'> torch.Size([2, 1, 168, 168, 168]) tensor([[1., 0.],
        [1., 0.]]) torch.Size([2, 2])


## Create data loaders 

In [10]:
TRAIN_VAL_RATIO = 0.8
train_val_num = int(len(train_data) * TRAIN_VAL_RATIO) * 3

train_images, val_images = train_images_arr[:train_val_num], train_images_arr[train_val_num:]
train_labels, val_labels = train_labels_arr[:train_val_num], train_labels_arr[train_val_num:]

# create a training data loader
train_ds = ImageDataset(
    image_files=train_images, 
    labels=train_labels, 
    transform=train_transforms
)
train_loader = DataLoader(
    train_ds, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=NUM_WORKERS, 
    pin_memory=pin_memory
)

# create a validation data loader
val_ds = ImageDataset(
    image_files=val_images, 
    labels=val_labels, 
    transform=val_transforms
)
val_loader = DataLoader(
    val_ds, 
    batch_size=BATCH_SIZE, 
    num_workers=NUM_WORKERS, 
    pin_memory=pin_memory
)

# Train the data

## Define network, loss function, and optimizer

In [11]:
today = datetime.today()
date_format_metric = today.strftime("%Y%m%d_%H%M%S")
print(date_format_metric)

model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)

loss_function = torch.nn.BCEWithLogitsLoss() 

optimizer = torch.optim.Adam(model.parameters(), 1e-4)

20220704_185438


## Training parameters

In [12]:
EPOCHS = 50

val_interval = 1
best_metric = -1
best_metric_epoch = -1

epoch_loss_values = []
metric_values = []
writer = SummaryWriter()

## Start training 

In [13]:
for epoch in range(EPOCHS):
    print("-" * 10)
    print(f"Epoch {epoch + 1}/{EPOCHS}")
    model.train()
    epoch_loss = 0
    step = 0

    for batch_data in tqdm(train_loader):
        step += 1
        inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(train_ds) // train_loader.batch_size
        writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)

    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()

        num_correct = 0.0
        metric_count = 0
        for val_data in val_loader:
            val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
            with torch.no_grad():
                val_outputs = model(val_images)
                value = torch.eq(val_outputs.argmax(dim=1), val_labels.argmax(dim=1))
                metric_count += len(value)
                num_correct += value.sum().item()

        metric = num_correct / metric_count
        metric_values.append(metric)

        if metric > best_metric:
            best_metric = metric
            best_metric_epoch = epoch + 1
            torch.save(
                model.state_dict(), 
                f"best_metrics/{date_format_metric}_best_metric_model.pth"
            )
            print("Saved new best metric model")

        print(f"Current epoch: {epoch+1} current accuracy: {metric:.4f} ")
        print(f"Best accuracy: {best_metric:.4f} at epoch {best_metric_epoch}")
        writer.add_scalar("val_accuracy", metric, epoch + 1)

print(f"Training completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
writer.close()

----------
Epoch 1/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:47<00:00,  1.16it/s]


epoch 1 average loss: 0.6768
Saved new best metric model
Current epoch: 1 current accuracy: 0.5051 
Best accuracy: 0.5051 at epoch 1
----------
Epoch 2/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:38<00:00,  1.23it/s]


epoch 2 average loss: 0.6516
Saved new best metric model
Current epoch: 2 current accuracy: 0.5152 
Best accuracy: 0.5152 at epoch 2
----------
Epoch 3/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:39<00:00,  1.22it/s]


epoch 3 average loss: 0.6444
Current epoch: 3 current accuracy: 0.5152 
Best accuracy: 0.5152 at epoch 2
----------
Epoch 4/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:37<00:00,  1.24it/s]


epoch 4 average loss: 0.6438
Current epoch: 4 current accuracy: 0.5152 
Best accuracy: 0.5152 at epoch 2
----------
Epoch 5/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:35<00:00,  1.25it/s]


epoch 5 average loss: 0.6410
Current epoch: 5 current accuracy: 0.5152 
Best accuracy: 0.5152 at epoch 2
----------
Epoch 6/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:35<00:00,  1.25it/s]


epoch 6 average loss: 0.6344
Saved new best metric model
Current epoch: 6 current accuracy: 0.5354 
Best accuracy: 0.5354 at epoch 6
----------
Epoch 7/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:41<00:00,  1.21it/s]


epoch 7 average loss: 0.6434
Current epoch: 7 current accuracy: 0.5051 
Best accuracy: 0.5354 at epoch 6
----------
Epoch 8/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:29<00:00,  1.30it/s]


epoch 8 average loss: 0.6352
Current epoch: 8 current accuracy: 0.4949 
Best accuracy: 0.5354 at epoch 6
----------
Epoch 9/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:21<00:00,  1.38it/s]


epoch 9 average loss: 0.6396
Current epoch: 9 current accuracy: 0.4848 
Best accuracy: 0.5354 at epoch 6
----------
Epoch 10/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:23<00:00,  1.36it/s]


epoch 10 average loss: 0.6351
Current epoch: 10 current accuracy: 0.5152 
Best accuracy: 0.5354 at epoch 6
----------
Epoch 11/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:21<00:00,  1.38it/s]


epoch 11 average loss: 0.6406
Current epoch: 11 current accuracy: 0.4949 
Best accuracy: 0.5354 at epoch 6
----------
Epoch 12/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:21<00:00,  1.38it/s]


epoch 12 average loss: 0.6355
Current epoch: 12 current accuracy: 0.5253 
Best accuracy: 0.5354 at epoch 6
----------
Epoch 13/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:21<00:00,  1.38it/s]


epoch 13 average loss: 0.6434
Current epoch: 13 current accuracy: 0.5152 
Best accuracy: 0.5354 at epoch 6
----------
Epoch 14/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.39it/s]


epoch 14 average loss: 0.6385
Current epoch: 14 current accuracy: 0.5152 
Best accuracy: 0.5354 at epoch 6
----------
Epoch 15/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.39it/s]


epoch 15 average loss: 0.6346
Current epoch: 15 current accuracy: 0.5152 
Best accuracy: 0.5354 at epoch 6
----------
Epoch 16/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.38it/s]


epoch 16 average loss: 0.6342
Current epoch: 16 current accuracy: 0.5152 
Best accuracy: 0.5354 at epoch 6
----------
Epoch 17/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.38it/s]


epoch 17 average loss: 0.6347
Current epoch: 17 current accuracy: 0.5253 
Best accuracy: 0.5354 at epoch 6
----------
Epoch 18/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.39it/s]


epoch 18 average loss: 0.6438
Current epoch: 18 current accuracy: 0.5354 
Best accuracy: 0.5354 at epoch 6
----------
Epoch 19/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.38it/s]


epoch 19 average loss: 0.6326
Current epoch: 19 current accuracy: 0.5253 
Best accuracy: 0.5354 at epoch 6
----------
Epoch 20/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.38it/s]


epoch 20 average loss: 0.6333
Current epoch: 20 current accuracy: 0.5253 
Best accuracy: 0.5354 at epoch 6
----------
Epoch 21/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.39it/s]


epoch 21 average loss: 0.6281
Current epoch: 21 current accuracy: 0.5253 
Best accuracy: 0.5354 at epoch 6
----------
Epoch 22/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.38it/s]


epoch 22 average loss: 0.6293
Current epoch: 22 current accuracy: 0.5354 
Best accuracy: 0.5354 at epoch 6
----------
Epoch 23/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.39it/s]


epoch 23 average loss: 0.6339
Current epoch: 23 current accuracy: 0.5152 
Best accuracy: 0.5354 at epoch 6
----------
Epoch 24/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.39it/s]


epoch 24 average loss: 0.6327
Current epoch: 24 current accuracy: 0.5253 
Best accuracy: 0.5354 at epoch 6
----------
Epoch 25/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.38it/s]


epoch 25 average loss: 0.6287
Saved new best metric model
Current epoch: 25 current accuracy: 0.5758 
Best accuracy: 0.5758 at epoch 25
----------
Epoch 26/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.39it/s]


epoch 26 average loss: 0.6356
Current epoch: 26 current accuracy: 0.5152 
Best accuracy: 0.5758 at epoch 25
----------
Epoch 27/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:21<00:00,  1.38it/s]


epoch 27 average loss: 0.6240
Current epoch: 27 current accuracy: 0.5253 
Best accuracy: 0.5758 at epoch 25
----------
Epoch 28/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.38it/s]


epoch 28 average loss: 0.6323
Current epoch: 28 current accuracy: 0.5455 
Best accuracy: 0.5758 at epoch 25
----------
Epoch 29/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:21<00:00,  1.38it/s]


epoch 29 average loss: 0.6277
Current epoch: 29 current accuracy: 0.5152 
Best accuracy: 0.5758 at epoch 25
----------
Epoch 30/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.38it/s]


epoch 30 average loss: 0.6220
Current epoch: 30 current accuracy: 0.5253 
Best accuracy: 0.5758 at epoch 25
----------
Epoch 31/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.38it/s]


epoch 31 average loss: 0.6144
Current epoch: 31 current accuracy: 0.5354 
Best accuracy: 0.5758 at epoch 25
----------
Epoch 32/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.38it/s]


epoch 32 average loss: 0.6281
Current epoch: 32 current accuracy: 0.5253 
Best accuracy: 0.5758 at epoch 25
----------
Epoch 33/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:21<00:00,  1.38it/s]


epoch 33 average loss: 0.6152
Current epoch: 33 current accuracy: 0.5556 
Best accuracy: 0.5758 at epoch 25
----------
Epoch 34/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.38it/s]


epoch 34 average loss: 0.6187
Current epoch: 34 current accuracy: 0.5253 
Best accuracy: 0.5758 at epoch 25
----------
Epoch 35/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.38it/s]


epoch 35 average loss: 0.6021
Current epoch: 35 current accuracy: 0.5455 
Best accuracy: 0.5758 at epoch 25
----------
Epoch 36/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.39it/s]


epoch 36 average loss: 0.6205
Current epoch: 36 current accuracy: 0.5455 
Best accuracy: 0.5758 at epoch 25
----------
Epoch 37/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:21<00:00,  1.38it/s]


epoch 37 average loss: 0.6064
Current epoch: 37 current accuracy: 0.5354 
Best accuracy: 0.5758 at epoch 25
----------
Epoch 38/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.39it/s]


epoch 38 average loss: 0.6142
Current epoch: 38 current accuracy: 0.5455 
Best accuracy: 0.5758 at epoch 25
----------
Epoch 39/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:21<00:00,  1.38it/s]


epoch 39 average loss: 0.6056
Current epoch: 39 current accuracy: 0.5354 
Best accuracy: 0.5758 at epoch 25
----------
Epoch 40/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:21<00:00,  1.38it/s]


epoch 40 average loss: 0.6029
Current epoch: 40 current accuracy: 0.5253 
Best accuracy: 0.5758 at epoch 25
----------
Epoch 41/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.38it/s]


epoch 41 average loss: 0.5870
Current epoch: 41 current accuracy: 0.5758 
Best accuracy: 0.5758 at epoch 25
----------
Epoch 42/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.39it/s]


epoch 42 average loss: 0.5963
Current epoch: 42 current accuracy: 0.5758 
Best accuracy: 0.5758 at epoch 25
----------
Epoch 43/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.39it/s]


epoch 43 average loss: 0.5850
Current epoch: 43 current accuracy: 0.5455 
Best accuracy: 0.5758 at epoch 25
----------
Epoch 44/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.38it/s]


epoch 44 average loss: 0.5901
Current epoch: 44 current accuracy: 0.5556 
Best accuracy: 0.5758 at epoch 25
----------
Epoch 45/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.39it/s]


epoch 45 average loss: 0.5767
Current epoch: 45 current accuracy: 0.5657 
Best accuracy: 0.5758 at epoch 25
----------
Epoch 46/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:21<00:00,  1.38it/s]


epoch 46 average loss: 0.5697
Saved new best metric model
Current epoch: 46 current accuracy: 0.5859 
Best accuracy: 0.5859 at epoch 46
----------
Epoch 47/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:20<00:00,  1.39it/s]


epoch 47 average loss: 0.5504
Current epoch: 47 current accuracy: 0.5455 
Best accuracy: 0.5859 at epoch 46
----------
Epoch 48/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:21<00:00,  1.38it/s]


epoch 48 average loss: 0.5366
Saved new best metric model
Current epoch: 48 current accuracy: 0.5960 
Best accuracy: 0.5960 at epoch 48
----------
Epoch 49/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:21<00:00,  1.38it/s]


epoch 49 average loss: 0.5730
Current epoch: 49 current accuracy: 0.5354 
Best accuracy: 0.5960 at epoch 48
----------
Epoch 50/50


100%|████████████████████████████████████████████████████████████████████████████████| 195/195 [02:21<00:00,  1.38it/s]


epoch 50 average loss: 0.5397
Saved new best metric model
Current epoch: 50 current accuracy: 0.6364 
Best accuracy: 0.6364 at epoch 50
Training completed, best_metric: 0.6364 at epoch: 50


# Test the model

## Prepare test data 

In [16]:
test_images_arr, labels_arr = get_images_and_labels(test_data, labels)

test_labels_arr = np.array(labels_arr)
test_labels_arr = torch.nn.functional.one_hot(torch.as_tensor(labels_arr)).float()

len(test_images_arr), len(test_labels_arr), test_labels_arr.shape

(123, 123, torch.Size([123, 2]))

## Create test loader

In [17]:
BATCH_SIZE_TEST = 1
NUM_WORKERS_TEST = 0

test_ds = ImageDataset(
    image_files=test_images_arr, 
    labels=test_labels_arr, 
    transform=val_transforms
)

test_loader = DataLoader(
    val_ds, 
    batch_size=BATCH_SIZE_TEST, 
    num_workers=NUM_WORKERS_TEST, 
    pin_memory=torch.cuda.is_available()
)

In [78]:
from monai.data import CSVSaver, ImageDataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_test = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)

model_test.load_state_dict(torch.load("best_metrics/20220704_185438.pth"))
model_test.eval()

OUTPUT = Path('output')
if not os.path.exists(OUTPUT): os.mkdir(OUTPUT)

with torch.no_grad():
    num_correct = 0.0
    metric_count = 0
    saver = CSVSaver(output_dir="./output")
    for test_data in tqdm(test_loader):
        test_images, test_labels = test_data[0].to(device), test_data[1].to(device)
        test_outputs = model_test(test_images).argmax(dim=1)
        value = torch.eq(test_outputs, test_labels)
        metric_count += len(value)
        num_correct += int(value[0][0].item())
        saver.save_batch(test_outputs)
    metric = num_correct / metric_count
    print("evaluation metric:", metric)

100%|██████████████████████████████████████████████████████████████████████████████████| 99/99 [00:21<00:00,  4.68it/s]

evaluation metric: 0.36363636363636365



