In [9]:
from collections import defaultdict

from medpy.io import load
import os
import numpy as np
import torch

import random
from copy import deepcopy
from scipy.ndimage import map_coordinates, fourier_gaussian
from scipy.ndimage.filters import gaussian_filter, gaussian_gradient_magnitude
from scipy.ndimage.morphology import grey_dilation
from skimage.transform import resize
from scipy.ndimage.measurements import label as lb
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import SimpleITK as sitk
from collections import OrderedDict
import pickle
from subprocess import check_call
# from utilities.file_and_folder_operations import subfiles
def subfiles(folder,res, join=True, prefix=None, suffix=None, sort=True):
    # if join:
    #     l = os.path.join
    # else:
    #     l = lambda x, y: y
    # for i in os.listdir(folder):
    #     print(i)
    # res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i))
    #         and (prefix is None or i.startswith(prefix))
    #         and (suffix is None or i.endswith(suffix))]
    
    dirList=[]
    for i in os.listdir(folder):
        wholepath = os.path.join(folder, i)
        if os.path.isdir(wholepath):
            dirList.append(wholepath)
        if os.path.isfile(wholepath):
            res.append(wholepath)
            if not wholepath.endswith(suffix):
                res.remove(wholepath)
    if dirList:
        for subDir in dirList:
            subfiles(subDir,res,join=False,suffix=".nii.gz")
    if sort:
        res.sort()
    

def create_nonzero_mask(data):
    from scipy.ndimage import binary_fill_holes
    assert len(data.shape) == 4 or len(data.shape) == 3, "data must have shape (C, X, Y, Z) or shape (C, X, Y)"
    nonzero_mask = np.zeros(data.shape[1:], dtype=bool)
    # print(4,data.shape[1:]) 155*240*240
    # print(5,data.shape[0]) 4
    for c in range(data.shape[0]):
        this_mask = data[c] != 0
        nonzero_mask = nonzero_mask | this_mask
    nonzero_mask = binary_fill_holes(nonzero_mask)
    return nonzero_mask


def get_bbox_from_mask(mask, outside_value=0):
    mask_voxel_coords = np.where(mask != outside_value)
    minzidx = int(np.min(mask_voxel_coords[0]))
    maxzidx = int(np.max(mask_voxel_coords[0])) + 1
    minxidx = int(np.min(mask_voxel_coords[1]))
    maxxidx = int(np.max(mask_voxel_coords[1])) + 1
    minyidx = int(np.min(mask_voxel_coords[2]))
    maxyidx = int(np.max(mask_voxel_coords[2])) + 1
    return [[minzidx, maxzidx], [minxidx, maxxidx], [minyidx, maxyidx]]


def crop_to_bbox(image, bbox):
    assert len(image.shape) == 3, "only supports 3d images"
    resizer = (slice(bbox[0][0], bbox[0][1]), slice(bbox[1][0], bbox[1][1]), slice(bbox[2][0], bbox[2][1]))
    return image[resizer]

def crop_to_nonzero(data, seg=None, nonzero_label=-1):
    """
    :param data:
    :param seg:
    :param nonzero_label: this will be written into the segmentation map
    :return:
    """
    nonzero_mask = create_nonzero_mask(data)
    bbox = get_bbox_from_mask(nonzero_mask, 0)
    # print(9,bbox)

    cropped_data = []
    for c in range(data.shape[0]):
        # if c==0:
        #     print(7,data[0].shape)
        cropped = crop_to_bbox(data[c], bbox)
        # if c==0:
        #     print(8,cropped.shape)
        cropped_data.append(cropped[None])
    data = np.vstack(cropped_data)

    if seg is not None:
        cropped_seg = []
        for c in range(seg.shape[0]):
            
            cropped = crop_to_bbox(seg[c], bbox)
            cropped_seg.append(cropped[None])
        seg = np.vstack(cropped_seg)

    nonzero_mask = crop_to_bbox(nonzero_mask, bbox)[None]
    if seg is not None:
        seg[(seg == 0) & (nonzero_mask == 0)] = nonzero_label
    
    return data, seg, bbox

def preprocess_data(root_dir='/home/jovyan/main/BraTS2020_TrainingData/', y_shape=64, z_shape=64):
    # image_dir = os.path.join(root_dir, 'imagesTr')
    image_dir = root_dir
    label_dir = os.path.join(root_dir, 'labelsTr')
    output_dir = os.path.join(root_dir, 'preprocessed')
    classes = 4

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print('Created' + output_dir + '...')

    class_stats = defaultdict(int)
    total = 0
    nii_files=[]
    subfiles(image_dir,nii_files, suffix=".nii.gz", join=False)

    # for i in range(0, len(nii_files)):
    #     if nii_files[i].startswith("._"):
    #         nii_files[i] = nii_files[i][2:]
    # print("--------")
    seg_files=[]
    data_files=[]
    data_itk = []
    seg_itk=[]
    count=0
    for f in nii_files:
        count=count+1
        image,metadata=load(f)
        label=metadata.get_sitkimage()
        if "seg" in f:
            seg_files.append(f)
            
            seg_itk.append(label)
            
        else:
            data_itk.append(label)
            data_files.append(f)
        if count==5:
            print(seg_files,data_files)
            data_npy = np.vstack([sitk.GetArrayFromImage(d)[None] for d in data_itk])
            print(1,np.array(data_itk[0].GetSize()))
            print(2,np.array(data_itk[0].GetSize())[[2, 1, 0]])
            
            seg_npy = np.vstack([sitk.GetArrayFromImage(s)[None] for s in seg_itk])
            data_npy= data_npy.astype(np.float32)
            seg_npy= seg_npy.astype(np.float32)
            # npImage = sitk.GetArrayFromImage(label)
            # print(2,npImage)
            # z = int(label.GetSize()[2]/2)
            # plt.figure(figsize=(5,5))
            # plt.imshow(image[:,:,z], 'gray')
            # plt.show()
            properties = OrderedDict()
            properties["original_size_of_raw_data"] = np.array(data_itk[0].GetSize())[[2, 1, 0]]
            properties["original_spacing"] = np.array(data_itk[0].GetSpacing())[[2, 1, 0]]
            properties["list_of_data_files"] = data_files
            properties["seg_file"] = seg_files
            properties["itk_origin"] = data_itk[0].GetOrigin()
            properties["itk_spacing"] = data_itk[0].GetSpacing()
            properties["itk_direction"] = data_itk[0].GetDirection()
            print(4,data_npy.shape,seg_npy.shape)
            data_npy,seg_npy,bbox=crop_to_nonzero(data_npy, seg_npy, nonzero_label=-1)    
            print(5,data_npy.shape,seg_npy.shape)
            properties["crop_bbox"] = bbox
            properties['classes'] = np.unique(seg_npy)
            seg_npy[seg_npy < -1] = 0
            properties["size_after_cropping"] = data_npy[0].shape 
            print(6,properties['classes'],properties['classes'].shape)
            case_id=seg_files[0].split("/")[-1].split(".nii.gz")[0][0:-4]
            all_data = np.vstack((data_npy, seg_npy))
            np.savez_compressed(os.path.join('/home/jovyan/main/BraTS2020_TrainingData/'+case_id, "%s.npz" % case_id), data=all_data)
            with open(os.path.join('/home/jovyan/main/BraTS2020_TrainingData/'+case_id, "%s.pkl" % case_id), 'wb') as file:
                pickle.dump(properties, file)
            
            count=0
            data_files=[]
            seg_files=[]
            data_itk = []
            seg_itk=[]
        # print(os.path.join(image_dir, f))
        # image, _ = load(os.path.join(image_dir, f))
        # label, _ = load(os.path.join(label_dir, f.replace('_0000', '')))
        # print(image.shape)
        # print(1,image)
        # print(label.getSize())

def intensity_normalization(cropped_output_dir='/home/jovyan/main/BraTS2020_TrainingData/', case_identifier='BraTS20_Training_001'):
    all_data = np.load(os.path.join(cropped_output_dir+case_identifier, "%s.npz" % case_identifier))['data']
    data = all_data[:-1].astype(np.float32)
    seg = all_data[-1:]
    with open(os.path.join(cropped_output_dir+case_identifier, "%s.pkl" % case_identifier), 'rb') as f:
        properties = pickle.load(f)
    # print(len(data))
    print(type(data))
    intensity_properties={}
    intensity=[]
    for i in range(0,len(data)):
        # print(data[i].shape)
        mn=np.mean(data[i])
        std=np.std(data[i])
        lower_bound=np.percentile(data[i],1)
        upper_bound=np.percentile(data[i],99)
        data[i]=np.clip(data[i], lower_bound, upper_bound)
        data[i] = (data[i] - mn) / std
    all_data = np.vstack((data, seg)).astype(np.float32)
    num_samples = 10000
    min_percent_coverage = 0.01 # at least 1% of the class voxels need to be selected, otherwise it may be too sparse
    rndst = np.random.RandomState(1234)
    class_locs = {}
    all_classes=[0,1,2,4]
    for c in all_classes:
        all_locs = np.argwhere(all_data[-1] == c)
        if len(all_locs) == 0:
            class_locs[c] = []
            continue
        target_num_samples = min(num_samples, len(all_locs))
        target_num_samples = max(target_num_samples, int(np.ceil(len(all_locs) * min_percent_coverage)))
        selected = all_locs[rndst.choice(len(all_locs), target_num_samples, replace=False)]
        class_locs[c] = selected
        print(c, target_num_samples)
    properties['class_locations'] = class_locs
    np.savez_compressed(os.path.join('/home/jovyan/main/BraTS2020_TrainingData/'+case_identifier, "%s_normalized.npz" % case_identifier), data=all_data)
    with open(os.path.join('/home/jovyan/main/BraTS2020_TrainingData/'+case_identifier, "%s_normalized.pkl" % case_identifier), 'wb') as file:
        pickle.dump(properties, file)
    
    
# preprocess_data()
for i in range(1,81):
    name=''+str(i//100)+str(i//10%10)+str(i%10)
    case_identifier='BraTS20_Training_'+name
    intensity_normalization(case_identifier=case_identifier)

<class 'numpy.ndarray'>
0 11310
1 10000
2 10000
4 10000
<class 'numpy.ndarray'>
0 14722
1 9160
2 10000
4 6549
<class 'numpy.ndarray'>
0 12314
1 733
2 10000
4 2998
<class 'numpy.ndarray'>
0 14197
1 10000
2 10000
4 10000
<class 'numpy.ndarray'>
0 14199
1 3624
2 7553
4 10000
<class 'numpy.ndarray'>
0 13293
1 10000
2 10000
4 10000
<class 'numpy.ndarray'>
0 13379
1 3398
2 10000
4 6159
<class 'numpy.ndarray'>
0 15115
1 157
2 10000
4 1211
<class 'numpy.ndarray'>
0 11392
1 10000
2 10000
4 10000
<class 'numpy.ndarray'>
0 12446
1 521
2 10000
4 6473
<class 'numpy.ndarray'>
0 15513
1 4603
2 10000
4 5774
<class 'numpy.ndarray'>
0 15764
1 2307
2 10000
4 5973
<class 'numpy.ndarray'>
0 14199
1 10000
2 10000
4 4039
<class 'numpy.ndarray'>
0 16489
1 9313
2 10000
4 10000
<class 'numpy.ndarray'>
0 14886
1 8675
2 10000
4 9002
<class 'numpy.ndarray'>
0 13356
1 10000
2 10000
4 10000
<class 'numpy.ndarray'>
0 14645
1 7491
2 10000
4 7222
<class 'numpy.ndarray'>
0 12696
1 341
2 10000
4 1695
<class 'numpy.ndarra