In [None]:
import glob, os
from random import *
import shutil
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
import time, math
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler


import torchio as tio

from monai.config import print_config
from monai.data import CacheDataset, DataLoader, partition_dataset
from monai.inferers import sliding_window_inference
from monai.losses import DiceLoss, GeneralizedDiceLoss
from monai.metrics import compute_meandice
from monai.networks.layers import Norm
from monai.networks.nets import UNet
from monai.transforms import (
    AsDiscrete, Compose, LoadNiftid, ToTensord, AddChanneld, LabelToContour,
)
from monai.networks.blocks import Convolution, Upsample
from monai.networks.layers.factories import Pool, Act
from monai.networks.layers import split_args
from monai.utils import set_determinism
from monai.optimizers import Novograd

print_config()
print('TorchIO version:', tio.__version__)

In [None]:
os.environ["MONAI_DATA_DIRECTORY"] = "./data"
os.environ["CUDA_VISIBLE_DEVICES"] = '0, 1, 2, 3'
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = directory
print(root_dir)

In [None]:
data_dir = os.path.join(root_dir, "nifti_data")
train_images = sorted(glob.glob(os.path.join(data_dir, "image", "*.nii.gz")))
train_labels = sorted(glob.glob(os.path.join(data_dir, "mask", "*.nii.gz")))
data_dicts = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(train_images, train_labels)
]

In [None]:
### hyperparameter setting
set_determinism(seed=0)

bs = 64
Height = 480
Width = 480
Depth = 16
epoch_num = 500
l_rate = 1e-3
multi_GPU = True

In [None]:
### image augmentation transform with monai and torchio API

# HistogramStandardization parameter calculation
histogram_landmarks_path = 'landmarks.npy'
landmarks = tio.HistogramStandardization.train(
    train_images,
    output_path=histogram_landmarks_path,
)
np.set_printoptions(suppress=True, precision=3)
print('\nTrained landmarks:', landmarks)

# transform setting
train_transforms_monai = [
        LoadNiftid(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),
        ToTensord(keys=["image", "label"]),
]

train_transforms_io = [
        tio.CropOrPad((Height, Width, Depth),mask_name='label', include=["image", "label"]),
        tio.HistogramStandardization({'image': landmarks}, include=["image"]),
        tio.ZNormalization(masking_method=tio.ZNormalization.mean, include=["image"]),
        tio.RandomNoise(p=0.1, include=["image"]),
        tio.RandomFlip(axes=(0,), include=["image", "label"]),
]

validation_transforms_monai = [
        LoadNiftid(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),
        ToTensord(keys=["image", "label"]),
]

validation_transforms_io = [
    tio.CropOrPad((Height, Width, Depth), include=["image", "label"], mask_name='label'),
    tio.HistogramStandardization({'image': landmarks}, include=["image"]),
    tio.ZNormalization(masking_method=tio.ZNormalization.mean, include=["image"]),
]

# transform composition
train_transforms = Compose(train_transforms_monai + train_transforms_io)
val_transforms = Compose(validation_transforms_monai + validation_transforms_io )

In [None]:
## train dataset using MONAI cachedataset for speed-up 

train_data, val_data, test_data = partition_dataset(data_dicts, ratios = [0.8, 0.1, 0.1], shuffle = True)

train_ds = CacheDataset(data=train_data, transform=train_transforms, cache_rate=1.0, num_workers=8)
val_ds = CacheDataset(data=val_data, transform=val_transforms, cache_rate=1.0, num_workers=8)
test_ds = CacheDataset(data=test_data, transform=val_transforms, cache_rate=1.0, num_workers=8)

print('\n'+'Training set:', len(train_data), 'subjects')
print('Validation set:', len(val_data), 'subjects')
print('Validation set:', len(test_data), 'subjects')

In [None]:
## dataloader

train_loader = DataLoader(train_ds, batch_size=bs, shuffle=True, num_workers=8)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=8)
test_loader = DataLoader(test_ds, batch_size=1, num_workers=8)

In [None]:
model = UNet(
    dimensions=3,
    in_channels=1,
    out_channels=2,
    channels=(32, 64, 128, 256, 512),
    strides=(2, 2, 2, 2),
    num_res_units=3,
    norm=Norm.BATCH,
    dropout=0.3
)

In [None]:
### training preparation

if multi_GPU:
    device = torch.device('cuda',0)
    model = torch.nn.DataParallel(model,output_device=0) # wrap the model with DataParallel module
    model.cuda()
else:
    device = torch.device('cuda',0)
    model.cuda()

loss_function = DiceLoss(to_onehot_y=True, softmax=True).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr = l_rate, weight_decay = 0.001)

In [None]:
## train start

val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
epoch_time = []
total_start = time.time()
post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=2)
post_label = AsDiscrete(to_onehot=True, n_classes=2)

for epoch in range(epoch_num):
    epoch_start = time.time()
    print("-" * 50)
    print(f"epoch {epoch + 1}/{epoch_num}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step_start = time.time()
        step += 1
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = math.ceil(len(train_ds)/train_loader.batch_size)
        print(
            f"{step}/{epoch_len}, train_loss: {loss.item():.4f}"
            f" step time: {(time.time() - step_start):.4f} seconds"
            )
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            metric_sum = 0.0
            metric_count = 0
            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                roi_size = (Height, Width, Depth)
                sw_batch_size = 1
                val_outputs = sliding_window_inference(
                val_inputs, roi_size, sw_batch_size, model
                        )
                val_outputs = post_pred(val_outputs)
                val_labels = post_label(val_labels)
                value = compute_meandice(
                    y_pred=val_outputs,
                    y=val_labels,
                    include_background=False,
                )
                metric_count += len(value)
                metric_sum += value.sum().item()
            metric = metric_sum / metric_count
            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_model.pth"))
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
            )
    print(
        f"time consuming of epoch {epoch + 1} is:"
        f" {(time.time() - epoch_start):.4f} seconds"
        )

In [None]:
print(f"train completed, best_metric: {best_metric:.4f}  at epoch: {best_metric_epoch}")

In [None]:
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("epoch")
plt.plot(x, y)
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
x = [val_interval * (i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("epoch")
plt.plot(x, y)
plt.show()

In [None]:
model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))
model.eval()
with torch.no_grad():
    for i, test_data in enumerate(test_loader):
        roi_size = (Height, Width, Depth)
        sw_batch_size = 1
        test_image = test_data["image"].to(device)
        test_output = sliding_window_inference(
                        test_image, roi_size, sw_batch_size, model)
        # plot the slice [:, :, rand]
        j = randint(0, len(test_image[0,0,0,0,:])-1)
        plt.figure("check", (20, 4))

        plt.subplot(1, 5, 1)
        plt.title(f"image {i}")
        plt.imshow(test_image.detach().cpu()[0, 0, :, :, j], cmap="gray")

        plt.subplot(1, 5, 2)
        plt.title(f"Ground truth mask {i}")
        plt.imshow(test_data["label"][0, 0, :, :, j])

        plt.subplot(1, 5, 3)
        plt.title(f"AI predicted mask {i}")
        argmax = AsDiscrete(argmax=True)(test_output)
        plt.imshow(argmax.detach().cpu()[0, 0, :, :, j])

        plt.subplot(1, 5, 4)
        plt.title(f"contour {i}")
        contour = LabelToContour()(argmax)
        plt.imshow(contour.detach().cpu()[0, 0, :, :, j])

        plt.subplot(1, 5, 5)
        plt.title(f"overaying contour {i}")
        map_image = test_image.clone().detach()
        map_image[contour==1] = map_image.max()
        plt.imshow(map_image.detach().cpu()[0, 0, :, :, j], cmap="gray")
        plt.show()