<a href="https://www.nvidia.com/dli"> <img src="DLI Header.png" alt="Header" style="width: 400px;"/> </a>
# Prostate segmentation with V-Net

#  <span style="color:green">Overview</span>

Being able to process volumetric data is a necessity in medical image analysis. Most medical data is in fact volumetric. CT scans, MRI, PET, SPECT and even ultrasound are producing 3D data which contain several million voxels (3D pixels) per scan. Although this type of data is extremely useful for diagnosis, as it reflects the true patient 3D anatomical configuration, dealing with such dimensionality can be challenging from a computational standpoint. In particular, neural networks which need to be trained on this kind of data suffer from the computational challenges brought by the amount of information in these scans. Recent advancements in both GPU computing capabilities, research [Ref. 1] and deep learning software have made the application of 3D-CNNs to 3D data possible. Large GPU memory sizes, the availability of accelerated software routines for 3D convolution, deconvolution and pooling as well as theoretical imporvements such as the introduction of Dice loss, group normalization, skip connections and encoder-decoder architectures have enabled interesting advancements in this field. 

In this exercise we show how to implement a popular volumetric CNN design, similar to V-Net [Ref. 1] and train it on 3D MRI data depicting prostate obtained from the segmentation decathlon 2018 challenge (http://medicaldecathlon.com). Our implementation is in PyTorch.

#  <span style="color:green">Introduction</span>

V-Net has introduced three important novelties in DL medical image processing.
* 3D network architecture processing volumes natively and at high resolution
* Dice loss layer to learn how to segment without suffering from the drawbacks of other losses
* (Long +) Short skip connections to accelerate learning and convergence.

## Dice loss for segmentation
The dice coefficient measures the overlap between two (binary) contours and has been generalized and introduced as an objective function for FCNNs in [Ref. 3]. Since then it has been utilized in a number of scientific works and it is now very well established. The formulation used in this work is DICE=2 * (Gt * Pred) / (Gt^2 + Pred^2). This corresponds to Dice when both Gt and Pred are binary.

other formulations such as DICE=2 * (Gt * Pred) / (Gt + Pred)

have been proposed but have slightly different behaviours, especially when it comes to gradients. You can get more information about this topic in [Ref. 5]

## The importance of skip connections in biomedical image segmentation
In a medical segmentation task, we want to map every voxel location of a medical image to a distinct class value representing for example background or organ of interest. Consequently, the input (image) and the output (segmentation) usually have the same spatial extent. A straight-forward approach to design a neural network architecture to achieve this goal would be to have several fully connected layers without changing the spatial dimensionality. However, this approach would very quickly lead to an explosion in terms of number of parameters. Convolutional operations have the advantage to significantly reduce the number of parameters, as the same operation is applied in strides over the entire image. Furthermore they are invariant to translation, e.g. it is not important if the object of interest is shifted. 

Designing a fully convolutional network without reducing the spatial dimensions would be possible, but research indicates that a encoder-decoder network is more effective [Ref. 6]. An encoder-decoder network typically looks like this:

![Encoder-Decoder Network](encoder.png)

The contracting path (Encoder) maps the image to a feature representation that is often lower dimensional than the original input size. The expanding path (Decoder) maps the feature representation to the output space. Originally, this type of network was known in connection with auto-encoder, where the output space is the same as the input space. In case of segmentation the output space often has the same spatial extend as the input space, but represents different content (classes). 

Thanks to the downsampling in the contracting path, fewer parameters need to be trained. However, this comes at the cost of losing spatial information. One approach to counteract this is the use of so-called skip connections: allowing the gradient to skip part of the network and to flow directly from a layer of the contracting part to the expanding path. 

## Fusing features from different layers
There are different techniques to realize these skip connections. Two of the most common ones are concatenation and summation:

### Feature concatenation
One possibility is to simply concatenate the layers. This requires the layers to have the same dimension in the concatenation direction. 
So a concatenation t1 = [1 2 3] with t2 = [4 5 6] could be t_new = [1 2 3 4 5 6].  In Pytorch this operation is torch.cat ([see here](https://pytorch.org/docs/stable/torch.html)).

<img src=concatenation.png width="600">


### Feature summation
Another widely used approach is element-wise summation. One very nice property is that it keeps the number of features fixed. A summation of t1= [1 2 3] and t2 = [4 5 6] would be t_new = [5 7 9]. 
<img src=summation.png width="600">

#  <span style="color:green">Best practice for structuring code</span>

## Code and exercise structure
Structuring code to solve a machine learning problem to ensure both flexibility and adoption of best practices is not an easy task. In this exercise we try to incorporate some of the best principles that have emerged in popular recent python projects.

With the introduction of modern frameworks such as tensorflow and pytorch, most of the processes around developement of DL approaches have been standardized. During developement of a typical project it is necessary to take care of only a handful of compartimentalized tasks such as:
* DATA
    1. Load data, standardize and augment it
    2. Split dataset into batches and iterate through them
* NETWORK
    1. Define a network architecture as computational graph
    2. Define suitable loss
* OPTIMIZATION
    1. Define optimization algorithm
    2. Implement training and validation loops
    
### Handling Data

For **data** handling we define transforms in charge of loading, standardizing and modifying the dataset. Our transforms are chainable (stackable) such that data handling pipelines can be created. The dataset is stored in a python **dictionary** in order to allow this behaviour.

**Transforms** are implemented by classes. In the constructor of these transforms (`__init__` method in python) we pass the parameters of the transform. We define the `__call__` method to accept only one user defined argument which is the dictionary containing data. 

```
class ExampleTransform(object):
    # this transformation adds a constant to the images of a dataset
    
    def __init__(self, constant):
        # the constant that we add is passed as a parameter of this transform 
        self.constant = constant
        
    def __call__(self, data):
        # data is a dictionary containing the dataset. we suppose it has a field 'images'
        data['images'] = data['images'] + self.constant
        
        # the modified version of the data dictionary is returned as a result of the transform
        return data 
```

The example code above implements a simple transform that adds a constant to all of the images of the dataset. Other transforms, including those aiming at loading datasets from filesystem and actually inject new data into the 'data' dictionary can be implemented.

In order to split the dataset into batches and iterate through these batches during training validation and potentially testing we need a batch iterator. This is implemented here through a python class which acts as a generator. That is, we can use our batch iterator object in a for loop to get batches that we can use during training. The batch iterator will return a dictionary containing data at each iteration. The batch iterator is also able to execute tranforms (as defined above) both before and during iterating over the dataset. More details will be shown later in the exercise. 

### Network definition

In the following sections of this exercise you will find the implementation of the network object in pytorch. A few network block objects have been defined in order to break down the implementation in more manageable portions and group together code that can be re-used.

Loss layers can be defined very easily in pytroch by implementing only the 'forward' computation of the loss and omitting the gradient implementation. This is possible thanks to the built-in automatic differentiation capabilities of pytorch and other modern DL frameworks.

### Fitting the networks parameters to the data

In order to train, validate, and test the network we need to write the relevant code implementing the training, validation and testing loops (testing loop omitted here). The basic functionality of this code is to instantiate the network layers (init), the network forward computation function (forward member function), and finally instantiate batch iterators and optimizer.  
At this point we can iterate (using a for loop) through the batches which can be fed to the network in order to optimize it for the task at hand.

In [None]:
# In this section we import all the python packages used in this exercise

import numpy as np

import torch
import torch.optim as optim
import torch.backends.cudnn as cudnn

import os
import copy

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

import nibabel as nib

from nilearn.image import resample_img
from random import shuffle

from multiprocessing.pool import ThreadPool
from functools import partial

from torch import nn
from torch.autograd import Variable
from torch.nn import Module, Conv3d, Parameter

%matplotlib inline

# set random seeds for reproducibility
torch.manual_seed(551)
torch.cuda.manual_seed_all(551)
np.random.seed(551)


## Loading the dataset
This function is in charge of creating a dataset by scanning the specified path for images and labels and filter the results in order to take only nifty files (extension nii).

In [None]:
def make_decathlon_prostate_dataset(images_path, labels_path):
    # in this method we create a dataset which is a list of data dictionaries
    dataset = []
        
    file_list = [f for f in os.listdir(images_path) if 'nii' in f]
        
    for file in file_list:
        dataset.append({
            'images': [os.path.join(images_path, file)],
            'labels': [os.path.join(labels_path, file)],
        })
            
    return dataset

## Data transformation and management
Here we define a few transformations that allow us to manipulate data in order to feed it to the network and realize training. These transforms, once instantiated, can be chained together. All of them can be called by feeding a data dictionary as input and return a modified version of such dictionary. 

In this exercise we start out with a dataset of MRI prostate images that contains *file names* for images and labels and we execute the following transformations to obtain images that can be fed to the network.

* **LoadNiftyFromFilename** loads nifty (2D)/3D/4D+ data into memory as a nifty image (through nibabel package)
* **ResampleNiftyVolume** makes all nifty volumes have the same spatial resolution (mm per voxel)
* **NiftyToNumpy** transforms the nifty volumes to numpy
* **CropCenteredSubVolume** takes numpy volumes of any spatial size (width x height x depth) and crops them to be a specific size (so they can be fed to the network)



In [None]:
class LoadNiftyFromFilename(object):
    def __init__(self, field):
        self.field = field

    def __call__(self, data):
        entries = []

        for entry in data[self.field]:
            entries.append(nib.load(entry))

        data[self.field] = entries
        return data


In [None]:
class ResampleNiftyVolume(object):
    def __init__(self, resolution, field, interpolation='continuous'):
        assert len(resolution) == 3
        self.interpolation = interpolation
        self.resolution = resolution
        self.field = field

    def __call__(self, data):
        entries = []
        spacings = []
        sizes = []

        for entry in data[self.field]:
            current_spacing = entry.header.get_zooms()
            current_shape = entry.header.get_data_shape()

            image_t = resample_img(
                img=entry,
                target_affine=np.diag([self.resolution[0], self.resolution[1], self.resolution[2]]),
                interpolation=self.interpolation
            )

            entries.append(image_t)
            spacings.append(current_spacing)
            sizes.append(current_shape)

        data[self.field] = entries
        data[self.field + '_spacings'] = spacings
        data[self.field + '_sizes'] = sizes

        return data

In [None]:
class NiftyToNumpy(object):
    def __init__(self, field):
        self.field = field

    def __call__(self, data):
        entries = []

        for entry in data[self.field]:
            entry_t = entry.get_data().astype(np.float32)
            
            if entry_t.ndim < 4:  # if label (labels have 3 spatial dimensions (single channel))
                # make it single class!!
                entry_t = (entry_t > 0.5).astype(np.float32)  
                
                # add channel dimension (1)
                entry_t = entry_t[np.newaxis]
            elif entry_t.ndim == 4:  # if image (images have 4 spatial dimensions (two channels))
                entry_t = np.transpose(entry_t, [3, 0, 1, 2])
                
                for i in range(entry_t.shape[0]):
                    # normalize channel-wise (each channel is a different MRI pulse sequence, "color" MRI)
                    entry_t[i] = (entry_t[i] - np.min(entry_t[i])) / (np.max(entry_t[i]) - np.min(entry_t[i]))

            entries.append(entry_t)

        data[self.field] = entries

        return data

In [None]:
# this transformation pads or crops images to make all of them have the same size. 
# it's a complex transform and you don't need to know all the details about it
class CropCenteredSubVolume(object):
    def __init__(self, size, image_field, label_field=None):
        self.size = size
        self.image_field = image_field
        self.label_field = label_field
        
    def pad_to_minimal_size(self, image, pad_mode='constant'):
        pad = self.size - np.asarray(image.shape[1:4]) + 1
        pad[pad < 0] = 0

        pad_before = np.floor(pad / 2.).astype(int)
        pad_after = (pad - pad_before).astype(int)

        pad_vector = [(0, 0)]
        for i in range(image.ndim - 1):
            if i < 3:
                pad_vector.append((pad_before[i], pad_after[i]))
            else:
                pad_vector.append((0, 0))
        image = np.pad(array=image, pad_width=pad_vector, mode=pad_mode)

        return image, pad_before, pad_after

    def __call__(self, data):
        image_entries = []
        label_entries = []

        image_field = self.image_field
        label_field = self.label_field

        for image_entry, label_entry in zip(data[image_field], data[label_field]):
            assert np.all(np.asarray(image_entry.shape[1:4]) == np.asarray(label_entry.shape[1:4]))

            image_entry, pad_before, pad_after = self.pad_to_minimal_size(image_entry, pad_mode='constant')
            label_entry, _, _ = self.pad_to_minimal_size(label_entry, pad_mode='constant')

            h_size = np.floor(np.asarray(self.size) / 2.).astype(int)
            centr_pix = np.floor(np.asarray(image_entry.shape[1:4]) / 2.).astype(int)

            start_px = (centr_pix - h_size).astype(int)

            end_px = (start_px + self.size).astype(int)

            assert np.all(end_px <= np.asarray(image_entry.shape[1:4]))
            assert np.all(start_px >= 0)

            image_patch = image_entry[:, start_px[0]:end_px[0], start_px[1]:end_px[1], start_px[2]:end_px[2]]

            label_patch = label_entry[:, start_px[0]:end_px[0], start_px[1]:end_px[1], start_px[2]:end_px[2]]

            crop_before = start_px
            crop_after = image_entry.shape[1:4] - end_px - 1

            assert np.all(np.asarray(image_patch.shape[1:4]) == self.size)
            assert np.all(np.asarray(label_patch.shape[1:4]) == self.size)

            image_entries.append(image_patch)
            label_entries.append(label_patch)

        data[self.image_field] = image_entries
        data[self.label_field] = label_entries

        return data

## Batching and iterating

In this exercise the batch iterator takes care of the whole data loading/transformation/batching process. It gives us the ability to iterate through the dataset and specify the transformations that need to be applied to the data. All the transformations applied in this exercise are done upon instatiation of the batch iterator. The dataset, which consists of file names at the beginning, get converted to 4D (3D + channel) numpy array having the necessary format and characteristics to be meaningfully used by the network.

We implement here such object which has a method `__iter__` allowing us to use it as a generator (Eg. we can write `for batch in iterator: ...`).

When executing `for batch in training_batch_iterator:` later in this exercise we will loop through the dataset and obtain batches.

In [None]:
class BatchIterator(object):
    def __init__(
        self,
        batch_size,
        keys,
        data,
        global_transforms,
        shuffle=False
    ):
        self.data = copy.deepcopy(data)
        self.keys = keys
        self.length = len(data)
        self.batch_size = batch_size
        self.global_transforms = []
        self.n_batches = int(np.ceil(len(data) / self.batch_size))
        self.shuffle = shuffle

        self.global_transforms = global_transforms
        
        transform_helper = partial(self.transform_helper, data=self.data, transforms=self.global_transforms)

        with ThreadPool(32) as p:
            p.map(transform_helper, range(len(self.data)))
                
    @staticmethod
    def transform_helper(idx, data, transforms):
        for transform in transforms:
            data[idx] = transform(data[idx])
            
    def __len__(self):
        return self.n_batches

    def __iter__(self):
        data = copy.deepcopy(self.data)

        if self.shuffle:
            shuffle(data)

        for i in range(self.n_batches):
            curr_data = data[i * self.batch_size:np.min([(i+1) * self.batch_size, self.length])]

            # collate batch (from a list of dictionary to a dictionary of lists)
            batch = {}
            for key in self.keys:
                batch[key] = []
                for j in range(len(curr_data)):
                    batch[key].append(curr_data[j][key][0])
                batch[key] = np.stack(batch[key])

            yield batch

## Training/validation parameters
Here we specify all the hyper-parameters and options that define what/how our network will learn.

In [None]:
# DEFININING MODEL PARAMETERS

# voxel size of the image (3d images)
image_size = [256, 256, 96]
# resolution of each voxel in millimeters
image_resolution = [1, 1, 1]

# batch size for iterator (how many images get fed to network at each batch)
batch_size = 2
# batch size for learning (every how many batches should we backpropagate) 
# (Setting the previous parameter to 2 and this to 4 induces a behavior similar to batch size 8) 
effective_batchsize = 1 
# normalization ('none'|'batchnorm'|'groupnorm')
normalization = 'groupnorm'
# number of training epochs
num_epochs = 50
# learning rate
learning_rate = 0.0001

# path of the datasets for training and validation
training_images_path = './Task05_Prostate/imagesTr'
training_labels_path = './Task05_Prostate/labelsTr'

# [todo] change validation path
validation_images_path = './Task05_Prostate/imagesVd'
validation_labels_path = './Task05_Prostate/labelsVd'

## Data loading and transformations
Here we define the sequence of transformations we do over our dataset. First we load our images and labels, then we resample them to make them have the same voxel resolution across cases, we then convert them to numpy format and finally crop them to have the same voxel size.

In [None]:
# DEFINING DATA TRANSFORMATION SEQUENCE. TRANSFORMS DEFINED ABOVE

global_transforms = [
    LoadNiftyFromFilename('images'),
    LoadNiftyFromFilename('labels'),
    ResampleNiftyVolume(field='images', resolution=image_resolution, interpolation='continuous'),
    ResampleNiftyVolume(field='labels', resolution=image_resolution, interpolation='nearest'),
    NiftyToNumpy('images'),
    NiftyToNumpy('labels'),
    CropCenteredSubVolume(size=image_size, image_field='images', label_field='labels'),
]

We instantiate here 1) datasets and 2) batch iterators which provide training and validation batches from our dataset to the training routine.

In [None]:
# LOADING DATASET FOR TRAINING AND VALIDATION -- LOADING WILL TAKE TIME -- LOAD ALL DATA IN MEMORY --

training_dataset = make_decathlon_prostate_dataset(training_images_path, training_labels_path)

validation_dataset = make_decathlon_prostate_dataset(validation_images_path, validation_labels_path)

# TRAINING BATCH ITERATOR
train_batch_iterator = BatchIterator(
            batch_size=batch_size,
            keys=['images', 'labels'],
            data=training_dataset,
            global_transforms=global_transforms,
            shuffle=True
)

# VALIDATION BATCH ITERATOR
valid_batch_iterator = BatchIterator(
            batch_size=batch_size,
            keys=['images', 'labels'],
            data=validation_dataset,
            global_transforms=global_transforms,
            shuffle=False
)

## Look at your data
It is important to always inspect the data before deciding what method is appropriate to solve the problem at hand. Having a look at the data might mean to take into consideration statistics and distributions underlying the dataset, but in this case we are interested in visually inspecting it in order to be sure that it has been loaded and transformed correctly.

In [None]:
# visualize data
for data in train_batch_iterator:
    # showing just 1 channel out of two from MRI image
    # label has been thresholded to be class 0 background, class 1 prostate
    plt.imshow(np.squeeze(data['images'][0, 0, :, : , 30] + data['labels'][0, 0, :, : , 30]))  
    plt.show()

#  <span style="color:green">V-Net network architecture</span>


### BatchNorm / GroupNorm

Normalizing the data as it passes through the network provides two main advantages. First, it keeps the values during training centered near 0, which is where activation functions are most nonlinear and hence most sensitive. Second, it introduces some noise, as the renormalizing procedure varies between images and discards a small amount of information. The addition of noise makes the network more robust and less likely to overfit.

However, there are multiple types of normalization available. A conventional choice is BatchNorm. In this method, for a single channel (the output of a particular convolution kernel from a previous layer), the mean and standard deviation of all pixel values in the channel are computed across an entire batch. The mean is subtracted from each pixel value, centering them at zero, and then they are scaled by the standard deviation, giving the new data mean 0 and standard deviation 1. Finally, every value is re-scaled by a single, trainable weight, and shifted by a trainable bias. For example, in a batch of 20 100x100 RGB images, the 20x100x100 = 200,000 red pixel values would be grouped together, as would the 200,000 green and blue values, respectively.

The disadvantage of this method is that when we deploy our model, the statistics of a single image are non-existent, so a running average must be computed during training. Similarly, if the model or input is large, the batch size may be small, again reducing the statistical power.

A solution to these disadvantages is found in the GroupNorm. The procedure is quite similar, but the statistics are only computed for a single image at a time. To improve statistical power, multiple channels are grouped together instead. This leaves us without the problems on small batches or at deployment. In the example above, the RGB images of size 100x100, regardless of batch size, 3x100x100 = 30,000 values would be normalized together.

### Dice loss/score formulation
We define here the dice loss layer which penalizes segmentations that do not overlap very well with the ground truth. As explained before this overlap measure between two (binary) contours has been generalized and introduced as an objective function for FCNNs in [Ref. 3]. The formulation used in this work is DICE=2 * (Gt * Pred) / (Gt^2 + Pred^2).

We provide here a PyTorch implementation of said Dice loss layer.

In [None]:
EPS = 0.00001

class DiceLoss(Module):
    def forward(self, input, target):
        num = (input * target).sum(dim=4, keepdim=True).sum(dim=3, keepdim=True).sum(dim=2, keepdim=True)
        den1 = input.pow(2).sum(dim=4, keepdim=True).sum(dim=3, keepdim=True).sum(dim=2, keepdim=True)
        den2 = target.pow(2).sum(dim=4, keepdim=True).sum(dim=3, keepdim=True).sum(dim=2, keepdim=True)

        dice = (2.0 * num / (den1 + den2 + EPS))
        
        return (1.0 - dice).mean()

## Defining the network architecture
In this exercise we will use a network **similar** to V-Net, which was introduced in [Ref. 2]. Although we introduced some changes (BatchNorm/GroupNorm, summation for skip connection and different input size) to make this exercise more actual and general, the network architecture remains slightly the same. A schematic representation of this architecture is shown below.

![V-Net](diagram.png)

### Pytorch implementation
A network inspired by V-Net, and further improving it, has been implmented here in a slightly different manner than the original paper.

Differently than original V-Net we propose to:
* Allow user to specify if batchnorm or groupnorm should be used
* Use a summation strategy for long skip connections instead of concatenation
* Sligthly change the network architecture

In [None]:
class GroupNorm3D(Module):
    def __init__(self, num_features, num_groups=16, eps=1e-5):
        super(GroupNorm3D, self).__init__()
        self.weight = Parameter(torch.ones(1, num_features, 1, 1, 1))
        self.bias = Parameter(torch.zeros(1, num_features, 1, 1, 1))
        self.num_groups = num_groups
        self.eps = eps

    def forward(self, x):
        N, C, H, W, D = x.size()
        G = self.num_groups
        assert C % G == 0

        x = x.view(N, G, -1)
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, keepdim=True)

        x = (x-mean) / (var+self.eps).sqrt()
        x = x.view(N, C, H, W, D)
        return x * self.weight + self.bias
    

class ResidualConvBlock(nn.Module):
    def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none', expand_chan=False):
        super(ResidualConvBlock, self).__init__()

        self.expand_chan = expand_chan
        if self.expand_chan:
            ops = []

            ops.append(nn.Conv3d(n_filters_in, n_filters_out, 1))

            if normalization == 'batchnorm':
                ops.append(nn.BatchNorm3d(n_filters_out))
            if normalization == 'groupnorm':
                ops.append(GroupNorm3D(n_filters_out))

            ops.append(nn.ReLU(inplace=True))

            self.conv_expan = nn.Sequential(*ops)

        ops = []
        for i in range(n_stages):
            if normalization != 'none':
                ops.append(nn.Conv3d(n_filters_in, n_filters_out, 3, padding=1))
                if normalization == 'batchnorm':
                    ops.append(nn.BatchNorm3d(n_filters_out))
                if normalization == 'groupnorm':
                    ops.append(GroupNorm3D(n_filters_out))
            else:
                ops.append(nn.Conv3d(n_filters_in, n_filters_out, 3, padding=1))

            ops.append(nn.ReLU(inplace=True))

        self.conv = nn.Sequential(*ops)

    def forward(self, x):
        if self.expand_chan:
            x = self.conv(x) + self.conv_expan(x)
        else:
            x = (self.conv(x) + x)

        return x


class DownsamplingConvBlock(nn.Module):
    def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
        super(DownsamplingConvBlock, self).__init__()

        ops = []
        if normalization != 'none':
            ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
            if normalization == 'batchnorm':
                ops.append(nn.BatchNorm3d(n_filters_out))
            if normalization == 'groupnorm':
                ops.append(GroupNorm3D(n_filters_out))
        else:
            ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))

        ops.append(nn.ReLU(inplace=True))

        self.conv = nn.Sequential(*ops)

    def forward(self, x):
        x = self.conv(x)
        return x


class UpsamplingDeconvBlock(nn.Module):
    def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
        super(UpsamplingDeconvBlock, self).__init__()

        ops = []
        if normalization != 'none':
            ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
            if normalization == 'batchnorm':
                ops.append(nn.BatchNorm3d(n_filters_out))
            if normalization == 'groupnorm':
                ops.append(GroupNorm3D(n_filters_out))
        else:
            ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))

        ops.append(nn.ReLU(inplace=True))

        self.conv = nn.Sequential(*ops)

    def forward(self, x):
        x = self.conv(x)
        return x


class VNet(nn.Module):
    def __init__(self, n_channels, n_classes, n_filters=16, normalization='none'):
        super(VNet, self).__init__()

        if n_channels > 1:
            self.block_one = ResidualConvBlock(1, n_channels, n_filters, normalization=normalization, expand_chan=True)
        else:
            self.block_one = ResidualConvBlock(1, n_channels, n_filters, normalization=normalization)

        self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization)

        self.block_two = ResidualConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
        self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization)

        self.block_three = ResidualConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
        self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization)

        self.block_four = ResidualConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
        self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization)

        self.block_five = ResidualConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization)
        self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization)

        self.block_six = ResidualConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
        self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization)

        self.block_seven = ResidualConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
        self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization)

        self.block_eight = ResidualConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
        self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization)

        self.block_nine = ResidualConvBlock(1, n_filters, n_filters, normalization=normalization)

        self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)

    def forward(self, input):
        x1 = self.block_one(input)
        x1_dw = self.block_one_dw(x1)

        x2 = self.block_two(x1_dw)
        x2_dw = self.block_two_dw(x2)

        x3 = self.block_three(x2_dw)
        x3_dw = self.block_three_dw(x3)

        x4 = self.block_four(x3_dw)
        x4_dw = self.block_four_dw(x4)

        x5 = self.block_five(x4_dw)

        x5_up = self.block_five_up(x5)
        x5_up = x5_up + x4

        x6 = self.block_six(x5_up)
        x6_up = self.block_six_up(x6)
        x6_up = x6_up + x3

        x7 = self.block_seven(x6_up)
        x7_up = self.block_seven_up(x7)
        x7_up = x7_up + x2

        x8 = self.block_eight(x7_up)
        x8_up = self.block_eight_up(x8)
        x8_up = x8_up + x1

        x9 = self.block_nine(x8_up)

        out = self.out_conv(x9)

        return out


In [None]:
net_seg = VNet(n_channels=2, n_classes=1, n_filters=16, normalization=normalization)
net_seg = torch.nn.DataParallel(net_seg)
net_seg.cuda()

## Implementation of training + validation loops
We implement here the training and validation loops. The training loop implements a cycle over the batches that can be obtained from the training set and feeds them to the network. After the forward propagation, gradients are backpropagated so that the parameters of the networks can be updated and training can take place. One thing that is very convenient in PyTorch is the ability to update the parameters of the network once every multiple batches. This is particularly useful in situation, such as the current one, where due to GPU memory constraints it is not possible to use large batch sizes during training (here we use batch size two, resulting in almost full usage of 16 GB of GPU memory).

The validation loop feeds data to the network and executes the forward propagation step. With the directive `net_seg.eval()` we instruct pytorch to put the network in evaluation mode (as opposed to training mode) and `with torch.no_grad():` we are asking pytorch not to allocate memory for gradient as we won't need to compute any during validation.

Together, the two loops implement training.

In [None]:
def run_training(net_seg):
    optimizer_seg = optim.Adam(net_seg.parameters(), lr=learning_rate)
    criterion_seg = DiceLoss()
    
    running_seg_loss = 0.0
        
    for i, data in enumerate(train_batch_iterator, 0):
        # get the inputs
        inputs, labels = torch.from_numpy(data['images']), torch.from_numpy(data['labels'])

        if torch.cuda.is_available:
            inputs, labels = inputs.cuda(), labels.cuda()

        # wrap them in Variable
        inputs, labels = Variable(inputs), Variable(labels)
        
        if (i % effective_batchsize) == 0:
            # zero the parameter gradients
            optimizer_seg.zero_grad()

        # forward + backward + optimize
        outputs_seg = net_seg(inputs)

        outputs_seg = torch.nn.Sigmoid()(outputs_seg)

        loss_seg = criterion_seg(outputs_seg, labels)

        loss_seg.backward()
        optimizer_seg.step()

        running_seg_loss += loss_seg.detach().cpu().item()
        
    avg_loss = running_seg_loss / len(train_batch_iterator)
    one_output_seg = outputs_seg.detach().cpu().numpy()[0]
    one_output_img = inputs.detach().cpu().numpy()[0]
        
    return avg_loss, one_output_seg, one_output_img 

In [None]:
def run_validation(net_seg):
    net_seg.eval()
    criterion_seg = DiceLoss()
    running_seg_loss = 0.0

    with torch.no_grad():
        for i, data in enumerate(valid_batch_iterator, 0):
            # get the inputs
            inputs, labels = torch.from_numpy(data['images']), torch.from_numpy((data['labels']))

            if torch.cuda.is_available:
                inputs, labels = inputs.cuda(), labels.cuda()

            # wrap them in Variable
            inputs, labels = Variable(inputs), Variable(labels)

            # forward + backward + optimize
            outputs_seg = net_seg(inputs)

            outputs_seg = torch.nn.Sigmoid()(outputs_seg)

            loss_seg = criterion_seg(outputs_seg, labels)

            running_seg_loss += loss_seg.detach().cpu().item()
            
    avg_loss = running_seg_loss / len(valid_batch_iterator)
    one_output_seg = outputs_seg.detach().cpu().numpy()[0]
    one_output_img = inputs.detach().cpu().numpy()[0]
        
    return avg_loss, one_output_seg, one_output_img 

In [None]:
# We run training + validation by executing this cell
train_losses = []
validation_losses = []

print('STARTING TRAINING')

for i in range(num_epochs):
    train_loss, train_output_seg, train_input_img = run_training(net_seg)
    valid_loss, valid_output_seg, valid_input_img = run_validation(net_seg)
    print('EPOCH {} of {}'.format(i, num_epochs))
    print('-- train loss {} -- valid loss {} --'.format(train_loss, valid_loss))
    
    train_losses.append(train_loss)
    validation_losses.append(valid_loss)
    
    plt.imshow(np.squeeze(train_input_img[0, :, : , 30] + train_output_seg[0, :, : , 30]))  # showing just 1 channel out of two from MRI image
    plt.show()
    plt.imshow(np.squeeze(valid_input_img[0, :, : , 30] + valid_output_seg[0, :, : , 30]))  # label has been thresholded to be class 0 background, class 1 prostate
    plt.show()
    
    plt.plot(range(len(train_losses)), train_losses, 'b', range(len(validation_losses)), validation_losses, 'r')
    red_patch = mpatches.Patch(color='red', label='Validation')
    blue_patch = mpatches.Patch(color='blue', label='Training')
    plt.legend(handles=[red_patch, blue_patch])
    plt.show()
    

## Plotting scores
We now plot the training and validation scores in terms of Dice loss (the lower, the better). More informations about the results and more experiments can be found in the original paper [Ref. 1]

In [None]:
plt.plot(range(len(train_losses)), train_losses, 'b', range(len(validation_losses)), validation_losses, 'r')
red_patch = mpatches.Patch(color='red', label='Validation')
blue_patch = mpatches.Patch(color='blue', label='Training')
plt.legend(handles=[red_patch, blue_patch])
plt.show()

## Try your own experiment!
What happens when you change the normalization strategy (when you declare net_seg) to `batchnorm` or `none`? What about other changes such as learning rate? What happens when you increase the effective batch_size by varying the `effective_batchsize` parameter? 

## <span style="color:green">Conclusions</span>
In this lab you have seen how to implement a network architecture **similar** to V-Net. V-Net is one of the first works proposing 3D segmentation and is the first work that introduced Dice loss as an objective function that can be optimized to deliver superior segmentation performance. Dice loss is the objective function of choice for medical image segmentation and has been undes by hundreds of groups worldwide. Network architecture derived from V-Net and further improving it have been also developed and have been applied to a multitude of tasks and a wide range of anatomies. 

In this lab we have shown the details of how to achieve an improved implementation of V-Net by using modern DL software and advanced python design constructs. By trying your own experiments you should be also able to further explore the effect of specific design choices on the outcome of the experiments and performance of the network. 

Despite the slow training procedure, due to the high computational load of the task presented in this lab, the network starts showing signs of convergence within 50 epochs. In reality, such a network would need to be trained for much longer periods of time, amounting to days or even weeks (depending on the dataset) on modern GPUs. 

## References
* Ref. 1: *Milletari, F., Navab, N. and Ahmadi, S.A., 2016, October. V-net: Fully convolutional neural networks for volumetric medical image segmentation. In 3D Vision (3DV), 2016 Fourth International Conference on (pp. 565-571). IEEE.*
* Ref. 2: *Ronneberger, O., Fischer, P. and Brox, T., 2015, October. U-net: Convolutional networks for biomedical image segmentation. In International Conference on Medical image computing and computer-assisted intervention (pp. 234-241). Springer, Cham.*
* Ref. 3: *Ioffe, S., & Szegedy, C. (2015). Batch normalization: Accelerating deep network training by reducing internal covariate shift. arXiv preprint arXiv:1502.03167*
* Ref. 4: *Drozdzal, M., Vorontsov, E., Chartrand, G., Kadoury, S. and Pal, C., 2016. The importance of skip connections in biomedical image segmentation. In Deep Learning and Data Labeling for Medical Applications (pp. 179-187). Springer, Cham*
* Ref. 5: *Milletari, Fausto. Hough Voting Strategies for Segmentation, Detection and Tracking. Diss. Universität München, 2018.*
