In [None]:
import os
import numpy as np
import nibabel as nib
from PIL import Image
import torchio as tio

def slice3DImages(input_folder):
    output_folder = os.path.join(input_folder, 'Sliced')
    modalities = ['t1c', 't1n', 't2f', 't2w', 'seg']

    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    for subfolder in os.listdir(input_folder):
        subfolder_path = os.path.join(input_folder, subfolder)
        if os.path.isdir(subfolder_path):
            for modality in modalities:
                modality_path = os.path.join(subfolder_path, f'{subfolder}-{modality}.nii.gz')
                if os.path.isfile(modality_path):
                    img = nib.load(modality_path)
                    img_data = img.get_fdata()
                    img_data_4d = np.expand_dims(img_data, axis=0)
                    subject = tio.Subject(image = tio.ScalarImage(tensor = img_data_4d))
                    transform = tio.CropOrPad((240, 240, 155))
                    transformed_subject = transform(subject)
                    img_data = transformed_subject['image'].data.numpy().squeeze()

                    planes = ['sagittal', 'coronal', 'axial']

                    for plane in planes:
                        os.makedirs(os.path.join(output_folder, plane, modality, subfolder), exist_ok=True)

                        if plane == 'sagittal':
                            slices = img_data.shape[0]
                        elif plane == 'coronal':
                            slices = img_data.shape[1]
                        elif plane == 'axial':
                            slices = img_data.shape[2]
                        
                        for i in range(2 * slices):  # Multiply by 2 because we're taking overlapping slices
                            if i % 2 == 0:  # Even index
                                idx = i // 2
                            else:  # Odd index
                                idx = (i // 2) + (slices // 2)

                            if idx < slices:
                                if plane == 'sagittal':
                                    slice_2d = img_data[idx, :, :]
                                elif plane == 'coronal':
                                    slice_2d = img_data[:, idx, :]
                                elif plane == 'axial':
                                    slice_2d = img_data[:, :, idx]

                                max_val = np.max(slice_2d)
                                min_val = np.min(slice_2d)
                                if max_val - min_val > 0:
                                    slice_2d_normalized = (slice_2d - min_val) / (max_val - min_val)
                                else:
                                    assert max_val == 0 and min_val == 0
                                    slice_2d_normalized = slice_2d
                                
                                img_pil = Image.fromarray(np.uint8(255 * slice_2d_normalized))
                                img_pil.save(os.path.join(output_folder, plane, modality, subfolder, f"png-slice-{str(i).zfill(4)}.png"))

slice3DImages('ASNR-MICCAI-BraTS2023-PED-Challenge-TrainingData/')