# Split adjoining cell into subsets of non-touching ones

This is the problem of graph coloring:
"it is a way of coloring the vertices of a graph such that no two adjacent vertices are of the same color [1]"

The code is modified from n-color algorithm in [2] and cellpose [3].

If the cells are not touching each other, then instance segmentation is simply semantic segmentation and can be solved using UNet, etc. Hence a solution to the "Sartorius Cell Instance Segmentation" problem is:
- detect cell potential candidates (e.g. centeroid of bounding box or proposal, seeds for watershed, etc)

  image --> [instance detection]  --> seed
  
- split instances into non-overlaping subsets

  seed --> [this code]  --> group1, group2 ...
  
- perform semantic segmentation on each subset. Use connected components labelling CCL (e.g. skimage.measure.label function) to recover instance segmentation

  group-n + image --> [semantic segmentation with marker]  --> segmentation-n  --> [CCL] --> instance segmentation


As a final note:
1. If the cells are not really clustered, random spilting into N subsets can also ensure they are not touching within each subset, if N is large enough. This may be useful in training.

2. It is also possible to learn a network to do the subset spliting. 

3. This framework can be use as an add-on refinement stage of other methods like mask-rcnn. You already have a good estimate of the instance segmentation and would like to refine their masks at the full resolution. You can just split predictions of mask rcnn into non-touching subset and perform semantic segmentation with marker.

[1] https://en.wikipedia.org/wiki/Graph_coloring

[2] https://forum.image.sc/t/relabel-with-4-colors-like-map/33564

[3] https://github.com/mouseland/cellpose


In [None]:
import numpy as np
import cv2
import pandas as pd
import matplotlib.pyplot as plt

from numba import njit
from scipy.ndimage import generate_binary_structure
from skimage.segmentation import expand_labels

In [None]:

def make_neighbour_idx(w, h):
    n = generate_binary_structure(rank=2, connectivity=2)
    n[1,1] = 0

    idx = np.where(n>0)
    idx = np.array(idx, dtype=np.uint8).T #neighbour (i,j)
    idx = np.array(idx-[1,1])
    idx = idx*[[w,1]]
    idx = idx[:,0]-idx[:,1]
    return idx

@njit(fastmath=True)
def search_for_neighbour(image, idx):
    line = image.ravel()
    L = len(line)

    neighbour = np.zeros((L,2), image.dtype)
    s = 0
    for i in range(L):
        if line[i]==0:continue
        for d in idx:
            if line[i+d]==0: continue
            if line[i]==line[i+d]: continue

            neighbour[s,0] = line[i]
            neighbour[s,1] = line[i+d]
            s += 1
    return neighbour[:s]


def find_connect(image):
    pad = np.pad(image, 1, 'constant')
    h,w = pad.shape

    idx = make_neighbour_idx(w, h)
    #print(idx)
    neighbour = search_for_neighbour(pad, idx)
    #print(neighbour)

    if len(neighbour)<2:
        return neighbour


    neighbour.sort(axis=1)
    i = (neighbour[:,0]<<16) + neighbour[:,1]

    argsort = np.argsort(i)

    i[:] = i[argsort]
    diff = i[:-1]!=i[1:]

    idx = np.where(diff)[0]+1
    idx = np.hstack(([0], idx))
    connect = neighbour[argsort][idx]
    return connect


def make_graph(connect):
    graph = {}
    for i in np.unique(connect):
        graph[i] = []

    for i,j in connect:
        graph[i].append(j)
        graph[j].append(i)

    return graph


def assign_color( graph, num_color=4, rand=12, depth=0, max_depth=8):
    threshold = 1e4

    if depth<max_depth:
        node = list(graph.keys())
        num_node = len(node)

        np.random.seed(depth+1)
        np.random.shuffle(node)

        assign  = dict(zip(node, [0]*num_node))
        counter = dict(zip(node, [0]*num_node))
        t = 0
        while len(node)>0 and t<threshold:
            t+=1

            k = node.pop(0)
            counter[k] += 1
            hist = [1e4] + [0] * num_color #history of used color
            for p in graph[k]:
                hist[assign[p]] += 1

            if min(hist)==0:
                assign[k] = hist.index(min(hist))
                counter[k] = 0


            hist[assign[k]] = 1e4
            minc = hist.index(min(hist))
            if counter[k]==rand:
                counter[k] = 0
                np.random.seed(count)
                minc = np.random.randint(1,num_color+1)

            assign[k] = minc
            for p in graph[k]:
                if assign[p] == minc:
                    node.append(p)

        if t==threshold:
            print('n-color algorithm failed, trying again with num_color=%d, depth=%d'%(num_color+1,depth+1))
            assign = assign_color(graph, num_color+1, rand, depth+1, max_depth)

        return assign

    else:
        print('n-color algorithm exceeded max depth of',max_depth)
        return None


#---------------------------------------------

#4-color algorthm based on https://forum.image.sc/t/relabel-with-4-colors-like-map/33564
def do_color_label(label, num_color=4):
    # label format is 0,1,2,3... N

    connect = find_connect(label)
    graph   = make_graph(connect)
    assign  = assign_color(graph, num_color)

    lut = np.ones(label.max() + 1, dtype=np.uint8)
    for i in assign:
        lut[i] = assign[i]
    lut[0] = 0

    unique, inverse, count = np.unique(lut, return_inverse=True, return_counts=True)
    lut = inverse.reshape(lut.shape)

    color = lut[label]
    return color


def relabel(label):
    _, inv, area = np.unique(label, return_inverse=True, return_counts=True)
    label = inv.reshape(label.shape)  # relabel to 0,1,2,3... N
    return label, area



def draw_label_to_overlay(label, color=None):
    if color is None:
        color = np.array([
            [  0,  0,  0],
            [ 77,159,255],
            [  0,255,  0],
            [255,  0,  0],
            [  0,255,255],
            [255,255,  0],
            [255,150,255],
            #[234,178,200],
            [  0,  0,255],
        ])

    h,w = label.shape
    overlay = color[label]
    return overlay



In [None]:
image_width  = 704
image_height = 520

def image_show_norm(image, mode='gray'):
    image = (image-image.min())/(image.max()-image.min()+0.0001)
    if mode=='gray':
        plt.imshow(image,'gray')
        
    if mode=='rgb':
        plt.imshow(image[...,::-1])
        
        
def rle_decode(rle, width=image_width, height=image_height, fill=1, dtype=np.float32):
    s = rle.split()
    start  = np.asarray(s[0::2], dtype=int)-1
    length = np.asarray(s[1::2], dtype=int)
    end = start + length
    image = np.zeros(height * width, dtype=dtype)
    for s, e in zip(start, end):
        image[s:e] = fill
    image = image.reshape(height, width) #.T
    return image




Let's try an input image "1c10ee85de67"

In [None]:
train_df = pd.read_csv('../input/sartorius-cell-instance-segmentation/train.csv')
df = train_df[train_df['id']=='1c10ee85de67'].reset_index(drop=True)
label = np.zeros((image_height,image_width), dtype=np.int32)
for i,d in df.iterrows():
    m = rle_decode(d.annotation, fill=True, dtype=np.bool)
    label[m]=i+1

image_show_norm(label,'gray')

After apply graph coloring,

In [None]:
assert len(np.unique(label)) == label.max()+1 #make sure label format: 0,1,2,3...N, is correct

label5 = do_color_label(label, num_color=5) 
print('input label :', np.unique(label5))

expand = expand_labels(label, distance=20)
expand5 = do_color_label(expand, num_color=5)
expand5 = expand5*(label>0)
expand5_overlay = draw_label_to_overlay(expand5)

image_show_norm(expand5_overlay,'rgb')

In [None]:
print('assigned color :', np.unique(expand5))
for u in np.unique(expand5)[1:]:
    m = expand5_overlay.copy()
    m[expand5!=u]=(0,0,0)
    plt.figure(), image_show_norm(m,'rgb')
