In [None]:
submission_mode=False # ommit preliminary data analysis, just run the model for submission

# ensure reproducibility
seed = 184618484
import random
import numpy as np
import numpy.random as rd
import tensorflow as tf
random.seed(seed)
np.random.seed(seed)
rd.seed(seed)
tf.random.set_seed(seed)

In [None]:
# copy saved models -- turn off for submissions!
#! cp -r /kaggle/input/2021-sartorius-neuronsegmentation/*.model /kaggle/input/2021-sartorius-neuronsegmentation/*.pkl /kaggle/working
#! ls
# adjust for submissions
epochs_to_train = 1024
early_stopping_patience = 16
simulation_epochs = 8
simulation_dt = 4.0

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

from scipy import signal
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import pickle as pkl
from tqdm import tqdm
from copy import copy

from io import StringIO

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib import cm
from matplotlib.patches import Rectangle

from PIL import Image

import os
import glob
print(glob.glob('/kaggle/input/sartorius-cell-instance-segmentation/*'))

import tensorflow_addons as tfa
import tensorflow_probability as tfp
from tensorflow import keras
from tensorflow.keras import layers, models, callbacks
# the line below causes the code to fail if running with GPU
#tf.config.experimental.enable_mlir_graph_optimization()

val_fraction = 0.15

# decide which models to train during the current run of the notebook
task_list = {
    'img_classifier':True, 
    'astro_recognitionNN':True,
    'astro_segmentation':True,
    'cort_recognitionNN':True,
    'cort_segmentation':True,
    'shsy5y_recognitionNN':True,
    'shsy5y_segmentation':True
}

# handle model dependencies
force_retrain = False
dependencies = {
    'img_classifier':[], 
    'astro_recognitionNN':[],
    'astro_segmentation':[],
    'cort_recognitionNN':[],
    'cort_segmentation':[],
    'shsy5y_recognitionNN':[],
    'shsy5y_segmentation':[]
}
for task in dependencies.keys():
    if task_list[task]:
        for dependency in dependencies[task]:
            if not task_list[dependency] and (not os.path.exists(dependency+'.model') or force_retrain):
                task_list[dependency] = True
    
print(task_list)

First, let's plot a number of images from the test set, to see what we're dealing with..

In [None]:
img_filenames = glob.glob('/kaggle/input/sartorius-cell-instance-segmentation/train/*')

if not submission_mode:
    fig = plt.figure(figsize=(16,16))
    gs = gridspec.GridSpec(4,4)
    filenames_to_plot = np.random.choice(img_filenames, 16)
    print(len(img_filenames))
    for i in range(8):
        plt.subplot(gs[2*(i%2), int(i/2)])
        img = np.array(Image.open(filenames_to_plot[i]))
        plt.imshow(img, cmap='inferno')
        print(img.shape)
        plt.subplot(gs[2*(i%2)+1, int(i/2)])
        plt.hist(img.flatten(), 64)
    plt.show()
    plt.close()

Observations:
 - there are 606 train images total,
 - all images have the same shape (520x704) with a **single** color channel,
 - the color scale ranges from 0 to 255, with the vast majority of pixels being at the half of it (128),
 - the most noticeable cell features appear dark on the background.

In [None]:
# Let's confirm these observations on the full dataset:
ok = True
if False: # already checked
    for filename in tqdm(img_filenames):
        img = np.array(Image.open(filenames_to_plot[i]))
        if img.shape != (520, 704):
            ok = False
            break
        img = img.flatten()
        if np.min(img) < 0 or np.max(img) > 255:
            ok = False
            break
        if np.abs(np.median(img) - 128) > 20.:
            ok = False
            break
print(ok)

It may be easier to deal with the images (even visually) if we rescale them to have a mean of 0 and the color to vary between -1 and 1.

In [None]:
def img_color_rescale (img):
    return (img-np.mean(img.flatten()))/128.

In [None]:
if not submission_mode:
    fig = plt.figure(figsize=(16,16))
    gs = gridspec.GridSpec(4,4)
    filenames_to_plot = np.random.choice(img_filenames, 16)
    for i in range(8):
        plt.subplot(gs[2*(i%2), int(i/2)])
        img = img_color_rescale(np.array(Image.open(filenames_to_plot[i])))
        plt.imshow(img, cmap='seismic')
        plt.subplot(gs[2*(i%2)+1, int(i/2)])
        plt.hist(img.flatten(), 64)
    plt.show()
    plt.close()

    fig = plt.figure(figsize=(16,16))
    img = img_color_rescale(np.array(Image.open(filenames_to_plot[0])))
    plt.imshow(img, cmap='seismic')
    plt.show()
    plt.close()

Ok, this looks much better. Now let's have a look at the label segmentation...

In [None]:
train_csv = pd.read_csv('/kaggle/input/sartorius-cell-instance-segmentation/train.csv', parse_dates=['sample_date'])
for time_col in ['plate_time', 'elapsed_timedelta']:
    train_csv[time_col] = pd.to_timedelta(train_csv[time_col])

In [None]:
if not submission_mode:
    print(train_csv.head())

In [None]:
if not submission_mode:
    print(train_csv.dtypes)

In [None]:
cell_types = train_csv.cell_type.unique()
if not submission_mode:
    print(cell_types)
    train_csv.cell_type.hist()

In [None]:
if not submission_mode:
    print(sorted(train_csv.plate_time.astype('timedelta64[h]').unique()/60.))
    fig = plt.figure()
    plt.hist(train_csv.plate_time.astype('timedelta64[h]'))
    plt.title('timedelta64[h]')
    plt.show()
    plt.close()

In [None]:
if not submission_mode:
    print(sorted(train_csv.elapsed_timedelta.astype('timedelta64[m]').unique()/60.))
    fig = plt.figure()
    plt.hist(train_csv.elapsed_timedelta.astype('timedelta64[m]'))
    plt.title('timedelta64[m]')
    plt.show()
    plt.close()

In [None]:
if not submission_mode:
    print(sorted((train_csv.plate_time.astype('timedelta64[s]')-train_csv.elapsed_timedelta.astype('timedelta64[s]')).unique()))
    fig = plt.figure()
    plt.hist(train_csv.plate_time.astype('timedelta64[s]')-train_csv.elapsed_timedelta.astype('timedelta64[s]'))
    plt.title('plate time - timedelta [s]')
    plt.show()
    plt.close()

In [None]:
if not submission_mode:
    print(sorted(train_csv.sample_date.unique()))
    fig = plt.figure()
    plt.hist(train_csv.sample_date)
    plt.title("Sample date")
    plt.show()
    plt.close()

In [None]:
if not submission_mode:
    print('No of cells histogram')
    train_csv.groupby('id').width.count().hist()

Observations:
 - there is a clear imbalance in the type of cells: shsy5y are overabundant by about a factor of 5 over astro and cort (which, in turn, have about equal populations),
 - the fields plate_time and elapsed_timedelta are identical,
 - Samples were taken in two rounds, about 2/3 of them in June of 2019 and 1/3 between Sept. and Nov. of 2020. Besides this bimodality, there is a very small number of distinct dates in the dataset, so it is likely that sample_date only designated the date the sample data was included in the database (not the date it was photographed)...
 - There is a large number of images with very small number of cells (<80), and a broader distribution centered around n ~ 400

With these in hand, let us now parse and plot the segmentation masks given in the training dataset. It would be memory-wasteful to store the masks as separate numpy arrays at all times, so let us just use a function to translate the run length encoded pixels from the csv into a numpy array **on the fly**. Fortunately, an efficient solution can be found on Stackoverflow (thanks, jdehesa and M. Innat!).

In [None]:
def rle_decode_tf(mask_rle, shape):
    '''
    Input [string]: Run-length encoded pixel mask
    Output [tf.array of shape shape]: Segmentation mask
    By: M. Innat, jdehesa
    Source: https://stackoverflow.com/questions/58693261/decoding-rle-run-length-encoding-mask-with-tensorflow-datasets
    '''
    shape = tf.convert_to_tensor(shape, tf.int64)
    size = tf.math.reduce_prod(shape)
    # Split string
    s = tf.strings.split(mask_rle)
    s = tf.strings.to_number(s, tf.int64)
    # Get starts and lengths
    starts = s[::2] - 1
    lens = s[1::2]
    # Make ones to be scattered
    total_ones = tf.reduce_sum(lens)
    ones = tf.ones([total_ones], tf.uint8)
    # Make scattering indices
    r = tf.range(total_ones)
    lens_cum = tf.math.cumsum(lens)
    s = tf.searchsorted(lens_cum, r, 'right')
    idx = r + tf.gather(starts - tf.pad(lens_cum[:-1], [(1, 0)]), s)
    # Scatter ones into flattened mask
    mask_flat = tf.scatter_nd(tf.expand_dims(idx, 1), ones, [size])
    # Reshape into mask
    return tf.reshape(mask_flat, shape)

Let's have a look at a single segmentation..

In [None]:
if not submission_mode:
    print(filenames_to_plot[0])
    img_id = filenames_to_plot[0].split('/')[-1].split('.')[0]
    print(train_csv[train_csv.id == img_id])

So there are only 5 segments on this image!

In [None]:
celltype_colors = {'shsy5y':'k', 'astro':'green', 'cort':'cyan'}
def plot_image (img_filename, ax=None):

    # Let us plot an example of the segmentation
    standalone = (ax == None)
    if standalone:
        fig = plt.figure(figsize=(16,16), clear=True)
        ax = plt.gca()

    # plot the image itself
    img = img_color_rescale(np.array(Image.open(img_filename)))
    ax.imshow(img, cmap='seismic', zorder=-3, rasterized=True)
    shape = img.shape
    del img

    # overplot segmentation
    img_id = img_filename.split('/')[-1].split('.')[0]
    mask = np.zeros(shape).astype(np.float)
    buffer = 0. * mask
    for idx, row in tqdm(train_csv[train_csv.id == img_id].iterrows()):
        # decode the mask
        buffer = rle_decode_tf(row.annotation, shape).numpy().astype(np.float)
        mask += buffer
        # draw the bounding box
        ypx, xpx = np.where(buffer == 1)
        xmin, xmax = np.min(xpx), np.max(xpx)
        ymin, ymax = np.min(ypx), np.max(ypx)
        box = Rectangle(xy=(xmin,ymin), width=(xmax-xmin), height=(ymax-ymin), fill=False, ec=celltype_colors[row.cell_type], zorder=-1, rasterized=True)
        ax.add_patch(box)
        # add annotations to mark the cell details
        ax.text(xmin,ymin, row.cell_type, color=celltype_colors[row.cell_type], va='bottom', weight='bold', zorder=-1, rasterized=True)
    # draw the mask
    mask[mask > 0] = 1
    mask[mask == 0] = np.nan
    cmap = cm.get_cmap('YlGn_r')
    cmap.set_bad(alpha=0.0)
    ax.imshow(mask, cmap=cmap, alpha=0.5, zorder=-2, rasterized=True)
    del mask, buffer
        
    # rasterize to conserve RAM
    ax.set_rasterization_zorder(0)
        
    if standalone:
        plt.show()
        fig.clear()
        plt.close(fig)
        del fig

In [None]:
if not submission_mode:
    for filename in filenames_to_plot[:3]:
        plot_image(filename)

 - cort-type cells are tiny, circular, and only occupy small spaces on their cortesponding images
 - astro-type cells are HUGE, occupy significant fraction of the cortesponding images (both in surface area, as well as linear dimensions), and are elongated / brancing in shape
 - shsy5y are somewhat between these two types in both size and appearance
 - **the three types of cells are clearly found in different environments**
  - cort are surrounded by a spiderweb-like network of threads (axions / dendrites, neural connections?)
  - astro usually occur by themselves, with plain featureless background (given size of other cells present, we can be relatively confident that the scale is the same?)
  - again, shsy5y is the transitional type

Given these observations, it may be helpful to work on two separate models -- one for cort-like cells and one for astro-like cells (a simple background-classifying model should be able to tell us which one we should use) -- and then merge them into a combined model that could handle the transitional shsy5y type as well.

First, let us find out the percentages of types for each picture..

In [None]:
img_topics = pd.DataFrame(train_csv['id'].unique(), columns=['id']).copy()
for cell_type in cell_types:
    img_topics['perc_%s' % cell_type] = train_csv.groupby('id')['cell_type'].apply(lambda x : np.sum(x == cell_type)).reset_index()['cell_type'] * 1.0 / train_csv.groupby('id')['cell_type'].apply(len).reset_index()['cell_type']
if not submission_mode:
    img_topics.head()

In [None]:
if not submission_mode:
    for cell_type in cell_types:
        print('perc_%s\n - unique values: ' % cell_type, img_topics['perc_%s' % cell_type].unique())
        print(' - no of images: ', np.sum(img_topics['perc_%s' % cell_type] == 1))

 - There are NO images with more than one type of cells (labeled) on them.
 - The number of images is not too large -- we **definitely** need to use augmentation to train a model to distinguish between cort and astro-like cells!
 
The problem of distinguishing these images does, however, seem relatively simple (they are immediately different even to my non-expert eye), so we probably don't need to resort to any state-of-the art pre-trained networks. Instead, let us use a simple CNN trained from scratch on the images.

In [None]:
# since our dataset is imbalanced, we will need sample weights to balance the training..
sample_weights = img_topics.drop('id', axis=1).sum()
sample_weights /= sample_weights.sum()
sample_weights = sample_weights.to_dict()
print(sample_weights)
img_topics['sample_weight'] = img_topics.drop('id', axis=1).apply(
    lambda row : np.sum([sample_weights[x] if row[x] > 0 else 0 for x in img_topics.drop('id', axis=1).columns]), axis=1)
print(img_topics.head())

In [None]:
# reload the dataset, just to be safe
train_csv = pd.read_csv('/kaggle/input/sartorius-cell-instance-segmentation/train.csv', parse_dates=['sample_date'])
for time_col in ['plate_time', 'elapsed_timedelta']:
    train_csv[time_col] = pd.to_timedelta(train_csv[time_col])

# before we start, let's set apart the validation sample
all_ids = train_csv.id.unique()
val_ids = rd.choice(all_ids, int(val_fraction * len(all_ids)))
val_csv = train_csv[train_csv.id.isin(val_ids)]
train_csv = train_csv[~train_csv.id.isin(val_ids)]

print(len(train_csv), len(val_csv))

In [None]:
# split the sample weights
train_weights = img_topics[~img_topics.id.isin(val_ids)]['sample_weight'].to_numpy()
val_weights = img_topics[img_topics.id.isin(val_ids)]['sample_weight'].to_numpy()

In [None]:
# ... and finally let us prepare the labels
train_labels = img_topics[~img_topics.id.isin(val_ids)].copy().drop(['id', 'sample_weight'], axis=1).to_numpy(dtype=np.int32)
val_labels = img_topics[img_topics.id.isin(val_ids)].copy().drop(['id', 'sample_weight'], axis=1).to_numpy(dtype=np.int32)
print(train_labels.shape, val_labels.shape)

In [None]:
# preprocessing
def tf_img_load (img_id, directory='train'):
    filestem = ('/kaggle/input/sartorius-cell-instance-segmentation/%s/' % directory)
    img = tf.keras.utils.load_img(filestem+img_id+'.png', color_mode='grayscale')
    return tf.keras.preprocessing.image.img_to_array(img)
def tf_img_color_rescale (img):
    return (img-tf.reduce_mean(img))/128.

In [None]:
def tf_img_augment (model):
    model.add(layers.RandomTranslation(
        height_factor=(-1,1),
        width_factor=(-1,1),
        fill_mode='wrap'
    ))
    model.add(layers.RandomZoom(
        height_factor=(-0.3, 0.1), # zoom-in 30% to zoom-out 10%
        fill_mode='wrap')
    )
    model.add(layers.RandomRotation(
        factor=(-1,1), # full rotation
        fill_mode='wrap')
    )
    model.add(layers.RandomFlip())
    return model

In [None]:
# in principle, we could do image loading as part of the model
# , but, since there aren't that many of them,
# it seems more sebsible to load them into our RAM memory all at once
if task_list['img_classifier']:
    train_imgs = []
    for img_id in tqdm(train_csv.id.unique()):
        train_imgs.append(
            tf_img_color_rescale(
                tf_img_load(img_id)
            ).numpy()
        )
    train_imgs = np.array(train_imgs)

    val_imgs = []
    for img_id in tqdm(val_csv.id.unique()):
        val_imgs.append(
            tf_img_color_rescale(
                tf_img_load(img_id)
            ).numpy()
        )
    val_imgs = np.array(val_imgs)
    print(train_imgs.shape, val_imgs.shape)

In [None]:
# build our model
model_name = 'img_classifier'
model = models.Sequential()
model.add(keras.Input(shape=(520,704,1)))
# first, augment the input image
#model = tf_img_augment(model)
# now some vanilla CNN layers
model.add(layers.Conv2D(
    filters=32, kernel_size=(5,5),
    activation='relu')
)
model.add(layers.MaxPooling2D((2,2)))
model.add(layers.Conv2D(
    filters=32, kernel_size=(5,5),
    activation='relu')
)
model.add(layers.MaxPooling2D((2,2)))
model.add(layers.Conv2D(
    filters=32, kernel_size=(5,5),
    activation='relu')
)
model.add(layers.MaxPooling2D((2,2)))
# and finish up with a perceptron
model.add(layers.Flatten())
model.add(layers.Dense(
    units=128, activation='relu'
))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(
    units=3, activation='relu'
))

model.compile(optimizer='adam',
              loss=tf.keras.losses.CategoricalCrossentropy(
                  from_logits=True
              ),
              metrics=['accuracy'])

print('input', model.input.shape)
for layer in model.layers:
    print(layer.name, layer.output_shape)

In [None]:
if task_list[model_name]:
    stop_early = tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=early_stopping_patience,
        restore_best_weights=True
    )
    callbacks = [stop_early,]
    history = model.fit(
        train_imgs,
        train_labels,
        epochs=epochs_to_train,
        validation_data=(val_imgs, val_labels),
        sample_weight=train_weights,
        callbacks=callbacks
    )

In [None]:
if task_list[model_name] and not submission_mode:
    plt.plot(history.history['loss'], label='loss')
    plt.plot(history.history['val_loss'], label = 'val_loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend(loc='best')
    plt.show()
    plt.close()

Hmm.. without data augmentation, the network clearly memorizes the training data (as expected, for a measly sample of 300 images). While the validation performance surprisingly isn't bad (at 98%!), it would be embarassing to use such an overfit model. Let's do this again, with image augmentation turned on:

In [None]:
# build our model
# prepend the previous pre-trained model with an augmentation layer
new_model = models.Sequential()
new_model.add(keras.Input(shape=(520,704,1)))
# first, augment the input image
new_model = tf_img_augment(new_model)
for layer in model.layers:
    new_model.add(layer)
model = new_model

model.compile(optimizer='adam',
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

for layer in model.layers:
    print(layer.name, layer.output_shape)

In [None]:
if task_list[model_name]:
    if force_retrain or not os.path.exists(model_name+'.model'):
        stop_early = tf.keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=early_stopping_patience,
            restore_best_weights=True
        )
        callbacks = [stop_early,]
        history = model.fit(train_imgs, train_labels, epochs=epochs_to_train,
                           validation_data=(val_imgs, val_labels),
                           sample_weight=train_weights,
                           callbacks=callbacks)
        model.save(model_name + '.model')
    else:
        model = keras.models.load_model(
            model_name + '.model'
        )
    img_classifier = model

In [None]:
if task_list['img_classifier'] and not submission_mode:
    plt.plot(history.history['loss'], label='loss')
    plt.plot(history.history['val_loss'], label = 'val_loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend(loc='best')
    plt.gca().set_yscale('log')
    plt.show()
    plt.close()

Well, this is probably more reasonable. This proves that we can recognize environments of our different cell classes in the images.

Thus, let us proceed to build three separate segmentation models, each specified for one type of neurons. We will then connect them to the model we've just trained, to build our solution to the competition problem.

In [None]:
# clean up after the img_classifier training
if task_list['img_classifier']:
    del train_imgs, val_imgs

---------------

**ASTRO-cell segmentation**

Let's start with a segmentation model to recognize the astro-type cells. First, let's create a dataset containing only images with these types.

In [None]:
if not submission_mode:
    img_topics.head()

In [None]:
astro_ids = img_topics[img_topics.perc_astro == 1.0].id.to_numpy()
print(len(astro_ids))

That's not a lot of images... Hopefully image segmentation will help us...

Let's split into train and validation sample:

In [None]:
astro_csv = train_csv[train_csv.id.isin(astro_ids)]
astro_val_ids = rd.choice(astro_ids, int(val_fraction * len(astro_ids)))
astro_val_csv = astro_csv[astro_csv.id.isin(astro_val_ids)]
astro_train_csv = astro_csv[~astro_csv.id.isin(astro_val_ids)]
print(len(astro_train_csv), len(astro_val_csv))

At least the number of cells is pretty decent.

In [None]:
# in principle, we could do image loading as part of the model
# , but, since there aren't that many of them,
# it seems more sensible to load them into our RAM memory all at once
shape = (520,704)
save_path = 'astro_train_data.pkl'
if task_list['astro_recognitionNN'] or task_list['astro_segmentation']:
    if force_retrain or not os.path.exists(save_path):
        astro_train_imgs = []
        astro_train_masks = []
        for img_id in tqdm(astro_train_csv.id.unique()):
            astro_train_imgs.append(
                tf_img_color_rescale(
                    tf_img_load(img_id)
                ).numpy()
            )
            # could probably make this faster with a tensorflow map, but as-is, a dtype error is returned (can't map a string-type tensor to int)...
            mask = np.zeros(shape)
            for ann in astro_train_csv[astro_train_csv.id == img_id].annotation:
                mask += rle_decode_tf(ann, shape=shape)
            mask = tf.where(mask > 0, 1,0)
            astro_train_masks.append(mask.numpy().astype(np.uint8))
        astro_train_imgs = np.array(astro_train_imgs)
        astro_train_masks = np.array(astro_train_masks)

        astro_val_imgs = []
        astro_val_masks = []
        for img_id in tqdm(astro_val_csv.id.unique()):
            astro_val_imgs.append(
                tf_img_color_rescale(
                    tf_img_load(img_id)
                ).numpy()
            )
            mask = np.zeros(shape)
            for ann in astro_val_csv[astro_val_csv.id == img_id].annotation:
                mask += rle_decode_tf(ann, shape=shape)
            mask = tf.where(mask > 0, 1,0)
            astro_val_masks.append(mask.numpy().astype(np.uint8))
        astro_val_imgs = np.array(astro_val_imgs)
        astro_val_masks = np.array(astro_val_masks)
        with open(save_path, 'wb') as f:
            pkl.dump((astro_train_imgs, astro_train_masks, astro_val_imgs, astro_val_masks), f)
    else:
        with open(save_path, 'rb') as f:
            astro_train_imgs, astro_train_masks, astro_val_imgs, astro_val_masks = pkl.load(f)
    print(astro_train_imgs.shape, astro_val_imgs.shape)
    print(astro_train_masks.shape, astro_val_masks.shape)

In [None]:
if not submission_mode:
    plt.imshow(astro_val_masks[2])
    plt.show()
    plt.close()

Let's have a look at a couple of the images, just to refresh our memory..

In [None]:
if not submission_mode:
    for filename in ['/kaggle/input/sartorius-cell-instance-segmentation/train/' + x + '.png' for x in astro_ids][:3]:
        plot_image(filename)

 - The cells cortespond to a large fraction of the image, so the segmentation network's footprint should be at least equal to ~25% of the image, to capture (ideally) whole cells
 - Cell segmentation maps contain fine structures (like threads, spikes, etc.), so lowering the image resolution to generate a segmentation map might not be the best idea...
 
 Given these observations, let us start with a "brute-force" idea, with a lossless convolutional network at a fixed size (no maxpool) mapping our image to the segmentation map..

In [None]:
# due to imbalance numbers of on-cell and off-cell pixels, let's use a weighing function
def balanced_loss (y_true, y_pred):
    loss_falsepositive = tf.reduce_mean(tf.where(y_true > 0, (y_true-y_pred)**2, 0))
    loss_falsenegative = tf.reduce_mean(tf.where(y_true < 1, (y_true-y_pred)**2, 0))
    return loss_falsepositive + loss_falsenegative

In [None]:
# since we're working with segmentation now, 
# we need to apply the same augmentation
# to both the input and the label masks
# We will do this by carrying masks over with the images,
# contained within an extra dimension.
def tf_img_augment_nomodel (X, y):
    # concatenate along channels
    # to process both the img and mask with the same operations
    buffer = tf.concat([X,y], axis=3)
    # build the augmentation model
    buffer = layers.RandomTranslation(
        height_factor=(-1,1),
        width_factor=(-1,1),
        fill_mode='wrap'
    )(buffer)
    buffer = layers.RandomZoom(
        height_factor=(-0.3, 0.1), # zoom-in 30% to zoom-out 10%
        fill_mode='wrap'
    )(buffer)
    buffer = layers.RandomRotation(
        factor=(-1,1), # full rotation
        fill_mode='wrap'
    )(buffer)
    buffer = layers.RandomFlip()(buffer)
    # split up the results into img and mask
    X = tf.expand_dims(buffer[:,:,:,0], axis=3)
    y = tf.expand_dims(buffer[:,:,:,1], axis=3)
    # clean up
    del buffer
    return X,y

In [None]:
# To work with our custom augmentation,
# we will need to overwrite the fitting method
# - adapted from https://keras.io/guides/customizing_what_happens_in_fit/
class AugmentedSequential (keras.Sequential):
    def train_step(self, data):
        # read the data passed to fit()
        X, y = data
        # augment
        X, y = tf_img_augment_nomodel(X, y)
        # perform the augmentation and turn into tensors
        X, y = tf.constant(X.numpy()), tf.constant(y.numpy())
        
        # then proceed as usual
        return super().train_step((X,y))

In [None]:
# build our astro_model
model_name = 'astro_recognitionNN'
astro_model = AugmentedSequential()
astro_model.add(keras.Input(
    shape=(520,704,1)
))
# now some vanilla CNN layers
if True: # val_loss: 0.0659
    astro_model.add(layers.Conv2D(
        filters=16, kernel_size=(3,3),
        activation='relu',
        padding='same')
    )
    astro_model.add(layers.Conv2D(
        filters=8, kernel_size=(5,5),
        activation='relu',
        padding='same')
    )
    astro_model.add(layers.Conv2D(
        filters=8, kernel_size=(7,7),
        activation='relu',
        padding='same')
    )
# Apply a dense layer along the axis of the filters, keeping the image size the same
astro_model.add(layers.Conv2D(
    filters=8, kernel_size=(1,1),
    activation='relu',
    padding='same')
)
# OUPTUT -----------------------------
#astro_model.add(layers.Dropout(0.5))
astro_model.add(layers.Conv2D(
    filters=1, kernel_size=(1,1),
    activation='sigmoid',
    padding='same')
)

astro_model.compile(optimizer='adam',
              loss=balanced_loss,
                   run_eagerly=True)  #tf.keras.losses.CategoricalCrossentropy(from_logits=False),
              #metrics=['accuracy']

astro_model.summary()

In [None]:
def plot_history (history):
    plt.plot(history.history['loss'], label='loss')
    plt.plot(history.history['val_loss'], label = 'val_loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend(loc='best')
    plt.gca().set_yscale('log')
    plt.show()
    plt.close()

In [None]:
if task_list[model_name]:
    print(model_name)
    if force_retrain or not os.path.exists(model_name+'.model'):
        stop_early = tf.keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=early_stopping_patience,
            restore_best_weights=True
        )
        callbacks = [stop_early,]
        history = astro_model.fit(
            tf.cast(astro_train_imgs,tf.float32).numpy(),
            tf.cast(tf.expand_dims(astro_train_masks, axis=3),tf.float32).numpy(),
            validation_data=(
                tf.cast(astro_val_imgs,tf.float32).numpy(),
                tf.cast(tf.expand_dims(astro_val_masks, axis=3),tf.float32).numpy()),
            callbacks=callbacks,
            epochs=epochs_to_train)
        if not submission_mode:
            astro_model.save(model_name + '.model')
            plot_history(history)
    else:
        astro_model = keras.models.load_model(
            model_name+'.model',
            custom_objects={'balanced_loss':balanced_loss}
        )

In [None]:
# make sure augmentation works, check the predictions
def test_predictions (model):
    for i in range(3):
        X_val, y_val = tf_img_augment_nomodel(
            astro_val_imgs, 
            tf.expand_dims(astro_val_masks, axis=3).numpy()
        )

        ex_img = X_val[4].numpy()
        ex_truemask = y_val[4].numpy()

        # show an example prediction
        fig = plt.figure(figsize=(16,6), clear=True)

        plt.subplot(131)
        plt.imshow(ex_img, cmap='seismic')
        plt.title('Original image')

        plt.subplot(132)
        ex_mask = model.predict(ex_img.reshape((1,520,704,1)))[0]
        print(ex_mask.shape)
        plt.imshow(ex_mask, cmap="YlGn_r")
        plt.title("Model prediction")
        plt.colorbar()

        plt.subplot(133)
        plt.imshow(ex_truemask, cmap="YlGn_r")
        plt.title("True mask")

        plt.show()
        plt.close()
    
if task_list[model_name] and not submission_mode:
    test_predictions(astro_model)

This looks quite good. I have tried some other ideas, but this simple network performed best for me..

Let us now proceed to map this detection mask into segmentation into separate neurons. First, we will need to threshold the map.

In [None]:
# The first step will probably be to binarize our output map with a threshold
# it would be great for the threshold value to be trainable
# Fortunately, this was already done by zeka0 (https://github.com/keras-team/keras/issues/6926)
class ThresholdLayer(keras.layers.Layer):
    def __init__(self, **kwargs):
        super(ThresholdLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        self.kernel = self.add_weight(name="threshold", shape=(1,), initializer="uniform",
                                      trainable=True)
        super(ThresholdLayer, self).build(input_shape)

    def call(self, x):
        return keras.backend.sigmoid(100*(x-self.kernel))

    def compute_output_shape(self, input_shape):
        return input_shape

In [None]:
# First, let's improve the contrast of the network output, to make it easier to segment
if not submission_mode:
    img = astro_train_imgs[5]
    mask = astro_train_masks[5]
    pred = astro_model.predict(np.array([img,]))[0]
    th_layer = ThresholdLayer()
    th_layer.build(input_shape=pred.shape)
    print(th_layer.get_weights())
    th_layer.set_weights([np.array([0.6])])
    th = th_layer(pred).numpy()

    plt.figure(figsize=(20,6), clear=True)
    plt.subplot(151)
    plt.imshow(img, cmap='seismic')
    plt.title('Original image')
    plt.subplot(152)
    plt.imshow(mask, cmap="YlGn_r")
    plt.title('True mask')
    plt.subplot(153)
    plt.imshow(pred, cmap="YlGn_r")
    plt.title('Predicted mask')
    plt.subplot(154)
    plt.imshow(pred > 0.6, cmap="YlGn_r")
    plt.title('Threshold 1')
    plt.subplot(155)
    plt.imshow(th, cmap="YlGn_r")
    plt.title('Threshold 2')
    plt.show()
    plt.close()

    plt.hist(pred.flatten())
    plt.title('Prediction value histogram')
    plt.show()
    plt.close()

Now, we need to segment the cortesponding binary map. First let's try to do the following:
 - translate the np array map into a list of positive pixel coordinates (~ point cloud)
 - perform clustering of the resulting dataset using classical ml methods (K-means?)
 
Ideally, we would need a clustering algorithm that:
 - doesn't require knowing the number of clusters in advance,
 - does well on elongated structures, connected into complex networks.
 
Judgning by [https://scikit-learn.org/stable/modules/clustering.html](https://scikit-learn.org/stable/modules/clustering.html), this would include:
 - DBSCAN,
 - OPTICS.
 
Let's try all of them and see how they perform. Let's use the "true" segmentation mask, to split their performance from the performance of our segmentation map extraction network.

In [None]:
# Let's plot the true segmentation, so that we know what we're looking for...
if not submission_mode:
    myid = astro_train_csv.id.unique()[5]
    astro_train_csv[astro_train_csv.id == myid].head()

    pixels_true = []
    cl = 0
    for annotation in tqdm(tf.random.shuffle( # shuffle helps to make adjacent cells different color
        astro_train_csv[astro_train_csv.id == myid].annotation)
                          ):
        mask_here = rle_decode_tf(annotation, shape=img.shape)
        pixels_true_here = tf.where(tf.squeeze(mask_here) > 0)
        pixels_true_here = tf.concat([pixels_true_here, cl*np.ones([pixels_true_here.shape[0],1])], axis=1)
        cl += 1
        pixels_true.append(pixels_true_here)
    pixels_true = tf.concat(pixels_true, axis=0)

    x = tf.transpose(pixels_true)[0]
    y = tf.transpose(pixels_true)[1]
    c = tf.transpose(pixels_true)[2]

    plt.scatter(x, y, c=c, s=0.125, cmap='prism')
    plt.show()
    plt.close()

In [None]:
if not submission_mode:
    pixels = tf.where(tf.squeeze(mask) > 0).numpy()

    x, y = tf.transpose(pixels)
    plt.scatter(x, y, s=0.125)
    plt.show()
    plt.close()

In [None]:
# Now let's try classical clustering algorithms on this point cloud
from sklearn.cluster import DBSCAN
if not submission_mode:
    clusterer = DBSCAN(eps=1, min_samples=1)

    clusterer.fit(pixels)
    print(len(np.unique(clusterer.labels_)))

    plt.scatter(x, y, c=clusterer.labels_, s=0.125, cmap='prism')
    plt.show()
    plt.close()

In [None]:
# https://scikit-learn.org/stable/auto_examples/cluster/plot_optics.html#sphx-glr-auto-examples-cluster-plot-optics-py
if True:
    from sklearn.cluster import OPTICS
    clusterer = OPTICS() #(min_samples=20, xi=1.0, min_cluster_size=5)

    clusterer.fit(pixels)
    print(len(np.unique(clusterer.labels_)))

    plt.scatter(x, y, c=clusterer.labels_, s=0.125, cmap='prism')
    plt.show()
    plt.close()

DBSCAN is probably a bit better (depending on parameters, I suspect), but still not ideal. Still, it is much faster, so we will stick with that model.

Segmentation is still, however, far from perfect. The clustering model has trouble splitting cells that are connected by threads (axions?). We should be able to fix this by pre-processing the segmentation map to push the cells further apart. Since I am, by background, a physicist, an idea that comes to mind is the following:
 - treat the point cloud as a collection of test particles,
 - introduce a force that will make particles stick (be attracted) to each other at close distances and repel each other at larger distances,
 - evolve the segmentation map with a particle simulator that applies the introduced force,
 - the particles should split according to their groupings in the original map, splitting at narrow features,
 - the resulting evolved map will be much simpler to segment.

There are already great particle simulators for Python available. However, since the problem here is very simple, it may be best to use a custom solution (the available simulators will be rather multi-purpose, so may be a bit unwieldy for us here). Let us quickly write a simple particle simulator then.

In [None]:
from matplotlib.animation import FuncAnimation
from IPython import display
class ParticleSimulator ():
    
    def __init__ (self, particles_init, force, force_radius=20., force_kwargs={}, store_history=False):
        self.particles = tf.cast(tf.convert_to_tensor(particles_init), tf.float16) # particles stored in shape (...,2+) with the last dimension being (x_pos, y_pos, ...)
        self.particles_shape = tf.shape(self.particles)
        self.transposed = tf.transpose(self.particles) # useful for plotting
        self.n_part = len(particles_init)
        self.force = force # a function f(part1, part2) returning a force acting between the given two particles (represented by (x_pos, y_pos, _))
        # make sure that force(part1,part1) = 0
        self.force_radius = force_radius # force only applied to particles that are closer than that
        self.force_kwargs = force_kwargs
        self.time = 0 # progress of the simulation in timesteps
        self.store_history = store_history
        if self.store_history:
            self.history = tf.expand_dims(copy(self.particles.numpy()), axis=0)
        else:
            self.history = []
    
    def all_forces (self, sparse=True, symmetric=True):
        # calculate distances between particles
        u = tf.repeat(tf.expand_dims(self.particles[:,:2], axis=0), tf.shape(self.particles)[0], axis=0)
        v = tf.transpose(u, perm=[1,0,2])
        # if the forces are symmetric, create upper triangular slices of the distance matrix, to save computation and memory
        if symmetric:
            distance_matrix = tf.transpose(
                tf.linalg.band_part(
                    tf.transpose(u-v, perm=[2,0,1]),
                    0,-1
                ),
                perm=[1,2,0]
            )
        else:
            distance_matrix = u-v
        sqr_distance_matrix = tf.reduce_sum(distance_matrix**2, axis=2)
        # limit to forces between nearby particles
        mask = tf.logical_and(sqr_distance_matrix > 0.5, sqr_distance_matrix < tf.cast(self.force_radius**2, tf.float16))
        indices = tf.cast(tf.where(mask), tf.int32)
        indices_vec = tf.concat(
            [tf.repeat(indices, tf.shape(distance_matrix)[2], axis=0),
             tf.tile([[0,],[1,]], [tf.shape(indices)[0],1])],
            axis=1
        )
        if sparse:
            # turn the distance matrices into sparse tensors
            distance_matrix = tf.SparseTensor(
                indices=tf.cast(indices_vec, tf.int64),
                values=tf.gather_nd(distance_matrix, indices_vec),
                dense_shape=tf.cast(tf.shape(distance_matrix), tf.int64)
            )
            sqr_distance_matrix = tf.SparseTensor(
                indices=tf.cast(indices, tf.int64),
                values=tf.gather_nd(sqr_distance_matrix, indices),
                dense_shape=tf.cast(tf.shape(sqr_distance_matrix), tf.int64)
            )
            # fill the forces matrix
            forces = tf.SparseTensor(
                indices=distance_matrix.indices,
                values=self.force(distance_matrix,sqr_distance_matrix, self.force_radius), # THIS IS THE ACTUAL FORCE
                dense_shape=distance_matrix.shape
            )
            # sum forces for each particle, including the anti-symmetric lower diagonal, if needed
            if symmetric:
                forces2 = tf.sparse.transpose(forces, perm=[1,0,2])
            forces = tf.sparse.reduce_sum(forces, axis=1)
            # forces are anti-symmetric, recreate the lower diagonal
            if symmetric:
                forces -= tf.sparse.reduce_sum(forces2, axis=1)
        else: # dense matrices used
            forces = tf.where(
                tf.repeat(tf.expand_dims(mask,axis=2), tf.shape(distance_matrix)[2], axis=2),
                0.1 * distance_matrix / tf.repeat(tf.expand_dims(sqr_distance_matrix, axis=2), tf.shape(distance_matrix)[2], axis=2),
                tf.zeros(tf.shape(distance_matrix), dtype=tf.float16)
            )
            forces -= tf.transpose(forces, perm=[1,0,2])
            forces = tf.reduce_sum(forces, axis=1)
        return forces
    
    def integrator (self, part, force, dt=1.0): # evolve the particle by dt=1 given the force f
        # the forward Euler integrator would be **terrible** for actual physics simulations, but for our use it is probably fine...
        return part[:2] + dt*force
            
    def plot (self):
        if len(self.transposed) > 2:
            c = self.transposed[2]
        else:
            c = None
        plt.scatter(self.transposed[0], self.transposed[1], c=c, s=0.125, cmap='prism')
        plt.show()
        plt.close()
        
    def plot_history_frame (self, i, ax):
        xlim, ylim = ax.get_xlim(), ax.get_ylim()
        ax.clear()
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        transposed = tf.transpose(self.history[i])
        if tf.shape(transposed)[0] > 2:
            c = transposed[2].numpy()
        else:
            c = self.time * tf.ones(tf.shape(transposed[0])[0])
        return ax.scatter(transposed[0].numpy(), transposed[1].numpy(), c=c, s=0.125, cmap='prism')
    
    def animate_history (self):
        fig, ax = plt.subplots()
        ax.set_xlim(tf.reduce_min(self.history[:,:,0])-10, tf.reduce_max(self.history[:,:,0])+10)
        ax.set_ylim(tf.reduce_min(self.history[:,:,1])-10, tf.reduce_max(self.history[:,:,1])+10)
        self.anim = FuncAnimation(fig, lambda i : self.plot_history_frame(i, ax), frames=len(self.history), interval=200)
        video = self.anim.to_html5_video()
        html = display.HTML(video)
        display.display(html)
        plt.close()
        
    def evolve (self, n_steps=1, animate=False, dt=1.0, verbose=True):
        if verbose:
            print('Evolving the particles...', flush=True)
            view = tqdm
        else:
            view = (lambda x : x)
        for i in view(range(n_steps)):
            
            indices = tf.cast(tf.range(tf.shape(self.particles)[0]), tf.int32)
            indices = tf.reshape(
                tf.stack(
                    [
                        indices,
                        tf.zeros(tf.shape(self.particles)[0], dtype=tf.int32),
                        indices,
                        tf.ones(tf.shape(self.particles)[0], dtype=tf.int32)
                    ],
                    axis=1
                ),
                [-1,2,2]
            )
            
            self.particles = tf.tensor_scatter_nd_add(
                self.particles,
                indices,
                dt * self.all_forces()
            )
            
            if self.store_history:
                self.history = tf.concat([self.history, tf.expand_dims(copy(self.particles.numpy()), axis=0)], axis=0)
            self.time += dt
        if animate:
            self.animate_history()

In [None]:
# trying out some different force prescriptions here..

def force_axis (sparse_distance, sparse_sqr_distance, force_radius, norm=0.05, attractive=1.0, repulsive=1.0):
    distance = tf.reshape(sparse_distance.values, [-1,2])
    sqr_distance = sparse_sqr_distance.values
    part0_indices = sparse_sqr_distance.indices[:,0]
    # calculate the number of neighbours
    n_neighbours = tf.SparseTensor(
        indices = sparse_sqr_distance.indices, 
        values = tf.ones(tf.shape(sparse_sqr_distance.values), dtype=tf.float16),
        dense_shape = sparse_sqr_distance.dense_shape
    )
    n_neighbours = tf.sparse.reduce_sum(n_neighbours, axis=1)
    n_neighbours = tf.gather_nd(n_neighbours, tf.expand_dims(part0_indices,-1))
    # calculate the local mean direction of neighbours
    direction = distance / tf.expand_dims(tf.sqrt(tf.reduce_sum(distance**2,-1)),-1)
    sparse_direction = tf.SparseTensor(
        indices = sparse_distance.indices,
        values = tf.reshape(direction, [-1]),
        dense_shape = sparse_distance.dense_shape
    )
    axis = tf.sparse.reduce_sum(sparse_direction, axis=1)
    axis = tf.reshape(axis, [-1,2])
    axis = tf.gather_nd(axis, tf.expand_dims(part0_indices,-1))
    axis /= tf.sqrt(tf.reduce_sum(axis**2))
    # have the axis always point at positive x
    axis = tf.where(
        tf.expand_dims(axis[:,0] > 0,-1),
        axis,
        -axis
    )
    axis_perp = axis[:,::-1]
    # the particles will feel the strongest force towards particles in the mean direction
    force_proposed = (
        #+ 0.5*attractive / ((1.0 + tf.abs(tf.reduce_sum(axis_perp*distance, axis=1))**2.0)**4 * sqr_distance)
        + attractive / ((1.0 + tf.abs(tf.reduce_sum(axis_perp*distance, axis=1))**2)**2 * sqr_distance) / 25**(n_neighbours/force_radius**2)
    )
    # and will be repelled from particles perpendicular to the mean direction
    force_proposed += (
        - repulsive * tf.sin(sqr_distance*0.5*np.pi/force_radius**2)
            * tf.abs(tf.reduce_sum(axis_perp*distance, axis=1))
            / ((0.2*force_radius + tf.abs(tf.reduce_sum(axis_perp*distance, axis=1))**3) * sqr_distance)
    )
    force_proposed = distance * tf.expand_dims(norm*force_proposed, axis=-1)
    # push the particles along the axis to avoid them being stuck in threads
    #force_proposed += norm * tf.expand_dims(tf.abs(tf.reduce_sum(axis*distance, axis=-1)), axis=-1) * axis
    force_value = tf.reduce_sum(force_proposed**2, axis=1)
    border_force = tfp.stats.percentile(force_value,25)
    force_proposed = tf.where(
        tf.expand_dims(force_value > border_force, axis=-1),
        force_proposed,
        #tf.expand_dims(border_force,-1)*axis
        1.0*axis
    )
    # apply limits
    force_max = tf.sqrt(sqr_distance) / n_neighbours
    force_value = tf.reduce_sum(force_proposed**2, axis=1)
    force_direction = force_proposed / tf.expand_dims(force_value, axis=-1)
    force_proposed = tf.where(
        tf.expand_dims(tf.abs(force_value) < force_max, -1),
        force_proposed,
        norm * force_direction * tf.expand_dims(force_max,-1)
    )
    # reshape to sparse representation
    force_proposed = tf.reshape(force_proposed, [-1])
    return force_proposed

In [None]:
# Well, let's try this
if not submission_mode:
    test_size = 20
    simulator = ParticleSimulator(
        tf.cast(
            tf.concat([
                tf.concat([np.expand_dims(np.linspace(0,test_size, test_size), axis=1),np.expand_dims(np.zeros([test_size,]),axis=1)], axis=1),
                #tf.concat([np.expand_dims(np.linspace(0,20, 20), axis=1),np.expand_dims(np.ones([20,]),axis=1)], axis=1),
                tf.concat([np.expand_dims(np.sin(np.pi*2./3.)*np.linspace(0,test_size, test_size), axis=1),np.expand_dims(np.cos(np.pi*2./3.)*np.linspace(0,test_size, test_size),axis=1)], axis=1),
                tf.concat([np.expand_dims(np.sin(np.pi*4./3.)*np.linspace(0,test_size, test_size), axis=1),np.expand_dims(np.cos(np.pi*4./3.)*np.linspace(0,test_size, test_size),axis=1)], axis=1),
            ], axis=0),
            tf.float32
        ),
        force=force_axis,
        store_history=True,
        force_radius=10.
    )
    simulator.plot()
    simulator.evolve(animate=True, n_steps=simulation_epochs, dt=simulation_dt)

This looks pretty good, if not perfect. The pixels detach at network nodes, but otherwise stay together. Let's try on our dataset.

The full mask has a very large number of pixels (N~50000), which will not fit as a particle simulator object (with N^2 matrices) in the GPU memory. Thus, let's limit our test to only part of the map.

In [None]:
if not submission_mode:
    plt.figure(figsize=(12,4))

    # Let's plot the true segmentation, so that we know what we're looking for...
    myid = astro_train_csv.id.unique()[5]
    astro_train_csv[astro_train_csv.id == myid].head()

    pixels_true = []
    cl = 0
    for annotation in tqdm(tf.random.shuffle( # shuffle helps to make adjacent cells different color
        astro_train_csv[astro_train_csv.id == myid].annotation)
                          ):
        mask_here = rle_decode_tf(annotation, shape=img.shape)
        pixels_true_here = tf.where(tf.squeeze(mask_here) > 0)
        pixels_true_here = tf.concat([pixels_true_here, cl*np.ones([pixels_true_here.shape[0],1])], axis=1)
        cl += 1
        pixels_true.append(pixels_true_here)
    pixels_true = tf.concat(pixels_true, axis=0)

    x = tf.transpose(pixels_true)[0]
    y = tf.transpose(pixels_true)[1]
    c = tf.transpose(pixels_true)[2]

    plt.subplot(121)
    plt.title('Full mask')
    plt.scatter(x, y, c=c, s=0.125, cmap='prism')

    # limit to a test region
    mask = tf.logical_and(x > 300, y < 400).numpy()
    print(pixels_true.shape)
    pixels_true = pixels_true.numpy()[mask,:]
    print(pixels_true.shape)

    x = tf.transpose(pixels_true)[0]
    y = tf.transpose(pixels_true)[1]
    c = tf.transpose(pixels_true)[2]

    plt.subplot(122)
    plt.title('Cropped fragment')
    plt.scatter(x, y, c=c, s=0.125, cmap='prism')

    plt.show()
    plt.close()

In [None]:
if not submission_mode:
    plt.scatter(x, y, s=0.125)
    plt.show()
    plt.close()

In [None]:
# some adjustments to avoid OOM errors
physical_devices = tf.config.list_physical_devices('GPU')
try:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
except:
    # Invalid device or cannot modify virtual devices once initialized.
    pass

In [None]:
# now evolve our cropped fragment
if not submission_mode:
    simulator = ParticleSimulator(
        tf.cast(pixels_true, tf.float16),
        force=force_axis,
        store_history=True,
        force_radius=25
    )
    simulator.plot()
    simulator.evolve(animate=True, n_steps=simulation_epochs, dt=simulation_dt)

Now we can use this evolved picture to cluster our point cloud. Let us compare 3 cases:
 - using the original recognition map (the NN output) only,
 - using the evolved map (the final state from our particle simulator) only,
 - using both the original and the evolved map as features.

In [None]:
# limit the clusters to only those with at least 10 members
def limit_labels (label_list, min_members=10, verbose=True):
    label_list = np.array(label_list)
    classes = {c:tf.reduce_sum(tf.cast(label_list == c, tf.int32)).numpy() for c in np.unique(label_list)}
    classes = {a:(classes[a] if classes[a] > 10 else -1) for a in classes.keys()}
    labels = [classes[x] for x in label_list]
    if verbose:
        print(len(np.unique(list(classes.values()))))
    return labels

In [None]:
# Now let's try classical clustering algorithms on this point cloud
from sklearn.cluster import DBSCAN
dbscan_eps = 6
if not submission_mode:

    plt.figure(figsize=(12,16))

    # target classification
    plt.subplot(321)
    plt.title('TARGET (%i)' % len(np.unique(c)))
    plt.scatter(x, y, c=c, s=0.125, cmap='prism')

    # classification without evolution
    clusterer = DBSCAN(eps=1, min_samples=10)
    clusterer.fit(pixels_true[:,:2])

    # discard classes below 10 pixels
    labels = limit_labels(clusterer.labels_)

    plt.subplot(322)
    plt.title('No evolution (%i)' % len(np.unique(clusterer.labels_)))
    plt.scatter(x, y, c=labels, s=0.125, cmap='prism')

    # just evolved

    clusterer = DBSCAN(eps=dbscan_eps, min_samples=10)
    clusterer.fit(simulator.particles[:,:2])

    # discard classes below 10 pixels
    labels = limit_labels(clusterer.labels_)

    plt.subplot(323)
    plt.title('Evolved only (%i)' % len(np.unique(clusterer.labels_)))
    plt.scatter(x, y, c=labels, s=0.125, cmap='prism')

    # for better clustering, combine the original positions and the evolved ones
    features = tf.concat([pixels_true[:,:2], simulator.particles[:,:2]], axis=1)

    clusterer = DBSCAN(eps=dbscan_eps, min_samples=10)
    clusterer.fit(features)

    # discard classes below 10 pixels
    labels = limit_labels(clusterer.labels_)

    plt.subplot(324)
    plt.title('Original + evolved (%i)' % len(np.unique(labels)))
    plt.scatter(x, y, c=labels, s=0.125, cmap='prism')

    # let's try optics too, just out of curiosity
    from sklearn.cluster import OPTICS
    clusterer = OPTICS(xi=0.3) #(min_samples=20, xi=1.0, min_cluster_size=5)

    clusterer.fit(features)

    # discard classes below 10 pixels
    labels = limit_labels(clusterer.labels_)

    plt.subplot(325)
    plt.title('Original + evolved (%i), OPTICS' % len(np.unique(labels)))
    plt.scatter(x, y, c=labels, s=0.125, cmap='prism')

    plt.show()
    plt.close()

The original+evolved DBSCAN model perfomes best here. The result is not perfect, but with some tuning, this could do great!

Since the trick with particle simulation paid off, let us now apply it to the full map. Since we can't hold all the pixels in memory at the same time, let us use cropouts. We will proceed as follows:
 - split the full binary detection map into a number of cropouts, where each single one can be held in GPU memory at a time -- note that they should overlap each other!,
 - evolve the cropouts with the particle simulator in sequence, by a small time step at a time,
 - repeat until the cells are sufficiently separated.

In [None]:
# recalculate the full list of pixels
if not submission_mode:
    pixels_true = []
    cl = 0
    for annotation in tqdm(tf.random.shuffle( # shuffle helps to make adjacent cells different color
        astro_train_csv[astro_train_csv.id == myid].annotation)
                          ):
        mask_here = rle_decode_tf(annotation, shape=img.shape)
        pixels_true_here = tf.where(tf.squeeze(mask_here) > 0)
        pixels_true_here = tf.concat([pixels_true_here, cl*np.ones([pixels_true_here.shape[0],1])], axis=1)
        cl += 1
        pixels_true.append(pixels_true_here)
    pixels_true = tf.concat(pixels_true, axis=0)

    x = tf.transpose(pixels_true)[0]
    y = tf.transpose(pixels_true)[1]
    c = tf.transpose(pixels_true)[2]

In [None]:
class ParticleSimulatorGroup ():
    
    def __init__ (
        self, particles_init,
        force, force_radius=20., force_kwargs={}, store_history=False, # ParticleSimulator arguments
        max_px=10000, # the maximum number of particles per simulator instance
        overlap=0.25, # fractional overlap between cropped images
    ):
        # read the properties of the particle list
        self.time = 0.0
        self.particles = tf.cast(particles_init, tf.float16)
        self.n_particles = tf.shape(particles_init)[0].numpy()
        transposed = tf.transpose(self.particles[:,:2])
        self.x = transposed[0]
        self.y = transposed[1]
        self.xmin, self.xmax = tf.reduce_min(self.x).numpy(), tf.reduce_max(self.x).numpy()
        self.xlen = self.xmax - self.xmin
        self.ymin, self.ymax = tf.reduce_min(self.y).numpy(), tf.reduce_max(self.y).numpy()
        self.ylen = self.ymax - self.ymin
        # decide how to split the pixel list
        if self.n_particles < max_px: # use a single simulator
            self.n_simulators = 1
            self.split_borders = [[self.xmin, self.xmax, self.ymin, self.ymax],]
            self.split = [tf.range(self.n_particles),] # index list
            self.overlap = 0.
        else:
            # split, ensuring significant overlap between cropped images
            self.n_simulators = 2 * self.n_particles / max_px
            # bring to the closest larger square of int to make it easier to split
            sqrt_n_simulators = np.sqrt(self.n_simulators)
            self.n_simulators = sqrt_n_simulators**2
            # calculate the split boundaries
            self.overlap = overlap
            self.split_borders = [[
                self.xmin + (i[0]  -self.overlap) * self.xlen/sqrt_n_simulators,
                self.xmin + (i[0]+1+self.overlap) * self.xlen/sqrt_n_simulators,
                self.ymin + (i[1]  -self.overlap) * self.ylen/sqrt_n_simulators,
                self.ymin + (i[1]+1+self.overlap) * self.ylen/sqrt_n_simulators,
            ] for i in np.reshape(np.transpose(np.meshgrid(np.arange(sqrt_n_simulators), np.arange(sqrt_n_simulators), indexing='ij')),[-1,2])]
            # now, gather the indices in each border set
            self.split = [
                tf.where(
                    tf.logical_and(
                        tf.logical_and(
                            self.particles[:,0] > b[0], self.particles[:,0] < b[1]
                        ),
                        tf.logical_and(
                            self.particles[:,1] > b[2], self.particles[:,1] < b[3]
                        )
                    )
                ) for b in self.split_borders
            ]
        # create particle simulators -- one for each crop-out
        self.simulators = [
            ParticleSimulator(
                tf.gather_nd(self.particles, indices),
                force=force, force_radius=force_radius, force_kwargs=force_kwargs, store_history=store_history
            ) for indices in self.split
        ]
        # prepare for history storage, if needed
        self.store_history = store_history
        if self.store_history:
            self.history = tf.expand_dims(copy(self.particles.numpy()), axis=0)
        else:
            self.history = []
            
    def plot (self):
        if len(self.transposed) > 2:
            c = self.transposed[2]
        else:
            c = None
        plt.scatter(self.transposed[0], self.transposed[1], c=c, s=0.125, cmap='prism')
        plt.show()
        plt.close()
        
    def plot_history_frame (self, i, ax):
        xlim, ylim = ax.get_xlim(), ax.get_ylim()
        ax.clear()
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        transposed = tf.transpose(self.history[i])
        if tf.shape(transposed)[0] > 2:
            c = transposed[2].numpy()
        else:
            c = self.time * tf.ones(tf.shape(transposed[0])[0])
        return ax.scatter(transposed[0].numpy(), transposed[1].numpy(), c=c, s=0.125, cmap='prism')
    
    def animate_history (self):
        fig, ax = plt.subplots()
        ax.set_xlim(tf.reduce_min(self.history[:,:,0])-10, tf.reduce_max(self.history[:,:,0])+10)
        ax.set_ylim(tf.reduce_min(self.history[:,:,1])-10, tf.reduce_max(self.history[:,:,1])+10)
        self.anim = FuncAnimation(fig, lambda i : self.plot_history_frame(i, ax), frames=len(self.history), interval=200)
        video = self.anim.to_html5_video()
        html = display.HTML(video)
        display.display(html)
        plt.close()
        
    def evolve (self, n_steps=1, animate=False, dt=1.0, verbose=True):
        if verbose:
            print('Evolving the particles...', flush=True)
        for i in tqdm(range(n_steps)):
            
            # evolve each particle simulator by one time step, in order
            for idx in range(tf.cast(self.n_simulators, tf.int32)):
                simulator = self.simulators[idx]
                indices = self.split[idx]
                simulator.evolve(
                    n_steps=1,
                    animate=False,
                    dt=dt,
                    verbose=False
                )
                self.particles = tf.tensor_scatter_nd_update(
                    self.particles,
                    indices,
                    simulator.particles
                )
            
            # save history if needed
            if self.store_history:
                self.history = tf.concat([self.history, tf.expand_dims(copy(self.particles.numpy()), axis=0)], axis=0)
            self.time += dt
        if animate:
            self.animate_history()

In [None]:
if not submission_mode:
    test = ParticleSimulatorGroup(
        pixels_true,
        force=force_axis,
        store_history=True,
        force_radius=25
    )
    print(test.simulators[3].n_part)

In [None]:
if not submission_mode:
    test.evolve(n_steps=simulation_epochs, animate=True, dt=simulation_dt)

In [None]:
# Now let's try classical clustering algorithms on this point cloud
from sklearn.cluster import DBSCAN
dbscan_eps = 6
astro_clusterer = DBSCAN(eps=dbscan_eps, min_samples=10)

if not submission_mode:
    plt.figure(figsize=(12,16))

    # target classification
    plt.subplot(321)
    plt.title('TARGET (%i)' % len(np.unique(c)))
    plt.scatter(x, y, c=c, s=0.125, cmap='prism')

    # classification without evolution
    clusterer = DBSCAN(eps=1, min_samples=10)
    clusterer.fit(pixels_true[:,:2])

    # discard classes below 10 pixels
    labels = limit_labels(clusterer.labels_)

    plt.subplot(322)
    plt.title('No evolution (%i)' % len(np.unique(clusterer.labels_)))
    plt.scatter(x, y, c=labels, s=0.125, cmap='prism')

    # just evolved

    clusterer = DBSCAN(eps=dbscan_eps, min_samples=10)
    clusterer.fit(test.particles[:,:2])

    # discard classes below 10 pixels
    labels = limit_labels(clusterer.labels_)

    plt.subplot(323)
    plt.title('Evolved only (%i)' % len(np.unique(clusterer.labels_)))
    plt.scatter(x, y, c=labels, s=0.125, cmap='prism')

    # for better clustering, combine the original positions and the evolved ones
    features = tf.concat([pixels_true[:,:2], tf.cast(test.particles[:,:2], tf.int64)], axis=1)

    astro_clusterer.fit(features)

    # discard classes below 10 pixels
    labels = limit_labels(astro_clusterer.labels_)

    plt.subplot(324)
    plt.title('Original + evolved (%i)' % len(np.unique(labels)))
    plt.scatter(x, y, c=labels, s=0.125, cmap='prism')

    # let's try optics too, just out of curiosity
    from sklearn.cluster import OPTICS
    clusterer = OPTICS(xi=0.3) #(min_samples=20, xi=1.0, min_cluster_size=5)

    clusterer.fit(features)

    # discard classes below 10 pixels
    labels = limit_labels(clusterer.labels_)

    plt.subplot(325)
    plt.title('Original + evolved (%i), OPTICS' % len(np.unique(labels)))
    plt.scatter(x, y, c=labels, s=0.125, cmap='prism')

    plt.show()
    plt.close()

All right, let's put it all together into a classifier.
 - Input: raw image with astro cells
 - Processing:
   - preprocessing (color adjustments, etc.),
   - detection of cell structures with the NN trained above,
   - transformation of the resulting mask with particle simulators,
   - mask splitting into separate cells with DBSCAN,
   - conversion into competition pixel format.
 - Output:
   - imaged clustering map (side-by-side with the original image)
   - the clustering map in the competition pixel format, as a string

In [None]:
def rle_encode (mask, classes=[1,], img_id='noid'):
    '''Encode an image using run-length-encoding'''
    mask_here = tf.squeeze(mask).numpy()
    shape = mask_here.shape
    
    rle_encoded = ''
    for cl in tqdm(classes):
        j,i = np.where(mask_here == cl)
        i = np.append(i, -1)

        kernel = np.array([1,-1])
        starting_idxs = np.where(
            np.convolve(i,kernel, mode='valid') != 1
        )[0] + 1
        starting_idxs = np.insert(starting_idxs,0,0)
        starting_px = j[starting_idxs[:-1]]*shape[1] + i[starting_idxs[:-1]] + 1
        run_lengths = np.convolve(starting_idxs,kernel, mode='valid')

        rle_encoded_here = np.transpose(
            [starting_px,
            run_lengths]
        ).flatten()
        rle_encoded += '\n' + img_id + ',' + ' '.join(rle_encoded_here.astype(str))
    
    return rle_encoded[1:]

# test that the rle encoding works properly
if not submission_mode:
    annotation = astro_train_csv[astro_train_csv.id == myid].annotation.iloc[0]
    rle_encoded = rle_encode(rle_decode_tf(annotation, shape=img.shape))
    rle_encoded[5:] == annotation

In [None]:
class AstroModel:
    
    def __init__ (
        self,
        preprocess=False,
        detection_model=astro_model,
        detection_threshold=0.6,
        simulator_force=force_axis,
        simulator_force_radius=25,
        simulator_force_kwargs={},
        simulator_steps=simulation_epochs,
        simulator_dt=simulation_dt,
        clusterer=DBSCAN(eps=dbscan_eps, min_samples=10)
    ):
        self.preprocess = preprocess
        self.detection_model = detection_model
        self.detection_threshold = detection_threshold
        self.simulator_force = simulator_force
        self.simulator_force_radius = simulator_force_radius
        self.simulator_force_kwargs = simulator_force_kwargs
        self.simulator_steps = simulator_steps
        self.simulator_dt = simulator_dt
        self.clusterer = astro_clusterer
    
    def predict (self, img, img_id='noid', mask_pixels=[], verbose=True):
        # image preprocessing
        if verbose:
            print(' - predicting an astro-cell image segmentation:')
            print('   - image preprocessing.. ', end='', flush=True)
            fig = plt.figure(figsize=(12,6))
            plt.subplot(121)
            plt.imshow(1.*img, cmap="seismic")
        if self.preprocess:
            img = img_color_rescale(img)
        if verbose:
            plt.subplot(122)
            plt.imshow(img, cmap="seismic")
            plt.show()
            plt.close()
            print('done.', flush=True)
        # use the NN to detect cell structures
        if verbose:
            print('   - detecting cell pixels with a NN.. ', end='', flush=True)
        pred = self.detection_model.predict(np.array([img,]))[0]
        pixels_true = tf.where(pred > self.detection_threshold)
        x = tf.transpose(pixels_true)[0]
        y = tf.transpose(pixels_true)[1]
        if verbose:
            print('done.', flush=True)
            plt.figure(figsize=(12,6))
            plt.subplot(121)
            plt.imshow(pred)
            plt.colorbar()
            plt.subplot(122)
            plt.scatter(y,x, s=0.125)
            plt.show()
            plt.close()
        # transform with particle simulators
        if verbose:
            print('   - transforming with particle simulators.. ', end='', flush=True)
        self.psg = ParticleSimulatorGroup(
            pixels_true,
            force=self.simulator_force,
            store_history=verbose,
            force_radius=self.simulator_force_radius,
            force_kwargs=self.simulator_force_kwargs
        )
        self.psg.evolve(n_steps=self.simulator_steps, animate=verbose, dt=self.simulator_dt)
        if verbose:
            print('done.', flush=True)
        # cluster with DBSCAN
        if verbose:
            print('   - clustering.. ', end='', flush=True)
        features = tf.concat([pixels_true[:,:2], tf.cast(self.psg.particles[:,:2], tf.int64)], axis=1)
        self.clusterer.fit(features)
        labels = np.array(limit_labels(self.clusterer.labels_))
        if verbose:
            if len(mask_pixels) > 0:
                fig = plt.figure(figsize=(12,6))
                plt.subplot(121)
                mask_pixels = tf.transpose(mask_pixels)
                xm = mask_pixels[0]
                ym = mask_pixels[1]
                cmask = mask_pixels[2]
                plt.scatter(ym,xm,c=cmask, s=0.125, cmap='prism')
                plt.subplot(122)
            else:
                fig = plt.figure(figsize=(6,6))
            plt.scatter(x, y, c=labels, s=0.125, cmap='prism')
            plt.show()
            plt.close()
            print('done.', flush=True)
        # convert to the competitioon pixel number format
        if verbose:
            print('   - converting to RLE string.. ', end='', flush=True)
        mask_pred = np.squeeze(np.zeros(img.shape))
        mask_pred[tuple(np.transpose(pixels_true[:,:2]))] = labels
        rle_encoded = rle_encode(mask_pred, classes=np.unique(labels), img_id=img_id)
        if verbose:
            print('done.', flush=True)
        return rle_encoded

In [None]:
if not submission_mode:
    img = astro_train_imgs[5]
    mask = astro_train_masks[5]

    model = AstroModel()
    encoded = model.predict(img)
    print(encoded[:50])

When improving the model later, it will be very helpful to have our own implementation of the competition score. Let us prepare one here.

In [None]:
def single_intersection (b1,e1,b2,e2):
    result = min(e1,e2) - max(b1,b2)
    return max(result, 0)
def rle_intersection_over_union (rle1, rle2):
    '''Calculate IoU for rle-encoded data'''
    start_px1, end_px1 = tuple(np.transpose(np.array(rle1.split()).astype(np.int).reshape([-1,2])))
    end_px1 += start_px1
    start_px2, end_px2 = tuple(np.transpose(np.array(rle2.split()).astype(np.int).reshape([-1,2])))
    end_px2 += start_px2
    intersections = [
        single_intersection(
            start_px1[i1],
            end_px1[i1],
            start_px2[i2],
            end_px2[i2]
        ) for i1,i2 in zip(range(len(start_px1)), range(len(start_px2)))
    ]
    intersection = np.sum(intersections)
    union = np.sum(end_px1-start_px1) + np.sum(end_px2-start_px2) - intersection
    return intersection / union

def comp_score (rle_pred, rle_true):
    '''Calculate competition score, https://www.kaggle.com/c/sartorius-cell-instance-segmentation/overview/evaluation'''
    # key annotations with their image ids
    data_true = pd.DataFrame([[x.split(',')[0], ' '.join(x.split(',')[1].split())] for x in rle_true.split('\n')], columns=['id','annotation'])
    data_pred = pd.DataFrame([[x.split(',')[0], ' '.join(x.split(',')[1].split())] for x in rle_pred.split('\n')], columns=['id','annotation'])
    # loop over all images
    global_precision = []
    for img_id in tqdm(data_true.id.unique()):
        # build an IoU array
        ious = np.array([
            [rle_intersection_over_union(rle1, rle2) \
            for rle2 in data_true[data_true.id == img_id].annotation] \
            for rle1 in np.array(data_pred[data_pred.id == img_id].annotation)
        ])
        # sweep over thresholds
        precisions = []
        for threshold in np.arange(0.5,0.99,0.05):
            matches = (ious > threshold)
            true_positives = np.sum(np.any(matches, axis=0))
            false_positives = matches.shape[0] - true_positives
            false_negatives = matches.shape[1] - np.sum(np.any(matches, axis=1))
            precisions.append(
                true_positives / (true_positives + false_positives + false_negatives)
            )
        global_precision.append(np.mean(precisions))
    return np.mean(global_precision)

# test that it works cortectly
if not submission_mode:
    print("Should be 1: ", comp_score(encoded, encoded))

In [None]:
# let's see how our astro model performs with respect to the competition score
if not submission_mode:
    
    # choose the indices to test on
    val_ids = np.random.choice(astro_val_csv.id.unique(), 1)
    print(val_ids)

    # parse the cortect answers
    rows_true = astro_val_csv[astro_val_csv.id.isin(val_ids)]
    encoded_true = ''
    for idx, row in rows_true[['id', 'annotation']].iterrows():
        encoded_true += '\n' + row['id'] + ',' + row['annotation']
    encoded_true = encoded_true[1:]

    # calculate predictions
    if True:
        encoded = ''
        for myid in val_ids:
            print('--- Predicting image %s ---' % myid, flush=True)
            img = tf_img_color_rescale(
                tf_img_load(myid)
            ).numpy()
            # decode the target mask for comparison
            pixels_true = []
            cl = 0
            for annotation in tf.random.shuffle( # shuffle helps to make adjacent cells different color
                astro_val_csv[astro_val_csv.id == myid].annotation
            ):
                mask_here = rle_decode_tf(annotation, shape=img.shape)
                pixels_true_here = tf.where(tf.squeeze(mask_here) > 0)
                pixels_true_here = tf.concat([pixels_true_here, cl*np.ones([pixels_true_here.shape[0],1])], axis=1)
                cl += 1
                pixels_true.append(pixels_true_here)
            pixels_true = tf.concat(pixels_true, axis=0)
            # pass the image through our model
            model = AstroModel(
                preprocess=False,
                detection_model=astro_model,
                detection_threshold=0.6,
                simulator_force=force_axis,
                simulator_force_radius=25,
                simulator_steps=32,
                clusterer=DBSCAN(eps=dbscan_eps, min_samples=10)
            )
            encoded += '\n' + model.predict(
                img,
                img_id=myid,
                mask_pixels=pixels_true
            )
            print('--- Image %s predicted ---' % myid, flush=True)
        encoded = encoded[1:]

        score = comp_score(encoded, encoded_true)
        print('Competition score: %.3f' % score)

Hmm.. not so great.. Still, let us obtain a full model first -- with the other types of cells included -- before we proceed to any tuning. This way we can compare with a submission score and see whether we calculate it cortectly (also, time is now starting to run out ;), and I can do only so much Kaggle work on weekends).

-----------
**CORT-cell segmentation**

In [None]:
cort_ids = img_topics[img_topics.perc_cort == 1.0].id.to_numpy()
print(len(cort_ids))

In [None]:
cort_csv = train_csv[train_csv.id.isin(cort_ids)]
cort_val_ids = rd.choice(cort_ids, int(val_fraction * len(cort_ids)))
cort_val_csv = cort_csv[cort_csv.id.isin(cort_val_ids)]
cort_train_csv = cort_csv[~cort_csv.id.isin(cort_val_ids)]
print(len(cort_train_csv), len(cort_val_csv))

In [None]:
# in principle, we could do image loading as part of the model
# , but, since there aren't that many of them,
# it seems more sensible to load them into our RAM memory all at once
shape = (520,704)
save_path = 'cort_train_data.pkl'
if task_list['cort_recognitionNN'] or task_list['cort_segmentation']:
    if force_retrain or not os.path.exists(save_path):
        cort_train_imgs = []
        cort_train_masks = []
        for img_id in tqdm(cort_train_csv.id.unique()):
            cort_train_imgs.append(
                tf_img_color_rescale(
                    tf_img_load(img_id)
                ).numpy()
            )
            # could probably make this faster with a tensorflow map, but as-is, a dtype error is returned (can't map a string-type tensor to int)...
            mask = np.zeros(shape)
            for ann in cort_train_csv[cort_train_csv.id == img_id].annotation:
                mask += rle_decode_tf(ann, shape=shape)
            mask = tf.where(mask > 0, 1,0)
            cort_train_masks.append(mask.numpy().astype(np.uint8))
        cort_train_imgs = np.array(cort_train_imgs)
        cort_train_masks = np.array(cort_train_masks)

        cort_val_imgs = []
        cort_val_masks = []
        for img_id in tqdm(cort_val_csv.id.unique()):
            cort_val_imgs.append(
                tf_img_color_rescale(
                    tf_img_load(img_id)
                ).numpy()
            )
            mask = np.zeros(shape)
            for ann in cort_val_csv[cort_val_csv.id == img_id].annotation:
                mask += rle_decode_tf(ann, shape=shape)
            mask = tf.where(mask > 0, 1,0)
            cort_val_masks.append(mask.numpy().astype(np.uint8))
        cort_val_imgs = np.array(cort_val_imgs)
        cort_val_masks = np.array(cort_val_masks)
        if not submission_mode:
            with open(save_path, 'wb') as f:
                pkl.dump((cort_train_imgs, cort_train_masks, cort_val_imgs, cort_val_masks), f)
    else:
        with open(save_path, 'rb') as f:
            cort_train_imgs, cort_train_masks, cort_val_imgs, cort_val_masks = pkl.load(f)
    print(cort_train_imgs.shape, cort_val_imgs.shape)
    print(cort_train_masks.shape, cort_val_masks.shape)

In [None]:
if not submission_mode:
    plt.imshow(cort_val_masks[2])
    plt.show()
    plt.close()

In [None]:
if not submission_mode:
    for filename in ['/kaggle/input/sartorius-cell-instance-segmentation/train/' + x + '.png' for x in cort_ids][:3]:
        plot_image(filename)

 - The cells are very small, circular, and very well separated.
 - Within the images, there are lots of similar and dis-similar features in addition to cort-cells.
 
Given these observations, a CNN with small, but deep filters could prove useful. Let's try.

In [None]:
# build our cort_model
model_name = 'cort_recognitionNN'
cort_model = AugmentedSequential()
cort_model.add(keras.Input(
    shape=(520,704,1)
))
# now some vanilla CNN layers
if False: # val_loss: 0.0142
    cort_model.add(layers.Conv2D(
        filters=16, kernel_size=(3,3),
        activation='relu',
        padding='same')
    )
    cort_model.add(layers.Conv2D(
        filters=16, kernel_size=(3,3),
        activation='relu',
        padding='same')
    )
    cort_model.add(layers.Conv2D(
        filters=16, kernel_size=(3,3),
        activation='relu',
        padding='same')
    )
else: # val_loss: 0.0098
    cort_model.add(layers.Conv2D(
        filters=16, kernel_size=(3,3),
        activation='relu',
        padding='same')
    )
    cort_model.add(layers.Conv2D(
        filters=8, kernel_size=(5,5),
        activation='relu',
        padding='same')
    )
    cort_model.add(layers.Conv2D(
        filters=8, kernel_size=(7,7),
        activation='relu',
        padding='same')
    )
# Apply a dense layer along the axis of the filters, keeping the image size the same
cort_model.add(layers.Conv2D(
    filters=8, kernel_size=(1,1),
    activation='relu',
    padding='same')
)
# OUPTUT -----------------------------
#cort_model.add(layers.Dropout(0.5))
cort_model.add(layers.Conv2D(
    filters=1, kernel_size=(1,1),
    activation='sigmoid',
    padding='same')
)

cort_model.compile(optimizer='adam',
              loss=balanced_loss,
                   run_eagerly=True)  #tf.keras.losses.CategoricalCrossentropy(from_logits=False),
              #metrics=['accuracy']

cort_model.summary()

In [None]:
if task_list[model_name]:
    print(model_name)
    if force_retrain or not os.path.exists(model_name+'.model'):
        stop_early = tf.keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=early_stopping_patience,
            restore_best_weights=True
        )
        callbacks = [stop_early,]
        history = cort_model.fit(
            tf.cast(cort_train_imgs,tf.float32).numpy(),
            tf.cast(tf.expand_dims(cort_train_masks, axis=3),tf.float32).numpy(),
            validation_data=(
                tf.cast(cort_val_imgs,tf.float32).numpy(),
                tf.cast(tf.expand_dims(cort_val_masks, axis=3),tf.float32).numpy()),
            callbacks=callbacks,
            epochs=epochs_to_train)
        if not submission_mode:
            cort_model.save(model_name + '.model')
            plot_history(history)
    else:
        cort_model = keras.models.load_model(
            model_name+'.model',
            custom_objects={'balanced_loss':balanced_loss}
        )

In [None]:
# make sure augmentation works, check the predictions
def test_predictions (model):
    for i in range(3):
        X_val, y_val = tf_img_augment_nomodel(
            cort_val_imgs, 
            tf.expand_dims(cort_val_masks, axis=3).numpy()
        )

        ex_img = X_val[4].numpy()
        ex_truemask = y_val[4].numpy()

        # show an example prediction
        fig = plt.figure(figsize=(16,6), clear=True)

        plt.subplot(131)
        plt.imshow(ex_img, cmap='seismic')
        plt.title('Original image')

        plt.subplot(132)
        ex_mask = model.predict(ex_img.reshape((1,520,704,1)))[0]
        print(ex_mask.shape)
        plt.imshow(ex_mask, cmap="YlGn_r")
        plt.title('Predicted mask')
        plt.colorbar()

        plt.subplot(133)
        plt.imshow(ex_truemask, cmap="YlGn_r")
        plt.title("True mask")

        plt.show()
        plt.close()
    
if task_list[model_name] and not submission_mode:
    test_predictions(cort_model)

In [None]:
# First, let's improve the contrast of the network output, to make it easier to segment
if not submission_mode:
    img = cort_train_imgs[5]
    mask = cort_train_masks[5]
    pred = cort_model.predict(np.array([img,]))[0]

    plt.figure(figsize=(20,6), clear=True)
    plt.subplot(141)
    plt.imshow(img, cmap='seismic')
    plt.title('Original image')
    plt.subplot(142)
    plt.imshow(mask, cmap="YlGn_r")
    plt.title('True mask')
    plt.subplot(143)
    plt.imshow(pred, cmap="YlGn_r")
    plt.title('Prediction')
    plt.subplot(144)
    plt.imshow(pred > 0.2, cmap="YlGn_r")
    plt.title('Threshold')
    plt.show()
    plt.close()

    plt.hist(pred.flatten())
    plt.title('Prediction histogram')
    plt.show()
    plt.close()

This shold be very easy to split into cells, so particle simulator is probably unnecessary here.

In [None]:
# Let's plot the true segmentation, so that we know what we're looking for...
if not submission_mode:
    myid = cort_train_csv.id.unique()[5]
    cort_train_csv[cort_train_csv.id == myid].head()

    pixels_true = []
    cl = 0
    for annotation in tqdm(tf.random.shuffle( # shuffle helps to make adjacent cells different color
        cort_train_csv[cort_train_csv.id == myid].annotation)
                          ):
        mask_here = rle_decode_tf(annotation, shape=img.shape)
        pixels_true_here = tf.where(tf.squeeze(mask_here) > 0)
        pixels_true_here = tf.concat([pixels_true_here, cl*np.ones([pixels_true_here.shape[0],1])], axis=1)
        cl += 1
        pixels_true.append(pixels_true_here)
    pixels_true = tf.concat(pixels_true, axis=0)

    x = tf.transpose(pixels_true)[0]
    y = tf.transpose(pixels_true)[1]
    c = tf.transpose(pixels_true)[2]

    plt.scatter(x, y, c=c, s=0.125, cmap='prism')
    plt.show()
    plt.close()

In [None]:
# Now let's try classical clustering algorithms on this point cloud
from sklearn.cluster import DBSCAN

if not submission_mode:
    pixels = tf.where(tf.squeeze(mask) > 0).numpy()
    x = tf.transpose(pixels)[0]
    y = tf.transpose(pixels)[1]

    clusterer = DBSCAN(eps=1, min_samples=1)

    clusterer.fit(pixels)
    print(len(np.unique(clusterer.labels_)))

    plt.scatter(x, y, c=clusterer.labels_, s=0.125, cmap='prism')
    plt.show()
    plt.close()

Looks good! This completes our model for cort-type cells.

In [None]:
class CortModel:
    
    def __init__ (
        self,
        preprocess=False,
        detection_model=cort_model,
        detection_threshold=0.2,
        clusterer=DBSCAN(eps=1, min_samples=1)
    ):
        self.preprocess = preprocess
        self.detection_model = detection_model
        self.detection_threshold = detection_threshold
        self.clusterer = clusterer
    
    def predict (self, img, img_id='noid', mask_pixels=[], verbose=True):
        # image preprocessing
        if verbose:
            print(' - predicting an cort-cell image segmentation:')
            print('   - image preprocessing.. ', end='', flush=True)
            fig = plt.figure(figsize=(12,6))
            plt.subplot(121)
            plt.imshow(1.*img, cmap="seismic")
        if self.preprocess:
            img = img_color_rescale(img)
        if verbose:
            plt.subplot(122)
            plt.imshow(img, cmap="seismic")
            plt.show()
            plt.close()
            print('done.', flush=True)
        # use the NN to detect cell structures
        if verbose:
            print('   - detecting cell pixels with a NN.. ', end='', flush=True)
        pred = self.detection_model.predict(np.array([img,]))[0]
        pixels_true = tf.where(pred > self.detection_threshold)
        x = tf.transpose(pixels_true)[0]
        y = tf.transpose(pixels_true)[1]
        if verbose:
            print('done.', flush=True)
            plt.figure(figsize=(12,6))
            plt.subplot(121)
            plt.imshow(pred)
            plt.colorbar()
            plt.subplot(122)
            plt.scatter(x,y, s=0.125)
            plt.show()
            plt.close()
        # cluster with DBSCAN
        if verbose:
            print('   - clustering.. ', end='', flush=True)
        features = pixels_true[:,:2]
        self.clusterer.fit(features)
        labels = np.array(limit_labels(self.clusterer.labels_))
        if verbose:
            if len(mask_pixels) > 0:
                fig = plt.figure(figsize=(12,6))
                plt.subplot(121)
                mask_pixels = tf.transpose(mask_pixels)
                xm = mask_pixels[0]
                ym = mask_pixels[1]
                cmask = mask_pixels[2]
                plt.scatter(xm,ym,c=cmask, s=0.125, cmap='prism')
                plt.subplot(122)
            else:
                fig = plt.figure(figsize=(6,6))
            plt.scatter(x, y, c=labels, s=0.125, cmap='prism')
            plt.show()
            plt.close()
            print('done.', flush=True)
        # convert to the competitioon pixel number format
        if verbose:
            print('   - converting to RLE string.. ', end='', flush=True)
        mask_pred = np.squeeze(np.zeros(img.shape))
        mask_pred[tuple(np.transpose(pixels_true[:,:2]))] = labels
        rle_encoded = rle_encode(mask_pred, classes=np.unique(labels), img_id=img_id)
        if verbose:
            print('done.', flush=True)
        return rle_encoded

In [None]:
if not submission_mode:
    img = cort_train_imgs[5]
    mask = cort_train_masks[5]

    model = CortModel()
    encoded = model.predict(img)
    print(encoded[:50])

In [None]:
# let's see how our cort model performs with respect to the competition score
if not submission_mode:
    
    # choose the indices to test on
    val_ids = np.random.choice(cort_val_csv.id.unique(), 1)
    print(val_ids)

    # parse the correct answers
    rows_true = cort_val_csv[cort_val_csv.id.isin(val_ids)]
    encoded_true = ''
    for idx, row in rows_true[['id', 'annotation']].iterrows():
        encoded_true += '\n' + row['id'] + ',' + row['annotation']
    encoded_true = encoded_true[1:]

    # calculate predictions
    if True:
        encoded = ''
        for myid in val_ids:
            print('--- Predicting image %s ---' % myid, flush=True)
            img = tf_img_color_rescale(
                tf_img_load(myid)
            ).numpy()
            # decode the target mask for comparison
            pixels_true = []
            cl = 0
            for annotation in tf.random.shuffle( # shuffle helps to make adjacent cells different color
                cort_val_csv[cort_val_csv.id == myid].annotation
            ):
                mask_here = rle_decode_tf(annotation, shape=img.shape)
                pixels_true_here = tf.where(tf.squeeze(mask_here) > 0)
                pixels_true_here = tf.concat([pixels_true_here, cl*np.ones([pixels_true_here.shape[0],1])], axis=1)
                cl += 1
                pixels_true.append(pixels_true_here)
            pixels_true = tf.concat(pixels_true, axis=0)
            # pass the image through our model
            model = CortModel(
                preprocess=False,
                detection_model=cort_model,
                detection_threshold=0.2,
                clusterer=DBSCAN(eps=1, min_samples=1)
            )
            encoded += '\n' + model.predict(
                img,
                img_id=myid,
                mask_pixels=pixels_true
            )
            print('--- Image %s predicted ---' % myid, flush=True)
        encoded = encoded[1:]

        score = comp_score(encoded, encoded_true)
        print('Competition score: %.3f' % score)

---------
**shsy5y-cell segmentation**

In [None]:
shsy5y_ids = img_topics[img_topics.perc_shsy5y == 1.0].id.to_numpy()
print(len(shsy5y_ids))

In [None]:
shsy5y_csv = train_csv[train_csv.id.isin(shsy5y_ids)]
shsy5y_val_ids = rd.choice(shsy5y_ids, int(val_fraction * len(shsy5y_ids)))
shsy5y_val_csv = shsy5y_csv[shsy5y_csv.id.isin(shsy5y_val_ids)]
shsy5y_train_csv = shsy5y_csv[~shsy5y_csv.id.isin(shsy5y_val_ids)]
print(len(shsy5y_train_csv), len(shsy5y_val_csv))

In [None]:
# in principle, we could do image loading as part of the model
# , but, since there aren't that many of them,
# it seems more sensible to load them into our RAM memory all at once
shape = (520,704)
save_path = 'shsy5y_train_data.pkl'
if task_list['shsy5y_recognitionNN'] or task_list['shsy5y_segmentation']:
    if force_retrain or not os.path.exists(save_path):
        shsy5y_train_imgs = []
        shsy5y_train_masks = []
        for img_id in tqdm(shsy5y_train_csv.id.unique()):
            shsy5y_train_imgs.append(
                tf_img_color_rescale(
                    tf_img_load(img_id)
                ).numpy()
            )
            # could probably make this faster with a tensorflow map, but as-is, a dtype error is returned (can't map a string-type tensor to int)...
            mask = np.zeros(shape)
            for ann in shsy5y_train_csv[shsy5y_train_csv.id == img_id].annotation:
                mask += rle_decode_tf(ann, shape=shape)
            mask = tf.where(mask > 0, 1,0)
            shsy5y_train_masks.append(mask.numpy().astype(np.uint8))
        shsy5y_train_imgs = np.array(shsy5y_train_imgs)
        shsy5y_train_masks = np.array(shsy5y_train_masks)

        shsy5y_val_imgs = []
        shsy5y_val_masks = []
        for img_id in tqdm(shsy5y_val_csv.id.unique()):
            shsy5y_val_imgs.append(
                tf_img_color_rescale(
                    tf_img_load(img_id)
                ).numpy()
            )
            mask = np.zeros(shape)
            for ann in shsy5y_val_csv[shsy5y_val_csv.id == img_id].annotation:
                mask += rle_decode_tf(ann, shape=shape)
            mask = tf.where(mask > 0, 1,0)
            shsy5y_val_masks.append(mask.numpy().astype(np.uint8))
        shsy5y_val_imgs = np.array(shsy5y_val_imgs)
        shsy5y_val_masks = np.array(shsy5y_val_masks)
        with open(save_path, 'wb') as f:
            pkl.dump((shsy5y_train_imgs, shsy5y_train_masks, shsy5y_val_imgs, shsy5y_val_masks), f)
    else:
        with open(save_path, 'rb') as f:
            shsy5y_train_imgs, shsy5y_train_masks, shsy5y_val_imgs, shsy5y_val_masks = pkl.load(f)
    print(shsy5y_train_imgs.shape, shsy5y_val_imgs.shape)
    print(shsy5y_train_masks.shape, shsy5y_val_masks.shape)

In [None]:
if not submission_mode:
    plt.imshow(shsy5y_val_masks[2])
    plt.show()
    plt.close()

In [None]:
if not submission_mode:
    for filename in ['/kaggle/input/sartorius-cell-instance-segmentation/train/' + x + '.png' for x in shsy5y_ids][:3]:
        plot_image(filename)

The previous approach worked well for both corr and astro, so it might fit here as well.

In [None]:
# build our shsy5y_model
model_name = 'shsy5y_recognitionNN'
shsy5y_model = AugmentedSequential()
shsy5y_model.add(keras.Input(
    shape=(520,704,1)
))
# now some vanilla CNN layers
if True: # val_loss: 
    shsy5y_model.add(layers.Conv2D(
        filters=16, kernel_size=(3,3),
        activation='relu',
        padding='same')
    )
    shsy5y_model.add(layers.Conv2D(
        filters=8, kernel_size=(5,5),
        activation='relu',
        padding='same')
    )
    shsy5y_model.add(layers.Conv2D(
        filters=8, kernel_size=(7,7),
        activation='relu',
        padding='same')
    )
# Apply a dense layer along the axis of the filters, keeping the image size the same
shsy5y_model.add(layers.Conv2D(
    filters=8, kernel_size=(1,1),
    activation='relu',
    padding='same')
)
# OUPTUT -----------------------------
#shsy5y_model.add(layers.Dropout(0.5))
shsy5y_model.add(layers.Conv2D(
    filters=1, kernel_size=(1,1),
    activation='sigmoid',
    padding='same')
)

shsy5y_model.compile(optimizer='adam',
              loss=balanced_loss,
                   run_eagerly=True)  #tf.keras.losses.CategoricalCrossentropy(from_logits=False),
              #metrics=['accuracy']

shsy5y_model.summary()

In [None]:
if task_list[model_name]:
    print(model_name)
    if force_retrain or not os.path.exists(model_name+'.model'):
        stop_early = tf.keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=early_stopping_patience,
            restore_best_weights=True
        )
        callbacks = [stop_early,]
        history = shsy5y_model.fit(
            tf.cast(shsy5y_train_imgs,tf.float32).numpy(),
            tf.cast(tf.expand_dims(shsy5y_train_masks, axis=3),tf.float32).numpy(),
            validation_data=(
                tf.cast(shsy5y_val_imgs,tf.float32).numpy(),
                tf.cast(tf.expand_dims(shsy5y_val_masks, axis=3),tf.float32).numpy()),
            callbacks=callbacks,
            epochs=epochs_to_train)
        if not submission_mode:
            shsy5y_model.save(model_name + '.model')
            plot_history(history)
    else:
        shsy5y_model = keras.models.load_model(
            model_name+'.model',
            custom_objects={'balanced_loss':balanced_loss}
        )

In [None]:
# make sure augmentation works, check the predictions
def test_predictions (model):
    for i in range(3):
        X_val, y_val = tf_img_augment_nomodel(
            shsy5y_val_imgs, 
            tf.expand_dims(shsy5y_val_masks, axis=3).numpy()
        )

        ex_img = X_val[4].numpy()
        ex_truemask = y_val[4].numpy()

        # show an example prediction
        fig = plt.figure(figsize=(16,6), clear=True)

        plt.subplot(131)
        plt.imshow(ex_img, cmap='seismic')
        plt.title('Original image')

        plt.subplot(132)
        ex_mask = model.predict(ex_img.reshape((1,520,704,1)))[0]
        print(ex_mask.shape)
        plt.imshow(ex_mask, cmap="YlGn_r")
        plt.title('Predicted mask')
        plt.colorbar()

        plt.subplot(133)
        plt.imshow(ex_truemask, cmap="YlGn_r")
        plt.title('True mask')

        plt.show()
        plt.close()
    
if task_list[model_name] and not submission_mode:
    test_predictions(shsy5y_model)

In [None]:
# First, let's improve the contrast of the network output, to make it easier to segment
if not submission_mode:
    img = shsy5y_train_imgs[5]
    mask = shsy5y_train_masks[5]
    pred = shsy5y_model.predict(np.array([img,]))[0]

    plt.figure(figsize=(20,6), clear=True)
    plt.subplot(141)
    plt.imshow(img, cmap='seismic')
    plt.title('Original image')
    plt.subplot(142)
    plt.imshow(mask, cmap="YlGn_r")
    plt.title('True mask')
    plt.subplot(143)
    plt.imshow(pred, cmap="YlGn_r")
    plt.title('Prediction')
    plt.subplot(144)
    plt.imshow(pred > 0.6, cmap="YlGn_r")
    plt.title('Threshold')
    plt.show()
    plt.close()

    plt.hist(pred.flatten())
    plt.title('Prediction value histogram')
    plt.show()
    plt.close()

In [None]:
# Let's plot the true segmentation, so that we know what we're looking for...
if not submission_mode:
    myid = shsy5y_train_csv.id.unique()[5]
    shsy5y_train_csv[shsy5y_train_csv.id == myid].head()

    pixels_true = []
    cl = 0
    for annotation in tqdm(tf.random.shuffle( # shuffle helps to make adjacent cells different color
        shsy5y_train_csv[shsy5y_train_csv.id == myid].annotation)
                          ):
        mask_here = rle_decode_tf(annotation, shape=img.shape)
        pixels_true_here = tf.where(tf.squeeze(mask_here) > 0)
        pixels_true_here = tf.concat([pixels_true_here, cl*np.ones([pixels_true_here.shape[0],1])], axis=1)
        cl += 1
        pixels_true.append(pixels_true_here)
    pixels_true = tf.concat(pixels_true, axis=0)

    x = tf.transpose(pixels_true)[0]
    y = tf.transpose(pixels_true)[1]
    c = tf.transpose(pixels_true)[2]

    plt.scatter(x, y, c=c, s=0.125, cmap='prism')
    plt.show()
    plt.close()

Hmm this might benefit from using a particle simulator, so the shsy5y model would be an adjusted version of the Astro model.

In [None]:
# trying out some different force prescriptions here..

def force_RA (sparse_distance, sparse_sqr_distance, force_radius, norm=0.5, sqr_scale1=1.5, sqr_scale3=2.):
    distance = tf.reshape(sparse_distance.values, [-1,2])
    sqr_distance = sparse_sqr_distance.values
    part0_indices = sparse_sqr_distance.indices[:,0]
    # calculate the number of neighbours
    n_neighbours = tf.SparseTensor(
        indices = sparse_sqr_distance.indices, 
        values = tf.ones(tf.shape(sparse_sqr_distance.values), dtype=tf.float16),
        dense_shape = sparse_sqr_distance.dense_shape
    )
    n_neighbours = tf.sparse.reduce_sum(n_neighbours, axis=1)
    n_neighbours = tf.gather_nd(n_neighbours, tf.expand_dims(part0_indices,-1))
    # the actual force prescription
    force_proposed = (
        - norm * (sqr_scale1-sqr_distance) * tf.exp( - sqr_distance/sqr_scale3)
    )
    force_proposed = distance * tf.expand_dims(norm*force_proposed, axis=-1)
    # apply limits
    force_max = tf.sqrt(sqr_distance) / n_neighbours
    force_value = tf.reduce_sum(force_proposed**2, axis=1)
    force_direction = force_proposed / tf.expand_dims(force_value, axis=-1)
    force_proposed = tf.where(
        tf.expand_dims(tf.abs(force_value) < force_max, -1),
        force_proposed,
        norm * force_direction * tf.expand_dims(force_max,-1)
    )
    # reshape to sparse representation
    force_proposed = tf.reshape(force_proposed, [-1])
    return force_proposed

In [None]:
class Shsy5yModel (AstroModel):
    
    def __init__ (
        self,
        preprocess=False,
        detection_model=shsy5y_model,
        detection_threshold=0.6,
        simulator_force=force_RA,
        simulator_force_radius=10,
        simulator_force_kwargs={},
        simulator_steps=simulation_epochs,
        simulator_dt=simulation_dt,
        clusterer=DBSCAN(eps=6, min_samples=10)
    ):
        super(Shsy5yModel, self).__init__(
            preprocess=preprocess,
            detection_model=detection_model,
            detection_threshold=detection_threshold,
            simulator_force=simulator_force,
            simulator_force_radius=simulator_force_radius,
            simulator_force_kwargs=simulator_force_kwargs,
            simulator_steps=simulator_steps,
            simulator_dt=simulator_dt,
            clusterer=clusterer
        )

In [None]:
if not submission_mode:
    img = shsy5y_train_imgs[5]
    mask = shsy5y_train_masks[5]

    model = Shsy5yModel()
    encoded = model.predict(img)
    print(encoded[:50])

In [None]:
# let's see how our shsy5y model performs with respect to the competition score
if not submission_mode:
    
    # choose the indices to test on
    val_ids = np.random.choice(shsy5y_val_csv.id.unique(), 1)
    print(val_ids)

    # parse the cortect answers
    rows_true = shsy5y_val_csv[shsy5y_val_csv.id.isin(val_ids)]
    encoded_true = ''
    for idx, row in rows_true[['id', 'annotation']].iterrows():
        encoded_true += '\n' + row['id'] + ',' + row['annotation']
    encoded_true = encoded_true[1:]

    # calculate predictions
    if True:
        encoded = ''
        for myid in val_ids:
            print('--- Predicting image %s ---' % myid, flush=True)
            img = tf_img_color_rescale(
                tf_img_load(myid)
            ).numpy()
            # decode the target mask for comparison
            pixels_true = []
            cl = 0
            for annotation in tf.random.shuffle( # shuffle helps to make adjacent cells different color
                shsy5y_val_csv[shsy5y_val_csv.id == myid].annotation
            ):
                mask_here = rle_decode_tf(annotation, shape=img.shape)
                pixels_true_here = tf.where(tf.squeeze(mask_here) > 0)
                pixels_true_here = tf.concat([pixels_true_here, cl*np.ones([pixels_true_here.shape[0],1])], axis=1)
                cl += 1
                pixels_true.append(pixels_true_here)
            pixels_true = tf.concat(pixels_true, axis=0)
            # pass the image through our model
            model = Shsy5yModel()
            encoded += '\n' + model.predict(
                img,
                img_id=myid,
                mask_pixels=pixels_true
            )
            print('--- Image %s predicted ---' % myid, flush=True)
        encoded = encoded[1:]

        score = comp_score(encoded, encoded_true)
        print('Competition score: %.3f' % score)

-----
**Final Model**

In [None]:
def logit_to_celltype (logit):
    return cell_types[np.argmax(logit)]

class FinalModel:
    
    def __init__ (
        self,
        preprocess=False,
        classifier=img_classifier,
        astro_kwargs={},
        cort_kwargs={},
        shsy5y_kwargs={}
    ):
        self.preprocess=preprocess
        self.classifier=img_classifier
        self.segmenters = {
            'astro':AstroModel(preprocess=False, **astro_kwargs),
            'cort':CortModel(preprocess=False, **cort_kwargs),
            'shsy5y':Shsy5yModel(preprocess=False, **shsy5y_kwargs)
        }
        
    def predict (
        self,
        img, img_id='noid',
        mask_pixels=[],
        verbose=True
    ):
        if verbose:
            print("Processing image %s." % img_id, flush=True)
        # preprocess if needed
        if verbose:
            print(' - image preprocessing.. ', end='', flush=True)
            fig = plt.figure(figsize=(12,6))
            plt.subplot(121)
            plt.imshow(1.*img, cmap="seismic")
        if self.preprocess:
            img = img_color_rescale(img)
        if verbose:
            plt.subplot(122)
            plt.imshow(img, cmap="seismic")
            plt.show()
            plt.close()
            print('done.', flush=True)
        # classify the image
        if verbose:
            print(' - image classification.. ', flush=True)
        img_class = logit_to_celltype(
            self.classifier.predict(tf.expand_dims(img, axis=0))
        )
        if verbose:
            print('    - classified as ', img_class, ', done.')
        # apply the appropriate type-specific segmentation model
        result = self.segmenters[img_class].predict(
            img, img_id=img_id,
            mask_pixels=mask_pixels,
            verbose=verbose
        )
        # cleanup and output
        return result

In [None]:
# Test the model on a couple of examples
if not submission_mode:
    for celltype in cell_types:
        print('Testing cell type =', celltype, flush=True)
        img_id = train_csv[train_csv.cell_type == celltype].id.iloc[7]
        # load the image
        img = tf_img_load(img_id, directory='train')
        # parse the cortect answers
        rows_true = train_csv[train_csv.id == img_id]
        encoded_true = ''
        for idx, row in rows_true[['id', 'annotation']].iterrows():
            encoded_true += '\n' + row['id'] + ',' + row['annotation']
        encoded_true = encoded_true[1:]
        # decode the target mask for comparison
        pixels_true = []
        cl = 0
        for annotation in tf.random.shuffle( # shuffle helps to make adjacent cells different color
            train_csv[train_csv.id == img_id].annotation
        ):
            mask_here = rle_decode_tf(annotation, shape=img.shape)
            pixels_true_here = tf.where(tf.squeeze(mask_here) > 0)
            pixels_true_here = tf.concat([
                pixels_true_here, 
                cl*np.ones([pixels_true_here.shape[0],1])
            ], axis=1)
            cl += 1
            pixels_true.append(pixels_true_here)
        pixels_true = tf.concat(pixels_true, axis=0)
        # run everything through the model
        model = FinalModel(preprocess=True)
        encoded = model.predict(
            img, img_id,
            mask_pixels=pixels_true
        )
        print(encoded[:50])
        # calculate competition score
        score = comp_score(encoded, encoded_true)
        print('Competition score: %.3f' % score)

In [None]:
# process the submission

# load the submission examples
test_ids = glob.glob('/kaggle/input/sartorius-cell-instance-segmentation/test/*.png')
test_ids = [x.split('/')[-1][:-4] for x in test_ids]
print(len(test_ids), test_ids)

# run the model
submission = 'id, predicted\n'
for img_id in test_ids:
    # load the image
    img = tf_img_load(img_id, directory='test')
    # run everything through the model
    model = FinalModel(preprocess=True)
    encoded = model.predict(
        img, img_id, verbose=(not submission_mode)
    )
    print(encoded[:50])
    submission += encoded + '\n'
# parse as csv and view to ensure correct format
submission = pd.read_csv(StringIO(submission[:-1]))
print(submission.head())
# save the submission file
submission.to_csv('submission.csv', index=False)