In [15]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
import h5py
import nibabel as nib
import os
import glob
from dev_tools.my_tools import print_red, minmax_normalize
import pdb
import numpy as np
import yaml
from tqdm.notebook import tqdm
import pickle


def create_h5(source_folder, mean_std_file, overwrite=False):
    try:
        affine = np.load('data/affine.npy')
    except FileNotFoundError:
        affine = None
    
    target = os.path.join('data',source_folder.split('_')[-1]+'.h5')
    
    if os.path.exists(target) and not overwrite:
        print('{:s} exists already.'.format(target))
        return
    
    with open(mean_std_file,'rb') as f:
        mean_std_values = pickle.load(f)
    
    with h5py.File(target,'w') as f:
        img_dirs  = glob.glob(os.path.join(source_folder,'*/*' 
                                             if source_folder.split('_')[-1] == 'Training' else '*'))
        for img_dir in tqdm(img_dirs,desc='writing {:s}'.format(target)):
            if not os.path.isdir(img_dir):
                continue
            sub_id = img_dir.split('/')[-1]
            h5_subid = f.create_group(sub_id)
            brain_widths = []
            for mod_file in os.listdir(img_dir):
                img = nib.load(os.path.join(img_dir,mod_file))
                if affine is None:
                    affine = img.affine
                    np.save('data/affine',affine)
                img_npy = img.get_data()
                mod = mod_file.split('_')[-1].split('.')[0]
                if mod != 'seg':
                    img_npy = normalize(img_npy,
                                        mean = mean_std_values['{:s}_mean'.format(mod)],
                                        std = mean_std_values['{:s}_std'.format(mod)])
                    brain_widths.append(cal_outline(img_npy))
                h5_subid.create_dataset(mod_file,data=img_npy)
            start_edge = np.min(brain_widths,axis=0)[0]
            end_edge = np.max(brain_widths,axis=0)[1]
            brain_width = np.vstack((start_edge,end_edge))
            h5_subid.create_dataset('brain_width',data=brain_width)
    return

def cal_outline(img_npy):
    '''
    return an numpy array shape=(2,3), indicating the outline of the brain area.
    '''
    brain_index = np.asarray(np.nonzero(img_npy))
    start_edge = np.maximum(np.min(brain_index,axis=1)-1,0)
    end_edge = np.minimum(np.max(brain_index,axis=1)+1,img_npy.shape)
    
    return np.vstack((start_edge,end_edge))

def normalize(img_npy,mean,std,offset=0.1, mul_factor=100):
    '''
    offset and mul_factor are used to make a distinction between brain voxel and background(zeros).
    '''
    brain_index = np.nonzero(img_npy)
    img_npy[brain_index] = (minmax_normalize((img_npy[brain_index]-mean)/std) + offset) * mul_factor
    return img_npy


def cal_mean_std(source_folder,saved_path,overwrite=False):
    '''
    Calculte the mean value and standard deviation for each modalities.
    Return a dictionary {'t1_mean': ,'t1_std': ,'t2_mean': ,'t2_std': ,...}
    '''
    if os.path.exists(saved_path) and not overwrite:
        print('{:s} exists already.'.format(saved_path))
        return
    sub_dirs = glob.glob(os.path.join(source_folder,'*/*')) # SD
    
    mean_std_values = {}
    
    for mod in config['data']['all_mods']:
        mean = 0
        amount = 0
        for sub_dir in tqdm(sub_dirs,
                             desc='Calculating {:s}\'s mean value'
                             .format(mod)):
            file_name = os.path.join(sub_dir,sub_dir.split('/')[-1]+'_{:s}.nii.gz'.format(mod))
            img_npy = nib.load(file_name).get_data()
            brain_area = img_npy[np.nonzero(img_npy)]
            mean += np.sum(brain_area)
            amount += len(brain_area)
        mean /= amount
        mean_std_values['{:s}_mean'.format(mod)] = round(mean,4)
        print('{:s}\'s mean value = {:.2f}'.format(mod,mean))
        
        std = 0
        for sub_dir in tqdm(sub_dirs,
                             desc='Calculating {:s}\'s std value'
                             .format(mod)):
            file_name = os.path.join(sub_dir,sub_dir.split('/')[-1]+'_{:s}.nii.gz'.format(mod))
            img_npy = nib.load(file_name).get_data()
            brain_area = img_npy[np.nonzero(img_npy)]
            std += np.sum((brain_area-mean)**2)
        std = np.sqrt(std/amount)
        mean_std_values['{:s}_std'.format(mod)] = round(std,4)
        print('{:s}\'s std value = {:.2f}'.format(mod,std))
    print(mean_std_values)
    with open(saved_path,'wb') as f:
        pickle.dump(mean_std_values,f)
   
                          
                          

In [17]:
def hello_test():
    with open('config.yml') as f:
        config = yaml.load(f,Loader=yaml.FullLoader)


    cal_mean_std(source_folder=config['data']['source_train'],
                 saved_path=config['data']['mean_std_file'])

    mean_std_file = config['data']['mean_std_file']
    create_h5(config['data']['source_train'],mean_std_file)
    create_h5(config['data']['source_val'],mean_std_file)
    create_h5(config['data']['source_test'],mean_std_file)


In [49]:
round(1/3,3) * 3

0.9990000000000001

In [56]:
from random import shuffle

    
def cross_val_split(num_sbjs, saved_path, num_folds=5, overwrite=False):
    '''
    To generate num_folds cross validation.
    Return {'subid index list'}
    '''
    if os.path.exists(saved_path) and not overwrite:
        print('{:s} exists already.'.format(saved_path))
        return
    subid_indices = list(range(num_sbjs))
    shuffle(subid_indices)
    res = {}
    for i in range(num_folds):
        left = int(i/num_folds * num_sbjs)
        right = int((i+1)/num_folds * num_sbjs)
        res['train_list_{:d}'.format(i)] = subid_indices[:left] + subid_indices[right:]
        res['val_list_{:d}'.format(i)] = subid_indices[left : right]
    with open(saved_path,'wb') as f:
        pickle.dump(res,f)
        
        

with open('config.yml') as f:
    config = yaml.load(f,Loader=yaml.FullLoader)
with h5py.File('data/Training.h5','r') as f:
    num_sbjs = len(f)
cross_val_indices = config['data']['cross_val_indices']  
cross_val_split(num_sbjs, cross_val_indices)

def get_training_and_validation_generators(data_file, batch_size, n_labels, training_keys_file, 
                                           validation_keys_file,
                                           data_split=0.8, overwrite=False, labels=None, augment=False,
                                           augment_flip=True, augment_distortion_factor=0.25, 
                                           patch_shape=None,
                                           validation_patch_overlap=0, training_patch_start_offset=None,
                                           validation_batch_size=None, skip_blank=True, permute=False,
                                           num_model=1,
                                           pred_specific=False, overlap_label=True,
                                           for_final_val=False):
    pass
    #     pdb.set_trace()
#     if not validation_batch_size:
#         validation_batch_size = batch_size

#     training_list, validation_list = get_validation_split(data_file,
#                                                           data_split=data_split,
#                                                           overwrite=overwrite,
#                                                           training_file=training_keys_file,
#                                                           validation_file=validation_keys_file)
#     if for_final_val:
#         training_list = training_list + validation_list

#     training_generator = data_generator(data_file, training_list,
#                                         batch_size=batch_size,
#                                         n_labels=n_labels,
#                                         labels=labels,
#                                         augment=augment,
#                                         augment_flip=augment_flip,
#                                         augment_distortion_factor=augment_distortion_factor,
#                                         patch_shape=patch_shape,
#                                         patch_overlap=validation_patch_overlap,
#                                         patch_start_offset=training_patch_start_offset,
#                                         skip_blank=skip_blank,
#                                         permute=permute,
#                                         num_model=num_model,
#                                         pred_specific=pred_specific,
#                                         overlap_label=overlap_label)

#     validation_generator = data_generator(data_file, validation_list,
#                                           batch_size=validation_batch_size,
#                                           n_labels=n_labels,
#                                           labels=labels,
#                                           patch_shape=patch_shape,
#                                           patch_overlap=validation_patch_overlap,
#                                           skip_blank=skip_blank,
#                                           num_model=num_model,
#                                           pred_specific=pred_specific,
#                                           overlap_label=overlap_label)

#     # Set the number of training and testing samples per epoch correctly
#     #     pdb.set_trace()
#     if os.path.exists('num_patches_training.npy'):
#         num_patches_training = int(np.load('num_patches_training.npy'))
#     else:
#         num_patches_training = get_number_of_patches(data_file, training_list, patch_shape,
#                                                        skip_blank=skip_blank,
#                                                        patch_start_offset=training_patch_start_offset,
#                                                        patch_overlap=validation_patch_overlap,
#                                                        pred_specific=pred_specific)
#         np.save('num_patches_training', num_patches_training)
#     num_training_steps = get_number_of_steps(num_patches_training, batch_size)
#     print("Number of training steps in each epoch: ", num_training_steps)

#     if os.path.exists('num_patches_val.npy'):
#         num_patches_val = int(np.load('num_patches_val.npy'))
#     else:
#         num_patches_val = get_number_of_patches(data_file, validation_list, patch_shape,
#                                                  skip_blank=skip_blank,
#                                                  patch_overlap=validation_patch_overlap,
#                                                  pred_specific=pred_specific)
#         np.save('num_patches_val', num_patches_val)
#     num_validation_steps = get_number_of_steps(num_patches_val, validation_batch_size)
#     print("Number of validation steps in each epoch: ", num_validation_steps)

#     return training_generator, validation_generator, num_training_steps, num_validation_steps

data/cross_val_indices.pkl exists already.


In [None]:
import os
import copy

import itertools

import numpy as np

from .utils import pickle_dump, pickle_load
from .patches import compute_patch_indices, get_random_nd_index, get_patch_from_3d_data, compute_patch_indices_for_prediction
from .augment import augment_data, random_permutation_x_y

import pdb
from dev_tools.my_tools import print_red
from tqdm import tqdm
import time

class train_generator(Generator):
    def __init__(self):
        super().__init__()

class Generator():
    def __init__(self,config_file):
        with open(config_file) as f:
            config = yaml.load(f,Loader=yaml.FullLoader)


# train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
#         data_file_opened,
#         batch_size=config["batch_size"],
#         data_split=config["validation_split"],
#         overwrite=overwrite,
#         validation_keys_file=config["validation_file"],
#         training_keys_file=config["training_file"],
#         n_labels=config["n_labels"],
#         labels=config["labels"],
#         patch_shape=config["patch_shape"],
#         validation_batch_size=config["validation_batch_size"],
#         validation_patch_overlap=config["validation_patch_overlap"],
#         training_patch_start_offset=config["training_patch_start_offset"],
#         permute=config["permute"],
#         augment=config["augment"],
#         skip_blank=config["skip_blank"],
#         augment_flip=config["flip"],
#         augment_distortion_factor=config["distort"],
#         pred_specific=config['pred_specific'],
#         overlap_label=config['overlap_label_generator'],
#         for_final_val=config['for_final_val'])        
        
   



    def get_number_of_steps(n_samples, batch_size):
        if n_samples <= batch_size:
            return n_samples
        elif np.remainder(n_samples, batch_size) == 0:
            return n_samples//batch_size
        else:
            return n_samples//batch_size + 1


    


    




    def data_generator(data_file, index_list, batch_size=1, n_labels=1, labels=None, augment=False, augment_flip=True,
                       augment_distortion_factor=0.25, patch_shape=None, patch_overlap=0, patch_start_offset=None,
                       shuffle_index_list=True, skip_blank=True, permute=False, num_model=1, pred_specific=False,overlap_label=False):
    #     pdb.set_trace()

        orig_index_list = index_list
        while True:
            x_list = list()
            y_list = list()
            if patch_shape:
                index_list = create_patch_index_list(orig_index_list, data_file, patch_shape,
                                                     patch_overlap, patch_start_offset,pred_specific=pred_specific)
            else:
                index_list = copy.copy(orig_index_list)

            if shuffle_index_list:
                shuffle(index_list)
            while len(index_list) > 0:
                index = index_list.pop()
                add_data(x_list, y_list, data_file, index, augment=augment, augment_flip=augment_flip,
                         augment_distortion_factor=augment_distortion_factor, patch_shape=patch_shape,
                         skip_blank=skip_blank, permute=permute)
                if len(x_list) == batch_size or (len(index_list) == 0 and len(x_list) > 0):
                    yield convert_data(x_list, y_list, n_labels=n_labels, labels=labels, num_model=num_model,overlap_label=overlap_label)
    #                 convert_data(x_list, y_list, n_labels=n_labels, labels=labels, num_model=num_model)
                    x_list = list()
                    y_list = list()



    def get_number_of_patches(data_file, index_list, patch_shape=None, patch_overlap=0, patch_start_offset=None,
                              skip_blank=True,pred_specific=False):
        if patch_shape:
            index_list = create_patch_index_list(index_list, data_file, patch_shape, patch_overlap,
                                                 patch_start_offset,pred_specific=pred_specific)
            count = 0
            for index in tqdm(index_list):
                x_list = list()
                y_list = list()
                add_data(x_list, y_list, data_file, index, skip_blank=skip_blank, patch_shape=patch_shape)
                if len(x_list) > 0:
                    count += 1
            return count
        else:
            return len(index_list)


    def create_patch_index_list(index_list, data_file, patch_shape, patch_overlap, patch_start_offset=None, pred_specific=False):
        patch_index = list()
        for index in index_list:
            brain_width = data_file.root.brain_width[index]
            image_shape = brain_width[1] - brain_width[0] + 1
            if pred_specific:
                patches = compute_patch_indices_for_prediction(image_shape, patch_shape)
            else:
                if patch_start_offset is not None:
                    random_start_offset = np.negative(get_random_nd_index(patch_start_offset))
                    patches = compute_patch_indices(image_shape, patch_shape, overlap=patch_overlap, start=random_start_offset)
                else:
                    patches = compute_patch_indices(image_shape, patch_shape, overlap=patch_overlap)
            patch_index.extend(itertools.product([index], patches))
        return patch_index


    def add_data(x_list, y_list, data_file, index, augment=False, augment_flip=False, augment_distortion_factor=0.25,
                 patch_shape=False, skip_blank=True, permute=False):
        '''
        add qualified x,y to the generator list
        '''
    #     pdb.set_trace()
        data, truth = get_data_from_file(data_file, index, patch_shape=patch_shape)

        if np.sum(truth) == 0:
            return
        if augment:
            affine = np.load('affine.npy')
            data, truth = augment_data(data, truth, affine, flip=augment_flip, scale_deviation=augment_distortion_factor)

        if permute:
            if data.shape[-3] != data.shape[-2] or data.shape[-2] != data.shape[-1]:
                raise ValueError("To utilize permutations, data array must be in 3D cube shape with all dimensions having "
                                 "the same length.")
            data, truth = random_permutation_x_y(data, truth[np.newaxis])
        else:
            truth = truth[np.newaxis]

        if not skip_blank or np.any(truth != 0):
            x_list.append(data)
            y_list.append(truth)


    def get_data_from_file(data_file, index, patch_shape=None):
    #     pdb.set_trace()
        if patch_shape:
            index, patch_index = index
            data, truth = get_data_from_file(data_file, index, patch_shape=None)
            x = get_patch_from_3d_data(data, patch_shape, patch_index)
            y = get_patch_from_3d_data(truth, patch_shape, patch_index)
        else:
            brain_width = data_file.root.brain_width[index]
            x = np.array([modality_img[index,0,
                                       brain_width[0,0]:brain_width[1,0]+1,
                                       brain_width[0,1]:brain_width[1,1]+1,
                                       brain_width[0,2]:brain_width[1,2]+1] 
                          for modality_img in [data_file.root.t1,
                                               data_file.root.t1ce,
                                               data_file.root.flair,
                                               data_file.root.t2]])
            y = data_file.root.truth[index, 0,
                                     brain_width[0,0]:brain_width[1,0]+1,
                                     brain_width[0,1]:brain_width[1,1]+1,
                                     brain_width[0,2]:brain_width[1,2]+1]
        return x, y


    def convert_data(x_list, y_list, n_labels=1, labels=None, num_model=1,overlap_label=False):
    #     pdb.set_trace()
        x = np.asarray(x_list)
        y = np.asarray(y_list)
        if n_labels == 1:
            y[y > 0] = 1
        elif n_labels > 1:
            if overlap_label:
                y = get_multi_class_labels_overlap(y, n_labels=n_labels, labels=labels)
            else:
                y = get_multi_class_labels(y, n_labels=n_labels, labels=labels)
        if num_model == 1:
            return x, y
        else:
            return [x]*num_model, y


    def get_multi_class_labels_overlap(data, n_labels=3, labels=(1,2,4)):
        """
        4: ET
        1+4: TC
        1+2+4: WT
        """
    #     pdb.set_trace()
        new_shape = [data.shape[0], n_labels] + list(data.shape[2:])
        y = np.zeros(new_shape, np.int8)

        y[:,0][np.logical_or(data[:,0] == 1,data[:,0] == 4)] = 1    #1
        y[:,1][np.logical_or(data[:,0] == 1,data[:,0] == 2, data[:,0] == 4)] = 1 #2
        y[:,2][data[:,0] == 4] = 1    #4
        return y

In [None]:
from dev_tools.my_tools import print2d
import matplotlib.pyplot as plt

img = nib.load('data/MICCAI_BraTS_2019_Data_Training/HGG\
/BraTS19_TMC_11964_1/BraTS19_TMC_11964_1_t1.nii.gz').get_data()
plt.figure()
plt.hist(np.ravel(img))
print2d(img)
with h5py.File('data/Training.h5','r') as f:
    norm_img = f['BraTS19_TMC_11964_1']['BraTS19_TMC_11964_1_t1.nii.gz']
    print2d(norm_img)
    plt.figure()
    plt.hist(np.ravel(norm_img))

In [73]:
b = img[np.nonzero(img)]
np.sum(b)/len(b)

534.3123637025672

In [52]:
a = {'a':1,'b':2}
with open('test.pkl','wb') as f:
    pickle.dump(a,f)

In [60]:
with open('data/mean_std.pkl','rb') as f:
    res = pickle.load(f)
    print(res)

{'t1_std': 1082.5379, 't1_mean': 571.9798, 't1ce_std': 1093.0112, 't2_mean': 652.5108, 'flair_mean': 411.4047, 't2_std': 1285.4105, 'flair_std': 1219.138, 't1ce_mean': 637.505}


In [64]:
np.mean([1,2,3,50,34])

18.0

In [66]:
np.mean([np.mean([1,34]),np.mean([2,3,50])])

17.916666666666664