# Performing medical imaging segmentation like a pro (Part 1)
*A highly oppiniated and biased tutorial on MRI lesion segmentation using Pytorch*

**Sergi Valverde, PhD**
*Universitat de Girona, Spain*



# Introduction:
---


*Deep learning* techniques have been implemented for a wide range of computer vision and medical imaging tasks such as image registration, classification and segmentation, showing a superior performance in comparison with state-of-the-art available methods. In particular, *U-NET* like architectures are nowadays *the-facto* methods used in whatever medical imaging segmentation tasks required.  

The goal of this tutorial is to introduce you to these new techniques. To do so, we will use MRI brain lesion segmentation as a context. Having the capability to hack and modify them for new problems will be a valuable contribution that you may want to leverage along of your research or professional career.

The tutorial is divided in two parts: During the first part, I will introduce you to the awesome [Pytorch](http://pytorch.org) library. PyTorch is the most commonly used library for *deep learning* research. We will cover the basic concepts underlaying the library. As you will see, Altough Pytorch is a low-level library, it introduces a very *pythonic* and easy-to-use syntaxis, which permits to modify our models extensively and to introduce new ideas very fast. 

During the second part, we will implement the *U-NET* model and we will apply it to the MRI white matter lesion segmentation problem. We will cover the entire training and inference procedures, showing some tricks to learn better and faster models. Finally, I will introduce some of the latest techniques that have been proposed in the context of medical image segmentation, showing how easy is to incorporate them into our models. 


# Why PyTorch?

---

PyTorch is a strong player in the field of deep learning and artificial intelligence, and it can be considered primarily as a research-first library. Some reasons to use PyTorch: 

* Pytorch is Pythonic (covered)
* Pytorch is a low-level library but easy to hack (covered)
* Pytorch is easy to debug (covered)
* Data parallelism is straighforward (not covered)
* Dynamic computational graph support (not covered)

## Pytorch is pythonic:

Pytorch syntaxes for operations are very similar to Python code. This makes the code very readable and hackable. Compare the following code to compute the product between two matrices in both Python (`numpy`) and PyTorch:

In [1]:
import numpy as np
A = np.ones((4,4))
B = np.ones((4,4)) * 2
C = A * B
print(C)

import torch
A = torch.ones((4,4))
B = torch.ones((4,4)) * 2
C = A * B
print(C)

[[2. 2. 2. 2.]
 [2. 2. 2. 2.]
 [2. 2. 2. 2.]
 [2. 2. 2. 2.]]
tensor([[2., 2., 2., 2.],
        [2., 2., 2., 2.],
        [2., 2., 2., 2.],
        [2., 2., 2., 2.]])


In fact, `torch` tensors can be converted into `numpy` arrays very easily using the NumPy bridge. The `torch` and `numpy` arrays will share their underlying memory locations (if located in the CPU):

In [4]:
a = torch.ones(5)
b = a.numpy()
print('torch tensor:', a)
print('numpy array:', b)

# adding implictly 1 to the torch tensor
a.add_(1)
print('torch tensor:', a)
print('numpy array:', b)

torch tensor: tensor([1., 1., 1., 1., 1.])
numpy array: [1. 1. 1. 1. 1.]
torch tensor: tensor([2., 2., 2., 2., 2.])
numpy array: [2. 2. 2. 2. 2.]


Conversely, `numpy` arrays can be also converted to `torch` tensors, maintaining the same underlaying memory locations:

In [5]:
a = np.ones(5)
b = torch.from_numpy(a)
np.add(a, 1, out=a)
print('numpy array:', a)
print('torch tensor:', b)

numpy array: [2. 2. 2. 2. 2.]
torch tensor: tensor([2., 2., 2., 2., 2.], dtype=torch.float64)


# Pytorch is low-level but easy to hack:
---

In constrast to other libraries like [Keras](https://keras.io), in PyTorch most of our codebase has to be built from scratch. However, this may be more a feature than a drawback in most situations where we need more control on the task at hand. Given the pythonic syntaxis, it is very easy to move one and build complex models. 

For instance, moving from any CPU or GPU device is straighforward in `torch`:

In [6]:
x = torch.ones((4,4))
if torch.cuda.is_available():
    device = torch.device('cuda:0')
    y = torch.ones_like(x, device=device)
    x = x.to(device)
    z = x + y # sum is performed in the GPU!
print(z)
print(z.cpu())

tensor([[2., 2., 2., 2.],
        [2., 2., 2., 2.],
        [2., 2., 2., 2.],
        [2., 2., 2., 2.]], device='cuda:0')
tensor([[2., 2., 2., 2.],
        [2., 2., 2., 2.],
        [2., 2., 2., 2.],
        [2., 2., 2., 2.]])


## Autograd:

The autograd package is the central element in all neural networks in PyTorch. The `autograd` package provides automatic differentiation for all the operations on tensors. It is a define-by-run framework, which means that your backprop is defined by how your code is run, and that every single iteration can be different (dynamic computational graph support). 

All `torch.tensor` elements have an attribute called `.requires_grad`, which controls if all the operations in the `tensor` are tracked or not during computations. When finished, by calling `.backward()` all the gradients will be computed automatically. The gradient for each `torch.tensor` will be accumulated in the `.grad` attribute. 

`torch.tensor` and `Function` (`+`, `-`, `torch.mul`, ...)  are interconnected and build up an acyclic graph, that encodes a complete history of computation. Each `torch.tensor` has a `.grad_fn` attribute that references a `Function` that has created the `torch.tensor`. Let's see an example:



In [7]:
x = torch.ones(2, 2, requires_grad=True) # True by default
print(x)
y = x + 2
print(y) # see how grad_fn contains the '+' function
z = y * y * 3
out = z.mean()
print(out)

tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
tensor([[3., 3.],
        [3., 3.]], grad_fn=<AddBackward0>)
tensor(27., grad_fn=<MeanBackward0>)


## Gradients:

Here it's when backprop is applied. We can easily compute the gradient of `out` with `out.backward()`, then see gradients of $\dfrac{d(out)}{dx}$:

In [8]:
out.backward()
print(x.grad)

tensor([[4.5000, 4.5000],
        [4.5000, 4.5000]])


Let's see if the result makes sense. We can write `out`as $out=\dfrac{1}{4}\sum_i z_i, z_i = 3(x_i + 2)^2 $,  so given that $x_i = 1$ then $out = 27$,  therefore: $\dfrac{d_{out}}{dx_i} = \dfrac{3}{2}(x_i + 2)$ and $\dfrac{d_{out}}{dx_i} |_{x_i = 1} = \dfrac{9}{2} = 4.5$

## Our first CNN network in Pytorch:

Let's build our first Convolutional Neural Network (CNN) from scratch. To be fair with the history, let's build the `LeNet` network proposed by LeCunn in 1995 for digit recognition: 

![LeNet](media/mnist.png)

Let's define the network first. Any differentiable object (a loss function, a layer or the same network) in PyTorch has to be defined using the `nn.Module` class, and has to incorporate at least a `forward` function:

In [10]:
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square convolution
        # kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 6 * 6, 120)  # 6*6 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features




As you see, here we have introduced some common neural network layers and activations, all included in the `torch.nn` and the `Functional` modules:

* `nn.Conv2d` layer: 2D CNN layer
* `nn.Linear` layer: Linear 1D layer 
* `F.max_pool2d` layer: 2D max pooling
* `F.relu` activation: non-linear RELU activations

Now, let's initialize the neural network we have created: 

In [11]:
net = Net()
params = list(net.parameters())
num_params = sum([np.prod(p.size()) for p in params])
print('Neural net with', num_params, 'parameters')

Neural net with 81194 parameters


### Loss function

Let's try a random 32x32 input and target to emulate a mini-batch training. We estimate how far away we are from the target using a loss function from the `nn.package`: 

In [12]:
# input and target
input = torch.randn(1, 1, 32,32)
target = torch.randn(10) # dummy probs for each label

# network forward pass
output = net(input)

# Minimum square error criterion loss
criterion = nn.MSELoss()
loss = criterion(output, target.view(1, -1))
print(loss)

tensor(1.2209, grad_fn=<MseLossBackward>)


### Backprop:

Finally, the last thing to do is to propagate the error with `loss.backward()`. Take into account that we need to clear the existing gradients for each minibatch. 

In [13]:
# clear gradients
net.zero_grad()

# print the conv1 for instance
print('before backward:', net.conv1.bias.grad)

# forward pass
output = net(input)

# compute the loss and backprop the error
loss = criterion(output, target.view(1, -1))
loss.backward()
print('after backward:', net.conv1.bias.grad)




before backward: None
after backward: tensor([ 0.0036, -0.0092, -0.0036,  0.0004, -0.0027, -0.0021])


### Update the weights of the model:

So far, we are not updating the weights of the model. As you may know, the most used technique is Stochastic Gradient Descend (`SGD`). The update rule is defined as:

$$ weight = weight - learning_{rate} * gradient$$

So we could manually update all the model weights as follows:

In [14]:
learning_rate = 0.01
for f in net.parameters():
    f.data.sub_(f.grad.data * learning_rate)

However, PyTorch incorporates the package `torch.optim` that implements most of the state-the-art optimizers available today. Using `torch.optim`, we can finally define the entire training loop for a particular mini-batch of data:



In [15]:
import torch.optim as optim

optimizer = optim.SGD(net.parameters(), lr=0.01)

# clear gradients
net.zero_grad()

# forward pass
output = net(input)

# compute the loss and backprop the error
loss = criterion(output, target.view(1, -1))
loss.backward()

# update the weights of the network
optimizer.step()

