https://www.tensorflow.org/tutorials/images/segmentation

In [None]:
%%capture
!pip install git+https://github.com/tensorflow/examples.git

In [None]:
import numpy as np 
import pandas as pd
import os
import cv2

import tensorflow as tf
from tensorflow.keras.layers.experimental import preprocessing

from tensorflow_examples.models.pix2pix import pix2pix

from IPython.display import clear_output
import matplotlib.pyplot as plt

In [None]:
DEBUG = False

train_path = '../input/sartorius-cell-instance-segmentation/train/'

SEED = 42
WIDTH, HEIGHT = 704, 520
# RESIZE_WIDTH, RESIZE_HEIGHT = 128, 128
RESIZE_WIDTH, RESIZE_HEIGHT = 512, 512
BATCH_SIZE = 32
BUFFER_SIZE = 32

VAL_SPLIT = 0.2

AUTO = tf.data.AUTOTUNE

In [None]:
train = pd.read_csv('../input/sartorius-cell-instance-segmentation/train.csv')
train.head()

n_ids = train.id.nunique()

if DEBUG:
    unique_ids_train = list(set(train['id'].tolist()))[:BATCH_SIZE]
    unique_ids_valid = list(set(train['id'].tolist()))[BATCH_SIZE:2*BATCH_SIZE]
else:
    unique_ids_train = list(set(train['id'].tolist()))[:int(n_ids * (1 - VAL_SPLIT))]
    unique_ids_valid = list(set(train['id'].tolist()))[int(n_ids * (1 - VAL_SPLIT)):]


temp = pd.DataFrame()
for sample_id in unique_ids_train:
    query = train[train.id == sample_id]
    temp = pd.concat([temp, query])
train = temp
train = train.reset_index(drop=True)

temp = pd.DataFrame()
for sample_id in unique_ids_valid:
    query = train[train.id == sample_id]
    temp = pd.concat([temp, query])
valid = temp
valid = train.reset_index(drop=True)
    
TRAIN_LENGTH = train['id'].nunique()
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

VALID_LENGTH = valid['id'].nunique()
VALIDATION_STEPS = VALID_LENGTH // BATCH_SIZE

In [None]:
# ref: https://www.kaggle.com/inversion/run-length-decoding-quick-start

def rle_decode(mask_rle, shape, color=1):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros((shape[0] * shape[1]), dtype=np.float32)
    for lo, hi in zip(starts, ends):
        img[lo : hi] = color
    return img.reshape(shape)

def get_mask(image_id, df):
    current = df[df["id"] == image_id]
    labels = current["annotation"].tolist()
    
    mask = np.zeros((HEIGHT, WIDTH))
    for label in labels:
        mask += rle_decode(label, (HEIGHT, WIDTH))
    mask = mask.clip(0, 1)
    
    return mask

In [None]:
def train_generator(df):
    image_ids = set(df['id'].tolist())
    
    for image_id in image_ids:
        image = cv2.imread(os.path.join(train_path, image_id) + '.png') 
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        mask = get_mask(image_id, df)
        
        image = cv2.resize(image, (RESIZE_HEIGHT, RESIZE_WIDTH))
        mask = cv2.resize(mask, (RESIZE_HEIGHT, RESIZE_WIDTH))
        mask = mask.reshape((*mask.shape, 1))
        
        image = image.astype(np.float32)
        mask = mask.astype(np.int32)
        
        yield image, mask

In [None]:
train_ds = tf.data.Dataset.from_generator(
    lambda : train_generator(train), 
    output_types=(tf.float32, tf.int32),
    output_shapes=((RESIZE_HEIGHT, RESIZE_WIDTH, 3), (RESIZE_HEIGHT, RESIZE_WIDTH, 1)))

valid_ds = tf.data.Dataset.from_generator(
    lambda : train_generator(valid), 
    output_types=(tf.float32, tf.int32),
    output_shapes=((RESIZE_HEIGHT, RESIZE_WIDTH, 3), (RESIZE_HEIGHT, RESIZE_WIDTH, 1)))


In [None]:
class Augment(tf.keras.layers.Layer):
    def __init__(self, seed=SEED):
        super().__init__()
        
        self.augment_inputs = preprocessing.RandomFlip('horizontal', seed=seed)
        self.augment_labels = preprocessing.RandomFlip('horizontal', seed=seed)
        
    def call(self, inputs, labels):
        inputs = self.augment_inputs(inputs)
        labels = self.augment_labels(labels)
        return inputs, labels

In [None]:
train_ds = (
    train_ds
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
    .repeat()
    .map(Augment())
    .prefetch(AUTO))

valid_ds = (
    valid_ds
    .batch(BATCH_SIZE)
    .repeat()
    .prefetch(AUTO))

In [None]:
def display(display_list):
    plt.figure(figsize=(20, 20))

    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis('off')
    plt.show()


In [None]:
for images, masks in train_ds.take(2):
    sample_image, sample_mask = images[0], masks[0]
    display([sample_image, sample_mask])


In [None]:
np.max(sample_mask)

In [None]:
base_model = tf.keras.applications.MobileNetV2(input_shape=[RESIZE_HEIGHT, RESIZE_WIDTH, 3], include_top=False)

# Use the activations of these layers
layer_names = [
    'block_1_expand_relu',   # 64x64
    'block_3_expand_relu',   # 32x32
    'block_6_expand_relu',   # 16x16
    'block_13_expand_relu',  # 8x8
    'block_16_project',      # 4x4
]

base_model_outputs = [base_model.get_layer(name).output for name in layer_names]

down_stack = tf.keras.Model(inputs=base_model.input, outputs=base_model_outputs)

down_stack.trainable = False

In [None]:
up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
]

In [None]:
def unet_model(output_channels : int):
    inputs = tf.keras.layers.Input(shape=[RESIZE_HEIGHT , RESIZE_WIDTH, 3])
    
    skips = down_stack(inputs)
    x = skips[-1]
    skips = reversed(skips[:-1])
    
    for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = tf.keras.layers.Concatenate()
        x = concat([x, skip])
    
    last = tf.keras.layers.Conv2DTranspose(
        filters=output_channels, kernel_size=3, strides=2,
        padding='same', activation='sigmoid') #64x64 -> 128x128
    
    x = last(x)
    
    return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
from keras.losses import binary_crossentropy
import tensorflow.keras.backend as K

def dice_loss(y_true, y_pred):
    smooth = 1.
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = y_true_f * y_pred_f
    score = (2. * K.sum(intersection) + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
    return 1. - score

def bce_dice_loss(y_true, y_pred):
    return binary_crossentropy(tf.cast(y_true, tf.float32), y_pred) + 0.5 * dice_loss(tf.cast(y_true, tf.float32), y_pred)

In [None]:
OUTPUT_CLASSES = 1

model = unet_model(output_channels=OUTPUT_CLASSES)
model.compile(optimizer='adam',
              loss=bce_dice_loss,
              metrics=['accuracy'])

In [None]:
tf.keras.utils.plot_model(model, show_shapes=True)


In [None]:
def create_mask(pred_mask):
    pred_mask = tf.where(pred_mask > 0.5,1,0)
#     pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask


In [None]:
def show_predictions(dataset=None, num=1):
    if dataset:
        for image, mask in dataset.take(num):
            pred_mask = model.predict(image)
            display([image[0], mask[0], create_mask(pred_mask[0])])
    else:
        display([sample_image, sample_mask,
                 create_mask(model.predict(sample_image[tf.newaxis, ...])[0])])

In [None]:
show_predictions(train_ds)

In [None]:
class DisplayCallback(tf.keras.callbacks.Callback):
    def __init__(self):
        super().__init__()
    
    def on_epoch_end(self, epoch, logs=None):
#         clear_output(wait=True)
        show_predictions()
        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))


In [None]:
EPOCHS = 100

display_cb = DisplayCallback()
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
    'best_model/',
    save_best_only=True,
    save_weights_only=False,
)
lr_reduce = tf.keras.callbacks.ReduceLROnPlateau()
es = tf.keras.callbacks.EarlyStopping(patience=15)

model_history = model.fit(train_ds, epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_steps=VALIDATION_STEPS,
                          validation_data=valid_ds,
                          callbacks=[display_cb, model_checkpoint, lr_reduce, es])
