# Training data generation

convert a list of image & segmented-images into training patches for a U-net type network.
Because of Unet's architecture, the input patches are larger than the output patches.
Ground truth data is saved as a (164x164x5) array.

### Imports & magics

In [None]:
%matplotlib notebook
%load_ext autoreload

In [None]:
%autoreload 2
import os
import numpy as np
import time
from scripts_training_data.data_preparation import *
from pims import ND2_Reader as nd2
from multiprocessing import Pool
from skimage import io


### Paths

In [None]:
dataset = '20190326'

root_dir = 'D:/Bachelor_Project/code/project/datasets/{}/'.format(dataset)
out_dir = '{}/patches/'.format(root_dir)

label_in_path = '{}/segmented_bordered/'.format(root_dir)

label_dir = '{}/label/'.format(root_dir)
raw_dir = '{}/raw/'.format(root_dir)

### extract data from stack

In [None]:
from scripts_training_data.data_preparation import *

stack_to_image_list(stack_path, raw_dir)

In [None]:
from scripts_training_data.data_preparation import *
classification_to_one_hot_ground_truth(label_in_path, label_dir, number_of_classes = None)

### parameters for patch distortion
specify parameters to generate distributions 

In [None]:
patch_augmentation_parameters = {
    'bin_image': False,
    'scaling': [0.9, 1.1, 'uniform'],
    'transposing': [0, 2, 'randint'],
    'rotating': [0, 4, 'randint'],
    'contrast_shifting': [0.2, 2.0, 'uniform'],
    'noise_mean': [0, 0.1, 'normal'],
    'noise_std': [0,0.05, 'normal'],
    'input_patch_size': [348,348],
    'output_patch_size': [164,164],
    'augmentations_per_image': 10,
    'patches_per_augmentation': 10,
    'label_dir': label_dir,
    'raw_dir': raw_dir,
    'img_type': np.float32,
    'label_class': 2,   # was 2, seems wrong that it has to be None to get it working
    'min_pixels': 200,
    'out_path_raw': '{}/raw/'.format(out_dir),
    'out_path_label': '{}/label/'.format(out_dir),
    'out_path_wmap': '{}/wmap/'.format(out_dir)
}

### make patches

In [None]:
from scripts_training_data.extract_patches import *

if __name__ == '__main__':
    start = time.time()
    
    img_list = [a for a in os.listdir(label_dir) if a.endswith('.tif')]
    
    paramlist = []
    for im in img_list:
        pap = patch_augmentation_parameters.copy()
        pap['frame'] = im
        paramlist.append(pap)
        
    p = Pool(processes=7)
    
    p.map(gt_generation, paramlist)

### extract patch classes
For weightmap construction we want to correct for class frequency. So we need to count them.

In [None]:
from scripts_training_data.patch_statistics import *

classcounts = count_classes(patch_augmentation_parameters['out_path_label'])
classcounts /= np.sum(classcounts)

with open('{}/classcounts.txt'.format(out_dir), 'w') as f:
    f.write('classcounts = {} \n'.format(classcounts))

### calculate training_dataset_statistics

to allow for normalisation, like $\frac{\text{img}-\text{mean}}{\text{variance}}$, we need to calculate these things for the entire dataset:

In [None]:
from scripts_training_data.patch_statistics import *

mean, sampleVariance = compute_training_set_statistics(patch_augmentation_parameters['out_path_raw'], )

with open('{}/patch_mean_var.txt'.format(out_dir), 'w') as f:
    f.write('mean = {} \n'.format(mean))
    f.write('variance = {} \n'.format(sampleVariance))

### Making weightmaps

In [None]:
from scripts_training_data.make_weightmaps import *

classcounts = np.array([0.2, 0.2, 0.2, 0.2, 0.2])

make_dirs(patch_augmentation_parameters['out_path_wmap'])

if __name__ == '__main__':
    
    tuplist = [
        (
            '{}/{}'.format(patch_augmentation_parameters['out_path_label'], a),
            '{}/{}'.format(patch_augmentation_parameters['out_path_wmap'], a),
            np.array(classcounts),
            10        )
        for a in os.listdir(patch_augmentation_parameters['out_path_label']) if a.endswith('.tif')
    ]

    p = Pool(processes=7)
    
    p.map(make_weightmap, tuplist)
    
    


In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt
# wmpa.shape
plt.imshow(wmap)
plt.show()

In [None]:
np.max(wmap)