# 01 Image registration helper code
Code based on https://www.kaggle.com/boojum/connecting-voxel-spaces/notebook

two approaches are outlined in the reference notebook: 
1. transforming scans to scanner space with affine matrix defined in dcm metadata --> resampling to match one channel
    this approach is slower, and if there is any disruption to scanning or meta data input, it can lead to incorrect results
2. simpleitk image registration
    this approach is more general 

Backlog: 
1. **(DONE)** view a sample of img_registration --> scaling --> crop pipeline
2. **(DONE)** image registration function: take reference channel as argument
3. **(DONE)** image registration function: re-order voxels if needed 
4. check values distribution based on metadata (MR Acquisition and Pixel encoding)
5. log scaling

In [None]:
import os
import sys 

#from tqdm import tqdm
from pathlib import Path
#from PIL import Image

import pydicom
import numpy as np
np.set_printoptions(precision=4)
np.set_printoptions(suppress=True)

import pandas as pd

import plotly.express as px
from ipywidgets import widgets
from IPython.display import display, clear_output, Image
#import matplotlib.pyplot as plt
#import matplotlib.colors

#import nibabel as nib
import SimpleITK as sitk
sitk.ProcessObject_SetGlobalWarningDisplay(False)

train_path = '../input/rsna-miccai-brain-tumor-radiogenomic-classification/train/'
train_dirs = os.listdir(train_path)

DATASET = 'train'
scan_types = ['FLAIR','T1w','T1wCE','T2w']
data_root = Path("../input/rsna-miccai-brain-tumor-radiogenomic-classification")

In [None]:
def resample(image, ref_image):
    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(ref_image) 
    resampler.SetInterpolator(sitk.sitkLinear)
    resampler.SetTransform(sitk.AffineTransform(image.GetDimension()))
    resampler.SetOutputSpacing(ref_image.GetSpacing())
    resampler.SetSize(ref_image.GetSize())
    resampler.SetOutputDirection(ref_image.GetDirection())
    resampler.SetOutputOrigin(ref_image.GetOrigin())
    resampler.SetDefaultPixelValue(image.GetPixelIDValue())
    resamped_image = resampler.Execute(image)
    return resamped_image


def image_registration(study_id, ref_channel='FLAIR'):
    
    reader = sitk.ImageSeriesReader()
    reader.LoadPrivateTagsOn()
    
    filenamesDICOM = reader.GetGDCMSeriesFileNames(str(data_root.joinpath(DATASET,study_id,'FLAIR')))
    reader.SetFileNames(filenamesDICOM)
    flair_sitk = reader.Execute()
    
    filenamesDICOM = reader.GetGDCMSeriesFileNames(str(data_root.joinpath(DATASET,study_id,'T1w')))
    reader.SetFileNames(filenamesDICOM)
    t1_sitk = reader.Execute()
    
    filenamesDICOM = reader.GetGDCMSeriesFileNames(str(data_root.joinpath(DATASET,study_id,'T1wCE')))
    reader.SetFileNames(filenamesDICOM)
    t1wce_sitk = reader.Execute()

    filenamesDICOM = reader.GetGDCMSeriesFileNames(str(data_root.joinpath(DATASET,study_id,'T2w')))
    reader.SetFileNames(filenamesDICOM)
    t2_sitk = reader.Execute()

    ref_sitk_dct = {
        'FLAIR': flair_sitk,
        'T1w': t1_sitk,
        'T1wCE': t1wce_sitk,
        'T2w': t2_sitk,
    }
    ref_sitk = ref_sitk_dct[ref_channel]
    
    flair_resampled = resample(flair_sitk, ref_sitk)
    t1_resampled = resample(t1_sitk, ref_sitk)
    t1wce_resampled = resample(t1wce_sitk, ref_sitk)
    t2_resampled = resample(t2_sitk, ref_sitk)
    
    flair_array = sitk.GetArrayFromImage(flair_resampled)
    t1_array = sitk.GetArrayFromImage(t1_resampled) 
    t1wce_array = sitk.GetArrayFromImage(t1wce_resampled) 
    t2_array = sitk.GetArrayFromImage(t2_resampled)
    
    stacked = np.stack([flair_array, t1_array, t1wce_array, t2_array,])
    
    ref_dir = ref_sitk.GetDirection()
    ref_dir = np.abs(np.round(np.array(ref_dir)))
    
    if np.array_equal(ref_dir, np.array([1, 0, 0, 0, 0, 1, 0, 1, 0,])):
        return stacked.transpose((0, 2, 1, 3))
    elif np.array_equal(ref_dir, np.array([1, 0, 0, 0, 1, 0, 0, 0, 1,])):
        return stacked
    elif np.array_equal(ref_dir, np.array([0, 0, 1, 1, 0, 0, 0, 1, 0,])):
        return stacked.transpose((0, 2, 3, 1))
    

def min_max_scaling(img):
    mins = img.min(axis=(1,2,3), keepdims=True)
    maxs = img.max(axis=(1,2,3), keepdims=True)
    scaled = (img - mins) / (maxs-mins)
    return scaled


# def crop(img, min_threshold = 0.0001):
#     if img.sum() == 0:
#         return voxel
    
#     keep = (img.mean(axis=(0, 1, 2)) > min_threshold)
#     img = img[:, :, :, keep]
    
#     keep = (img.mean(axis=(0, 1, 3)) > min_threshold)
#     img = img[:, :, keep, :]
    
#     keep = (img.mean(axis=(0, 2, 3)) > min_threshold)
#     img = img[:, keep, :, :]
#     return img


def crop(img, scan_type):
    if img.sum() == 0:
        return voxel
    
    channel_map = {
        'FLAIR': 0,
        'T1w': 1,
        'T1wCE': 2,
        'T2w': 3
    }
    
    c = channel_map[scan_type]
    
    keep = (img[c].mean(axis=(0, 1)) > 0)
    img = img[:, :, :, keep]
    
    keep = (img[c].mean(axis=(0, 2)) > 0)
    img = img[:, :, keep, :]
    
    keep = (img[c].mean(axis=(1, 2)) > 0)
    img = img[:, keep, :, :]
    return img

def log_scaling(img):
    mins = np.log(img.min(axis=(1,2,3), keepdims=True) + 100)
    maxs = np.log(np.percentile(img, 99, axis=(1,2,3), keepdims=True) + 1)
    print(mins)
    print(maxs)
    scaled = (np.log(img + 1) - mins) / (maxs-mins)
    clipped = np.clip(scaled, 0, 1)
    return clipped

In [None]:
# from matplotlib import animation, rc
# rc('animation', html='jshtml')


# def display_img(img):
#     rgb = img[:,img.shape[1]//2,:,:].transpose(1,2,0)
#     im = Image.fromarray((rgb * 255).astype(np.uint8))
#     return im


# def get_slice(img, n):
#     rgb = img[:,n].transpose((1,2,0))
#     return (rgb * 255).astype(np.uint8)


# def get_channel(img, n_slice, channel):
#     rgb = img[channel,n_slice]
#     return (rgb * 255).astype(np.uint8)


# def anim_breakdown(ims, title=''):
#     fig, (ax1, ax2, ax3, ax4) = plt.subplots(1,4 ,figsize=(24,6))
    
#     plt.axis('off')
#     fig.patch.set_facecolor('white')

#     im1 = ax1.imshow(get_channel(ims, n_slice=0, channel=0))
#     im2 = ax2.imshow(get_channel(ims, n_slice=0, channel=1))
#     im3 = ax3.imshow(get_channel(ims, n_slice=0, channel=2))
#     im4 = ax4.imshow(get_channel(ims, n_slice=0, channel=3))
    
#     ax1.title.set_text('FLAIR')
#     ax2.title.set_text('T1w')
#     ax3.title.set_text('T1wCE')
#     ax4.title.set_text('T2w')
   

#     def animate_func(i):
#         im1.set_array(get_channel(ims, n_slice=i, channel=0))
#         im2.set_array(get_channel(ims, n_slice=i, channel=1))
#         im3.set_array(get_channel(ims, n_slice=i, channel=2))
#         im4.set_array(get_channel(ims, n_slice=i, channel=3))
#         return fig
    
#     plt.close()
#     return animation.FuncAnimation(fig, animate_func, frames = ims.shape[1], interval = 1000//24)

In [None]:
#ref_scantype = 'T1wCE'
#reg_im = image_registration(train_dirs[2], ref_scantype)
#print(f'{ref_scantype} aligned image {reg_im.shape}')
#anim_breakdown(crop(min_max_scaling(reg_im)))

cropping based on https://www.kaggle.com/ren4yu/normalized-voxels-align-planes-and-crop
since linear interpolation is used, threshold is required instead of checking for zeros

Animation code based on https://www.kaggle.com/ihelon/brain-tumor-eda-with-animations-and-modeling

In [None]:
def show_all_alignments(study_path):
    
    for scan_type in scan_types:
        reg_im = image_registration(study_path, scan_type)
        im = crop(min_max_scaling(reg_im), scan_type)

        fig = px.imshow(im, 
                        animation_frame=1, 
                        facet_col=0, 
                        binary_string=True, 
                        title=f'Reference Channel = {scan_type}, Dimensions = {im.shape}')
        
        for i, label in enumerate(scan_types):
            fig.layout.annotations[i]['text'] = label

        fig.show()

show_all_alignments(train_dirs[2])