In [None]:
# -*- coding:utf-8 -*-

# general imports
import os.path as op
import os

# third party imports
import numpy as np
import nibabel as nib
from scipy.ndimage.interpolation import zoom

import SimpleITK as sitk
import matplotlib.image as mpimg

In [1]:
require_labels = [0,    # background
                  1,    # spleen
                  2,    # right kidney
                  3,    # left kidney
                  6]    # liver

resample_size = [256,256,32]
path = ['../../Dataset/CHAOS/data_process/label/',
              '../../Dataset/CHAOS/data_process/wo_label/']
save_path = '../Dataset/CHAOS/CHAOS_%d_%d_%d/' % \
            (resample_size[0], resample_size[1], resample_size[2])

In [2]:
def create_path(pathlist):
    for path in pathlist:
        if not op.exists(path):
            os.makedirs(path)

def resize_label(img, label, spacing, resample_size):
    assert img.shape == label.shape
    origin_shape = np.array(img.shape)
    new_space = spacing * origin_shape / resample_size
    resize_factor = 1.0 * np.asarray(resample_size) / np.asarray(origin_shape)
    img = zoom(img, resize_factor, order=1) #order = 1：bilinear interpolaion
    label = zoom(label, resize_factor, order=0) #order = 0:nearest
    return img, label, new_space

In [None]:
for j in range(2):
    image_path = path[j]
    files = os.listdir(image_path)
    files.sort()
    print(files,len(files))
    reader = sitk.ImageSeriesReader()
    for patient in files:

        img_names = reader.GetGDCMSeriesFileNames(image_path + patient +'/T2SPIR/DICOM_anon')
        reader.SetFileNames(img_names)
        image = reader.Execute()
        array = sitk.GetArrayFromImage(image).transpose(1,2,0) 
        spacing = image.GetSpacing()#(x,y,z)
        print(array.shape,spacing)
        print(array.dtype)
        
        if j == 0:
            labels = os.listdir(image_path + patient +'/T2SPIR/Ground')
            labels.sort()
            print(labels)
            #get the mask and concat them
            for i in range(len(labels)):
                if i == 0:
                    label = mpimg.imread(image_path + patient +'/T2SPIR/Ground/'+labels[i])[...,np.newaxis]
                else:
                    mask = mpimg.imread(image_path + patient +'/T2SPIR/Ground/'+labels[i])[...,np.newaxis]
                    label = np.concatenate((label,mask),axis=2)

            print('label',label.shape)
            #in mask .png the label with different number , we should change it
            label[label == 0.9882353] = 1
            label[label == 0.49411765] = 2
            label[label == 0.7411765] = 3
            label[label == 0.24705882] = 6
            label[label == 0] = 0

            label = label.astype(np.int8)
            print(label.dtype)
            
            label_unique = np.unique(label)
            print('label_unique',label_unique)
            new_array, new_label, new_spacing = resize_label(array, label, spacing, resample_size)
            print(patient,new_array.shape,new_label.shape,new_spacing)
        
        else :
            origin_shape = np.array(array.shape)
            new_space = spacing * origin_shape / resample_size
            resize_factor = 1.0 * np.asarray(resample_size) / np.asarray(origin_shape)
            new_array = zoom(array, resize_factor, order=1) #order = 1：bilinear interpolaion
            print(patient,new_array.shape,new_spacing)

        if patient in ['1', '5', '8', '10', '13', '20', '22', '31', '34', '38']: # random select validation set
            npy_save_path = op.join(save_path,'valid', str(patient))
            create_path([npy_save_path]) 
            np.save(npy_save_path+'/image.npy',new_array)
            np.save(npy_save_path+'/seg.npy',new_label)
            np.save(npy_save_path+'/spacing.npy',new_spacing)
        else :
            npy_save_path = op.join(save_path,'train', str(patient))
            create_path([npy_save_path])    
            np.save(npy_save_path+'/image.npy',new_array)
            np.save(npy_save_path+'/spacing.npy',new_spacing)
            if j == 0 :
                np.save(npy_save_path+'/seg.npy',new_label)
        print()