## Imports

In [None]:
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from matplotlib import pyplot as plt
import cv2
from sklearn.preprocessing import Binarizer
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.utils import plot_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, MaxPool2D, SpatialDropout2D, Concatenate, LeakyReLU

## Constants and Configuration

In [None]:
RANDOM_STATE = 7
INPUT_IMG_SHAPE = (520, 704)
TARGET_IMG_SHAPE = (512,  704)
POWER = 2
VAL_SIZE = 0.1
LEARNING_RATE = 1e-4
BATCH_SIZE = 4
EPOCHS = 100
PLOTS_DPI = 150
PATIENCE = 6

np.random.seed(RANDOM_STATE)

## EDA

### Loading training data

In [None]:
train_df = pd.read_csv('../input/sartorius-cell-instance-segmentation/train.csv')
train_df.head()

In [None]:
train_df.shape

### Null check

In [None]:
train_df.isnull().sum().sum()

### Number of unique images

In [None]:
train_df["id"].nunique()

### Unique cell types

In [None]:
train_df["cell_type"].unique()

### Image sizes
All images defined in `train_df` are of the same size - 520 * 704

In [None]:
train_df[["height", "width"]].describe()

### Number of annotations
Number of annotations per image are very varied with the minimum being 4 and maximum being 790

In [None]:
annot_counts = train_df.groupby('id')[['annotation']].count().sort_values('annotation')
annot_counts

In [None]:
annot_counts.describe()

In [None]:
plt.style.use('ggplot')
plt.figure(figsize = (10, 6))
plt.hist(annot_counts, bins = 50, alpha = 0.8)
plt.xlabel("Number of annotations")
plt.ylabel("Count")
plt.title("Number of annotations per image")
plt.show()

## Mask generation and pixel distribution

5 sample images have been chosen for visualization:
- Random image of cell type `shsy5y`
- Random image of cell type `astro`
- Random image of cell type `cort`
- Image with least number of annotations
- Image with most number of annotations

In [None]:
def rle_decode(mask_rle, shape, color = 1):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height, width, channels) of array to return 
    color: color for the mask
    Returns numpy array (mask)
    '''
    s = mask_rle.split()
    
    starts = list(map(lambda x: int(x) - 1, s[0::2]))
    lengths = list(map(int, s[1::2]))
    ends = [x + y for x, y in zip(starts, lengths)]
    
    img = np.zeros((shape[0] * shape[1], shape[2]), dtype = np.float32)
            
    for start, end in zip(starts, ends):
        img[start : end] = color
    
    return img.reshape(shape)

def get_grayscale_mask(image_id, annots):
    grayscale_mask = np.zeros((*INPUT_IMG_SHAPE, 1))
    
    for annot in annots:
            grayscale_mask += rle_decode(annot, shape = (*INPUT_IMG_SHAPE, 1))
    
    return grayscale_mask.clip(0, 1)

def get_rgb_mask(image_id, annots):
    rgb_mask = np.zeros((*INPUT_IMG_SHAPE, 3))
    
    for annot in annots:
        rgb_mask += rle_decode(annot, shape = (*INPUT_IMG_SHAPE, 3), color = np.random.rand(3))
    
    return rgb_mask.clip(0, 1)

def get_image(image_id):
    image = cv2.imread(f"../input/sartorius-cell-instance-segmentation/train/{image_id}.png", cv2.IMREAD_GRAYSCALE)
    return image.reshape(*INPUT_IMG_SHAPE, 1)

def plot_images(image_ids):
    n = len(image_ids)
    grayscale_masks = []
    rgb_masks = []
    images = []
    celltypes = []
    
    for img_id in image_ids:
        row = train_df[train_df["id"] == img_id]

        annots = row["annotation"].tolist()
        celltypes.append(row["cell_type"].tolist()[0])
    
        grayscale_masks.append(get_grayscale_mask(img_id, annots))
        rgb_masks.append(get_rgb_mask(img_id, annots))
        
        images.append(get_image(img_id))
    
    plt.figure(figsize = (20 , 4 * n))
    
    for i in range(n):
        
        plt.subplot(n, 4, (i * 4) + 1)
        plt.imshow(images[i], cmap = 'gray')
        plt.title(f'{image_ids[i]} - {celltypes[i]}', fontsize = 16)
        plt.axis("off")

        plt.subplot(n, 4, (i * 4) + 2)
        plt.imshow(images[i] * grayscale_masks[i], cmap = 'gray')
        plt.title('Input image with mask', fontsize = 16)
        plt.axis("off")

        plt.subplot(n, 4, (i * 4) + 3)
        plt.imshow(rgb_masks[i])
        plt.title('RGB mask', fontsize = 16)
        plt.axis("off")
        
        plt.subplot(n, 4, (i * 4) + 4)
        plt.hist(images[i].flatten(), bins = 255, range = (0, 255))
        plt.title('Pixel distribution', fontsize = 16)
    
    plt.suptitle("Sample images, masks and their pixel distributions", fontsize = 24)
    plt.tight_layout(rect = [0, 0, 0.90, 1])
    plt.show()
    return grayscale_masks, rgb_masks, images

sample_ids = ['0030fd0e6378','0140b3c8f445','01ae5a43a2ab', 'e92c56871769', 'c4121689002f']
_, _, images = plot_images(sample_ids)

## Image transformation
Images have very less contrast as all pixel values are very close together. Pixel values have been updated to rectify that.

### Transforming input images

In [None]:
def transform_image(img_data, power = 2):
    img_data_mask = np.ones_like(img_data, dtype = np.int16)
    img_data_mask[img_data < 127.5] = -1
    
    img_data_transformed = img_data.astype(np.int16) - 127.5
    img_data_transformed[img_data_transformed > 0] = np.power(img_data_transformed[img_data_transformed > 0], 1 / power)
    img_data_transformed[img_data_transformed < 0] = np.power(-img_data_transformed[img_data_transformed < 0], 1 / power)
    img_data_transformed = ((img_data_transformed * img_data_mask) / (2 * np.power(127.5, 1 / power))) + 0.5
    
    return img_data_transformed

def plot_transformed_images(images, transformed_images):
    n = len(images)
    
    plt.figure(figsize = (20 , 4 * n))
    
    for i in range(n):
        plt.subplot(n, 4, (i * 4) + 1)
        plt.imshow(images[i], cmap = 'gray')
        plt.title(f'{sample_ids[i]} - Original', fontsize = 16)
        plt.axis("off")
        
        plt.subplot(n, 4, (i * 4) + 2)
        plt.hist(images[i].flatten() / 255, bins = 255, range = (0, 1))
        plt.title('Original pixel distribution', fontsize = 16)
        
        plt.subplot(n, 4, (i * 4) + 3)
        plt.imshow(transformed_images[i], cmap = 'gray')
        plt.title('Transformed image', fontsize = 16)
        plt.axis("off")
        
        plt.subplot(n, 4, (i * 4) + 4)
        plt.hist(transformed_images[i].flatten(), bins = 255, range = (0, 1))
        plt.title('Pixel distribution after transformation', fontsize = 16)

    plt.suptitle("Image transformation", fontsize = 24)
    plt.tight_layout(rect = [0, 0, 0.90, 1])
    plt.show()
    
plot_transformed_images(images, [transform_image(image) for image in images])

### Effect of transformation for different values of power

In [None]:
n = len(images)
plt.figure(figsize = (20 , 4.5 * n))

for i in range(n):
    plt.subplot(n, 4, (i * 4) + 1)
    plt.imshow(images[i], cmap = 'gray')
    plt.title(f'{sample_ids[i]} - Original', fontsize = 16)
    plt.axis("off")
    
    plt.subplot(n, 4, (i * 4) + 2)
    plt.imshow(transform_image(images[i], 2), cmap = 'gray')
    plt.title('Power = 2', fontsize = 16)
    plt.axis("off")
    
    plt.subplot(n, 4, (i * 4) + 3)
    plt.imshow(transform_image(images[i], 3), cmap = 'gray')
    plt.title('Power = 3', fontsize = 16)
    plt.axis("off")
    
    plt.subplot(n, 4, (i * 4) + 4)
    plt.imshow(transform_image(images[i], 4), cmap = 'gray')
    plt.title('Power = 4', fontsize = 16)
    plt.axis("off")
    
plt.suptitle("Effect of power on image transformation", fontsize = 24)
plt.tight_layout()
plt.show()

## Data preparation

### Loading and transforming the input images and their corresponding grayscale masks

In [None]:
X = []
y = []

image_ids = train_df["id"].unique()
np.random.shuffle(image_ids)

for img_id in tqdm(image_ids, unit = " images", desc = "Loading transformed images and their masks in grayscale"):
    X.append(cv2.resize(transform_image(get_image(img_id), POWER), (TARGET_IMG_SHAPE[1], TARGET_IMG_SHAPE[0])).reshape(*TARGET_IMG_SHAPE, 1))
    
    annots = train_df[train_df["id"] == img_id]["annotation"].tolist()
    y.append(cv2.resize(get_grayscale_mask(img_id, annots), (TARGET_IMG_SHAPE[1], TARGET_IMG_SHAPE[0])).reshape(*TARGET_IMG_SHAPE, 1))
    
X = np.array(X)
y = np.array(y)
y = Binarizer().transform(y.reshape(-1, 1)).reshape(y.shape)

X.shape, y.shape

### Sample input images and output masks

In [None]:
plt.figure(figsize = (20 , 13))

for i, j in enumerate(np.random.randint(len(image_ids), size = 6)):
    plt.subplot(3, 4, (i * 2) + 1)
    plt.imshow(X[j], cmap = 'gray')
    plt.title(f'Input image - {image_ids[j]}', fontsize = 16)
    plt.axis("off")
    
    plt.subplot(3, 4, (i * 2) + 2)
    plt.imshow(y[j], cmap = 'gray')
    plt.title(f'Output mask - {image_ids[j]}', fontsize = 16)
    plt.axis("off")
    
plt.suptitle("Sample inputs and outputs", fontsize = 24)
plt.tight_layout()
plt.show()

## Model

### Model Building

In [None]:
def unet_model():
    input_layer = Input(shape = (*TARGET_IMG_SHAPE, 1), name = 'Input_Layer')
    
    conv_1 = Conv2D(16, 5, padding = 'same', activation = LeakyReLU(), name = 'Conv_1')(input_layer)
    pool_1 = MaxPool2D(name = 'Max_Pool_1')(conv_1)
    spd_1 = SpatialDropout2D(0.1, name = 'SPD_1')(pool_1)
    
    conv_2 = Conv2D(32, 4, padding = 'same', activation = LeakyReLU(), name = 'Conv_2')(spd_1)
    pool_2 = MaxPool2D(name = 'Max_Pool_2')(conv_2)  
    conv_3 = Conv2D(64, 4, padding = 'same', activation = LeakyReLU(), name = 'Conv_3')(pool_2)
    pool_3 = MaxPool2D(name = 'Max_Pool_3')(conv_3)
    spd_2 = SpatialDropout2D(0.1, name = 'SPD_2')(pool_3)
    
    conv_4 = Conv2D(128, 3, padding = 'same', activation = LeakyReLU(), name = 'Conv_4')(spd_2)
    pool_4 = MaxPool2D(name = 'Max_Pool_4')(conv_4)
    conv_5 = Conv2D(256, 3, padding = 'same', activation = LeakyReLU(), name = 'Conv_5')(pool_4)
    pool_5 = MaxPool2D(name = 'Max_Pool_5')(conv_5)
    spd_3 = SpatialDropout2D(0.1, name = 'SPD_3')(pool_5)
    
    conv_6 = Conv2D(512, 2, padding = 'same', activation = LeakyReLU(), name = 'Conv_6')(spd_3)
    pool_6 = MaxPool2D(name = 'Max_Pool_6')(conv_6)
    
    conv_t_1 = Conv2DTranspose(256, 1, padding = 'same', strides = 2, activation = LeakyReLU(), name = 'Conv_T_1')(pool_6)
    concat_1 = Concatenate(name = 'Concat_1')([conv_t_1, spd_3])
    spd_4 = SpatialDropout2D(0.1, name = 'SPD_4')(concat_1)
    
    conv_t_2 = Conv2DTranspose(128, 3, padding = 'same', strides = 2, activation = LeakyReLU(), name = 'Conv_T_2')(spd_4)
    conv_t_3 = Conv2DTranspose(64, 3, padding = 'same', strides = 2, activation = LeakyReLU(), name = 'Conv_T_3')(conv_t_2)
    concat_2 = Concatenate(name = 'Concat_2')([conv_t_3, spd_2])
    spd_5 = SpatialDropout2D(0.1, name = 'SPD_5')(concat_2)
    
    conv_t_4 = Conv2DTranspose(32, 4, padding = 'same', strides = 2, activation = LeakyReLU(), name = 'Conv_T_4')(spd_5)
    conv_t_5 = Conv2DTranspose(16, 4, padding = 'same', strides = 2, activation = LeakyReLU(), name = 'Conv_T_5')(conv_t_4)
    concat_3 = Concatenate(name = 'Concat_3')([conv_t_5, spd_1])
    spd_6 = SpatialDropout2D(0.1, name = 'SPD_6')(concat_3)
    
    conv_t_6 = Conv2DTranspose(8, 5, padding = 'same', strides = 2, activation = LeakyReLU(), name = 'Conv_T_6')(spd_6)
    
    output_layer = Conv2DTranspose(1, 5, padding = 'same', activation = 'sigmoid', name = 'Output_Layer')(conv_t_6)
    
    return Model(inputs = input_layer, outputs = output_layer, name = 'Sartorius')

model = unet_model()
model.compile(optimizer = Adam(LEARNING_RATE), loss = 'binary_crossentropy', metrics = ['accuracy'])
model.summary()

In [None]:
plot_model(model, to_file = 'model.jpg', show_shapes = True, dpi = PLOTS_DPI)

### Training

In [None]:
%%time

early_stop = EarlyStopping(monitor = 'val_loss', patience = PATIENCE, restore_best_weights = True)

history = model.fit(
    X, y,
    batch_size = BATCH_SIZE,
    epochs = EPOCHS,
    validation_split = VAL_SIZE,
    callbacks = [early_stop]
)

### Metrics

In [None]:
loss = history.history['loss']
val_loss = history.history['val_loss']

accuracy = history.history['accuracy']
val_accuracy = history.history['val_accuracy']

epochs_range = history.epoch

plt.figure(figsize = (18, 6))

plt.subplot(1, 2, 1)
plt.plot(epochs_range, loss, label = 'Training Loss')
plt.plot(epochs_range, val_loss, label = 'Validation Loss')
plt.legend(loc = 'upper right', fontsize = 14)
plt.ylim(0, None)
plt.title('Loss', fontsize = 20)

plt.subplot(1, 2, 2)
plt.plot(epochs_range, accuracy, label = 'Training Accuracy')
plt.plot(epochs_range, val_accuracy, label = 'Validation Accuracy')
plt.legend(loc = 'lower right', fontsize = 14)
plt.title('Accuracy', fontsize = 20)

plt.suptitle("Evaluation Metrics", fontsize = 24)
plt.savefig('loss_and_accuracy.jpg', dpi = PLOTS_DPI, bbox_inches = 'tight')
plt.show()

## Predictions

### Getting smaple predictions

In [None]:
num_preds = 5
sample_pred_ids = np.random.randint(len(X), size = num_preds)

preds = model.predict(X[sample_pred_ids])

preds.shape

### Finding optimal threshold for creating a binary mask

In [None]:
threshold_ranges = np.arange(0.25, 0.76, 0.05)

accuracies = []

for threshold in threshold_ranges:
    pred_mask = Binarizer(threshold = threshold).transform(preds.reshape(-1, 1)).reshape(preds.shape)
    accuracies.append((pred_mask == y[sample_pred_ids]).sum() / pred_mask.size)

threshold_results_df = pd.DataFrame({
    'threshold': threshold_ranges,
    'accuracy': accuracies
}).round(3).sort_values('accuracy', ascending = False)

threshold_results_df

### Visualizing Predictions

In [None]:
pred_mask = Binarizer(threshold = threshold_results_df['threshold'].iloc[0]).transform(preds.reshape(-1, 1)).reshape(preds.shape)

plt.figure(figsize = (20 , 20))
for i in range(num_preds):
    plt.subplot(num_preds, 4, (4 * i) + 1)
    plt.imshow(X[sample_pred_ids[i]], cmap = 'gray')
    plt.axis('off')
    plt.title(f'Input image - {image_ids[sample_pred_ids[i]]}', fontsize = 16)
    
    plt.subplot(num_preds, 4, (4 * i) + 2)
    plt.imshow(y[sample_pred_ids[i]], cmap = 'gray')
    plt.axis('off')
    plt.title('Expected output mask', fontsize = 16)
    
    plt.subplot(num_preds, 4, (4 * i) + 3)
    plt.imshow(preds[i], cmap = 'gray')
    plt.axis('off')
    plt.title('Predicted mask', fontsize = 16)
    
    plt.subplot(num_preds, 4, (4 * i) + 4)
    plt.imshow(pred_mask[i], cmap = 'gray')
    plt.axis('off')
    plt.title('Binarized mask', fontsize = 16)
    
plt.suptitle("Sample inputs and outputs", fontsize = 24)
plt.tight_layout(rect = [0, 0, 0.90, 1])
plt.show()