In [None]:
import os
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np

from modules.Dataset import FeTADataSet
from modules.Utils import get_file_names

In [None]:
def remove_zeros_axial(image, mask):
    """Remove zero slices in axial orientation. Also permute the axis of mri images. 
     First axis will be used for axial."""

    # Permute images axis to get axial view in the first dimension.
    image = np.transpose(image, (2, 0, 1))
    mask = np.transpose(mask, (2, 0, 1))

    # Find zero slices.
    indexes = np.where(np.all(np.all(np.array(masks)==False, axis=2), axis=1))[0] 
    image = np.delete(image, indexes, axis=0)
    mask = np.delete(mask, indexes, axis=0) 

    # Convert images in numpy to nifti.
    image = nib.Nifti1Image(image, np.eye(4))
    mask = nib.Nifti1Image(mask, np.eye(4))

    return image, mask

In [None]:
def add_padding(image, mask, ax):
    """Adds zero padding to each dimension of 3D image to complete dimension sizes 128. 
    """
    
    size = image.shape[ax]
    remainder = 128-size
    
    if remainder%2==0:
        pad_ = (remainder//2, remainder//2)
    else:
        pad_ = (int(np.ceil(remainder/2)), int(np.floor(remainder/2)))
        assert pad_[0]+pad_[1]+size==128, pad_[0]+pad_[1]+size
    
    if ax == 0:
        image = np.pad(image, ((pad_[0], pad_[1]), (0, 0), (0, 0)), constant_values=0)
        mask = np.pad(mask, ((pad_[0], pad_[1]), (0, 0), (0, 0)), constant_values=0)
    elif ax == 1:
        image = np.pad(image, ((0, 0), (pad_[0], pad_[1]), (0, 0)), constant_values=0)
        mask = np.pad(mask, ((0, 0), (pad_[0], pad_[1]), (0, 0)), constant_values=0)
    elif ax == 2:
        image = np.pad(image, ((0, 0), (0, 0), (pad_[0], pad_[1])), constant_values=0)
        mask = np.pad(mask, ((0, 0), (0, 0), (pad_[0], pad_[1])), constant_values=0)
    
    return images, masks

In [None]:
def correct_indexes(indexes):
    """ If brain region size lower than 128 do not remove all zero indexes.
    """
    
    for pair in zip(indexes, np.roll(indexes, 1)):
        if pair[1]-pair[0]>1 and pair[1]-pair[0]<128:
            # indexes = [0, 1, 2, 3, 254, 255, 256]
            # x1 = 3, x2 = 254
            x1 = pair[0]
            x2 = pair[1]
                       
            x2_index = indexes.tolist().index(x2)  # for example index of 251 is 4 above
        
            total_zero_indexes = pair[0] + (256-pair[1])
            brain_region_size = 256-total_zero_indexes
            total_margin = 128 - brain_region_size

            pad_start = np.ceil(total_margin/2).astype(int)
            pad_end = np.floor(total_margin/2).astype(int)

            start_indexes = indexes[:x1][:-pad_start]
            end_indexes = indexes[x2_index:][pad_end:]

            indexes = np.concatenate((start_indexes, end_indexes))  
            
    return indexes

In [None]:
def removes_zeros_all(image, mask):
    """Remove zero slices in all orientations.""" 

    # Find zero slices and remove in axis 0.
    indexes = np.where(np.all(np.all(np.array(image)==False, axis=2), axis=1))[0]
    indexes = correct_indexes(indexes)
    image = np.delete(image, indexes, axis=0)
    mask = np.delete(mask, indexes, axis=0)
    
    ax = 0
    if image.shape[ax]<128:
        image, mask = add_padding(image, mask, ax)

    # Find zero slices and remove in axis 1.
    indexes = np.where(np.all(np.all(np.array(image)==False, axis=0), axis=1))[0]
    indexes = correct_indexes(indexes)
    image = np.delete(image, indexes, axis=1)
    mask = np.delete(mask, indexes, axis=1)
    
    ax = 1
    if image.shape[ax]<128:
        image, mask = add_padding(image, mask, ax)

    # Find zero slices and remove in axis 2.
    indexes = np.where(np.all(np.all(np.array(image)==False, axis=1), axis=0))[0]
    indexes = correct_indexes(indexes)
    image = np.delete(image, indexes, axis=2)
    mask = np.delete(mask, indexes, axis=2)
    
    ax = 2
    if image.shape[ax]<128:
        image, mask = add_padding(image, mask, ax)

    # Convert images in numpy to nifti.
    image = nib.Nifti1Image(image, np.eye(4))
    mask = nib.Nifti1Image(mask, np.eye(4))
    
    return image, mask

In [None]:
def save_data(image, mask, path_image, path_mask):    
    # Only change the main folder name. Rest of the paths will be same.
    new_path_image = os.path.join("data7", *path_image.split(os.sep)[1:])
    new_path_mask = os.path.join("data7", *path_mask.split(os.sep)[1:])

    # Create sub path.
    new_folder_path = os.path.join(*new_path_image.split(os.sep)[:-1])
    if not os.path.exists(new_folder_path):
        os.makedirs(new_folder_path)

    # Save new paths.
    nib.save(image, new_path_image)
    nib.save(mask, new_path_mask)   

In [None]:
files = get_file_names("feta_2.1/")

for sub, (path_image, path_mask) in files.items():
    print(sub)
    # Load nifti images.
    image = nib.load(path_image).get_fdata()
    mask = nib.load(path_mask).get_fdata()
    
    # new_image, new_mask =  remove_zeros_axial(image, mask)
    new_image, new_mask = removes_zeros_all(image, mask)
    save_data(new_image, new_mask, path_image, path_mask)

In [None]:
files = get_file_names("data2")
shapes = {'x':[], 'y':[], 'z':[]}

for sub, (path_image, path_mask) in files.items():
    # Load nifti images.
    images = nib.load(path_image).get_fdata()
    
    x1, y1, z1 = images.shape
    shapes['x'].append(x1)
    shapes['y'].append(y1)
    shapes['z'].append(z1)
    

print(f"x: {min(shapes['x'])}, y: {min(shapes['y'])}, z: {min(shapes['z'])}")

In [None]:
files = get_file_names("feta_2.1/")
images = nib.load(files["sub-013"][0]).get_fdata()
masks = nib.load(files["sub-013"][1]).get_fdata()
_, ax = plt.subplots(1, 2, figsize=(8, 8))
idx = 103
ax[0].imshow(images[:, :, idx])
ax[1].imshow(masks[:, :, idx])

In [None]:
files = get_file_names("data2")
images = nib.load(files["sub-013"][0]).get_fdata()
masks = nib.load(files["sub-013"][1]).get_fdata()
_, ax = plt.subplots(1, 2, figsize=(8, 8))
idx = 100
ax[0].imshow(images[:, :, idx])
ax[1].imshow(masks[:, :, idx])