# Setup imports

In [1]:
import logging
import ntpath
import os
import random
import sys
import shutil
import tempfile

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
torch.cuda.empty_cache()

from torch.utils.tensorboard import SummaryWriter

import monai
from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import DataLoader, ImageDataset
from monai.transforms import AddChannel, Compose, RandRotate90
from monai.transforms import Resize, ScaleIntensity, EnsureType

from datetime import datetime
from pathlib import Path
from tqdm import tqdm

from constants import Constants

pin_memory = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logging.basicConfig(stream=sys.stdout, level=logging.INFO)

constants = Constants()

In [2]:
LABELS = constants.labels_pkl

LOGS = constants.logs
if not os.path.exists(LOGS): os.mkdir(LOGS)
    
BEST_METRICS = constants.best_metrics
if not os.path.exists(BEST_METRICS): os.mkdir(BEST_METRICS)

# Get existing data

In [3]:
def get_patient_data(patient_id, labels):
    idx = get_patient_idx(patient_id)
    patient_data = {
        'data' : {
            'ProxID'  : labels.ProxID.iloc[idx],
            'ClinSig' : labels.ClinSig.iloc[idx],
            'fid'     : labels.fid.iloc[idx],
            'pos'     : labels.pos.iloc[idx],
            'zone'    : labels.zone.iloc[idx],
            'spacing' : labels.spacing.iloc[idx],
            'slices'  : labels.slices.iloc[idx] 
        },
        'images' : {
            'T2'     : labels.T2.iloc[idx],
            'ADC'    : labels.ADC.iloc[idx],
            'DWI'    : labels.DWI.iloc[idx],
            'KTrans' : labels.KTrans.iloc[idx] 
        }
    }
    return patient_data


def get_patientID(filename):
    return filename[:14]


def get_patient_idx(patient_id):
    return int(patient_id[10:])

In [4]:
labels_df = pd.read_pickle(LABELS)

# Split the dataset randomly

In [5]:
patients = list(labels_df.ProxID)

TRAIN_TEST_RATIO = constants.split_ratio
train_num = int(len(patients) * TRAIN_TEST_RATIO)

random.shuffle(patients)

train_data, test_data = patients[:train_num], patients[train_num:] 

len(train_data), len(test_data)

(163, 41)

# Prepare training data 

In [6]:
def get_lesion_summary(ClinSig):
    summary = False
    for cs in ClinSig:
        summary = summary or cs
    return summary


def get_images_and_labels(data, labels):
    images_arr = []
    labels_arr = []
    num_files_idx = []

    for patient_id in data:
        patient_data = get_patient_data(patient_id, labels)

        label = 0
        clin_sig = get_lesion_summary(patient_data['data']['ClinSig'])
        if clin_sig: label = 1

        for t2 in patient_data['images']['T2']: 
            images_arr.append(t2)
            labels_arr.append(label)
        for adc in patient_data['images']['ADC']: 
            images_arr.append(adc)
            labels_arr.append(label)
        for dwi in patient_data['images']['DWI']: 
            images_arr.append(dwi)
            labels_arr.append(label)
        images_arr.append(patient_data['images']['KTrans'])
        labels_arr.append(label)
        num_files_idx.append(len(images_arr))
    return images_arr, labels_arr, num_files_idx

In [7]:
train_images_arr, labels_arr, num_files_idx = get_images_and_labels(train_data, labels_df)

train_labels_arr = np.array(labels_arr)
train_labels_arr = torch.nn.functional.one_hot(torch.as_tensor(labels_arr)).float()

len(train_images_arr), len(train_labels_arr), train_labels_arr.shape

(1029, 1029, torch.Size([1029, 2]))

## Define transforms

In [8]:
RESIZE_SIZE = constants.image_resize

train_transforms = Compose([
    ScaleIntensity(), 
    AddChannel(), 
    Resize((RESIZE_SIZE, RESIZE_SIZE, RESIZE_SIZE)), 
    RandRotate90(), 
    EnsureType()
])

val_transforms = Compose([
    ScaleIntensity(), 
    AddChannel(), 
    Resize((RESIZE_SIZE, RESIZE_SIZE, RESIZE_SIZE)), 
    EnsureType()
])

## Check loaders

In [9]:
BATCH_SIZE = constants.batch_size
NUM_WORKERS = constants.num_workers

# Define nifti dataset, data loader
check_ds = ImageDataset(
    image_files=train_images_arr, 
    labels=train_labels_arr, 
    transform=train_transforms
)
check_loader = DataLoader(
    check_ds, 
    batch_size=BATCH_SIZE, 
    num_workers=NUM_WORKERS, 
    pin_memory=pin_memory
)

im, label = monai.utils.misc.first(check_loader)
print(type(im), im.shape, label, label.shape)

<class 'torch.Tensor'> torch.Size([2, 1, 168, 168, 168]) tensor([[1., 0.],
        [1., 0.]]) torch.Size([2, 2])


## Create data loaders 

In [10]:
TRAIN_VAL_RATIO = constants.split_ratio
train_val_num = num_files_idx[int(len(train_data) * TRAIN_VAL_RATIO)] - 1 

train_images, val_images = train_images_arr[:train_val_num], train_images_arr[train_val_num:]
train_labels, val_labels = train_labels_arr[:train_val_num], train_labels_arr[train_val_num:]

# create a training data loader
train_ds = ImageDataset(
    image_files=train_images, 
    labels=train_labels, 
    transform=train_transforms
)
train_loader = DataLoader(
    train_ds, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=NUM_WORKERS, 
    pin_memory=pin_memory
)

# create a validation data loader
val_ds = ImageDataset(
    image_files=val_images, 
    labels=val_labels, 
    transform=val_transforms
)
val_loader = DataLoader(
    val_ds, 
    batch_size=BATCH_SIZE, 
    num_workers=NUM_WORKERS, 
    pin_memory=pin_memory
)

# Train the data

## Define network, loss function, and optimizer

In [11]:
today = datetime.today()
date_format_metric = today.strftime("%Y%m%d_%H%M%S")
print(date_format_metric)

model = constants.model.to(device)
loss_function = constants.loss_function
optimizer = constants.optimizer

20220713_150900


## Training parameters

In [12]:
EPOCHS = constants.epochs

val_interval = 1
best_metric = -1
best_metric_epoch = -1

epoch_loss_values = []
metric_values = []
writer = SummaryWriter()

## Start training 

In [13]:
for epoch in range(EPOCHS):
    print("-" * 10)
    print(f"Epoch {epoch + 1}/{EPOCHS}")
    model.train()
    epoch_loss = 0
    step = 0
    
    MODEL_STATE_DICT = BEST_METRICS / f'{date_format_metric}.pth'
    
    for batch_data in tqdm(train_loader):
        step += 1
        inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(train_ds) // train_loader.batch_size
        writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)

    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"Epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()

        num_correct = 0.0
        metric_count = 0
        for val_data in val_loader:
            val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
            with torch.no_grad():
                val_outputs = model(val_images)
                value = torch.eq(val_outputs.argmax(dim=1), val_labels.argmax(dim=1))
                metric_count += len(value)
                num_correct += value.sum().item()

        metric = num_correct / metric_count
        metric_values.append(metric)

        if metric > best_metric:
            best_metric = metric
            best_metric_epoch = epoch + 1
            torch.save(model.state_dict(), MODEL_STATE_DICT)
            print("Saved new best metric model")

        print(f"Current epoch: {epoch+1} current accuracy: {metric:.4f} ")
        print(f"Best accuracy: {best_metric:.4f} at epoch {best_metric_epoch}")
        writer.add_scalar("val_accuracy", metric, epoch + 1)

print(f"Training completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
writer.close()

----------
Epoch 1/100


 37%|█████████████████████████████▉                                                  | 155/414 [01:56<03:13,  1.34it/s]

KeyboardInterrupt



# Test the model

## Prepare test data 

In [None]:
test_images_arr, labels_arr = get_images_and_labels(test_data, labels_df)

test_labels_arr = np.array(labels_arr)
test_labels_arr = torch.nn.functional.one_hot(torch.as_tensor(labels_arr)).float()

len(test_images_arr), len(test_labels_arr), test_labels_arr.shape

## Create test loader

In [None]:
BATCH_SIZE_TEST = constants.batch_size_test
NUM_WORKERS_TEST = constants.num_workers

test_ds = ImageDataset(
    image_files=test_images_arr, 
    labels=test_labels_arr, 
    transform=val_transforms
)

test_loader = DataLoader(
    val_ds, 
    batch_size=BATCH_SIZE_TEST, 
    num_workers=NUM_WORKERS_TEST, 
    pin_memory=torch.cuda.is_available()
)

In [None]:
def calculate_KPI(kpi, output, label):
    tp, tn, fp, fn = kpi
    if label[0]:
        if output[0]: tp += 1
        else: fn +=1
    else:
        if output[0]: fp += 1
        else: tn +=1
    return tp, tn, fp, fn


def get_metrics(kpi):
    tp, tn, fp, fn = kpi
    
    recall, precision, f1 = -1, -1, -1
    accuracy = (tp+tn) / (tp+tn+fp+fn)
    
    if tp + fn != 0: 
        recall = tp / (tp+fn)
        
    if tp + fp != 0: 
        precision = tp / (tp+fp)
        
    if precision > 0 and recall > 0: 
        f1 = (2*precision*recall) / (precision+recall)
    
    return accuracy, recall, precision, f1

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_test = constants.model.to(device)

MODEL = BEST_METRICS / f'{date_format_metric}.pth'

model_test.load_state_dict(torch.load(MODEL))
model_test.eval()

with torch.no_grad():
    KPI = 0, 0, 0, 0
    i = 0
    
    for data in tqdm(test_loader):
        test_images, test_labels = data[0].to(device), data[1].to(device)
        
        test_outputs = model_test(test_images).argmax(dim=1)
        test_labels = test_labels.argmax(dim=1)
        
        value = torch.eq(test_outputs, test_labels)
        
        KPI = calculate_KPI(KPI, test_outputs, test_labels)
        i += 1
        
    print(f'All test data :: (TP, TN, FP, FN) = {KPI}')
    
    metrics = get_metrics(KPI)
    print(
        f'Evaluation metrics:\n'
        f'Accuracy  : {metrics[0]}\n'
        f'Recall    : {metrics[1]}\n'
        f'Precision : {metrics[2]}\n'
        f'F1-score  : {metrics[3]}\n'
    )