In [19]:
# general imports
import os.path as op
import os

# third party imports
import numpy as np
import nibabel as nib
from scipy.ndimage.interpolation import zoom

import SimpleITK as sitk
import matplotlib.image as mpimg
import dicom2nifti
import matplotlib.pyplot as plt

In [42]:
require_labels = [0,    # background
                  1,    # spleen
                  2,    # right kidney
                  3,    # left kidney
                  6]    # liver

resample_size = [256,256,32]
nii_path = '../../Dataset/CHAOS/CHAOS_nii/'
save_path_npy = '../../Dataset/CHAOS/CHAOS_%d_%d_%d/' % \
            (resample_size[0], resample_size[1], resample_size[2])
path = ['../../Dataset/CHAOS/data_process/label/',
              '../../Dataset/CHAOS/data_process/wo_label/']
save_path = '../../Dataset/CHAOS/CHAOS_%d_%d_%d_nii/' % \
            (resample_size[0], resample_size[1], resample_size[2])

In [45]:
def read_nii( volume_path):
    nii = nib.load(volume_path)
    data = nii.get_fdata()
    header = nii.header
    affine = nii.affine
    spacing = list(nii.header['pixdim'][1:4])
    return data, header, affine, spacing

def create_path(pathlist):
    for path in pathlist:
        if not op.exists(path):
            os.makedirs(path)
            
def save_nii( data, save_path, header, affine):
    new_img = nib.Nifti1Image(data, affine, header)
    nib.save(new_img, save_path)

def resize(img, label, spacing, resample_size):
    assert img.shape == label.shape
    origin_shape = img.shape
    new_space = spacing * origin_shape / resample_size
    resize_factor = 1.0 * np.asarray(resample_size) / np.asarray(origin_shape)
    print('resize_factor: ', resize_factor)
    img = zoom(img, resize_factor, order=1) #order = 1：bilinear interpolaion
    label = zoom(label, resize_factor, order=0) #order = 0:nearest
    return img, label, new_space
         
def create_and_save_resize_data(resample_size, img, label, patient,
                                spacing,img_header, img_affine,nii_save_path):
    img, label, new_space = resize(img, label, np.array(spacing), resample_size)
    img_header['pixdim'][1:4] = new_space
    save_nii(img, op.join(nii_save_path, 'procimg.nii.gz'), img_header, img_affine)
    save_nii(label, op.join(nii_save_path, 'seg.nii.gz'), img_header, img_affine)

    
    return img, label, new_space


In [None]:

for j in range(2):
    image_path = path[j]
    files = os.listdir(image_path)
    files.sort()
    print(files,len(files))
    reader = sitk.ImageSeriesReader()
    for patient in files:
        print("Proprecess train image " + str(patient))
        img, img_header, img_affine, img_spacing = read_nii(op.join(nii_path, str(patient)+'/procimg.nii.gz'))
        print(img_affine)
        img = img.astype(np.float32)
        print(img.shape,img.dtype)
        
        if j == 0:
            labels = os.listdir(image_path + patient +'/T2SPIR/Ground')
            labels.sort()
            #print(labels)

            #get the mask and concat them
            for i in range(len(labels)):
                if i == 0:
                    label = mpimg.imread(image_path + patient +'/T2SPIR/Ground/'+labels[i]).transpose()
                    label = np.fliplr(label)
#                     label = np.fliplr(label)
                    label = label[...,np.newaxis]
                else:
                    mask = mpimg.imread(image_path + patient +'/T2SPIR/Ground/'+labels[i]).transpose()
                    mask = np.fliplr(mask)
#                     mask = np.fliplr(mask)[...,np.newaxis]
                    mask = mask[...,np.newaxis]
                    label = np.concatenate((label,mask),axis=2)

            plt.rcParams['figure.figsize'] = (15.0,15.0)
    
            fig,axes = plt.subplots(1,2)
            ax1 = axes[0]
            ax2 = axes[1]

            ax1.imshow(img[:,:,20])
            ax2.imshow(label[:,:,20])
            plt.show()
            print(label.shape)
            
            label_unique = np.unique(label)
            print(label_unique)

            label[label == 0.9882353] = 1
            label[label == 0.49411765] = 2
            label[label == 0.7411765] = 3
            label[label == 0.24705882] = 6
            label[label == 0] = 0

            label = label.astype(np.int64)
            print('Image and label shape before resize: ', img.shape,img.dtype, label.shape,label.dtype)

            label_unique = np.unique(label)
            print(label_unique)
            for l in label_unique:
                if l not in require_labels:
                    label[label == l] = 0 
            
        if patient in ['1', '5', '8', '10', '13', '20', '22', '31', '34', '38']: # random select validation set
            nii_save_path = op.join(save_path,'valid', str(patient))
            create_path([nii_save_path]) 
            img, label, space = create_and_save_resize_data(resample_size, img, label,
                                                  patient, img_spacing,img_header, 
                                                  img_affine, nii_save_path)
            
            fig,axes = plt.subplots(1,2)
            ax1 = axes[0]
            ax2 = axes[1]

            ax1.imshow(img[:,:,20])
            ax2.imshow(label[:,:,20])
            plt.show()
            print(label.shape,space)
            label_unique = np.unique(label)
            print(label_unique)
            
        
        else :
            nii_save_path = op.join(save_path,'train', str(patient))
            create_path([nii_save_path])    
            if j == 0:
                img, label, space = create_and_save_resize_data(resample_size, img, label,
                                                  patient, img_spacing,img_header, 
                                                  img_affine, nii_save_path)

                fig,axes = plt.subplots(1,2)
                ax1 = axes[0]
                ax2 = axes[1]

                ax1.imshow(img[:,:,20])
                ax2.imshow(label[:,:,20])
                plt.show()
                
            else :
                print(patient)
                origin_shape = np.array(img.shape)
                new_space = img_spacing * origin_shape / resample_size
                resize_factor = 1.0 * np.asarray(resample_size) / np.asarray(origin_shape)
                print('resize_factor: ', resize_factor)
                img = zoom(img, resize_factor, order=1) #order = 1：bilinear interpolaion
                img_header['pixdim'][1:4] = new_space
                save_nii(img, op.join(nii_save_path, 'procimg.nii.gz'), img_header, img_affine)

        print('Image and label shape after resize: ', img.shape, label.shape)
        print()