In [25]:
import torch
import pandas as pd
import scipy.ndimage as ndi
import numpy as np
import nibabel as nib
import os
import pprint
import sys

pp = pprint.PrettyPrinter(depth=4)

## Set data paths

In [2]:
ROOT_RAW_TRAINING_DATA_PATH = "../../../unet-lits-2d-pipeline/OriginalData/Training_Data/"
ROOT_RAW_TEST_DATA_PATH = "../../../unet-lits-2d-pipeline/OriginalData/Test_Data/"

print(f"Found {len(os.listdir(ROOT_RAW_TRAINING_DATA_PATH)) // 2} training samples at {ROOT_RAW_TRAINING_DATA_PATH}")
print(f"Found {len(os.listdir(ROOT_RAW_TEST_DATA_PATH))} training samples at {ROOT_RAW_TEST_DATA_PATH}")

Found 131 training samples at ../../../unet-lits-2d-pipeline/OriginalData/Training_Data/
Found 70 training samples at ../../../unet-lits-2d-pipeline/OriginalData/Test_Data/


In [3]:
ROOT_PREPROCESSED_TRAINING_DATA_PATH = "../../../unet-lits-2d-pipeline/LOADDATA/Training_Data_2D/"
ROOT_PREPROCESSED_TEST_DATA_PATH = "../../../unet-lits-2d-pipeline/LOADDATA/Test_Data_2D/"

## Compute Basic Dataset Parameter Values

In [4]:
# Minimum Hounsfield voxel value for tissue. Any value smaller than this is conventinally thought to correspond to water.
MIN_BOUND = -100

# Maximum Hounsfield voxel value for tissue. Any value greater than this is conventionally thought to correspond to bone structure.
MAX_BOUND = 400

In [5]:
# Compute LiTS dataset mean and standard deviation
# test_volumes = os.listdir(ROOT_RAW_TEST_DATA_PATH)
# training_volumes = os.listdir(ROOT_RAW_TRAINING_DATA_PATH)
# training_volumes = [elem for elem in training_volumes if "volume" in elem]
# dataset_volumes = training_volumes + test_volumes

# volume_means = []
# volume_stds = []

# for volume_index, volume_name in enumerate(dataset_volumes):
#     root_path = ""
#     if volume_name in test_volumes:
#         root_path = ROOT_RAW_TEST_DATA_PATH
#     elif volume_name in training_volumes:
#         root_path = ROOT_RAW_TRAINING_DATA_PATH

#     volume = nib.load(root_path + volume_name)

#     volume_data = volume.get_fdata()
#     min_voxel_value, max_voxel_value = volume_data.min(), volume_data.max()
#     normalized_volume_data = (volume_data - min_voxel_value) / (max_voxel_value - min_voxel_value)
    
#     normalized_volume_mean = np.mean(normalized_volume_data)
#     volume_means.append(normalized_volume_mean)
    
#     normalized_volume_std = np.sqrt(np.sum((normalized_volume_data - normalized_volume_mean) ** 2) / normalized_volume_data.size)
#     volume_stds.append(normalized_volume_std)

# DATASET_MEAN = np.mean(np.array(volume_means))
# DATASET_STD = np.mean(np.array(volume_stds))

# print(DATASET_MEAN)
# print(DATASET_STD)

In [6]:
MANUALLY_COMPUTED_DATASET_MEAN = 0.1572
MANUALLY_COMPUTED_DATASET_STD = 0.14909

# Expected values on the LiTS dataset
DATASET_MEAN = 0.1021
DATASET_STD = 0.19177

## Reference Dataset Class

In [7]:
"""======================================"""
"""========== Basic Utilities ==========="""
"""======================================"""
def set_bounds(image,MIN_BOUND,MAX_BOUND):
    """
    Clip image to lower bound MIN_BOUND, upper bound MAX_BOUND.
    """
    return np.clip(image,MIN_BOUND,MAX_BOUND)

def normalize(image,use_bd=True,zero_center=True,unit_variance=True,supply_mode="orig"):
    """
    Perform standardization/normalization, i.e. zero_centering and Setting
    the data to unit variance.
    Input Arguments are self-explanatory except for:
    supply_mode: Describes the type of LiTS-Data, i.e. whether it has been
                 rescaled/resized or not. See >Basic_Parameter_Values<
    """
    if not use_bd:
        MIN_BOUND = np.min(image)
        MAX_BOUND = np.max(image)
    else:
        MIN_BOUND = -100.0 #Everything below: Water
        MAX_BOUND = 400.0
        image = set_bounds(image,MIN_BOUND,MAX_BOUND)
    image = (image - MIN_BOUND) / (MAX_BOUND - MIN_BOUND)
    image = np.clip(image,0.,1.)
    
    # TODO: Figure out how the mean and std values are computed
    if zero_center:
        image = image - DATASET_MEAN
    if unit_variance:
        image = image/DATASET_STD
    return image

"""======================================"""
"""============ Augmentation ============"""
"""======================================"""
##############################################################################################
def rotate_2D(to_aug, rng=np.random.RandomState(1)):
    """
    Perform standard 2D-per-slice image rotation.
    Arguments:
    to_aug:     List of files that should be deformed in the same way. Each element
                must be of standard Torch_Tensor shape: (C,W,H,...).
                Deformation is done equally for each channel, but differently for
                each image in a batch if N!=1.
    rng:        Random Number Generator that can be provided for the Gaussian filter means.
    copy_files: If True, copies the input files before transforming. Ensures that the actual
                input data remains untouched. Otherwise, it is directly altered.

    Function only returns data when copy_files==True.
    """
    angle = (rng.rand()*2-1)*10
    for i,aug_file in enumerate(to_aug):
        for ch in range(aug_file.shape[0]):
            #actually perform rotation
            aug_file[ch,:]    = ndi.rotate(aug_file[ch,:].astype(np.float32), angle, reshape=False, order=0, mode="nearest")
    return to_aug, angle


##############################################################################################
def zoom_2D(to_aug, rng=np.random.RandomState(1)):
    """
    Perform standard 2D per-slice zooming/rescaling.
    Arguments:
    to_aug:     List of files that should be deformed in the same way. Each element
                must be of standard Torch_Tensor shape: (N,C,W,H,...).
                Deformation is done equally for each channel, but differently for
                each image in a batch if N!=1.
    rng:        Random Number Generator that can be provided for the Gaussian filter means.
    copy_files: If True, copies the input files before transforming. Ensures that the actual
                input data remains untouched. Otherwise, it is directly altered.

    Function only returns data when copy_files==True.
    Note: Should also work for 3D, but has not been tested for that.
    """
    # TODO: Figure out how the magnification range limits are computed
    magnif = rng.uniform(0.825,1.175)
    for i,aug_file in enumerate(to_aug):
        for ch in range(aug_file.shape[0]):
            sub_img     = aug_file[ch,:]
            # sub_mask    = aug_file[ch,:]
            img_shape   = np.array(sub_img.shape)
            new_shape   = [int(np.round(magnif*shape_val)) for shape_val in img_shape]
            zoomed_shape= (magnif,)*(sub_img.ndim)

            if magnif<1:
                how_much_to_clip    = [(x-y)//2 for x,y in zip(img_shape, new_shape)]
                idx_cornerpix       = tuple(-1 for _ in range(sub_img.ndim))
                idx_zoom            = tuple(slice(x,x+y) for x,y in zip(how_much_to_clip,new_shape))
                zoomed_out_img      = np.ones_like(sub_img)*sub_img[idx_cornerpix]
                zoomed_out_img[idx_zoom] = ndi.zoom(sub_img.astype(np.float32),zoomed_shape,order=0,mode="nearest")
                aug_file[ch,:]        = zoomed_out_img

            if magnif>1:
                zoomed_in_img       = ndi.zoom(sub_img.astype(np.float32),zoomed_shape,order=0,mode="nearest")
                rounding_correction = [(x-y)//2 for x,y in zip(zoomed_in_img.shape,img_shape)]
                rc_idx              = tuple(slice(x,x+y) for x,y in zip(rounding_correction, img_shape))
                aug_file[ch,:]   = zoomed_in_img[rc_idx]

    return to_aug


##############################################################################################
def hflip_2D(to_aug, rng=np.random.RandomState(1)):
    """
    Perform standard 2D per-slice horizontal_flipping.
    Arguments:
    to_aug:     List of files that should be deformed in the same way. Each element
                must be of standard Torch_Tensor shape: (N,C,W,H,...).
                Deformation is done equally for each channel, but differently for
                each image in a batch if N!=1.
    rng:        Random Number Generator that can be provided for the Gaussian filter means.
    copy_files: If True, copies the input files before transforming. Ensures that the actual
                input data remains untouched. Otherwise, it is directly altered.

    Function only returns data when copy_files==True.
    Note: Should also work for 3D, but has not been tested for that.
    """
    for i,aug_file in enumerate(to_aug):
        for ch in range(aug_file.shape[0]):
            aug_file[ch,:]  = np.fliplr(aug_file[ch,:])

    return to_aug


##############################################################################################
def vflip_2D(to_aug, rng=np.random.RandomState(1)):
    """
    Perform standard 2D per-slice vertical flipping.
    Arguments:
    to_aug:     List of files that should be deformed in the same way. Each element
                must be of standard Torch_Tensor shape: (N,C,W,H,...).
                Deformation is done equally for each channel, but differently for
                each image in a batch if N!=1.
    rng:        Random Number Generator that can be provided for the Gaussian filter means.
    copy_files: If True, copies the input files before transforming. Ensures that the actual
                input data remains untouched. Otherwise, it is directly altered.

    Function only returns data when copy_files==True.
    Note: Should also work for 3D, but has not been tested for that.
    """
    for i,aug_file in enumerate(to_aug):
        for ch in range(aug_file.shape[0]):
            aug_file[ch,:]  = np.flipud(aug_file[ch,:])

    return to_aug

def augment_2D(to_aug, mode_dict=["rot","zoom"], copy_files=False, return_files=False, seed=1, is_mask=[0,1,0]):
    """
    Combine all augmentation methods to perform data augmentation (in 2D). Selection is done randomly.
    Arguments:
    to_aug:     List of files that should be deformed in the same way. Each element is a list with
                Arrays of standard Torch_Tensor shape: (C,W,H,...).
                Augmentation is done equally for each channel, but differently for
                each image in a batch if N!=1.
    mode_dict:  List of augmentation methods that should be used.
    rng:        Random Number Generator that can be provided for the Gaussian filter means.
    copy_files: If True, copies the input files before transforming. Ensures that the actual
                input data remains untouched. Otherwise, it is directly altered.

    Function only returns data when copy_files==True.
    """
    rng = np.random.RandomState(seed)
    modes = []

    if rng.randint(2) and "rot" in mode_dict:
        modes.append('rot')
        to_aug, rotation_angle = rotate_2D(to_aug,rng)
    if rng.randint(2) and "zoom" in mode_dict:
        modes.append('zoom')
        to_aug = zoom_2D(to_aug,rng)
    if rng.randint(2) and "hflip" in mode_dict:
        modes.append('hflip')
        to_aug = hflip_2D(to_aug,rng)
    if rng.randint(2) and "vflip" in mode_dict:
        modes.append('vflip')
        to_aug = vflip_2D(to_aug,rng)

    return to_aug


In [236]:
"""================================================="""
"""============ Cropping for DataLoader ============"""
"""================================================="""
def get_crops_per_batch(batch_to_crop, idx_batch=None, crop_size=[128,128], n_crops=1, seed=1):
    """
    Function to crop from input images.
    Takes as input a list of same-shaped 3D/4D-arrays with Ch,W,H(,D). If an index-file
    is supplied, crops will only be taken in and around clusters in the index file. If the index-file
    contains no clusters, then a random crop will be taken.

    Arguments:
    batch_to_crop:      list of batches that need to be cropped. Note that cropping is performed independently for
                        each image of a batch.
    idx_batch:          Batch of same size as input batches. Contains either clusters (i.e. ones) from which a
                        cluster-center will be sampled or None. In this case, the center will be randomly selected.
                        If not None, prov_coords must be None. The idx_image should ahve shape (1,W,H).
    prov_coords:        If we have precomputed indices where we simply want to crop around, pass with prov_coords-argument.
                        In this case, idx_batch must be None! When passed, prov_coords should be a list of lists/arrays containing
                        the coordinate suggestions and should be of length batch_size!
                        It is assumed that all cooridnates are already adjusted to viable ranges per volume.
    crop_size:          Size of the crops to take -> len(crop_size) = input_batch.ndim-1, i.e. ignore batchdimension.
    n_crops:            Number of crops to take per image. Ensure that this coincides with your chosen batchsize during training.
    """
    rng = np.random.RandomState(seed)

    # assert (idx_batch is not None and prov_coords is None) or \
    #        (idx_batch is None and prov_coords is not None) or \
    #        (idx_batch is None and prov_coords is None), "Error when passing arguments for idx_batch and/or prov_coords!"
    #
    # assert all((np.array(batch_to_crop[0].shape[-len(crop_size):])-np.array(crop_size))>0), "Crop size chosen to be bigger than volume!"

    sup = list(1-np.array(crop_size)%2)
    bl_len = len(batch_to_crop)
    batch_list_to_return = []

    ### Provide idx-list
    batch_list_to_return_temp = [[] for i in range(len(batch_to_crop))]

    if idx_batch is not None:
        all_crop_idxs = np.where(idx_batch[0,:]==1) if np.sum(idx_batch[0,:])!=0 else [[]]
    else:
        all_crop_idxs = [[]]

    if len(all_crop_idxs[0]) > 0:
        if idx_batch is not None:
            crop_idx = [np.clip(rng.choice(ax),crop_size[i]//2-1,batch_to_crop[0][:].shape[i+1]-crop_size[i]//2-1) for i,ax in enumerate(all_crop_idxs)]
    else:
        crop_idx = [rng.randint(crop_size[i]//2-1,np.array(batch_to_crop[0].shape[i+1])-crop_size[i]//2-1) for i in range(batch_to_crop[0].ndim-1)]
    # if prov_coords is not None:
    # slice_list = [slice(0,None)]+[slice(center-crop_size[i]//2+mv,center+crop_size[i]//2+1) for i,(center,mv) in enumerate(zip(crop_idx,sup))]
    # else:
    
    crop_coordinates = [(center - crop_size[i] // 2 + mv, center + crop_size[i] // 2 + 1) for i, (center, mv) in enumerate(zip(list(crop_idx),sup))]
    
    for i in range(bl_len):
        batch_list_to_return.append(batch_to_crop[i][:, crop_coordinates[0][0]: crop_coordinates[0][1], crop_coordinates[1][0]: crop_coordinates[1][1]])

    return tuple(batch_list_to_return)

In [237]:
"""========================================================================="""
### CONFIG PARAMETERS

liver_training_config = {
    "mode": "2D",
	"data": "liver",
	"n_epochs": 50,
	"lr": 3e-05,
	"l2_reg": 1e-05,
	"gpu": 0,
	"num_workers": 8,
	"batch_size": 2,
	"step_size": [25, 42],
	"gamma": 0.1,
	"crop_size": [256, 256],
	"perc_data": 1,
	"train_val_split": 0.9,
	"seed": 1,
	"loss_func": "multiclass_pwce",
	"class_weights": [1, 1],
	"num_classes": 2,
	"augment": ['rot', 'zoom', 'hflip', 'vflip'],
	"verbose_idx": 200,
	"initialization": "",
	"pos_sample_chance": 2,
	"no_standardize": True,
	"epsilon": 1e-06,
	"wmap_weight": 3,
	"weight_score": [1, 1],
	"focal_gamma": 1.5,
	"Training_ROI_Vicinity": 4,
	"savename": "liver_small",
	"use_weightmaps": True,
	"require_one_hot": False,
	"num_out_classes": 2
}

lesion_training_config = {
    "mode": "2D",
	"data": "lesion",
	"n_epochs": 50,
	"lr": 3e-05,
	"l2_reg": 1e-05,
	"gpu": 0,
	"num_workers": 8,
	"batch_size": 2,
	"step_size": [25, 42],
	"gamma": 0.1,
	"crop_size": [],
	"perc_data": 1,
	"train_val_split": 0.9,
	"seed": 1,
	"loss_func": "multiclass_pwce",
	"class_weights": [1, 1],
	"num_classes": 2,
	"augment": ['rot', 'zoom', 'hflip', 'vflip'],
	"verbose_idx": 200,
	"initialization": "placeholder/SAVEDATA/Standard_Liver_Networks/vUnet2D_liver_small",
	"pos_sample_chance": 2,
	"no_standardize": True,
	"epsilon": 1e-06,
	"wmap_weight": 3,
	"weight_score": [1, 1],
	"focal_gamma": 1.5,
	"Training_ROI_Vicinity": 4,
	"savename": "liver_small",
	"use_weightmaps": True,
	"require_one_hot": False,
	"num_out_classes": 2
}

"""========================================================================="""
### META FUNCTION TO RETURN ADJUSTED DATASETS, e.g. train-val-split --- 2D
def Generate_Required_Datasets(config):
    rng = np.random.RandomState(config['seed'])
    vol_info = {}
    vol_info['volume_slice_info'] = pd.read_csv(ROOT_PREPROCESSED_TRAINING_DATA_PATH+'/Assign_2D_Volumes.csv',     header=0)
    vol_info['target_mask_info']  = pd.read_csv(ROOT_PREPROCESSED_TRAINING_DATA_PATH+'/Assign_2D_LesionMasks.csv', header=0) if config['data'] == 'lesion' else pd.read_csv(ROOT_PREPROCESSED_TRAINING_DATA_PATH+'/Assign_2D_LiverMasks.csv', header=0)

    if config['data']=='lesion':  vol_info['ref_mask_info']     = pd.read_csv(ROOT_PREPROCESSED_TRAINING_DATA_PATH+'/Assign_2D_LiverMasks.csv', header=0)
    if config['use_weightmaps']:  vol_info['weight_mask_info']  = pd.read_csv(ROOT_PREPROCESSED_TRAINING_DATA_PATH+'/Assign_2D_LesionWmaps.csv', header=0) if config['data'] == 'lesion' else pd.read_csv(ROOT_PREPROCESSED_TRAINING_DATA_PATH+'/Assign_2D_LiverWmaps.csv', header=0)

    available_volumes = sorted(list(set(np.array(vol_info['volume_slice_info']['Volume']))), key=lambda x: int(x.split('-')[-1]))
    rng.shuffle(available_volumes)

    percentage_data_len = int(len(available_volumes)*config['perc_data'])
    train_val_split     = int(percentage_data_len*config['train_val_split'])
    training_volumes    = available_volumes[:percentage_data_len][:train_val_split]
    validation_volumes  = available_volumes[:percentage_data_len][train_val_split:]


    training_dataset   = Basic_Image_Dataset_2D(vol_info, training_volumes, config)
    validation_dataset = Basic_Image_Dataset_2D(vol_info, validation_volumes, config, is_validation=True)
    return training_dataset, validation_dataset


"""========================================================================="""
### BASE DATASET CLASS IN 2D
class Basic_Image_Dataset_2D(torch.utils.data.Dataset):
    def __init__(self, vol_info, volumes, config, is_validation=False):
        self.config = config

        self.vol_info = vol_info

        self.is_validation = is_validation

        self.rng = np.random.RandomState(self.config["seed"])

        self.available_volumes = volumes

        self.rvic = self.config["Training_ROI_Vicinity"]

        self.channel_size = 1

        self.data_augmentation = True
        if len(self.config["augment"]) == 0 or self.is_validation:
            self.data_augmentation = False

        self.input_samples = {'Neg':[], 'Pos':[]}

        self.div_in_volumes = {
            key: {
                    'Input_Image_Paths':[],
                    'Has Target Mask':[],
                    'Wmap_Paths':[],
                    'TargetMask_Paths':[],
                    'Has Ref Mask':[],
                    'RefMask_Paths':[]
                } for key in self.available_volumes
            }
        
        # Record data paths for each slice in all training/validation volumes:
        # * scan slice
        # * object annotation flag (whether or not object class is visible in the scan slice/annotated scan slice)
        # * weightmap
        # * scan annotation
        # * liver annotation flag (whether or not the liver is visible in the scan slice/annotated scan slice)
        # * annotated liver scan slice
        for i,vol in enumerate(vol_info['volume_slice_info']['Volume']):
            if vol in self.div_in_volumes.keys():
                self.div_in_volumes[vol]['Input_Image_Paths'].append(vol_info['volume_slice_info']['Slice Path'][i])
                self.div_in_volumes[vol]['Has Target Mask'].append(vol_info['target_mask_info']['Has Mask'][i])
                if self.config['use_weightmaps']: self.div_in_volumes[vol]['Wmap_Paths'].append(vol_info['weight_mask_info']['Slice Path'][i])
                self.div_in_volumes[vol]['TargetMask_Paths'].append(vol_info['target_mask_info']['Slice Path'][i])
                if self.config['data']=='lesion': self.div_in_volumes[vol]['Has Ref Mask'].append(vol_info['ref_mask_info']['Has Mask'][i])
                if self.config['data']=='lesion': self.div_in_volumes[vol]['RefMask_Paths'].append(vol_info['ref_mask_info']['Slice Path'][i])

        self.volume_details = {
            key: {
                    'Input_Image_Paths':[],
                    'TargetMask_Paths':[],
                    'Wmap_Paths':[],
                    'RefMask_Paths':[]
                } for key in self.available_volumes
            }

        # Populate dictionary with all necessary data for training
        for i,vol in enumerate(self.div_in_volumes.keys()):
            for j in range(len(self.div_in_volumes[vol]['Input_Image_Paths'])):
                crop_condition = np.sum(self.div_in_volumes[vol]['Has Ref Mask'][int(np.clip(j-self.rvic, 0, None)):j+self.rvic])
                if self.config['data']=='liver': crop_condition=True

                if crop_condition:
                    extra_ch = self.channel_size//2
                    low_bound, low_diff = np.clip(j-extra_ch,0,None).astype(int), extra_ch-j
                    up_bound, up_diff = np.clip(j+extra_ch+1,None,len(self.div_in_volumes[vol]["Input_Image_Paths"])).astype(int), j+extra_ch+1-len(self.div_in_volumes[vol]["Input_Image_Paths"])

                    vol_slices = self.div_in_volumes[vol]["Input_Image_Paths"][low_bound:up_bound]

                    if low_diff>0:
                        extra_slices = self.div_in_volumes[vol]["Input_Image_Paths"][low_bound+1:low_bound+1+low_diff][::-1]
                        vol_slices = extra_slices+vol_slices
                    if up_diff>0:
                        extra_slices = self.div_in_volumes[vol]["Input_Image_Paths"][up_bound-up_diff-1:up_bound-1][::-1]
                        vol_slices = vol_slices+extra_slices

                    self.volume_details[vol]['Input_Image_Paths'].append(vol_slices)
                    self.volume_details[vol]['TargetMask_Paths'].append(self.div_in_volumes[vol]['TargetMask_Paths'][j])

                    if self.config['data']!='liver':  self.volume_details[vol]['RefMask_Paths'].append(self.div_in_volumes[vol]['RefMask_Paths'][j])
                    if self.config['use_weightmaps']: self.volume_details[vol]['Wmap_Paths'].append(self.div_in_volumes[vol]['Wmap_Paths'][j])

                    type_key = 'Pos' if self.div_in_volumes[vol]['Has Target Mask'][j] or self.is_validation else 'Neg'
                    self.input_samples[type_key].append((vol, len(self.volume_details[vol]['Input_Image_Paths'])-1))

        self.n_files  = np.sum([len(self.input_samples[key]) for key in self.input_samples.keys()])
        self.curr_vol = self.input_samples['Pos'][0][0] if len(self.input_samples['Pos']) else self.input_samples['Neg'][0][0]

    def __getitem__(self, idx):
        #Choose a positive example with 50% change if training.
        #During validation, 'Pos' will contain all validation samples.
        #Note that again, volumes without lesions/positive target masks need to be taken into account.
        type_choice = not idx % self.config['pos_sample_chance'] or self.is_validation
        modes       = list(self.input_samples.keys())
        type_key    = modes[type_choice] if len(self.input_samples[modes[type_choice]]) else modes[not type_choice]
    
        type_len = len(self.input_samples[type_key])

        vol, idx   = self.input_samples[type_key][idx%type_len]
        next_vol,_ = self.input_samples[type_key][(idx+1)%type_len]

        vol_change = next_vol!=vol
        self.curr_vol   = vol
    
        intvol = self.volume_details[vol]["Input_Image_Paths"][idx]
        intvol = intvol[len(intvol)//2]

        input_image  = np.concatenate([np.expand_dims(np.load(sub_vol),0) for sub_vol in self.volume_details[vol]["Input_Image_Paths"][idx]],axis=0)

        #Perform data standardization
        if self.config['no_standardize']:
            input_image  = normalize(input_image, zero_center=False, unit_variance=False, supply_mode="orig")
        else:
            input_image  = normalize(input_image)

        #Lesion/Liver Mask to output
        target_mask = np.load(self.volume_details[vol]["TargetMask_Paths"][idx])
        target_mask = np.expand_dims(target_mask,0)


        #Liver Mask to use for defining training region of interest
        crop_mask = np.expand_dims(np.load(self.volume_details[vol]["RefMask_Paths"][idx]),0) if self.config['data']=='lesion' else None
        #Weightmask to output
        weightmap = np.expand_dims(np.load(self.volume_details[vol]["Wmap_Paths"][idx]),0).astype(float) if self.config['use_weightmaps'] else None


        #Generate list of all files that would need to be crop, if cropping is required.
        files_to_crop  = [input_image, target_mask]
        is_mask        = [0,1]
        if weightmap is not None:
            files_to_crop.append(weightmap)
            is_mask.append(0)
        if crop_mask is not None:
            files_to_crop.append(crop_mask)
            is_mask.append(1)

        #First however, augmentation, if required, is performed (i.e. on fullsize images to remove border artefacts in crops).
        if self.data_augmentation:
            files_to_crop = list(augment_2D(files_to_crop, mode_dict = self.config["augment"],
                                               seed=self.rng.randint(0,1e8), is_mask = is_mask))

        #If Cropping is required, we crop now.
        if len(self.config['crop_size']) and not self.is_validation:
            #Add imaginary batch axis in gu.get_crops_per_batch
            crops_for_picked_batch  = get_crops_per_batch(files_to_crop, crop_mask, crop_size=self.config['crop_size'], seed=self.rng.randint(0,1e8))
            input_image     = crops_for_picked_batch[0]
            target_mask     = crops_for_picked_batch[1]
            weightmap       = crops_for_picked_batch[2] if weightmap is not None else None
            crop_mask       = crops_for_picked_batch[-1] if crop_mask is not None else None


        # #If a one-hot encoded target mask is required:
        # one_hot_target = gu.numpy_generate_onehot_matrix(target_mask, self.pars.Training['num_classes']) if self.pars.Training['require_one_hot'] else None

        # #If we use auxiliary inputs to input additional information into the network, we compute respective outputs here.
        # auxiliary_targets, auxiliary_wmaps, one_hot_auxiliary_targets   = None, None, None
        # if not self.is_validation and self.pars.Network['use_auxiliary_inputs']:
        #     auxiliary_targets, auxiliary_wmaps, one_hot_auxiliary_targets   = [], [], []
        #     for val in range(len(self.pars.Network['structure'])-1):
        #         aux_level = 2**(val+1)
        #         aux_img = np.round(st.resize(target_mask,(target_mask.shape[0], target_mask.shape[1]//aux_level,target_mask.shape[2]//aux_level),order=0, mode="reflect", preserve_range=True))
        #         auxiliary_targets.append(aux_img)
        #         if self.pars.Training['require_one_hot']:
        #             one_hot_auxiliary_targets.append(gu.numpy_generate_onehot_matrix(aux_img, self.pars.Training['num_classes']))
        #         if weightmap is not None:
        #             aux_img = st.resize(weightmap,(weightmap.shape[0], weightmap.shape[1]//aux_level,weightmap.shape[2]//aux_level),order=0, mode="reflect", preserve_range=True)
        #             auxiliary_wmaps.append(aux_img)

        one_hot_target = None
        auxiliary_targets = None
        one_hot_auxiliary_targets = None
        auxiliary_wmaps = None

        #Final Output Dictionary
        return_dict = {"input_images":input_image.astype(float), "targets":target_mask.astype(float),
                       "crop_option":crop_mask.astype(float) if crop_mask is not None else None,
                       "weightmaps":weightmap.astype(float) if weightmap is not None else None,
                       "one_hot_targets":one_hot_target,
                       "aux_targets":auxiliary_targets, "one_hot_aux_targets": one_hot_auxiliary_targets,
                       "aux_weightmaps": auxiliary_wmaps, 'internal_slice_name':intvol, 'vol_change':vol_change}

        return_dict = {key:item for key,item in return_dict.items() if item is not None}
        return return_dict


    def __len__(self):
        return self.n_files

In [238]:
train_dataset, val_dataset = Generate_Required_Datasets(liver_training_config)

In [239]:
train_data_loader = torch.utils.data.DataLoader(train_dataset, num_workers=liver_training_config['num_workers'], batch_size=liver_training_config['batch_size'], pin_memory=False, shuffle=True)
val_data_loader   = torch.utils.data.DataLoader(val_dataset,   num_workers=0, batch_size=1, shuffle=False)