![logo](https://cdn.freelogovectors.net/wp-content/uploads/2018/07/tensorflow-logo.png)

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
from functools import partial
import json
import glob
import pickle
import matplotlib.pyplot as plt
from kaggle_datasets import KaggleDatasets

# Intro
In this notebook we will create a Tensorflow dataset for segmentation. The dataset returns image and mask pairs. The dataset code runs on CPU, GPU and TPU.

References:
  * [Sartorius: Create Mask Dataset](https://www.kaggle.com/mistag/sartorius-create-mask-dataset)

# Dataset creation
We already created a mask dataset [here](https://www.kaggle.com/mistag/sartorius-create-mask-dataset). Next step is to create a TF dataset, which is quite straight forward. And the dataset fits in memory, so TPU can be supported without any private GCS bucket.   
There are three types of cells, but only one type per image, so masks can be binary.

In [None]:
with open('../input/sartorius-create-mask-dataset/mask_dict.pkl', 'rb') as f:
    mask_dict = pickle.load(f)

The dict format:

In [None]:
mask_dict['0030fd0e6378']

HW strategy (works on CPU, GPU, TPU).

In [None]:
# Function to get hardware strategy
def get_hardware_strategy():
    try:
        # TPU detection. No parameters necessary if TPU_NAME environment variable is
        # set: this is always the case on Kaggle.
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        print('Running on TPU ', tpu.master())
    except ValueError:
        tpu = None

    if tpu:
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
        tf.config.optimizer.set_jit(True)
    else:
        # Default distribution strategy in Tensorflow. Works on CPU and single GPU.
        strategy = tf.distribute.get_strategy()

    return tpu, strategy

tpu, strategy = get_hardware_strategy()

Check if we are running on TPU, and set data path accordingly.

In [None]:
if strategy.num_replicas_in_sync > 1:
    DS_PATH = KaggleDatasets().get_gcs_path('sartorius-cell-instance-segmentation') # TPU
else:
    DS_PATH = '../input/sartorius-cell-instance-segmentation' # GPU or CPU

Dataset creation. Note that the Python dictionary created [here](https://www.kaggle.com/mistag/sartorius-create-mask-dataset) is converted to a Tensorflow hash table.

In [None]:
train_files = tf.io.gfile.glob(DS_PATH+'/train/*.png')

AUTOTUNE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 8 # adapt to needs

IMAGE_SIZE = [520, 704]
GRAYSCALE = True # set to false for color images

# convert mask dictionary to TF hash table of encoded PNG masks
blank = tf.io.encode_png(np.zeros((IMAGE_SIZE[0],IMAGE_SIZE[1],1), dtype=np.uint8)) 
masks = []
m_values = [*mask_dict.values()] # extract values
for i in range(len(mask_dict)):
    masks.append(tf.io.encode_png(np.expand_dims(m_values[i]['mask'], -1)))
mask_init = tf.lookup.KeyValueTensorInitializer([*mask_dict], masks)
mask_table = tf.lookup.StaticHashTable(mask_init, default_value=blank)

# read .png file and create mask from lookup table
def _read_png(filename):
    img = tf.io.read_file(filename)
    iid = tf.strings.split(tf.strings.split(filename, '/')[-1], '.')[0]
    mask = tf.io.decode_png(mask_table.lookup(iid))
    img = tf.image.decode_png(img, channels=1)
    img = tf.expand_dims(img, -1)
    img = tf.cast(img, tf.float32)
    img = img / 255.
    if not GRAYSCALE:
        img = tf.image.grayscale_to_rgb(img)
        img = tf.reshape(img, [*IMAGE_SIZE, 3])
    else:
        img = tf.reshape(img, [*IMAGE_SIZE, 1])
    return img, mask

train_ds = tf.data.Dataset.from_tensor_slices(train_files)
train_ds = train_ds.map(_read_png, num_parallel_calls=AUTOTUNE)
train_ds = train_ds.shuffle(len(mask_dict))
train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
train_ds = train_ds.batch(BATCH_SIZE, drop_remainder=True)
train_ds = train_ds.repeat()

# Test dataset
Finally, test the dataset.

In [None]:
# fetch a batch
image_batch, mask_batch = next(iter(train_ds))

In [None]:
fig = plt.figure(figsize=(16,48))
for i in range(8):
    axes = fig.add_subplot(8, 2, 2*i+1)
    plt.setp(axes, xticks=[], yticks=[])
    plt.imshow(image_batch[i].numpy(), cmap='gray')
    axes = fig.add_subplot(8, 2, 2*i+2)
    plt.setp(axes, xticks=[], yticks=[])
    plt.imshow(mask_batch[i].numpy())
    plt.tight_layout()