# Train U-Net models
* Some issues with using keras for data augmentation, the standard functions may only support 3 channel images?

### Manual Data Augmentation
* First try using a simple augmentation strategy, only using 

In [1]:
import os
import numpy as np
import glob
from scipy.ndimage import rotate
from PIL import Image

import keras
from keras.models import Model
from keras import backend as K
from keras.engine.topology import Layer
from keras import metrics
from keras import layers
from keras.models import Sequential
from keras.layers import Dropout, Flatten, Reshape, Input, concatenate, Conv2DTranspose
from keras.layers.core import Activation, Dense, Lambda
from keras.constraints import maxnorm
from keras.optimizers import SGD, Adam
from keras.layers.convolutional import Conv2D
from keras.layers.convolutional import MaxPooling2D, AveragePooling2D
from keras.layers.normalization import BatchNormalization


############ DATA GENERATORS
def data_gen_aug_combined(file_loc, mask_loc, batch_size, square_rot_p=.3, seed=101):
    # square_rot_p is the prob of using a 90x rotation, otherwise sample from 360. Possibly not useful
    # translate is maximum number of pixels to translate by
    # crops are done 
    square_rot_p = int(square_rot_p)
    np.random.seed(seed)
    all_files=glob.glob(os.path.join(file_loc, '*'))
    all_masks=[]

    all_files = [loc for loc in all_files if loc.rsplit('.', 1)[-1] in ['tif']]

#     for file in all_files:
#         im_name = str(file.rsplit('.', 1)[-2].rsplit('/', 1)[1].rsplit('_', 1)[0].replace(" ", "_"))
#         loc = os.path.join(mask_loc, im_name+'.npy')
#         all_masks.append(loc)
        
    for file in all_files:
        im_name = str(file.rsplit('.', 1)[-2].rsplit('/', 1)[1])
        loc = os.path.join(mask_loc, im_name+'.tif')
        all_masks.append(loc)

    while 1:
        c = list(zip(all_files, all_masks))
        np.random.shuffle(c)
        all_files, all_masks = zip(*c)

        num_batches = int(np.floor(len(all_files)/batch_size))-1

        for batch in range(num_batches):
            x=[]
            y=[]
            batch_files = all_files[batch_size*batch:batch_size*(batch+1)]
            batch_files_mask = all_masks[batch_size*batch:batch_size*(batch+1)]

            for index in range(len(batch_files)):
                image_loc = batch_files[index]
                mask_loc = batch_files_mask[index]

                # load the image
                image = Image.open(image_loc)
                width, height = image.size
                image = np.reshape(np.array(image.getdata()), (height, width, 3))

                #load the mask
                mask = Image.open(mask_loc)
                width, height = mask.size
                mask = np.reshape(np.array(mask.getdata()), (height, width, 4))
                
                # All the randomness:
                height, width = np.shape(image)[0], np.shape(image)[1]
                crop_row = np.random.randint(0, height-320)
                crop_col = np.random.randint(0, width-368)
                flip_vert = np.random.randint(0, 2)
                flip_hor = np.random.randint(0, 2)

                # APPLY AUGMENTATION:
                # flips
                if flip_vert:
                    image = np.flipud(image)
                    mask = np.flipud(mask)

                if flip_hor:
                    image = np.fliplr(image)
                    mask = np.fliplr(mask)

                # rotation
                square_rot =  bool((np.random.uniform(0, 1, 1)<square_rot_p))
                if square_rot:  # maybe this is dumb, but it cant hurt
                    rotations=['0', '90', '180', '270']
                    angle = int(random.choice(rotations))
                    image = rotate(image, angle, reshape=False)
                    mask = rotate(mask, angle, reshape=False)

                else:
                    angle = np.random.uniform(0, 360, 1)
                    image = rotate(image, angle, reshape=False)
                    mask = rotate(mask, angle, reshape=False)
 
                # crop to 320 x 360 so it will fit into network, and for data augmentation
                image = image[crop_row:crop_row+320, crop_col:crop_col+368]
                mask = mask[crop_row:crop_row+320, crop_col:crop_col+368]

                image = image/255.0 # make pixels in [0,1] 
                x.append(image)
                y.append(mask)
            x=np.array(x)
            y=np.array(y)
            yield (x, y)


def data_gen_combined(file_loc, mask_loc, batch_size, seed=101):
    np.random.seed(seed)
    all_files=glob.glob(os.path.join(file_loc, '*'))
    all_masks=[]
    for file in all_files:
        im_name = str(file.rsplit('.', 1)[-2].rsplit('/', 1)[1])
        loc = os.path.join(mask_loc, im_name+'.tif')
        all_masks.append(loc)

    all_files = [loc for loc in all_files if loc.rsplit('.', 1)[-1] in ['tif']]

    while 1:
        c = list(zip(all_files, all_masks))
        np.random.shuffle(c)
        all_files, all_masks = zip(*c)
        
        num_batches = int(np.floor(len(all_files)/batch_size))-1
        for batch in range(num_batches):
            x=[]
            y=[]
            batch_files = all_files[batch_size*batch:batch_size*(batch+1)]
            batch_files_mask = all_masks[batch_size*batch:batch_size*(batch+1)]

            for index in range(len(batch_files)):
                image_loc = batch_files[index]
                mask_loc = batch_files_mask[index]

                # load the image
                image = Image.open(image_loc)
                width, height = image.size
                image = np.reshape(np.array(image.getdata()), (height, width, 3))

                #load the mask
                mask = Image.open(mask_loc)
                width, height = mask.size
                mask = np.reshape(np.array(mask.getdata()), (height, width, 4))
                
                ################################ IMPLEMENT::::
                # We will pad the imput to make them all the same size:
                
                # make it the same size as the training examples
                height, width = np.shape(image)[0], np.shape(image)[1]
                crop_row = np.random.randint(0, height-320)
                crop_col = np.random.randint(0, width-368)

                # crop to 320 x 360 so it will fit into network, and for data augmentation
                image = image[crop_row:crop_row+320, crop_col:crop_col+368]
                mask = mask[crop_row:crop_row+320, crop_col:crop_col+368]

                image = image/255.0 # make pixels in [0,1]     
                x.append(image)
                y.append(mask)

            x=np.array(x)
            y=np.array(y)
            yield (x, y)

Using TensorFlow backend.


## Models

#### Loss
Loss Function is different than the usual dice coefficient. We won't measure overlap. It is made of two parts:
1. MSE on the distance to the nearest nuclei.
2. Class of the nearest nuclei
Both of these parts should be 0 if the nearest nuclei is over 20 pixels away? At least I think so. For sure the distance is meaningless, and the classifation would just add some noise to the model.

#### Model
* Test a models smaller and larger than the original U-Net.
* Try adding batchnorm

In [2]:
# Distance loss function
def distance_loss(y_true, y_pred):
    weight = .5 # how mush does the distance matter compared to the cross entropy (fast ai used .001 for 4 more uncertain ones)
    # Already scaled distance values between (0,1). Cut off ones larger because this doesn't hurt the prediction
#     K.int_shape(y_true)
#     K.int_shape(y_pred)
#     y_pred_clip = K.clip(y_pred[:, :, 0], -1, 1)
#     K.int_shape(y_pred_clip)
    distance_loss = K.binary_crossentropy(y_pred[:, :, :, 0], y_true[:, :, :, 0])
#     K.int_shape(distance_loss)
    
    cross_entropy = K.categorical_crossentropy(y_true[:, :, :, 1:], y_pred[:, :, :, 1:])    
#     K.int_shape(cross_entropy)

    return(distance_loss*weight+(1-weight)*cross_entropy)


# Remove all the predictions from the cost that are under 20 away for cross entropy. Not for MSE because it should learn easily
# def distance_loss_under20(y_true, y_pred):
#     weight = .05 # how mush does the distance matter compared to the cross entropy (fast ai used .001 for 4 more uncertain ones)
#     # Clip the distance values to be less than 20 :
#     y_pred[:, :, 0] = K.clip(y_pred[:, :, 0], -1, 1)
#     mse = K.mean(K.square(y_pred[:, :, 0] - y_true[:, :, 0]), axis=-1)
    
#     # Only look at the elements with a distance of less than 20  pixels from the nuclei.
#     y_true_clip = 
#     y_pred_clip = 
#     cross_entropy = categorical_crossentropy(y_true[:, :, 1:], y_pred[:, :, 1:])    
#     return(mse*weight+(1-weight)*cross_entropy)




############ UNET ARCHITECTURES 

def unet_standard(learning_rate=.0001):
    input_shape = (None, None, 3)
    img_input = Input(shape=input_shape)
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(img_input)
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)

    up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)

    up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)

    up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)

    up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)

    conv10_dist = Conv2D(1, (1, 1), activation='sigmoid')(conv9)
    conv10_cross_entropy = Conv2D(3, (1, 1), activation='softmax')(conv9)
    output = concatenate([conv10_dist, conv10_cross_entropy])

    model = Model(img_input, output)
    model.compile(optimizer=Adam(lr=learning_rate), loss=distance_loss, metrics=[distance_loss])
    return model

def unet_mid(learning_rate=.0001):
    input_shape = (None, None, 3)
    img_input = Input(shape=input_shape)
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(img_input)
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv5 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool4)
    conv5 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv5)

    up6 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv5), conv3], axis=3)
    conv6 = Conv2D(128, (3, 3), activation='relu', padding='same')(up6)
    conv6 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv6)

    up7 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv6), conv2], axis=3)
    conv7 = Conv2D(64, (3, 3), activation='relu', padding='same')(up7)
    conv7 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv7)

    up8 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv7), conv1], axis=3)
    conv8 = Conv2D(32, (3, 3), activation='relu', padding='same')(up8)
    conv8 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv8)

    conv9_dist = Conv2D(1, (1, 1), activation='sigmoid')(conv8)
    conv9_cross_entropy = Conv2D(3, (1, 1), activation='softmax')(conv8)
    output = concatenate([conv9_dist, conv9_cross_entropy])

    model = Model(img_input, output)
    model.compile(optimizer=Adam(lr=learning_rate), loss=distance_loss, metrics=[distance_loss])
    return model




def conv_block(x,
              filters,
              num_row,
              num_col,
              dropout, 
              padding='same',
              strides=(1, 1),
              activation='relu'):
    x = Conv2D(filters, (num_row, num_col), strides=strides, padding=padding, activation=activation)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Dropout(dropout)(x)

    x = Conv2D(filters, (num_row, num_col), strides=strides, padding=padding, activation=activation)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Dropout(dropout)(x)
    return x

def unet_paper(learning_rate=.0001):
    input_shape = (None, None, 3)
    img_input = Input(shape=input_shape)

    conv1 = conv_block(img_input, 32, 3, 3, dropout = .1, padding='same', strides=(1, 1), activation='relu')
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = conv_block(pool1, 64, 3, 3, dropout = .1, padding='same', strides=(1, 1), activation='relu')
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = conv_block(pool2, 128, 3, 3, dropout = .1, padding='same', strides=(1, 1), activation='relu')
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = conv_block(pool3, 128, 3, 3, dropout = .1, padding='same', strides=(1, 1), activation='relu')

    up5 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv4), conv3], axis=3)
    conv5 = conv_block(up5, 128, 3, 3, dropout = .1, padding='same', strides=(1, 1), activation='relu')

    up6 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv5), conv2], axis=3)
    conv6 = conv_block(up6, 64, 3, 3, dropout = .1, padding='same', strides=(1, 1), activation='relu')

    up7 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv6), conv1], axis=3)
    conv7 = conv_block(up7, 32, 3, 3, dropout = .1, padding='same', strides=(1, 1), activation='relu')
    
    conv8_dist = Conv2D(1, (1, 1), activation='sigmoid')(conv7)
    conv8_cross_entropy = Conv2D(3, (1, 1), activation='softmax')(conv7)
    output = concatenate([conv8_dist, conv8_cross_entropy])
    
    model = Model(img_input, output)
    model.compile(optimizer=Adam(lr=learning_rate), loss=distance_loss, metrics=[distance_loss])
    return model

## Test out

In [3]:
import os
import glob
import random
import numpy as np 
import pandas as pd
import keras
import pickle
from keras import backend as K
from keras.engine.topology import Layer
from keras.layers import Dropout, Flatten, Reshape, Input
from keras.layers.core import Activation, Dense, Lambda
from keras.callbacks import ModelCheckpoint, EarlyStopping

learning_rate=.005
epochs=100
batch_size=8
data_loc='/home/rbbidart/project/rbbidart/cancer_hist/full_slides2'
mask_loc='/home/rbbidart/project/rbbidart/cancer_hist/im_dist_labels'
out_loc='/home/rbbidart/cancer_hist_out/unet_dist/unet_paper_custom_aug'


# Locations
train_loc = os.path.join(str(data_loc),'train', str(0))
train_mask_loc = os.path.join(str(mask_loc),'train', str(0))

valid_loc = os.path.join(str(data_loc),'valid', str(0))
valid_mask_loc = os.path.join(str(mask_loc),'valid', str(0))

num_train = len(glob.glob(os.path.join(train_loc, '*')))/2-2
num_valid = len(glob.glob(os.path.join(valid_loc, '*')))/2-2
print(valid_loc)
print('num_train', num_train)
print('num_valid', num_valid)

# Params for all models
batch_size=int(batch_size)   # make this divisible by len(x_data)
steps_per_epoch = np.floor(num_train/batch_size) # num of batches from generator at each epoch. (make it full train set)
validation_steps = np.floor(num_valid/batch_size)# size of validation dataset divided by batch size
print('validation_steps', validation_steps)

model = unet_paper(learning_rate=learning_rate)
name = 'unet_paper'+'_'+str(learning_rate)+'_'+'custom_aug'
out_file=os.path.join(str(out_loc), name)

# need a batch generator to augment the labels same as the train images
valid_generator = data_gen_combined(valid_loc, valid_mask_loc, batch_size, seed=101)
train_generator = data_gen_aug_combined(train_loc, train_mask_loc, batch_size, square_rot_p=.3,  seed=101)

callbacks = [EarlyStopping(monitor='distance_loss', patience=15, verbose=0),
        ModelCheckpoint(filepath=os.path.join(out_loc, name + '_.{epoch:02d}-{val_acc:.2f}.hdf5'), 
        verbose=1, monitor='val_loss', save_best_only=True)]

hist = model.fit_generator(train_generator,
                                  validation_data=valid_generator,
                                  steps_per_epoch=steps_per_epoch, 
                                  epochs=epochs,
                                  validation_steps=validation_steps,
                                  callbacks=callbacks)
pickle.dump(hist.history, open(out_file, 'wb'))

/home/rbbidart/project/rbbidart/cancer_hist/full_slides2/valid/0
num_train 87.0
num_valid 20.0
validation_steps 2.0
Epoch 1/100

Exception in thread Thread-5:
Traceback (most recent call last):
  File "/cvmfs/soft.computecanada.ca/nix/store/v29pphgl66qjvjck1mn4pcm5g9agk5kh-python3-3.5.2/lib/python3.5/threading.py", line 914, in _bootstrap_inner
    self.run()
  File "/cvmfs/soft.computecanada.ca/nix/store/v29pphgl66qjvjck1mn4pcm5g9agk5kh-python3-3.5.2/lib/python3.5/threading.py", line 862, in run
    self._target(*self._args, **self._kwargs)
  File "/home/rbbidart/tensorflow2/lib/python3.5/site-packages/keras/utils/data_utils.py", line 568, in data_generator_task
    generator_output = next(self._generator)
  File "<ipython-input-1-1ba4f1cb8e10>", line 120, in data_gen_combined
    all_masks.append(loc)
UnboundLocalError: local variable 'all_masks' referenced before assignment



StopIteration: 

In [None]:
import os
import glob
import random
import numpy as np 
import pandas as pd
import keras
import pickle
from keras import backend as K
from keras.engine.topology import Layer
from keras.layers import Dropout, Flatten, Reshape, Input
from keras.layers.core import Activation, Dense, Lambda
from keras.callbacks import ModelCheckpoint, EarlyStopping

learning_rate=.005
epochs=100
batch_size=8
data_loc='/home/rbbidart/project/rbbidart/cancer_hist/full_slides2'
mask_loc='/home/rbbidart/project/rbbidart/cancer_hist/im_dist_labels'
out_loc='/home/rbbidart/cancer_hist_out/unet_dist/unet_standard_custom_aug'


# Locations
train_loc = os.path.join(str(data_loc),'train', str(0))
train_mask_loc = os.path.join(str(mask_loc),'train', str(0))

valid_loc = os.path.join(str(data_loc),'valid', str(0))
valid_mask_loc = os.path.join(str(mask_loc),'valid', str(0))

num_train = len(glob.glob(os.path.join(train_loc, '*')))/2-2
num_valid = len(glob.glob(os.path.join(valid_loc, '*')))/2-2
print(valid_loc)
print('num_train', num_train)
print('num_valid', num_valid)

# Params for all models
batch_size=int(batch_size)   # make this divisible by len(x_data)
steps_per_epoch = np.floor(num_train/batch_size) # num of batches from generator at each epoch. (make it full train set)
validation_steps = np.floor(num_valid/batch_size)# size of validation dataset divided by batch size
print('validation_steps', validation_steps)

model = unet_standard(learning_rate=learning_rate)
name = 'unet_standard'+'_'+str(learning_rate)+'_'+'custom_aug'
out_file=os.path.join(str(out_loc), name)

# need a batch generator to augment the labels same as the train images
valid_generator = data_gen_combined(valid_loc, valid_mask_loc, batch_size, seed=101)
train_generator = data_gen_aug_combined(train_loc, train_mask_loc, batch_size, square_rot_p=.3,  seed=101)

callbacks = [EarlyStopping(monitor='distance_loss', patience=15, verbose=0),
        ModelCheckpoint(filepath=os.path.join(out_loc, name + '_.{epoch:02d}-{val_acc:.2f}.hdf5'), 
        verbose=1, monitor='val_loss', save_best_only=True)]

hist = model.fit_generator(train_generator,
                                  validation_data=valid_generator,
                                  steps_per_epoch=steps_per_epoch, 
                                  epochs=epochs,
                                  validation_steps=validation_steps,
                                  callbacks=callbacks)
pickle.dump(hist.history, open(out_file, 'wb'))

In [None]:
from PIL import Image

data_loc = '/home/rbbidart/project/rbbidart/cancer_hist/full_slides2'
out_loc = '/home/rbbidart/project/rbbidart/cancer_hist/full_slides2_k'

all_images=glob.glob(data_loc + '/**/*.tif', recursive=True)
for image_file in all_images:
    name = image_file.rsplit('/', 1)[-1].rsplit('.', 1)[0]
    new_loc = image_file.rsplit('/', 1)[0].replace('full_slides2', 'full_slides2_k')
    if not os.path.exists(new_loc):
        os.makedirs(new_loc)
    new_name = name+'.jpg'
    im = Image.open(image_file)
    im.save(os.path.join(new_loc, new_name))
    
data_loc = '/home/rbbidart/project/rbbidart/cancer_hist/im_dist_labels'
out_loc = '/home/rbbidart/project/rbbidart/cancer_hist/im_dist_labels_k'
all_images=glob.glob(data_loc + '/**/*.tif', recursive=True)
for image_file in all_images:
    name = image_file.rsplit('/', 1)[-1].rsplit('.', 1)[0]
    new_loc = image_file.rsplit('/', 1)[0].replace('im_dist_labels', 'im_dist_labels_k')
    if not os.path.exists(new_loc):
        os.makedirs(new_loc)
    new_name = name+'.jpg'
    im = Image.open(image_file)
    im.save(os.path.join(new_loc, new_name))

## Keras 

In [None]:
from keras.preprocessing.image import ImageDataGenerator
def get_generator(train_folder, train_mask_folder, valid_folder, valid_mask_folder,
                    width_shift_range=0.2,
                    height_shift_range=0.2,
                    horizontal_flip=True,
                    rotation_range=10,
                    zoom_range=0.2,
                    classes=['keras'],
                    fill_mode="constant"):
    batch_size = 4
    seed = 42
    # Example taken from https://keras.io/preprocessing/image/
    # We create two instances with the same arguments
    data_gen_args_train = dict(
                        width_shift_range=width_shift_range,
                        height_shift_range=height_shift_range,
                        horizontal_flip=horizontal_flip,
                        rotation_range=rotation_range,
                        zoom_range=zoom_range,
                        fill_mode=fill_mode, 
                        cval=0       
                        )
    data_gen_args_masks = dict(
                        width_shift_range=width_shift_range,
                        height_shift_range=height_shift_range,
                        horizontal_flip=horizontal_flip,
                        rotation_range=rotation_range,
                        zoom_range=zoom_range,
                        fill_mode=fill_mode,
                        cval=0    
                        )
    image_datagen = ImageDataGenerator(**data_gen_args_train)
    mask_datagen = ImageDataGenerator(**data_gen_args_masks)
    image_generator = image_datagen.flow_from_directory(
        train_folder,
        batch_size = batch_size,
        target_size = (640, 800),
        class_mode=None,
        color_mode='rgb',
        classes=classes,
        seed=seed)
    mask_generator = mask_datagen.flow_from_directory(
        train_mask_folder,
        batch_size = batch_size,    
        target_size = (640, 800),    
        class_mode=None,
        color_mode='grayscale',
        classes=classes,
        seed=seed)
    valid_image_generator = image_datagen.flow_from_directory(
        valid_folder,
        batch_size = batch_size,
        target_size = (640, 800),
        class_mode=None,
        color_mode='rgb',
        classes=classes,
        seed=seed)
    valid_mask_generator = mask_datagen.flow_from_directory(
        valid_mask_folder,
        batch_size = batch_size,    
        target_size = (640, 800),    
        class_mode=None,
        color_mode='grayscale',
        classes=classes,
        seed=seed)
    # combine generators into one which yields image and masks
    train_generator = zip(image_generator, mask_generator)
    valid_generator = zip(valid_image_generator, valid_mask_generator)   
    return train_generator,valid_generator

In [None]:
out_loc='/home/rbbidart/cancer_hist_out/unet_dist/unet_standard_keras_aug'
learning_rate=.005
epochs=100
batch_size=8
data_loc='/home/rbbidart/project/rbbidart/cancer_hist/full_slides2_k'
mask_loc='/home/rbbidart/project/rbbidart/cancer_hist/im_dist_labels_k'
out_loc='/home/rbbidart/cancer_hist_out/unet_dist/unet_standard_custom_aug'


train_loc = os.path.join(str(data_loc),'train')
train_loc_mask = os.path.join(str(mask_loc),'train')

valid_loc = os.path.join(str(data_loc),'valid')
valid_loc_mask = os.path.join(str(mask_loc),'valid')


train_generator,valid_generator=get_generator(train_loc, train_loc_mask, 
                                              valid_loc, valid_loc_mask,
                    width_shift_range=0.2,
                    height_shift_range=0.2,
                    horizontal_flip=True,
                    rotation_range=180,
                    zoom_range=0.3,
                    classes=['keras'],
                    fill_mode="constant")


num_train = len(glob.glob(train_loc + '/**/*.png', recursive=True))
num_valid = len(glob.glob(valid_loc + '/**/*.png', recursive=True))
print(valid_loc)
print('num_train', num_train)
print('num_valid', num_valid)

# Params for all models
batch_size=int(batch_size)   # make this divisible by len(x_data)
steps_per_epoch = np.floor(num_train/batch_size) # num of batches from generator at each epoch. (make it full train set)
validation_steps = np.floor(num_valid/batch_size)# size of validation dataset divided by batch size
print('validation_steps', validation_steps)

model = unet_standard(learning_rate=learning_rate)
name = 'unet_standard'+'_'+str(learning_rate)+'_'+'keras_aug'
out_file=os.path.join(str(out_loc), name)

# need a batch generator to augment the labels same as the train images
valid_generator = data_gen_combined(valid_loc, valid_mask_loc, batch_size, seed=101)
train_generator = data_gen_aug_combined(train_loc, train_mask_loc, batch_size, square_rot_p=.3,  seed=101)

callbacks = [EarlyStopping(monitor='distance_loss', patience=15, verbose=0),
        ModelCheckpoint(filepath=os.path.join(out_loc, name + '_.{epoch:02d}-{val_acc:.2f}.hdf5'), 
        verbose=1, monitor='val_loss', save_best_only=True)]

hist = model.fit_generator(train_generator,
                                  validation_data=valid_generator,
                                  steps_per_epoch=steps_per_epoch, 
                                  epochs=epochs,
                                  validation_steps=validation_steps,
                                  callbacks=callbacks)
pickle.dump(hist.history, open(out_file, 'wb'))