<a href="https://colab.research.google.com/github/zahraDehghanian97/VoxelMorph_Registration/blob/master/VoxelMorph.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Medical Image Registration Using Voxelmorph
## Intelligent Analysis of Biomedical Images - Fall 2022
## Instructor: Dr. M. H. Rohban
## HW5 - Practical
## Student Name: Zahra Dehghanian
## Student ID: 401300417


In this assignment, you are going to train a voxelmorph network to do the unsupervised image registration task. You are going to use the [CHAOS MR T2 dataset](https://chaos.grand-challenge.org/), which is available to you from this google drive folder: [dataset](https://drive.google.com/drive/folders/1BGPa--fsmf8I5AeyFxa142N4lXpTqzIG?usp=sharing).
In this folder, you have access to 20 3D MR images of 20 different patients. Each image has its own number of slices (2D images) and consists of only one channel (gray-scaled images).

In the folder provided, you will have four different types of objects:


1.   the image itself
2.   the foreground mask (fgmask) of the image
3.   the label of the image
4.   the superpixels of the image

You have nothing to do with the superpixels; they are just in the folder. 

### Trianing
In the following sections, you may read the data, get familiar with it, and implement a Pytorch code that trains a Voxelmorph network to register a moving image (2D) to a fixed image (also 2D). 

1.   Do your training with 19 images (all except the last one which has the ID 39; we need it for the testing part)
2.   You have to train your network in an unsupervised manner (don't use labels in training)
3.   Use voxelmorph library; you don't need to implement the model yourself
4.   Use both similarity and smoothness losses. You may use trial and error to determine the best combination of these losses.
5. Plot the training loss  
6. Visualize your model's ability to register images by plotting at least 10 (moving, moved, fixed) triplets
6.   You may use creative approaches to improve your results (e.g., train bidirectional, train with not adjacent slices, etc.)

### Testing


1.   Pick the patient 39 image and label
2.   Find the range of slices in which all slices have a labeled organ (slices whose label is not blank)
3.   Pick the middle slice in the range mentioned above
4.   Propagate its label to the whole volume using your trained model (using the displacement fields)
5.   Visualize your results. Plot propagated labels and actual labels in order to do an intuitive comparison
6.   Take the dice score between the propagated labels and the actual labels of slices. Report them one by one. Take the average of those. Report this number as the final metric of evaluation of your model.







# Mount Drive
Mount your Google Drive here to get access to the data folder. If you don't know what is it and how to do it, use this [link](https://www.geeksforgeeks.org/download-anything-to-google-drive-using-google-colab/#:~:text=To%20import%20google%20drive%2C%20write,run%20it%20by%20Ctrl%2BEnter%20.&text=On%20running%20code%2C%20one%20blue,permission%20to%20access%20google%20drive.)

In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

Mounted at /content/gdrive


# Installations

In [None]:
!pip install SimpleITK -q
!pip install sacred==0.7.5
!pip install voxelmorph -q

[K     |████████████████████████████████| 52.7 MB 148 kB/s 
[?25hLooking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sacred==0.7.5
  Downloading sacred-0.7.5-py2.py3-none-any.whl (92 kB)
[K     |████████████████████████████████| 92 kB 272 kB/s 
Collecting py-cpuinfo>=4.0
  Downloading py_cpuinfo-9.0.0-py3-none-any.whl (22 kB)
Collecting jsonpickle<1.0,>=0.7.2
  Downloading jsonpickle-0.9.6.tar.gz (67 kB)
[K     |████████████████████████████████| 67 kB 6.4 MB/s 
Collecting colorama>=0.4
  Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)
Collecting munch<3.0,>=2.0.2
  Downloading munch-2.5.0-py2.py3-none-any.whl (10 kB)
Collecting docopt<1.0,>=0.3
  Downloading docopt-0.6.2.tar.gz (25 kB)
Building wheels for collected packages: docopt, jsonpickle
  Building wheel for docopt (setup.py) ... [?25l[?25hdone
  Created wheel for docopt: filename=docopt-0.6.2-py2.py3-none-any.whl size=13723 sha256=12bed39b9b7cc91cc53fac68ee6

# Imports

In [None]:
import os
import numpy as np
from matplotlib import pyplot as plt
import glob
from tqdm import tqdm
import torch

os.environ['VXM_BACKEND'] = 'pytorch'
import voxelmorph as vxm

import SimpleITK as sitk
import pickle
import json
import time

# Reading Data

In [None]:
class Patient:
    def __init__(self, id_):
        self.id = id_
        self.frames_count = None
        self.fgmasks = None
        self.images = None
        self.labels = None

    def remove_without_labels(self):
        idx = []
        for frame_number, label in enumerate(self.labels):
            if label.max() > 0:
                idx.append(frame_number)
        self.frames_count = len(idx)
        self.fgmasks = self.fgmasks[idx]
        self.images = self.images[idx]
        self.labels = self.labels[idx]
        
    def print_data_shapes(self):
        print('patient ', self.id)
        print('fmgasks: ', self.fgmasks.shape)
        print('images: ', self.images.shape)
        print('lables: ', self.labels.shape)
        print('-' * 30)

    def plot(self, frame_number):
        fig , (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 15))
        ax1.imshow(self.images[frame_number], origin='lower', cmap='gray')
        ax2.imshow(self.fgmasks[frame_number], origin='lower', cmap='gray')
        ax3.imshow(self.labels[frame_number], origin='lower', cmap='gray')
        ax1.set_title("image")
        ax2.set_title("fgmask")
        ax3.set_title("lable")


def create_patients(ids):
    patients = {}
    for id_ in ids:
        patients[id_] = Patient(id_)
    return patients


In [None]:
def save_patients_images(path):
    images = {}
    for id, patient in patients.items():
        images[id] = patient.images
    with open(path, 'wb') as f:
        pickle.dump(images, f)


def save_patients_superpixs(path):
    superpixs = {}
    for id, patient in patients.items():
        superpixs[id] = patient.superpixs
    with open(path, 'wb') as f:
        pickle.dump(superpixs, f)


def save_patients_labels(path):
    labels = {}
    for id, patient in patients.items():
        labels[id] = patient.labels
    with open(path, 'wb') as f:
        pickle.dump(labels, f)


def read_nii_bysitk(input_fid, peel_info=False):
    """ read nii to numpy through simpleitk
        peelinfo: taking direction, origin, spacing and metadata out
    """
    img_obj = sitk.ReadImage(input_fid)
    img_np = sitk.GetArrayFromImage(img_obj)
    if peel_info:
        info_obj = {
            "spacing": img_obj.GetSpacing(),
            "origin": img_obj.GetOrigin(),
            "direction": img_obj.GetDirection(),
            "array_size": img_np.shape
        }
        return img_np, info_obj
    else:
        return img_np

In [None]:
patient_ids = [1, 2, 3, 5, 8, 10, 13, 15, 19, 20, 21, 22, 31, 32, 33, 34, 36, 37, 38, 39]
project_root = "/content/gdrive/MyDrive/IABI-F2022/"
data_root = '/content/gdrive/MyDrive/IABI-F2022/chaos_MR_T2_normalized/'

patients = create_patients(patient_ids)
for path in tqdm(glob.iglob(data_root + '**/*.nii.gz', recursive=True), desc="Reading"):
    id_ = int(path.split('_')[-1].split('.')[0])
    patient = patients.get(id_)
    obj = read_nii_bysitk(path)
    if 'fgmask' in path:
        patient.fgmasks = obj
        patient.frames_count = obj.shape[0]
    elif 'image' in path:
        patient.images = obj
    elif 'label' in path:
        patient.labels = obj

Reading: 80it [00:38,  2.08it/s]


# Visualization

# Volxelmorph

# Data Set and Data Loader

# Data Loader Visualization

# Model and Optimizer

# Spatial Transformer

# Training

# Testing