In [None]:
''' 
This notebook was executed with the following package versions:

MONAI version: 0.2.0+166.g12b3fbf
Python version: 3.6.10 |Anaconda, Inc.| (default, May  8 2020, 02:54:21)  [GCC 7.3.0]
Numpy version: 1.19.5
Pytorch version: 1.7.1

Optional dependencies:
Pytorch Ignite version: 0.3.0
Nibabel version: 3.1.1
scikit-image version: 0.15.0
Pillow version: 7.2.0
Tensorboard version: 1.15.0+nv
gdown version: 3.12.2
TorchVision version: 0.8.0a0
ITK version: 5.1.1

Later MONAI/PyTorch versions are likely to have slight changes in syntax.
'''

In [None]:
import logging
import os
import shutil
import sys
import time
import tempfile
from glob import glob

import nibabel as nib
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from matplotlib import pylab as plt

import monai
from monai.networks.layers import Norm
from monai.data import create_test_image_3d, list_data_collate, ITKReader
from monai.inferers import sliding_window_inference
from monai.inferers import SimpleInferer
from monai.metrics import DiceMetric
from monai.transforms import (
    AsChannelFirstd,
    AsChannelLastd,
    AddChanneld,
    RandAdjustContrastd,
    Compose,
    DivisiblePadd,
    LoadNiftid,
    LoadImaged,
    RandCropByPosNegLabeld,
    RandRotated,
    RandZoomd,
    RandFlipd,
    RandShiftIntensityd,
    RandScaleIntensityd,
    RandAffined,
    Rand3DElasticd,
    RandGaussianNoised,
    ScaleIntensityd,
    SpatialPadd,
    ToTensord,
    DataStats,
)
from monai.utils import first, set_determinism
from monai.visualize import plot_2d_or_3d_image
import itk

In [None]:
# collect files for train/valid/test sets into lists
monai.config.print_config()
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

# load dataset
dpath = '/data/IEVnet/train'
data = []

# create a list of volume/segmentation pairs in training data
img_path_pattern = os.path.join(dpath, "subject_*")
for subject_base in sorted(glob(img_path_pattern)):
    imgL = os.path.join(subject_base, "vol_L_02mm.nii.gz")
    imgR = os.path.join(subject_base, "vol_R_02mm.nii.gz")
    segL = os.path.join(subject_base, "seg_L_02mm.nii.gz")
    segR = os.path.join(subject_base, "seg_R_02mm.nii.gz")
    data += [{"img": imgL, "seg": segL},
             {"img": imgR, "seg": segR}]

# create another dataset for the test data:
test_subject_ids = sorted(['D2_01', 'D2_02', 'D2_03', 'D2_04', 'D2_05', 
                           'D2_06', 'D2_07', 'D2_08', 'D2_09', 'D2_10', 
                           'D3_01', 'D3_02', 'D3_03', 'D3_04', 'D3_05', 
                           'D3_06', 'D3_07', 'D3_08', 'D3_09', 'D3_10', 
                           'D4_01', 'D4_02', 'D4_03', 'D4_04', 'D4_05', 
                           'D4_06', 'D4_07', 'D4_08', 'D4_09', 'D4_10', 
                           'D5_01', 'D5_02', 'D5_03', 'D5_04', 'D5_05', 
                           'D5_06', 'D5_07', 'D5_08', 'D5_09', 'D5_10'])

pn_test = '/data/IEVnet/test'
data_test = []
for idx, sid in enumerate(test_subject_ids):
    for side in ['L','R']:
        # volume
        ff_img = os.path.join(pn_test, sid, 'vol_%s_02mm.nii.gz'%side)
        # manual groundtruth segmentation (used for computation of test metrics, Dice etc.)
        ff_seg = os.path.join(pn_test, sid, 'seg_%s_02mm.nii.gz'%side)
        if os.path.exists(ff_img):
            data_test.append({"img": ff_img, 
                              "seg": ff_seg})
print('Number of vols in test-set: %d. Expected: 80 (40 subjects, L/R IEs).'%(len(data_test)))

val_fraction = 0.1
idx_val_split = int(len(data)*(1-val_fraction))
print('Number of subjects in train set: %d'%(len(data[:idx_val_split])))
print('Number of subjects in valid set: %d'%(len(data[idx_val_split:])))


In [None]:
# split training data into train/val
data_train = data[:idx_val_split]
data_val = data[idx_val_split:]

# determine max rot/trans limits in augmentation
vol_size = np.array([200, 150, 100])
rot_max = 20*np.pi/180.0
trans_max = tuple((vol_size*0.15).astype(int))

# define transforms for image and segmentation during training (augmentation)
train_transforms = Compose(
    [
        LoadNiftid(keys=["img", "seg"]),
        AddChanneld(keys=["img", "seg"]),
        SpatialPadd(keys=["img", "seg"],spatial_size=[208, 160, 112]),
        ScaleIntensityd(keys="img"),
        RandAdjustContrastd(keys="img", prob=0.9, gamma=(0.3, 1.5)),
        RandGaussianNoised(keys="img", prob=0.5),
        RandFlipd(keys=["img","seg"],prob=0.5, spatial_axis=[0]),
        ToTensord(keys=["img", "seg"]),
        Rand3DElasticd(
            keys=["img", "seg"],
            mode=("bilinear", "nearest"),
            prob=0.75,
            sigma_range=(5, 8),
            magnitude_range=(5, 100),
            translate_range=trans_max,
            rotate_range=(rot_max,rot_max,rot_max),
            scale_range=(0.15, 0.15, 0.15),
            padding_mode="border",
        ),        
    ]
)

# define transforms for image and segmentation during validation/testing
val_transforms = Compose(
    [
        LoadNiftid(keys=["img", "seg"]),
        AddChanneld(keys=["img", "seg"]),
        SpatialPadd(keys=["img", "seg"],spatial_size=[208, 160, 112]),
        ScaleIntensityd(keys="img"),
        ToTensord(keys=["img", "seg"]),
    ]
)

In [None]:
# Check a few example augmentations/transforms with a temporary DataLoader
for i in range(10):
    check_ds = monai.data.Dataset(data=data_train, transform=train_transforms)
    check_loader = DataLoader(check_ds, batch_size=1)
    check_data = monai.utils.first(check_loader)
    image, label = (check_data["img"][0][0], check_data["seg"][0][0])
    print(f"image shape: {image.shape}, label shape: {label.shape}")
    # plot the slice [:, :, slice_idx]
    slice_idxs = (np.array([208, 160, 112])/2).astype(int)
    plt.figure("check", (12, 6))
    # plot along 1st axis
    slice_idx = slice_idxs[0]
    plt.subplot(3, 2, 1)
    plt.title("img")
    plt.imshow(np.squeeze(image[slice_idx, :, :]), cmap="gray", vmin=0.0, vmax=1.0)
    plt.subplot(3, 2, 2)
    plt.title("seg")
    plt.imshow(np.squeeze(label[slice_idx, :, :]))
    # plot along 2nd axis
    slice_idx = slice_idxs[1]
    plt.subplot(3, 2, 3)
    plt.title("img")
    plt.imshow(np.squeeze(image[:, slice_idx, :]), cmap="gray", vmin=0.0, vmax=1.0)
    plt.subplot(3, 2, 4)
    plt.title("seg")
    plt.imshow(np.squeeze(label[:, slice_idx, :]))
    # plot along 3rd axis
    slice_idx = slice_idxs[1]
    plt.subplot(3, 2, 5)
    plt.title("img")
    plt.imshow(np.squeeze(image[:, :, slice_idx]), cmap="gray", vmin=0.0, vmax=1.0)
    plt.subplot(3, 2, 6)
    plt.title("seg")
    plt.imshow(np.squeeze(label[:, :, slice_idx]))
    # show
    plt.show()

In [None]:
# define dataset, data loader
# create a training data loader
train_ds = monai.data.CacheDataset(data=data_train, transform=train_transforms, num_workers=4)
# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
batch_size_train = 4
train_loader = DataLoader(
    train_ds,
    batch_size=batch_size_train,
    shuffle=True,
    num_workers=4,
    collate_fn=list_data_collate,
    pin_memory=torch.cuda.is_available(),
)
# create a validation data loader
val_ds = monai.data.CacheDataset(data=data_val, transform=val_transforms, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
# create a test data loader
test_ds = monai.data.CacheDataset(data=data_test, transform=val_transforms, num_workers=4)
test_loader = DataLoader(test_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)

In [None]:
# create VNet, DiceLoss, DiceMetric and Adam optimizer
dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean")
device = torch.device("cuda:1")
# using MONAI Unet implementation for cuda optimizations
# parametrize UNet like VNet (see: https://arxiv.org/abs/1606.04797): 
#   - 4x downsampling
#   - 16/32/64/128/256 filter channels, 
#   - down/up-convolutions with stride 2 instead of max-pooling/up-pooling
#   - number of residual units in each layer: 2 
#   - Dice loss
model = monai.networks.nets.UNet(
    dimensions=3,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    dropout=0.5,
    num_res_units=2).to(device)
loss_function = monai.losses.DiceLoss(include_background=True, sigmoid=True, squared_pred=True)#
optimizer = torch.optim.Adam(model.parameters(), 3e-4)
model.load_weights()

In [None]:
# start a typical PyTorch training
export_tag = 'best_metric_model'
val_interval = 1
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
writer = SummaryWriter()
n_epochs = 120
for epoch in range(n_epochs):
    t0 = time.time()
    print("-" * 10)
    print(f"epoch {epoch + 1}/{n_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(train_ds) // train_loader.batch_size
        print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    t1 = time.time()
    print('Elapsed time for epoch: %0.2f sec.'%(t1-t0))

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            metric_sum = 0.0
            metric_count = 0
            val_images = None
            val_labels = None
            val_outputs = None
            for val_data in val_loader:
                val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
                #roi_size = (96, 96, 96)
                sw_batch_size = 4
                #val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
                val_outputs = model(val_images)
                value = dice_metric(y_pred=val_outputs, y=val_labels)
                metric_count += len(value)
                metric_sum += value.item() * len(value)
            metric = metric_sum / metric_count
            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), export_tag+'.pth')
                print("saved new best metric model")
            print(
                "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
                    epoch + 1, metric, best_metric, best_metric_epoch
                )
            )
            writer.add_scalar("val_mean_dice", metric, epoch + 1)
            # plot the last model output as GIF image in TensorBoard with the corresponding image and label
            plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
            plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
            plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")
            t2 = time.time()
            print('Elapsed time for validation eval: %0.2f sec.'%(t2-t1))

print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
writer.close()

In [None]:
# plot training loss / validation metric, save in publication quality
fig, axs = plt.subplots(1,2,figsize=(14,5),dpi=300)
axs[0].plot(epoch_loss_values)
axs[0].xaxis.set_ticks(np.arange(0, n_epochs, 10))
axs[0].yaxis.set_ticks(np.linspace(0, 1, 21))
axs[0].grid(True, color = '0.9')
axs[0].set_title('Dice loss (train)\n(Min loss %0.4f at epoch %d)'%(torch.min(torch.tensor(epoch_loss_values)),
                                                                    torch.argmin(torch.tensor(epoch_loss_values)+1)))
axs[1].plot(metric_values)
axs[1].xaxis.set_ticks(np.arange(0, n_epochs, 10))
axs[1].yaxis.set_ticks(np.linspace(0, 1, 21))
axs[1].grid(True, color = '0.9')
axs[1].set_title('Dice metric (validation)\n(Max Dice %0.4f at epoch %d)'%(torch.max(torch.tensor(metric_values)),
                                                                           torch.argmax(torch.tensor(metric_values)+1)))
plt.show()

fig.savefig(export_tag+'_Scalars_DiceLossMetric.png',dpi=300)
fig.savefig(export_tag+'_Scalars_DiceLossMetric.pdf',dpi=300)
# also save the numpy arrays of loss and metric
np.save(export_tag+'_Scalars_DiceLoss.npy',torch.tensor(epoch_loss_values).cpu().numpy())
np.save(export_tag+'_Scalars_DiceValidationMetric.npy',torch.tensor(metric_values).cpu().numpy())

In [None]:
# now, run the forward inference on validation set
val_ds = monai.data.Dataset(data=data_val, transform=val_transforms)
tmp_val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, collate_fn=list_data_collate)

# can load previously computed weights (if kernel restarted, or e.g. for fine-tuning)
load_stored_weights = False
if load_stored_weights:
    device = torch.device("cuda:1")
    dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean")
    ff_model_weights = export_tag+'.pth'
        model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        dropout=0.5,
        num_res_units=2).to(device)
    model.load_state_dict(torch.load(ff_model_weights))
    model.eval()
    
t0 = time.time()
times_elapsed = []
inferer = SimpleInferer()
model.eval()
with torch.no_grad():
    val_images = []
    val_labels = []
    val_outputs = []
    val_metrics = []
    t0a = time.time()
    for idx, val_data in enumerate(tmp_val_loader):
        val_image, val_label = val_data["img"].to(device), val_data["seg"].to(device)
        val_output = inferer(val_image, model) 
        val_metric = dice_metric(y_pred=val_output, y=val_label)
        # store results
        val_images.append(val_image)
        val_labels.append(val_label)
        val_outputs.append(val_output)
        val_metrics.append(val_metric.cpu().numpy())
        print(f'Evaluated val vol {idx+1} of {len(val_loader)}')
        times_elapsed.append(time.time()-t0a)
        t0a = time.time()
t1 = time.time()
t_elapsed = (t1-t0)
t_avg = t_elapsed/len(val_loader)
t_avg2 = np.mean(times_elapsed)
print(f'Total inference time for %d samples: %0.3f sec.'%(len(val_loader), t_elapsed))
print(f'Average inference time: %0.3f sec (sd: %0.3f sec).'%(t_avg2, np.std(times_elapsed)))
print(f'Speedup over ANTs segmentation (377 sec.): %0.2f'%(377/t_avg2))
print(f'Dice metric stats:\n{arrayStats(np.array(val_metrics))}')

In [None]:
# after training: plot inference result on validation set:
normalizeTo01 = False
nrcols = 5
cropper = monai.transforms.CenterSpatialCrop([200, 150, 100])
for idx, (img, seg, pred, dice) in enumerate(zip(val_images,
                                                 val_labels,
                                                 val_outputs,
                                                 val_metrics)):
    if idx!=10:
        continue
    print('Val sample: %d'%(idx))
    print('Vol: %s'%(data_val[idx]['img']))
    print(f"image shape: {image.shape}, label shape: {label.shape}")
    img = np.squeeze(img.cpu().numpy(),axis=0)
    seg = np.squeeze(seg.cpu().numpy(),axis=0)
    pred = np.squeeze(pred.cpu().numpy(),axis=0)
    if normalizeTo01:
        pred -= np.min(pred)
        pred /= np.max(pred)
    # plot the slice [:, :, slice_idx]
    slice_idxs = (np.array([208, 160, 112])/2).astype(int)
    plt.figure("check val sample %d"%idx, (12, 6))
    # plot along 1st axis
    # img
    slice_idx = slice_idxs[0]-10
    plt.subplot(1, nrcols, 1)
    plt.title("img")
    plt.imshow(np.squeeze(img[0,slice_idx, :, :]), cmap="gray")
    # seg
    plt.subplot(1, nrcols, 2)
    plt.title("seg")
    plt.imshow(np.squeeze(seg[0,slice_idx, :, :]))
    # pred
    plt.subplot(1, nrcols, 3)
    plt.title("pred")
    plt.imshow(np.squeeze(pred[0,slice_idx, :, :]))
    # pred sigmoid activate
    pred_sig = np_sigmoid(pred)
    plt.subplot(1, nrcols, 4)
    plt.title("pred sigmoid")
    plt.imshow(np.squeeze(pred_sig[0,slice_idx, :, :]))
    # pred|0.5
    th = 0.5
    plt.subplot(1, nrcols, 5)
    plt.title(f"pred > {th}")
    plt.imshow(np.squeeze(pred_sig[0,slice_idx, :, :]>th))
    # overlay
    # todo
    plt.show()
    break