In [83]:
import nibabel as nib
import numpy as np
from tqdm.auto import tqdm

In [37]:
import nibabel as nib
import numpy as np
# Path to the .nii.gz file
file_path = "/scratch/student/sinaziaee/datasets/3d_dataset/testing/images/kidney_0038_0000.nii.gz"

# Load the .nii.gz file
img = nib.load(file_path)
img = np.array(img.get_fdata())
# Print the shape of the image
print(img.shape)

padding_needed = 512 - img.shape[2]
pad_before = padding_needed // 2
pad_after = padding_needed - pad_before
padded_img = np.pad(img, ((0, 0), (0, 0), (pad_before, pad_after)), mode='constant', constant_values=0)

padded_img.shape

(512, 512, 205)


(512, 512, 512)

In [55]:
import os

def make_path(path):
    if not os.path.exists(path):
        os.makedirs(path)
    return path

dest_train_images_path = make_path('3d_original_padded_images/training/images')
dest_train_masks_path = make_path('3d_original_padded_images/training/labels')
dest_valid_images_path = make_path('3d_original_padded_images/validation/images')
dest_valid_masks_path = make_path('3d_original_padded_images/validation/labels')
dest_test_images_path = make_path('3d_original_padded_images/testing/images')
dest_test_masks_path = make_path('3d_original_padded_images/testing/labels')

src_train_images_path = '/scratch/student/sinaziaee/datasets/3d_dataset/training/images'
src_train_masks_path = '/scratch/student/sinaziaee/datasets/3d_dataset/training/labels'
src_valid_images_path = '/scratch/student/sinaziaee/datasets/3d_dataset/validation/images'
src_valid_masks_path = '/scratch/student/sinaziaee/datasets/3d_dataset/validation/labels'
src_test_images_path = '/scratch/student/sinaziaee/datasets/3d_dataset/testing/images'
src_test_masks_path = '/scratch/student/sinaziaee/datasets/3d_dataset/testing/labels'

In [80]:
def pad_image(image, mask, dest_image_path, dest_mask_path, new_image_name, new_mask_name):
    depth = mask.shape[2]
    original_depth = mask.shape[2]
    start_inx = 0
    end_inx = depth
    for j in range(depth):
        slice = mask[:,:,j]
        binary_slice = np.where(slice > 0.5, 1, 0)
        max_slice = np.max(binary_slice)
        if max_slice == 1:
            start_inx = j
            break

    for j in range(depth-1, 0, -1):
        slice = mask[:,:,j]
        binary_slice = np.where(slice > 0.5, 1, 0)
        max_slice = np.max(binary_slice)
        if max_slice == 1:
            end_inx = j
            break
        
    start_inx = start_inx - 1
    end_inx = end_inx + 1
    # print(start_inx, end_inx, depth)

    temp_image = image.copy()
    temp_mask = mask.copy()
    depth = mask.shape[2]
    desired_depth = 512
    while depth <= desired_depth:
        starter = temp_image[:, :, 0:start_inx]
        ender = temp_image[:, :, end_inx:]
        temp_image = np.concatenate((starter, temp_image, ender), axis=2)
        depth = temp_image.shape[2]
        
        starter = temp_mask[:, :, 0:start_inx]
        ender = temp_mask[:, :, end_inx:]
        temp_mask = np.concatenate((starter, temp_mask, ender), axis=2)

    start_inx_weight = start_inx
    end_inx_weight = original_depth - end_inx
    # print(start_inx_weight, end_inx_weight)
    total_weight = start_inx_weight + end_inx_weight
    # print(total_weight)
    start_inx_weight = start_inx_weight / total_weight
    end_inx_weight = end_inx_weight / total_weight
    # print(start_inx_weight, end_inx_weight)

    start_remover = int((depth - desired_depth) * start_inx_weight)
    end_remover = start_remover + desired_depth
    # print(start_remover, end_remover)
    new_image = temp_image[:, :, start_remover:end_remover]
    new_mask = temp_mask[:, :, start_remover:end_remover]

    new_image_nifti = nib.Nifti1Image(new_image, np.eye(4))
    new_mask_nifti = nib.Nifti1Image(new_mask, np.eye(4))
    nib.save(new_image_nifti, f'{dest_image_path}/{new_image_name}')
    nib.save(new_mask_nifti, f'{dest_mask_path}/{new_mask_name}')

In [86]:
def padding_3d_image(src_image_path, src_mask_path, dest_image_path, dest_mask_path):
    for image_file_name, mask_file_name in zip(sorted(os.listdir(src_image_path)), sorted(os.listdir(src_mask_path))):
        image = nib.load(os.path.join(src_image_path, image_file_name)).get_fdata()
        mask = nib.load(os.path.join(src_mask_path, mask_file_name)).get_fdata()
        pad_image(image, mask, dest_image_path, dest_mask_path, image_file_name, mask_file_name)
padding_3d_image(src_train_images_path, src_train_masks_path, dest_train_images_path, dest_train_masks_path)
padding_3d_image(src_valid_images_path, src_valid_masks_path, dest_valid_images_path, dest_valid_masks_path)
padding_3d_image(src_test_images_path, src_test_masks_path, dest_test_images_path, dest_test_masks_path)