In [6]:
from collections import defaultdict

from medpy.io import load
import os
import numpy as np
import torch

import random
from copy import deepcopy
from scipy.ndimage import map_coordinates, fourier_gaussian
from scipy.ndimage.filters import gaussian_filter, gaussian_gradient_magnitude
from scipy.ndimage.morphology import grey_dilation
from skimage.transform import resize
from scipy.ndimage.measurements import label as lb
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import SimpleITK as sitk
from collections import OrderedDict
import pickle
# from utilities.file_and_folder_operations import subfiles
def subfiles(folder,res, join=True, prefix=None, suffix=None, sort=True):
    # if join:
    #     l = os.path.join
    # else:
    #     l = lambda x, y: y
    # for i in os.listdir(folder):
    #     print(i)
    # res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i))
    #         and (prefix is None or i.startswith(prefix))
    #         and (suffix is None or i.endswith(suffix))]
    
    dirList=[]
    for i in os.listdir(folder):
        wholepath = os.path.join(folder, i)
        if os.path.isdir(wholepath):
            dirList.append(wholepath)
        if os.path.isfile(wholepath):
            res.append(wholepath)
            if not wholepath.endswith(suffix):
                res.remove(wholepath)
    if dirList:
        for subDir in dirList:
            subfiles(subDir,res,join=False,suffix=".nii.gz")
    if sort:
        res.sort()
    
    
def pad_nd_image(image, new_shape=None, mode="constant", kwargs=None, return_slicer=False, shape_must_be_divisible_by=None):
    """
    one padder to pad them all. Documentation? Well okay. A little bit
    :param image: nd image. can be anything
    :param new_shape: what shape do you want? new_shape does not have to have the same dimensionality as image. If
    len(new_shape) < len(image.shape) then the last axes of image will be padded. If new_shape < image.shape in any of
    the axes then we will not pad that axis, but also not crop! (interpret new_shape as new_min_shape)
    Example:
    image.shape = (10, 1, 512, 512); new_shape = (768, 768) -> result: (10, 1, 768, 768). Cool, huh?
    image.shape = (10, 1, 512, 512); new_shape = (364, 768) -> result: (10, 1, 512, 768).
    :param mode: see np.pad for documentation
    :param return_slicer: if True then this function will also return what coords you will need to use when cropping back
    to original shape
    :param shape_must_be_divisible_by: for network prediction. After applying new_shape, make sure the new shape is
    divisibly by that number (can also be a list with an entry for each axis). Whatever is missing to match that will
    be padded (so the result may be larger than new_shape if shape_must_be_divisible_by is not None)
    :param kwargs: see np.pad for documentation
    """
    if kwargs is None:
        kwargs = {'constant_values': 0}

    if new_shape is not None:
        old_shape = np.array(image.shape[-len(new_shape):])
    else:
        assert shape_must_be_divisible_by is not None
        assert isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray))
        new_shape = image.shape[-len(shape_must_be_divisible_by):]
        old_shape = new_shape

    num_axes_nopad = len(image.shape) - len(new_shape)

    new_shape = [max(new_shape[i], old_shape[i]) for i in range(len(new_shape))]

    if not isinstance(new_shape, np.ndarray):
        new_shape = np.array(new_shape)

    if shape_must_be_divisible_by is not None:
        if not isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray)):
            shape_must_be_divisible_by = [shape_must_be_divisible_by] * len(new_shape)
        else:
            assert len(shape_must_be_divisible_by) == len(new_shape)

        for i in range(len(new_shape)):
            if new_shape[i] % shape_must_be_divisible_by[i] == 0:
                new_shape[i] -= shape_must_be_divisible_by[i]

        new_shape = np.array([new_shape[i] + shape_must_be_divisible_by[i] - new_shape[i] % shape_must_be_divisible_by[i] for i in range(len(new_shape))])

    difference = new_shape - old_shape
    pad_below = difference // 2
    pad_above = difference // 2 + difference % 2
    pad_list = [[0, 0]]*num_axes_nopad + list([list(i) for i in zip(pad_below, pad_above)])

    if not ((all([i == 0 for i in pad_below])) and (all([i == 0 for i in pad_above]))):
        res = np.pad(image, pad_list, mode, **kwargs)
    else:
        res = image

    if not return_slicer:
        return res
    else:
        pad_list = np.array(pad_list)
        pad_list[:, 1] = np.array(res.shape) - pad_list[:, 1]
        slicer = list(slice(*i) for i in pad_list)
        return res, slicer

def create_nonzero_mask(data):
    from scipy.ndimage import binary_fill_holes
    assert len(data.shape) == 4 or len(data.shape) == 3, "data must have shape (C, X, Y, Z) or shape (C, X, Y)"
    nonzero_mask = np.zeros(data.shape[1:], dtype=bool)
    print(4,data.shape[1:])
    print(5,data.shape[0])
    for c in range(data.shape[0]):
        this_mask = data[c] != 0
        nonzero_mask = nonzero_mask | this_mask
    nonzero_mask = binary_fill_holes(nonzero_mask)
    return nonzero_mask


def get_bbox_from_mask(mask, outside_value=0):
    mask_voxel_coords = np.where(mask != outside_value)
    minzidx = int(np.min(mask_voxel_coords[0]))
    maxzidx = int(np.max(mask_voxel_coords[0])) + 1
    minxidx = int(np.min(mask_voxel_coords[1]))
    maxxidx = int(np.max(mask_voxel_coords[1])) + 1
    minyidx = int(np.min(mask_voxel_coords[2]))
    maxyidx = int(np.max(mask_voxel_coords[2])) + 1
    return [[minzidx, maxzidx], [minxidx, maxxidx], [minyidx, maxyidx]]


def crop_to_bbox(image, bbox):
    assert len(image.shape) == 3, "only supports 3d images"
    resizer = (slice(bbox[0][0], bbox[0][1]), slice(bbox[1][0], bbox[1][1]), slice(bbox[2][0], bbox[2][1]))
    return image[resizer]

def crop_to_nonzero(data, seg=None, nonzero_label=-1):
    """
    :param data:
    :param seg:
    :param nonzero_label: this will be written into the segmentation map
    :return:
    """
    nonzero_mask = create_nonzero_mask(data)
    bbox = get_bbox_from_mask(nonzero_mask, 0)

    cropped_data = []
    for c in range(data.shape[0]):
        cropped = crop_to_bbox(data[c], bbox)
        cropped_data.append(cropped[None])
    data = np.vstack(cropped_data)

    if seg is not None:
        cropped_seg = []
        for c in range(seg.shape[0]):
            cropped = crop_to_bbox(seg[c], bbox)
            cropped_seg.append(cropped[None])
        seg = np.vstack(cropped_seg)

    nonzero_mask = crop_to_bbox(nonzero_mask, bbox)[None]
    if seg is not None:
        seg[(seg == 0) & (nonzero_mask == 0)] = nonzero_label
    
    return data, seg, bbox

def preprocess_data(root_dir='/home/jovyan/main/BraTS2020_TrainingData/', y_shape=64, z_shape=64):
    # image_dir = os.path.join(root_dir, 'imagesTr')
    image_dir = root_dir
    label_dir = os.path.join(root_dir, 'labelsTr')
    output_dir = os.path.join(root_dir, 'preprocessed')
    classes = 4

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print('Created' + output_dir + '...')

    class_stats = defaultdict(int)
    total = 0
    nii_files=[]
    subfiles(image_dir,nii_files, suffix=".nii.gz", join=False)

    # for i in range(0, len(nii_files)):
    #     if nii_files[i].startswith("._"):
    #         nii_files[i] = nii_files[i][2:]
    # print("--------")
    seg_files=[]
    data_files=[]
    data_itk = []
    seg_itk=[]
    count=0
    for f in nii_files:
        count=count+1
        image,metadata=load(f)
        label=metadata.get_sitkimage()
        if "seg" in f:
            seg_files.append(f)
            
            seg_itk.append(label)
        else:
            data_itk.append(label)
            data_files.append(f)
        if count==5:
            print(seg_files,data_files)
            data_npy = np.vstack([sitk.GetArrayFromImage(d)[None] for d in data_itk])
            print(1,np.array(data_itk[0].GetSize()))
            print(2,np.array(data_itk[0].GetSize())[[2, 1, 0]])
            
            seg_npy = np.vstack([sitk.GetArrayFromImage(s)[None] for s in seg_itk])
            data_npy= data_npy.astype(np.float32)
            seg_npy= data_npy.astype(np.float32)
            # npImage = sitk.GetArrayFromImage(label)
            # print(2,npImage)
            # z = int(label.GetSize()[2]/2)
            # plt.figure(figsize=(5,5))
            # plt.imshow(image[:,:,z], 'gray')
            # plt.show()
            properties = OrderedDict()
            properties["original_size_of_raw_data"] = np.array(data_itk[0].GetSize())[[2, 1, 0]]
            properties["original_spacing"] = np.array(data_itk[0].GetSpacing())[[2, 1, 0]]
            properties["list_of_data_files"] = data_files
            properties["seg_file"] = seg_files
            properties["itk_origin"] = data_itk[0].GetOrigin()
            properties["itk_spacing"] = data_itk[0].GetSpacing()
            properties["itk_direction"] = data_itk[0].GetDirection()
            data_npy,seg_npy,bbox=crop_to_nonzero(data_npy, seg_npy, nonzero_label=-1)    
            properties["crop_bbox"] = bbox
            properties['classes'] = np.unique(seg_npy)
            seg_npy[seg_npy < -1] = 0
            properties["size_after_cropping"] = data_npy[0].shape 
            print(6,properties['classes'])
            case_id=seg_files[0].split("/")[-1].split(".nii.gz")[0][0:-4]
            all_data = np.vstack((data_npy, seg_npy))
            np.savez_compressed(os.path.join('/home/jovyan/main/BraTS2020_TrainingData/'+case_id, "%s.npz" % case_id), data=all_data)
            
            with open(os.path.join('/home/jovyan/main/BraTS2020_TrainingData/'+case_id, "%s.pkl" % case_id), 'wb') as file:
                pickle.dump(properties, file)
            
            count=0
            data_files=[]
            seg_files=[]
            data_itk = []
            seg_itk=[]
        # print(os.path.join(image_dir, f))
        # image, _ = load(os.path.join(image_dir, f))
        # label, _ = load(os.path.join(label_dir, f.replace('_0000', '')))
        # print(image.shape)
        # print(1,image)
        # print(label.getSize())
    
    
preprocess_data()

['/home/jovyan/main/BraTS2020_TrainingData/BraTS20_Training_001/BraTS20_Training_001_seg.nii.gz'] ['/home/jovyan/main/BraTS2020_TrainingData/BraTS20_Training_001/BraTS20_Training_001_flair.nii.gz', '/home/jovyan/main/BraTS2020_TrainingData/BraTS20_Training_001/BraTS20_Training_001_t1.nii.gz', '/home/jovyan/main/BraTS2020_TrainingData/BraTS20_Training_001/BraTS20_Training_001_t1ce.nii.gz', '/home/jovyan/main/BraTS2020_TrainingData/BraTS20_Training_001/BraTS20_Training_001_t2.nii.gz']
1 [240 240 155]
2 [155 240 240]
4 (155, 240, 240)
5 4
6 [-1.000e+00  5.000e+00  6.000e+00 ...  1.823e+03  1.834e+03  1.845e+03]
['/home/jovyan/main/BraTS2020_TrainingData/BraTS20_Training_002/BraTS20_Training_002_seg.nii.gz'] ['/home/jovyan/main/BraTS2020_TrainingData/BraTS20_Training_002/BraTS20_Training_002_flair.nii.gz', '/home/jovyan/main/BraTS2020_TrainingData/BraTS20_Training_002/BraTS20_Training_002_t1.nii.gz', '/home/jovyan/main/BraTS2020_TrainingData/BraTS20_Training_002/BraTS20_Training_002_t1ce.n