# ***Disclaimer:*** 
Hello Kagglers! I am a Solution Architect with the Google Cloud Platform. I am a coach for this competition, the focus of my contributions is on helping users to leverage GCP components (GCS, TPUs, BigQueryetc..) in order to solve large problems. My ideas and contributions represent my own opinion, and are not representative of an official recommendation by Google. Also, I try to develop notebooks quickly in order to help users early in competitions. There may be better ways to solving particular problems, I welcome comments and suggestions. Use my contributions at your own risk, I don't garantee that they will help on winning any competition, but I am hoping to learn by collaborating with everyone.


# Objective:


The objective of this notebook is to demonstrate how to feed a TFRecord dataset to a Keras Unet model for image segmentation. The advantage of using a TFRecord dataset is that you can then train using TPUs, as it will be explained in the next Notebook -- and you get 100x performance. 

In previous notebooks, I demonstrated how to read the competition data and produce a TFRecord dataset. This Notebook will use this dataset as input:
--> [Link to the TFRecord Dataset Used by this Notebook.](https://www.kaggle.com/marcosnovaes/hubmap-tfrecord-512)

Previous Notebooks in this competition: 

[https://www.kaggle.com/marcosnovaes/hubmap-read-data-and-build-tfrecords/](https://www.kaggle.com/marcosnovaes/hubmap-read-data-and-build-tfrecords/): Demonstrates how the TFRecord Dataset was built

[https://www.kaggle.com/marcosnovaes/hubmap-looking-at-tfrecords/](https://www.kaggle.com/marcosnovaes/hubmap-looking-at-tfrecords/): Explains how to read the data using the TFRecord Dataset



# Setup
1) Add the TFRecord Dataset as input to the notebook: Go to the Data section at the right, click "add data" and lof for the dataset: "hubmap_train_test"
2) This Notebook also shows how to access a Kaggle dataset directly from Google Cloud Storage (GCS). To enable this feature, you need to link the Notebook to a GCS project, by going to the menu Add-ons-->Cloud SDK

My import section is a little messy as I imported snippets from several sources. It will be cleaned eventually.

In [None]:
import os
import sys
import random
import warnings

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from tqdm import tqdm
from itertools import chain
from skimage.io import imread, imshow, imread_collection, concatenate_images
from skimage.transform import resize
from skimage.morphology import label

from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input

from keras.layers.core import Dropout, Lambda
from tensorflow.keras.layers import Conv2D, Conv2DTranspose
from tensorflow.keras.layers import MaxPooling2D, UpSampling2D
from tensorflow.keras.layers import concatenate

from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras import backend as K

from tensorflow.keras import layers

from keras.engine.topology import Layer
#from tensorflow.keras.layers.merge import concatenate, add


from tensorflow.keras.optimizers import Adam
from keras.utils.generic_utils import get_custom_objects


from kaggle_datasets import KaggleDatasets
from kaggle_secrets import UserSecretsClient

import tensorflow as tf

In [None]:
user_secrets = UserSecretsClient()
user_credential = user_secrets.get_gcloud_credential()
user_secrets.set_tensorflow_credential(user_credential)

# IMPORTANT SECTION: 
The code below shows how to read TFRecords as in the previous notebook. However, the read function is altered here to return only the image and mask arrays. 

Notice the inclusion of a tf.reshape statement using literal dimensions. i.e, tf.reshape(1,512,512,3) for images and tf.reshape(1,512,512) for masks. We were previously reading these values dynamically, but when using TPUs this lead to a compilation error ("Dynamic Shape xxxx not support in function XXX). This is because the TPU compilation does not support dynamic shapes yet, so the shapes must be "baked" into the code with literals as shown below. I learned this the hard way...

In [None]:
# Create a dictionary describing the features.
image_feature_description = {
    'img_index': tf.io.FixedLenFeature([], tf.int64),
    'height': tf.io.FixedLenFeature([], tf.int64),
    'width': tf.io.FixedLenFeature([], tf.int64),
    'num_channels': tf.io.FixedLenFeature([], tf.int64),
    'img_bytes': tf.io.FixedLenFeature([], tf.string),
    'mask': tf.io.FixedLenFeature([], tf.string),
    'tile_id': tf.io.FixedLenFeature([], tf.int64),
    'tile_col_pos': tf.io.FixedLenFeature([], tf.int64),
    'tile_row_pos': tf.io.FixedLenFeature([], tf.int64),
}

def _parse_image_and_masks_function(example_proto):
    single_example = tf.io.parse_single_example(example_proto, image_feature_description)
    img_height = single_example['height']
    img_width = single_example['width']
    num_channels = single_example['num_channels']
    
    img_bytes =  tf.io.decode_raw(single_example['img_bytes'],out_type='uint8')
    #img_array = tf.reshape( img_bytes, (img_height, img_width, num_channels))
    # Need to define array shape with literals to avoid dynamic shape errors
    img_array = tf.reshape( img_bytes, (1,512, 512, 3))
    
    mask_bytes =  tf.io.decode_raw(single_example['mask'],out_type='bool')
    
    #mask = tf.reshape(mask_bytes, (1,img_height,img_width,1))
    #mask = tf.reshape(mask_bytes, (img_height,img_width))
    # Need to define array shape with literals to avoid dynamic shape errors
    mask = tf.reshape(mask_bytes, (1,512,512))
    
    #cast to float 32
    img_array = tf.cast(img_array, tf.float32) / 255.0
    mask = tf.cast(mask, tf.float32)
    return img_array, mask

def _parse_mask_function(example_proto):
    single_example = tf.io.parse_single_example(example_proto, image_feature_description)
    img_height = single_example['height']
    img_width = single_example['width']
    num_channels = single_example['num_channels']   
    mask_bytes =  tf.io.decode_raw(single_example['mask'],out_type='bool') 
    mask = tf.reshape(mask_bytes, (img_height,img_width))
    return mask

def read_tf_dataset(storage_file_path):
    encoded_image_dataset = tf.data.TFRecordDataset(storage_file_path, compression_type="GZIP")
    parsed_image_dataset = encoded_image_dataset.map(_parse_image_data_function)
    return parsed_image_dataset

def read_images_and_masks(storage_file_path):
    encoded_image_dataset = tf.data.TFRecordDataset(storage_file_path, compression_type="GZIP")
    parsed_image_dataset = encoded_image_dataset.map(_parse_image_and_masks_function)
    return parsed_image_dataset

def read_masks(storage_file_path):
    encoded_image_dataset = tf.data.TFRecordDataset(storage_file_path, compression_type="GZIP")
    parsed_image_dataset = encoded_image_dataset.map(_parse_mask_function)
    return parsed_image_dataset


Now look for the dataset "hubmap-tfrecord-512" in your input directory. If it is note there, add it as described in the Setup section.

In [None]:
!ls /kaggle/input

In [None]:
!ls /kaggle/input/hubmap-tfrecord-512

I have recently added two CSV files that have a lot of useful metadata for each tile. The file "train_all_tiles.csv" has the data for all images tiles in the train set and "test_all_tiles.csv" for the test set.

In [None]:
train_tiles_csv = '/kaggle/input/hubmap-tfrecord-512/train_all_tiles.csv'
test_tiles_csv = '/kaggle/input/hubmap-tfrecord-512/test_all_tiles.csv'

In [None]:
train_tiles_df = pd.read_csv(train_tiles_csv)
train_tiles_df.head()

Notice above that I have include both the local path and the gcs path for each tile. This is super handy. If you are using a GPU you can use the local file paths, but TPUs require the GCS file path. There is also metadata on some metrics for each tile, such as the mask density and lowband_density (as explained in the previous notebook). We can use these values to filters tiles of interest. For example, if we are interested in getting all tiles that have gloms, we select mask_density > 0 as below:

In [None]:
# build a dataset of all images tiles from the train set that have gloms in them
#for csv_file in file_list:

gloms_df = train_tiles_df.loc[train_tiles_df["mask_density"]  > 0]

gloms_df.head()

In [None]:
gloms_df.__len__()

So, we have a total of 3384 tiles that have gloms. Selecting "lowband_density > 1000" will exclude all the "black" and "gray" tiles from the border, which have no tissue. 

In [None]:
cropped_df = train_tiles_df.loc[train_tiles_df["lowband_density"]  > 1000]
cropped_df.head()

In [None]:
cropped_df.__len__()

So, there are 15957 tiles with tissue, 3384 of them have gloms. We may want to produce a balanced mix for training.

# Selecting a Unet Model
I found several references to a [popular paper in biomedical image segementation](https://arxiv.org/abs/1505.04597), by (Olaf Ronneberger, Philipp Fischer, Thomas Brox).

I found a few implementations with slight differences in downsampling and upsampling layers.

1) Unet1: [Kaggle Notebook](https://www.kaggle.com/keegil/keras-u-net-starter-lb-0-277) by @keegil. That notebook has some useful background. 

2) Unet2: [Github site](https://github.com/jocicmarko/ultrasound-nerve-segmentation/blob/master/train.py) by jocicmarko, built for the ultrasound-nerve-segmentation competition

3) Unet3: [The Magician's Corner repository](https://github.com/RSNA/MagiciansCorner/blob/master/UNetWithTensorflow.ipynb), maintained by [Dr. Bradley Erickson](https://github.com/slowvak). 

I modified slightly all three, so that they can run both in CPU and TPUs. They seem to work, but I could note get (2) to produce good results. The main difference between (1) and (3) is that (3) uses normalization layers, which make a huge difference. I also noticed that (3) is much faster, because it that fewer layers. I am showing all three here as a curiosity, and to encourage discussion (feel free to comment to this notebook).

# Loss Function
I use the dice_coefficient for loss function, the exact same function was used in (2) and (3). 

Also notice the LayerNormalization here, which is a subclass of keras.engine.topology.Layer. 

LayerNormalization is the original one used in (3)

LayerNormalization2 is the slight modification I made using tensorflow.math functions so that it can run on a TPU.

In [None]:
def dice_coeff(y_true, y_pred):
    # add epsilon to avoid a divide by 0 error in case a slice has no pixels set
    # we only care about relative value, not absolute so this alteration doesn't matter
    _epsilon = 10 ** -7
    intersections = tf.reduce_sum(y_true * y_pred)
    unions = tf.reduce_sum(y_true + y_pred)
    dice_scores = (2.0 * intersections + _epsilon) / (unions + _epsilon)
    return dice_scores

def dice_loss(y_true, y_pred):
    loss = 1 - dice_coeff(y_true, y_pred)
    return loss
  
get_custom_objects().update({"dice": dice_loss})

class LayerNormalization (Layer) :
    
    def call(self, x, mask=None, training=None) :
        axis = list (range (1, len (x.shape)))
        x /= K.std (x, axis = axis, keepdims = True) + K.epsilon()
        x -= K.mean (x, axis = axis, keepdims = True)
        return x
        
    def compute_output_shape(self, input_shape):
        return input_shape
    
class LayerNormalization2 (Layer) :
    
    def call(self, x, mask=None, training=None) :
        axis = list (range (1, len (x.shape)))
        _epsilon = 10 ** -7
        x /= tf.math.reduce_std (x, axis = axis, keepdims = True) + _epsilon
        x -= tf.math.reduce_mean (x, axis = axis, keepdims = True)
        return x
        
    def compute_output_shape(self, input_shape):
        return input_shape

In [None]:
def build_unet1( img_height, img_width, img_channels):
    # Build U-Net model
    inputs = Input((img_height, img_width, img_channels))
    #s = Lambda(lambda x: x / 255) (inputs)
    
    #c1 = Conv2D(16, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (s)
    
    c1 = Conv2D(16, (3, 3), activation='linear', kernel_initializer='he_normal', padding='same') (inputs)
    c1 = Dropout(0.1) (c1)
    c1 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c1)
    p1 = MaxPooling2D((2, 2)) (c1)

    c2 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (p1)
    c2 = Dropout(0.1) (c2)
    c2 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c2)
    p2 = MaxPooling2D((2, 2)) (c2)

    c3 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (p2)
    c3 = Dropout(0.2) (c3)
    c3 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c3)
    p3 = MaxPooling2D((2, 2)) (c3)

    c4 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (p3)
    c4 = Dropout(0.2) (c4)
    c4 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c4)
    p4 = MaxPooling2D(pool_size=(2, 2)) (c4)

    c5 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (p4)
    c5 = Dropout(0.3) (c5)
    c5 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c5)

    u6 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same') (c5)
    u6 = concatenate([u6, c4])
    c6 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (u6)
    c6 = Dropout(0.2) (c6)
    c6 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c6)

    u7 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same') (c6)
    u7 = concatenate([u7, c3])
    c7 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (u7)
    c7 = Dropout(0.2) (c7)
    c7 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c7)

    u8 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same') (c7)
    u8 = concatenate([u8, c2])
    c8 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (u8)
    c8 = Dropout(0.1) (c8)
    c8 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c8)

    u9 = Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same') (c8)
    u9 = concatenate([u9, c1], axis=3)
    c9 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (u9)
    c9 = Dropout(0.1) (c9)
    c9 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c9)

    outputs = Conv2D(1, (1, 1), activation='sigmoid') (c9)

    model = Model(inputs=[inputs], outputs=[outputs])
    
    return model

In [None]:
def build_unet2(img_rows, img_cols, img_channels):
    #inputs = Input((img_rows, img_cols, img_channels))
    inputs = Input((512, 512, 3))
    act_fn = 'relu'
    
    #conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = Conv2D(32, (3, 3), activation='linear', padding='same')(inputs)
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
    #conv2 = Conv2D(64, (3, 3), activation='linear', padding='same')(pool1)
    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
    #conv3 = Conv2D(128, (3, 3), activation='linear', padding='same')(pool2)
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
    #conv4 = Conv2D(256, (3, 3), activation='linear', padding='same')(pool3)
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
    #conv5 = Conv2D(512, (3, 3), activation='linear', padding='same')(pool4)
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)

    up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3)
    conv6 = Conv2D(256, (3, 3), activation=act_fn, padding='same')(up6)
    conv6 = Conv2D(256, (3, 3), activation=act_fn, padding='same')(conv6)

    up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3)
    conv7 = Conv2D(128, (3, 3), activation=act_fn, padding='same')(up7)
    conv7 = Conv2D(128, (3, 3), activation=act_fn, padding='same')(conv7)

    up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
    conv8 = Conv2D(64, (3, 3), activation=act_fn, padding='same')(up8)
    conv8 = Conv2D(64, (3, 3), activation=act_fn, padding='same')(conv8)

    up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
    conv9 = Conv2D(32, (3, 3), activation=act_fn, padding='same')(up9)
    conv9 = Conv2D(32, (3, 3), activation=act_fn, padding='same')(conv9)

    conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)

    model = Model(inputs=[inputs], outputs=[conv10])
    #model.compile(optimizer = Adam(lr = 1e-4), loss = 'dice', metrics=[dice_coeff])
    
    return model


In [None]:
def build_unet3(act_fn = 'relu', init_fn = 'he_normal', width=512, height = 512, channels = 3): 
    inputs = Input((512,512,3))
    act_fn = 'relu'
    init_fn = 'he_normal'

    # note we use linear function before layer normalization
    conv1 = Conv2D(8, 5, activation = 'linear', padding = 'same', kernel_initializer = init_fn)(inputs)
    conv1 = LayerNormalization()(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = Conv2D(16, 3, activation = act_fn, padding = 'same', kernel_initializer = init_fn)(pool1)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    conv3 = Conv2D(32, 3, activation = 'linear', padding = 'same', kernel_initializer = init_fn)(pool2)
    conv3 = LayerNormalization()(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    conv4 = Conv2D(64, 3, activation = act_fn, padding = 'same', kernel_initializer = init_fn)(pool3)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(72, 3, activation = act_fn, padding = 'same', kernel_initializer = init_fn)(pool4)

    up6 = Conv2D(64, 2, activation = 'linear', padding = 'same', kernel_initializer = init_fn)(UpSampling2D(size = (2,2))(conv5))
    up6 = LayerNormalization()(up6)
    merge6 = concatenate([conv4,up6], axis = 3)
    conv6 = Conv2D(64, 3, activation = act_fn, padding = 'same', kernel_initializer = init_fn)(merge6)

    up7 = Conv2D(32, 2, activation = act_fn, padding = 'same', kernel_initializer = init_fn)(UpSampling2D(size = (2,2))(conv6))
    merge7 = concatenate([conv3,up7], axis = 3)
    conv7 = Conv2D(32, 3, activation = act_fn, padding = 'same', kernel_initializer = init_fn)(merge7)

    up8 = Conv2D(16, 2, activation = 'linear', padding = 'same', kernel_initializer = init_fn)(UpSampling2D(size = (2,2))(conv7))
    up8 = LayerNormalization()(up8)
    merge8 = concatenate([conv2,up8], axis = 3)
    conv8 = Conv2D(16, 3, activation = act_fn, padding = 'same', kernel_initializer = init_fn)(merge8)

    up9 = Conv2D(8, 2, activation = act_fn, padding = 'same', kernel_initializer = init_fn)(UpSampling2D(size = (2,2))(conv8))
    merge9 = concatenate([conv1,up9], axis = 3)
    conv9 = Conv2D(8, 3, activation = act_fn, padding = 'same', kernel_initializer = init_fn)(merge9)
    conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9)
    model = Model(inputs = inputs, outputs = conv10)

    return model


Choose the model you want to try by moving the comments below. I am using (3) going forward

In [None]:
#unet_model = build_unet1( 512, 512, 3)
#unet_model = get_unet2(512,512,3)
unet_model = build_unet3(512,512,3)

unet_model.compile(optimizer = Adam(lr = 1e-4), loss = 'dice', metrics=[dice_coeff])
unet_model.summary()

In [None]:
tf.keras.utils.plot_model(unet_model, show_shapes=True)

# Build train and validation datasets
For testing purposes, I am using a very small dataset, only 10 images for test and 5 for validation. Notice that I first build file lists that have the tiles we want. Then we create tf.data.datasets using these lists. It is that simple! 

In [None]:
small_train = gloms_df[0:10]['gcs_path']
small_test = gloms_df[10:15]['gcs_path']

small_train

In [None]:
train_dataset = read_images_and_masks(small_train)
test_dataset = read_images_and_masks(small_test)

for image, mask in train_dataset.take(1):
    sample_image, sample_mask = image, mask

fig, ax = plt.subplots(1,2,figsize=(20,3))
ax[0].imshow(sample_image[0,:,:,:])
ax[1].imshow(sample_mask[0,:,:])


#Using keras model.fit
Now we can use model.fit to train a model.

In [None]:
earlystopper = EarlyStopping(patience=5, verbose=1)
checkpointer = ModelCheckpoint('/kaggle/working/model-hubmap.h5', verbose=1)
# run 1 epoch
results = unet_model.fit(train_dataset, batch_size=1, epochs=1, callbacks=[checkpointer])
#results = unet_model.fit(train_dataset, batch_size=1, epochs=5, validation_data=test_dataset,callbacks=[checkpointer])

#results = model.fit(X_train, Y_train, validation_split=0.1, batch_size=16, epochs=50, 
#                    callbacks=[earlystopper, checkpointer])

In [None]:
!ls /kaggle/working

In [None]:
small_test = gloms_df[1000:1005]['gcs_path']
small_dataset = read_images_and_masks(small_test)

for image, mask in test_dataset.take(1):
    test_image, test_mask = image, mask
test_image.shape

In [None]:
plt.imshow(test_image[0,:,:,:])

In [None]:
unet_model.load_weights("/kaggle/working/model-hubmap.h5")

In [None]:
#reshaped = tf.reshape(test_image,(1,512,512,3))
pred_mask = unet_model.predict(test_image, verbose=1)
pred_mask.shape

Let's have a look at the mask returned

In [None]:
pred_mask[0,:,:,0]

In [None]:
plt.imshow(pred_mask[0,:,:,0])

By looking at the numbers, we see that 0.45 is an interesting threshold to produce a boolean mask. 

In [None]:
#bool_mask = pred_mask[0,:,:,0] > 0.9
bool_mask = pred_mask[0,:,:,0] > 0.45
plt.imshow(bool_mask)

This is pretty good for a single epoch with only 10 images!!! But we tried with an image from the train set, so it seems to be overfitting already. Let's try with a image from the test set:

In [None]:
#unet_model.load_weights("/kaggle/working/model-hubmap.h5")

small_test = gloms_df[1000:1005]['gcs_path']
small_dataset = read_images_and_masks(small_test)

test_image = []
test_mask = []
pred_mask = []
for image, mask in small_dataset.take(1):
    test_image, test_mask = image, mask
    pred_mask = unet_model.predict(test_image, verbose=1)
    pred_mask = pred_mask[0,:,:,0] > 0.5
    #pred_mask = (pred_mask > 0.3).astype(np.uint8)
    #pred_mask = (pred_mask > 0.702)
    
fig, ax = plt.subplots(1,3,figsize=(20,3))
ax[0].imshow(test_image[0,:,:,:])
ax[1].imshow(test_mask[0,:,:])
ax[2].imshow(pred_mask)

In [None]:
mask_density = np.count_nonzero(pred_mask)
mask_density

It does pick up the tissue inside the glom, but also some outside -- but again we only trained with 10 images. But it is a good start. In my next notebooks we will scale up the training to thousands of images using TPUs. 