In [None]:
import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import cv2
import tensorflow.keras.backend as K
from sklearn.model_selection import train_test_split
from IPython.display import clear_output


import warnings

warnings.filterwarnings("ignore")

# Sartorius Competition 

**Overview**

In Sartorius Cell Instance Segmentation challenge we are faced with a task to perform instance segmentation on medical images to find neuronal cells. Successful segmentation could lead to the discovery of new drugs and treatments for neurodegenerative diseases such as Alzimer's and brain tumors.

**Data**

<code>train.csv</code> file provided to competitors contains the following columns 
- <code>id</code> - unique identifier
- <code>annotation</code> - run length encoded pixels for the identified neuronal cell
- <code>width</code> - image width
- <code>height</code> - image height
- <code>cell_type</code> - class of the cell
- <code>plate_time</code> - time the plate was created
- <code>sample_id</code> - id of the sample
- <code>sample_date</code> - date the sample was created
- <code>elapsed_timedelta</code> - time since the image was shot


<code>ids</code> present in the column <code>id</code> correspond to image files in <code>train</code> folder

<code>annotations</code> are run length encoded pixels in order to create a mask we will have to decode the annotations

In [None]:
class Config:
    '''
    Config in which I hold all hyperparameters and frequently used variables such as image shape, train directory path etc.
    '''
    def __init__(self, DEBUG=False):
        self.DEBUG = DEBUG
        
    TRAIN_CSV = '../input/sartorius-cell-instance-segmentation/train.csv'
    TRAIN_DIR = '../input/sartorius-cell-instance-segmentation/train/'
    TEST_DIR = '../input/sartorius-cell-instance-segmentation/test/'
    
    IMG_SHAPE = (512, 512)
    
    LR = 1e-3
    
    EPOCHS = 100
    
    N_FILTERS = 32

    
    BATCH_SIZE = 4
    AUTOTUNE = tf.data.AUTOTUNE
    
    N_CLASSES = 1
    BUFFER_SIZE = 2
    
    val_size = 0.1
    
    WEIGHTS_PATH = os.path.join('./', 'model.h5')
        
    seed = 123
    

config = Config()

# Data Loading And Exploration

First lets load <code>train.csv</code> to pandas dataframe and take a look at out data

In [None]:
train_csv = pd.read_csv(Config.TRAIN_CSV)
train_csv.head()

First thing I've noticed is that the same ID appears several times with different annotations, it means that each annotation is only for one cell, and that the annotations might overlap 

**Let's take a look at a single image**

In [None]:
plt.figure(figsize=(15,15))

plt.imshow(cv2.imread(config.TRAIN_DIR + '0030fd0e6378' + '.png'))
plt.axis("off")
plt.show()

Knowing that ids might appear several times let's see how large our dataset is

In [None]:
train_csv.shape

But how large is it when we take into the account that id is unique?

In [None]:
train_csv["id"].unique().shape

Let's see how our dataframe looks for a single id

In [None]:
train_csv[train_csv['id'] == '0030fd0e6378']

# Load Data

Masks are encoded in the annotation column by an algorithm called **Run Length Encoding**. RLE encodes a mask into a vector where vector index corresponds to flattened mask matrix index and the value at that index corresponds to length of the annotation, for a more in depth understanding I recommend looking into [this](https://www.kaggle.com/ihelon/cell-segmentation-run-length-decoding) notebook

In [None]:
def rle_decode(mask_rle, shape, color=1):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros((shape[0] * shape[1], shape[2]), dtype=np.float32)
    for lo, hi in zip(starts, ends):
        img[lo : hi] = color
    return img.reshape(shape)


def build_masks(labels, input_shape, colors=True):
    height, width = input_shape
    if colors:
        mask = np.zeros((height, width, 3))
        for label in labels:
            mask += rle_decode(label, shape=(height, width , 3), color=np.random.rand(3))
    else:
        mask = np.zeros((height, width, 1))
        for label in labels:
            mask += rle_decode(label, shape=(height, width, 1))
    mask = mask.clip(0, 1)
    return mask

# Data Vizualization

Now knowing how to load our data we can finally see how the annotations for neuronal cells look like

Let's see how many classes there are for the <code>cell_type</code> column and what the distribution looks like

In [None]:
train_csv["cell_type"].unique()

In [None]:
cell_types = train_csv["cell_type"].value_counts()

plt.figure(figsize=(10, 6), tight_layout=True)

plt.bar(cell_types.index, cell_types.values)
plt.show()

The distribution is clearly skewed, **shsy5y** cell images domminate the dataset, this might cause our model to have worse performance on other cell types

Data augmentation should be performed on the dataset in order to fight this issue

In [None]:
shy5y_sample = train_csv[train_csv['cell_type'] == 'shsy5y'].sample(2)['id']
cort_sample = train_csv[train_csv['cell_type'] == 'cort'].sample(2)['id']
astro_sample = train_csv[train_csv['cell_type'] == 'astro'].sample(2)['id']

**Let's see how the annotations look like**

In [None]:
def display_sample(sample_ids, n_samples=2, hspace=-0.6):
    '''
    Function to visualize images and their annotations
    
    sample_ids - list of ids
    n_samples - number of samples to display
    hspace - parameter for matplotlib, it contorls spacing between images 
    '''
    fig, axs = plt.subplots(n_samples, 3, figsize=(22, 25))
    
    for idx, sample_id in enumerate(sample_ids):
        sample_image = cv2.imread(os.path.join(config.TRAIN_DIR + sample_id + '.png'))
        
        sample_rles = train_csv.loc[train_csv['id'] == sample_id]['annotation'].values
        
        sample_mask_colors = build_masks(sample_rles, (520, 704), colors=True)
        sample_mask = build_masks(sample_rles, (520, 704), colors=False)
        
        axs[idx][0].imshow(sample_image)
        axs[idx][0].axis('off')        
        
        axs[idx][1].imshow(sample_mask_colors)        
        axs[idx][1].axis('off')        
        
        axs[idx][2].imshow(sample_mask)
        axs[idx][2].axis('off')        

    axs[0][0].set_title("Image", fontsize=16)
    axs[0][1].set_title("Mask Color", fontsize=16)
    axs[0][2].set_title("Mask", fontsize=16)

    fig.subplots_adjust(hspace=hspace)
    plt.show()       

**SHY5Y**

In [None]:
display_sample(shy5y_sample)

**CORT**

In [None]:
display_sample(cort_sample)

**ASTRO**

In [None]:
display_sample(astro_sample)

# Preparing training and validation datasets

In order to train the neural network we have to load the data and optimize it for training

first train dataset has to be split into training and validation parts

**Validation dataset** is used to validate how well your model is doing before performing predictions on test dataset. Having a validation dataset helps to detect wheter the model is overtrained

In [None]:
ids = train_csv['id'].unique()

train_ids, val_ids = train_test_split(ids, test_size=config.val_size, random_state=config.seed)

In [None]:
train_ids.shape

Tensorflow provides <code>tf.data.Dataset</code> api which is very usefull when creating data pipelines for ML models, since it supports caching, batching, one can perform data preprocessing using <code>.map</code> method on <code>Dataset</code> object.

Tensorflow docs - https://www.tensorflow.org/guide/data

In [None]:
def load_train_ds():
    '''
    This function creates a generator for train dataset
    '''
    for image_id in train_ids:
        rows = train_csv.loc[train_csv['id'] == image_id]
        image = tf.io.read_file(config.TRAIN_DIR + image_id + '.png')
        image = tf.image.decode_image(image, channels=3, dtype=tf.float32)
        rles = rows['annotation'].values

        mask = build_masks(rles, (520, 704), colors=False)
        mask = tf.cast(tf.image.resize(mask, config.IMG_SHAPE), tf.int32)

        image = tf.image.resize(image, config.IMG_SHAPE)
        image /= 255.0
    
        yield image, mask
        
        
def load_val_ds():
    '''
    This function creates a generator for train dataset
    '''
    for image_id in val_ids:
        rows = train_csv.loc[train_csv['id'] == image_id]
        image = tf.io.read_file(config.TRAIN_DIR + image_id + '.png')
        image = tf.image.decode_image(image, channels=3, dtype=tf.float32)
        rles = rows['annotation'].values

        mask = build_masks(rles, (520, 704), colors=False)
        mask = tf.cast(tf.image.resize(mask, config.IMG_SHAPE), tf.int32)

        image = tf.image.resize(image, config.IMG_SHAPE)
        image /= 255.0
    
        yield image, mask

Having created functions that yield images and masks we can use <code>from_generator</code> method to create datasets from generators

In [None]:
train_ds = tf.data.Dataset.from_generator(
    load_train_ds, 
    output_types=(tf.float32, tf.int32)
)

val_ds = tf.data.Dataset.from_generator(
    load_val_ds, 
    output_types=(tf.float32, tf.int32)
)

As I've mentioned before the dataset is severly inbalanced which can cause our model to perform worse in some cases. One popular and simple method to fight this problem is **data augmentation**.

**Data Augmentation** is a method to artifically expand dataset by slightly modifing existing data in a realistic way, for my case I will flip the images horizontally and vertically

Tensorflow docs - https://www.tensorflow.org/tutorials/images/data_augmentation

In [None]:
def augment_ds(image, mask):
    image = tf.image.random_flip_up_down(image, seed=config.seed)
    mask = tf.image.random_flip_up_down(mask, seed=config.seed)
    
    image = tf.image.random_flip_left_right(image, seed=config.seed)
    mask = tf.image.random_flip_left_right(mask, seed=config.seed)
    
    return image, mask

In [None]:
train_ds = (
    train_ds
    .shuffle(config.BUFFER_SIZE)
    .map(augment_ds)
    .batch(config.BATCH_SIZE)    
    .prefetch(Config.AUTOTUNE)
)

val_ds = val_ds.batch(config.BATCH_SIZE)

In [None]:
sample_batch = next(iter(train_ds))

images, masks = sample_batch

fig, ax = plt.subplots(config.BATCH_SIZE, 2, figsize=(20, 20))

for i in range(config.BATCH_SIZE):
    ax[i][0].imshow(images[i] * 255)
    ax[i][0].axis('off')        
    
    ax[i][1].imshow(masks[i])    
    ax[i][1].axis('off')        

plt.show()

# Modeling

**UNET** is a Conv net architecture proposed by Olaf Ronneberger, Philipp Fischer, Thomas Brox in their paper [U-Net: Convolutional Networks for Biomedical Image Segmentation
](https://arxiv.org/pdf/1505.04597.pdf). It has been very successful in performing semantic segmantation on many benchmarks. The architecture is composed by encoder and decoder networks with a bottleneck in between. Let's see a visualization from the authors.

<img src='https://miro.medium.com/max/680/1*TXfEPqTbFBPCbXYh2bstlA.png'/>

**The encoder** is composed of conv block each with two 3x3 conv layers followed by max pooling with pool size of 2, there is a total of 4 of this layers with number of filters 512, 256, 128, 64

**The bottlenck** is a simple conv block of two 3x3 conv layers with 1024 filters

**The decoder** consists of 4 upsampling conv block, each having tranposed conv layers with filters size of 2 and strides of 2, after upsampling skip connections are added, lastly two conv 3x3 layers are applied

Key idea of UNET are **skip connections**. Output of each encoder layer is added to corresponding decoder layer, this preserves the spatial structure of the input image, since upsampling in the decoder leaves unprecise expansions. Adding output from encoder layer helps with preserving a lot of information.

# Residual UNET 
**Residual UNET** simply is a UNET with residual conv blocks instead of regular conv blocks, the architecture looks like this

<img src='https://ichi.pro/assets/images/max/640/0*Q9iM4_vhdCYDlTsO.png'/>

It was proposed by Zhengxin Zhang, Qingjie Liu, Yunhong Wang in their paper [Road Extraction by Deep Residual U-Net](https://arxiv.org/abs/1711.10684).
<br/>
**What are residual conv blocks?** First let's see a simple visualization of a residual unit

<img src='https://miro.medium.com/max/1140/1*D0F3UitQ2l5Q0Ak-tjEdJg.png' />

In normal conv units tensors are directly propagated throught conv layers. This way of propagation has one big issue - **vanishing gradients**. Vanishing gradients problem occurs when training very deep networks, during backpropagation gradinets are propagated from deeper layers into shallower ones, sometimes the gradients can get smaller (close to 0) at each consecutive layer, this prevents the network from learning. Residual units first save input tensor then propagate them throught conv layers, then add saved input tensor to conv layers output, thus learning **identity mapping** and greatly helping with the vanishing gradients problem.

# Residual UNET with Attention
Attention was introduced to UNET in 2018's paper [Attention U-Net: Learning Where to Look for the Pancreas](https://arxiv.org/pdf/1804.03999) by Ozan Oktay et al.

**What is attention in the context of computer vision?** Attention is very often used in NLP problems as a way to make a model focus more on for example a part of a sentence. In computer vision attention is a mechanism that allows your network to look only at certain parts of image. Such a part is called a **region of interest** (ROI). Looking at only parts of an image increases computational efficiency, while adding only a small amount of parameters. Below is a diagram from the paper, as you can see attention gate is aplied before concatenation skip connetions to decoder layer.

<img src='https://www.researchgate.net/publication/324472010/figure/fig1/AS:614439988494349@1523505317982/A-block-diagram-of-the-proposed-Attention-U-Net-segmentation-model-Input-image-is.png' />

**Why is attention needed for UNET?** Skip connections are main characteristic of UNET, they help to preserve spatial structure in the upsampling layers. One issue with skip connections is that since they come from shallower layers of the network they extract less complex feature maps, this means that many unuseful low-level features are concatenated to the decoder, attention learns which of those features are worth taking a look at and which are just noise. The end result is a more computationaly efficient network and slighlty better performance. Let's break down the attention gate architecutre. 

<img src='https://miro.medium.com/max/1838/1*Q1aMxFm1L6KJeia5wCmC5A.png' />

Attention gate takes as a input a skip connection and the output from the previous decoder layer. Matematically attention gate does the following operation

$$ q_{att}^{l} = \upsilon^{T}(\sigma_1(W^{T}_{x}x^{l}_{i} + W^{T}_{g}g_{i} + b_{g})) + b_{\upsilon} $$
$$ \alpha_{i}^{l} = \sigma_{2}(q_{att}^{l}(x_{i}^{l}, g_{i}; \Theta_{att})) $$

Where $\sigma_{2}(x_{i,c}) = \frac{1}{1+\exp(-x_{i, c})}$, $\Theta_{att}$ contains linear transformations $W_x, W_g$, which are computed using 1x1x1 convolutions for the input tensors. Let's see how we can implement this.

* Input g (previous decoder layer output) and x (skip connection)
* Convolve x with 1x1 filter and stride = 2, and g with 1x1 filter and stride = 1
* Add together x and g
* Apply ReLU activation function
* $\psi =$ 1x1x1 convolution 
* Apply sigmoid activation function
* Upsample sigmoid output to original input size (2x2)
* $att = multiply(upsample, x_{input})$ 
* 1x1 convolution with n_filters = n_input_x_filters and batch normalization

In [None]:
def attention(input_tensor, g, inter_shape):    
    input_shapes = input_tensor.shape
    g_shapes = g.shape
    
    x = tf.keras.layers.Conv2D(inter_shape, 1, 2, padding="same")(input_tensor)
    g = tf.keras.layers.Conv2D(inter_shape, 1, padding="same")(g)


    add = tf.keras.layers.add([x, g])
    relu = tf.keras.layers.Activation('relu')(x)
    
    psi = tf.keras.layers.Conv2D(1, 1, padding="same")(relu)
    sigmoid = tf.keras.layers.Activation('sigmoid')(psi)
    
    upsample = tf.keras.layers.UpSampling2D(size=(2, 2))(sigmoid)
    
    att = tf.keras.layers.multiply([upsample, input_tensor])
    
    output = tf.keras.layers.Conv2D(input_shapes[3], 1, padding="same")(att)
    return tf.keras.layers.BatchNormalization()(output)

def conv_block(input_tensor, n_filters, dropout=0.5, batch_norm=True):
    x_save = tf.keras.layers.Conv2D(n_filters, 3, activation="relu", padding="same")(input_tensor)
    if batch_norm:
        x = tf.keras.layers.BatchNormalization()(x_save)
    
    x = tf.keras.layers.Conv2D(n_filters, 3, activation="relu", padding="same")(x)
    if batch_norm:
        x = tf.keras.layers.BatchNormalization()(x)
    
    if dropout:
        x = tf.keras.layers.Dropout(dropout)(x)
            
    x = tf.keras.layers.add([x, x_save])
    x = tf.keras.layers.Activation("relu")(x)

    return x

def downsample(x, n_filters, dropout=0.5, batch_norm=True):
    res_conn = conv_block(x, n_filters, dropout=dropout, batch_norm=batch_norm)
    
    x = tf.keras.layers.MaxPool2D((2, 2), strides=(2, 2))(res_conn)
    
    return x, res_conn

def upsample(x, n_filters, skip_conn, dropout=0.5, batch_norm=True):
    att = attention(skip_conn, x, n_filters)
    x = tf.keras.layers.Conv2DTranspose(n_filters, (2, 2), strides=2, padding="same", activation="relu")(x)
    x = tf.keras.layers.Concatenate()([x, att])
    x = conv_block(x, n_filters)
    
    if dropout:
        x = tf.keras.layers.Dropout(dropout)(x)

    if batch_norm:
        x = tf.keras.layers.BatchNormalization()(x)
        
    return x
    
def create_model(n_filters):
    inputs = tf.keras.layers.Input(shape=(*Config.IMG_SHAPE, 3))
    
    x, skip_conn1 = downsample(inputs, n_filters)
    x, skip_conn2 = downsample(x, n_filters * 2)
    x, skip_conn3 = downsample(x, n_filters * 4)
    x, skip_conn4 = downsample(x, n_filters * 8)
    
    x = conv_block(x, n_filters * 16)

    x = upsample(x, n_filters * 8, skip_conn4)
    x = upsample(x, n_filters * 4, skip_conn3)    
    x = upsample(x, n_filters * 2, skip_conn2)    
    x = upsample(x, n_filters, skip_conn1)
    
    outputs = tf.keras.layers.Conv2D(config.N_CLASSES, 3, activation="sigmoid", padding="same")(x)
    
    return tf.keras.Model(inputs=inputs, outputs=outputs)

In [None]:
model = create_model(config.N_FILTERS)
model.summary()

Since instance segmantation is binary classification problem, one might think binary crossentropy loss function is a perfect fit, however this is not the case. Binary crossentropy function makes training segmentation models difficult because it conisders only one pixel, it doesn't take into the account the whole image. How can we do better?

<h2>Dice Loss</h2>
Dice coeficient is a statistic from 1940s developed to be a measure of similarity between two samples. It was introduced to the field of computer vision in 2016 for 3d segmantation by Milletari et al.

$$D = \frac{2\sum\limits_{i = 1}^{n} p_{i}g_{i}}{\sum\limits_{i = 1}^{n} p_{i}^{2} + \sum\limits_{i = 1}^{n} g_{i}^{2}}$$

$p_{i}$ and $g_{i}$ are the values of corresponding pixels in reality numerator is intersection of two sets and denominator is the sum of areas of these two sets, dice coefficient values range from 0 to 1 where 1 would mean that the images are practically the same and 0 would mean that there is no similarity at all. Since optimizers in machine learning are trying to minimize the loss function dice loss is defined as
$$\ell = 1 - D$$

<h2>Intersection Over Union</h2>
IoU is this competitions evaluation metric, it is defined as
$$IoU(A, B) = \frac{\mid A \cap B \mid}{\mid A \cup B \mid}$$
Where $A$ and $B$ both are sets, simillariy to dice coeficient it takes values between 0 and 1 where 1 would mean that the two sets are identical and 0 - that the sets have nothing in common. IoU isn't used as a loss function mainly becouse it is not differentiable.

In [None]:
def dice_loss(y_true, y_pred, smooth=1.0):
    y_true = tf.cast(y_true, tf.float32)
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return 1 - (2. * intersection + smooth) / (K.sum(K.square(y_true_f)) + K.sum(K.square(y_pred_f)) + smooth)

def iou_coef(y_true, y_pred, smooth=1):
    intersection = K.sum(K.abs(y_true * y_pred), axis=[1,2,3])
    union = K.sum(y_true,[1,2,3]) + K.sum(y_pred,[1,2,3]) - intersection
    iou = K.mean((intersection + smooth) / (union + smooth), axis=0)
    return iou

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=config.LR)

metrics = [iou_coef]
model.compile(optimizer=optimizer, loss=dice_loss, metrics=metrics)

# Training
**Callbacks** are a way to introduce additional logic to the training loop, for example Tensorflow allows to create a callback that saves model's weigth at the end of each epoch, you can tweak this callback to save only best weights (weights when model's loss is minimized). Tensorflow docs - https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback

The callbacks I am going to use:
- <code>ModelCheckpoint</code>
- <code>ReduceLROnPlateau</code>
- <code>EarlyStopping</code>

In [None]:
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    config.WEIGHTS_PATH,
    save_best_only=True,
    save_weights_only=True,
    verbose=1,
    monitor="val_loss",
    mode="min"
)

rlr_callback = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss', 
    factor=0.01, 
    patience=5, 
    min_delta=1e-2
)

es_callback = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', 
    min_delta=1e-2, 
    patience=15, 
    verbose=1,
    mode='min',
)

In [None]:
history = model.fit(
    train_ds, 
    epochs=config.EPOCHS, 
    validation_data=val_ds,
    callbacks=[cp_callback, rlr_callback, es_callback]
)

In [None]:
history_dict = history.history

fig, ax = plt.subplots(1, 2, figsize=(15, 5), tight_layout=True)

ax[0].plot(history_dict['loss'], label="Training loss", linewidth=3)
ax[0].plot(history_dict['val_loss'], label="Validation loss", linewidth=3)
ax[0].set_xlabel("Epoch")
ax[0].set_ylabel("Loss")
ax[0].set_title("Loss")
ax[0].legend()

ax[1].plot(history_dict['iou_coef'], label="Training IOU", linewidth=3)
ax[1].plot(history_dict['val_iou_coef'], label="Validation IOU", linewidth=3)
ax[1].set_xlabel("Epoch")
ax[1].set_ylabel("IOU")
ax[1].set_title("IOU")
ax[1].legend()

plt.show()

# Testing

In [None]:
# Get the best model

model.load_weights(config.WEIGHTS_PATH)

In [None]:
test_ids = os.listdir(config.TEST_DIR)

def load_test_ds():
    for image_id in test_ids:
        image = tf.io.read_file(os.path.join(config.TEST_DIR, image_id))         
        image = tf.image.decode_image(image, channels=3, dtype=tf.float32)
        image = tf.image.resize(image, config.IMG_SHAPE)
        image /= 255.0
        yield image
        
test_ds = (
    tf.data.Dataset.from_generator(
        load_test_ds, 
        output_types=tf.float32
    )
    .batch(3)
)

In [None]:
preds = model.predict(test_ds)

In [None]:
preds = (preds > 0.5).astype(np.int32)

In [None]:
test_iter = next(iter(test_ds))

fig, ax = plt.subplots(3, 2, figsize=(20, 20))

for i in range(3):
    ax[i][0].imshow(test_iter[i] * 255)
    ax[i][0].axis('off')        
    
    ax[i][1].imshow(preds[i])
    ax[i][1].axis('off')        
        
plt.show()