# Playing with SimpleITK and nnU-Net

## Installation

In [None]:
%cd /content
!git clone https://github.com/woctezuma/playing-with-simpleitk.git

%pip install -qq SimpleITK nnunetv2

## Data nomenclature

In [None]:
ROOT_FOLDER = '/content/'
NNUNET_FOLDER_NAME = f"{ROOT_FOLDER}nnUNet_base/"

ORIGINAL_DATASET_FNAME = "Dataset000_Original"
ORIGINAL_DATASET_PATH = f"{NNUNET_FOLDER_NAME}{ORIGINAL_DATASET_FNAME}/"

DATASET_FNAME = "Dataset001_Downsampled"
DATASET_PATH = f"{NNUNET_FOLDER_NAME}{DATASET_FNAME}/"

TRIMMED_DATASET_FNAME = "Dataset002_Trimmed"
TRIMMED_DATASET_PATH = f"{NNUNET_FOLDER_NAME}{TRIMMED_DATASET_FNAME}/"

In [None]:
PATIENT_INDICES = [1,2,5,6,8,10,14,16,18,19] + list(range(21, 31))
IMAGE_TYPE = '.nii.gz'

def get_image_file_name(patient_no, modality_no = 0, data_folder = DATASET_PATH):
  image_folder = f'{data_folder}imagesTr/'
  return f'{image_folder}patientID{patient_no}_{modality_no:04}{IMAGE_TYPE}'

def get_ground_truth_file_name(patient_no, data_folder = DATASET_PATH):
  ground_truth_folder = f'{data_folder}labelsTr/'
  return f'{ground_truth_folder}patientID{patient_no}{IMAGE_TYPE}'

## Import data

You can skip all the steps and download the pre-pocessed dataset straight away with:

In [None]:
# %cd /content

# !curl -OL https://github.com/woctezuma/playing-with-simpleitk/releases/download/original_data/Dataset000_Original.tar.gz
# !tar xzf Dataset000_Original.tar.gz

# !curl -OL https://github.com/woctezuma/playing-with-simpleitk/releases/download/data/Dataset001_Downsampled.tar.gz
# !tar xzf Dataset001_Downsampled.tar.gz

If you prefer to download the original data and run the steps by yourself, run the cells below.

### Download data

In [None]:
%cd /content
!curl -O https://zenodo.org/records/3431873/files/CHAOS_Train_Sets.zip
!unzip -qq CHAOS_Train_Sets.zip

In [None]:
%cd /content/playing-with-simpleitk
%mv /content/Train_Sets data/
!python convert_to_nii.py

In [None]:
%mkdir -p {ORIGINAL_DATASET_PATH}
%mv data/output/* {ORIGINAL_DATASET_PATH}

### Tests

In [None]:
import numpy as np
import SimpleITK as sitk

patient_no = PATIENT_INDICES[0]

for image_name in [
    get_image_file_name(patient_no, data_folder=ORIGINAL_DATASET_PATH),
    get_ground_truth_file_name(patient_no, data_folder=ORIGINAL_DATASET_PATH),
    ]:

  image = sitk.ReadImage(image_name)
  print(image.GetSize())

  v = sitk.GetArrayViewFromImage(image)
  print(np.unique(v))

### Down-sample images

#### Utils

In [None]:
import SimpleITK as sitk

def resample_sitk_image(volume, new_spacing, interpolator=None):
    if not interpolator:
        interpolator = sitk.sitkLinear
        pixelid = volume.GetPixelIDValue()

        if pixelid not in [1, 2, 4]:
            raise NotImplementedError(
                'Set `interpolator` manually, '
                'can only infer for 8-bit unsigned or 16, 32-bit signed integers')
        if pixelid == 1: #  8-bit unsigned int
            interpolator = sitk.sitkNearestNeighbor

    # Reference: https://discourse.itk.org/t/resample-volume-to-specific-voxel-spacing-simpleitk/3531/2
    original_spacing = volume.GetSpacing()
    original_size = volume.GetSize()
    new_size = [int(round(osz*ospc/nspc)) for osz,ospc,nspc in zip(original_size, original_spacing, new_spacing)]
    return sitk.Resample(volume, new_size, sitk.Transform(), interpolator,
                         volume.GetOrigin(), new_spacing, volume.GetDirection(), 0,
                         volume.GetPixelID())

#### Down-sample for faster checks.

In [None]:
%mkdir -p {DATASET_PATH}imagesTr
%mkdir -p {DATASET_PATH}labelsTr

In [None]:
num_dim = 3
new_spacing = [4]*num_dim

for patient_no in PATIENT_INDICES:
  for file_name in [
      get_ground_truth_file_name(patient_no, data_folder=""),
      get_image_file_name(patient_no, data_folder=""),

  ]:
    print(file_name)

    original_image = sitk.ReadImage(f"{ORIGINAL_DATASET_PATH}{file_name}")

    # Copy information w.r.t. original spacing
    original_intensity_image = sitk.ReadImage(get_image_file_name(patient_no, data_folder=ORIGINAL_DATASET_PATH))
    original_image.CopyInformation(original_intensity_image)

    resampled_image = resample_sitk_image(original_image, new_spacing)

    print(original_image.GetSize())
    print(resampled_image.GetSize())

    sitk.WriteImage(resampled_image, f"{DATASET_PATH}{file_name}")

### Binarize label maps

nnUNet wants consecutive labels: 0, 1, etc.

Originally, the label map in the CHAOS challenge contains only two labels:
- 0 (background),
- 255 (region of interest).

We can simply binarize the label map.


In [None]:
import SimpleITK as sitk
import numpy as np

write_to_disk = True

for patient_no in PATIENT_INDICES:
  print('Patient n°{}'.format(patient_no))

  input_image_name = get_ground_truth_file_name(patient_no,
                                                data_folder=DATASET_PATH)
  input_image = sitk.ReadImage(input_image_name)
  print('Image size: {}'.format(input_image.GetSize()))

  v = sitk.GetArrayFromImage(input_image)

  labels = np.unique(v)

  max_val = max(labels)
  median_val = max_val/2

  # Binarize

  binarized_v = np.zeros(v.shape, v.dtype)
  binarized_v[v>median_val] = 1

  print(f'Labels: {np.unique(v)} -> {median_val} -> {np.unique(binarized_v)}')

  output_image = sitk.GetImageFromArray(binarized_v)

  # Copy meta-data
  output_image.CopyInformation(input_image)
  # print('Image size: {}'.format(output_image.GetSize()))

  if write_to_disk:
    output_image_name = input_image_name
    sitk.WriteImage(output_image, output_image_name)


## nnUNet

### Installation

In [None]:
%env nnUNet_raw=/content/nnUNet_base/
%env nnUNet_preprocessed=/content/nnUNet_preprocessed/
%env nnUNet_results=/content/nnUNet_results/

## Edit dataset.json at the root of the dataset folder

Reference: https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/dataset_conversion

In [None]:
import os
import glob
import json
from collections import OrderedDict

def create_dataset_json(dataset_path):
  json_dict = OrderedDict()

  json_dict['channel_names'] = {
      "0": "CT"
  }

  json_dict['labels'] = {
      "background": 0,
      "liver": 1,
  }

  file_names = glob.glob(f'{dataset_path}labelsTr/*{IMAGE_TYPE}')
  num_patients = len(file_names)

  json_dict['numTraining'] = num_patients
  json_dict['file_ending'] = IMAGE_TYPE

  with open(os.path.join(dataset_path, "dataset.json"), 'w') as f:
      json.dump(json_dict, f, indent=4, sort_keys=True)

In [None]:
create_dataset_json(DATASET_PATH)

In [None]:
%mkdir -p {TRIMMED_DATASET_PATH}
%cp -r {DATASET_PATH}* {TRIMMED_DATASET_PATH}

num_patients = 5
for i in PATIENT_INDICES[num_patients:]:
  %rm {TRIMMED_DATASET_PATH}imagesTr/patientID{i}_0000{IMAGE_TYPE}
  %rm {TRIMMED_DATASET_PATH}labelsTr/patientID{i}{IMAGE_TYPE}

In [None]:
create_dataset_json(TRIMMED_DATASET_PATH)

## Run pre-processing

In [None]:
!nnUNetv2_plan_and_preprocess -d 001 --verify_dataset_integrity

In [None]:
!nnUNetv2_plan_and_preprocess -d 002 --verify_dataset_integrity

## Training

To continue training, append `--c` to the command-line.

In [None]:
!nnUNetv2_train {TRIMMED_DATASET_FNAME} 2d all

In [None]:
!nnUNetv2_train {TRIMMED_DATASET_FNAME} 3d_fullres all

## Inference

In [None]:
# If you want to check the results before the end of the training, copy the checkpoint as below.

%cd /content/nnUNet_results/Dataset002_Trimmed/nnUNetTrainer__nnUNetPlans__3d_fullres/fold_all/
%cp checkpoint_best.pth checkpoint_final.pth

%cd /content/nnUNet_results/Dataset002_Trimmed/nnUNetTrainer__nnUNetPlans__2d/fold_all/
%cp checkpoint_best.pth checkpoint_final.pth

In [None]:
INPUT_FOLDER = f'{DATASET_PATH}imagesTr/'
ROOT_OUTPUT_FOLDER = '/content/sample_data/output/'

for config in ["2d", "3d_fullres"]:
  print(config)

  OUTPUT_FOLDER = f"{ROOT_OUTPUT_FOLDER}{config}/"
  !mkdir -p {OUTPUT_FOLDER}

  !nnUNetv2_predict \
  -i {INPUT_FOLDER} \
  -o {OUTPUT_FOLDER} \
  -c {config} \
  -d {TRIMMED_DATASET_FNAME} \
  -f all

## Visualize segmentation results

In [None]:
import numpy as np
import SimpleITK as sitk

INPUT_FOLDER = f'{DATASET_PATH}labelsTr/'

for config in ["2d", "3d_fullres"]:
  print(config)

  OUTPUT_FOLDER = f"{ROOT_OUTPUT_FOLDER}{config}/"

  dice_scores = []

  for patient_no in PATIENT_INDICES:
    ground_truth_name = f"{INPUT_FOLDER}patientID{patient_no}.nii.gz"
    prediction_name = f"{OUTPUT_FOLDER}patientID{patient_no}.nii.gz"

    ground_truth = sitk.ReadImage(ground_truth_name)
    prediction = sitk.ReadImage(prediction_name)

    v = sitk.GetArrayViewFromImage(ground_truth)
    w = sitk.GetArrayViewFromImage(prediction)

    dice_score = 2 * np.sum((v*w) > 0) / (np.sum(v>0) + np.sum(w>0))
    print(f'Patient n°{patient_no} ; Dice = {dice_score:.3f}')

    dice_scores.append(dice_score)

  print(f'\n[training and validation dataset] Average Dice score = {np.mean(dice_scores):.3f} (#patients={len(dice_scores)})')
  print(f'NB: training was done with the first {num_patients} patients. Look for possible over-fitting!')

  dice_scores_for_training = dice_scores[:num_patients]
  print(f'\n[training dataset] Average Dice score = {np.mean(dice_scores_for_training):.3f} (#patients={len(dice_scores_for_training)})')

  dice_scores_for_validation = dice_scores[num_patients:]
  print(f'\n[validation dataset] Average Dice score = {np.mean(dice_scores_for_validation):.3f} (#patients={len(dice_scores_for_validation)})')
