# Coloring network for instance segmentation

The objective of this notebook is to show the architecture of a coloring network for instance segmentation.
It is based on Unet and make use of [1]:

- A seed net to predict the seed of cell. This is instance detection and finds the centroids of target objects. 

- A color net to propagate the segmentation of seed to the whole cell. This step replaces watershed transform which can be unstable and produce over-segmentation for thin, long target object, especially when there is multiple seeds (i.e. multiple basins) detected for one object


[1] https://www.kaggle.com/hengck23/split-adjoining-cell-into-subsets-of-non-touching

In [None]:
import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')
sys.path.append('../input/coloring-network')


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.cuda.amp as amp
is_amp = True   

from timm.models.resnet import *
 
import numpy as np
import cv2
import pandas as pd
import matplotlib.pyplot as plt

import skimage.measure
from color_graph import *

image_width  = 704
image_height = 520

#helper
def image_show(image, mode='gray'):
    image = (image-image.min())/(image.max()-image.min()+0.0001)
    if mode=='gray':
        plt.imshow(image,'gray')
        
    if mode=='rgb':
        plt.imshow(image[...,::-1])
        
        
def rle_decode(rle, width=image_width, height=image_height, fill=1, dtype=np.float32):
    s = rle.split()
    start  = np.asarray(s[0::2], dtype=int)-1
    length = np.asarray(s[1::2], dtype=int)
    end = start + length
    image = np.zeros(height * width, dtype=dtype)
    for s, e in zip(start, end):
        image[s:e] = fill
    image = image.reshape(height, width) #.T
    return image

In [None]:

class Backbone(nn.Module):
    def __init__(self, in_channel=3, arch='34d'):
        super().__init__()
        if arch=='18d':
            e = resnet18d(pretrained=True)
        if arch=='34d':
            e = resnet34d(pretrained=True)
        if arch=='50d':
            e = resnet50d(pretrained=True)

        if in_channel!=3:
            e.conv1[0] = nn.Conv2d(in_channel, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

        self.block0 = nn.Sequential(
            e.conv1,
            e.bn1,
            e.act1,
        )
        self.block1 = nn.Sequential(
            e.maxpool,
            e.layer1,
        )
        self.block2 = e.layer2
        self.block3 = e.layer3
        self.block4 = e.layer4
        del e    #dropped

    def forward(self,x ):
        x0 = self.block0(x)
        x1 = self.block1(x0)
        x2 = self.block2(x1)
        x3 = self.block3(x2)
        x4 = self.block4(x3)
        return [x0,x1,x2,x3,x4]


class FeaturePyramidNet(nn.Module):
    def __init__( self, in_channel, out_channel ):
        super(FeaturePyramidNet, self).__init__()
        self.inner_block = nn.ModuleList()
        self.layer_block = nn.ModuleList()

        for in_c in in_channel:
            self.inner_block.append(
                nn.Conv2d(in_c, out_channel, 1)
            )
            self.layer_block.append(
                nn.Conv2d(out_channel, out_channel, 3, padding=1)
            )

        #extra
        self.extra_max_pool = nn.MaxPool2d(kernel_size=2,stride=2)

        # initialize parameters now to avoid modifying the initialization of top_blocks
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight, a=1)
                nn.init.constant_(m.bias, 0)


    def forward(self, x):
        num_layer = len(x)
        out = []

        last_inner = self.inner_block[-1](x[-1])
        last_layer = self.layer_block[-1](last_inner)
        out.append(last_layer)

        for i in range(num_layer-2, -1, -1):
            lateral    = self.inner_block[i](x[i])
            top_down   = F.interpolate(last_inner, size=lateral.shape[-2:], mode='nearest')
            last_inner = lateral + top_down
            last_layer = self.layer_block[i](last_inner)
            out.insert(0,last_layer)

        out.append(
            self.extra_max_pool(out[-1])
        )
        return out

#----------------------------------------------------------------------------------------
class SeedNet(nn.Module):
    def __init__(self, ):
        super().__init__()

        self.backbone = Backbone(arch='50d')
        #self.fpn = FeaturePyramidNet([64, 64, 128, 256, 512], 32)
        self.fpn = FeaturePyramidNet([64, 256, 512, 1024, 2048], 64)
        self.fuse = nn.Sequential(
            nn.Conv2d(64 * 6, 256, kernel_size=3, padding=1, bias=None),
            nn.BatchNorm2d(256),
            nn.SiLU(),
            nn.Conv2d(256, 128, kernel_size=3, padding=1, bias=None),
            nn.BatchNorm2d(128),
            nn.SiLU(),

            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(128, 64, kernel_size=3, padding=1, bias=None),
            nn.BatchNorm2d(64),
            nn.SiLU(),
        )
        self.seed = nn.Sequential(
            nn.Conv2d(64, 4, kernel_size=1,padding=0),
        )#3 cell_type + background


    def forward(self, input):
        x = input['x']

        feature = self.backbone(x)
        feature = self.fpn(feature)

        z = []
        for i, f in enumerate(feature):
            if i != 0:
                f = F.interpolate(f, size=feature[0].shape[-2:], mode='nearest')
            z.append(f)
        z = torch.cat(z, 1)  # torch.Size([1, 192, 256, 256])
        z = self.fuse(z)

        # ---
        seed = self.seed(z)
        output = {
            'seed' : seed,
        }
        return output


class ColorNet(nn.Module):
    def __init__(self, ):
        super().__init__()

        self.backbone = Backbone(in_channel=3+1,arch='34d' ) 
        self.fpn = FeaturePyramidNet([64,64,128,256,512],32)
        self.fuse= nn.Sequential(
            nn.Conv2d(32*6, 256, kernel_size=3,padding=1, bias=None),
            nn.BatchNorm2d(256),
            nn.SiLU(),
            nn.Conv2d(256, 128, kernel_size=3,padding=1, bias=None),
            nn.BatchNorm2d(128),
            nn.SiLU(),

            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(128, 64, kernel_size=3,padding=1, bias=None),
            nn.BatchNorm2d(64),
            nn.SiLU(),
        )
        self.color = nn.Sequential(
            nn.Conv2d(64, 1, kernel_size=1,padding=0),
        )

    def forward(self, input):
        x        = input['x']
        seed     = input['sampled_seed']
        batch_size, num_color, H, W = seed.shape

        #duplicate for num of colored seeds
        x = x.reshape(batch_size, 1, 3, H, W).expand(-1,num_color,-1,-1,-1)
        seed = seed.reshape(batch_size, num_color, 1, H, W)
        x = torch.cat([x, seed,], 2).reshape(batch_size*num_color, -1, H, W)

        feature = self.backbone(x)
        feature = self.fpn(feature)

        z=[] #resize fpn feature maps to same size
        for i, f in enumerate(feature):
            if i!=0:
                f = F.interpolate(f,size=feature[0].shape[-2:],mode='nearest')
            z.append(f)
        z = torch.cat(z,1)
        z = self.fuse(z)

        #---
        color = self.color(z)
        color = color.reshape(batch_size, num_color, H, W)
        output ={
            'color' : color,
        }
        return output


Let's applied a learned model on the training images (**not validation**) to inspect the results. Can the solution learned complex details of the instance segmentation?

If the results looks good, then we can proceed to the next stage of generalisatiing the newtork to unseen validation image.

In [None]:
# load the learned model
seed_net  = SeedNet() 
color_net = ColorNet() 

model_file = '../input/coloring-network/00024500.model.pth'
f = torch.load(model_file, map_location=lambda storage, loc: storage)
seed_net.load_state_dict(f['seed_state_dict'],strict=True)   
color_net.load_state_dict(f['color_state_dict'],strict=True)   
print('load model ok !')


In [None]:
#load some images

image_id =[
    '11c2e4fcac6d', #astro example
    '7ad870da5a63', #cort example
    '1c10ee85de67', #shsy5y example
]
        
image = []
for id in image_id:
    image_file = '../input/sartorius-cell-instance-segmentation/train/%s.png'%(id)
    m = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE)
    image.append(m)
    
#change to torch tensor
x = torch.from_numpy(np.stack(image)).unsqueeze(1)
x = x.float()/255

mx = torch.median(x)
x  = x-mx
x  = x.repeat(1,3,1,1)
x[:,0] *= 2
x[:,1] *= 4
x[:,2] *= 16
x = torch.tanh(x)

input ={
    'x': x
}


In [None]:
#apply seed net and visualise results

seed_net.eval()
with amp.autocast(enabled = is_amp):
    with torch.no_grad():
        output = seed_net(input)

    
def show_predicted_seed(output, input) : 
    
    x = input['x'].data.cpu().numpy()*0.5 +0.5
    x = np.ascontiguousarray(x.transpose(0,2,3,1))
    
    seed = torch.softmax(output['seed'],1).float().data.cpu().numpy()
    seed_argmax = np.argmax(seed,1)
                
    batch_size = len(x)
    for b in range(batch_size):
        overlay = x[b].copy()  
        overlay[seed_argmax[b]!=0] = (1,0,0)
        plt.figure(figsize=(12, 9)), image_show(overlay, 'rgb')
        
show_predicted_seed(output, input) 

In [None]:
#split seeds
num_color=8 #exclude background (which is zero)
def split_seed_and_add_to_input(output, input): 
    s = torch.argmax(output['seed'], 1, keepdim=True) != 0
    s = s.data.cpu().numpy()
    
    #Label connected regions (CCL) 
    s_l = skimage.measure.label(s, background=0) #optional : filter invalid case (e.g. too small, etc)
    expand = expand_labels(s_l, distance=20)
    
    sampled_seed = []
    batch_size = len(s)
    for b in range(batch_size):
        e_l = do_color_label(expand[b,0], num_color=num_color)
        print('assigned color', np.unique(e_l))
        e_l = e_l*(s_l[b,0]>0)
        sampled_seed.append(e_l)
        
    sampled_seed = np.stack(sampled_seed)
    sampled_seed = torch.from_numpy(sampled_seed)
    
    sampled_seed = F.one_hot(sampled_seed.long(),num_color+1).float()
    sampled_seed = sampled_seed.permute(0,3,1,2).contiguous()
    sampled_seed[:,0]=1-sampled_seed[:,1:].sum(1,keepdim=False) 
    print('sampled_seed.shape :', sampled_seed.shape)
    
    input['sampled_seed'] = sampled_seed.detach()
    return input
        
def show_sampled_seed(input) : 
    
    s = input['sampled_seed']
    s = F.interpolate(s,size=(128,128),mode='bilinear',align_corners=False)
    s = s.data.cpu().numpy() 
    batch_size, num_color,h, w = s.shape
    
   
    for b in range(batch_size):
        overlay=[]
        for i in range(num_color):
            si = s[b,i]
            si[0]=si[-1]=si[:,0]=si[:,-1]=0.5
            overlay.append(si)
        overlay = np.hstack(overlay)
        plt.figure(figsize=(24, 18)), image_show(overlay, 'gray')
        
         
with torch.no_grad():
    input = split_seed_and_add_to_input(output, input)

show_sampled_seed(input)

In [None]:
#apply color net and visualise results
color_net.eval()
with amp.autocast(enabled = is_amp):
    with torch.no_grad():
        output1 = color_net(input)


In [None]:
def show_predicted_color(output, input) : 
    colormap = np.array([
            [  0,  0,  0],
            [ 77,159,255],
            [  0,255,  0],
            [255,  0,  0],
            [  0,255,255],
            [255,255,  0],
            [255,150,255],
            #[234,178,200],
            [  0,  0,255],
            [255,255,255],
        ])
    
    color = torch.softmax(output['color'], 1)
    color_small = F.interpolate(color,size=(128,128),mode='bilinear', align_corners=False)
    color_argmax = color.argmax(1).data.cpu().numpy()
    color_small  = color_small.float().data.cpu().numpy()       
           
    batch_size, num_color,h, w = color.shape
    for b in range(batch_size):
        
        overlay=[] 
        for i in range(num_color):
            m = color_small[b,i].copy()
            m[:,0]=m[:,-1]=m[0]=m[-1]=0.5
            overlay.append(m)

        overlay = np.hstack(overlay)
        plt.figure(figsize=(24, 18)), image_show(overlay, 'gray')
          
        overlay = draw_label_to_overlay(color_argmax[b],colormap)
        plt.figure(figsize=(12, 9)), image_show(overlay, 'rgb')
        
show_predicted_color(output1, input) 

I note that the seeds of yellow (color layer 4) stick together after region growing. Hence the seed splitting algorithm still needs to be improved. 

Different splitting can be used as test-time augmentation (TTA) for inference.

In [None]:
# show region growing on for a few color channel

def show_region_growing(output, input) : 
 
    
    color = torch.softmax(output['color'], 1)
    color = color.float().data.cpu().numpy()
    
    x = input['x'].data.cpu().numpy()*0.5 +0.5
    x = np.ascontiguousarray(x.transpose(0,2,3,1))
    
    seed = torch.softmax(input['sampled_seed'],1).float().data.cpu().numpy()
    seed_argmax = np.argmax(seed,1)
           
    batch_size, num_color,h, w = color.shape
    for b in range(batch_size):
        overlay =[]
        for i in [2,3,4]:
            o = x[b].copy() 
            o[...,1] = 1-(1-o[...,1])*(1-color[b,i]) 
            o[seed_argmax[b]!=0] = (1,0,0)
            o[:,0]=o[:,-1]=o[0]=o[-1]=1
            overlay.append(o) 
        overlay = np.hstack(overlay)
        plt.figure(figsize=(24, 18)), image_show(overlay, 'rgb')
        
show_region_growing(output1, input) 

Let's try with the "ground-truth splitting"

In [None]:
#ground truth
train_df = pd.read_csv('../input/sartorius-cell-instance-segmentation/train.csv')

label = [] 
for id in image_id: 
    df = train_df[train_df['id']==id].reset_index(drop=True)
    l = np.zeros((image_height,image_width), dtype=np.int32)
    for i,d in df.iterrows():
        m = rle_decode(d.annotation, fill=True, dtype=np.bool)
        l[m]=i+1
    label.append(l)
    
label = torch.from_numpy(np.stack(label)).long().unsqueeze(1)
target = {
    'label': label,
}
#print(label.shape)
    

#split seeds based on ground truth (e.g. use for training)
num_color=8 #exclude background (which is zero)
def split_seed_and_add_to_input_and_target(output, input, target): 
    l = target['label'] 
    s = torch.argmax(output['seed'], 1, keepdim=True) != 0
    s = s * l 
 
    if 1: #remove miss:
        batch_size = len(s)
        for b in range(batch_size):
            #https://discuss.pytorch.org/t/intersection-between-to-vectors-tensors/50364/9
            # a.symmetric_difference(b)

            u1 = torch.unique(s[b]).data.cpu().numpy()
            u2 = torch.unique(l[b]).data.cpu().numpy()
            miss = list(set(u1) ^ set(u2))
            #print(miss)
            for m in miss:
                l[b][l[b]==m]=0

        # graph-color sampling -----------------------------
        s = s.data.cpu().numpy()
        l = l.data.cpu().numpy()

        sampled_seed  = []
        sampled_color = []
        for b in range(batch_size):
            l0 = l[b,0]
            unique, index, inverse, area = \
                np.unique(l0, return_index=True, return_inverse=True, return_counts=True)

            l1 = inverse.reshape(l0.shape)
            expand = expand_labels(l1, distance=30)
            l2 = do_color_label(expand, num_color=num_color)
            l2[l2>num_color]=0
            l2 = l2 * (l0 > 0)
            s2 = l2*(s[b,0]>0)

            sampled_seed.append(s2)
            sampled_color.append(l2)
#             if 0: #debug
#                 overlay = []
#                 for i in range(num_color + 1):
#                     overlay.append(np.vstack([s2==i,l2==i]))  
#                 overlay = np.hstack(overlay)
#                 image_show('overlay', overlay, min=0, max=1, resize=0.7)
#                 cv2.waitKey(0)


        sampled_seed  = torch.from_numpy(np.stack(sampled_seed)).long().to(input['x'].device)
        sampled_color = torch.from_numpy(np.stack(sampled_color)).long().to(input['x'].device)
        sampled_seed  = F.one_hot(sampled_seed, num_color + 1).float()
        sampled_color = F.one_hot(sampled_color, num_color + 1).float()

        input['sampled_seed' ] = sampled_seed.permute(0,3,1,2).contiguous()
        target['sampled_color'] = sampled_color.permute(0,3,1,2).contiguous()
        return input,target


input,target = split_seed_and_add_to_input_and_target(output, input, target)

In [None]:
#apply color net and visualise results
color_net.eval()
with amp.autocast(enabled = is_amp):
    with torch.no_grad():
        output1 = color_net(input)
        
'''
for training, apply loss like:
color_loss(predict=output1['color'], truth=target['sampled_color'])

'''        

In [None]:
def show_predicted_color(output, input, target) : 
    colormap = np.array([
            [  0,  0,  0],
            [ 77,159,255],
            [  0,255,  0],
            [255,  0,  0],
            [  0,255,255],
            [255,255,  0],
            [255,150,255],
            #[234,178,200],
            [  0,  0,255],
            [255,255,255],
        ])
    
    color = torch.softmax(output['color'], 1)
    color_small = F.interpolate(color,size=(128,128),mode='bilinear', align_corners=False)
    color_argmax = color.argmax(1).data.cpu().numpy()
    color_small  = color_small.float().data.cpu().numpy()       
           
        
    color_hat = target['sampled_color']
    color_small_hat = F.interpolate(color_hat,size=(128,128),mode='bilinear', align_corners=False)
    color_argmax_hat = color_hat.argmax(1).data.cpu().numpy()
    color_small_hat  = color_small_hat.float().data.cpu().numpy()       
    
    
    batch_size, num_color,h, w = color.shape
    for b in range(batch_size):
        
        overlay = [] 
        for i in range(num_color):
            m_hat = color_small_hat[b,i].copy()
            m_hat[:,0]=m_hat[:,-1]=m_hat[0]=m_hat[-1]=0.5 
            m = color_small[b,i].copy()
            m[:,0]=m[:,-1]=m[0]=m[-1]=0.5 
            overlay.append(np.vstack([m_hat,m]))

        overlay = np.hstack(overlay)
        plt.figure(figsize=(24, 18)), image_show(overlay, 'gray')
        
        
        m1 = draw_label_to_overlay(color_argmax_hat[b],colormap)
        m2 = draw_label_to_overlay(color_argmax[b],colormap)
        m1[:,0]=m1[:,-1]=m1[0,:]=m1[-1,:]=(255,255,255) 
        m2[:,0]=m2[:,-1]=m2[0,:]=m2[-1,:]=(255,255,255)  
        overlay = np.hstack([m1, m2])
        plt.figure(figsize=(24, 18)), image_show(overlay, 'rgb')
         
        
show_predicted_color(output1, input, target)