# Test Code for Registration and Synthesis

## Install Packages 

In [None]:
!pip install antspyx

In [None]:
!pip insatll SimpleITK

In [None]:
from numpy import spacing
import ants
import os
import torch
import shutil
import SimpleITK as sitk
import nibabel as nib
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
import time
import matplotlib.pyplot as plt
import glob
from tqdm import tqdm
import matplotlib
from scipy.ndimage import zoom
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader

## Task A: Skull Stripping

### Clone this repository:

In [None]:
!git clone https://github.com/MIC-DKFZ/HD-BET.git

### Go into the repository (the folder with the setup.py file) and install:

In [None]:
!cd HD-BET

In [None]:
!pip install -e .

### Run your dataset to remove the skull region:

In [None]:
'''
INPUT_FOLDER: Raw subject folder path (need to remove skull region);
OUTPUT_FOLDER: Skull stripping subject folder path.
'''
!hd-bet -i INPUT_FOLDER -o OUTPUT_FOLDER

## Task B: Registration

### Rigid register T2w to T1w, rigid + nonrigid register FA, ADC to T1w

In [41]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Give the input and output folder path
input_path = '/skull_removed_folder_path'
output_path = '/output_path'

if not os.path.exists(output_path):
  os.makedirs(output_path)

# Define the paths to the input images
t1_path = f"{input_path}/T1w_1mm.nii.gz"
t2_path = f"{input_path}/T2w_1mm_noalign.nii.gz"
fa_path = f"{input_path}/FA_deformed.nii.gz"
adc_path = f"{input_path}/ADC_deformed.nii.gz"

# Load the input images using ANTs
t1 = ants.image_read(t1_path)
t2 = ants.image_read(t2_path)
fa = ants.image_read(fa_path)
adc = ants.image_read(adc_path)

# Rigid align T2w, FA, and ADC to T1w using ANTs
t2_rigid = ants.registration(t1, t2, type_of_transform='Rigid', cost_function='MutualInformation', device=device)
fa_rigid = ants.registration(t1, fa, type_of_transform='Rigid', cost_function='MutualInformation', device=device)
adc_rigid = ants.registration(t1, adc, type_of_transform='Rigid', cost_function='MutualInformation', device=device)

# Non-rigid align FA and ADC to T1w using ANTs
fa_nonrigid = ants.registration(t1, fa_rigid['warpedmovout'], type_of_transform='SyN', device=device)
adc_nonrigid = ants.registration(t1, adc_rigid['warpedmovout'], type_of_transform='SyN', device=device)

# Save the resulting non-rigid registered images
ants.image_write(t2_rigid['warpedmovout'], f"{output_path}/T2w_registered.nii.gz")
ants.image_write(fa_nonrigid['warpedmovout'], f"{output_path}/FA_registered.nii.gz")
ants.image_write(adc_nonrigid['warpedmovout'], f"{output_path}/ADC_registered.nii.gz")
shutil.copy(t1_path, f"{output_path}/T1w_1mm.nii.gz")


### Resample registered FA and ADC dimension from (182, 218, 18) to (145, 174, 145), and resolution from 1mm to 1.25 mm

#### First use ANTs resample method - change dimension 

In [41]:
if not os.path.exists(output_path):
  os.makedirs(output_path)

FA_resampled_image = ants.resample_image(fa_nonrigid['warpedmovout'], (145, 174, 145), True, 4) 
ADC_resampled_image = ants.resample_image(adc_nonrigid['warpedmovout'], (145, 174, 145), True, 4) 

# Save the ants resampled image
ants.image_write(FA_resampled_image, f"{output_path}/FA_ants_resample.nii.gz")
ants.image_write(ADC_resampled_image, f"{output_path}/ADC_ants_resample.nii.gz")
shutil.copy(f"{output_path}/T2w_registered.nii.gz", f"{output_path}/T2w_align.nii.gz")

#### Second use SimpleITK resample method - change the resolution

In [41]:
submit_path = '/registration_submit_path'

if not os.path.exists(submit_path):
  os.makedirs(submit_path)

FA_file_reader = sitk.ImageFileReader()
FA_file_reader.SetImageIO('NiftiImageIO')
FA_file_reader.SetFileName(f"{output_path}/FA_ants_resample.nii.gz") # need to change path
FA_image = FA_file_reader.Execute()

ADC_file_reader = sitk.ImageFileReader()
ADC_file_reader.SetImageIO('NiftiImageIO')
ADC_file_reader.SetFileName(f"{output_path}/ADC_ants_resample.nii.gz") # need to change path
ADC_image = ADC_file_reader.Execute()

FA_image.SetSpacing([1.25,1.25,1.25])
ADC_image.SetSpacing([1.25,1.25,1.25])

# # Save the SimpleITK resampled image
sitk.WriteImage(FA_image, f"{submit_path}/FA_align.nii.gz")
sitk.WriteImage(ADC_image, f"{submit_path}/ADC_align.nii.gz")

## Task C: Synthesis

### Padding T1w_1mm, T2w_registered to (192, 224, 192) for our Network

In [41]:
t1_path = f"{output_path}/T1w_1mm.nii.gz"
t2_path = f"{output_path}/T2w_registered.nii.gz"

# Load the MRI T1w, T2w images
t1w_image = nib.load(t1_path).get_fdata()
t2w_image = nib.load(t2_path).get_fdata()

# Get the current shape of the images
t1w_shape = t1w_image.shape
t2w_shape = t2w_image.shape

# Calculate the amount of padding needed for each dimension
t1w_padding = tuple((32 - dim % 32) % 32 for dim in t1w_shape)
t2w_padding = tuple((32 - dim % 32) % 32 for dim in t2w_shape)

# Pad the images
t1w_padded = np.pad(t1w_image, ((0, t1w_padding[0]), (0, t1w_padding[1]), (0, t1w_padding[2])), mode='constant')
t2w_padded = np.pad(t2w_image, ((0, t2w_padding[0]), (0, t2w_padding[1]), (0, t2w_padding[2])), mode='constant')

# Create new Nifti images with the padded data and affine information
t1w_padded_nii = nib.Nifti1Image(t1w_padded, nib.load(t1_path).affine)
t2w_padded_nii = nib.Nifti1Image(t2w_padded, nib.load(t2_path).affine)

# Save the padded images in nii.gz format in the patient folder
nib.save(t1w_padded_nii, f"{output_path}/T1w_1mm_padded.nii.gz")
nib.save(t2w_padded_nii, f"{output_path}/T2w_registered_padded.nii.gz")

### Define the Network - UNet 

In [None]:
def conv_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.encoder = nn.ModuleList([
            conv_block(2, 16),
            conv_block(16, 32),
            conv_block(32, 64),
            conv_block(64, 128)
        ])
        self.pool = nn.MaxPool2d(2)
        self.bridge = conv_block(128, 256)
        self.decoder = nn.ModuleList([
            conv_block(256, 128),
            conv_block(128, 64),
            conv_block(64, 32),
            conv_block(32, 16)
        ])
        self.upconv = nn.ModuleList([
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
            nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
        ])
        self.final = nn.Conv2d(16, 1, kernel_size=1)

    def forward(self, x):
        skips = []
        for layer in self.encoder:
            x = layer(x)
            skips.append(x)
            x = self.pool(x)
        x = self.bridge(x)
        for i, layer in enumerate(self.decoder):
            x = self.upconv[i](x)
            x = torch.cat([x, skips[-(i+1)]], axis=1)
            x = layer(x)
        x = self.final(x)
        return x

### Generate the Test Dataloader - you should run with 1 T1w and 1 T2w at a time

In [41]:
'''image_proc_2C is for generating 2-channel source images pair (T1w, T2w) as the input to the network.'''
def image_proc_2C(filepath1, filepath2):
    """ Data loader (*.nii)
    :param filepath1: file path to the first set of images
    :param filepath2: file path to the second set of images
    :return: 2D array images with 2 channels
    """
    img_data0 = []
    img_data1 = []

    for item1, item2 in tqdm(zip(sorted(filepath1), sorted(filepath2)), desc='Processing'):
        # loading images
        img1 = nib.load(item1).get_fdata()
        img2 = nib.load(item2).get_fdata()
        # stack the two images along the last axis to create a 2-channel image
        combined_img = np.stack((img1, img2), axis=-1)
        img_data0.append(combined_img)

    img_data0 = np.concatenate(img_data0, axis=2)
    img_data0 = np.moveaxis(img_data0, [2], [0])
    return np.array(img_data0).astype('float32')

# Read and process the test images 
# Source image pairs
dir_list_sc_t1 = sorted(glob.glob(os.path.join(output_path, 'T1w_1mm_padded.nii.gz')))
dir_list_sc_t2 = sorted(glob.glob(os.path.join(output_path, 'T2w_registered_padded.nii.gz')))
img_sc = image_proc_2C(dir_list_sc_t1, dir_list_sc_t2)

# Rearrange the axes
img_sc = np.transpose(img_sc, (0, 3, 1, 2))
test_dataset = img_sc
test_dataset = torch.from_numpy(test_dataset).float()

### Load the saved model and test

#### FA

In [None]:
model_FA = UNet()
model_FA.load_state_dict(torch.load("/path_to_saved_FA_model")) 

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model_FA = model_FA.to(device)
model_FA.eval()

with torch.no_grad(): 
  test_dataset = test_dataset.to(device)
  output_FA = model_FA(test_dataset)
        
output_FA = output_FA.cpu()
output_FA = np.transpose(output_FA, (0, 2, 3, 1))
output_FA = output_FA.squeeze(-1)
output_FA = output_FA.numpy()

#### ADC

In [None]:
model_ADC = UNet()
model_ADC.load_state_dict(torch.load("/path_to_saved_ADC_model")) # you can either choose the (best model based on mae/best model based on loss)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model_ADC = model_ADC.to(device)
model_ADC.eval()

with torch.no_grad(): 
  test_dataset = test_dataset.to(device)
  output_ADC = model_ADC(test_dataset)
        
output_ADC = output_ADC.cpu()
output_ADC = np.transpose(output_ADC, (0, 2, 3, 1))
output_ADC = output_ADC.squeeze(-1)
output_ADC = output_ADC.numpy()

### Reverse predicted 2D FA and ADC images back to 3D images and Unpadding synthesis FA and ADC back to (182, 218, 182)

In [None]:
def reverse_image_proc(image_data):
    """ Reverse the image processing operations.
    :param image_data: The processed images.
    :return: The original 3D image.
    """
    # Move the axis back to its original position
    image_data = np.moveaxis(image_data, [0], [2])

    return image_data

# Create an identity affine matrix
affine = np.eye(4)

# reverse to the origial 3D image
FA_reversed_image = reverse_image_proc(output_FA)
FA_reversed_image = FA_reversed_image[0:182, 0:218, 0:182]
FA_img_nii = nib.Nifti1Image(FA_reversed_image, affine)

ADC_reversed_image = reverse_image_proc(output_ADC)
ADC_reversed_image = ADC_reversed_image[0:182, 0:218, 0:182]
ADC_img_nii = nib.Nifti1Image(ADC_reversed_image, affine)

nib.save(FA_img_nii, f"{output_path}/FA_registered_synthsized_need_resample.nii.gz") # 192, 224, 192
nib.save(ADC_img_nii, f"{output_path}/ADC_registered_synthsized_need_resample.nii.gz") # 192, 224, 192

### Resample unpadded synthesis FA and ADC to original dimension (145, 174, 145) and resolution 1.25mm

#### First use ANTs resample method

In [41]:
FA_image = ants.image_read(f"{output_path}/FA_registered_synthsized_need_resample.nii.gz") #(182, 218, 182)
ADC_image = ants.image_read(f"{output_path}/ADC_registered_synthsized_need_resample.nii.gz") #(182, 218, 182)
# Resample the image to the desired resolution and size
FA_resampled_image = ants.resample_image(FA_image, (145, 174, 145), True, 4) 
ADC_resampled_image = ants.resample_image(ADC_image, (145, 174, 145), True, 4) 

# Save the resampled image
ants.image_write(FA_resampled_image, f"{output_path}/FA_syn_ants_resample.nii.gz")
ants.image_write(ADC_resampled_image, f"{output_path}/ADC_syn_ants_resample.nii.gz")

#### Second use SimpleITK resample method 

In [41]:
submitted_path = '/synthesis_submit_folder_path'

if not os.path.exists(submitted_path):
  os.makedirs(submitted_path)

FA_file_reader = sitk.ImageFileReader()
FA_file_reader.SetImageIO('NiftiImageIO')
FA_file_reader.SetFileName(f"{output_path}/FA_syn_ants_resample.nii.gz") # need to change path
FA_image = FA_file_reader.Execute()

ADC_file_reader = sitk.ImageFileReader()
ADC_file_reader.SetImageIO('NiftiImageIO')
ADC_file_reader.SetFileName(f"{output_path}/ADC_syn_ants_resample.nii.gz") # need to change path
ADC_image = ADC_file_reader.Execute()

FA_image.SetSpacing([1.25,1.25,1.25])
ADC_image.SetSpacing([1.25,1.25,1.25])

# # Save the resampled image
sitk.WriteImage(FA_image, f"{submitted_path}/FA_syn.nii.gz")
sitk.WriteImage(ADC_image, f"{submitted_path}/ADC_syn.nii.gz")

### Finally you should apply the normalization to get the finalized synthesis FA and ADC

In [41]:
# Load the images
img_FA = nib.load(f"{submitted_path}/FA_syn.nii.gz")
img_ADC = nib.load(f"{submitted_path}/ADC_syn.nii.gz")

# Convert image data to numpy array
data_FA = img_FA.get_fdata()
data_ADC = img_ADC.get_fdata()

# Set the window and level
# ADC
ADC_window = 0.00223
ADC_level = 0.00111

# FA
FA_window = 0.781
FA_level = 0.390

# Apply the window/level
min_value_ADC = ADC_level - ADC_window / 2
max_value_ADC = ADC_level + ADC_window / 2
min_value_FA = FA_level - FA_window / 2
max_value_FA = FA_level + FA_window / 2

# Any intensity value less than the minimum display value is set to 0,
# and any value greater than the maximum display value is set to 1.
# Values in between are rescaled to fall within the 0-1 range.
wl_data_ADC = np.clip((data_ADC - min_value_ADC) / (max_value_ADC - min_value_ADC), 0, 1)
wl_data_FA = np.clip((data_FA - min_value_FA) / (max_value_FA - min_value_FA), 0, 1)

# Create a new NIfTI image with the windowed/leveled data and the same affine transformation
wl_img_ADC = nib.Nifti1Image(wl_data_ADC, img_ADC.affine)
wl_img_FA = nib.Nifti1Image(wl_data_FA, img_FA.affine)

# Save the final images
nib.save(wl_img_ADC, f"{submitted_path}/ADC_syn.nii.gz")
nib.save(wl_img_FA, f"{submitted_path}/FA_syn.nii.gz")