<table>
<tr height="150px">
<th>
<img height="80px" margin="20px" src='https://www.archimedesai.gr/images/logo_en.svg' />
</th>
<th>
<img height="150px" src='https://stergioc.github.io/assets/img/logos.png' />
</th>
</tr>
</table>

<h1>Introduction to Deep Learning (Hands-on Tutorial)</h1>
<h3>Maria Vakalopoulou & Stergios CHRISTODOULIDIS</h3>

[![ML-tutorial](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/stergioc/BioMed-AI-Summer-School/blob/master/DL/dl-tutorial.ipynb)

In this tutorial, we will code a neural network to perform segmentation of tumours in brain images. The data are retrieved from the [BraTS](https://www.med.upenn.edu/cbica/brats2019/data.html) dataset and have been already lightly preprocessed for this tutorial.

# **1. Imports**

In [None]:
import torch
import torch.nn as nn
import torch.utils.data
import torch.utils.tensorboard
import torchvision
import torchvision.transforms
import os
import time
import numpy as np
import skimage.transform
import matplotlib.pyplot as plt
import pandas as pd
import sklearn.model_selection as model_selection
from tqdm import tqdm
from glob import glob
from IPython.display import clear_output

!pip install SimpleITK
import SimpleITK as sitk

!pip install tensorboard

# **2. Database**



## 2.a. Download the database

Run the following cell to download the database.

In [None]:
# A large portion of the BraTS data (~2GB) of data will be downloaded with
# this command, it should take about ~3 minutes.
!wget https://nextcloud.centralesupelec.fr/s/YXd3S8sAYbjaz2Z/download/dl-tutorial-data.zip
!unzip -q dl-tutorial-data.zip -d GEP1

## 2.b. Overview of the database

The data is stored in the folder `/GEP1/`.

In that folder you will find:
> The `origin_data` folder: contains 4 patients.
>- For each patient, you will find a whole volume per modality in nifti format.
>- Each volume has a shape `(155, 192, 192)`

> The `data` folder: contains 336 patients.
>- Images in that folder have underwent some preprocessing and have now a shape `(78, 96, 96)`. This will make training of deep learning models faster and less greedy memory-wise.
>- For a patient `BraTS19_EXAMPLE`, you will have a corresponding folder `/GEP1/data/BraTS19_EXAMPLE`. Inside this folder, you will have nifti files (`.nii.gz`) for each modality and each z slice. .
>- There are 78 slices per volume, i.e. 78 slices per modality.



In [None]:
# The data is store in the folder /GEP1/
data_path = './GEP1/data/'
original_data_path = './GEP1/origin_data/'

### **i. Inside of the database**

In [None]:
# Original data
print("Content of the folder:\n", os.listdir(original_data_path))
patient = "BraTS19_TCIA01_131_1"
print("Content of a patient's folder:\n", os.listdir(original_data_path + patient))

In [None]:
# Processed data
files = os.listdir(data_path) # All the files in the folder /GEP1/data/
print('Content of the folder {} \n: {}'.format(data_path, files[:5]))
# print('Number of files for each patient : {}'.format(len(os.listdir(data_path + files[0])) ))
print('Number of patients in {} : {}'.format(data_path, len(files)))

### **ii. Patient's folder**

For each patient, we have 4 modalities and the segmentation:

- t1
- t2
- flair
- t1ce (gado)
- segmentation

For each modality, you also have a file for each slice along the Z axis.

Each patient has **78 slices** per modality.

In [None]:
modalities = ['t1', 't2', 't1ce', 'flair', 'seg']

In [None]:
patient = 'BraTS19_2013_20_1'
patient_path = os.path.join(data_path, patient)
patient_files = os.listdir(patient_path)
patient_files[:10]

In [None]:
"""
Filter for the Flair modality
"""

flair_modality_files = sorted([e for e in patient_files if 'flair' in e])
print("Number of Z slices:", len(flair_modality_files))
flair_modality_files[-5:]

### **iii. SimpleITK tutorial**

Use the `SimpleITK` Python package in order to read the nifti files of the database.

To open a a nifti image:

        image = sitk.ReadImage(image_path)

Using this package, you can access relevant and physical information about the image:
- spacing: `image.GetSpacing()`
- direction: `image.GetDirection()`
- origin: `image.GetOrigin()`
- size: `image.GetSize()`
- metadata: `image.GetMetaDataKeys()`
- access the value of a pixel: `image.GetPixel(pixel_x, pixel_y, pixel_z)`

You can also convert the `sitk` image into a `numpy` array:

        array = sitk.GetArrayFromImage(image)

In [None]:
### START CODE HERE ###
#img = ...
#array = ...
### END CODE HERE ###

### **iv. Comparison between original data and preprocessed data.**

In order to accelerate the calculation time and have good results quickly, the original images of the dataset of shape `(155, 240, 240)` have been preprocessed according to the following steps:
- Cropped the images to a shape of `(155, 192, 192)`
- Downsampled the images by interpolation of scale 0.5 (https://scikit-image.org/docs/dev/auto_examples/transform/plot_rescale.html) to a shape of `(78, 96, 96)`
- Saved all the Z slices **independently** in a new array of shape `(96, 96)`



In [None]:
patient = 'BraTS19_CBICA_ANP_1'

# Define the image path in the original data
z = 3
modality = 'flair'
patient_folder = os.path.join(original_data_path, patient)
image_name = "{patient}_{modality}.nii.gz".format(patient=patient, modality=modality)
image_path = os.path.join(patient_folder, image_name)

# We use the librairy sitk to open the nifti images
image = sitk.ReadImage(image_path)
orig_array = sitk.GetArrayFromImage(image)
print('Original array shape : {}'.format(orig_array.shape))

# open corresponding preprocessed data slice
patient_folder = os.path.join(data_path, patient)
z_slice = 35
path = os.path.join(patient_folder, "{patient}_{modality}_z_{z_slice}.nii.gz".format(patient=patient, modality=modality, z_slice=z_slice))
image = sitk.ReadImage(path)
processed_array = sitk.GetArrayFromImage(image)
print('Processed array shape : {}'.format(processed_array.shape))

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(orig_array[z_slice*2, :, :], cmap='gray')
plt.title('Original array')
plt.subplot(1, 2, 2)
plt.imshow(processed_array, cmap='gray')
plt.title('Processed array')
plt.show()

### **iv. Visualize all modalities**

Consider a patient in the `data` folder. Plot each modality side by side, by iterating over the number of slices.

In [None]:
# Go over each Z slice
z=46

f, axes = plt.subplots(1, 5, figsize=(20, 6))
# Plot each modality for that slice
for i, modality in enumerate(modalities):
    colormap = None if i==4 else 'gray'
    # Fetch the modality-slice file and open using SimpleITK
    file_path = data_path + f"{patient}/{patient}_{modality}_z_{z}.nii.gz"
    slice = sitk.GetArrayFromImage(sitk.ReadImage(file_path))

    # Plot the slice
    axes[i].set_title(modality)
    axes[i].imshow(slice, cmap=colormap)

plt.suptitle("Slice {}/{}".format(z+1, len(flair_modality_files)), y=0.85)
plt.show()

## 2.c. Creating the train, validation and test sets


The train, validation and test split are stored in the folder `/GEP1/datasets`. For each split, you will have a text file indicating the list of patients.

Execute the following code to :
*   Load the train, validation and test set.
*   Print the first 5 patients of the train set.
*   Print the length of the train, validation and test set.


In [None]:
datasets_path = './GEP1/datasets/'

train_set = np.loadtxt(datasets_path + 'train.txt', dtype=str)
validation_set = np.loadtxt(datasets_path + 'val.txt', dtype=str)
test_set = np.loadtxt(datasets_path + 'test.txt', dtype=str)

# Train_set, validation_set and test_set are list of patients
print('Train set, first 5 patients : {}\n'.format(train_set[:5])) # Print the first 5 patients of train_set
print('Train set length :\t {}'.format(len(train_set)))
print('Validation set length :\t {}'.format(len(validation_set)))
print('Test set length :\t {}'.format(len(test_set)))

# **3. Creation of the neural network**

In this part, we will implement and train a [**UNet**](https://arxiv.org/pdf/1505.04597.pdf). UNets are particularly used for segmentation tasks in medical imaging. You can study its architecture in the following figure.

The UNet has two main features :

1.   The size of the input image is downsampled by 2 at each block by a layer called `MaxPooling` in the **encoder part** of the model. In the **decoder part**, extracted features are upsampled progressively using a Transpose Convolution  (`ConvTranspose2d` in PyTorch).

2.   Secondly in order to keep information of high resolution, we use **skip-connections** to pass information from the encoder part of the network to the decoder part.

![](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png)



## 3.a. Create the network


### Exersise 1 - Implement the UNet Building blocks

>Complete the following class `ConvBatchNorm`. You should use:
>- [Conv2d](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html)
>- [BatchNorm2d](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html)

In [None]:
class ConvBatchNorm(nn.Module):
    """This block implements the sequence: (convolution => [BN] => ReLU)"""
    def __init__(self, in_channels, out_channels):
        ### START CODE HERE ###
        #self.conv = ...
        #self.norm = ...
        #self.activation = ...
        ### END CODE HERE ###

    def forward(self, x):
        out = self.conv(x)
        out = self.norm(out)
        out = self.activation(out)
        return out


What does the function `_make_nConv` do ?

In [None]:
def _make_nConv(in_channels, out_channels, nb_Conv):
    layers = []
    layers.append(ConvBatchNorm(in_channels, out_channels))
    for _ in range(nb_Conv-1):
        layers.append(ConvBatchNorm(out_channels, out_channels))
    return nn.Sequential(*layers)

>Complete the classes:
>- `ConvBatchNorm`
>- `DownConvBlock`: these will be used to build the **encoder** part of the UNet.
>- `UpConvBlock`: these are used to build the **decoder** part of the UNet.

>You should use PyTorch layers such as:
>- `MaxPool2d`,
>- `ConvTransposed2d` or `Upsample`,
>- `ReLU`.

In [None]:
class DownBlock(nn.Module):
    """Downscaling with maxpooling and convolutions"""
    def __init__(self, in_channels, out_channels, nb_Conv=2):
        self.maxpool = nn.MaxPool2d(2)
        self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv)

    def forward(self, x):
        out = self.maxpool(x)
        out = self.nConvs(out)
        return out

class Bottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, nb_Conv=2):
        self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv)
        self.last = nn.ConvTranspose2d(out_channels, in_channels,
                                           kernel_size=3, stride=2,
                                           padding=1, output_padding=1)
    def forward(self, input):
        out = self.nConvs(input)
        out = self.last(out)
        return out

class UpBlock(nn.Module):
    """Upscaling then conv"""
    def __init__(self, in_channels, out_channels, nb_Conv=2):
        self.up = nn.Upsample(scale_factor=2)
        self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv)

    def forward(self, x, skip_x):
        ### START CODE HERE ###
        # Note that in this function there are two inputs.
        # Useful fuction: https://pytorch.org/docs/stable/generated/torch.cat.html
        ### END CODE HERE ###
        return out

### **ii. UNet architecture**

Here you will use the building blocks coded above to construct the full UNet architecture.

- Be careful to the number of input / output channels of each block when implementing the skip connections.

- Try to compare the following code with the figure. Where are the `DownConvBlock`, the `UpConvBlock` ?

In [None]:
class UNet(nn.Module):
    def __init__(self, n_channels=4, n_classes=4):
        '''
        n_channels : number of channels of the input.
                        By default 4, because we have 4 modalities
        n_labels : number of channels of the ouput.
                      By default 4 (3 labels + 1 for the background)
        '''
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes

        # Question here
        self.inc = ConvBatchNorm(n_channels, 64)
        self.down1 = DownBlock(64, 128, nb_Conv=2)
        self.down2 = DownBlock(128, 256, nb_Conv=2)
        self.down3 = DownBlock(256, 512, nb_Conv=2)
        self.down4 = DownBlock(512, 512, nb_Conv=2)

        self.Encoder = [self.down1, self.down2, self.down3, self.down4]

        self.bottleneck = Bottleneck(512, 512)

        self.up1 = UpBlock(1024, 256, nb_Conv=2)
        self.up2 = UpBlock(512, 128, nb_Conv=2)
        self.up3 = UpBlock(256, 64, nb_Conv=2)
        self.up4 = UpBlock(128, 64, nb_Conv=2)

        self.Decoder = [self.up1, self.up2, self.up3, self.up4]

        self.outc = nn.Sequential(
            nn.ConvTranspose2d(
                64, 64, kernel_size=3,stride=2,padding=1, output_padding=1
            ),
            nn.Conv2d(
                64, self.n_classes, kernel_size=3, stride=1, padding=1
            )
        )
        self.last_activation = nn.Softmax(dim=1)


    def forward(self, x):
        # Forward
        skip_inputs = []
        x = self.inc(x)

        # Forward through encoder
        for i, block in enumerate(self.Encoder):
            x = block(x)
            skip_inputs += [x]

        # We are at the bottleneck.
        bottleneck = self.bottleneck(x)

        # Forward through decoder
        skip_inputs.reverse()

        decoded = bottleneck
        for i, block in enumerate(self.Decoder):
            # Concat with skipconnections
            skipped = skip_inputs[i+1]
            decoded = block(decoded, skipped)
        out = self.last_activation(self.outc(decoded))
        return out

## 3.b. Study the model

> To study and debug a neural network, you can try to feed a random tensor of size `(1, 4, 96, 96)` (batch size, number of modalites, image shape) using the `torch.rand` function.

> The output's shape should be the same as the input.

> To better debug, you may need to modify your code to take a look at the output shape of each layer in your network .



In [None]:
model = UNet(n_channels=4, n_classes=4)
print(model)

# Image of size 96*96 with 4 modality + batch size = 1
x = torch.rand((1, 4, 96, 96))
y = model(x)
print(y.shape)

# **4. Dataset creation**



## 4.a. Helper functions - CODE TO EXECUTE AND HIDE

All the following code is needed to execute the model, but we don't ask you to implement it. You just need to execute it once. If you want, you can try to understand what the different functions are doing but it is not needed.

In [None]:
#@title
end = '.nii.gz'
seg_name = '_seg'

def load_split(split_folder):
    '''
        return train, val, test split with loadtxt
    '''
    train_split = np.loadtxt(os.path.join(
        split_folder, 'train.txt'), dtype=str)
    val_split = np.loadtxt(os.path.join(split_folder, 'val.txt'), dtype=str)
    test_split = np.loadtxt(os.path.join(split_folder, 'test.txt'), dtype=str)
    return train_split, val_split, test_split


def load_sitk(path):
    return sitk.GetArrayFromImage(sitk.ReadImage(path))


def find_z_slice(list_patient, threshold, dataframe):
    """
    For each patient in list_patient, this function returns the list of slices where
    the corresponding image is not empty"""

    list_IDs = []
    for patient in list_patient:
        if threshold > 0:
            condition = dataframe[patient].values >= threshold
            z_slice = np.where(condition)[0]
        else:
            z_slice = range(155)
        list_IDs += list(set([(patient, int(z//2)) for z in z_slice]))

    return list_IDs


def generate_IDs(train_split, val_split, test_split,
                 tumor_percentage, csv_path, image_size=(240, 240)):

    tumor_volume_dataframe = pd.read_csv(csv_path)
    threshold = int(tumor_percentage * np.prod(image_size) / 100)

    train_IDs, val_IDs, test_IDs = [], [], []
    train_IDs = find_z_slice(train_split, threshold, tumor_volume_dataframe)
    val_IDs = find_z_slice(val_split, threshold, tumor_volume_dataframe)
    test_IDs = find_z_slice(test_split, 0, tumor_volume_dataframe)
    return train_IDs, val_IDs, test_IDs


def to_var(x, device):
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x)
    x = x.to(device)
    return x

def to_numpy(x):
    if not (isinstance(x, np.ndarray) or x is None):
        if x.is_cuda:
            x = x.data.cpu()
        x = x.numpy()
    return x

def save_checkpoint(state, save_path):
    '''
        Save the current model.
        If the model is the best model since beginning of the training
        it will be copy
    '''

    if not os.path.isdir(save_path):
        os.makedirs(save_path)

    epoch = state['epoch']
    val_loss = state['val_loss']
    filename = save_path + '/' + \
        'model.{:02d}--{:.3f}.pth.tar'.format(epoch, val_loss)
    torch.save(state, filename)


def print_summary(epoch, i, nb_batch, loss, batch_time,
                  average_loss, average_time, mode):
    '''
        mode = Train or Test
    '''
    summary = '[' + str(mode) + '] Epoch: [{0}][{1}/{2}]\t'.format(
        epoch, i, nb_batch)

    string = ''
    string += ('Dice Loss {:.4f} ').format(loss)
    string += ('(Average {:.4f}) \t').format(average_loss)
    string += ('Batch Time {:.4f} ').format(batch_time)
    string += ('(Average {:.4f}) \t').format(average_time)

    summary += string
    print(summary)

def plot(irms, masks=None, pred_masks=None):

    kwargs = {'cmap': 'gray'}
    fig, ax = plt.subplots(2, 3, gridspec_kw={'wspace': 0.15, 'hspace': 0.2,
                                              'top': 0.85, 'bottom': 0.1,
                                              'left': 0.05, 'right': 0.95},
                           figsize=(12, 7))
    ax[0, 0].imshow(irms[0, :, :], **kwargs)

    if masks is not None:
        masks = np.argmax(masks, axis=0)
        ax[0, 1].imshow(masks, vmin=0, vmax=3)

    if pred_masks is not None:
        pred_masks = np.argmax(pred_masks, axis=0)
        ax[0, 2].imshow(pred_masks, vmin=0, vmax=3)

    for i in range(3):
        ax[1, i].imshow(irms[i+1, :, :], **kwargs)

    for i in range(2):
        for j in range(3):
            ax[i, j].grid(False)
            ax[i, j].axis('off')
            ax[i, j].set_xticks([])
            ax[i, j].set_yticks([])

    ax[0, 0].set_title('IRM T1')
    ax[1, 0].set_title('IRM Gado')
    ax[1, 1].set_title('IRM T2')
    ax[1, 2].set_title('IRM Flair')
    ax[0, 1].set_title('Ground Truth Seg')
    ax[0, 2].set_title('Predicted Seg')
    fig.canvas.draw()

    return fig

class SegmentationDataset(torch.utils.data.Dataset):
    'Generates data for torch'

    def __init__(self, files_list, data_path, modalities=['t1', 't2', 't1ce', 'flair'], transform=None):
        super(SegmentationDataset, self).__init__()
        self.files_list = files_list
        self.transform = transform
        self.data_path = data_path
        self.modalities = modalities

    def __len__(self):
        return len(self.files_list)

    def __getitem__(self, idx):
        'Get a patient given idx'
        patient = self.files_list[idx]

        # Load the patient's modalities and segmentation masks
        irm, mask = self.load(patient)
        sample = (irm, mask)

        # Apply data transformation
        if self.transform:
            irm, mask = self.transform(sample)
        return (irm, mask, patient)

    def load(self, ID):

        patient, z_slice = ID
        patient_path = os.path.join(self.data_path, patient)

        # Get all modalities for the given slice
        irm = []
        for modality in self.modalities:
            file_name = "{patient}_{modality}_z_{z_slice}.nii.gz".format(patient=patient, modality=modality, z_slice=z_slice)
            path = os.path.join(patient_path, file_name)
            irm.append(load_sitk(path))
        irm = np.stack(irm, axis=0)

        # Get the segmentation mask for the given slice
        seg_name = "{patient}_seg_z_{z_slice}.nii.gz".format(patient=patient, z_slice=z_slice)
        mask_path = os.path.join(patient_path, seg_name)
        mask = load_sitk(mask_path)
        mask[mask == 4] = 3

        # Convert segmentation mask to one-hot encoding
        label = 4
        mask = mask.astype(np.int16)
        mask = np.rollaxis(np.eye(label, dtype=np.uint8)[mask], -1, 0)
        return irm, mask


# Load the split, generate the IDs list
datasets_path ='./GEP1/datasets/'
csv_path = './GEP1/data/tumor_count.csv'

# The tumour percentage is the percentage of tumour in an image. It's a threshold
# that is used when selecting relevant slice indexes in a patient's images.
tumour_percentage = 0.5
train_split, val_split, test_split = load_split(datasets_path)

(train_IDs, val_IDs, test_IDs) = generate_IDs(train_split, val_split, test_split, tumour_percentage, csv_path)

## 4.b. Creating instances of the `SegmentationDataset` class for each split

In [None]:
# No data augmentation implemented yet.
train_Dataset = SegmentationDataset(train_IDs, data_path=data_path)
val_Dataset = SegmentationDataset(val_IDs, data_path=data_path)
test_Dataset = SegmentationDataset(test_IDs, data_path=data_path)


The following cell calls for a sample of the training set. Running `train_Dataset[0]` actually calls the `__getitem__` method of the `SegmentationDataset` class.

In [None]:
input_modalities, segmentation_mask, patient = train_Dataset[0]

In [None]:
print("Shape of the input:", input_modalities.shape)
print("Shape of the segmentation masks:", segmentation_mask.shape)
print("Patient identification:", patient[0])
print("Selected slice:", patient[1])

## 4.c. Create the `DataLoader` that will be used during training

For each split, you need to specify:
- the batch size with the `batch_size` argument
- whether to shuffle your dataset when feeding batches to the model using the `shuffle` argument
- whether to drop the last incomplete batch, if the dataset size is not divisible by the batch size, using the `drop_last` argument. This parameter is particularly useful when working with multiprocessing (`num_workers` > 1).

In [None]:
# Define the batch size
batch_size = 64

train_loader = torch.utils.data.DataLoader(train_Dataset,
                                           batch_size=batch_size, shuffle=True,
                                           drop_last=True)

val_loader = torch.utils.data.DataLoader(val_Dataset,
                                         batch_size=batch_size, shuffle=False,
                                         drop_last=True)

test_loader = torch.utils.data.DataLoader(test_Dataset,
                                          batch_size=1, shuffle=False,
                                          drop_last=False)

# **5. Training**

**Tensorboard** : This part is used to plot the prediction of the network during the training and study its performance.

## 5.a. Loss function, optimizer and hyperparameters

Here you need to choose:
- the loss function used to optimize the model (see [here](https://pytorch.org/docs/stable/nn.html#loss-functions)),
- the optimizer (see [here](https://pytorch.org/docs/stable/optim.html))
- all the hyperparameters: learning rate, weight decay ...


>In our application, you can start by :
>- choosing the `Adam` optimizer
>- coding your own custom loss function as the dice loss


### Exersise 2 - Implement the dice loss

![dice_loss](https://i.stack.imgur.com/Usv5J.png)


In [None]:
def dice_loss(input, target):
    iflat = torch.flatten(input.float())
    tflat = torch.flatten(target.float())
    ### START CODE HERE ###

    ### END CODE HERE ###
    return loss

def mean_dice_loss(input, target):
    channels = list(range(target.shape[1]))
    loss = 0
    for channel in channels:
        dice = dice_loss(input[:, channel, ...],
                         target[:, channel, ...])
        loss += dice
    return loss / len(channels)


## 5.b. Instantiate your model

You need to specify:
- the number of input channels of the model `n_channels` which should be equal to the number of modalities, so 4 modalities
- the number of segmentation classes `n_classes` = 4

In [None]:
n_modalities = 4
n_classes = 4
model = UNet(n_channels=n_modalities, n_classes=n_classes) # Create model
model.cuda() # move model to GPU

In [None]:
learning_rate = 1e-2

criterion = mean_dice_loss # Choose loss function
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # Choose optimize

## 5.c. Training loop for a single epoch

Here you should complete the `train_loop` function, which goes over the whole dataset only once, so correspond to one epoch.

When the model is in training mode `model.training == True`, the function should perform backpropagation and parameters' updates.

In [None]:
# Train the model
def train_loop(loader, model, criterion, optimizer, writer, epoch):

    logging_mode = 'Train' if model.training else 'Val'
    if model.training:print('training')

    epoch_time_sum, epoch_loss_sum = [], []

    for i, sample in enumerate(loader, 1):
        start = time.time()
        # Take variable
        (irms, masks, patients) = sample
        # print(irms.shape) # Batch * Number of Modalities * Width * Height

        # Put variables to GPU
        irms = irms.float().cuda()
        masks = masks.float().cuda()

        # compute model prediction
        pred_masks = model(irms)

        # compute loss
        dice_loss = criterion(pred_masks, masks)

        # If in training mode ...
        if model.training:
            # Initialize optimizer gradients to zero
            optimizer.zero_grad()
            # Perform backpropagation
            dice_loss.backward()
            # Update the model's trainable parameters using the computed gradients
            optimizer.step()


        # Compute elapsed time
        batch_time = time.time() - start

        epoch_time_sum += [batch_time]
        epoch_loss_sum += [dice_loss.item()]

        average_time = np.mean(epoch_time_sum)
        average_loss = np.mean(epoch_loss_sum)

        if i % print_frequency == 0:
            print_summary(epoch + 1, i, len(loader), dice_loss, batch_time,
                          average_loss, average_time, logging_mode)
        step = epoch*len(loader) + i
        writer.add_scalar(logging_mode + '_dice', dice_loss.item(),step)



    writer.add_scalar(logging_mode + '_global_loss', np.mean(epoch_loss_sum), epoch)


    # Save some figures to monitor segmentation quality
    n_modalities = irms.shape[0]
    irms = to_numpy(irms)
    masks = to_numpy(masks)
    pred_masks = to_numpy(pred_masks)

    for batch in range(n_modalities):
        fig = plot(irms[batch, ...], masks[batch, ...], pred_masks[batch, ...])
        writer.add_figure(logging_mode + str(batch), fig, epoch)
    writer.flush()
    return np.mean(epoch_loss_sum)

## 5.d. Perform the training

You will use **Tensorboard** to monitor the decrease of the loss funciton/

In [None]:
%reload_ext tensorboard
%tensorboard --logdir './GEP1/save/tensorboard_logs/'

In [None]:
save_path = "./GEP1/save/"
session_name = 'Test_session' + '_' + time.strftime('%m.%d %Hh%M')
model_path = save_path + 'models/' + session_name + '/'
# Configure tensorboard
session_name = 'Test_session' + '_' + time.strftime('%m.%d %Hh%M')
tensorboard_folder = save_path + 'tensorboard_logs/'
log_dir = tensorboard_folder + session_name + '/'

if not os.path.isdir(log_dir):
    os.makedirs(log_dir)
writer = torch.utils.tensorboard.SummaryWriter(log_dir)
# Training parameters
epochs = 5
print_frequency = 10
save_frequency = 1
save_model = True

In [None]:
for epoch in range(epochs):  # loop over the dataset multiple times
    print('******** Epoch [{}/{}]  ********'.format(epoch+1, epochs+1))
    print(session_name)

    # train for one epoch
    model.train()
    print('Training')
    train_loop(train_loader, model, criterion, optimizer, writer, epoch)

    # evaluate on validation set
    print('Validation')
    with torch.no_grad():   # Disable gradient computation (faster and saves memory)
        model.eval()        # Disable Dropout and BatchNormalization
        val_loss = train_loop(val_loader, model, criterion, optimizer, writer, epoch)

    if save_model and epoch % save_frequency == 0:
        save_checkpoint({'epoch': epoch,
                        'state_dict': model.state_dict(),
                         'val_loss': val_loss,
                         'optimizer': optimizer.state_dict()}, model_path)
