# 3D Segmentation with UNet

In [1]:
import os
import sys
import tempfile
from glob import glob
import logging

import nibabel as nib
import numpy as np
import torch
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.handlers import ModelCheckpoint
from torch.utils.data import DataLoader

import monai
from monai.data import NiftiDataset, create_test_image_3d
from monai.transforms import Compose, AddChannel, ScaleIntensity, Resize, ToTensor, RandUniformPatch
from monai.handlers import \
    StatsHandler, TensorBoardStatsHandler, TensorBoardImageHandler, MeanDice, stopping_fn_from_metric
from monai.networks.utils import predict_segmentation

monai.config.print_config()
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

MONAI version: 0+untagged.106.gf2e8580.dirty
Python version: 3.6.9 |Anaconda, Inc.| (default, Jul 30 2019, 19:07:31)  [GCC 7.3.0]
Numpy version: 1.17.4
Pytorch version: 1.4.0a0+a5b4d78
Ignite version: 0.3.0


## Setup demo data

In [2]:
# Create a temporary directory and 50 random image, mask paris
tempdir = tempfile.mkdtemp()

for i in range(50):
    im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

    n = nib.Nifti1Image(im, np.eye(4))
    nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i))

    n = nib.Nifti1Image(seg, np.eye(4))
    nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i))

## Setup transforms, dataset

In [3]:
images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz')))
segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))

# Define transforms for image and segmentation
imtrans = Compose([
    ScaleIntensity(), 
    AddChannel(), 
    RandUniformPatch((96, 96, 96)), 
    ToTensor()
])
segtrans = Compose([
    AddChannel(), 
    RandUniformPatch((96, 96, 96)), 
    ToTensor()
])

# Define nifti dataset, dataloader.
ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans)
loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())
im, seg = monai.utils.misc.first(loader)
print(im.shape, seg.shape)

torch.Size([10, 1, 96, 96, 96]) torch.Size([10, 1, 96, 96, 96])


## Create Model, Loss, Optimizer

In [4]:
# Create UNet, DiceLoss and Adam optimizer.
net = monai.networks.nets.UNet(
    dimensions=3,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
)

loss = monai.losses.DiceLoss(do_sigmoid=True)
lr = 1e-3
opt = torch.optim.Adam(net.parameters(), lr)

## Create supervised_trainer using ignite

In [6]:
# Create trainer
device = torch.device("cuda:0")
trainer = create_supervised_trainer(net, opt, loss, device, False)

## Setup event handlers for checkpointing and logging

In [7]:
### optional section for checkpoint and tensorboard logging
# adding checkpoint handler to save models (network params and optimizer stats) during training
checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False)
trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                          handler=checkpoint_handler,
                          to_save={'net': net, 'opt': opt})
# StatsHandler prints loss at every iteration and print metrics at every epoch,
# we don't set metrics for trainer here, so just print loss, user can also customize print functions
# and can use output_transform to convert engine.state.output if it's not a loss value
train_stats_handler = StatsHandler(name='trainer')
train_stats_handler.attach(trainer)


# TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
train_tensorboard_stats_handler = TensorBoardStatsHandler()
train_tensorboard_stats_handler.attach(trainer)

## Add Vadliation every N epochs

In [8]:
### optional section for model validation during training
validation_every_n_epochs = 1
# Set parameters for validation
metric_name = 'Mean_Dice'
# add evaluation metric to the evaluator engine
val_metrics = {metric_name: MeanDice(add_sigmoid=True, to_onehot_y=False)}

# ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration,
# user can add output_transform to return other values
evaluator = create_supervised_evaluator(net, val_metrics, device, True)

# create a validation data loader
val_imtrans = Compose([
    ScaleIntensity(),
    AddChannel(),
    Resize((96, 96, 96)),
    ToTensor()
])
val_segtrans = Compose([
    AddChannel(),
    Resize((96, 96, 96)),
    ToTensor()
])
val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans)
val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available())


@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
def run_validation(engine):
    evaluator.run(val_loader)


# Add stats event handler to print validation stats via evaluator
val_stats_handler = StatsHandler(
    name='evaluator',
    output_transform=lambda x: None,  # no need to print loss value, so disable per iteration output
    global_epoch_transform=lambda x: trainer.state.epoch)  # fetch global epoch number from trainer
val_stats_handler.attach(evaluator)

# add handler to record metrics to TensorBoard at every validation epoch
val_tensorboard_stats_handler = TensorBoardStatsHandler(
    output_transform=lambda x: None,  # no need to plot loss value, so disable per iteration output
    global_epoch_transform=lambda x: trainer.state.epoch)  # fetch global epoch number from trainer
val_tensorboard_stats_handler.attach(evaluator)

# add handler to draw the first image and the corresponding label and model output in the last batch
# here we draw the 3D output as GIF format along Depth axis, at every validation epoch
val_tensorboard_image_handler = TensorBoardImageHandler(
    batch_transform=lambda batch: (batch[0], batch[1]),
    output_transform=lambda output: predict_segmentation(output[0]),
    global_iter_transform=lambda x: trainer.state.epoch
)
evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler)

<ignite.engine.engine.RemovableEventHandle at 0x7fa60714a978>

## Run training loop

In [None]:
# create a training data loader
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

train_ds = NiftiDataset(images[:20], segs[:20], transform=imtrans, seg_transform=segtrans)
train_loader = DataLoader(train_ds, batch_size=5, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available())

train_epochs = 5
state = trainer.run(train_loader, train_epochs)

## Visualizing Tensorboard logs

In [None]:
log_dir = './runs'  # by default TensorBoard logs go into './runs'

%load_ext tensorboard
%tensorboard --logdir $log_dir

In [None]:
!rm -rf {tempdir}