# Tiling and data augmentation: a first approach
The EDA notebook has shown us that data augmentation is needed: color stain augmentation could help bring the distributions of Karolinska and Radboud together, hopefully closing the distinctness of both distributions so that the resulting model can be stain invariant with respect to these two providers.

A second idea is to use tiling: the tissue samples consist of a large white background, which do not need to be evaluated in the model, saving processing and training time. The focus of this notebook is threefold:

* Create a small data augmentation scheme, and plotting/comparing the distributions of the original data, and the augmented distribution

* Creating a tiling example which can be used in a processing pipeline

* Creating a first baseline model, to be improved upon.

A few starting points are provided on Kaggle. 

* https://www.kaggle.com/c/prostate-cancer-grade-assessment/discussion/146855

## Part 1: Tiling implementation



![woof](https://www.googleapis.com/download/storage/v1/b/kaggle-user-content/o/inbox%2F1212661%2Fe6fe32d759a28480343001aa3c661723%2FTILE.png?generation=1588094975239255&alt=media)
*source: https://www.kaggle.com/c/prostate-cancer-grade-assessment/discussion/146855*

In [3]:
(128 - 600%128)%128

40

### hardships with tiling in keras + TF training
The concat tile pool approach requires a dynamic batch size that changes during training. An example is shown in the original Kaggle notebook using fastAI/Pytorch, but now we would like an implementation in Tensorflow/Keras. 

Tensorflow and Keras applications are not fond of changing the batch size after it is declared. For this reason, the batch size cannot be changed through normal layer operations (i.e. KL.Reshape). However, the 

In [41]:
import tensorflow as tf
import tensorflow.keras.layers as KL
import tensorflow.keras.backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.utils import to_categorical
import numpy as np
import os
import cv2
import skimage.io
from tqdm.notebook import tqdm
import zipfile


(60000, 2, 14, 28, 1)
(60000, 10)


## Tiling
The code below tiles the input data at level 2 (lowest resolution) and saves them as PNG images in a separate folder.
An on the fly generator that does this is defined in the TiffGenerator.

In [77]:
TRAIN = '../data/train_images'
MASKS = '../data/train_label_masks/'
OUT_TRAIN = 'train.zip'
OUT_MASKS = 'masks.zip'
sz = 128
N = 16

In [78]:
def tile(img, mask):
    result = []
    shape = img.shape
    pad0,pad1 = (sz - shape[0]%sz)%sz, (sz - shape[1]%sz)%sz
    img = np.pad(img,[[pad0//2,pad0-pad0//2],[pad1//2,pad1-pad1//2],[0,0]],
                constant_values=255)
    mask = np.pad(mask,[[pad0//2,pad0-pad0//2],[pad1//2,pad1-pad1//2],[0,0]],
                constant_values=0)
    img = img.reshape(img.shape[0]//sz,sz,img.shape[1]//sz,sz,3)
    img = img.transpose(0,2,1,3,4).reshape(-1,sz,sz,3)
    mask = mask.reshape(mask.shape[0]//sz,sz,mask.shape[1]//sz,sz,3)
    mask = mask.transpose(0,2,1,3,4).reshape(-1,sz,sz,3)
    if len(img) < N:
        mask = np.pad(mask,[[0,N-len(img)],[0,0],[0,0],[0,0]],constant_values=0)
        img = np.pad(img,[[0,N-len(img)],[0,0],[0,0],[0,0]],constant_values=255)
    idxs = np.argsort(img.reshape(img.shape[0],-1).sum(-1))[:N]
    img = img[idxs]
    mask = mask[idxs]
    for i in range(len(img)):
        result.append({'img':img[i], 'mask':mask[i], 'idx':i})
    return result

In [79]:
x_tot,x2_tot = [],[]
names = [name[:-10] for name in os.listdir(MASKS)]
with zipfile.ZipFile(OUT_TRAIN, 'w') as img_out,\
 zipfile.ZipFile(OUT_MASKS, 'w') as mask_out:
    for name in tqdm(names):
        img = skimage.io.MultiImage(os.path.join(TRAIN,name+'.tiff'))[-1]
        mask = skimage.io.MultiImage(os.path.join(MASKS,name+'_mask.tiff'))[-1]
        tiles = tile(img,mask)
        for t in tiles:
            img,mask,idx = t['img'],t['mask'],t['idx']
            x_tot.append((img/255.0).reshape(-1,3).mean(0))
            x2_tot.append(((img/255.0)**2).reshape(-1,3).mean(0)) 
            #if read with PIL RGB turns into BGR
            img = cv2.imencode('.png',cv2.cvtColor(img, cv2.COLOR_RGB2BGR))[1]
            img_out.writestr(f'{name}_{idx}.png', img)
            mask = cv2.imencode('.png',mask[:,:,0])[1]
            mask_out.writestr(f'{name}_{idx}.png', mask)

HBox(children=(FloatProgress(value=0.0, max=10516.0), HTML(value='')))




In [80]:
#image stats
img_avr =  np.array(x_tot).mean(0)
img_std =  np.sqrt(np.array(x2_tot).mean(0) - img_avr**2)
print('mean:',img_avr, ', std:', np.sqrt(img_std))

mean: [0.90949707 0.8188697  0.87795304] , std: [0.36357649 0.49984502 0.40477625]
