In [None]:
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from PIL import Image

import os

!pip install https://github.com/CellProfiling/HPA-Cell-Segmentation/archive/master.zip

In [None]:
def get_blended_image(images): 

    blended_array = np.stack(images[:-1], 2)

    blended_image = Image.fromarray( np.uint8(blended_array) )
    return blended_image

def image_to_arrays(path):
    
    image_arrays = list()
    for image in path:
        array = np.asarray(Image.open(image))
        image_arrays.append(array)
        
    return image_arrays

In [None]:
import hpacellseg.cellsegmentator as cellsegmentator
from hpacellseg.utils import label_cell, label_nuclei

NUC_MODEL = "../input/hpacellsegmentatormodelweights/dpn_unet_nuclei_v1.pth"
CELL_MODEL = "../input/hpacellsegmentatormodelweights/dpn_unet_cell_3ch_v1.pth"
segmentator = cellsegmentator.CellSegmentator(
    NUC_MODEL,
    CELL_MODEL,
    scale_factor=0.25,
    device="cuda",

    padding=False,
    multi_channel_model=True,
)

def get_masks(imgs, test=True):
    try:
        images = [[img[:, :, 0] for img in imgs], 
                  [img[:, :, 3] for img in imgs], 
                  [img[:, :, 2] for img in imgs]]
    
        nuc_segmentations = segmentator.pred_nuclei(images[2])
        cell_segmentations = segmentator.pred_cells(images)
        cell_masks = []
        for i in tqdm(range(len(cell_segmentations)), desc='Labeling cells..'):
            _, cell_mask = label_cell(nuc_segmentations[i], cell_segmentations[i])
            cell_masks.append(cell_mask)
        return cell_masks
    except:
        raise ValueError('Segmentation failed')

In [None]:
channels = ['_red.png', '_blue.png', '_yellow.png', '_green.png']
train_label = pd.read_csv('../input/hpa-single-cell-image-classification/train.csv')
train_data = '../input/hpa-single-cell-image-classification/train'
paths = [[os.path.join(train_data, train_label.iloc[idx,0])+ channel for channel in channels] for idx in range(len(train_label))]

In [None]:
image = paths[1]
array = image_to_arrays(image)
blended_image = get_blended_image(array)
plt.imshow(blended_image)

In [None]:
nuclei = array[1]
cell = array[:-1]
inter_step = [[i] for i in image[:-1]]

# For nuclei
nuc_segmentations = segmentator.pred_nuclei([nuclei])

# For full cells
cell_segmentations = segmentator.pred_cells(inter_step)

In [None]:
# post-processing
nuclei_mask, cell_mask = label_cell(nuc_segmentations[0], cell_segmentations[0])
plt.imshow(cell_mask)

In [None]:
# Unique vector of cell_mask numbers
numbers = set(np.ravel(cell_mask))
numbers.remove(0)

fig = plt.figure(figsize=(25,6*len(numbers)/4))
index = 1

for number in numbers:
    isolated_cell = np.where(cell_mask==number, cell_mask, 0)
    ax = fig.add_subplot(len(numbers)//5+1, 5, index)
    ax.set_title(number, size=20)
    plt.imshow(isolated_cell)
    index += 1

In [None]:
isolated_cell = np.where(cell_mask==1, cell_mask, 0)
plt.imshow(isolated_cell)

In [None]:
isolated_cell = np.where(cell_mask==2, cell_mask, 0)
plt.imshow(isolated_cell)

In [None]:
isolated_cell = np.where(cell_mask==3, cell_mask, 0)
plt.imshow(isolated_cell)

In [None]:
isolated_cell = np.where(cell_mask==4, cell_mask, 0)
plt.imshow(isolated_cell)

In [None]:
isolated_cell = np.where(cell_mask==5, cell_mask, 0)
plt.imshow(isolated_cell)

In [None]:
isolated_cell = np.where(cell_mask==6, cell_mask, 0)
plt.imshow(isolated_cell)

In [None]:
isolated_cell = np.where(cell_mask==7, cell_mask, 0)
plt.imshow(isolated_cell)

In [None]:
isolated_cell = np.where(cell_mask==8, cell_mask, 0)
plt.imshow(isolated_cell)

In [None]:
isolated_cell = np.where(cell_mask==9, cell_mask, 0)
plt.imshow(isolated_cell)

In [None]:
isolated_cell = np.where(cell_mask==10, cell_mask, 0)
plt.imshow(isolated_cell)

In [None]:
isolated_cell = np.where(cell_mask==11, cell_mask, 0)
plt.imshow(isolated_cell)

In [None]:
isolated_cell = np.where(cell_mask==12, cell_mask, 0)
plt.imshow(isolated_cell)

In [None]:
isolated_cell = np.where(cell_mask==13, cell_mask, 0)
plt.imshow(isolated_cell)

In [None]:
isolated_cell = np.where(cell_mask==14, cell_mask, 0)
plt.imshow(isolated_cell)

In [None]:
isolated_cell = np.where(cell_mask==15, cell_mask, 0)
plt.imshow(isolated_cell)