Import necessary libraries

In [1]:
import os
from glob import glob
import shutil
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from monai.transforms import (
    Compose,
    AddChanneld,
    LoadImaged,
    Resized,
    ToTensord,
    Spacingd,
    Orientationd,
    ScaleIntensityRanged,
    CropForegroundd
)
from monai.data import DataLoader, Dataset, CacheDataset
from monai.utils import set_determinism, first
from monai.losses import DiceLoss
import torch

Create a train-test split

In [2]:
def create_train_test_split(data_path, test_size=0.2, random_state=0):
    images = np.array(sorted(glob(os.path.join(data_path, 'images/*'))))
    labels = np.array(sorted(glob(os.path.join(data_path, 'labels/*'))))
    train_images, test_images, train_labels, test_labels = train_test_split(images, labels, test_size = test_size, random_state = random_state)
    return { 'train_images': train_images, 'test_images': test_images, 'train_labels': train_labels, 'test_labels' : test_labels }

Creating Input Transforms and Data Loaders

In [3]:
def preprocess(files, pixdim=(1.5, 1.5, 1.0), a_min=-200, a_max=200, spatial_size=[128, 128, 64], cache=False):
    
    set_determinism(seed = 0)

    train_files = [{'images': image_name, 'labels': label_name} for image_name, label_name in zip(files['train_images'], files['train_labels'])]
    test_files = [{'images': image_name, 'labels': label_name} for image_name, label_name in zip(files['test_images'], files['test_labels'])]

    train_transforms = Compose(
        [
            LoadImaged(keys=['images', 'labels']),
            AddChanneld(keys=['images', 'labels']),
            Spacingd(keys=['images', 'labels'], pixdim=pixdim, mode=('bilinear', 'nearest')),
            Orientationd(keys=['images', 'labels'], axcodes='RAS'),
            ScaleIntensityRanged(keys=['images'], a_min=a_min, a_max=a_max, b_min=0, b_max=1, clip=True),
            CropForegroundd(keys=['images', 'labels'], source_key='images'),
            Resized(keys=['images', 'labels'], spatial_size=spatial_size),
            ToTensord(keys=['images', 'labels'])
        ]
    )

    test_transforms = Compose(
        [
            LoadImaged(keys=['images', 'labels']),
            AddChanneld(keys=['images', 'labels']),
            Spacingd(keys=['images', 'labels'], pixdim=pixdim, mode=('bilinear', 'nearest')),
            Orientationd(keys=['images', 'labels'], axcodes='RAS'),
            ScaleIntensityRanged(keys=['images'], a_min=a_min, a_max=a_max, b_min=0, b_max=1, clip=True),
            CropForegroundd(keys=['images', 'labels'], source_key='images'),
            Resized(keys=['images', 'labels'], spatial_size=spatial_size),
            ToTensord(keys=['images', 'labels'])
        ]
    )

    if cache:
        train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0)
        test_ds = CacheDataset(data=test_files, transform=test_transforms, cache_rate=1.0)

    else:
        train_ds = Dataset(data=train_files, transform=train_transforms)
        test_ds = Dataset(data=test_files, transform=test_transforms)

    train_loader = DataLoader(train_ds, batch_size=1)
    test_loader = DataLoader(test_ds, batch_size=1)

    return train_loader, test_loader
    

Display a slice of the first patient data sample

In [4]:
def show_patient(data, SLICE_NUMBER=1, train=True, test=False):
    
    check_patient_train, check_patient_test = data

    view_train_patient = first(check_patient_train)
    view_test_patient = first(check_patient_test)

    if train:
        plt.figure("Train Visualisation", (12, 6))
        plt.subplot(1, 2, 1)
        plt.title(f'image {SLICE_NUMBER}')
        plt.imshow(view_train_patient['images'][0, 0, :, :, SLICE_NUMBER], cmap='gray')

        plt.subplot(1, 2, 2)
        plt.title(f'label {SLICE_NUMBER}')
        plt.imshow(view_train_patient['labels'][0, 0, :, :, SLICE_NUMBER])
        plt.show()

    if test:
        plt.figure("Test Visualisation", (12, 6))
        plt.subplot(1, 2, 1)
        plt.title(f'image {SLICE_NUMBER}')
        plt.imshow(view_test_patient['images'][0, 0, :, :, SLICE_NUMBER], cmap='gray')

        plt.subplot(1, 2, 2)
        plt.title(f'label {SLICE_NUMBER}')
        plt.imshow(view_test_patient['labels'][0, 0, :, :, SLICE_NUMBER])
        plt.show()

Calculate number of pixels

In [5]:
def calculate_pixels(data):
    
    val = np.zeros((1,2))
    
    for batch in tqdm(data):
        batch_label = batch["labels"] != 0
        _, count = np.unique(batch_label, return_counts=True)
        
        if len(count) == 1:
            count = np.append(count, 0)
        val += count

    val = val // 1e7
    print(val)

    return val

Calculate weights for Cross Entropy Loss using counts of foreground and background

In [6]:
def calculate_weights(count):

    total = count.sum()
    weights = count / total
    weights = 1 / weights
    total = weights.sum()
    weights = weights / total
    
    return torch.tensor(weights, dtype=torch.float32)

Calculate the Dice Metric to evaluate the model

In [7]:
def dice_metric(predicted, target):
    dice_value = DiceLoss(to_onehot_y=True, sigmoid=True, squared_pred=True)
    value = 1 - dice_value(predicted, target).item()
    return value

Main training function

In [9]:
def train(model, data_in, loss, optim, max_epochs, model_dir, test_interval=1, load=False, path='', device=torch.device("cuda:0")):
    
    best_metric = -1
    best_metric_epoch = -1

    save_loss_train = []
    save_loss_test = []
    save_metric_train = []
    save_metric_test = []

    train_loader, test_loader = data_in

    cur_epoch = 0

    if load:
        checkpoint = torch.load(os.path.join(model_dir, "best_metric_model.pth"))
        model.load_state_dict(checkpoint['model_state_dict'])
        optim.load_state_dict(checkpoint['optim_state_dict'])
        cur_epoch = checkpoint['epoch']

    for epoch in range(cur_epoch, max_epochs):
        print("-" * 10)
        print(f"epoch {epoch+1}/{max_epochs}")

        model.train()

        train_step = 0
        train_epoch_loss = 0
        train_epoch_metric = 0

        for batch_data in train_loader:
            train_step += 1

            image = batch_data["images"]
            label = batch_data["labels"]
            label = label != 0
            image, label = (image.to(device), label.to(device))

            optim.zero_grad()
            outputs = model(image)
            
            train_loss = loss(outputs, label)

            train_loss.backward()
            optim.step()

            train_epoch_loss += train_loss.item()

            train_metric = dice_metric(outputs, label)
            train_epoch_metric += train_metric

            print(f'{train_step}/{len(train_loader) // train_loader.batch_size}  '
                f'train_loss: {train_loss.item():.4f}, train_dice: {train_metric:.4f}')

        print('-'*20)

        train_epoch_loss /= train_step
        print(f'epoch_loss: {train_epoch_loss:.4f}')
        save_loss_train.append(train_epoch_loss)
        np.save(os.path.join(model_dir, 'loss_train.npy'), save_loss_train)

        train_epoch_metric /= train_step
        print(f'epoch_metric: {train_epoch_metric:.4f}')
        save_metric_train.append(train_epoch_metric)
        np.save(os.path.join(model_dir, 'metric_train.npy'), save_metric_train)

        if (epoch + 1) % test_interval == 0:
            model.eval()
            with torch.no_grad():
                test_step = 0
                test_epoch_loss = 0
                test_epoch_metric = 0
                
                for test_data in test_loader:

                    test_step += 1

                    test_image = test_data["images"]
                    test_label = test_data["labels"]
                    test_label = test_label != 0
                    test_image, test_label = (test_image.to(device), test_label.to(device))

                    test_outputs = model(test_image)

                    test_loss = loss(test_outputs, test_label)
                    test_epoch_loss += test_loss.item()
                    test_metric = dice_metric(test_outputs, test_label)
                    test_epoch_metric += test_metric

                test_epoch_loss /= test_step
                print(f'test_loss_epoch: {test_epoch_loss:.4f}')
                save_loss_test.append(test_epoch_loss)
                np.save(os.path.join(model_dir, 'loss_test.npy'), save_loss_test)

                test_epoch_metric /= test_step
                print(f'test_dice_epoch: {test_epoch_metric:.4f}')
                save_metric_test.append(test_epoch_metric)
                np.save(os.path.join(model_dir, 'metric_test.npy'), save_metric_test)

                if test_epoch_metric > best_metric:
                    best_metric = test_epoch_metric
                    best_metric_epoch = epoch + 1

                    torch.save({
                        'epoch': best_metric_epoch, 
                        'model_state_dict': model.state_dict(),
                        'optim_state_dict': optim.state_dict()
                    }, os.path.join(model_dir, "best_metric_model.pth"))

                print(
                    f"current epoch: {epoch + 1} current mean dice: {test_metric:.4f}\n"
                    f"best mean dice: {best_metric:.4f} "
                    f"at epoch: {best_metric_epoch}"
                )