# Tiling and data augmentation: a first approach
The EDA notebook has shown us that data augmentation is needed: color stain augmentation could help bring the distributions of Karolinska and Radboud together, hopefully closing the distinctness of both distributions so that the resulting model can be stain invariant with respect to these two providers.

A second idea is to use tiling: the tissue samples consist of a large white background, which do not need to be evaluated in the model, saving processing and training time. The focus of this notebook is threefold:

* Create a small data augmentation scheme, and plotting/comparing the distributions of the original data, and the augmented distribution

* Creating a tiling example which can be used in a processing pipeline

* Creating a first baseline model, to be improved upon.

A few starting points are provided on Kaggle. 

* https://www.kaggle.com/c/prostate-cancer-grade-assessment/discussion/146855

## Part 1: Tiling implementation



![woof](https://www.googleapis.com/download/storage/v1/b/kaggle-user-content/o/inbox%2F1212661%2Fe6fe32d759a28480343001aa3c661723%2FTILE.png?generation=1588094975239255&alt=media)
*source: https://www.kaggle.com/c/prostate-cancer-grade-assessment/discussion/146855*

In [3]:
(128 - 600%128)%128

40

### hardships with tiling in keras + TF training
The concat tile pool approach requires a dynamic batch size that changes during training. An example is shown in the original Kaggle notebook using fastAI/Pytorch, but now we would like an implementation in Tensorflow/Keras. 

Tensorflow and Keras applications are not fond of changing the batch size after it is declared. For this reason, the batch size cannot be changed through normal layer operations (i.e. KL.Reshape). However, the 

In [41]:
import tensorflow as tf
import tensorflow.keras.layers as KL
import tensorflow.keras.backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.utils import to_categorical
import numpy as np


(60000, 2, 14, 28, 1)
(60000, 10)


In [54]:
def backend_reshape(x, new_shape):
    # casting a shape (bs, N, w/N, H, C) with batch shape None:
    # (-1,N,w/N, H, C), -1 handles the None argument
    return K.reshape(x, new_shape)


In [75]:
# create a toy model
# bs * N * 32 * 32 * C

(x_train, y_train), (x_test, y_test)=tf.keras.datasets.mnist.load_data(
    path='mnist.npz'
)

x_train = np.reshape(x_train, (60000,2, 14, 28,1))
print(x_train.shape)
y_train = to_categorical(y_train, num_classes=10)
print(y_train.shape)
x_test = np.reshape(x_test, (10000, 2, 14, 28, 1))
print(x_test.shape)
y_test = to_categorical(y_test, num_classes=10)
print(y_test.shape)


batch_size = None
N = 2


inputs = KL.Input(shape=(N,14,28,1), batch_size=batch_size)

tf.shape(inputs)[0]
x = KL.Lambda(backend_reshape, arguments={'new_shape': (-1, 14, 28, 1)})(inputs)
x = KL.Conv2D(32, (3,3))(x)
x = KL.Conv2D(32, (3,3))(x)
x = KL.Lambda(backend_reshape, arguments={'new_shape': (-1, 10*N, 24, 32)})(x)
x = KL.Flatten()(x)
x = KL.Dense(10, activation='softmax')(x)
model = Model(inputs=inputs, outputs=x)
model.summary()
model.compile(optimizer='adam', loss='categorical_crossentropy')

model.fit(x=x_train, y=y_train, validation_data = (x_test,y_test), epochs=10)


(60000, 2, 14, 28, 1)
(60000, 10)
(10000, 2, 14, 28, 1)
(10000, 10)
Tensor("lambda_47/Shape:0", shape=(5,), dtype=int32)
Tensor("lambda_48/Shape:0", shape=(4,), dtype=int32)
Model: "model_23"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_39 (InputLayer)        [(None, 2, 14, 28, 1)]    0         
_________________________________________________________________
lambda_47 (Lambda)           (None, 14, 28, 1)         0         
_________________________________________________________________
conv2d_50 (Conv2D)           (None, 12, 26, 32)        320       
_________________________________________________________________
conv2d_51 (Conv2D)           (None, 10, 24, 32)        9248      
_________________________________________________________________
lambda_48 (Lambda)           (None, 20, 24, 32)        0         
_________________________________________________________________
flatten_15 (Flat

<tensorflow.python.keras.callbacks.History at 0x149e46bffef0>

## Tiling
Tiling data and export to a new folder

In [76]:
import os
import cv2
import skimage.io
from tqdm.notebook import tqdm
import zipfile
import numpy as np

In [77]:
TRAIN = '../data/train_images'
MASKS = '../data/train_label_masks/'
OUT_TRAIN = 'train.zip'
OUT_MASKS = 'masks.zip'
sz = 128
N = 16

In [78]:
def tile(img, mask):
    result = []
    shape = img.shape
    pad0,pad1 = (sz - shape[0]%sz)%sz, (sz - shape[1]%sz)%sz
    img = np.pad(img,[[pad0//2,pad0-pad0//2],[pad1//2,pad1-pad1//2],[0,0]],
                constant_values=255)
    mask = np.pad(mask,[[pad0//2,pad0-pad0//2],[pad1//2,pad1-pad1//2],[0,0]],
                constant_values=0)
    img = img.reshape(img.shape[0]//sz,sz,img.shape[1]//sz,sz,3)
    img = img.transpose(0,2,1,3,4).reshape(-1,sz,sz,3)
    mask = mask.reshape(mask.shape[0]//sz,sz,mask.shape[1]//sz,sz,3)
    mask = mask.transpose(0,2,1,3,4).reshape(-1,sz,sz,3)
    if len(img) < N:
        mask = np.pad(mask,[[0,N-len(img)],[0,0],[0,0],[0,0]],constant_values=0)
        img = np.pad(img,[[0,N-len(img)],[0,0],[0,0],[0,0]],constant_values=255)
    idxs = np.argsort(img.reshape(img.shape[0],-1).sum(-1))[:N]
    img = img[idxs]
    mask = mask[idxs]
    for i in range(len(img)):
        result.append({'img':img[i], 'mask':mask[i], 'idx':i})
    return result

In [79]:
x_tot,x2_tot = [],[]
names = [name[:-10] for name in os.listdir(MASKS)]
with zipfile.ZipFile(OUT_TRAIN, 'w') as img_out,\
 zipfile.ZipFile(OUT_MASKS, 'w') as mask_out:
    for name in tqdm(names):
        img = skimage.io.MultiImage(os.path.join(TRAIN,name+'.tiff'))[-1]
        mask = skimage.io.MultiImage(os.path.join(MASKS,name+'_mask.tiff'))[-1]
        tiles = tile(img,mask)
        for t in tiles:
            img,mask,idx = t['img'],t['mask'],t['idx']
            x_tot.append((img/255.0).reshape(-1,3).mean(0))
            x2_tot.append(((img/255.0)**2).reshape(-1,3).mean(0)) 
            #if read with PIL RGB turns into BGR
            img = cv2.imencode('.png',cv2.cvtColor(img, cv2.COLOR_RGB2BGR))[1]
            img_out.writestr(f'{name}_{idx}.png', img)
            mask = cv2.imencode('.png',mask[:,:,0])[1]
            mask_out.writestr(f'{name}_{idx}.png', mask)

HBox(children=(FloatProgress(value=0.0, max=10516.0), HTML(value='')))




In [80]:
#image stats
img_avr =  np.array(x_tot).mean(0)
img_std =  np.sqrt(np.array(x2_tot).mean(0) - img_avr**2)
print('mean:',img_avr, ', std:', np.sqrt(img_std))

mean: [0.90949707 0.8188697  0.87795304] , std: [0.36357649 0.49984502 0.40477625]
