![](https://i.ibb.co/mqDLS1P/cover.jpg)

<span style="color: #0D0D0D; font-family: Trebuchet MS; font-size: 3em;">[TPU] HubMAP Double U-Net Model + Augmentation</span>

<span style="color: #04BF6F; font-family: Trebuchet MS; font-size: 1.4em;">This notebook contains steps for Data Preparation, Data Augmentation, TFRecord Building and Custom Keras Model Building and Training on full scale 512x512 tiled images.</span>

<p style='text-align: justify;'>Double U-Net consists of a combination of two intertwined U-Nets. Pre-trained VGG-19 is used as the encoder sub-network in the first U-Net which is combined with a custom decoder block. The second U-Net has custom encoder and decoder sub-networks to capture more semantic information efficiently. Atrous Spatial Pyramid Pooling (ASPP) is employed to capture contextual information within the network. Several Skip connections are also employed between the two U-Nets. Output concatenations and multiplications are performed in several stages. Squeeze and excite blocks are used to reduce redundant information and improve contextual information capture.</p>

This variation of U-net is based on the paper, [A Deep Convolutional Neural Network for Medical Image Segmentation paper by Debesh Jha et al](https://arxiv.org/pdf/2006.04868.pdf)

<p style='text-align: justify;'>I have added several intuitive & operational optimizations including the conversion of Upscaling layers to Convolutional 2D Tranpose layers so that TPU for Keras is supported, also allowing upscaling filers to learn feature groups. Additionally, Convlutional Layers have been added to the output to combine the output feature maps from the two U-Nets.</p>

Detailed implemenation details are present in the model building section.

<p style='text-align: justify;'>Superior performance is exibited by this network in several medical segmentation datasets, covering various imaging modalities such as colonoscopy, dermoscopy & and microscopy. Experiments on the 2015 MICCAI sub-challenge on automatic polyp detection dataset, the CVC-ClinicDB, the 2018 Data Science Bowl challenge, and the Lesion boundary segmen-tation datasets demonstrate that the DoubleU-Net outperforms U-Net and the baseline models.</p>

<p style='text-align: justify;'><span style="color: #55038C; font-family: Trebuchet MS; font-size: 1.3em;">I have created an Augmented Image Dataset which consists of all 512x512 tiles with gloms and Tfrecords with original and augmented images split into validation and train sets. These have been uploaded and are available as public datasets.</span></p>

**Augmented images 512x512 tiled. Augmentations done only for images with glomerules. Validation set was not augmented to avoid leakage.**

https://www.kaggle.com/sreevishnudamodaran/hubmap-512x512-augmented


**TFRecord dataset with actual and augmented images grouped into 90-130MB records.**

https://www.kaggle.com/sreevishnudamodaran/hubmap-512x512-tfrecords-with-aug






[![Ask Me Anything !](https://img.shields.io/badge/Ask%20me-anything-1abc9c.svg?style=flat-square&logo=appveyor)](https://www.kaggle.com/sreevishnudamodaran)



![TPU!](https://img.shields.io/badge/Accelerator-TPU-purple?style=flat-square&logo=appveyor)

![Upvote!](https://img.shields.io/badge/Upvote-If%20you%20like%20my%20work-green?style=for-the-badge&logo=appveyor)

## Import Libraries

In [None]:
%matplotlib inline

import json
import os
import glob
import re
import datetime
import os.path as osp
from path import Path
import collections
import sys
import uuid
import random
import warnings
from itertools import chain
from tqdm import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rcParams
sns.set(rc={"font.size":9,"axes.titlesize":15,"axes.labelsize":9,
            "axes.titlepad":2, "axes.labelpad":9, "legend.fontsize":7,
            "legend.title_fontsize":7, 'axes.grid' : False,
           'figure.titlesize':35})

# from skimage import measure
from PIL import Image
import cv2
# from skimage.io import imread, imshow, imread_collection, concatenate_images
# from skimage.transform import resize
# from skimage.morphology import label

import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Conv2D, Conv2DTranspose
from tensorflow.keras.layers import MaxPooling2D, UpSampling2D, Conv2DTranspose
from tensorflow.keras.layers import concatenate
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras import backend as K
from tensorflow.keras import layers
from tensorflow.keras.optimizers import Adam

from keras.engine.topology import Layer
from keras.utils.generic_utils import get_custom_objects
from kaggle_datasets import KaggleDatasets
from kaggle_secrets import UserSecretsClient

from sklearn.model_selection import train_test_split
from tensorflow.keras.losses import binary_crossentropy
import plotly
import plotly.graph_objs as go
import numpy as np   # So we can use random numbers in examples


## Read Images and Masks

Using the [HuBMAP: 512x512 full size tiles](https://www.kaggle.com/xhlulu/hubmap-512x512-full-size-tiles) by xhlulu as the base before augmentation

In [None]:
# image_paths = glob.glob("/kaggle/input/hubmap-1024x1024/train/*.png")
# mask_paths = glob.glob("/kaggle/input/hubmap-1024x1024/masks/*.png")
# len(image_paths)

image_paths = glob.glob("/kaggle/input/hubmap-512x512-full-size-tiles/train/*.png")
mask_paths = glob.glob("/kaggle/input/hubmap-512x512-full-size-tiles/masks/*.png")
len(image_paths)

In [None]:
# Doing augmentation in parts due to harddisk limitations and the large dataset size

image_paths = image_paths[:11607]
mask_paths = mask_paths[:11607]
len(image_paths)
# image_paths = image_paths[11608:]
# mask_paths = mask_paths[11608:]
# len(image_paths)

## Helper Functions

In [None]:
def read_single(img_path, msk_path):
    """ Read the image and mask from the given path. """
    image = cv2.imread(img_path, cv2.IMREAD_COLOR)
    mask = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE)
    return image, mask

def read_data(image_paths, mask_paths, gloms_only=False):
    images = []
    masks = []

    for img_path, msk_path in tqdm(zip(image_paths, mask_paths), total=len(image_paths)):

        image, mask = read_single(img_path, msk_path)
        mask_density = np.count_nonzero(mask)   
        if gloms_only:
            if(mask_density>0):
                images.append(image)
                masks.append(mask)
        else:
            images.append(image)
            masks.append(mask)

    images = np.array(images)
    masks = np.array(masks)
    print('images shape:', images.shape)
    print('masks shape:', masks.shape)
    return images, masks

In [None]:
from google.cloud import storage
storage_client = storage.Client(project='placesproject-284409')

def create_bucket(dataset_name):
    """Creates a new bucket. https://cloud.google.com/storage/docs/ """
    bucket = storage_client.create_bucket(dataset_name)
    print('Bucket {} created'.format(bucket.name))

def upload_blob(bucket_name, source_file_name, destination_blob_name):
    """Uploads a file to the bucket. https://cloud.google.com/storage/docs/ """
    bucket = storage_client.get_bucket(bucket_name)
    blob = bucket.blob(destination_blob_name)
    blob.upload_from_filename(source_file_name)
#     print('File {} uploaded to {}.'.format(
#         source_file_name,
#         destination_blob_name))
    
def list_blobs(bucket_name):
    """Lists all the blobs in the bucket. https://cloud.google.com/storage/docs/"""
    blob_list = []
    blobs = storage_client.list_blobs(bucket_name)
    for blob in blobs:
        blob_list.append(blob.name)
    #print(blob_list)
    return blob_list
        
def download_to_kaggle(bucket_name,destination_directory,file_name):
    """Takes the data from your GCS Bucket and puts it into the working directory of your Kaggle notebook"""
    os.makedirs(destination_directory, exist_ok = True)
    full_file_path = os.path.join(destination_directory, file_name)
    blobs = storage_client.list_blobs(bucket_name)
    for blob in blobs:
        blob.download_to_filename(full_file_path)

In [None]:
# Set your own project id here
PROJECT_ID = 'placesproject-284409'
from google.cloud import storage
storage_client = storage.Client(project=PROJECT_ID)
user_secrets = UserSecretsClient()
user_credential = user_secrets.get_gcloud_credential()
user_secrets.set_tensorflow_credential(user_credential)

In [None]:
lowband_density_values = []
mask_density_values = []

for img_path, msk_path in tqdm(zip(image_paths, mask_paths), total=len(image_paths)):
    image, mask = read_single(img_path, msk_path)
    img_hist = np.histogram(image)
    #print("img_hist", img_hist)
    lowband_density = np.sum(img_hist[0][0:4])
    mask_density = np.count_nonzero(mask)
    #print("lowband_density", lowband_density)
    #print("highband_density", highband_density)
    #print("mask_density", mask_density)
    lowband_density_values.append(lowband_density)
    mask_density_values.append(mask_density)
train_helper_df = pd.DataFrame(data=list(zip(image_paths, mask_paths, lowband_density_values,
                                             mask_density_values)),
                               columns=['image_path','mask_path', 'lowband_density', 'mask_density'])
train_helper_df.astype(dtype={'image_path':'object','mask_path':'object',
                                      'lowband_density':'int64', 'mask_density':'int64'})

In [None]:
# bucket_name = 'hubmap_512x512_with_aug'
# # train_helper_df.to_csv('./hubmap_dataset_helper.csv')
# # upload_blob(bucket_name, './hubmap_dataset_helper.csv', 'hubmap_dataset_helper.csv')

# download_to_kaggle(bucket_name, '/kaggle/working', 'hubmap_dataset_helper.csv')
# train_helper_df = pd.read_csv('hubmap_dataset_helper.csv')

In [None]:
train_helper_df.sample(5)

## Selecting Images with Tissues

In [None]:
images_tissue = train_helper_df[train_helper_df.lowband_density>100].image_path
masks_tissue = train_helper_df[train_helper_df.lowband_density>100].mask_path
images_tissue.shape

## Visualize Samples

In [None]:
images, masks = read_data(images_tissue[1200:1218], masks_tissue[1200:1218])

In [None]:
max_rows = 6
max_cols = 6
fig, ax = plt.subplots(max_rows, max_cols, figsize=(20,18))
fig.suptitle('Sample Images', y=0.93)
plot_count = (max_rows*max_cols)//2
for idx, (img, mas) in enumerate(zip(images[:plot_count], masks[:plot_count])):
    row = (idx//max_cols)*2
    row_masks = row+1
    col = idx % max_cols
    ax[row, col].imshow(img)
    #sns.distplot(img_array.flatten(), ax=ax[1]);
    ax[row_masks, col].imshow(mas)


## Data Preparation
## Augmentation

Augmentation is done only on images with gloms as the dataset is already huge.

Validation samples are split and kept aside and it is not used for augmentation to avoid leakage of train data to val data

In [None]:
image_90_per_tissues, image_val_files, mask_90_per_tissues, mask_val_files = train_test_split(images_tissue, masks_tissue, test_size=0.30, random_state=17)
print("Split Counts\n\tImage_90_per_files:\t{0}\n\tMask_90_per_files:\t{2}\n\tVal Images:\t\t{1}\n\tVal Masks:\t\t{3}\n"
      .format(len(image_90_per_tissues), len(image_val_files), len(mask_90_per_tissues), len(mask_val_files)))

In [None]:
from albumentations import (
CLAHE,
ElasticTransform,
GridDistortion,
OpticalDistortion,
HorizontalFlip,
RandomBrightnessContrast,
RandomGamma,
HueSaturationValue,
RGBShift,
MedianBlur,
GaussianBlur,
GaussNoise,
ChannelShuffle,
CoarseDropout
)

def augment_data(image_paths, mask_paths):  

    if not os.path.exists('hubmap_512x512_augmented/images_aug2'):
        os.makedirs('hubmap_512x512_augmented/images_aug2')
    if not os.path.exists('hubmap_512x512_augmented/masks_aug2'):
        os.makedirs('hubmap_512x512_augmented/masks_aug2')

    for image, mask in tqdm(zip(image_paths, mask_paths), total=len(image_paths)):
        images_aug = []
        masks_aug = []
        image_name = Path(image).stem
        mask_name = Path(mask).stem

        x, y = read_single(image, mask)
        mask_density = np.count_nonzero(y)

        ## Augmenting only images with Gloms
        if(mask_density>0):

            try:
                h, w, c = x.shape
            except Exception as e:
                image = image[:-1]
                x, y = read_single(image, mask)
                h, w, c = x.shape

            aug = CLAHE(clip_limit=1.0, tile_grid_size=(8, 8), always_apply=False, p=1)
            augmented = aug(image=x, mask=y)
            x0 = augmented['image']
            y0 = augmented['mask']

            ## ElasticTransform
            aug = ElasticTransform(p=1, alpha=120, sigma=512*0.05, alpha_affine=512*0.03)
            augmented = aug(image=x, mask=y)
            x1 = augmented['image']
            y1 = augmented['mask']

            ## Grid Distortion
            aug = GridDistortion(p=1)
            augmented = aug(image=x, mask=y)
            x2 = augmented['image']
            y2 = augmented['mask']

            ## Optical Distortion
            aug = OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5)
            augmented = aug(image=x, mask=y)
            x3 = augmented['image']
            y3 = augmented['mask']

            ## Horizontal Flip
            aug = HorizontalFlip(p=1)
            augmented = aug(image=x, mask=y)
            x4 = augmented['image']
            y4 = augmented['mask']

            ## Random Brightness and Contrast
            aug = RandomBrightnessContrast(p=1)
            augmented = aug(image=x, mask=y)
            x5 = augmented['image']
            y5 = augmented['mask']

            aug = RandomGamma(p=1)
            augmented = aug(image=x, mask=y)
            x6 = augmented['image']
            y6 = augmented['mask']

            aug = HueSaturationValue(p=1)
            augmented = aug(image=x, mask=y)
            x7 = augmented['image']
            y7 = augmented['mask']

            aug = RGBShift(p=1)
            augmented = aug(image=x, mask=y)
            x8 = augmented['image']
            y8 = augmented['mask']

            aug = MedianBlur(p=1, blur_limit=5)
            augmented = aug(image=x, mask=y)
            x9 = augmented['image']
            y9 = augmented['mask']

            aug = GaussianBlur(p=1, blur_limit=3)
            augmented = aug(image=x, mask=y)
            x10 = augmented['image']
            y10 = augmented['mask']

            aug = GaussNoise(p=1)
            augmented = aug(image=x, mask=y)
            x11 = augmented['image']
            y11 = augmented['mask']

            aug = ChannelShuffle(p=1)
            augmented = aug(image=x, mask=y)
            x12 = augmented['image']
            y12 = augmented['mask']

            aug = CoarseDropout(p=1, max_holes=8, max_height=32, max_width=32)
            augmented = aug(image=x, mask=y)
            x13 = augmented['image']
            y13 = augmented['mask']

            images_aug.extend([
                    x0, x1, x2, x3, x4, x5, x6,
                    x7, x8, x9, x10, x11, x12,
                    x13])

            masks_aug.extend([
                    y0, y1, y2, y3, y4, y5, y6,
                    y7, y8, y9, y10, y11, y12,
                    y13])

            idx = 0
            for i, m in zip(images_aug, masks_aug):
                tmp_image_name = f"{image_name}_{idx}.png"
                tmp_mask_name  = f"{mask_name}_{idx}.png"

                image_path = os.path.join("hubmap_512x512_augmented/images_aug2/", tmp_image_name)
                mask_path  = os.path.join("hubmap_512x512_augmented/masks_aug2/", tmp_mask_name)

                cv2.imwrite(image_path, i)
                cv2.imwrite(mask_path, m)

                idx += 1

    return images_aug, masks_aug

images_aug, masks_aug = augment_data(image_90_per_tissues, mask_90_per_tissues)

## Load Augmented Dataset
Loading the augmented images from the registered dataset.

In [None]:
aug_img_paths = glob.glob("/kaggle/input/hubmap-512x512-augmented/images_aug/*.png")
aug_msk_paths = glob.glob("/kaggle/input/hubmap-512x512-augmented/masks_aug/*.png")
aug_img_paths2 = glob.glob("/kaggle/input/hubmap-512x512-augmented/images_aug2/*.png")
aug_msk_paths2 = glob.glob("/kaggle/input/hubmap-512x512-augmented/masks_aug2/*.png")

aug_img_paths.extend(aug_img_paths2)
aug_msk_paths.extend(aug_msk_paths2)
print("Number of Augmented Images", len(aug_img_paths))
print("Number of Augmented Masks", len(aug_msk_paths))

## Visualize Augmented

In [None]:
aug_img_paths = aug_img_paths[-100:]
aug_msk_paths = aug_msk_paths[-100:]
aug_imgs, aug_msks = read_data(aug_img_paths, aug_msk_paths)

In [None]:
max_rows = 10
max_cols = 4
fig, ax = plt.subplots(max_rows, max_cols, figsize=(20,32))
plot_count = (max_rows*max_cols)//2
for idx, (img, mas) in enumerate(zip(aug_imgs[:plot_count], aug_msks[:plot_count])):
    row = (idx//max_cols)*2
    row_masks = row+1
    col = idx % max_cols
    ax[row, col].imshow(img)
    ax[row_masks, col].imshow(mas)

In [None]:
sel_img_paths = [img_path for img_path in aug_img_paths if '_0.png' in img_path]
sel_msk_paths = [msk_path for msk_path in aug_msk_paths if '_0.png' in msk_path]
aug_imgs, aug_msks = read_data(sel_img_paths, sel_msk_paths)

In [None]:
max_rows = 2
max_cols = 2
fig, ax = plt.subplots(max_rows, max_cols, figsize=(20,9))
fig.suptitle("CLAHE", y=0.95)
plot_count = (max_rows*max_cols)//2
for idx, (img, mas) in enumerate(zip(aug_imgs[:plot_count], aug_msks[:plot_count])):
    row = (idx//max_cols)*2
    row_masks = row+1
    col = idx % max_cols
    ax[row, col].imshow(img)
    ax[row_masks, col].imshow(mas)

In [None]:
sel_img_paths = [img_path for img_path in aug_img_paths if '_1.png' in img_path]
sel_msk_paths = [msk_path for msk_path in aug_msk_paths if '_1.png' in msk_path]
aug_imgs, aug_msks = read_data(sel_img_paths, sel_msk_paths)

In [None]:
max_rows = 2
max_cols = 2
fig, ax = plt.subplots(max_rows, max_cols, figsize=(20,9))
fig.suptitle("ElasticTransform", y=0.95)
plot_count = (max_rows*max_cols)//2
for idx, (img, mas) in enumerate(zip(aug_imgs[:plot_count], aug_msks[:plot_count])):
    row = (idx//max_cols)*2
    row_masks = row+1
    col = idx % max_cols
    ax[row, col].imshow(img)
    ax[row_masks, col].imshow(mas)

In [None]:
sel_img_paths = [img_path for img_path in aug_img_paths if '_2.png' in img_path]
sel_msk_paths = [msk_path for msk_path in aug_msk_paths if '_2.png' in msk_path]
aug_imgs, aug_msks = read_data(sel_img_paths, sel_msk_paths)

In [None]:
max_rows = 2
max_cols = 2
fig, ax = plt.subplots(max_rows, max_cols, figsize=(20,9))
fig.suptitle("GridDistortion", y=0.95)
plot_count = (max_rows*max_cols)//2
for idx, (img, mas) in enumerate(zip(aug_imgs[:plot_count], aug_msks[:plot_count])):
    row = (idx//max_cols)*2
    row_masks = row+1
    col = idx % max_cols
    ax[row, col].imshow(img)
    ax[row_masks, col].imshow(mas)

In [None]:
sel_img_paths = [img_path for img_path in aug_img_paths if '_3.png' in img_path]
sel_msk_paths = [msk_path for msk_path in aug_msk_paths if '_3.png' in msk_path]
aug_imgs, aug_msks = read_data(sel_img_paths, sel_msk_paths)

In [None]:
max_rows = 2
max_cols = 2
fig, ax = plt.subplots(max_rows, max_cols, figsize=(20,9))
fig.suptitle("OpticalDistortion", y=0.95)
plot_count = (max_rows*max_cols)//2
for idx, (img, mas) in enumerate(zip(aug_imgs[:plot_count], aug_msks[:plot_count])):
    row = (idx//max_cols)*2
    row_masks = row+1
    col = idx % max_cols
    ax[row, col].imshow(img)
    ax[row_masks, col].imshow(mas)

In [None]:
sel_img_paths = [img_path for img_path in aug_img_paths if '_4.png' in img_path]
sel_msk_paths = [msk_path for msk_path in aug_msk_paths if '_4.png' in msk_path]
aug_imgs, aug_msks = read_data(sel_img_paths, sel_msk_paths)

In [None]:
max_rows = 2
max_cols = 2
fig, ax = plt.subplots(max_rows, max_cols, figsize=(20,9))
fig.suptitle("HorizontalFlip", y=0.95)
plot_count = (max_rows*max_cols)//2
for idx, (img, mas) in enumerate(zip(aug_imgs[:plot_count], aug_msks[:plot_count])):
    row = (idx//max_cols)*2
    row_masks = row+1
    col = idx % max_cols
    ax[row, col].imshow(img)
    ax[row_masks, col].imshow(mas)

In [None]:
sel_img_paths = [img_path for img_path in aug_img_paths if '_5.png' in img_path]
sel_msk_paths = [msk_path for msk_path in aug_msk_paths if '_5.png' in msk_path]
aug_imgs, aug_msks = read_data(sel_img_paths, sel_msk_paths)

In [None]:
max_rows = 2
max_cols = 2
fig, ax = plt.subplots(max_rows, max_cols, figsize=(20,9))
fig.suptitle("RandomBrightnessContrast", y=0.95)
plot_count = (max_rows*max_cols)//2
for idx, (img, mas) in enumerate(zip(aug_imgs[:plot_count], aug_msks[:plot_count])):
    row = (idx//max_cols)*2
    row_masks = row+1
    col = idx % max_cols
    ax[row, col].imshow(img)
    ax[row_masks, col].imshow(mas)

In [None]:
sel_img_paths = [img_path for img_path in aug_img_paths if '_6.png' in img_path]
sel_msk_paths = [msk_path for msk_path in aug_msk_paths if '_6.png' in msk_path]
aug_imgs, aug_msks = read_data(sel_img_paths, sel_msk_paths)

In [None]:
max_rows = 2
max_cols = 2
fig, ax = plt.subplots(max_rows, max_cols, figsize=(20,9))
fig.suptitle("RandomGamma", y=0.95)
plot_count = (max_rows*max_cols)//2
for idx, (img, mas) in enumerate(zip(aug_imgs[:plot_count], aug_msks[:plot_count])):
    row = (idx//max_cols)*2
    row_masks = row+1
    col = idx % max_cols
    ax[row, col].imshow(img)
    ax[row_masks, col].imshow(mas)

In [None]:
sel_img_paths = [img_path for img_path in aug_img_paths if '_7.png' in img_path]
sel_msk_paths = [msk_path for msk_path in aug_msk_paths if '_7.png' in msk_path]
aug_imgs, aug_msks = read_data(sel_img_paths, sel_msk_paths)

In [None]:
max_rows = 2
max_cols = 2
fig, ax = plt.subplots(max_rows, max_cols, figsize=(20,9))
fig.suptitle("HueSaturationValue", y=0.95)
plot_count = (max_rows*max_cols)//2
for idx, (img, mas) in enumerate(zip(aug_imgs[:plot_count], aug_msks[:plot_count])):
    row = (idx//max_cols)*2
    row_masks = row+1
    col = idx % max_cols
    ax[row, col].imshow(img)
    ax[row_masks, col].imshow(mas)

In [None]:
sel_img_paths = [img_path for img_path in aug_img_paths if '_8.png' in img_path]
sel_msk_paths = [msk_path for msk_path in aug_msk_paths if '_8.png' in msk_path]
aug_imgs, aug_msks = read_data(sel_img_paths, sel_msk_paths)

In [None]:
max_rows = 2
max_cols = 2
fig, ax = plt.subplots(max_rows, max_cols, figsize=(20,9))
fig.suptitle("RGBShift", y=0.95)
plot_count = (max_rows*max_cols)//2
for idx, (img, mas) in enumerate(zip(aug_imgs[:plot_count], aug_msks[:plot_count])):
    row = (idx//max_cols)*2
    row_masks = row+1
    col = idx % max_cols
    ax[row, col].imshow(img)
    ax[row_masks, col].imshow(mas)

In [None]:
sel_img_paths = [img_path for img_path in aug_img_paths if '_9.png' in img_path]
sel_msk_paths = [msk_path for msk_path in aug_msk_paths if '_9.png' in msk_path]
aug_imgs, aug_msks = read_data(sel_img_paths, sel_msk_paths)

In [None]:
max_rows = 2
max_cols = 2
fig, ax = plt.subplots(max_rows, max_cols, figsize=(20,9))
fig.suptitle("MedianBlur", y=0.95)
plot_count = (max_rows*max_cols)//2
for idx, (img, mas) in enumerate(zip(aug_imgs[:plot_count], aug_msks[:plot_count])):
    row = (idx//max_cols)*2
    row_masks = row+1
    col = idx % max_cols
    ax[row, col].imshow(img)
    ax[row_masks, col].imshow(mas)

In [None]:
sel_img_paths = [img_path for img_path in aug_img_paths if '_10.png' in img_path]
sel_msk_paths = [msk_path for msk_path in aug_msk_paths if '_10.png' in msk_path]
aug_imgs, aug_msks = read_data(sel_img_paths, sel_msk_paths)

In [None]:
max_rows = 2
max_cols = 2
fig, ax = plt.subplots(max_rows, max_cols, figsize=(20,9))
fig.suptitle("GaussianBlur", y=0.95)
plot_count = (max_rows*max_cols)//2
for idx, (img, mas) in enumerate(zip(aug_imgs[:plot_count], aug_msks[:plot_count])):
    row = (idx//max_cols)*2
    row_masks = row+1
    col = idx % max_cols
    ax[row, col].imshow(img)
    ax[row_masks, col].imshow(mas)

In [None]:
sel_img_paths = [img_path for img_path in aug_img_paths if '_11.png' in img_path]
sel_msk_paths = [msk_path for msk_path in aug_msk_paths if '_11.png' in msk_path]
aug_imgs, aug_msks = read_data(sel_img_paths, sel_msk_paths)

In [None]:
max_rows = 2
max_cols = 2
fig, ax = plt.subplots(max_rows, max_cols, figsize=(20,9))
fig.suptitle("GaussNoise", y=0.95)
plot_count = (max_rows*max_cols)//2
for idx, (img, mas) in enumerate(zip(aug_imgs[:plot_count], aug_msks[:plot_count])):
    row = (idx//max_cols)*2
    row_masks = row+1
    col = idx % max_cols
    ax[row, col].imshow(img)
    ax[row_masks, col].imshow(mas)

In [None]:
sel_img_paths = [img_path for img_path in aug_img_paths if '_12.png' in img_path]
sel_msk_paths = [msk_path for msk_path in aug_msk_paths if '_12.png' in msk_path]
aug_imgs, aug_msks = read_data(sel_img_paths, sel_msk_paths)

In [None]:
max_rows = 2
max_cols = 2
fig, ax = plt.subplots(max_rows, max_cols, figsize=(20,9))
fig.suptitle("ChannelShuffle", y=0.95)
plot_count = (max_rows*max_cols)//2
for idx, (img, mas) in enumerate(zip(aug_imgs[:plot_count], aug_msks[:plot_count])):
    row = (idx//max_cols)*2
    row_masks = row+1
    col = idx % max_cols
    ax[row, col].imshow(img)
    ax[row_masks, col].imshow(mas)

In [None]:
sel_img_paths = [img_path for img_path in aug_img_paths if '_13.png' in img_path]
sel_msk_paths = [msk_path for msk_path in aug_msk_paths if '_13.png' in msk_path]
aug_imgs, aug_msks = read_data(sel_img_paths, sel_msk_paths)

In [None]:
max_rows = 2
max_cols = 2
fig, ax = plt.subplots(max_rows, max_cols, figsize=(20,9))
fig.suptitle("CoarseDropout", y=0.95)
plot_count = (max_rows*max_cols)//2
for idx, (img, mas) in enumerate(zip(aug_imgs[:plot_count], aug_msks[:plot_count])):
    row = (idx//max_cols)*2
    row_masks = row+1
    col = idx % max_cols
    ax[row, col].imshow(img)
    ax[row_masks, col].imshow(mas)


## Merge Augmented Data

In [None]:
images_tissue = train_helper_df[train_helper_df.lowband_density>100].image_path
masks_tissue = train_helper_df[train_helper_df.lowband_density>100].mask_path

print("Number of images with tissue:", images_tissue.shape)
print("Number of masks with tissue:", masks_tissue.shape)

In [None]:
aug_img_paths = glob.glob("/kaggle/input/hubmap-512x512-augmented/images_aug/*.png")
aug_msk_paths = glob.glob("/kaggle/input/hubmap-512x512-augmented/masks_aug/*.png")
aug_img_paths2 = glob.glob("/kaggle/input/hubmap-512x512-augmented/images_aug2/*.png")
aug_msk_paths2 = glob.glob("/kaggle/input/hubmap-512x512-augmented/masks_aug2/*.png")

aug_img_paths.extend(aug_img_paths2)
aug_msk_paths.extend(aug_msk_paths2)

print("Number of Augmented Images", len(aug_img_paths))
print("Number of Augmented Masks", len(aug_msk_paths))

In [None]:
val_files = pd.read_csv('../input/hubmap-512x512-augmented/validation_files.csv')
val_files.sample(3)

In [None]:
image_val_files = val_files['image_val_files'].tolist()
mask_val_files = val_files['mask_val_files'].tolist()
print("Total Val Image Count:", len(image_val_files))
print("Total Val Mask Count:", len(mask_val_files))
image_val_files[:2]

### Note: Including Augmented Images only in the Train Set and not in the Validation Set to prevent leakage of validation data into training data

In [None]:
# image_90_per_tissues, image_val_files, mask_90_per_tissues, mask_val_files = train_test_split(images_tissue, masks_tissue, test_size=0.35, random_state=13)
# print("Split Counts\n\tImage_90_per_files:\t{0}\n\tMask_90_per_files:\t{2}\n\tVal Images:\t\t{1}\n\tVal Masks:\t\t{3}\n"
#       .format(len(image_90_per_tissues), len(image_val_files), len(mask_90_per_tissues), len(mask_val_files)))


In [None]:
image_90_per_tissues = image_90_per_tissues.tolist()
image_90_per_tissues.extend(aug_img_paths)
mask_90_per_tissues = mask_90_per_tissues.tolist()
mask_90_per_tissues.extend(aug_msk_paths)
print("Total Train Image Count:", len(image_90_per_tissues))
print("Total Train Mask Count:", len(mask_90_per_tissues))

In [None]:
image_45_per_files1, image_45_per_files2, mask_45_per_files1, mask_45_per_files2 = train_test_split(image_90_per_tissues, mask_90_per_tissues, test_size=0.5, random_state=13)
print("Split Counts\n\timage_45_per_files1:\t{0}\n\tmask_45_per_files1:\t{2}\n\timage_45_per_files2:\t{1}\n\tmask_45_per_files2:\t{3}\n"
      .format(len(image_45_per_files1), len(image_45_per_files2), len(mask_45_per_files1), len(mask_45_per_files2)))

## Build TfRecords

In [None]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

In [None]:
def image_example(image, mask):
    image_shape = image.shape  
    img_bytes = image.tostring()
    mask_bytes = mask.tostring()
    feature = {
        'height': _int64_feature(image_shape[0]),
        'width': _int64_feature(image_shape[1]),
        'num_channels': _int64_feature(image_shape[2]),
        'img_bytes': _bytes_feature(img_bytes),
        'mask' : _bytes_feature(mask_bytes),
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))

def create_tfrecord(image, mask, output_path):
    opts = tf.io.TFRecordOptions(compression_type="GZIP")
    with tf.io.TFRecordWriter(output_path, opts) as writer:
        tf_example = image_example(image, mask)
        writer.write(tf_example.SerializeToString())
    writer.close()

In [None]:
def write_dataset(img_files, msk_files, records_per_part, prefix):
    opts = tf.io.TFRecordOptions(compression_type="GZIP")
    part_num = 0
    num_records = 0
    output_path = prefix+'_part{}.tfrecords'.format(part_num)
    writer = tf.io.TFRecordWriter(output_path, opts)
    
    for img_file, msk_file in tqdm(zip(img_files, msk_files), total=len(img_files), position=0, leave=True):
            image, mask = read_single(img_file, msk_file)
            assert image.shape == (512, 512, 3), print("Wrong image.shape", image.shape)
            assert mask.shape == (512, 512), print("mask.shape", mask.shape)
            #print("image.shape", image.shape)
            mask = np.expand_dims(mask, axis=-1)
            tf_example = image_example(image, mask)
            writer.write(tf_example.SerializeToString())
            num_records += 1  
            if(num_records == records_per_part - 1):
                # close current file and open new one
                print("wrote part #{}".format(part_num))
                writer.close()
                part_num += 1
                output_path = prefix+'_part{}.tfrecords'.format(part_num)
                writer = tf.io.TFRecordWriter(output_path, opts)
                num_records = 0
    writer.close()

In [None]:
bucket_name = 'hubmap_train_job_2'         
try:
    create_bucket(bucket_name)   
except:
    pass

In [None]:
if not os.path.exists('train_tfrecords'):
    os.makedirs('train_tfrecords')

print("Writing Train Dataset")
write_dataset(image_45_per_files1, mask_45_per_files1, 256, '/kaggle/working/train_tfrecords/train1')

In [None]:
!ls -lah ./train_tfrecords | wc -l

In [None]:
files = glob.glob("/kaggle/working/train_tfrecords/*.tfrecords")
for file in files:
    file_name = os.path.join('train', os.path.basename(Path(file)))
    print(file_name)
    upload_blob(bucket_name, file, file_name)

In [None]:
!rm -R ./train_tfrecords

In [None]:
if not os.path.exists('train_tfrecords'):
    os.makedirs('train_tfrecords')
    
write_dataset(image_45_per_files2, mask_45_per_files2, 256, '/kaggle/working/train_tfrecords/train2')

In [None]:
!ls -lah ./train_tfrecords

In [None]:
files = glob.glob("/kaggle/working/train_tfrecords/*.tfrecords")
for file in files:
    file_name = os.path.join('train', os.path.basename(Path(file)))
    print(file_name)
    upload_blob(bucket_name, file, file_name)

In [None]:
!rm -R ./train_tfrecords

In [None]:
if not os.path.exists('val_tfrecords'):
    os.makedirs('val_tfrecords')
    
print("Writing Validation Dataset")
write_dataset(image_val_files, mask_val_files, 256, '/kaggle/working/val_tfrecords/val')

In [None]:
!ls -lah ./val_tfrecords | wc -l

In [None]:
files = glob.glob("/kaggle/working/val_tfrecords/*.tfrecords")
for file in files:
    file_name = os.path.join('val', os.path.basename(Path(file)))
    print(file_name)
    upload_blob(bucket_name, file, file_name)

In [None]:
!rm -R ./val_tfrecords

## Intialize and Get TPU Ready

In [None]:
ACCELERATOR_TYPE = 'TPU'

if ACCELERATOR_TYPE == 'TPU':
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)

    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.MirroredStrategy()

In [None]:
user_secrets = UserSecretsClient()
user_credential = user_secrets.get_gcloud_credential()
user_secrets.set_tensorflow_credential(user_credential)

## View TFRecord Samples

In [None]:
AUTO = tf.data.experimental.AUTOTUNE

image_feature_description = {
    'height': tf.io.FixedLenFeature([], tf.int64),
    'width': tf.io.FixedLenFeature([], tf.int64),
    'num_channels': tf.io.FixedLenFeature([], tf.int64),
    'img_bytes': tf.io.FixedLenFeature([], tf.string),
    'mask': tf.io.FixedLenFeature([], tf.string),
}

def _parse_image_and_masks_function(example_proto):
    single_example = tf.io.parse_single_example(example_proto, image_feature_description)
    img_bytes =  tf.io.decode_raw(single_example['img_bytes'], out_type='uint8')
    img_array = tf.reshape(img_bytes, (512, 512, 3))
    mask_bytes =  tf.io.decode_raw(single_example['mask'], out_type='bool')
    mask = tf.reshape(mask_bytes, (512, 512, 1))
    
    ## normalize images array and cast image and mask to float32
#     img_array = tf.cast(img_array, tf.float32) / 255.0
#     mask = tf.cast(mask, tf.float32)
    return img_array, mask

def read_dataset(storage_file_path):
    encoded_image_dataset = tf.data.TFRecordDataset(storage_file_path, compression_type="GZIP")
    parsed_image_dataset = encoded_image_dataset.map(_parse_image_and_masks_function)
    return parsed_image_dataset

In [None]:
# Get the credential from the Cloud SDK
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
user_credential = user_secrets.get_gcloud_credential()

# Set the credentials
user_secrets.set_tensorflow_credential(user_credential)

# Use a familiar call to get the GCS path of the dataset
from kaggle_datasets import KaggleDatasets
GCS_DS_PATH = KaggleDatasets().get_gcs_path('hubmap-512x512-tfrecords-with-aug')
GCS_DS_PATH

In [None]:
train_tf_gcs = GCS_DS_PATH+'/train/*.tfrecords'
val_tf_gcs = GCS_DS_PATH+'/val/*.tfrecords'
train_tf_files = tf.io.gfile.glob(train_tf_gcs)
val_tf_files = tf.io.gfile.glob(val_tf_gcs)
print(val_tf_files[:3])
print("Train TFrecord Files:", len(train_tf_files))
print("Val TFrecord Files:", len(val_tf_files))

In [None]:
train_dataset = read_dataset(train_tf_files[15])
validation_dataset = read_dataset(val_tf_files[15])

train_image = []
train_mask =[]
for image, mask in train_dataset.take(5):
    train_image, train_mask = image, mask
train_mask = np.squeeze(train_mask)
    
test_image = []
test_mask =[]
for image, mask in validation_dataset.take(5):
    test_image, test_mask = image, mask
test_mask = np.squeeze(test_mask)
    
fig, ax = plt.subplots(2,2,figsize=(20,10))
ax[0][0].imshow(train_image)
ax[0][1].imshow(train_mask)
ax[1][0].imshow(test_image)
ax[1][1].imshow(test_mask)

## Model Building

The model consists of a VGG19 pretrined sub-network as an encoder, trained on imagenet and a custom decoder sub-network which forms the First U-net Network (NETWORK1). The Second U-net (NETWORK2) consists of a element-wise image-mask multiplier, custom encoder blocks and custom decoder blocks.

VGG and custom encoder blocker encodes the information contained in the input image. Each encoder block in the NETWORK2 performs two 3×3 convolution operation, each followed by a batch normalization. The batch normalization reduces the internal co-variant shift and also regularizes the model. A Rectified Linear Unit (ReLU) activation function is applied, which introduces non-linearity into the model. This is followed by a squeeze-and- excitation block, which enhances the quality of the feature maps and max-pooling with a 2×2 window and stride 2 to reduce the spatial dimension of the feature maps.

Atrous SpatialPyramid Pooling (ASPP) is used in both the sub-networks between the encoder and the decoder to extract high-resolution feature maps that lead to superior performance.

Decoder blocks uses Conv2DTranspose layers which learns a number of filters for performing the upsizing specified with the appropriate kernel_size. The decoder in the NETWORK1,uses only skip connection from the first encoder but, in the decoder of NETWORK2, uses skip connection from both the encoders, which maintains the spatial resolution and enhances the quality of the output feature maps. Squeeze-and-excite blocks are used in the decoder blocks of NETWORK1 and NETWORK2 which helps in reducing redundant information.

The output masks from both the networks are concatenated, then a final conv layer is used to combine both the masks to get the final output mask. 

The intermediate concatenation and the multiplications of the input image with the output of NETWORK1 and then again the concatenation with the output of NETWORK2 improves the performance of the network and this is the intuitive basis and motivation behind this architecture as described in the original paper.

![](https://i.ibb.co/KyXDQwV/Double-U-Net.png)



In [None]:
from tensorflow.keras.layers import *
from tensorflow.keras.applications import *
from tensorflow.keras.callbacks import *
from tensorflow.keras.optimizers import Adam, Nadam
from tensorflow.keras.metrics import *
from tensorflow.keras.losses import binary_crossentropy

np.random.seed(13)
tf.random.set_seed(13)

In [None]:
def squeeze_excite_block(inputs, ratio=8):
    init = inputs
    channel_axis = -1
    filters = init.shape[channel_axis]
    se_shape = (1, 1, filters)

    se = GlobalAveragePooling2D()(init)
    se = Reshape(se_shape)(se)
    se = Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se)
    se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se)

    x = Multiply()([init, se])
    return x

def conv_block(inputs, filters):
    x = inputs

    x = Conv2D(filters, (3, 3), padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(filters, (3, 3), padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = squeeze_excite_block(x)

    return x

def encoder1(inputs):
    skip_connections = []

    model = VGG19(include_top=False, weights='imagenet', input_tensor=inputs)
    names = ["block1_conv2", "block2_conv2", "block3_conv4", "block4_conv4"]
    for name in names:
        skip_connections.append(model.get_layer(name).output)

    output = model.get_layer("block5_conv4").output
    return output, skip_connections

def decoder1(inputs, skip_connections):
    num_filters = [256, 128, 64, 32]
    skip_connections.reverse()
    x = inputs
    shape = x.shape

    for i, f in enumerate(num_filters):
        x = Conv2DTranspose(shape[3], (2, 2), activation="relu", strides=(2, 2))(x)
        x = Concatenate()([x, skip_connections[i]])
        x = conv_block(x, f)

    return x

def encoder2(inputs):
    num_filters = [32, 64, 128, 256]
    skip_connections = []
    x = inputs

    for i, f in enumerate(num_filters):
        x = conv_block(x, f)
        skip_connections.append(x)
        x = MaxPool2D((2, 2))(x)

    return x, skip_connections

def decoder2(inputs, skip_1, skip_2):
    num_filters = [256, 128, 64, 32]
    skip_2.reverse()
    x = inputs
    shape = x.shape

    for i, f in enumerate(num_filters):
        x = Conv2DTranspose(shape[3], (2, 2), activation="relu", strides=(2, 2))(x)
        x = Concatenate()([x, skip_1[i], skip_2[i]])
        x = conv_block(x, f)

    return x

def output_block(inputs):
    x = Conv2D(1, (1, 1), padding="same")(inputs)
    x = Activation('sigmoid')(x)
    return x

def ASPP(x, filter):
    shape = x.shape

    y1 = AveragePooling2D(pool_size=(shape[1], shape[2]))(x)
    y1 = Conv2D(filter, 1, padding="same")(y1)
    y1 = BatchNormalization()(y1)
    y1 = Activation("relu")(y1)
    shape2 = y1.shape
    
    y1 = Conv2DTranspose(shape2[3], (8,8), activation="relu", strides=(shape[1], shape[2]))(y1)
    

    y2 = Conv2D(filter, 1, dilation_rate=1, padding="same", use_bias=False)(x)
    y2 = BatchNormalization()(y2)
    y2 = Activation("relu")(y2)

    y3 = Conv2D(filter, 3, dilation_rate=6, padding="same", use_bias=False)(x)
    y3 = BatchNormalization()(y3)
    y3 = Activation("relu")(y3)

    y4 = Conv2D(filter, 3, dilation_rate=12, padding="same", use_bias=False)(x)
    y4 = BatchNormalization()(y4)
    y4 = Activation("relu")(y4)

    y5 = Conv2D(filter, 3, dilation_rate=18, padding="same", use_bias=False)(x)
    y5 = BatchNormalization()(y5)
    y5 = Activation("relu")(y5)

    y = Concatenate()([y1, y2, y3, y4, y5])

    y = Conv2D(filter, 1, dilation_rate=1, padding="same", use_bias=False)(y)
    y = BatchNormalization()(y)
    y = Activation("relu")(y)

    return y

def build_model():
    inputs = Input((512, 512, 3))
    x, skip_1 = encoder1(inputs)
    x = ASPP(x, 64)
    x = decoder1(x, skip_1)
    outputs1 = output_block(x)

    x = inputs * outputs1

    x, skip_2 = encoder2(x)
    x = ASPP(x, 64)
    x = decoder2(x, skip_1, skip_2)
    outputs2 = output_block(x)
    outputs = Concatenate()([outputs1, outputs2])
    
    combine_output = Conv2D(1, (64, 64), activation="sigmoid", padding="same")(outputs)

    model = Model(inputs, combine_output)
    return model

In [None]:
model = build_model()
model.summary(line_length=150)

## Define Metrics

In [None]:
with strategy.scope():
    def dice_coeff(y_true, y_pred):
        # add epsilon to avoid a divide by 0 error in case a slice has no pixels set
        # we only care about relative value, not absolute so this alteration doesn't matter
        _epsilon = 10 ** -7
        intersections = tf.reduce_sum(y_true * y_pred)
        unions = tf.reduce_sum(y_true + y_pred)
        dice_scores = (2.0 * intersections + _epsilon) / (unions + _epsilon)
        return dice_scores

    def dice_loss(y_true, y_pred):
        loss = 1 - dice_coeff(y_true, y_pred)
        return loss
    
    def iou(y_true, y_pred):
        def f(y_true, y_pred):
            intersection = (y_true * y_pred).sum()
            union = y_true.sum() + y_pred.sum() - intersection
            x = (intersection + smooth) / (union + smooth)
            x = x.astype(np.float32)
            return x
        return tf.numpy_function(f, [y_true, y_pred], tf.float32)
    
    def bce_dice_loss(y_true, y_pred):
        return binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)

#     def focal_loss(y_true, y_pred):
#         alpha=0.25
#         gamma=2
#         def focal_loss_with_logits(logits, targets, alpha, gamma, y_pred):
#             weight_a = alpha * (1 - y_pred) ** gamma * targets
#             weight_b = (1 - alpha) * y_pred ** gamma * (1 - targets)
#             return (tf.math.log1p(tf.exp(-tf.abs(logits))) + tf.nn.relu(-logits)) * (weight_a + weight_b) + logits * weight_b

#         y_pred = tf.clip_by_value(y_pred, tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon())
#         logits = tf.math.log(y_pred / (1 - y_pred))
#         loss = focal_loss_with_logits(logits=logits, targets=y_true, alpha=alpha, gamma=gamma, y_pred=y_pred)
#         # or reduce_sum and/or axis=-1
#         return tf.reduce_mean(loss)
    
    def tversky(y_true, y_pred, smooth=1, alpha=0.7):
        y_true_pos = tf.reshape(y_true,[-1])
        y_pred_pos = tf.reshape(y_pred,[-1])
        true_pos = tf.reduce_sum(y_true_pos * y_pred_pos)
        false_neg = tf.reduce_sum(y_true_pos * (1 - y_pred_pos))
        false_pos = tf.reduce_sum((1 - y_true_pos) * y_pred_pos)
        return (true_pos + smooth) / (true_pos + alpha * false_neg + (1 - alpha) * false_pos + smooth)

    def tversky_loss(y_true, y_pred):
        return 1 - tversky(y_true, y_pred)

    def focal_tversky_loss(y_true, y_pred, gamma=0.75):
        tv = tversky(y_true, y_pred)
        return K.pow((1 - tv), gamma)


    get_custom_objects().update({"dice": dice_loss})

## Functions to Load Records

In [None]:
AUTO = tf.data.experimental.AUTOTUNE
image_feature_description = {
    'height': tf.io.FixedLenFeature([], tf.int64),
    'width': tf.io.FixedLenFeature([], tf.int64),
    'num_channels': tf.io.FixedLenFeature([], tf.int64),
    'img_bytes': tf.io.FixedLenFeature([], tf.string),
    'mask': tf.io.FixedLenFeature([], tf.string),
}

def _parse_image_function(example_proto):
    single_example = tf.io.parse_single_example(example_proto, image_feature_description)
    image = tf.reshape( tf.io.decode_raw(single_example['img_bytes'],out_type='uint8'), (512, 512, 3))
    mask =  tf.reshape(tf.io.decode_raw(single_example['mask'],out_type='bool'),(512, 512, 1))
    ## normalize images array and cast image and mask to float32
    image = tf.cast(image, tf.float32) / 255.0
    mask = tf.cast(mask, tf.float32)
    return image, mask

def load_dataset(filenames, ordered=False):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO, compression_type="GZIP")
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(_parse_image_function, num_parallel_calls=AUTO)
    return dataset

def get_training_dataset():
    dataset = load_dataset(train_tf_files)
    #dataset = dataset.repeat()
    dataset = dataset.shuffle(20000)
    dataset = dataset.batch(32, drop_remainder=True)
    dataset = dataset.prefetch(AUTO)
    return dataset
def get_val_dataset():
    dataset = load_dataset(val_tf_files)
    #dataset = dataset.repeat()
    dataset = dataset.shuffle(5000)
    dataset = dataset.batch(32, drop_remainder=True)
    dataset = dataset.prefetch(AUTO)
    return dataset

## Model Training

In [None]:
if not os.path.exists('/kaggle/working/train_job'):
    os.makedirs('/kaggle/working/train_job')


    Training Samples         : 47703
	Validation Samples       : 5829
    



In [None]:
with strategy.scope():
    metrics = [
        dice_coeff,
#        iou,
        bce_dice_loss,
#        focal_loss,
        Recall(),
        Precision(),
        tversky_loss,
        focal_tversky_loss
    ]
    
    callbacks = [
        ModelCheckpoint('/kaggle/working/hubmap-model-1.h5', verbose=1),
        ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10),
        CSVLogger("/kaggle/working/train_job/data.csv"),
    #    TensorBoard(),
        EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=False)
    ]
    
    model = build_model()
    model.compile(optimizer = Adam(lr = 1e-3), loss = 'dice', metrics=metrics)
    train_dataset = get_training_dataset()
    validation_dataset = get_val_dataset()
    
    train_steps = round((47703//32)*0.70)
    validation_steps = round((5829//32)*0.70)

    model.fit(train_dataset, epochs=20, steps_per_epoch=train_steps,
              validation_data=validation_dataset, validation_steps=validation_steps,
              callbacks=callbacks)

    model.save_weights("/kaggle/working/train_job/hubmap_model_1.h5")
