# Purpose
This notebook works out a solution for cell segmination

References:
    https://github.com/CellProfiling/HPA-Cell-Segmentation

In [None]:

import pandas as pd 
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from PIL import Image  
from IPython.display import display 


In [None]:
!pip install 'git+https://github.com/haoxusci/pytorch_zoo@master#egg=pytorch_zoo'

In [None]:
!pip install https://github.com/CellProfiling/HPA-Cell-Segmentation/archive/master.zip

In [None]:
import hpacellseg.cellsegmentator as cellsegmentator
from hpacellseg.utils import label_cell, label_nuclei

# Parameters

In [None]:
ROOT_PATH = '/kaggle/input/hpa-single-cell-image-classification/'

CHANNELS = np.array(['red', 'green', 'blue', 'yellow'])

CHANNEL_RED = 0    # microtubule channels
CHANNEL_GREEN = 1  # protein of interest
CHANNEL_BLUE = 2   # nuclei channels
CHANNEL_YELLOW = 3 # endoplasmic reticulum 

CHANNEL_SIZE = len(CHANNELS)
SAMPLE_SIZE = 5

# Functions

In [None]:
#
# Get array of all images for given sample id
#
def get_images(id):
    images = list()
    for channel in CHANNELS:
        path = ROOT_PATH + 'train/{}_{}.png'.format(id, channel)
        image = Image.open(path) 
        images.append(image)
    return images

In [None]:
#
# Get single image that blends all RGBY into RGB
#
def get_blended_image(id): 
    # get rgby images for sample
    images = get_images(id)

    # blend rgby images into single array
    blended_array = np.stack((
            np.maximum(images[0], images[3]),
            np.maximum(images[1], images[3]),
            images[2]
        ), 2)

    # Create PIL Image
    blended_image = Image.fromarray( np.uint8(blended_array) )
    return blended_image

# Read Trainning Data

In [None]:
df_train = pd.read_csv(ROOT_PATH + 'train.csv')
print("Trainning data length: {}".format(len(df_train)))
df_train.head()

In [None]:
# if sample size is set then reduce trainning set accordingly
if SAMPLE_SIZE > -1:
    df_train = df_train.sample(SAMPLE_SIZE)
    df_train.reset_index(inplace=True);

In [None]:
NUC_MODEL = "./nuclei-model.pth"
CELL_MODEL = "./cell-model.pth"
segmentator = cellsegmentator.CellSegmentator(
    NUC_MODEL,
    CELL_MODEL,
    scale_factor=0.25,
    device="cuda",
    padding=False,
    multi_channel_model=True,
)


# Explore Using HPA-Cell-Segmentation

Take the first image from the samples and walk through steps to segment cells.

In [None]:
# id to use while exploring
sample_id = df_train['ID'].iloc[0]

images = get_images(sample_id)

nuc_segmentations = segmentator.pred_nuclei([np.asarray( images[CHANNEL_BLUE] )])
print(np.shape(nuc_segmentations))

# For full cells
cell_segmentations = cell_segmentations = segmentator.pred_cells([
        [np.asarray( images[CHANNEL_RED] )],
        [np.asarray( images[CHANNEL_YELLOW] )],
        [np.asarray( images[CHANNEL_BLUE] )]
    ])
print(np.shape(cell_segmentations))
   

In [None]:
np.shape(nuc_segmentations)

In [None]:
nuclei_mask = label_nuclei(nuc_segmentations[0])
print(np.shape(nuclei_mask))

cell_nuclei_mask, cell_mask = label_cell(nuc_segmentations[0], cell_segmentations[0])
print(np.shape(nuclei_mask))
print(np.shape(cell_mask))

In [None]:
fig = plt.figure(figsize=(25,25))

# Nuclei mask
nuclei_image = Image.fromarray( np.uint8(nuclei_mask) )

ax = fig.add_subplot(1, 4, 1)
ax.set_title("Nuclei Mask")
plt.imshow(np.asarray(nuclei_image))

# Cell nuclei mask
cell_nuclei_image = Image.fromarray( np.uint8(cell_nuclei_mask) )

ax = fig.add_subplot(1, 4, 2)
ax.set_title("Cell Nuclei Mask")
plt.imshow(np.asarray(cell_nuclei_image))

# Cell mask
cell_image = Image.fromarray( np.uint8(cell_mask) )

ax = fig.add_subplot(1, 4, 3)
ax.set_title("Cell Mask")
plt.imshow(np.asarray(cell_image))


# Isolate Each Segment

Isolate and separate each cell's mask into separate images

In [None]:
# Get unique vector of segment numbers
numbers = set( np.ravel(cell_mask) )
numbers.remove(0)

fig = plt.figure(figsize=(25,6*len(numbers)/4))
index = 1

# plot original cell mask from above
ax = fig.add_subplot((len(numbers)//4)+1, 4, index)
ax.set_title("Complete Cell Mask")
plt.imshow(np.asarray(cell_image))
index = index + 1

for number in numbers:
    # set all other 'numbers' to zero in cell mask
    isolated = np.where(cell_mask == number, cell_mask, 0)

    # plot isolated image
    ax = fig.add_subplot((len(numbers)//4)+1, 4, index)
    ax.set_title("Segment: {}".format(number))

    plt.imshow(isolated)
    index = index + 1

# Crop Cells

Use the mask to cut cells out of the original blended image

In [None]:
blended_image = get_blended_image(sample_id)
blended_array = np.asarray(blended_image)

# Get unique vector of segment numbers
numbers = np.unique(cell_mask)
numbers = np.delete(numbers, [0])

# build figure
fig = plt.figure(figsize=(25,6*len(numbers)/4))
index = 1

# plot original cell mask from above
ax = fig.add_subplot((len(numbers)//4)+1, 4, index)
ax.set_title("Blended Cell Image")
plt.imshow(np.asarray(blended_image))
index = index + 1

# plot original cell mask from above
ax = fig.add_subplot((len(numbers)//4)+1, 4, index)
ax.set_title("Complete Cell Mask")
plt.imshow(np.asarray(cell_image))
index = index + 1

for number in numbers:
    # set all other 'numbers' to 255 in cell mask
    isolated_mask = np.where(cell_mask == number, True, False)
    
    # match shape for RGB
    isolated_mask = np.stack((isolated_mask, isolated_mask, isolated_mask), axis=2)
    
    # crop across all channels
    isolated_image = np.where(isolated_mask == True, blended_array, 0)
    
    # plot isolated image
    ax = fig.add_subplot((len(numbers)//4)+2, 4, index)
    ax.set_title("Segment: {}".format(number))

    plt.imshow(isolated_image)
    index = index + 1

# Segment All Samples

Use what was learned above and loop through all samples printing the blended image and the cell mask.

In [None]:
masks = {}

sample_count = 1
for sample_index, df_sample in df_train.iterrows():
    images = get_images(df_sample['ID'])
    
    print("\rSegmentation started (count={}, current_id={})  ".format(sample_count, df_sample['ID']), end="")
    
    # Segment nuclie
    nuc_segmentations = segmentator.pred_nuclei([np.asarray( images[CHANNEL_BLUE] )])

    # For full cells
    cell_segmentations = cell_segmentations = segmentator.pred_cells([
        [np.asarray( images[CHANNEL_RED] )],
        [np.asarray( images[CHANNEL_YELLOW] )],
        [np.asarray( images[CHANNEL_BLUE] )]
    ])

    nuclei_mask, cell_mask = label_cell(nuc_segmentations[0], cell_segmentations[0])
    
    masks[df_sample['ID']] = [nuclei_mask, cell_mask]

    print("\rSegmentation completed (count={}, last_id={})   ".format(sample_count, df_sample['ID']), end="")
    sample_count = sample_count + 1


In [None]:

COLUMN_COUNT = 5
ROW_COUNT = 50
index = 1

fig = plt.figure(figsize=(30,5*ROW_COUNT))

for sample_index, df_sample in df_train.iterrows():
    
    images = get_images(df_sample['ID'])
    [nuclei_mask, cell_mask] = masks[df_sample['ID']]

    # Blended image
    blended_image = get_blended_image(df_sample['ID'])
    blended_array = np.asarray(blended_image)

    ax = fig.add_subplot(ROW_COUNT, COLUMN_COUNT, index)
    ax.set_title("Blended Image")
    plt.imshow(np.asarray(blended_image))
    index = index + 1
    if index > (ROW_COUNT*5):
        break

    # Get unique vector of segment numbers
    numbers = np.unique(cell_mask)
    numbers = np.delete(numbers, [0])
    
    for number in numbers:
        # Isolate and crop cell from blended image
        isolated_mask = np.where(cell_mask == number, True, False)
        isolated_mask = np.stack((isolated_mask, isolated_mask, isolated_mask), axis=2)
        isolated_image = np.where(isolated_mask == True, blended_array, 0)

        ax = fig.add_subplot(ROW_COUNT, COLUMN_COUNT, index)
        ax.set_title("Blended Cell")
        plt.imshow(np.asarray(isolated_image))

        index = index + 1
        if index >= (ROW_COUNT*5):
            break


