This notebook is to infer <a href=https://www.kaggle.com/c/sartorius-cell-instance-segmentation>Sartorius Competition</a>

- Train my notebook: https://www.kaggle.com/yoshikuwano/sartorius-tf-train-tfrecords
- Image : png images -> ndarray -> tfrecord (split into tiles)

Ref. https://www.kaggle.com/wrrosa/hubmap-tf-with-tpu-efficientunet-512x512-subm

<a class='anchor' id='TOC'></a>
# Table of Contents

1. [Packages](#1)
1. [Accelarator](#2)
1. [Parameters](#3)
1. [Input Data](#4)
1. [Generate TFRecord (test data)](#5)
1. [Model](#6)
1. [Predict](#7)

<a class='anchor' id='1'></a>
# 1. Packages
[Back to Table of Contents](#TOC)

In [None]:
import os, glob, gc, re, yaml, json, pathlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import cv2
import tensorflow as tf
import rasterio
from rasterio.windows import Window
from datetime import datetime
from pprint import pprint
from tqdm.notebook import tqdm

import warnings
warnings.filterwarnings('ignore')

<a class='anchor' id='2'></a>
# 2. Accelarator
[Back to Table of Contents](#TOC)

In [None]:
def set_strategy():
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect() # TPU detection
    except ValueError:
        tpu = None
        gpus = tf.config.experimental.list_logical_devices("GPU")

    if tpu:
        strategy = tf.distribute.TPUStrategy(tpu)
        print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
    elif len(gpus) > 1:
        strategy = tf.distribute.MirroredStrategy([gpu.name for gpu in gpus])
        print('Running on multiple GPUs ', [gpu.name for gpu in gpus])
    elif len(gpus) == 1:
        strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
        print('Running on single GPU ', gpus[0].name)
    else:
        strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
        print('Running on CPU')

    print("Number of accelerators: ", strategy.num_replicas_in_sync)

    return strategy

strategy = set_strategy()

<a class='anchor' id='3'></a>
# 3. Parameters
[Back to Table of Contents](#TOC)

In [None]:
# Paths
INPUT_PATH = '../input/sartorius-cell-instance-segmentation/'
TRAIN_PATH = '../input/sartorius-train-unet/'

# Trained models
MODEL_PATH = TRAIN_PATH + 'model/'
print(f'Model path: {MODEL_PATH}')
print(f'Number of models(folds): {len(glob.glob(MODEL_PATH + "model-fold*"))}\n')
    
# Load model metrics
with open(MODEL_PATH + 'metrics.json') as json_file:
    M = json.load(json_file)
print('Model metrics:')
print(f'  train IoU: {round(M["mean_train_iou_coef"], 4)}')
print(f'  valid IoU: {round(M["mean_valid_iou_coef"], 4)}\n')

# Load lodel parameters
print('Model paramaters:')
with open(MODEL_PATH + 'params.yaml') as file:
    P = yaml.load(file, Loader=yaml.FullLoader)
    pprint(P)
    
# Parameters for prediction 
THRESHOLD = 0.5 # When the predicte pixel value > THRESHOLD, the pixel is judged as MASK
CHECKSUM = True # Compute the sum of mask pixel, orn not
MIN_SIZE = 5 # Minimum size of predicted mask which regarded as a object
MAX_SIZE = 500 # Maximum size of predicted mask which regarded as a object
AUTO = tf.data.experimental.AUTOTUNE

<a class='anchor' id='4'></a>
# 4. Input Data
[Back to Table of Contents](#TOC)

## Sample submission

In [None]:
sample_subm = pd.read_csv(INPUT_PATH + 'sample_submission.csv')
display(sample_subm)

## Image data

In [None]:
test_imgs = INPUT_PATH + 'test/' + sample_subm['id'] + '.png'
print(f'test images: {len(test_imgs)} files')
display(test_imgs)

In [None]:
gs = gridspec.GridSpec(1, 3)
plt.figure(figsize = (25, 20))
for i in range(3):
    img = cv2.imread(test_imgs[i])
    img_id = test_imgs[i].split('/')[-1].split('.')[0]
    ax = plt.subplot(gs[i])
    ax.set_title(f'id: {img_id}')
    ax.imshow(img)
    ax.set_aspect('equal')
    plt.axis('on')   

plt.show()

<a class='anchor' id='5'></a>
# 5. Generate TFRecord
[Back to Table of Contents](#TOC)

In [None]:
# Cast datatypes into 1 of the type lists (integer,float and bytes)
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# Make grid of tiles 
def make_grid(shape, window=256, min_overlap=32):
    """
    return: (N,4)-array, 
    N is number of tiles,
    4 is slice points: (x1, x2, y1, y2), where (x1,y1) is left-up point, (x2, y2) is right-bottom point.
    """
    x, y = shape
    nx = x // (window - min_overlap) + 1
    x1 = np.linspace(0, x, num=nx, endpoint=False, dtype=np.int64)
    x1[-1] = x - window
    x2 = (x1 + window).clip(0, x)
    ny = y // (window - min_overlap) + 1
    y1 = np.linspace(0, y, num=ny, endpoint=False, dtype=np.int64)
    y1[-1] = y - window
    y2 = (y1 + window).clip(0, y)
    slices = np.zeros((nx,ny, 4), dtype=np.int64)
    
    for i in range(nx):
        for j in range(ny):
            slices[i,j] = x1[i], x2[i], y1[j], y2[j]    
    return slices.reshape(nx*ny, 4)

# Serialization
def serialize_example_test(image, x1, y1):
    feature = {
        'image': _bytes_feature(image), # tile image
        'x1': _int64_feature(x1), # x left-up point of tile
        'y1': _int64_feature(y1)  # y left-up point of tile
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

# Count images in a tfrecord file
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(f).group(1)) for f in filenames]
    return np.sum(n)

In [None]:
# Config of tiles
BASE = P['BASE_TILE'][0] # Base tile size (Not original image size)
RESIZE = P['RESIZED_TILE'][0] # Re-sized tile size
reduce = BASE//RESIZE # Reduce base image size
MIN_OVERLAP = P['MIN_OVERLAP'] # Overlap width of each tile (Note: Edge image may overlap more than MIN_OVERLAP)

# Path to save tfrecords
TFREC_PATH = f'./tfrec-{len(sample_subm)}-data_{RESIZE}x{RESIZE}-tile/'
P['DATASET'] = TFREC_PATH
if not os.path.exists(TFREC_PATH):
    os.mkdir(TFREC_PATH)
    os.mkdir(TFREC_PATH + 'test')

# For statistics
identity = rasterio.Affine(1, 0, 0,
                           0, 1, 0)

path = pathlib.Path(INPUT_PATH + 'test')
for i, filename in tqdm(enumerate(path.glob('*.png')), total = len(list(path.glob('*.png')))):
    
    dataset = rasterio.open(filename.as_posix(), transform=identity)
    slices = make_grid(dataset.shape, window=BASE, min_overlap=MIN_OVERLAP)
    tfrec_filepath = TFREC_PATH + f'/{filename.stem}.tfrec'
    writer = tf.io.TFRecordWriter(tfrec_filepath) 
    cnt = 0
    for (x1, x2, y1, y2) in slices:
        image = dataset.read(window=Window.from_slices((x1,x2), (y1,y2))) # Shape: (color(3), BASE, BASE)
        image = np.moveaxis(image, 0, -1) # Shape: (BASE, BASE, color(3))
        image = cv2.resize(image, (RESIZE, RESIZE), interpolation = cv2.INTER_AREA) # Shape: (RESIZE, RESIZE, color(3))
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        # Create tf.Example
        example = serialize_example_test(image.tobytes(), x1, y1)
        writer.write(example)
        cnt+=1
        
    writer.close()
    del writer
    
    os.rename(tfrec_filepath, TFREC_PATH + f'/{filename.stem}-{cnt}.tfrec')
    print(f'Generate TFRecord: {filename.stem + "-" + str(cnt) + ".tfrec"}')
    gc.collect();

print(f'Successfully all completed and Saved in {TFREC_PATH}\n')

test_tfrecs = glob.glob(TFREC_PATH + '*.tfrec')
print(f'Number of TFRecord files: {len(test_tfrecs)}')
print(f'Number of total tiles: {count_data_items(test_tfrecs)}')

In [None]:
from skimage.segmentation import mark_boundaries

DIM = RESIZE
def _parse_image(example_proto):
    image_feature = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'x1': tf.io.FixedLenFeature([], tf.int64),
        'y1': tf.io.FixedLenFeature([], tf.int64)
    }
    single_example = tf.io.parse_single_example(example_proto, image_feature)
    image = tf.reshape(tf.io.decode_raw(single_example['image'], out_type=np.dtype('uint8')),
                       (DIM, DIM, 3))
    x1 = single_example['x1']
    y1 = single_example['y1']
    return image, x1, y1

def load_dataset(filenames, ordered=True):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(_parse_image)
    return dataset

def get_dataset(filename, n):
    dataset = load_dataset(filename)
    dataset  = dataset.batch(n)
    dataset = dataset.prefetch(AUTO)
    return dataset

# Confirm 3 tfrecords 
for idx in range(3):
    num_tiles = count_data_items([test_tfrecs[idx]])
    for imgs, x1, y1 in get_dataset(test_tfrecs[idx], num_tiles).take(1):
        break
    print(f'{idx+1}. Sample image: {test_tfrecs[idx].split("/")[-1]}')
    print(f'image shape: {imgs.shape}')
    
    gs = gridspec.GridSpec(6, num_tiles//6)
    plt.figure(figsize = (8, 6))
    for i in range(num_tiles):
        ax1 = plt.subplot(gs[i])
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        ax1.set_aspect('equal')
        ax1.set_axis_off()
        ax1.imshow(imgs[i])
    plt.subplots_adjust(wspace=0.02, hspace=0.02)
    plt.show()

<a class='anchor' id='6'></a>
# 6. Model
[Back to Table of Contents](#TOC)

In [None]:
from tensorflow.keras import backend as K

# Metric
def iou_coef(y_true, y_pred, smooth=1):
    intersection = K.sum(K.abs(y_true * y_pred), axis=[1,2,3])
    union = K.sum(y_true,[1,2,3]) + K.sum(y_pred,[1,2,3]) - intersection
    iou = K.mean((intersection + smooth) / (union + smooth), axis=0)
    return iou

# Loss
def bce_dice_loss(y_true, y_pred):
    
    def dice_loss(y_true, y_pred):
        smooth = 1.
        y_true_f = K.flatten(y_true)
        y_pred_f = K.flatten(y_pred)
        intersection = y_true_f * y_pred_f
        score = (2. * K.sum(intersection) + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
        return 1. - score
    
    bce_loss_ = tf.keras.losses.binary_crossentropy(tf.cast(y_true, tf.float32), y_pred)
    dice_loss_ = dice_loss(tf.cast(y_true, tf.float32), y_pred)
    return bce_loss_ * 0.5 + dice_loss_*0.5


# Load model
model_paths = sorted(glob.glob(MODEL_PATH + 'model-fold*' ))
models = []
for model_path in tqdm(model_paths, total=len(model_paths)):
    print(f'Model loading from {model_path}')
    with strategy.scope():
        model = tf.keras.models.load_model(model_path,
                                           custom_objects = {'iou_coef'     : iou_coef,
                                                             'bce_dice_loss': bce_dice_loss})
    models.append(model)

<a class='anchor' id='7'></a>
# 7. Predict
[Back to Table of Contents](#TOC)

In [None]:
# RLE Encoder: ndarry -> string
def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    returns: run length as string
    '''
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

# RLE Decoder: string -> ndarry
def rle_decode(annotation, shape):
    '''
    annotation: string
    shape: (height, width)
    return: ndarray, mask: 1, background: 0
    '''
    rle = annotation.split() # Even elements are starts, odd elements are the lengths.
    starts  = np.asarray(rle[0:][::2], dtype=int)
    lengths = np.asarray(rle[1:][::2], dtype=int)
    starts -= 1 # Run-length start is numbered from one, on the other hand, list is numbered from zero.
    ends = starts + lengths
    
    mask = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for start, end in zip(starts, ends):
        mask[start:end] = 1

    return mask.reshape(shape)

In [None]:
path = pathlib.Path(INPUT_PATH + 'test/')

predictions = {}
for i, filename in tqdm(enumerate(path.glob('*.png')), total = len(list(path.glob('*.png')))):
    dataset = rasterio.open(filename.as_posix(), transform=identity)
    test_tfrec = glob.glob(TFREC_PATH + filename.stem + '*.tfrec')[0]
    num_tiles = count_data_items([test_tfrec])
    print(f'{i+1}. Predicting {test_tfrec.split("/")[-1]}')
        
    # Predict per tile
    with strategy.scope():
        for fold, model in enumerate(models):
            if fold == 0:
                pred_tile  = model.predict(get_dataset(filename=test_tfrec, n=num_tiles))
            else:
                pred_tile += model.predict(get_dataset(filename=test_tfrec, n=num_tiles))
    # Average over folds and Resize
    pred_tile = pred_tile/len(models)
    pred_tile = tf.image.resize(pred_tile, (BASE,BASE))
    # Mask probability is converted to boolean by THRESHOLD 
    pred_tile = tf.cast(pred_tile>THRESHOLD, tf.bool).numpy().squeeze() # shape: (#tiles, BASE, BASE)

    # Put together tiles to be the shape of original image
    pred_together = np.zeros(dataset.shape, dtype=np.uint8)    
    idx = 0
    for img, X1, Y1 in get_dataset(filename=test_tfrec, n=num_tiles):
        for fi in range(X1.shape[0]):
            x1 = X1[fi].numpy()
            y1 = Y1[fi].numpy()
            pred_together[x1:(x1+BASE), y1:(y1+BASE)] += pred_tile[idx]
            idx += 1

    pred_together = (pred_together>0.5).astype(np.uint8)
    if CHECKSUM:
        print(f'   Checksum: {str(np.sum(pred_together))}')
   
    predictions[i] = {'id': filename.stem, 'predicted': rle_encode(pred_together)}
    

In [None]:
# Split the mask of semantic segmention into instance segments
def instance_segment(mask, min_size=MIN_SIZE, max_size=MAX_SIZE):
    num_component, component, stats, centroids = cv2.connectedComponentsWithStats(mask)
    instance_segments = []
    for c in range(num_component):
        mask_pixels = (component == c)
        mask_size = max(stats[c][cv2.CC_STAT_HEIGHT], stats[c][cv2.CC_STAT_WIDTH])
        if (mask_size > min_size) & (mask_size < max_size): # judge object or not
            instance_segment = np.zeros((520, 704), np.float32)
            instance_segment[mask_pixels] = 1
            instance_segments.append(instance_segment)
    return instance_segments


submission = {}
count_all_instances = 0
for pred in predictions.values():
    id_ = pred['id']
    pred_segmentation = rle_decode(pred['predicted'], (520, 704))
    
    # Split into instance segments
    pred_instances = instance_segment(pred_segmentation)
    pred_rles = [rle_encode(pred_instance) for pred_instance in pred_instances]
    count_instances_per_img = 0
    for pred_rle in pred_rles:
        submission[count_all_instances] = {'id': id_, 'predicted': pred_rle}
        count_instances_per_img += 1
        count_all_instances += 1
    
    print(f'ID: {id_},   Number of instances: {count_instances_per_img}')

In [None]:
pd.DataFrame.from_dict(submission, orient='index').to_csv('submission.csv', index=False)
display(pd.read_csv('submission.csv'))

### Confirm image and predicted mask

In [None]:
# Build mask image from all annotations with same id
def build_masks(annotations, shape, distinguish_objects=False):
    '''
    annotation_list: List[string]
    shape: (height, width)
    return: ndarray, mask: integer 1,2,3,..., background:0
    '''
    masks = np.zeros(shape, dtype=np.uint8)
    for i, annotation in enumerate(annotations):
        mask = rle_decode(annotation, shape)
        if distinguish_objects:
            masks = np.where(mask==0, masks, i+1)
        else:
            masks = np.where(mask==0, masks, 1)
    
    return masks

from PIL import Image, ImageEnhance
def plot_image_and_mask(img, mask, title=None):
      
    fig, ax = plt.subplots(1, 4, figsize=(20,4))
    
    ax[0].set_title('Original image')
    ax[0].imshow(img)
    
    ax[1].set_title('High contrasted img')
    img_hc = img.max() - img
    img_hc = np.asarray(ImageEnhance.Contrast(Image.fromarray(img_hc)).enhance(24))
    ax[1].imshow(img_hc)
       
    ax[2].set_title('Prediction')
    ax[2].imshow(mask, cmap='inferno')
    
    ax[3].set_title('Image + Prediction')
    mask_ = np.tile(np.expand_dims(mask, 2), 3) # shape: (height, width) -> (height, width, 3) 
    mask_ = np.clip(mask_, 0, 1)*255 # mask: (255,255,255), background: (0,0,0)
    mask_[:,:,2] = 0 # mask: (255,255,0): yellow
    mask_ = mask_.astype(np.uint8) # type: np.uint16 -> np.unit8
    merge_img_mask = cv2.addWeighted(img_hc, 0.80, mask_, 0.20, gamma=0.0)
    ax[3].imshow(merge_img_mask)
    
    fig.suptitle(title, fontsize=14)

In [None]:
df_subm = pd.read_csv('submission.csv').sort_values('id')
# Anotations are grouped by id
df_tmp = df_subm.drop_duplicates('id').reset_index(drop=True)
df_tmp["predicted"] = df_subm.groupby('id')['predicted'].agg(list).reset_index(drop=True)
df_subm = df_tmp.copy()
for i in range(3):
    img_id = df_subm.loc[i, 'id']
    img = cv2.imread(INPUT_PATH + 'test/' + img_id + '.png')
    mask = build_masks(annotations = df_subm.loc[i, 'predicted'],
                       shape=(520, 704),
                       distinguish_objects=True)
    plot_image_and_mask(img, mask, title=f'id: {img_id}')